The GFlowNets and Amortized Marginalization Tutorial
https://milayb.notion.site/The-GFlowNets-and-Amortized-Marginalization-Tutorial-01755ca312834e15ab0ae9ef46bcb1bb
Pre-requisite
- MC simulation: the method I learned from the MIT opencourse class to estimate pi by dropping balls in a square with a circle inside.
- The ball dropping is the process of sampling from the underlying distribution (i.e., the area of the square and that of the circle provide the distribution, a Bernoulli distribution)
- Randomness ensures sample selection is unbiased.
- The law of large number is also needed.
- Monte-Carlo Approximation is a specific application of Monte-Carlo methods/simulation used to approximate complex mathematical integrals or expectations that are difficult or impossible to compute analytically.
- Markov Chain Monte Carlo (MCMC): MCMC is a class of algorithms for sampling from a probability distribution when direct sampling is difficult.
- An example is the estimation of pi. But instead of dropping balls, we perform a random walk. So each new point depends only on the previous point, which is a Markov chain.
- A sampling policy (e.g., $\pi(y|x)$) is a prob dist that is actively used to generate samples
- Relationship between Universal Approximation Theorem and the role of data and training: (AI gen) while the UAT tells us about the potential of neural networks to represent complex functions, the amount and quality of data determine our ability to actually find and learn those representations in practice (the learning problem).
Simple MSE Criterion to Amortize an Intractable Expectation
- For the MSE loss below, we use $y \sim p(y|x)$ (y is distributed as $p(y|x)$).
$$L = \left( \widehat{S}(x) – R(x,y) \right)^2$$
- How is R(x, y) computed when I only know the y’s dist but its value? AN: the distribution already tells me the sample space (i.e., values it can take).
- Why are gradients $\frac{\partial L}{\partial \theta}$ stochastic? AN: because $R(x,y)$ is stochastic (it depends on the random variable $y$), and therefore the gradients themselves are random variables. Note that in regular supervised learning training, what sits in R(x, y)’s place is a label, which is NOT a random variable.
- Why would $\widehat{S}$’s convergence requires enough capacity and long enough training? GUESS: the stochastic nature of the gradients. AI GEN AN: Universal Approximation Theorem, which states that a neural network with enough capacity can approximate any continuous function to arbitrary precision. My AN: the mentioning of stochastic gradients implies the assumption of the ability to sample from the true distribution indefinitely and therefore, the assumption of infinite data.
- This is what Bengio always talks about: “For any new x, we would then have an amortized estimator $\widehat{S}(x)$ which in one pass through the network would give us an approximation of the intractable sum. We can consider this an efficient alternative to doing a Monte-Carlo approximation”
- The MC approximation for $S(x) = \sum_y p(y|x) R(x,y)$ is $\widehat{S}{MC}(x) = {\rm mean}{y \sim p(y|x)} R(x,y)$. The steps of approx. are:
- Sample multiple y values from the distribution p(y|x)
- For each sampled y, we compute R(x,y)
- We take the average of all these R(x,y) values. Pay attention to the fact that we are averaging over R(x,y) values ONLY. So the concept of MC approx. of the intractable expectation is really quite simple.
- “Instead of computing the exact intractable expectation, we’re approximating it by taking the average over many random samples”
- The key to understand the amortization is "observing a training set of $(x, R(x,y))$ pairs". With MC, the observation of $(x, R(x,y))$ pairs is done in runtime for a particular $x$. With DL, the observation is done during training. The stored knowledge is then used for inference for particular $x$’s.
- Another advantage of the amortized approach is the ability to exploit the generalizable structure in $p(y|x) R(x,y)$, implying it might not be necessary to train $\widehat{S}$ with an exponential number of examples before it captures that generalizable structure and provides good answers (i.e., $E_{Y|x}[R(x,Y)]$) on new x’s.
- Now if we can’t easily sample from $p(y|x)$, we can use MCMC. However, the difficulty is when the modes of the prob dist occupy a small volume and are well-separated, it takes exponential time for the MCMC to converge to the desired $p(y|x)$. And again the amortized ML samplers can potentially do better than MCMC samplers because of the generalizable structure discussed in the above point.
- How does the generalizable structure help the above two situations?: It helps the first by reducing the potential number of samples needed to train the amortized model. It helps the first by recovering the underlying distribution that is otherwise difficult to get.
GFN Criterion to Obtain a Sampler and Estimate Intractable Sums
- the objective is to approximate a set of intractable sums when $p(y|x)$ isn’t available
- The approximation of this sum helps us get to the normalization constant in the context in energy-based models or evidence (i.e., prob dist of the observed data independently from any param value) in the context of Bayesian inference (i.e., bayesian posteriors).
- $S(x)$ is the intractable sum and normalizing constant
- $\pi(y|x) = \frac{R(x,y)}{S(x)} \propto R(x,y)$ is the result of policy design. There is nothing inherent that makes the sampling policy equal $\frac{R(x,y)}{S(x)}$, even though we do need it to be a proper prob dist. (i.e., normalized). And the reason for such design, or why sampling from π(y|x) is useful is that it allows us to generate samples that reflect the relative magnitudes of R(x,y).
- GFN Loss definition: Based on the above policy design, We can define estimators $\widehat{\pi}$ and $\widehat{S}$ (i.e., they are neural nets) and train them with a loss such as $L(x,y) = \left( \widehat{\pi}(y|x) \widehat{S}(x) – R(x,y)\right)^2$ where (x,y) are sampled from a training distribution $\widetilde{p}(x,y)$. With UAT, the two estimators converge to their desired value with enough capacity and training time.
Marginalizing over Compositional Random Variables
- Marginalization isn’t discussed in this section. Is it?
- How is $\widehat{\pi}(y|x)$ related to a GFN trajectory?
- “The policy $\pi(y|x)$ is now specified by a forward transition distribution $P_F(s|s’)$”
- $y$ in the policy $\pi(y|x)$ represent the fully constructed object. The trajectory, consists of a sequence of states $s$’s, describes the history of the sequential construction of $y$
- “there may be many ways (in fact exponentially many trajectories) to construct y from some starting point and context $x$.”
- summary: starting point $x$ -> GFN trajectory -> fully constructed object $y$
- How is $\widehat{S}(x)$ related to a GFN trajectory?
- “$F(s)$ is called the flow at state $s$ and plays a role similar to $S(x)$ above, i.e., it is an intractable sum”
- Multiple parent states mean the GFN DAG of a fully constructed y isn’t a tree, whose nodes can only have one parent.
- Global (i.e., entire flow graph) vs Local (i.e., a node/GFN state in the flow graph):
- Global flow: $\forall (x,y):\quad \pi(y|x) S(x) = R(x,y)$
- Local flow: $\forall (s,s’):\quad P_F(s|s’) F(s’) = P_B(s’|s) F(s)$. Either side can be thought of as the flow through the edge of $s$ – $s’$
- The graphical equivalence of $F(s_0)=S(x) = \sum_y R(x,y)$ below shows the flow of the initial state $F(s_0)=S(x)$ is the sum of all flow of the terminal states $F(s)$ where $s = y$ (i.e., those with a T on the incoming path). This equation describes the ENTIRETY of the graph.
- $R(x,y)$ seems to be the flow of the terminal state $y$, but I am not entirely sure because the meaning of $R$ isn’t described in the tutorial.
![[Screenshot 2024-07-22 at 08.31.37.png]]
- Deriving the GFN loss from the intractable sum estimation:
Leave a comment