def reverse_flow(self, estimator, flow_embed, multgate, i_flow, input, logdet): assert input.size(1) % 2 == 0 flow_embed = flow_embed.expand( (input.shape[0], flow_embed.shape[1], flow_embed.shape[2], flow_embed.shape[3])) flow_embed, _ = self.actnorm_embed( flow_embed, reverse=False) # NOTE: reverse=False is the correct usage for this # 1.coupling z1, z2 = split_feature(input, "split") if self.flow_coupling == "additive": out = estimator(flow_embed, multgate, z1) out = self.conv_proj(out) z2 = z2 - out elif self.flow_coupling == "affine": # h = self.block(z1) h = estimator(flow_embed, multgate, z1) h = self.conv_proj(h) shift, scale = split_feature(h, "cross") scale = torch.sigmoid(scale + 2.) z2 = z2 / scale z2 = z2 - shift logdet = -torch.sum(torch.log(scale), dim=[1, 2, 3]) + logdet z = torch.cat((z1, z2), dim=1) # 2. permute z, logdet = self.flow_permutation(z, logdet, True) # 3. actnorm z, logdet = self.actnorm(z, logdet=logdet, reverse=True) return z, logdet
def normal_flow(self, estimator, flow_embed, multgate, i_flow, input, logdet): assert input.size(1) % 2 == 0 # 1. actnorm z, logdet = self.actnorm(input, logdet=logdet, reverse=False) # 1.1 also actnorm embed flow_embed = flow_embed.expand( (z.shape[0], flow_embed.shape[1], flow_embed.shape[2], flow_embed.shape[3])) flow_embed, _ = self.actnorm_embed(flow_embed) # 2. permute z, logdet = self.flow_permutation(z, logdet, False) # 3. coupling z1, z2 = split_feature(z, "split") if self.flow_coupling == "additive": z1_out = estimator(flow_embed, multgate, z1) z1_out = self.conv_proj(z1_out) z2 = z2 + z1_out elif self.flow_coupling == "affine": h = estimator(flow_embed, multgate, z1) h = self.conv_proj(h) shift, scale = split_feature(h, "cross") scale = torch.sigmoid(scale + 2.) z2 = z2 + shift z2 = z2 * scale logdet = torch.sum(torch.log(scale), dim=[1, 2, 3]) + logdet z = torch.cat((z1, z2), dim=1) return z, logdet
def normal_flow(self, input, y_onehot, logdet): assert input.size(1) % 2 == 0 # 1. actnorm z, logdet = self.actnorm(input, logdet=logdet, reverse=False) # 2. permute z, logdet = self.flow_permutation(z, logdet, False) # 3. coupling z1, z2 = split_feature(z, "split") if self.flow_coupling == "additive": if self.extra_condition: z2 = z2 + self.block((z1, y_onehot)) else: z2 = z2 + self.block(z1) elif self.flow_coupling == "affine": if self.extra_condition: h = self.block((z1, y_onehot)) else: h = self.block(z1) shift, scale = split_feature(h, "cross") scale = torch.sigmoid(scale + 2.) z2 = z2 + shift z2 = z2 * scale logdet = torch.sum(torch.log(scale), dim=[1, 2, 3]) + logdet z = torch.cat((z1, z2), dim=1) return z, logdet
def normal_flow(self, x, y_onehot): b, c, h, w = x.shape x, logdet = uniform_binning_correction(x) z, objective = self.flow(x, logdet=logdet, reverse=False) # z_size = C mean, logs = self.prior( x, y_onehot) # if condition, mean_size = C//2; else, mean_size = C if self.y_condition: z_y, z_n = split_feature(z, "split") self.mean_normal = torch.zeros(z_y.shape).cuda() self.logs_normal = torch.ones(z_y.shape).cuda() y_logits = self.project_class(z_y.mean(2).mean(2)) objective += gaussian_likelihood( mean, logs, z_y) + gaussian_likelihood(self.mean_normal, self.logs_normal, z_n) else: objective += gaussian_likelihood(mean, logs, z) y_logits = None # Full objective - converted to bits per dimension bpd = (-objective) / (math.log(2.) * c * h * w) return z, bpd, y_logits
def prior(self, data, y_onehot=None, batch_size=32): if data is not None: h = self.prior_h.repeat(data.shape[0], 1, 1, 1) else: # Hardcoded a batch size of 32 here h = self.prior_h.repeat(batch_size, 1, 1, 1) channels = h.size(1) if self.learn_top: h = self.learn_top_fn(h) if self.y_condition: assert y_onehot is not None yp = self.project_ycond(y_onehot) if data is not None: h += yp.view(data.shape[0], channels, 1, 1) else: print("no data") h += yp.view(batch_size, channels, 1, 1) if self.yd_condition: assert y_onehot is not None yp = self.project_ycond(y_onehot) if data is not None: h += yp.view(data.shape[0], channels, 1, 1) else: print("no data") h += yp.view(batch_size, channels, 1, 1) return split_feature(h, "split")
def normal_flow(self, input, logdet): assert input.size(1) % 2 == 0 # 1. actnorm if not self.no_actnorm: z, logdet = self.actnorm(input, logdet=logdet, reverse=False) else: z = input # 2. permute z, logdet = self.flow_permutation(z, logdet, False) # 3. coupling z1, z2 = split_feature(z, "split") if self.flow_coupling == "additive": z2 = z2 + self.block(z1) elif self.flow_coupling == "affine": h = self.block(z1) shift, scale = split_feature(h, "cross") scale = torch.sigmoid(scale + self.affine_scale_eps) + self.affine_eps z2 = z2 + shift z2 = z2 * scale logdet = torch.sum(torch.log(scale), dim=[1, 2, 3]) + logdet elif self.flow_coupling == "gaffine": h = self.block(z1) shift, scale = split_feature(h, "cross") scale = torch.exp(np.log(self.max_scale) * torch.tanh(scale)) self.last_scale = scale z2 = z2 + shift z2 = z2 * scale logdet = torch.sum(torch.log(scale), dim=[1, 2, 3]) + logdet elif self.flow_coupling == "naffine": h = self.block(z1) shift, scale = split_feature(h, "cross") eps = self.affine_eps scale = (2 * torch.sigmoid(scale) - 1) * (1 - eps) + 1 z2 = z2 + shift z2 = z2 * scale logdet = torch.sum(torch.log(scale), dim=[1, 2, 3]) + logdet z = torch.cat((z1, z2), dim=1) return z, logdet
def forward(self, input, logdet=0., reverse=False, temperature=None): if reverse: z1 = input mean, logs = self.split2d_prior(z1) z2 = gaussian_sample(mean, logs, temperature) z = torch.cat((z1, z2), dim=1) return z, logdet else: z1, z2 = split_feature(input, "split") mean, logs = self.split2d_prior(z1) logdet = gaussian_likelihood(mean, logs, z2) + logdet return z1, logdet
def reverse_flow(self, input, logdet): assert input.size(1) % 2 == 0 # 1.coupling z1, z2 = split_feature(input, "split") if self.flow_coupling == "additive": z2 = z2 - self.block(z1) elif self.flow_coupling == "affine": h = self.block(z1) shift, scale = split_feature(h, "cross") scale = torch.sigmoid(scale + 2.) z2 = z2 / scale z2 = z2 - shift logdet = -torch.sum(torch.log(scale), dim=[1, 2, 3]) + logdet z = torch.cat((z1, z2), dim=1) # 2. permute z, logdet = self.flow_permutation(z, logdet, True) # 3. actnorm z, logdet = self.actnorm(z, logdet=logdet, reverse=True) return z, logdet
def split2d_prior(self, z, condition): if self.sp_condition is False: # print('no split prior') h = self.conv(z) else: print('split prior') cond = self.cond_fc(condition) cond = cond.view(z.size()) cond = self.cond_conv(cond) cond = F.relu(cond, inplace=False) z = torch.cat([z, cond], dim=1) h = self.conv(z) return split_feature(h, "cross")
def normal_flow(self, estimator, flow_embed, i_flow, input, logdet): assert input.size(1) % 2 == 0 # 1. actnorm z, logdet = self.actnorm(input, logdet=logdet, reverse=False) # 2. permute z, logdet = self.flow_permutation(z, logdet, False) # 3. coupling z1, z2 = split_feature(z, "split") if self.flow_coupling == "additive": z1_out = estimator(flow_embed, z1) z2 = z2 + z1_out elif self.flow_coupling == "affine": h = estimator(flow_embed, z1) shift, scale = split_feature(h, "cross") scale = torch.sigmoid(scale + 2.) z2 = z2 + shift z2 = z2 * scale logdet = torch.sum(torch.log(scale), dim=[1, 2, 3]) + logdet z = torch.cat((z1, z2), dim=1) return z, logdet
def forward(self, input, zn, logdet=0., reverse=False, temperature=None): if reverse: # z1 = input z1 = input z2 = zn.pop() mean, logs = self.split2d_prior(z1) # z2 = gaussian_sample(mean, logs, temperature) z = torch.cat((z1, z2), dim=1) return z, logdet, zn else: z1, z2 = split_feature(input, "split") zn.append(z2.clone()) mean, logs = self.split2d_prior(z1) logdet = gaussian_likelihood(mean, logs, z2) + logdet return z1, logdet, zn
def prior(self, data, y_onehot=None): if data is not None: h = self.prior_h.repeat(data.shape[0], 1, 1, 1) else: # Hardcoded a batch size of 32 here h = self.prior_h.repeat(32, 1, 1, 1) channels = h.size(1) if self.learn_top: # import pdb # pdb.set_trace() h = self.learn_top_fn(h) if self.y_condition: assert y_onehot is not None y_onehot = y_onehot.float() yp = self.project_ycond(y_onehot) h += yp.view(data.shape[0], channels, 1, 1) return split_feature(h, "split")
def forward(self, input, logdet=0., reverse=False, temperature=None): if reverse: z1 = input mean, logs = self.split2d_prior(z1) if self.use_last: self._last_z2.requires_grad_() z2 = self._last_z2 self.use_last = False else: z2 = gaussian_sample(mean, logs, temperature) self._last_z2 = z2#.clone() z = torch.cat((z1, z2), dim=1) return z, logdet else: z1, z2 = split_feature(input, "split") self._last_z2 = z2.clone() mean, logs = self.split2d_prior(z1) d = gaussian_likelihood(mean, logs, z2) logdet = d + logdet self._last_logdet = d return z1, logdet
def split2d_prior(self, z): h = self.conv(z) return split_feature(h, "cross")