def p_sample(self, xt: torch.Tensor, t: torch.Tensor, eps_theta: torch.Tensor): """ #### Sample from $\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)$ \begin{align} \textcolor{lightgreen}{p_\theta}(x_{t-1} | x_t) &= \mathcal{N}\big(x_{t-1}; \textcolor{lightgreen}{\mu_\theta}(x_t, t), \sigma_t^2 \mathbf{I} \big) \\ \textcolor{lightgreen}{\mu_\theta}(x_t, t) &= \frac{1}{\sqrt{\alpha_t}} \Big(x_t - \frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\textcolor{lightgreen}{\epsilon_\theta}(x_t, t) \Big) \end{align} """ # [gather](utils.html) $\bar\alpha_t$ alpha_bar = gather(self.alpha_bar, t) # $\alpha_t$ alpha = gather(self.alpha, t) # $\frac{\beta}{\sqrt{1-\bar\alpha_t}}$ eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5 # $$\frac{1}{\sqrt{\alpha_t}} \Big(x_t - # \frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\textcolor{lightgreen}{\epsilon_\theta}(x_t, t) \Big)$$ mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta) # $\sigma^2$ var = gather(self.sigma2, t) # $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$ eps = torch.randn(xt.shape, device=xt.device) # Sample return mean + (var ** .5) * eps
def p_x0(self, xt: torch.Tensor, t: torch.Tensor, eps: torch.Tensor): """ #### Estimate $x_0$ $$x_0 \approx \hat{x}_0 = \frac{1}{\sqrt{\bar\alpha}} \Big( x_t - \sqrt{1 - \bar\alpha_t} \textcolor{lightgreen}{\epsilon_\theta}(x_t, t) \Big)$$ """ # [gather](utils.html) $\bar\alpha_t$ alpha_bar = gather(self.alpha_bar, t) # $$x_0 \approx \hat{x}_0 = \frac{1}{\sqrt{\bar\alpha}} # \Big( x_t - \sqrt{1 - \bar\alpha_t} \textcolor{lightgreen}{\epsilon_\theta}(x_t, t) \Big)$$ return (xt - (1 - alpha_bar) ** 0.5 * eps) / (alpha_bar ** 0.5)