Ejemplo n.º 1
0
def test_gaussian_ll():

    # test basic
    n_batch = 5
    n_dims = 3
    std = 1
    x = torch.rand((n_batch, n_dims))
    ll = losses.gaussian_ll(x, x, masks=None, std=std)
    assert ll == - (0.5 * LN2PI + 0.5 * np.log(std ** 2)) * n_dims

    # test with masks
    x = torch.ones(n_batch, n_dims)
    y = torch.zeros(n_batch, n_dims)
    m = torch.zeros(n_batch, n_dims)
    for b in range(n_batch):
        m[b, 0] = 1
    ll = losses.gaussian_ll(x, y, masks=m, std=std)
    assert ll == - (0.5 * LN2PI + 0.5 * np.log(std ** 2)) * n_dims - (0.5 / (std ** 2))
Ejemplo n.º 2
0
def test_gaussian_ll_to_mse():

    n_batch = 5
    n_dims = 3
    std = 1
    x = torch.ones(n_batch, n_dims)
    y = torch.zeros(n_batch, n_dims)
    ll = losses.gaussian_ll(x, y, std=std)
    mse_ = 2 * (-ll - (0.5 * LN2PI + 0.5 * np.log(std ** 2)) * n_dims) / n_dims
    mse = losses.gaussian_ll_to_mse(ll.detach().numpy(), n_dims, gaussian_std=std, mse_std=1)
    assert np.allclose(mse, mse_.detach().numpy())
Ejemplo n.º 3
0
    def loss(self, data, dataset=0, accumulate_grad=True, chunk_size=200):
        """Calculate modified ELBO loss for PSVAE.

        The batch is split into chunks if larger than a hard-coded `chunk_size` to keep memory
        requirements low; gradients are accumulated across all chunks before a gradient step is
        taken.

        Parameters
        ----------
        data : :obj:`dict`
            batch of data; keys should include 'images' and 'masks', if necessary
        dataset : :obj:`int`, optional
            used for session-specific io layers
        accumulate_grad : :obj:`bool`, optional
            accumulate gradient for training step
        chunk_size : :obj:`int`, optional
            batch is split into chunks of this size to keep memory requirements low

        Returns
        -------
        :obj:`dict`
            - 'loss' (:obj:`float`): full elbo
            - 'loss_ll' (:obj:`float`): log-likelihood portion of elbo
            - 'loss_kl' (:obj:`float`): kl portion of elbo
            - 'loss_mse' (:obj:`float`): mse (without gaussian constants)
            - 'beta' (:obj:`float`): weight in front of kl term

        """

        x = data['images'][0]
        y = data['labels'][0]
        m = data['masks'][0] if 'masks' in data else None
        n = data['labels_masks'][0] if 'labels_masks' in data else None
        batch_size = x.shape[0]
        n_chunks = int(np.ceil(batch_size / chunk_size))
        n_labels = self.hparams['n_labels']
        # n_latents = self.hparams['n_ae_latents']

        # compute hyperparameters
        alpha = self.hparams['ps_vae.alpha']
        beta = self.beta_vals[self.curr_epoch]
        gamma = self.hparams['ps_vae.gamma']
        kl = self.kl_anneal_vals[self.curr_epoch]

        loss_strs = [
            'loss', 'loss_data_ll', 'loss_label_ll', 'loss_zs_kl',
            'loss_zu_mi', 'loss_zu_tc', 'loss_zu_dwkl', 'loss_AB_orth'
        ]

        loss_dict_vals = {loss: 0 for loss in loss_strs}
        loss_dict_vals['loss_data_mse'] = 0

        y_hat_all = []

        for chunk in range(n_chunks):

            idx_beg = chunk * chunk_size
            idx_end = np.min([(chunk + 1) * chunk_size, batch_size])

            x_in = x[idx_beg:idx_end]
            y_in = y[idx_beg:idx_end]
            m_in = m[idx_beg:idx_end] if m is not None else None
            n_in = n[idx_beg:idx_end] if n is not None else None
            x_hat, sample, mu, logvar, y_hat = self.forward(x_in,
                                                            dataset=dataset,
                                                            use_mean=False)

            # reset losses
            loss_dict_torch = {loss: 0 for loss in loss_strs}

            # data log-likelihood
            loss_dict_torch['loss_data_ll'] = losses.gaussian_ll(
                x_in, x_hat, m_in)
            loss_dict_torch['loss'] -= loss_dict_torch['loss_data_ll']

            # label log-likelihood
            loss_dict_torch['loss_label_ll'] = losses.gaussian_ll(
                y_in, y_hat, n_in)
            loss_dict_torch['loss'] -= alpha * loss_dict_torch['loss_label_ll']

            # supervised latents kl
            loss_dict_torch['loss_zs_kl'] = losses.kl_div_to_std_normal(
                mu[:, :n_labels], logvar[:, :n_labels])
            loss_dict_torch['loss'] += loss_dict_torch['loss_zs_kl']

            # compute all terms of decomposed elbo at once
            index_code_mi, total_correlation, dimension_wise_kl = losses.decomposed_kl(
                sample[:, n_labels:], mu[:, n_labels:], logvar[:, n_labels:])

            # unsupervised latents index-code mutual information
            loss_dict_torch['loss_zu_mi'] = index_code_mi
            loss_dict_torch['loss'] += kl * loss_dict_torch['loss_zu_mi']

            # unsupervised latents total correlation
            loss_dict_torch['loss_zu_tc'] = total_correlation
            loss_dict_torch['loss'] += beta * loss_dict_torch['loss_zu_tc']

            # unsupervised latents dimension-wise kl
            loss_dict_torch['loss_zu_dwkl'] = dimension_wise_kl
            loss_dict_torch['loss'] += kl * loss_dict_torch['loss_zu_dwkl']

            # orthogonality between A and B
            # A shape: [n_labels, n_latents]
            # B shape: [n_latents - n_labels, n_latents]
            # compute ||AB^T||^2
            loss_dict_torch['loss_AB_orth'] = losses.subspace_overlap(
                self.encoding.A.weight, self.encoding.B.weight)

            loss_dict_torch['loss'] += gamma * loss_dict_torch['loss_AB_orth']

            if accumulate_grad:
                loss_dict_torch['loss'].backward()

            # get loss value (weighted by batch size)
            bs = idx_end - idx_beg
            for key, val in loss_dict_torch.items():
                loss_dict_vals[key] += val.item() * bs
            loss_dict_vals['loss_data_mse'] += losses.gaussian_ll_to_mse(
                loss_dict_vals['loss_data_ll'] / bs, np.prod(x.shape[1:])) * bs

            # collect predicted labels to compute R2
            y_hat_all.append(y_hat.cpu().detach().numpy())

        # use variance-weighted r2s to ignore small-variance latents
        y_hat_all = np.concatenate(y_hat_all, axis=0)
        y_all = y.cpu().detach().numpy()
        if n is not None:
            n_np = n.cpu().detach().numpy()
            r2 = r2_score(y_all[n_np == 1],
                          y_hat_all[n_np == 1],
                          multioutput='variance_weighted')
        else:
            r2 = r2_score(y_all, y_hat_all, multioutput='variance_weighted')

        # compile (properly weighted) loss terms
        for key in loss_dict_vals.keys():
            loss_dict_vals[key] /= batch_size

        # store hyperparams
        loss_dict_vals['alpha'] = alpha
        loss_dict_vals['beta'] = beta
        loss_dict_vals['gamma'] = gamma
        loss_dict_vals['label_r2'] = r2

        return loss_dict_vals
