Ejemplo n.º 1
0
def test(model, epoch, batch_size, num_imp_samples, data_loader):
    model.eval()
    test_loss = 0
    num_samples = 0
    marginals = 0

    with torch.no_grad():
        for i, data in enumerate(data_loader):
            data = data.to(device)
            loss, mu, logvar = model.loss(data)
            test_loss += loss.item()

            batchsize = data.data.shape[0]
            q_z = Normal(mu, torch.exp(0.5 * logvar))
            zs = q_z.sample_n(num_imp_samples).view(batchsize, num_imp_samples,
                                                    -1)

            ll_batch = marginal(model, data, zs)
            marginals = marginals + torch.sum(ll_batch).item()
            print(
                f"Minibatch [{i}/{len(data_loader)}] elbo = {-loss.item()/batchsize}, Log Likelihood of minibatch  = {ll_batch.mean().item()}"
            )
            num_samples += data.shape[0]

    print("\nELBO = {}, marginal probability = {}\n".format(
        -test_loss / num_samples, marginals / num_samples))
Ejemplo n.º 2
0
class TanhNormal:
    """
    Represent distribution of X where
        X ~ tanh(Z), where Z ~ N(mean, std)
    Note: this is not very numerically stable.
    Source: https://github.com/vitchyr/rlkit/blob/f136e140a57078c4f0f665051df74dffb1351f33/rlkit/torch/distributions.py
    """
    def __init__(self, normal_mean, normal_std, epsilon=1e-6):
        """
        :param normal_mean: Mean of the normal distribution
        :param normal_std: Std of the normal distribution
        :param epsilon: Numerical stability epsilon when computing log-prob.
        """
        self.normal_mean = normal_mean
        self.normal_std = normal_std
        self.normal = Normal(normal_mean, normal_std)
        self.epsilon = epsilon

    def sample_n(self, n):
        z = self.normal.sample_n(n)
        return torch.tanh(z), z

    def log_prob(self, value, pre_tanh_value=None):
        """
        :param value (torch.Tensor): some value, x
        :param pre_tanh_value (torch.Tensor): arctanh(x)
        :return:
        """
        if pre_tanh_value is None:
            # arctanh(x) = 1/2*log((1+x)/(1-x))
            pre_tanh_value = torch.log((1 + value) / (1 - value)) / 2.0

        if value is None:
            value = torch.tanh(pre_tanh_value)

        action = value
        z = pre_tanh_value
        log_prob = self.normal.log_prob(z) - torch.log(1 - action.pow(2) +
                                                       self.epsilon)
        return log_prob.sum(-1, keepdim=True)

    def sample(self):
        """
        Gradients will and should not pass through this operation.
        See https://github.com/pytorch/pytorch/issues/4620 for discussion.
        """
        z = self.normal.sample().detach()
        return torch.tanh(z), z

    def rsample(self):
        """
        Sampling in the reparameterization case.
        """
        z = self.normal.rsample()
        z.requires_grad_()
        return torch.tanh(z), z
