def forward(self, input: torch.Tensor, logdet=None, reverse=False, ft=None): if not reverse: z = input assert z.shape[1] == self.in_channels, (z.shape[1], self.in_channels) # Feature Conditional scaleFt, shiftFt = self.feature_extract(ft, self.fFeatures) z = z + shiftFt z = z * scaleFt logdet = logdet + self.get_logdet(scaleFt) # Self Conditional z1, z2 = self.split(z) scale, shift = self.feature_extract_aff(z1, ft, self.fAffine) self.asserts(scale, shift, z1, z2) z2 = z2 + shift z2 = z2 * scale logdet = logdet + self.get_logdet(scale) z = thops.cat_feature(z1, z2) output = z else: z = input # Self Conditional z1, z2 = self.split(z) scale, shift = self.feature_extract_aff(z1, ft, self.fAffine) self.asserts(scale, shift, z1, z2) z2 = z2 / scale z2 = z2 - shift z = thops.cat_feature(z1, z2) logdet = logdet - self.get_logdet(scale) # Feature Conditional scaleFt, shiftFt = self.feature_extract(ft, self.fFeatures) z = z / scaleFt z = z - shiftFt logdet = logdet - self.get_logdet(scaleFt) output = z return output, logdet
def forward(self, input, logdet=0., reverse=False, eps_std=None, eps=None, ft=None, y_onehot=None): if not reverse: # self.input = input z1, z2 = self.split_ratio(input) mean, logs = self.split2d_prior(z1, ft) eps = (z2 - mean) / self.exp_eps(logs) logdet = logdet + self.get_logdet(logs, mean, z2) # print(logs.shape, mean.shape, z2.shape) # self.eps = eps # print('split, enc eps:', eps) return z1, logdet, eps else: z1 = input mean, logs = self.split2d_prior(z1, ft) if eps is None: #print("WARNING: eps is None, generating eps untested functionality!") eps = GaussianDiag.sample_eps(mean.shape, eps_std) eps = eps.to(mean.device) z2 = mean + self.exp_eps(logs) * eps z = thops.cat_feature(z1, z2) logdet = logdet - self.get_logdet(logs, mean, z2) return z, logdet