def _kld(self, z, q_param, p_param=None): """ Computes the KL-divergence of some element z. KL(q||p) = -∫ q(z) log [ p(z) / q(z) ] = -E[log p(z) - log q(z)] :param z: sample from q-distribuion :param q_param: (mu, log_var) of the q-distribution :param p_param: (mu, log_var) of the p-distribution :return: KL(q||p) """ (q_mu, q_log_var) = q_param qz = log_gaussian(z, q_mu, q_log_var) if p_param is None: pz = log_standard_gaussian(z) else: (p_mu, p_log_var) = p_param pz = log_gaussian(z, p_mu, p_log_var) kl = qz - pz return kl
def _kld(self, z, q_param, i, h_last, p_param=None, sylvester_params=None, auxiliary=False): """ Computes the KL-divergence of some element z. KL(q||p) = -∫ q(z) log [ p(z) / q(z) ] = -E[log p(z) - log q(z)] :param z: sample from q-distribuion :param q_param: (mu, log_var) of the q-distribution :param p_param: (mu, log_var) of the p-distribution :return: KL(q||p) """ if self.flow_type == "nf" and self.n_flows > 0: (mu, log_var) = q_param if not auxiliary: f_z, log_det_z = self.flow(z, i, False) else: f_z, log_det_z = self.flow_a(z, i, True) qz = log_gaussian(z, mu, log_var) - sum(log_det_z) z = f_z elif self.flow_type in ["hf", "ccLinIAF"] and self.n_flows > 0: (mu, log_var) = q_param if not auxiliary: f_z = self.flow(z, i, h_last, False) else: f_z = self.flow_a(z, i, h_last, True) qz = log_gaussian(z, mu, log_var) z = f_z elif self.flow_type in ["o-sylvester", "h-sylvester", "t-sylvester"] and self.n_flows > 0: mu, log_var, r1, r2, q_ortho, b = q_param if not auxiliary: f_z = self.flow(z, r1, r2, q_ortho, b, i, False) else: f_z = self.flow_a(z, r1, r2, q_ortho, b, i, True) qz = log_gaussian(z, mu, log_var) z = f_z else: (mu, log_var) = q_param qz = log_gaussian(z, mu, log_var) if p_param is None: pz = log_standard_gaussian(z) else: (mu, log_var) = p_param pz = log_gaussian(z, mu, log_var) kl = qz - pz return kl
def binary_loss_array(recon_x, x, z_mu, z_var, z_0, z_k, ldj, beta=1.): """ Computes the binary loss without averaging or summing over the batch dimension. """ batch_size = x.size(0) # if not summed over batch_dimension if len(ldj.size()) > 1: ldj = ldj.view(ldj.size(0), -1).sum(-1) # TODO: upgrade to newest pytorch version on master branch, there the nn.BCELoss comes with the option # reduce, which when set to False, does no sum over batch dimension. bce = -log_bernoulli( x.view(batch_size, -1), recon_x.view(batch_size, -1), dim=1) # ln p(z_k) (not averaged) log_p_zk = log_standard_gaussian(z_k, dim=1) # ln q(z_0) (not averaged) log_q_z0 = log_gaussian(z_0, mean=z_mu, log_var=z_var.log(), dim=1) # ln q(z_0) - ln p(z_k) ] logs = log_q_z0 - log_p_zk loss = bce + beta * (logs - ldj) return loss
def calculate_losses(self, data, lambda1=0., lambda2=0., beta=1., likelihood=F.mse_loss): if self.ladder: ladder = "ladder" else: ladder = "not_ladder" self.images_path = self.results_path + "/images/examples/generative/" + ladder + "/" + self.flavour + "/" create_missing_folders(self.images_path) data = torch.tanh(data) if self.flow_type in ["o-sylvester", "t-sylvester", "h-sylvester" ] and not self.ladder: z_q = {0: None, 1: None} reconstruction, mu, log_var, self.log_det_j, z_q[0], z_q[ -1] = self.run_sylvester(data, auxiliary=self.auxiliary) log_p_zk = log_standard_gaussian(z_q[-1]) # ln q(z_0) (not averaged) # mu, log_var, r1, r2, q, b = q_param_inverse log_q_z0 = log_gaussian(z_q[0], mu, log_var=log_var) - self.log_det_j # N E_q0[ ln q(z_0) - ln p(z_k) ] self.kl_divergence = log_q_z0 - log_p_zk del log_q_z0, log_p_zk else: reconstruction, z_q = self(data) kl = beta * self.kl_divergence likelihood = torch.sum(likelihood(reconstruction, data.float(), reduce=False), dim=-1) if self.ladder: params = torch.cat( [x.view(-1) for x in self.reconstruction.parameters()]) else: params = torch.cat( [x.view(-1) for x in self.decoder.reconstruction.parameters()]) l1_regularization = lambda1 * torch.norm(params, 1).cuda() l2_regularization = lambda2 * torch.norm(params, 2).cuda() try: assert l1_regularization >= 0. and l2_regularization >= 0. except: print(l1_regularization, l2_regularization) loss = torch.mean(likelihood + kl.cuda() + l1_regularization + l2_regularization) del data, params, l1_regularization, l2_regularization, lambda1, lambda2 return loss, torch.mean(likelihood), torch.mean( kl), reconstruction, z_q
def multinomial_loss_function(x_logit, x, z_mu, z_var, z_0, z_k, ldj, args, beta=1.): """ Computes the cross entropy loss function while summing over batch dimension, not averaged! :param x_logit: shape: (batch_size, num_classes * num_channels, pixel_width, pixel_height), real valued logits :param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1]. :param z_mu: mean of z_0 :param z_var: variance of z_0 :param z_0: first stochastic latent variable :param z_k: last stochastic latent variable :param ldj: log det jacobian :param args: global parameter settings :param beta: beta for kl loss :return: loss, ce, kl """ num_classes = 256 batch_size = x.size(0) x_logit = x_logit.view(batch_size, num_classes, args.input_size[0], args.input_size[1], args.input_size[2]) # make integer class labels target = (x * (num_classes - 1)).long() # - N E_q0 [ ln p(x|z_k) ] # sums over batch dimension (and feature dimension) ce = cross_entropy(x_logit, target, size_average=False) # ln p(z_k) (not averaged) log_p_zk = log_standard_gaussian(z_k, dim=1) # ln q(z_0) (not averaged) log_q_z0 = log_gaussian(z_0, mean=z_mu, log_var=z_var.log(), dim=1) # N E_q0[ ln q(z_0) - ln p(z_k) ] summed_logs = torch.sum(log_q_z0 - log_p_zk) # sum over batches summed_ldj = torch.sum(ldj) # ldj = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ] kl = (summed_logs - summed_ldj) loss = ce + beta * kl loss /= float(batch_size) ce /= float(batch_size) kl /= float(batch_size) return loss, ce, kl
def multinomial_loss_array(x_logit, x, z_mu, z_var, z_0, z_k, ldj, args, beta=1.): """ Computes the discritezed logistic loss without averaging or summing over the batch dimension. """ num_classes = 256 batch_size = x.size(0) x_logit = x_logit.view(batch_size, num_classes, args.input_size[0], args.input_size[1], args.input_size[2]) # make integer class labels target = (x * (num_classes - 1)).long() # - N E_q0 [ ln p(x|z_k) ] # computes cross entropy over all dimensions separately: ce = cross_entropy(x_logit, target, size_average=False, reduce=False) # sum over feature dimension ce = ce.view(batch_size, -1).sum(dim=1) # ln p(z_k) (not averaged) log_p_zk = log_standard_gaussian(z_k.view(batch_size, -1), dim=1) # ln q(z_0) (not averaged) log_q_z0 = log_gaussian(z_0.view(batch_size, -1), mean=z_mu.view(batch_size, -1), log_var=z_var.log().view(batch_size, -1), dim=1) # ln q(z_0) - ln p(z_k) ] logs = log_q_z0 - log_p_zk loss = ce + beta * (logs - ldj) return loss
def mse_loss_function(recon_x, x, z_mu, z_var, z_0, z_k, ldj, beta=1.): """ Computes the binary loss function while summing over batch dimension, not averaged! :param recon_x: shape: (batch_size, num_channels, pixel_width, pixel_height), bernoulli parameters p(x=1) :param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1]. :param z_mu: mean of z_0 :param z_var: variance of z_0 :param z_0: first stochastic latent variable :param z_k: last stochastic latent variable :param ldj: log det jacobian :param beta: beta for kl loss :return: loss, ce, kl """ x = torch.tanh(x) reconstruction_function = nn.MSELoss(size_average=False, reduce=False) batch_size = x.size(0) # - N E_q0 [ ln p(x|z_k) ] mse = Variable(torch.sum(reconstruction_function(recon_x, x), dim=-1)) # ln p(z_k) (not averaged) log_p_zk = log_standard_gaussian(z_k) # ln q(z_0) (not averaged) log_q_z0 = log_gaussian(z_0, z_mu, log_var=z_var) - ldj # N E_q0[ ln q(z_0) - ln p(z_k) ] kl = abs(log_q_z0 - log_p_zk) # sum over batches # ldj = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ] loss = mse + beta * kl loss = torch.sum(loss) mse = torch.sum(mse) kl = torch.sum(kl) loss /= float(batch_size) mse /= float(batch_size) kl /= float(batch_size) return loss, mse, kl
def binary_loss_function(recon_x, x, z_mu, z_var, z_0, z_k, ldj, beta=1.): """ Computes the binary loss function while summing over batch dimension, not averaged! :param recon_x: shape: (batch_size, num_channels, pixel_width, pixel_height), bernoulli parameters p(x=1) :param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1]. :param z_mu: mean of z_0 :param z_var: variance of z_0 :param z_0: first stochastic latent variable :param z_k: last stochastic latent variable :param ldj: log det jacobian :param beta: beta for kl loss :return: loss, ce, kl """ reconstruction_function = nn.BCELoss() reconstruction_function.size_average = False batch_size = x.size(0) # - N E_q0 [ ln p(x|z_k) ] bce = reconstruction_function(recon_x, x) # ln p(z_k) (not averaged) log_p_zk = log_standard_gaussian(z_k, dim=1) # ln q(z_0) (not averaged) log_q_z0 = log_gaussian(z_0, mean=z_mu, log_var=z_var.log(), dim=1) # N E_q0[ ln q(z_0) - ln p(z_k) ] summed_logs = torch.sum(log_q_z0 - log_p_zk) # sum over batches summed_ldj = torch.sum(ldj) # ldj = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ] kl = (summed_logs - summed_ldj) loss = bce + beta * kl loss /= float(batch_size) bce /= float(batch_size) kl /= float(batch_size) return loss, bce, kl
def run_sylvester(self, x, y=torch.Tensor([]).cuda(), a=torch.Tensor([]).cuda(), k=0, auxiliary=True): """ Forward pass with orthogonal sylvester flows for the transformation z_0 -> z_1 -> ... -> z_k. Log determinant is computed as log_det_j = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ]. """ if len(x.shape) == 2: x = x.view(-1, self.input_shape[0], self.input_shape[1], self.input_shape[2]) self.log_det_j = 0. (z_mu, z_var, r1, r2, q, b), x, z_q = self.encode(x, y, a, i=k, auxiliary=auxiliary) # Orthogonalize all q matrices if self.flow_type == "o-sylvester": q_ortho = self.batch_construct_orthogonal(q, auxiliary) elif self.flow_type == "h-sylvester": q_ortho = self.batch_construct_householder_orthogonal(q, auxiliary) else: q_ortho = None # Sample z_0 z = [self.reparameterize(z_mu, z_var)] # Normalizing flows for i in range(self.n_flows): flow_k = getattr( self, 'flow_' + str(k) + "_" + str(i) + "_" + str(auxiliary)) if self.flow_type in ["o-sylvester"]: try: z_k, log_det_jacobian = flow_k(zk=z[i], r1=r1[:, :, :, i], r2=r2[:, :, :, i], q_ortho=q_ortho[i, :, :, :], b=b[:, :, :, i]) except: z_k, log_det_jacobian = flow_k(zk=z[:, i], r1=r1[:, :, :, i], r2=r2[:, :, :, i], q_ortho=q_ortho[i, :, :, :], b=b[:, :, :, i]) elif self.flow_type in ["h-sylvester"]: q_k = q_ortho[i] z_k, log_det_jacobian = flow_k(z[i], r1[:, :, :, i], r2[:, :, :, i], q_k, b[:, :, :, i]) elif self.flow_type in ["t-sylvester"]: if k % 2 == 1: # Alternate with reorderering z for triangular flow permute_z = self.flip_idx else: permute_z = None z_k, log_det_jacobian = flow_k(zk=z[i], r1=r1[:, :, :, i], r2=r2[:, :, :, i], b=b[:, :, :, i], permute_z=permute_z, sum_ldj=True, auxiliary=auxiliary) else: exit("Non implemented") z.append(z_k) self.log_det_j += log_det_jacobian log_p_zk = log_standard_gaussian(z[-1]) # ln q(z_0) (not averaged) # mu, log_var, r1, r2, q, b = q_param_inverse log_q_z0 = log_gaussian(z[0], z_mu, log_var=z_var) - self.log_det_j # N E_q0[ ln q(z_0) - ln p(z_k) ] self.kl_divergence = log_q_z0 - log_p_zk if auxiliary: x_mean = None else: #if len(y) == 0: x_mean = self.sample(z[-1], y) return x_mean, z_mu, z_var, self.log_det_j, z[0], z[-1]
def forward(self, x, y=torch.Tensor([]).cuda(), a=torch.Tensor([]).cuda(), k=0, auxiliary=False): """ Forward pass with orthogonal sylvester flows for the transformation z_0 -> z_1 -> ... -> z_k. Log determinant is computed as log_det_j = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ]. """ self.log_det_j = 0. (z_mu, z_var, r1, r2, q, b), x, z_q = self.encode(torch.cat([x, y], 1), auxiliary=auxiliary) self.sylvester_params = (r1, r2, q, b) if self.flow_type == "o-sylvester": q_ortho = self.batch_construct_orthogonal(q) elif self.flow_type == "h-sylvester": q_ortho = self.batch_construct_householder_orthogonal(q) else: q_ortho = None # Sample z_0 z = [self.reparameterize(z_mu, z_var)] # Normalizing flows for i in range(self.n_flows): flow_k = getattr( self, 'flow_' + str(k) + "_" + str(i) + "_" + str(auxiliary)) if self.flow_type in ["o-sylvester"]: z_k, log_det_jacobian = flow_k(z[i], r1[:, :, :, i], r2[:, :, :, i], q_ortho[i, :, :, :], b[:, :, :, i]) elif self.flow_type in ["h-sylvester"]: q_k = q_ortho[i] z_k, log_det_jacobian = flow_k(z[i], r1[:, :, :, i], r2[:, :, :, i], q_k, b[:, :, :, i]) elif self.flow_type in ["t-sylvester"]: if k % 2 == 1: # Alternate with reorderering z for triangular flow permute_z = self.flip_idx else: permute_z = None z_k, log_det_jacobian = flow_k(z[i], r1[:, :, :, i], r2[:, :, :, i], b[:, :, :, i], permute_z, sum_ldj=True) else: exit("Non implemented") z.append(z_k) self.log_det_j += log_det_jacobian log_p_zk = log_standard_gaussian(z[-1]) # ln q(z_0) (not averaged) # mu, log_var, r1, r2, q, b = q_param_inverse log_q_z0 = log_gaussian(z[0], z_mu, log_var=z_var) - self.log_det_j # N E_q0[ ln q(z_0) - ln p(z_k) ] self.model.kl_divergence = log_q_z0 - log_p_zk x_mean, _ = self.sample(z[-1], y) return x_mean, z_mu, z_var, self.log_det_j, z[0], z[-1]
def log_gaussians(self, x, mus, logvars): G = [] for c in range(self.n_centroids): G.append(log_gaussian(x, mus[c:c + 1, :], logvars[c:c + 1,:]).view(-1, 1)) return torch.cat(G, 1)
def forward(self, x, i=None): # Gather latent representation # from encoders along with final z. latents = [] x = torch.tanh(x) if self.flow_type in ["o-sylvester", "h-sylvester", "t-sylvester"]: for i in range(len(self.encoder)): q_param, x, z = self.encoder(x, i) latents.append(q_param) else: for i, encoder in enumerate(self.encoder): q_param, x = encoder(x) z = q_param[0] q_param = q_param[1:] latents.append(q_param) latents = list(reversed(latents)) kl_divergence = 0 h = x self.log_det_j = 0 for k, decoder in enumerate([-1, *self.decoder]): # If at top, encode == decoder, # use prior for KL. q_param = latents[k] if self.sylvester_flow: mu, log_var, r1, r2, q, b = q_param if k > 0: z = [self.reparameterize(mu, log_var)] else: z = [z] l = -1 - k q_ortho = self.batch_construct_orthogonal(q, l) # Sample z_0 # Normalizing flows for i in range(self.n_flows): flow_k = getattr(self, 'flow_' + str(k) + "_" + str(i)) z_k, log_det_jacobian = flow_k(z[i], r1[:, :, :, i], r2[:, :, :, i], q_ortho[i, :, :, :], b[:, :, :, i]) z.append(z_k) self.log_det_j += log_det_jacobian # KL log_p_zk = log_standard_gaussian(z[-1]) # ln q(z_0) (not averaged) #mu, log_var, r1, r2, q, b = q_param_inverse log_q_z0 = log_gaussian(z[0], mu, log_var=log_var) - self.log_det_j # N E_q0[ ln q(z_0) - ln p(z_k) ] kl = log_q_z0 - log_p_zk kl_divergence += kl # x_mean = self.sample(z[-1]) elif k == 0: kl_divergence += self._kld(z, q_param=q_param, i=k, h_last=h).abs() else: #q = (q_param_inverse[0], q_param_inverse[1]) (mu, log_var) = q_param z, kl = decoder(z, mu, log_var) (q_z, q_param, p_param) = kl kl_divergence += self._kld(z, q_param=q_param, i=k, h_last=h, p_param=p_param).abs() try: x_mu = self.reconstruction(z) except: x_mu = self.reconstruction(z[-1]) del latents, x, self.log_det_j, r1, r2, q, b, q_ortho, q_param self.kl_divergence = Variable(kl_divergence) return x_mu, z