Ejemplo n.º 3
0
class TanhNormal(Distribution):
    """
    Represent distribution of X where
        X ~ tanh(Z)
        Z ~ N(mean, std)

    Note: this is not very numerically stable.
    """
    def __init__(self, normal_mean, normal_std, epsilon=1e-6):
        """
        :param normal_mean: Mean of the normal distribution
        :param normal_std: Std of the normal distribution
        :param epsilon: Numerical stability epsilon when computing log-prob.
        """
        self.normal_mean = normal_mean
        self.normal_std = normal_std
        self.normal = Normal(normal_mean, normal_std)
        self.epsilon = epsilon

    def sample_n(self, n, return_pre_tanh_value=False):
        z = self.normal.sample_n(n)
        if return_pre_tanh_value:
            return torch.tanh(z), z
        else:
            return torch.tanh(z)

    def log_prob(self, value, pre_tanh_value=None):
        """
        :param value: some value, x
        :param pre_tanh_value: arctanh(x)
        :return:
        """
        if pre_tanh_value is None:
            pre_tanh_value = torch.log((1 + value) / (1 - value)) / 2
        return self.normal.log_prob(pre_tanh_value) - torch.log(1 -
                                                                value * value +
                                                                self.epsilon)

    def sample(self, return_pretanh_value=False):
        z = self.normal.sample()
        if return_pretanh_value:
            return torch.tanh(z), z
        else:
            return torch.tanh(z)

    def rsample(self, return_pretanh_value=False):
        z = (self.normal_mean + self.normal_std * Variable(
            Normal(ptu.zeros(self.normal_mean.size()),
                   ptu.ones(self.normal_std.size())).sample()))
        # z.requires_grad_()
        if return_pretanh_value:
            return torch.tanh(z), z
        else:
            return torch.tanh(z)
Ejemplo n.º 4
0
 def sample_action(self,state,batch_size=1):
     mu,log_std=self.Actor(state)
     #print(mu,log_std)
     dst=Normal(0,1)
     temp=dst.sample_n(self.action_size*batch_size).to(self.device)
     temp=temp.view((batch_size,self.action_size))
     log_prob=dst.log_prob(temp)
     action=mu+log_std*temp
     action=torch.tanh(action)*self.action_scale
     #print(action)
     return action,log_prob
Ejemplo n.º 5
0
class Normal(Distribution):
    def __init__(self, mean, std):
        self.normal = PyNormal(mean, std)

    def mean(self):
        return self.normal.mean

    def params(self):
        return [self.normal.mean, self.normal.std]

    def sample(self):
        """
        Generates a single sample or single batch of samples if the distribution
        parameters are batched.
        """
        return self.normal.sample()

    def sample_n(self, n):
        """
        Generates n samples or n batches of samples if the distribution parameters
        are batched.
        """
        return self.normal.sample_n(n)

    def log_prob(self, value):
        """
        Returns the log of the probability density/mass function evaluated at
        `value`.

        Args:
            value (Tensor or Variable):
	"""
        return self.normal.log_prob(value)

    def kl(self, other):
        """ 
        KL-divergence between two Normals: KL[N(u_i, s_i) || N(u_j, s_j)] 
        where params_i = [u_i, s_i] and similarly for j.
        Returns a tensor with the dimensionality of the location variable.
        """
        if not isinstance(other, Normal):
            raise ValueError('Impossible')
        location_i, scale_i = self.params()  # [mean, std]
        location_j, scale_j = other.params()  # [mean, std]
        var_i = scale_i**2.
        var_j = scale_j**2.
        term1 = 1. / (2. * var_j) * (
            (location_i - location_j)**2. + var_i - var_j)
        term2 = torch.log(scale_j) - torch.log(scale_i)
        return term1 + term2
Ejemplo n.º 6
0
class TanhNormal(Distribution):
    """
    Represent distribution of X where
        X ~ tanh(Z)
        Z ~ N(mean, std)
    """
    def __init__(self, normal_mean, normal_std, epsilon=1e-8):
        self.normal_mean = normal_mean
        self.normal_std = normal_std
        self.normal = Normal(normal_mean, normal_std)
        self.epsilon = epsilon

    def sample_n(self, n, return_pre_tanh_value=False):
        z = self.normal.sample_n(n)
        if return_pre_tanh_value:
            return torch.tanh(z), z
        else:
            return torch.tanh(z)

    def log_prob(self, value, pre_tanh_value=None):
        if pre_tanh_value is None:
            pre_tanh_value = torch.log((1 + value) / (1 - value)) / 2

        return self.normal.log_prob(pre_tanh_value) - torch.log(1 -
                                                                value * value +
                                                                self.epsilon)

    def sample(self, return_pretanh_value=False):
        z = self.normal.sample().detach()

        if return_pretanh_value:
            return torch.tanh(z), z
        else:
            return torch.tanh(z)

    def rsample(self, return_pretanh_value=False):
        z = (self.normal_mean + self.normal_std * Normal(
            torch.zeros(self.normal_mean.size(),
                        device=self.normal_mean.device),
            torch.ones(self.normal_std.size(),
                       device=self.normal_std.device)).sample())
        z.requires_grad_()

        if return_pretanh_value:
            return torch.tanh(z), z
        else:
            return torch.tanh(z)
