def forward(self, x, *, u=None, dequant_logd=None): assert (u is None) == (dequant_logd is None) if u is None: u, dequant_logd = self.calc_dequant_noise(x) assert u.shape == x.shape and dequant_logd.shape == (x.shape[0], ) assert (u >= 0).all() and (u <= 1).all() z, main_logd = self.main_flow(x + u, aux=None, inverse=False) z_logp = sumflat(standard_normal_logp(z)) total_logd = dequant_logd + main_logd + z_logp assert z.shape[0] == x.shape[0] and z.numel() == x.numel() assert main_logd.shape == dequant_logd.shape == total_logd.shape == z_logp.shape == ( x.shape[0], ) return { 'u': u, 'z': z, 'total_logd': total_logd, 'dequant_logd': dequant_logd, }
def calc_dequant_noise(self, x): eps = torch.randn_like(x) u, dequant_logd = self.dequant_flow(eps=eps, aux=x, inverse=False) assert u.shape == x.shape and dequant_logd.shape == (x.shape[0], ) return u, dequant_logd - sumflat(standard_normal_logp(eps))