Ejemplo n.º 4
0
    def loss(self, data, dataset=0, accumulate_grad=True, chunk_size=200):
        """Calculate (decomposed) ELBO loss for VAE.

        The batch is split into chunks if larger than a hard-coded `chunk_size` to keep memory
        requirements low; gradients are accumulated across all chunks before a gradient step is
        taken.

        Parameters
        ----------
        data : :obj:`dict`
            batch of data; keys should include 'images' and 'masks', if necessary
        dataset : :obj:`int`, optional
            used for session-specific io layers
        accumulate_grad : :obj:`bool`, optional
            accumulate gradient for training step
        chunk_size : :obj:`int`, optional
            batch is split into chunks of this size to keep memory requirements low

        Returns
        -------
        :obj:`dict`
            - 'loss' (:obj:`float`): full elbo
            - 'loss_ll' (:obj:`float`): log-likelihood portion of elbo
            - 'loss_mi' (:obj:`float`): code index mutual info portion of kl of elbo
            - 'loss_tc' (:obj:`float`): total correlation portion of kl of elbo
            - 'loss_dwkl' (:obj:`float`): dim-wise kl portion of kl of elbo
            - 'loss_mse' (:obj:`float`): mse (without gaussian constants)
            - 'beta' (:obj:`float`): weight in front of kl term

        """

        x = data['images'][0]
        m = data['masks'][0] if 'masks' in data else None
        beta = self.beta_vals[self.curr_epoch]
        kl = self.kl_anneal_vals[self.curr_epoch]

        batch_size = x.shape[0]
        n_chunks = int(np.ceil(batch_size / chunk_size))

        loss_strs = ['loss', 'loss_ll', 'loss_mi', 'loss_tc', 'loss_dwkl']

        loss_dict_vals = {loss: 0 for loss in loss_strs}
        loss_dict_vals['loss_mse'] = 0

        for chunk in range(n_chunks):

            idx_beg = chunk * chunk_size
            idx_end = np.min([(chunk + 1) * chunk_size, batch_size])

            x_in = x[idx_beg:idx_end]
            m_in = m[idx_beg:idx_end] if m is not None else None
            x_hat, sample, mu, logvar = self.forward(x_in,
                                                     dataset=dataset,
                                                     use_mean=False)

            # reset losses
            loss_dict_torch = {loss: 0 for loss in loss_strs}

            # data log-likelihood
            loss_dict_torch['loss_ll'] = losses.gaussian_ll(x_in, x_hat, m_in)
            loss_dict_torch['loss'] -= loss_dict_torch['loss_ll']

            # compute all terms of decomposed elbo at once
            index_code_mi, total_correlation, dimension_wise_kl = losses.decomposed_kl(
                sample, mu, logvar)

            # unsupervised latents index-code mutual information
            loss_dict_torch['loss_mi'] = index_code_mi
            loss_dict_torch['loss'] += kl * loss_dict_torch['loss_mi']

            # unsupervised latents total correlation
            loss_dict_torch['loss_tc'] = total_correlation
            loss_dict_torch['loss'] += beta * loss_dict_torch['loss_tc']

            # unsupervised latents dimension-wise kl
            loss_dict_torch['loss_dwkl'] = dimension_wise_kl
            loss_dict_torch['loss'] += kl * loss_dict_torch['loss_dwkl']

            if accumulate_grad:
                loss_dict_torch['loss'].backward()

            # get loss value (weighted by batch size)
            bs = idx_end - idx_beg
            for key, val in loss_dict_torch.items():
                loss_dict_vals[key] += val.item() * bs
            loss_dict_vals['loss_mse'] += losses.gaussian_ll_to_mse(
                loss_dict_vals['loss_ll'] / bs, np.prod(x.shape[1:])) * bs

        # compile (properly weighted) loss terms
        for key in loss_dict_vals.keys():
            loss_dict_vals[key] /= batch_size
        # store hyperparams
        loss_dict_vals['beta'] = beta

        return loss_dict_vals