Ejemplo n.º 7
0
class TanhNormal(Distribution):
    """
        Represent distribution of X where
            X ~ tanh(Z)
            Z ~ N(mean, std)
        Note: this is not very numerically stable.
        """
    def __init__(self, normal_mean, normal_std, epsilon=1e-6):
        super(TanhNormal, self).__init__()
        """
        :param normal_mean: Mean of the normal distribution
        :param normal_std: Std of the normal distribution
        :param epsilon: Numerical stability epsilon when computing log-prob.
        """
        self.normal = Normal(normal_mean, normal_std)
        self.epsilon = epsilon

    def sample_n(self, n, return_pre_tanh_value=False):
        z = self.normal.sample_n(n)
        if return_pre_tanh_value:
            return F.tanh(z), z
        else:
            return F.tanh(z)

    def log_prob(self, value, pre_tanh_value=None):
        """
        :param value: some value, x
        :param pre_sigmoid_value: arcsigmoid(x)
        :return:
        """
        if pre_tanh_value is None:
            pre_sigmoid_value = torch.log((1 + value) / (1 - value)) / 2
        return self.normal.log_prob(pre_tanh_value) - torch.log(1 -
                                                                value * value +
                                                                self.epsilon)

    def sample(self, return_pre_tanh_value=False):
        z = self.normal.sample()
        if return_pre_tanh_value:
            return F.tanh(z), z
        else:
            return F.tanh(z)
Ejemplo n.º 8
0
class TanhNormal(Distribution):
    """
    Represent distribution of X where
        X ~ tanh(Z)
        Z ~ N(mean, std)
    Note: this is not very numerically stable.
    """

    def __init__(self, normal_mean, normal_std, epsilon=1e-6):
        """
        :param normal_mean: Mean of the normal distribution
        :param normal_std: Std of the normal distribution
        :param epsilon: Numerical stability epsilon when computing log-prob.
        """
        self.normal_mean = normal_mean
        self.normal_std = normal_std
        self.normal = Normal(normal_mean, normal_std)
        self.epsilon = epsilon

    def sample_n(self, n, return_pre_tanh_value=False):
        z = self.normal.sample_n(n)
        if return_pre_tanh_value:
            return torch.tanh(z), z
        else:
            return torch.tanh(z)

    def log_prob(self, value, pre_tanh_value=None):
        """
        :param value: some value, x
        :param pre_tanh_value: arctanh(x)
        :return:
        """
        if pre_tanh_value is None:
            pre_tanh_value = torch.log(
                (1 + value) / (1 - value)
            ) / 2
        return self.normal.log_prob(pre_tanh_value) - torch.log(
            1 - value * value + self.epsilon
        )

    def sample(self, return_pretanh_value=False):
        """
        Gradients will and should *not* pass through this operation.
        See https://github.com/pytorch/pytorch/issues/4620 for discussion.
        """
        z = self.normal.sample().detach()

        if return_pretanh_value:
            return torch.tanh(z), z
        else:
            return torch.tanh(z)

    def rsample(self, return_pretanh_value=False):
        """
        Sampling in the reparameterization case.
        """
        z = (
            self.normal_mean +
            self.normal_std *
            Normal(
                ptu.zeros(self.normal_mean.size()),
                ptu.ones(self.normal_std.size())
            ).sample()
        )
        z.requires_grad_()

        if return_pretanh_value:
            return torch.tanh(z), z
        else:
            return torch.tanh(z)
