def test_gaussian_ll(): # test basic n_batch = 5 n_dims = 3 std = 1 x = torch.rand((n_batch, n_dims)) ll = losses.gaussian_ll(x, x, masks=None, std=std) assert ll == - (0.5 * LN2PI + 0.5 * np.log(std ** 2)) * n_dims # test with masks x = torch.ones(n_batch, n_dims) y = torch.zeros(n_batch, n_dims) m = torch.zeros(n_batch, n_dims) for b in range(n_batch): m[b, 0] = 1 ll = losses.gaussian_ll(x, y, masks=m, std=std) assert ll == - (0.5 * LN2PI + 0.5 * np.log(std ** 2)) * n_dims - (0.5 / (std ** 2))
def test_gaussian_ll_to_mse(): n_batch = 5 n_dims = 3 std = 1 x = torch.ones(n_batch, n_dims) y = torch.zeros(n_batch, n_dims) ll = losses.gaussian_ll(x, y, std=std) mse_ = 2 * (-ll - (0.5 * LN2PI + 0.5 * np.log(std ** 2)) * n_dims) / n_dims mse = losses.gaussian_ll_to_mse(ll.detach().numpy(), n_dims, gaussian_std=std, mse_std=1) assert np.allclose(mse, mse_.detach().numpy())
def loss(self, data, dataset=0, accumulate_grad=True, chunk_size=200): """Calculate modified ELBO loss for PSVAE. The batch is split into chunks if larger than a hard-coded `chunk_size` to keep memory requirements low; gradients are accumulated across all chunks before a gradient step is taken. Parameters ---------- data : :obj:`dict` batch of data; keys should include 'images' and 'masks', if necessary dataset : :obj:`int`, optional used for session-specific io layers accumulate_grad : :obj:`bool`, optional accumulate gradient for training step chunk_size : :obj:`int`, optional batch is split into chunks of this size to keep memory requirements low Returns ------- :obj:`dict` - 'loss' (:obj:`float`): full elbo - 'loss_ll' (:obj:`float`): log-likelihood portion of elbo - 'loss_kl' (:obj:`float`): kl portion of elbo - 'loss_mse' (:obj:`float`): mse (without gaussian constants) - 'beta' (:obj:`float`): weight in front of kl term """ x = data['images'][0] y = data['labels'][0] m = data['masks'][0] if 'masks' in data else None n = data['labels_masks'][0] if 'labels_masks' in data else None batch_size = x.shape[0] n_chunks = int(np.ceil(batch_size / chunk_size)) n_labels = self.hparams['n_labels'] # n_latents = self.hparams['n_ae_latents'] # compute hyperparameters alpha = self.hparams['ps_vae.alpha'] beta = self.beta_vals[self.curr_epoch] gamma = self.hparams['ps_vae.gamma'] kl = self.kl_anneal_vals[self.curr_epoch] loss_strs = [ 'loss', 'loss_data_ll', 'loss_label_ll', 'loss_zs_kl', 'loss_zu_mi', 'loss_zu_tc', 'loss_zu_dwkl', 'loss_AB_orth' ] loss_dict_vals = {loss: 0 for loss in loss_strs} loss_dict_vals['loss_data_mse'] = 0 y_hat_all = [] for chunk in range(n_chunks): idx_beg = chunk * chunk_size idx_end = np.min([(chunk + 1) * chunk_size, batch_size]) x_in = x[idx_beg:idx_end] y_in = y[idx_beg:idx_end] m_in = m[idx_beg:idx_end] if m is not None else None n_in = n[idx_beg:idx_end] if n is not None else None x_hat, sample, mu, logvar, y_hat = self.forward(x_in, dataset=dataset, use_mean=False) # reset losses loss_dict_torch = {loss: 0 for loss in loss_strs} # data log-likelihood loss_dict_torch['loss_data_ll'] = losses.gaussian_ll( x_in, x_hat, m_in) loss_dict_torch['loss'] -= loss_dict_torch['loss_data_ll'] # label log-likelihood loss_dict_torch['loss_label_ll'] = losses.gaussian_ll( y_in, y_hat, n_in) loss_dict_torch['loss'] -= alpha * loss_dict_torch['loss_label_ll'] # supervised latents kl loss_dict_torch['loss_zs_kl'] = losses.kl_div_to_std_normal( mu[:, :n_labels], logvar[:, :n_labels]) loss_dict_torch['loss'] += loss_dict_torch['loss_zs_kl'] # compute all terms of decomposed elbo at once index_code_mi, total_correlation, dimension_wise_kl = losses.decomposed_kl( sample[:, n_labels:], mu[:, n_labels:], logvar[:, n_labels:]) # unsupervised latents index-code mutual information loss_dict_torch['loss_zu_mi'] = index_code_mi loss_dict_torch['loss'] += kl * loss_dict_torch['loss_zu_mi'] # unsupervised latents total correlation loss_dict_torch['loss_zu_tc'] = total_correlation loss_dict_torch['loss'] += beta * loss_dict_torch['loss_zu_tc'] # unsupervised latents dimension-wise kl loss_dict_torch['loss_zu_dwkl'] = dimension_wise_kl loss_dict_torch['loss'] += kl * loss_dict_torch['loss_zu_dwkl'] # orthogonality between A and B # A shape: [n_labels, n_latents] # B shape: [n_latents - n_labels, n_latents] # compute ||AB^T||^2 loss_dict_torch['loss_AB_orth'] = losses.subspace_overlap( self.encoding.A.weight, self.encoding.B.weight) loss_dict_torch['loss'] += gamma * loss_dict_torch['loss_AB_orth'] if accumulate_grad: loss_dict_torch['loss'].backward() # get loss value (weighted by batch size) bs = idx_end - idx_beg for key, val in loss_dict_torch.items(): loss_dict_vals[key] += val.item() * bs loss_dict_vals['loss_data_mse'] += losses.gaussian_ll_to_mse( loss_dict_vals['loss_data_ll'] / bs, np.prod(x.shape[1:])) * bs # collect predicted labels to compute R2 y_hat_all.append(y_hat.cpu().detach().numpy()) # use variance-weighted r2s to ignore small-variance latents y_hat_all = np.concatenate(y_hat_all, axis=0) y_all = y.cpu().detach().numpy() if n is not None: n_np = n.cpu().detach().numpy() r2 = r2_score(y_all[n_np == 1], y_hat_all[n_np == 1], multioutput='variance_weighted') else: r2 = r2_score(y_all, y_hat_all, multioutput='variance_weighted') # compile (properly weighted) loss terms for key in loss_dict_vals.keys(): loss_dict_vals[key] /= batch_size # store hyperparams loss_dict_vals['alpha'] = alpha loss_dict_vals['beta'] = beta loss_dict_vals['gamma'] = gamma loss_dict_vals['label_r2'] = r2 return loss_dict_vals
def loss(self, data, dataset=0, accumulate_grad=True, chunk_size=200): """Calculate (decomposed) ELBO loss for VAE. The batch is split into chunks if larger than a hard-coded `chunk_size` to keep memory requirements low; gradients are accumulated across all chunks before a gradient step is taken. Parameters ---------- data : :obj:`dict` batch of data; keys should include 'images' and 'masks', if necessary dataset : :obj:`int`, optional used for session-specific io layers accumulate_grad : :obj:`bool`, optional accumulate gradient for training step chunk_size : :obj:`int`, optional batch is split into chunks of this size to keep memory requirements low Returns ------- :obj:`dict` - 'loss' (:obj:`float`): full elbo - 'loss_ll' (:obj:`float`): log-likelihood portion of elbo - 'loss_mi' (:obj:`float`): code index mutual info portion of kl of elbo - 'loss_tc' (:obj:`float`): total correlation portion of kl of elbo - 'loss_dwkl' (:obj:`float`): dim-wise kl portion of kl of elbo - 'loss_mse' (:obj:`float`): mse (without gaussian constants) - 'beta' (:obj:`float`): weight in front of kl term """ x = data['images'][0] m = data['masks'][0] if 'masks' in data else None beta = self.beta_vals[self.curr_epoch] kl = self.kl_anneal_vals[self.curr_epoch] batch_size = x.shape[0] n_chunks = int(np.ceil(batch_size / chunk_size)) loss_strs = ['loss', 'loss_ll', 'loss_mi', 'loss_tc', 'loss_dwkl'] loss_dict_vals = {loss: 0 for loss in loss_strs} loss_dict_vals['loss_mse'] = 0 for chunk in range(n_chunks): idx_beg = chunk * chunk_size idx_end = np.min([(chunk + 1) * chunk_size, batch_size]) x_in = x[idx_beg:idx_end] m_in = m[idx_beg:idx_end] if m is not None else None x_hat, sample, mu, logvar = self.forward(x_in, dataset=dataset, use_mean=False) # reset losses loss_dict_torch = {loss: 0 for loss in loss_strs} # data log-likelihood loss_dict_torch['loss_ll'] = losses.gaussian_ll(x_in, x_hat, m_in) loss_dict_torch['loss'] -= loss_dict_torch['loss_ll'] # compute all terms of decomposed elbo at once index_code_mi, total_correlation, dimension_wise_kl = losses.decomposed_kl( sample, mu, logvar) # unsupervised latents index-code mutual information loss_dict_torch['loss_mi'] = index_code_mi loss_dict_torch['loss'] += kl * loss_dict_torch['loss_mi'] # unsupervised latents total correlation loss_dict_torch['loss_tc'] = total_correlation loss_dict_torch['loss'] += beta * loss_dict_torch['loss_tc'] # unsupervised latents dimension-wise kl loss_dict_torch['loss_dwkl'] = dimension_wise_kl loss_dict_torch['loss'] += kl * loss_dict_torch['loss_dwkl'] if accumulate_grad: loss_dict_torch['loss'].backward() # get loss value (weighted by batch size) bs = idx_end - idx_beg for key, val in loss_dict_torch.items(): loss_dict_vals[key] += val.item() * bs loss_dict_vals['loss_mse'] += losses.gaussian_ll_to_mse( loss_dict_vals['loss_ll'] / bs, np.prod(x.shape[1:])) * bs # compile (properly weighted) loss terms for key in loss_dict_vals.keys(): loss_dict_vals[key] /= batch_size # store hyperparams loss_dict_vals['beta'] = beta return loss_dict_vals
def loss(self, data, dataset=0, accumulate_grad=True, chunk_size=200): """Calculate ELBO loss for ConditionalVAE. The batch is split into chunks if larger than a hard-coded `chunk_size` to keep memory requirements low; gradients are accumulated across all chunks before a gradient step is taken. Parameters ---------- data : :obj:`dict` batch of data; keys should include 'images' and 'masks', if necessary dataset : :obj:`int`, optional used for session-specific io layers accumulate_grad : :obj:`bool`, optional accumulate gradient for training step chunk_size : :obj:`int`, optional batch is split into chunks of this size to keep memory requirements low Returns ------- :obj:`dict` - 'loss' (:obj:`float`): full elbo - 'loss_ll' (:obj:`float`): log-likelihood portion of elbo - 'loss_kl' (:obj:`float`): kl portion of elbo - 'loss_mse' (:obj:`float`): mse (without gaussian constants) - 'beta' (:obj:`float`): weight in front of kl term """ x = data['images'][0] y = data['labels'][0] m = data['masks'][0] if 'masks' in data else None if self.hparams['conditional_encoder']: # continuous labels transformed into 2d one-hot array as input to encoder y_2d = data['labels_sc'][0] else: y_2d = None beta = self.beta_vals[self.curr_epoch] batch_size = x.shape[0] n_chunks = int(np.ceil(batch_size / chunk_size)) loss_val = 0 loss_ll_val = 0 loss_kl_val = 0 loss_mse_val = 0 for chunk in range(n_chunks): idx_beg = chunk * chunk_size idx_end = np.min([(chunk + 1) * chunk_size, batch_size]) x_in = x[idx_beg:idx_end] y_in = y[idx_beg:idx_end] m_in = m[idx_beg:idx_end] if m is not None else None y_2d_in = y_2d[idx_beg:idx_end] if y_2d is not None else None x_hat, _, mu, logvar = self.forward(x_in, dataset=dataset, use_mean=False, labels=y_in, labels_2d=y_2d_in) # log-likelihood loss_ll = losses.gaussian_ll(x_in, x_hat, m_in) # kl loss_kl = losses.kl_div_to_std_normal(mu, logvar) # combine loss = -loss_ll + beta * loss_kl if accumulate_grad: loss.backward() # get loss value (weighted by batch size) loss_val += loss.item() * (idx_end - idx_beg) loss_ll_val += loss_ll.item() * (idx_end - idx_beg) loss_kl_val += loss_kl.item() * (idx_end - idx_beg) loss_mse_val += losses.gaussian_ll_to_mse( loss_ll.item(), np.prod(x.shape[1:])) * (idx_end - idx_beg) loss_val /= batch_size loss_ll_val /= batch_size loss_kl_val /= batch_size loss_mse_val /= batch_size loss_dict = { 'loss': loss_val, 'loss_ll': loss_ll_val, 'loss_kl': loss_kl_val, 'loss_mse': loss_mse_val, 'beta': beta } return loss_dict