Ejemplo n.º 5
0
    def loss(self, data, dataset=0, accumulate_grad=True, chunk_size=200):
        """Calculate ELBO loss for ConditionalVAE.

        The batch is split into chunks if larger than a hard-coded `chunk_size` to keep memory
        requirements low; gradients are accumulated across all chunks before a gradient step is
        taken.

        Parameters
        ----------
        data : :obj:`dict`
            batch of data; keys should include 'images' and 'masks', if necessary
        dataset : :obj:`int`, optional
            used for session-specific io layers
        accumulate_grad : :obj:`bool`, optional
            accumulate gradient for training step
        chunk_size : :obj:`int`, optional
            batch is split into chunks of this size to keep memory requirements low

        Returns
        -------
        :obj:`dict`
            - 'loss' (:obj:`float`): full elbo
            - 'loss_ll' (:obj:`float`): log-likelihood portion of elbo
            - 'loss_kl' (:obj:`float`): kl portion of elbo
            - 'loss_mse' (:obj:`float`): mse (without gaussian constants)
            - 'beta' (:obj:`float`): weight in front of kl term

        """

        x = data['images'][0]
        y = data['labels'][0]
        m = data['masks'][0] if 'masks' in data else None
        if self.hparams['conditional_encoder']:
            # continuous labels transformed into 2d one-hot array as input to encoder
            y_2d = data['labels_sc'][0]
        else:
            y_2d = None
        beta = self.beta_vals[self.curr_epoch]

        batch_size = x.shape[0]
        n_chunks = int(np.ceil(batch_size / chunk_size))

        loss_val = 0
        loss_ll_val = 0
        loss_kl_val = 0
        loss_mse_val = 0
        for chunk in range(n_chunks):

            idx_beg = chunk * chunk_size
            idx_end = np.min([(chunk + 1) * chunk_size, batch_size])

            x_in = x[idx_beg:idx_end]
            y_in = y[idx_beg:idx_end]
            m_in = m[idx_beg:idx_end] if m is not None else None
            y_2d_in = y_2d[idx_beg:idx_end] if y_2d is not None else None
            x_hat, _, mu, logvar = self.forward(x_in,
                                                dataset=dataset,
                                                use_mean=False,
                                                labels=y_in,
                                                labels_2d=y_2d_in)

            # log-likelihood
            loss_ll = losses.gaussian_ll(x_in, x_hat, m_in)

            # kl
            loss_kl = losses.kl_div_to_std_normal(mu, logvar)

            # combine
            loss = -loss_ll + beta * loss_kl

            if accumulate_grad:
                loss.backward()

            # get loss value (weighted by batch size)
            loss_val += loss.item() * (idx_end - idx_beg)
            loss_ll_val += loss_ll.item() * (idx_end - idx_beg)
            loss_kl_val += loss_kl.item() * (idx_end - idx_beg)
            loss_mse_val += losses.gaussian_ll_to_mse(
                loss_ll.item(), np.prod(x.shape[1:])) * (idx_end - idx_beg)

        loss_val /= batch_size
        loss_ll_val /= batch_size
        loss_kl_val /= batch_size
        loss_mse_val /= batch_size

        loss_dict = {
            'loss': loss_val,
            'loss_ll': loss_ll_val,
            'loss_kl': loss_kl_val,
            'loss_mse': loss_mse_val,
            'beta': beta
        }

        return loss_dict