Ejemplo n.º 9
0
class TanhNormal(Distribution):
    """
    [[Source]](https://github.com/seba-1511/cherry/blob/master/cherry/models/tabular.py)

    **Description**

    Implements a Normal distribution followed by a Tanh, often used with the Soft Actor-Critic.

    This implementation also exposes `sample_and_log_prob` and `rsample_and_log_prob`,
    which returns both samples and log-densities.
    The log-densities are computed using the pre-activation values for numerical stability.

    **Credit**

    Adapted from Vitchyr Pong's RLkit:
    https://github.com/vitchyr/rlkit/blob/master/rlkit/torch/distributions.py

    **References**

    1. Haarnoja et al. 2018. “Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor.” arXiv [cs.LG].
    2. Haarnoja et al. 2018. “Soft Actor-Critic Algorithms and Applications.” arXiv [cs.LG].

    **Arguments**

    * **normal_mean** (tensor) - Mean of the Normal distribution.
    * **normal_std** (tensor) - Standard deviation of the Normal distribution.

    **Example**
    ~~~python
    mean = th.zeros(5)
    std = th.ones(5)
    dist = TanhNormal(mean, std)
    samples = dist.rsample()
    logprobs = dist.log_prob(samples)  # Numerically unstable :(
    samples, logprobs = dist.rsample_and_log_prob()  # Stable :)
    ~~~

    """
    def __init__(self, normal_mean, normal_std):
        self.normal_mean = normal_mean
        self.normal_std = normal_std
        self.normal = Normal(normal_mean, normal_std)

    def sample_n(self, n):
        z = self.normal.sample_n(n)
        return th.tanh(z)

    def log_prob(self, value):
        pre_tanh_value = (th.log1p(value) - th.log1p(-value)).mul(0.5)
        offset = th.log1p(-value**2 + 1e-6)
        return self.normal.log_prob(pre_tanh_value) - offset

    def sample(self):
        z = self.normal.sample().detach()
        return th.tanh(z)

    def sample_and_log_prob(self):
        z = self.normal.sample().detach()
        value = th.tanh(z)
        offset = th.log1p(-value**2 + 1e-6)
        log_prob = self.normal.log_prob(z) - offset
        return value, log_prob

    def rsample_and_log_prob(self):
        z = self.normal.rsample()
        value = th.tanh(z)
        offset = th.log1p(-value**2 + 1e-6)
        log_prob = self.normal.log_prob(z) - offset
        return value, log_prob

    def rsample(self):
        z = self.normal.rsample()
        z.requires_grad_()
        return th.tanh(z)
Ejemplo n.º 10
0
class VRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, latent_size,
                 num_layers_lstm):
        """Set the hyper-parameters and build the layers."""
        super(VRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size + latent_size,
                            hidden_size,
                            num_layers_lstm,
                            batch_first=True)
        # q(z|x, h)
        self.q_z = nn.Linear(embed_size + hidden_size, latent_size * 2)
        # p(z|h)
        self.prior = nn.Linear(hidden_size, latent_size * 2)
        # gaussian noise generator for re-paramaterization trick
        self.normal = Normal(torch.zeros(latent_size, ),
                             torch.ones(latent_size, ))
        # q(x|z) backwards inference
        self.q_x = nn.Linear(latent_size + hidden_size, vocab_size)
        # deterministicly computed distribution over dictionary output (vanilla output)
        self.det_x = nn.Linear(hidden_size, vocab_size)
        self.init_weights()

    def init_weights(self):
        """Initialize weights."""
        self.embed.weight.data.uniform_(-0.1, 0.1)
        self.q_z.weight.data.uniform_(-0.1, 0.1)
        self.q_z.bias.data.fill_(0)
        self.prior.weight.data.uniform_(-0.1, 0.1)
        self.prior.bias.data.fill_(0)
        self.q_x.weight.data.uniform_(-0.1, 0.1)
        self.q_x.bias.data.fill_(0)
        self.det_x.weight.data.uniform_(-0.1, 0.1)
        self.det_x.bias.data.fill_(0)

    """
    Conduct forward pass, outputting prior, posterior, and inference distributions
        z_step: 'all' - compute distributions for every step
                t - compute distributions during step, t, compute deterministic output for all
                    steps following
    """

    def forward(self,
                features,
                captions,
                lengths,
                states=None,
                z_0=None,
                z_step='all'):
        """ Decode image feature vectors and generates captions. """

        z_padding = to_var(
            torch.zeros(features.shape[0], self.q_z.out_features / 2))
        if z_0 is None:
            z_0 = z_padding
        features = torch.cat([features, z_0], dim=1).unsqueeze(1)

        # conduct initial lstm step to embed image features in internal states
        h, states = self.lstm(features, states)
        h = h.squeeze(1)
        embeddings = self.embed(captions)

        p_mus, p_sigmas, q_mus, q_sigmas, q_xs, det_xs = [], [], [], [], [], []
        for i in range(max(lengths)):
            if z_step == 'all' or i == z_step:
                # conduct forward pass for VRNN cell

                # get tuple of (mean, std dev) for prior from h_tm1
                p_mu, p_sigma = self.prior(h).chunk(2, dim=1)
                # get tuple of (mean, var) for q_z from x_t and h_tm1
                q_mu, q_sigma = self.q_z(
                    torch.cat([embeddings[:, i], h], dim=1)).chunk(2, dim=1)
                # sample from q_z using reparameterization to get z - we take n=batch size samples
                z = to_var(self.normal.sample_n(
                    q_mu.shape[0])) * q_sigma + q_mu
                # get q_x from z_t and h_tm1
                q_x = self.q_x(torch.cat([z, h], dim=1))
                q_x = nn.functional.log_softmax(q_x, dim=1)
                # perform lstm step to get h_t from x_t, z_t, h_tm1 - unsqueeze for lstm api
                inputs = torch.cat([embeddings[:, i], z], dim=1).unsqueeze(1)
                h, states = self.lstm(inputs, states)
                h = h.squeeze(1)

                p_mus.append(p_mu)
                p_sigmas.append(p_sigma)
                q_mus.append(q_mu)
                q_sigmas.append(q_sigma)
                q_xs.append(q_x)

            else:
                # conduct forward pass for deterministic LSTM cell

                inputs = torch.cat([embeddings[:, i], z_padding],
                                   dim=1).unsqueeze(1)
                h, states = self.lstm(inputs, states)
                h = h.squeeze(1)

                # if we are decoding, compute distribution over dictionary to output
                if i > z_step:
                    det_output = self.det_x(h)
                    det_xs.append(det_output)

        # pack padded sequence to mask out padding terms and flatten batch + step dimensions
        #[p_mus, p_sigmas, q_mus, q_sigmas, q_xs] = [pack_padded_sequence(torch.stack(t, dim=1), lengths, batch_first=True)[0] for t in [p_mus, p_sigmas, q_mus, q_sigmas, q_xs]]
        [p_mus, p_sigmas, q_mus, q_sigmas,
         q_xs] = [t[0] for t in [p_mus, p_sigmas, q_mus, q_sigmas, q_xs]]
        det_xs = pack_padded_sequence(torch.stack(det_xs, dim=1),
                                      [l - z_step - 1 for l in lengths],
                                      batch_first=True)[0]

        return (p_mus, p_sigmas), (q_mus, q_sigmas), q_xs, det_xs

    """
    Encode the first t ground truths, for t = z_step, conducting a variational rnn cell forward pass
        during step t, in order to obtain q(z_t|x_t, h_t-1) and p(z_t|h_t-1) that will both define
        (along with x_t, h_t-1) a distribution over h_t, and thereby a distribution over decoded
        captions beginning with the specified ground truths
    """

    def encode(self, features, ground_truth, z_step, states=None, z_0=None):
        z_padding = to_var(
            torch.zeros(features.shape[0], self.q_z.out_features / 2))
        if z_0 is None:
            z_0 = z_padding
        features = torch.cat([features, z_0], dim=1).unsqueeze(1)

        # conduct initial lstm step to embed image features in internal states
        _, states = self.lstm(features, states)
        embeddings = self.embed(ground_truth)

        # encode to get z_t, h_t-1

        # conduct deterministic forward pass for all steps before z_step
        for i in range(z_step):
            inputs = torch.cat([embeddings[:, i], z_padding],
                               dim=1).unsqueeze(1)
            h, states = self.lstm(inputs, states)
            h = h.squeeze(1)

        # conduct forward pass for VRNN cell during z_step
        # get tuple of (mean, std dev) for prior from h_tm1
        p_mu, p_sigma = self.prior(h).chunk(2, dim=1)
        # get tuple of (mean, var) for q_z from x_t and h_tm1
        q_mu, q_sigma = self.q_z(torch.cat([embeddings[:, z_step], h],
                                           dim=1)).chunk(2, dim=1)

        return (p_mu, p_sigma), (q_mu, q_sigma), states, embeddings[:, z_step]

    """
    Decode the remainder of a caption given h_t, c_t and x_t-1
        h: [layers X batch X hidden]
        c: [layers X batch X hidden]
        x: [batch X embed]
    """

    def decode(self, states, x, z_step):
        z_padding = to_var(torch.zeros(x.shape[0], self.q_z.out_features / 2))
        caption = []
        for i in range(20 - z_step - 1):
            inputs = torch.cat([x, z_padding], dim=1).unsqueeze(1)
            h, states = self.lstm(inputs, states)
            h = h.squeeze(1)

            x = self.det_x(h)  # [batch X dictionary]
            x = x.max(1)[1]  # [batch]

            caption.append(x.data)

            x = self.embed(x)

        return torch.stack(caption, dim=1)

    """
    Encode latent distribution at step z_step, sample, then decode latent vector + hidden states
        into the remaining part of the image caption
    """

    def sample(self, features, ground_truth, states=None, z_0=None, z_step=3):
        # get prior and posterior densities at z_step
        (p_mu, p_sigma), (q_mu, q_sigma), states, x = self.encode(
            features, ground_truth, z_step)

        # sample to get z
        z_p = to_var(self.normal.sample_n(features.shape[0])) * p_sigma + p_mu
        z_q = to_var(self.normal.sample_n(features.shape[0])) * q_sigma + q_mu

        # get hidden states to give to decoder
        inputs_p = torch.cat([x, z_p], dim=1).unsqueeze(1)
        inputs_q = torch.cat([x, z_q], dim=1).unsqueeze(1)

        h_p, states_p, h_q, states_q = self.lstm(inputs_p, states) + self.lstm(
            inputs_q, states)
        h_p, h_q = h_p.squeeze(1), h_q.squeeze(1)

        # get predicted embedding of next step from h
        x_p, x_q = self.det_x(h_p), self.det_x(h_q)  # [batch X dictionary]
        x_p, x_q = x_p.max(1)[1], x_q.max(1)[1]  # [batch]
        x_p, x_q = self.embed(x_p), self.embed(x_q)  # [batch X embed]

        # decode from z, x, hidden states the remaining caption
        captions_p = self.decode(states_p, x_p, z_step)
        captions_q = self.decode(states_q, x_q, z_step)

        return captions_p, captions_q
Ejemplo n.º 11
0
class TanhNormal(Distribution):
    """
    Represent distribution of X where
        X ~ tanh(Z)
        Z ~ N(mean, std)

    Note: this is not very numerically stable.
    """
    def __init__(self, normal_mean, normal_std, epsilon=1e-6):
        """

        Args:
            normal_mean (Tensor): Mean of the normal distribution
            normal_std (Tensor): Std of the normal distribution
            epsilon (Double): Numerical stability epsilon when computing
                log-prob.
        """
        super(TanhNormal, self).__init__()
        self._normal_mean = normal_mean
        self._normal_std = normal_std
        self._normal = Normal(normal_mean, normal_std)
        self._epsilon = epsilon

    @property
    def mean(self):
        return self._normal.mean

    @property
    def variance(self):
        return self._normal.variance

    @property
    def stddev(self):
        return self._normal.stddev

    @property
    def epsilon(self):
        return self._epsilon

    def sample(self, return_pretanh_value=False):
        # z = self._normal.sample()
        z = self._normal.sample().detach()
        if return_pretanh_value:
            return torch.tanh(z), z
        else:
            return torch.tanh(z)

    def rsample(self, return_pretanh_value=False):
        z = self._normal.rsample()
        # z = (
        #     self._normal_mean +
        #     self._normal_std *
        #     Normal(
        #         ptu.zeros(self._normal_mean.size()),
        #         ptu.ones(self._normal_std.size()),
        #     ).sample()
        # )
        if return_pretanh_value:
            return torch.tanh(z), z
        else:
            return torch.tanh(z)

    def sample_n(self, n, return_pre_tanh_value=False):
        z = self._normal.sample_n(n)
        if return_pre_tanh_value:
            return torch.tanh(z), z
        else:
            return torch.tanh(z)

    def log_prob(self, value, pre_tanh_value=None):
        """
        Returns the log of the probability density function evaluated at
        `value`.

        Args:
            value (Tensor):
            pre_tanh_value (Tensor): arctan(value)

        Returns:
            log_prob (Tensor)

        """
        if pre_tanh_value is None:
            pre_tanh_value = torch.log((1 + value) / (1 - value)) / 2

        return self._normal.log_prob(pre_tanh_value) - \
            torch.log(1. - value * value + self._epsilon)
        # return self.normal.log_prob(pre_tanh_value) - \
        #     torch.log(1. - torch.tanh(pre_tanh_value)**2 + self._epsilon)

    def cdf(self, value, pre_tanh_value=None):
        if pre_tanh_value is None:
            pre_tanh_value = torch.log((1 + value) / (1 - value)) / 2
        return self._normal.cdf(pre_tanh_value)
Ejemplo n.º 12
0
    # KL-divergence between a diagonal multivariate normal,
    # and a standard normal distribution (with zero mean and unit variance)
    # In other words, we are punishing it if it's distribution moves away from a standard normal dist
    KLD = -0.5 * torch.sum(1 + dist.scale.log() - dist.loc.pow(2) - dist.scale)
    return BCE + KLD


# +
# You can try the KLD here with differen't distribution
p = Normal(loc=1, scale=2)
q = Normal(loc=0, scale=1)
kld = torch.distributions.kl.kl_divergence(p, q)

# plot the distributions
ps = p.sample_n(10000).numpy()
qs = q.sample_n(10000).numpy()

sns.kdeplot(ps, label='p')
sns.kdeplot(qs, label='q')
plt.title(f"KLD(p|q) = {kld:2.2f}\nKLD({p}|{q})")
plt.legend()
plt.show()
# -

# ## Exercise 1: KLD
#
# Run the above cell with while changing Q.
#
# - Use the code above and test if the KLD is higher for distributions that overlap more
#
 def sample_n(self, n):
     x = Normal.sample_n(self, n)
     return F.sigmoid(x)
 def sample_n(self, n):
     x = Normal.sample_n(self, n)
     return T.exp(x)
Ejemplo n.º 15
0
import matplotlib.pyplot as plt
import torch
from torch.distributions import Normal

from mmvae_hub.experiment_vis.utils import load_experiment

exp_uid = 'polymnist_iwmogfm_multiloss_2021_09_28_11_08_24_742311'
exp = load_experiment(_id=exp_uid)

exp.set_eval_mode()
num_samples = 10
Gf = Normal(
    torch.zeros(exp.flags.class_dim, device=exp.flags.device),
    torch.tensor(1 / 2).sqrt() *
    torch.ones(exp.flags.class_dim, device=exp.flags.device))
z_Gf = Gf.sample_n(num_samples)

zss, log_det_J = exp.mm_vae.flow.rev(z_Gf)

rec_mods = {}
for mod_key, mod in exp.mm_vae.modalities.items():
    rec_mods[mod_key] = mod.calc_likelihood(None, class_embeddings=zss).mean

for k in range(num_samples):
    plt.figure()
    plt.subplot(1, 3, 1)
    plt.imshow(rec_mods['m0'][k].detach().cpu().numpy().swapaxes(0, -1))
    plt.subplot(1, 3, 2)
    plt.imshow(rec_mods['m1'][k].detach().cpu().numpy().swapaxes(0, -1))
    plt.subplot(1, 3, 3)
    plt.imshow(rec_mods['m2'][k].detach().cpu().numpy().swapaxes(0, -1))
Ejemplo n.º 16
0
class TanhNormal(Distribution):
    """
    Represent distribution of X where
        X ~ tanh(Z)
        Z ~ N(mean, std)

    Note: this is not very numerically stable.
    """
    def __init__(self, normal_mean, normal_std, epsilon=1e-6):
        """
        :param normal_mean: Mean of the normal distribution
        :param normal_std: Std of the normal distribution
        :param epsilon: Numerical stability epsilon when computing log-prob.
        """
        self.normal_mean = normal_mean
        self.normal_std = normal_std
        self.normal = Normal(normal_mean, normal_std)
        self.epsilon = epsilon

    def sample_n(self, n, return_pre_tanh_value=False):
        z = self.normal.sample_n(n)
        if return_pre_tanh_value:
            return torch.tanh(z), z
        else:
            return torch.tanh(z)

    def log_prob(self, value, pre_tanh_value=None):
        """
        Adapted from
        https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/tanh.py#L73

        This formula is mathematically equivalent to log(1 - tanh(x)^2).

        Derivation:

        log(1 - tanh(x)^2)
         = log(sech(x)^2)
         = 2 * log(sech(x))
         = 2 * log(2e^-x / (e^-2x + 1))
         = 2 * (log(2) - x - log(e^-2x + 1))
         = 2 * (log(2) - x - softplus(-2x))

        :param value: some value, x
        :param pre_tanh_value: arctanh(x)
        :return:
        """
        if pre_tanh_value is None:
            value = torch.clamp(value, -0.999999, 0.999999)
            # pre_tanh_value = torch.log(
            # (1+value) / (1-value)
            # ) / 2
            pre_tanh_value = torch.log(1 + value) / 2 - torch.log(1 -
                                                                  value) / 2
            # ) / 2
        return self.normal.log_prob(pre_tanh_value) - 2. * (
            ptu.from_numpy(np.log([2.])) - pre_tanh_value -
            torch.nn.functional.softplus(-2. * pre_tanh_value))

    def sample(self, return_pretanh_value=False):
        """
        Gradients will and should *not* pass through this operation.

        See https://github.com/pytorch/pytorch/issues/4620 for discussion.
        """
        z = self.normal.sample().detach()

        if return_pretanh_value:
            return torch.tanh(z), z
        else:
            return torch.tanh(z)

    def rsample(self, return_pretanh_value=False):
        """
        Sampling in the reparameterization case.
        """
        z = (self.normal_mean + self.normal_std *
             Normal(ptu.zeros(self.normal_mean.size()),
                    ptu.ones(self.normal_std.size())).sample())
        z.requires_grad_()

        if return_pretanh_value:
            return torch.tanh(z), z
        else:
            return torch.tanh(z)