Exemplo n.º 1
0
    def forward(self, x):
        prev_hidden = self.agent(x)
        prev_c = torch.zeros_like(prev_hidden)  # only for LSTM

        e_t = torch.stack([self.sos_embedding] * prev_hidden.size(0))
        sequence = []

        for step in range(self.max_len):
            if isinstance(self.cell, nn.LSTMCell):
                h_t, prev_c = self.cell(e_t, (prev_hidden, prev_c))
            else:
                h_t = self.cell(e_t, prev_hidden)

            step_logits = F.log_softmax(self.hidden_to_output(h_t), dim=1)
            distr = RelaxedOneHotCategorical(logits=step_logits,
                                             temperature=self.temperature)

            if self.training:
                x = distr.rsample()
            else:
                x = torch.zeros_like(step_logits).scatter_(
                    -1, step_logits.argmax(dim=-1, keepdim=True), 1.0)

            prev_hidden = h_t
            e_t = self.embedding(x)
            sequence.append(x)

        sequence = torch.stack(sequence).permute(1, 0, 2)

        if self.force_eos:
            eos = torch.zeros_like(sequence[:, 0, :]).unsqueeze(1)
            eos[:, 0, 0] = 1
            sequence = torch.cat([sequence, eos], dim=1)

        return sequence
Exemplo n.º 2
0
    def multi_dist_sample(self, logits, temp, return_only_action):
        action = []
        log_prob = []
        for i in range(0, self.num_heads * self.num_actions, self.num_heads):
            dist = RelaxedOneHotCategorical(logits=logits[:, i:i +
                                                          self.num_heads],
                                            temperature=temp)
            raw_action = dist.rsample()
            action.append(raw_action.argmax(1).unsqueeze(1))

            if not return_only_action:
                log_prob.append(dist.log_prob(raw_action).t()[:, 0:1])

        action = torch.cat(action, 1).float()

        if len(action) == 1 and random.random() < self.epsilon:
            for _ in range(4):
                action_ind = random.randint(
                    0, int(self.num_actions / self.num_heads))
                action[0, action_ind] = 0.0 if random.random() < 0.5 else 1.0

            if self.epsilon > self.epsilon_end:
                self.epsilon -= self.epsilon_decay_rate

        #action -= 1 #Translate 0,1,2 to -1,0,1
        if return_only_action: return action

        log_prob = torch.cat(log_prob, 1).mean(1).unsqueeze(1)
        return action, log_prob
Exemplo n.º 3
0
def gumbel_softmax_sample(logits: torch.Tensor,
                          temperature: float = 1.0,
                          training: bool = True,
                          straight_through: bool = False):

    size = logits.size()
    if not training:
        indexes = logits.argmax(dim=-1)
        one_hot = torch.zeros_like(logits).view(-1, size[-1])
        one_hot.scatter_(1, indexes.view(-1, 1), 1)
        one_hot = one_hot.view(*size)
        return one_hot

    sample = RelaxedOneHotCategorical(logits=logits,
                                      temperature=temperature).rsample()

    if straight_through:
        size = sample.size()
        indexes = sample.argmax(dim=-1)
        hard_sample = torch.zeros_like(sample).view(-1, size[-1])
        hard_sample.scatter_(1, indexes.view(-1, 1), 1)
        hard_sample = hard_sample.view(*size)

        sample = sample + (hard_sample - sample).detach()
    return sample
Exemplo n.º 4
0
    def forward(self, input, args, n_particles, test=False):
        """
        n_particles is interpreted as 1 for now to not screw anything up
        """
        n_particles = 1
        T = nn.Softmax(dim=0)(self.T)  # NOTE: not in log-space
        pi = nn.Softmax(dim=0)(self.pi)
        emit = self.calc_emit()

        # run the input and teacher-forcing inputs through the embedding layers here
        seq_len, batch_sz = input.size()
        emb = self.inp_embedding(input)
        hidden = self.init_hidden(batch_sz, self.nhid, 2)  # bidirectional
        hidden_states, (_, _) = self.encoder(emb, hidden)
        hidden_states = hidden_states.repeat(1, n_particles, 1)

        # run the z-decoder at this point, evaluating the NLL at each step
        z = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True)

        nlls = hidden_states.data.new(seq_len, batch_sz * n_particles)
        loss = 0

        # now a log-prob
        prior_probs = pi.unsqueeze(0).expand(batch_sz * n_particles,
                                             self.z_dim)

        logits = self.init_hidden(batch_sz * n_particles,
                                  self.z_dim,
                                  squeeze=True)

        for i in range(seq_len):
            # logits = self.logits(torch.cat([hidden_states[i], h], 1))
            # logits = self.logits(nn.functional.relu(self.z_decoder(hidden_states[i], logits)))
            logits = self.logits(hidden_states[i])

            # build the next z sample
            q = RelaxedOneHotCategorical(temperature=Variable(
                torch.Tensor([args.temp]).cuda()),
                                         logits=logits)
            z = q.sample()

            lse = log_sum_exp(logits, dim=1).view(-1, 1)
            log_probs = logits - lse

            # now, compute the log-likelihood of the data given this z-sample
            # so emission is [batch_sz x z_dim], i.e. emission[i, j] is the log-probability of getting this
            # data for element i given choice z
            emission = F.embedding(input[i].repeat(n_particles), emit)

            NLL = -log_sum_exp(emission + log_probs, 1)
            nlls[i] = NLL.data
            KL = (log_probs.exp() * (log_probs -
                                     (prior_probs + 1e-16).log())).sum(1)
            loss += (NLL + KL)

            if i != seq_len - 1:
                prior_probs = (T.unsqueeze(0) * z.unsqueeze(1)).sum(2)

        # now, we calculate the final log-marginal estimator
        return loss.sum(), nlls.sum(), (seq_len * batch_sz * n_particles), 0
Exemplo n.º 5
0
 def reparameterize(self, p):
     if self.training:
         # At training time we sample from a relaxed Gumbel-Softmax Distribution. The samples are continuous but when we increase the temperature the samples gets closer to a Categorical.
         m = RelaxedOneHotCategorical(TEMPERATURE, p)
         return m.rsample()
     else:
         # At testing time we sample from a Categorical Distribution.
         m = OneHotCategorical(p)
         return m.sample()
    def noisy_action(self, obs, return_only_action=True):
        _, log_temp, logits = self.clean_action(obs, return_only_action=False)

        temp = log_temp.exp()
        dist = RelaxedOneHotCategorical(temperature=temp,
                                        probs=F.softmax(logits, dim=1))
        action = dist.rsample()

        if return_only_action:
            return action.argmax(1)

        log_prob = dist.log_prob(action)
        log_prob = torch.diagonal(log_prob, offset=0).unsqueeze(1)

        return action.argmax(1), log_prob, logits
Exemplo n.º 7
0
    def forward(self, input: torch.Tensor,
                proposal: distributions.RelaxedOneHotCategorical,
                proposal_sample: torch.Tensor,
                reconstruction: torch.Tensor) -> torch.Tensor:

        if self.likelihood == 'bernoulli':
            likelihood = distributions.Bernoulli(probs=reconstruction)
        else:
            likelihood = distributions.Normal(reconstruction,
                                              torch.ones_like(reconstruction))

        likelihood = distributions.Independent(likelihood,
                                               reinterpreted_batch_ndims=-1)
        reconstruction_loss = likelihood.log_prob(input).mean()

        assert proposal.logits.dim(
        ) == 2, "proposal.shape == [*, D], D is shape of isotopic gaussian"

        prior = distributions.RelaxedOneHotCategorical(proposal.temperature,
                                                       logits=torch.ones_like(
                                                           proposal.logits))
        regularization = (proposal.log_prob(proposal_sample) - prior.log_prob(proposal_sample)) \
            .mean()

        # evidence lower bound (maximize)
        total_loss = reconstruction_loss - self.beta * regularization

        return -total_loss, -reconstruction_loss, regularization
    def forward(self, state, action):  
        #encoding 
        state = state.unsqueeze(0)
        action = action.unsqueeze(0)
        q = F.relu(self.e1(torch.cat([state, action], dim=1))) 
        q = F.relu(self.e2(q))
        q = F.relu(self.e3(q))  
        q_y = q.view(q.size(0), self.latent_dim, self.categorical_dim)  # 1 x 2 x 5  

        # decoding  
        z = RelaxedOneHotCategorical(torch.tensor([self.temperature]), logits=q)  
        log_prob = z.log_prob(action)
        sample = z.sample()  
        recon = self.decode(state, sample)     
     
        return recon, F.softmax(q_y, dim=1).reshape(*q.size()))
Exemplo n.º 9
0
    def forward(self, *args, **kwargs):
        logits = self.agent(*args, **kwargs)

        if self.training:
            return RelaxedOneHotCategorical(logits=logits, temperature=self.temperature).rsample()
        else:
            return (logits / self.temperature).softmax(dim=1)
Exemplo n.º 10
0
    def forward(self, x):
        params = self.network(x)

        if self.beta not in (None, 0., np.inf):
            RelaxedOneHotCategorical(temperature=1./self.beta, logits=params)

        return OneHotCategorical(logits=params)
Exemplo n.º 11
0
 def prob_dists(self, obs, temperature=1.0):
     logits = self.forward(obs)
     split_logits = torch.split(logits, self.action_split, dim=-1)
     temperature = torch.tensor(temperature).to(DEVICE)
     return [
         RelaxedOneHotCategorical(temperature, logits=l)
         for l in split_logits
     ]
Exemplo n.º 12
0
 def concrete(self, state):
     conv = F.relu(self.c3(F.relu(self.c2(F.relu(self.c1(state))))))
     conv_flat = conv.view(state.size()[0], -1)
     fc_out = self.fc2(F.relu(self.fc1(conv_flat)))
     # print fc_out.data[0].numpy()
     c = torch.clamp(torch.sign(fc_out), 0.0).data[0].cpu().numpy()
     return RelaxedOneHotCategorical(self.temperature,
                                     logits=fc_out).sample(), c
Exemplo n.º 13
0
    def forward(self, *args, **kwargs):
        logits = self.agent(*args, **kwargs)

        if self.training:
            return RelaxedOneHotCategorical(
                logits=logits, temperature=self.temperature).rsample()
        else:
            return torch.zeros_like(logits).scatter_(
                -1, logits.argmax(dim=-1, keepdim=True), 1.0)
Exemplo n.º 14
0
    def forward(self, logits: torch.Tensor):
        size = logits.size()
        if not self.training:
            indexes = logits.argmax(dim=-1)
            one_hot = torch.zeros_like(logits).view(-1, size[-1])
            one_hot.scatter_(1, indexes.view(-1, 1), 1)
            one_hot = one_hot.view(*size)

            return one_hot

        sample = RelaxedOneHotCategorical(
            logits=logits, temperature=self.temperature).rsample()

        if self.straight_through:
            size = sample.size()
            indexes = sample.argmax(dim=-1)
            hard_sample = torch.zeros_like(sample).view(-1, size[-1])
            hard_sample.scatter_(1, indexes.view(-1, 1), 1)
            hard_sample = hard_sample.view(*size)

            sample = sample + (hard_sample - sample).detach()
        return sample
Exemplo n.º 15
0
    def forward(self, tensor_list):
        """ Creates a RelaxedOneHotCategorical distribution conditioned
        on the inputs.

        Parameters
        ----------
        tensor_list: list of torch.Tensor
            a list of tensors that will be first concatenatedd on the last
            dimension.
        """
        x = torch.cat(tensor_list, dim=-1)
        logits = self.w_dense(x)
        return RelaxedOneHotCategorical(self._temperature, logits=logits)
Exemplo n.º 16
0
    def forward(ctx, input, temperature):
        """Forward pass
        Parameters
        ==========
        :param input: input tensor
        Returns
        =======
        :return: a one-hot tensor with 1 indicating the max of that input vector"""

        # We can cache arbitrary Tensors for use in the backward pass using the
        # save_for_backward method.
        # ctx.save_for_backward(input)
        batch_size = input.shape[0]
        # maxes = torch.max(input,1)[1]
        dist = RelaxedOneHotCategorical(temperature, input)
        samples = dist.sample()
        # probs =
        # out = torch.zeros_like(input)
        # out[range(batch_size), samples] = 1
        # ctx.save_for_backward(out)
        # return out
        return samples
Exemplo n.º 17
0
    def forward(self, x):
        B, C, H, W = x.size()
        N, M, D = self.embedding.size()
        assert C == N * D

        x = x.view(B, N, D, H, W).permute(1, 0, 3, 4, 2)
        x_flat = x.reshape(N, -1, D)

        distances = torch.baddbmm(
            torch.sum(self.embedding**2, dim=2).unsqueeze(1) +
            torch.sum(x_flat**2, dim=2, keepdim=True),
            x_flat,
            self.embedding.transpose(1, 2),
            alpha=-2.0,
            beta=1.0)
        distances = distances.view(N, B, H, W, M)

        dist = RelaxedOneHotCategorical(0.5, logits=-distances)
        if self.training:
            samples = dist.rsample().view(N, -1, M)
        else:
            samples = torch.argmax(dist.probs, dim=-1)
            samples = F.one_hot(samples, M).float()
            samples = samples.view(N, -1, M)

        quantized = torch.bmm(samples, self.embedding)
        quantized = quantized.view_as(x)

        KL = dist.probs * (dist.logits + math.log(M))
        KL[(dist.probs == 0).expand_as(KL)] = 0
        KL = KL.sum(dim=(0, 2, 3, 4)).mean()

        avg_probs = torch.mean(samples, dim=1)
        perplexity = torch.exp(
            -torch.sum(avg_probs * torch.log(avg_probs + 1e-10), dim=-1))

        return quantized.permute(1, 0, 4, 2,
                                 3).reshape(B, C, H, W), KL, perplexity.sum()
Exemplo n.º 18
0
def gumbel_softmax_sample(logits: torch.Tensor,
                          temperature: float = 1.0,
                          straight_through: bool = False):
    """Samples from a Gumbel-Sotmax/Concrete of a Categorical distribution.
    More details in:
    - Gumbel-Softmax: https://arxiv.org/abs/1611.01144
    - Concrete distribution: https://arxiv.org/abs/1611.00712

    Arguments:
        logits {torch.Tensor} -- tensor of logits, the output of an inference network.
            Size: [batch_size, n_categories]

    Keyword Arguments:
        temperature {float} -- temperature of the softmax relaxation. The lower the
            temperature (-->0), the closer the sample is to a discrete sample.
            (default: {1.0})
        straight_through {bool} -- Whether to use the straight-through estimator.
            (default: {False})

    Returns:
        torch.Tensor -- the relaxed sample.
            Size: [batch_size, n_categories]
    """

    sample = RelaxedOneHotCategorical(logits=logits,
                                      temperature=temperature).rsample()

    if straight_through:
        size = sample.size()
        indexes = sample.argmax(dim=-1)
        hard_sample = torch.zeros_like(sample).view(-1, size[-1])
        hard_sample.scatter_(1, indexes.view(-1, 1), 1)
        hard_sample = hard_sample.view(*size)

        sample = sample + (hard_sample - sample).detach()
    return sample
Exemplo n.º 19
0
    def sampled_elbo(self, input, args, n_particles, emb, hidden_states):
        seq_len, batch_sz = input.size()
        T = nn.Softmax(dim=0)(self.T)  # NOTE: not in log-space
        pi = nn.Softmax(dim=0)(self.pi)
        emit = self.calc_emit()

        hidden_states = hidden_states.repeat(1, n_particles, 1)
        nlls = hidden_states.data.new(seq_len, batch_sz * n_particles)
        loss = 0

        # now a value in probability space
        prior_probs = pi.unsqueeze(0).expand(batch_sz * n_particles,
                                             self.z_dim)

        logits = self.init_hidden(batch_sz * n_particles,
                                  self.z_dim,
                                  squeeze=True)

        for i in range(seq_len):
            logits = self.logits(hidden_states[i])

            # build the next z sample
            p = RelaxedOneHotCategorical(temperature=self.temp_prior,
                                         probs=prior_probs)
            q = RelaxedOneHotCategorical(temperature=self.temp, logits=logits)
            z = q.rsample()

            log_probs = F.log_softmax(logits, dim=1)

            # now, compute the log-likelihood of the data given this z-sample
            # so emission is [batch_sz x z_dim], i.e. emission[i, j] is the log-probability of getting this
            # data for element i given choice z
            emission = F.embedding(input[i].repeat(n_particles), emit)

            NLL = -log_sum_exp(emission + log_probs, 1)
            nlls[i] = NLL.data
            KL = q.log_prob(z) - p.log_prob(z)  # pretty inexact
            loss += (NLL + KL)

            if i != seq_len - 1:
                prior_probs = (T.unsqueeze(0) * z.unsqueeze(1)).sum(2)

        (loss.sum() /
         (seq_len * batch_sz * n_particles)).backward(retain_graph=True)
        return loss, 0, seq_len * batch_sz * n_particles, 0
Exemplo n.º 20
0
 def rsample(self):
     if not self.has_rsample:
         raise NotImplementedError("Mixture does not support rsample.")
     samples = []
     for distribution in self.distributions:
         sample = distribution.rsample().unsqueeze(0)
         samples.append(sample)
     samples = torch.cat(samples, dim=0)
     expand = samples.dim() - 2
     choice = RelaxedOneHotCategorical(probs=self.weights, temperature=0.1)
     choice = choice.rsample().permute(1, 0)
     choice = choice.view(choice.size(0), choice.size(1), *expand)
     result = (samples * choice).sum(dim=0)
     return result
Exemplo n.º 21
0
 def add_noise_(self, batch):
     for i in range(len(batch.actions)):
         if i == self.index:
             continue
         # get observations and actions for agent i
         obs = batch.observations[i]
         actions = batch.actions[i]
         # create noise tensors, same shape and on same device
         if self.sigma_noise is not None:
             obs = obs + torch.randn_like(obs) * self.sigma_noise
         if self.temp_noise is not None:
             temp = torch.tensor(self.temp_noise,
                                 dtype=torch.float,
                                 device=actions.device)
             # avoid zero probs which lead to nan samples
             probs = actions + 1e-45
             actions = RelaxedOneHotCategorical(temp, probs=probs).sample()
         # add noise
         batch.observations[i] = obs
         batch.actions[i] = actions
Exemplo n.º 22
0
def rsample_gumbel_softmax(
    distr: Distribution,
    n: int,
    temperature: torch.Tensor,
    straight_through: bool = False,
) -> torch.Tensor:
    if isinstance(distr, (Categorical, OneHotCategorical)):
        if straight_through:
            gumbel_distr = RelaxedOneHotCategoricalStraightThrough(
                temperature, probs=distr.probs)
        else:
            gumbel_distr = RelaxedOneHotCategorical(temperature,
                                                    probs=distr.probs)
    elif isinstance(distr, Bernoulli):
        if straight_through:
            gumbel_distr = RelaxedBernoulliStraightThrough(temperature,
                                                           probs=distr.probs)
        else:
            gumbel_distr = RelaxedBernoulli(temperature, probs=distr.probs)
    else:
        raise ValueError("Using Gumbel Softmax with non-discrete distribution")
    return gumbel_distr.rsample((n, ))
Exemplo n.º 23
0
 def sample(self, mean, logvar, probabilities):
     normal = Normal(mean, torch.exp(0.5 * logvar))
     categorical = RelaxedOneHotCategorical(self.temperature, probabilities)
     return normal.rsample(), categorical.rsample()
Exemplo n.º 24
0
    def forward(self, x, mask, num_particles=4):

        logweight_acc = torch.zeros(x.size(1), num_particles).to(
            device)  # (batch_size, num_particles)
        log_hat_p_acc = torch.zeros(x.size(1)).to(device)  # (batch_size, )
        log_hat_p_iwae_acc = torch.zeros(x.size(1)).to(device)
        kl_acc = torch.zeros(x.size(1)).to(device)  # (batch_size, )

        # [0, 1, 2, 3, 4, 5, 6, 7, ... ]
        noresampleidxs = torch.arange(x.size(1) * num_particles).to(device)

        h = Variable(
            torch.zeros(self.n_layers,
                        x.size(1) * num_particles, self.h_dim)).to(device)
        c = Variable(
            torch.zeros(self.n_layers,
                        x.size(1) * num_particles, self.h_dim)).to(device)

        # with torch.autograd.set_detect_anomaly(True):
        for t in range(x.size(0)):
            # VRNN Cell
            xts = x[t].repeat((1, num_particles)).reshape(
                (x.size(1) * num_particles, -1))
            phi_x_ts = self.phi_x(
                xts)  # [batch_size * num_particle, embed_size]

            enc_t = self.enc(torch.cat([phi_x_ts, h[-1]], 1))
            enc_mean_t = self.enc_mean(enc_t)
            enc_std_t = self.enc_std(enc_t)

            encoder_dist = MultivariateNormal(
                enc_mean_t, scale_tril=torch.diag_embed(enc_std_t))

            prior_t = self.prior(h[-1])
            prior_mean_t = self.prior_mean(prior_t)
            prior_std_t = self.prior_std(prior_t)

            prior_dist = MultivariateNormal(
                prior_mean_t, scale_tril=torch.diag_embed(prior_std_t))

            z_t_is = encoder_dist.rsample(
            )  # reparametrizable  # [batch_size * seq_len, latent_size]

            phi_z_ts = self.phi_z(z_t_is)

            dec_t = self.dec(torch.cat([phi_z_ts, h[-1]], 1))
            dec_mean_t = self.dec_mean(dec_t)
            decoder_dist = Bernoulli(probs=dec_mean_t)

            prior_logprob_ti = prior_dist.log_prob(z_t_is.detach()) + 1e-7
            encoder_logprob_ti = encoder_dist.log_prob(z_t_is.detach()) + 1e-7
            decoder_logprob_ti = decoder_dist.log_prob(xts).sum(-1) + 1e-7

            # recurrence
            _, (h, c) = self.rnn(
                torch.cat([phi_x_ts, phi_z_ts], 1).unsqueeze(0), (h, c))

            kl = torch.distributions.kl_divergence(encoder_dist, prior_dist)
            kl_acc += kl.mean(-1) * mask[t]
            nll = self._nll_bernoulli(dec_mean_t, xts)

            # log_alpha_ti = prior_logprob_ti + decoder_logprob_ti - encoder_logprob_ti 	# [batch_size, ]
            log_alpha_ti = -(nll + kl)
            log_alpha_ti = log_alpha_ti.reshape(
                x.size(1), -1)  # [batch_size, num_particles]
            log_alpha_ti = log_alpha_ti * mask[t][
                None].T  # [batch_size, num_particles] * [batch_size, 1]

            # hat_p = torch.exp(logweight_acc + log_alpha_ti) 		# [batch_size, num_particles]
            logweight_acc += log_alpha_ti

            # Add resampling procedure here
            # ess = 1. / (torch.exp(logweight_acc) ** 2).sum(-1)  # [batch_size, ]
            # logess = torch.log(1. / (torch.exp(logweight_acc) ** 2).sum(-1) )
            logess_num = 2 * torch.logsumexp(logweight_acc, dim=-1)
            logess_denom = torch.logsumexp(2 * logweight_acc, dim=-1)
            logess = logess_num - logess_denom

            if not self.use_resampling_gradient:
                resample_dist = Categorical(
                    logits=logweight_acc.reshape(x.size(1), num_particles))
                resampled_idxs = resample_dist.sample([num_particles]).T

                # [0, 0, 0, 0, 4, 4, 4, 4, ... ]
                sample_offset = torch.arange(x.size(1)).repeat([
                    num_particles, 1
                ]).T.reshape(-1).to(device) * num_particles
                resampled_idxs = resampled_idxs.reshape(-1) + sample_offset

                should_resample = logess <= torch.log(
                    torch.ones_like(logess).to(device) * num_particles / 2.0)
                should_resample = should_resample & mask[t].bool()
                should_resample_tiled = should_resample.repeat(
                    [num_particles, 1]).T.reshape(-1)

                new_idxs = torch.where(should_resample_tiled, resampled_idxs,
                                       noresampleidxs)

                h[-1] = h[-1][new_idxs]
                c[-1] = c[-1][new_idxs]

                log_hat_p = torch.logsumexp(logweight_acc.clone(),
                                            dim=-1) - math.log(
                                                float(num_particles))
                log_hat_p_acc += log_hat_p * should_resample.float()

                logweight_acc *= (1. - should_resample_tiled.reshape(
                    x.size(1), num_particles).float())

            else:
                # raise NotImplementedError
                resample_dist = RelaxedOneHotCategorical(
                    logits=logweight_acc.reshape(x.size(1), num_particles),
                    temperature=0.1)
                resampled_onehot_relaxedidxs = resample_dist.rsample(
                    [num_particles]).permute(1, 0,
                                             2)  #.reshape(-1, num_particles)

                should_resample = logess <= torch.log(
                    torch.ones_like(logess).to(device) * num_particles / 2.0)
                should_resample = should_resample & mask[t].bool()
                should_resample_tiled = should_resample.repeat(
                    [num_particles, 1]).T.reshape(-1)

                # noresample_onehot = torch.eye(x.size(1) * num_particles)

                for batch_idx in range(x.size(1)):
                    if should_resample[batch_idx]:
                        # cur_slice = (batch_idx * x.size(1) * num_particles) : (batch_idx * x.size(1) * num_particles + x.size(1) * num_particles)
                        h[-1][(batch_idx * num_particles) : (batch_idx * num_particles + num_particles)] = \
                         resampled_onehot_relaxedidxs[batch_idx] @ h[-1][(batch_idx * num_particles) : (batch_idx * num_particles + num_particles)].clone()
                        c[-1][(batch_idx * num_particles) : (batch_idx * num_particles + num_particles)] = \
                         resampled_onehot_relaxedidxs[batch_idx] @ c[-1][(batch_idx * num_particles) : (batch_idx * num_particles + num_particles)].clone()

                log_hat_p = torch.logsumexp(logweight_acc.clone(),
                                            dim=-1) - math.log(
                                                float(num_particles))
                log_hat_p_acc += log_hat_p * should_resample.float()

                logweight_acc *= (1. - should_resample_tiled.reshape(
                    x.size(1), num_particles).float())

            log_hat_p_iwae_acc += (
                torch.logsumexp(log_alpha_ti.detach(), dim=-1) -
                math.log(float(num_particles))) * mask[t]
            #computing losses
            # kld_loss /= self.num_zs
            # nll_loss /= self.num_zs

        log_hat_p_acc += torch.logsumexp(logweight_acc, dim=-1) - math.log(
            float(num_particles))
        fivo_bound = torch.sum(log_hat_p_acc)
        # kl = torch.mean(kl_acc.reshape(x.size(1), -1), dim=-1)

        # return fivo_loss, kld_loss, nll_loss, \
        # 	(all_enc_mean, all_enc_std), \
        # 	(all_dec_mean, all_dec_std), \
        # 	log_hat_ps
        return -fivo_bound, log_hat_p_acc, logweight_acc, kl_acc, log_hat_p_iwae_acc
Exemplo n.º 25
0
    def forward(self, X, index, length, hidden=None, temp=1.0):
        # A mode for learning to predict words, ignoring how to distinguish between parser states
        # Set to true for pre-training, then set to False for fine-tuning
        pretrain = False
        dice = random.random()

        prev_a, prev_b, prev_depth, _ = hidden
        batch_size = X.shape[0]
        hidden_size = prev_b.shape[-1]

        batch_range = range(batch_size)
        device = next(self.parameters()).device

        # Depth "0" is initialized to 0 (needed for conditioning of depth 1)
        ab_00 = [torch.zeros(batch_size, 2 * hidden_size, device=device)]
        ab_01 = [torch.zeros(batch_size, 2 * hidden_size, device=device)]
        ab_10 = [torch.zeros(batch_size, 2 * hidden_size, device=device)]
        ab_11 = [torch.zeros(batch_size, 2 * hidden_size, device=device)]

        sect_start = time.time()

        for d in range(1, self.depth + 1):
            fork_join_a = prev_a[:, d, :]
            nofork_nojoin_a = self.w_a00(
                torch.cat((X, prev_b[:, d - 1, :], prev_a[:, d, :]), 1))
            fork_nojoin_a = self.w_a10(torch.cat((X, prev_b[:, d - 1, :]), 1))
            nofork_join_a = prev_a[:, d - 1, :]

            ## At next depth, need to update a and/or b
            next_a_d_00 = (
                torch.eq(prev_depth, float(d)).float() * nofork_nojoin_a
                +  # at shallower depth, copy over
                torch.gt(prev_depth, float(d)).float() * prev_a[:, d, :]
                +  # at deeper depth, zero out
                torch.lt(prev_depth, float(d)).float() *
                torch.zeros_like(prev_a[:, d, :]))

            next_a_d_11 = (
                torch.eq(prev_depth, float(d)).float() * fork_join_a
                +  # at shallower depth, copy over
                torch.gt(prev_depth, float(d)).float() * prev_a[:, d, :]
                +  # at deeper depth, zero out
                torch.lt(prev_depth, float(d)).float() *
                torch.zeros_like(prev_a[:, d, :]))

            next_a_d_10 = (
                torch.eq(prev_depth + 1, float(d)).float() * fork_nojoin_a
                +  # at shallower depth, copy over
                torch.gt(prev_depth + 1, float(d)).float() * prev_a[:, d, :]
                +  # at deeper depth, zero out
                torch.lt(prev_depth + 1, float(d)).float() *
                torch.zeros_like(prev_a[:, d, :]))

            next_a_d_01 = (
                torch.eq(prev_depth - 1, float(d)).float() * nofork_join_a
                +  # at shallower depth, copy over
                torch.gt(prev_depth - 1, float(d)).float() * prev_a[:, d, :]
                +  # at deeper depth, zero out
                torch.lt(prev_depth - 1, float(d)).float() *
                torch.zeros_like(prev_a[:, d, :]))

            fork_join_b = self.w_b11(torch.cat((X, prev_b[:, d, :]), 1))
            nofork_nojoin_b = self.w_b00(
                torch.cat((X, prev_a[:, d, :], next_a_d_00), 1))
            fork_nojoin_b = self.w_b10(torch.cat((X, next_a_d_10), 1))
            nofork_join_b = self.w_b01(
                torch.cat((X, prev_b[:, d - 1, :], prev_a[:, d, :]), 1))

            next_b_d_00 = (
                torch.eq(prev_depth, float(d)).float() * nofork_nojoin_b
                +  # at shallower depth, copy over
                torch.gt(prev_depth, float(d)).float() * prev_b[:, d, :]
                +  # at deeper depth, zero out
                torch.lt(prev_depth, float(d)).float() *
                torch.zeros_like(prev_b[:, d, :]))
            next_b_d_11 = (
                torch.eq(prev_depth, float(d)).float() * fork_join_b
                +  # at shallower depth, copy over
                torch.gt(prev_depth, float(d)).float() * prev_b[:, d, :]
                +  # at deeper depth, zero out
                torch.lt(prev_depth, float(d)).float() *
                torch.zeros_like(prev_b[:, d, :]))
            next_b_d_10 = (
                torch.eq(prev_depth + 1, float(d)).float() * fork_nojoin_b
                +  # at shallower depth, copy over
                torch.gt(prev_depth + 1, float(d)).float() * prev_b[:, d, :]
                +  # at deeper depth, zero out
                torch.lt(prev_depth + 1, float(d)).float() *
                torch.zeros_like(prev_b[:, d, :]))
            next_b_d_01 = (
                torch.eq(prev_depth - 1, float(d)).float() * nofork_join_b
                +  # at shallower depth, copy over
                torch.gt(prev_depth - 1, float(d)).float() * prev_b[:, d, :]
                +  # at deeper depth, zero out
                torch.lt(prev_depth - 1, float(d)).float() *
                torch.zeros_like(prev_b[:, d, :]))

            next_ab_00 = torch.cat((next_a_d_00, next_b_d_00), 1)
            next_ab_01 = torch.cat((next_a_d_01, next_b_d_01), 1)
            next_ab_10 = torch.cat((next_a_d_10, next_b_d_10), 1)
            next_ab_11 = torch.cat((next_a_d_11, next_b_d_11), 1)

            ab_00.append(next_ab_00)
            ab_01.append(next_ab_01)
            ab_10.append(next_ab_10)
            ab_11.append(next_ab_11)

        sect_end = time.time()
        self.state_compute += (sect_end - sect_start)

        sect_start = time.time()
        # Now flatten the depth and predict the attention variables:
        # next_state_flat = torch.squeeze( next_state.view(batch_size, 1, -1) )
        ab_00_flat = torch.stack(ab_00, 1).view(batch_size, 1, -1)
        ab_01_flat = torch.stack(ab_01, 1).view(batch_size, 1, -1)
        ab_10_flat = torch.stack(ab_10, 1).view(batch_size, 1, -1)
        ab_11_flat = torch.stack(ab_11, 1).view(batch_size, 1, -1)

        next_state_flat = torch.squeeze(
            torch.cat((ab_00_flat, ab_01_flat, ab_10_flat, ab_11_flat), 2), 1)

        ## These are our deterministic masks: (Dis)Allow certain states at start, end, and depth limits
        # At time 0, and only at time 0, prev_Depth is 0, so we must choose 1/0
        mask = torch.ones(batch_size, 4, device=device)
        # if we're at depth 0, we can only allow 1/0
        mask[:, (0, 1, 3)] *= (1 - torch.eq(prev_depth, 0).float())
        # if we're at depth d, we cannot allow 1/0
        mask[:, (2, )] *= (1 - torch.eq(prev_depth, self.depth).float())
        # if we're at depth 1, we cannot allow 0/1 (reduce in the middle of the sentence)
        mask[:, (1, )] *= (1 - torch.eq(prev_depth, 1).float())

        #  if we're in word-learning mode, disallow either 0/0 or 1/1. (Logic above takes care of forcing 1/0 where necessary)
        if pretrain:
            if dice > 0.5:
                mask[:, 0] *= 0
            else:
                mask[:, 3] *= 0

        # Get the attention variables
        dist = RelaxedOneHotCategorical(
            temp,
            torch.nn.functional.softmax(torch.sigmoid(
                self.attention(next_state_flat[:, self.depth_size:])),
                                        dim=1))
        att_vars = torch.nn.functional.normalize(mask * dist.sample())
        # att_vars = torch.nn.functional.softmax(mask * SampleST(torch.nn.functional.softmax( torch.sigmoid(self.attention( next_state_flat[:, self.depth_size:] ) ), dim=1 ), temp))

        if self.parsing:
            selection = ArgmaxST(att_vars)
        else:
            selection = att_vars

        sect_end = time.time()
        self.mask_time += (sect_end - sect_start)

        sect_start = time.time()
        striped = torch.mm(selection, self.selection_striper)

        ## It's ok up to here. Now we need to broadcast a dot product for each
        ## stripe across the stacked identity matrix to mask out the unwanted
        ## parts of the state space
        mask_list = []
        for b in range(batch_size):
            batch_mask = []
            expanded_stripe = striped[b].repeat(self.stride, 1)
            batch_mask = expanded_stripe * self.stripe_expander.t()
            mask_list.append(batch_mask.t())
        striped_identity = torch.stack(mask_list, 0)

        sect_end = time.time()
        self.batch_mask_time += (sect_end - sect_start)

        sect_start = time.time()

        # It's ok below here.
        hidden = torch.squeeze(
            torch.bmm(torch.unsqueeze(next_state_flat, 1), striped_identity))

        # now hidden needs to be re-viewed as batch x depth x hidden
        hidden = hidden.view(batch_size, self.depth + 1, self.hidden_size * 2)
        next_a = hidden[:, :, :self.hidden_size]
        next_b = hidden[:, :, self.hidden_size:]

        # compute the selected next depth and equivalent f/j variables from the selection:
        next_depth = torch.unsqueeze(
            (selection[:, 0].long() * prev_depth[:, 0] +
             selection[:, 1].long() *
             (prev_depth[:, 0] - 1) + selection[:, 2].long() *
             (prev_depth[:, 0] + 1) +
             selection[:, 3].long() * prev_depth[:, 0]), 1)
        f = torch.unsqueeze((selection[:, 2] * 1 + selection[:, 3] * 1), 1)
        j = torch.unsqueeze((selection[:, 1] * 1 + selection[:, 3] * 1), 1)

        sect_end = time.time()
        self.finish_time += (sect_end - sect_start)

        return next_a, next_b, next_depth, (f, j)
Exemplo n.º 26
0
    def forward(self, input, targets, args, n_particles, criterion, test=False):
        """
        This version takes the inputs, and does not expose the logits, but instead
        computes the losses directly
        """

        # run the input and teacher-forcing inputs through the embedding layers here
        seq_len, batch_sz = input.size()
        emb = self.inp_embedding(input)
        hidden = self.init_hidden(batch_sz, self.nhid, 2)  # bidirectional
        hidden_states, (h, c) = self.encoder(emb, hidden)

        # teacher-forcing
        out_emb = self.dropout(self.dec_embedding(targets))

        # now, we'll replicate it for each particle - it's currently [seq_len x batch_sz x nhid]
        hidden_states = hidden_states.repeat(1, n_particles, 1)
        out_emb = out_emb.repeat(1, n_particles, 1)
        # now [seq_len x (n_particles x batch_sz) x nhid]
        # out_emb, hidden_states should be viewed as (n_particles x batch_sz) - this means that's true for h as well

        # run the z-decoder at this point, evaluating the NLL at each step
        p_h = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True)  # initially zero
        h = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True)
        d_h = self.init_hidden(batch_sz * n_particles, self.nhid, squeeze=True)

        nlls = hidden_states.data.new(seq_len, batch_sz * n_particles)
        loss = 0

        accumulated_weights = -math.log(n_particles)  # will contain log w_{t - 1}
        resamples = 0

        for i in range(seq_len):
            h = self.z_decoder(hidden_states[i], h)
            logits = self.logits(h)

            # build the next z sample
            if test:
                q = OneHotCategorical(logits=logits)
                z = q.sample()
            else:
                q = RelaxedOneHotCategorical(temperature=self.temp, logits=logits)
                z = q.rsample()
            h = z

            # prior
            if test:
                p = OneHotCategorical(logits=p_h)
            else:
                p = RelaxedOneHotCategorical(temperature=self.temp_prior, logits=p_h)

            # now, compute the log-likelihood of the data given this mean, and the input out_emb
            d_h = self.decoder(torch.cat([z, out_emb[i]], 1), d_h)
            decoder_logits = self.out_embedding(d_h)
            NLL = criterion(decoder_logits, input[i].repeat(n_particles))
            nlls[i] = NLL.data

            # compute the weight using `reweight` on page (4)
            f_term = p.log_prob(z)  # prior
            r_term = q.log_prob(z)  # proposal
            alpha = -NLL + args.anneal * (f_term - r_term)

            wa = accumulated_weights + alpha.view(n_particles, batch_sz)

            # sample ancestors, and reindex everything
            Z = log_sum_exp(wa, dim=0)  # line 7
            if (Z.data > 0.1).any():
                pdb.set_trace()

            loss += Z  # line 8
            accumulated_weights = wa - Z  # line 9
            probs = accumulated_weights.data.exp()
            probs += 0.01
            probs = probs / probs.sum(0, keepdim=True)
            effective_sample_size = 1./probs.pow(2).sum(0)

            # resample / RSAMP if 3 batch elements need resampling
            if ((effective_sample_size / n_particles) < 0.3).sum() > 0:
                resamples += 1
                ancestors = torch.multinomial(probs.transpose(0, 1), n_particles, True)

                # now, reindex, which is the most important thing
                offsets = n_particles * torch.arange(batch_sz).unsqueeze(1).repeat(1, n_particles).long()
                if ancestors.is_cuda:
                    offsets = offsets.cuda()
                unrolled_idx = Variable(ancestors.t().contiguous()+offsets).view(-1)
                h = torch.index_select(h, 0, unrolled_idx)
                p_h = torch.index_select(p_h, 0, unrolled_idx)
                d_h = torch.index_select(d_h, 0, unrolled_idx)

                # reset accumulated_weights
                accumulated_weights = -math.log(n_particles)  # will contain log w_{t - 1}

            if i != seq_len - 1:
                # build the next mean prediction, feeding in the correct ancestor
                p_h = self.ar_prior_logits(torch.cat([h, out_emb[i]], 1), p_h)

        # now, we calculate the final log-marginal estimator
        nll = nlls.view(seq_len, n_particles, batch_sz).mean(1).sum()
        return -loss.sum(), nll, (seq_len * batch_sz), resamples
Exemplo n.º 27
0
def gumbel_softmax(input: torch.Tensor, dim: int, temp: float) -> torch.Tensor:
    """ gumbel softmax
    """
    return RelaxedOneHotCategorical(temp, input.softmax(dim=dim)).rsample()
Exemplo n.º 28
0
    def forward(self, input, args, n_particles, test=False):
        """
        evaluation is the IWAE-10 bound
        """
        if test:
            n_particles = 10
        else:
            n_particles = 1
        pi = F.log_softmax(self.pi, 0)

        # run the input and teacher-forcing inputs through the embedding layers here
        seq_len, batch_sz = input.size()
        emb = self.inp_embedding(input)
        hidden = self.init_hidden(batch_sz, self.nhid, 2)  # bidirectional
        hidden_states, (_, _) = self.encoder(emb, hidden)
        hidden_states = hidden_states.repeat(1, n_particles, 1)

        # run the z-decoder at this point, evaluating the NLL at each step
        h = (Variable(
            hidden_states.data.new(batch_sz * n_particles,
                                   self.hidden_size).zero_()),
             Variable(
                 hidden_states.data.new(batch_sz * n_particles,
                                        self.hidden_size).zero_()))

        nlls = hidden_states.data.new(seq_len, batch_sz * n_particles)
        loss = 0

        # now a log-prob
        prior_logits = pi.unsqueeze(0).expand(batch_sz * n_particles,
                                              self.z_dim)
        prior_h = (Variable(torch.zeros(batch_sz * n_particles, 50).cuda()),
                   Variable(torch.zeros(batch_sz * n_particles, 50).cuda()))

        logits = self.init_hidden(batch_sz * n_particles,
                                  self.z_dim,
                                  squeeze=True)
        feed = None

        x_emb = self.lockdrop(emb, self.dropout_x)

        for i in range(seq_len):
            # build the next z sample - not differentiable! we don't train the inference network
            logits = F.log_softmax(self.logits(hidden_states[i]), 1)

            if test:
                q = OneHotCategorical(logits=logits)
                # p = OneHotCategorical(logits=prior_logits)
                z = q.sample()
            else:
                q = RelaxedOneHotCategorical(temperature=self.temp,
                                             logits=logits)
                # p = RelaxedOneHotCategorical(temperature=self.temp_prior, logits=prior_logits)
                z = q.rsample()

            # this should be batch_sz x x_dim
            scores = torch.mm(self.project(torch.cat([h[0], z], 1)),
                              self.emit.t())

            NLL = nn.CrossEntropyLoss(reduce=False)(
                scores, input[i].repeat(n_particles))
            # KL = q.log_prob(z) - p.log_prob(z)
            KL = (logits.exp() * (logits - prior_logits)).sum(1)
            loss += (NLL + KL)
            # else:
            #     loss += (NLL + args.anneal * KL)

            nlls[i] = NLL.data

            # set things up for next time
            if i != seq_len - 1:
                feed = torch.cat(
                    [emb[i].repeat(n_particles, 1),
                     self.z_emb(z)], 1)
                prior_h = self.z_decoder(feed, prior_h)
                prior_logits = F.log_softmax(self.project_z(prior_h[0]), 1)
                h = self.hidden_rnn(x_emb[i].repeat(n_particles, 1),
                                    h)  # feed the next word into the RNN

        if n_particles != 1:
            loss = -log_sum_exp(-loss.view(n_particles, batch_sz),
                                0) + math.log(n_particles)
            NLL = -log_sum_exp(
                -nlls.view(seq_len, n_particles, batch_sz), 1) + math.log(
                    n_particles)  # not quite accurate, but what can you do
        else:
            NLL = nlls

        # now, we calculate the final log-marginal estimator
        return loss.sum(), NLL.sum(), (seq_len * batch_sz), 0
Exemplo n.º 29
0
    def sampled_filter(self, input, args, n_particles, emb, hidden_states):
        seq_len, batch_sz = input.size()
        T = F.log_softmax(self.T, 0)  # NOTE: in log-space
        pi = F.log_softmax(self.pi, 0)  # NOTE: in log-space
        emit = self.calc_emit()

        hidden_states = hidden_states.repeat(1, n_particles, 1)

        nlls = hidden_states.data.new(seq_len, batch_sz * n_particles)
        loss = 0

        accumulated_weights = -math.log(
            n_particles)  # will contain log w_{t - 1}
        resamples = 0

        # in log probability space
        prior_logits = pi.unsqueeze(0).expand(batch_sz * n_particles,
                                              self.z_dim)

        for i in range(seq_len):
            # the approximate posterior comes from the same thing as before
            logits = self.logits(hidden_states[i])

            if not self.training:
                # this is crucial!!
                p = OneHotCategorical(logits=prior_logits)
                q = OneHotCategorical(logits=logits)
                z = q.sample()
            else:
                p = RelaxedOneHotCategorical(temperature=self.temp_prior,
                                             logits=prior_logits)
                q = RelaxedOneHotCategorical(temperature=self.temp,
                                             logits=logits)
                z = q.rsample()

            # now, compute the log-likelihood of the data given this z-sample
            emission = F.embedding(input[i].repeat(n_particles), emit)
            NLL = -(emission * z).sum(1)
            # NLL = -self.decode(z, input[i].repeat(n_particles), (emit,))  # diff. w.r.t. z
            nlls[i] = NLL.data

            # compute the weight using `reweight` on page (4)
            f_term = p.log_prob(z)  # prior
            r_term = q.log_prob(z)  # proposal
            alpha = -NLL + (f_term - r_term)

            wa = accumulated_weights + alpha.view(n_particles, batch_sz)

            Z = log_sum_exp(wa, dim=0)  # line 7

            loss += Z  # line 8
            accumulated_weights = wa - Z  # F.log_softmax(wa, dim=0)  # line 9

            # sample ancestors, and reindex everything
            if args.filter:
                probs = accumulated_weights.data.exp()
                probs += 0.01
                probs = probs / probs.sum(0, keepdim=True)
                effective_sample_size = 1. / probs.pow(2).sum(0)

                # probs is [n_particles, batch_sz]
                # ancestors [2 x 15] = [[0, 0, 0, ..., 0], [0, 1, 2, 3, ...]]
                # offsets   [2 x 15] = [[0, 0, 0, ..., 0], [1, 1, 1, 1, ...]]

                # resample / RSAMP
                if ((effective_sample_size / n_particles) < 0.3).sum() > 0:
                    resamples += 1
                    ancestors = torch.multinomial(probs.transpose(0, 1),
                                                  n_particles, True)

                    # now, reindex, which is the most important thing
                    offsets = n_particles * torch.arange(batch_sz).unsqueeze(
                        1).repeat(1, n_particles).long()
                    if ancestors.is_cuda:
                        offsets = offsets.cuda()
                    unrolled_idx = Variable(ancestors + offsets).view(-1)
                    z = torch.index_select(z, 0, unrolled_idx)

                    # reset accumulated_weights
                    accumulated_weights = -math.log(
                        n_particles)  # will contain log w_{t - 1}

            if i != seq_len - 1:
                # now in log-probability space
                prior_logits = log_sum_exp(T.unsqueeze(0) + z.unsqueeze(1), 2)

        if self.training:
            (-loss.sum() /
             (seq_len * batch_sz * n_particles)).backward(retain_graph=True)
        return -loss.sum(), nlls.sum(), seq_len * batch_sz * n_particles, 0
    def inference(
        self,
        x,
        y=None,
        temperature=None,
        n_samples=1,
        reparam=True,
        encoder_key="default",
        counts=None,
    ):
        """
        Dimension choice
            (n_categories, n_is, n_batch, n_latent)

            log_q
            (n_categories, n_is, n_batch)
        """
        if temperature is None:
            raise ValueError(
                "Please provide a temperature for the relaxed OneHot distribution"
            )

        if counts is not None:
            return self.inference_defensive_sampling(
                x=x, y=y, temperature=temperature, counts=counts
            )
        n_cat = self.n_labels
        n_batch = len(x)
        # Z | X
        inp = x
        q_z1 = self.encoder_z1[encoder_key](
            inp, n_samples=n_samples, reparam=reparam, squeeze=False
        )
        # if not self.do_iaf:
        qz1_m = q_z1["q_m"]
        qz1_v = q_z1["q_v"]
        z1 = q_z1["latent"]
        assert z1.dim() == 3
        # log_qz1_x = Normal(qz1_m, qz1_v.sqrt()).log_prob(z1).sum(-1)
        log_qz1_x = q_z1["dist"].log_prob(z1)
        dfs = q_z1.get("df", None)
        if q_z1["sum_last"]:
            log_qz1_x = log_qz1_x.sum(-1)
        z1s = z1
        # torch.cuda.synchronize()

        #  C | Z
        # Broadcast labels if necessary
        qc_z1 = self.classifier[encoder_key](z1)
        log_qc_z1 = qc_z1.log()
        qc_z1_all_probas = qc_z1
        # C
        if y is None:
            if reparam:
                cat_dist = RelaxedOneHotCategorical(
                    temperature=temperature, probs=qc_z1
                )
                ys_probs = cat_dist.rsample()
            else:
                cat_dist = OneHotCategorical(probs=qc_z1)
                ys_probs = cat_dist.sample()
            ys = (ys_probs == ys_probs.max(-1, keepdim=True).values).float()
            y_int = ys.argmax(-1)
        else:
            ys = torch.cuda.FloatTensor(n_batch, n_cat)
            ys.zero_()
            ys.scatter_(1, y.view(-1, 1), 1)
            ys = ys.view(1, n_batch, n_cat).expand(n_samples, n_batch, n_cat)
            y_int = y.view(1, -1).expand(n_samples, n_batch)
        log_pc = self.y_prior.log_prob(y_int)
        assert y_int.unsqueeze(-1).shape == (n_samples, n_batch, 1), y_int.shape
        log_qc_z1 = torch.gather(log_qc_z1, dim=-1, index=y_int.unsqueeze(-1)).squeeze(
            -1
        )
        qc_z1 = torch.gather(qc_z1, dim=-1, index=y_int.unsqueeze(-1)).squeeze(-1)
        assert qc_z1.shape == (n_samples, n_batch)
        pc = log_pc.exp()

        # U | Z1, C
        z1_y = torch.cat([z1s, ys], dim=-1)
        q_z2_z1 = self.encoder_z2_z1[encoder_key](z1_y, n_samples=1, reparam=reparam)
        z2 = q_z2_z1["latent"]
        qz2_z1_m = q_z2_z1["q_m"]
        qz2_z1_v = q_z2_z1["q_v"]
        # log_qz2_z1 = Normal(q_z2_z1["q_m"], q_z2_z1["q_v"].sqrt()).log_prob(z2).sum(-1)
        log_qz2_z1 = q_z2_z1["dist"].log_prob(z2)
        if q_z2_z1["sum_last"]:
            log_qz2_z1 = log_qz2_z1.sum(-1)
        z2_y = torch.cat([z2, ys], dim=-1)
        pz1_z2m, pz1_z2_v = self.decoder_z1_z2(z2_y)
        log_pz1_z2 = Normal(pz1_z2m, pz1_z2_v.sqrt()).log_prob(z1).sum(-1)

        log_pz2 = Normal(torch.zeros_like(z2), torch.ones_like(z2)).log_prob(z2).sum(-1)

        px_z_loc = self.x_decoder(z1)
        log_px_z = Bernoulli(px_z_loc).log_prob(x).sum(-1)
        generative_density = log_pz2 + log_pc + log_pz1_z2 + log_px_z
        variational_density = log_qz1_x + log_qz2_z1
        log_ratio = generative_density - variational_density

        variables = dict(
            z1=z1,
            ys=ys,
            z2=z2,
            qz1_m=qz1_m,
            qz1_v=qz1_v,
            qz2_z1_m=qz2_z1_m,
            qz2_z1_v=qz2_z1_v,
            pz1_z2m=pz1_z2m,
            pz1_z2_v=pz1_z2_v,
            px_z_m=px_z_loc,
            log_qz1_x=log_qz1_x,
            qc_z1=qc_z1,
            log_qc_z1=log_qc_z1,
            log_qz2_z1=log_qz2_z1,
            log_pz2=log_pz2,
            log_pc=log_pc,
            pc=pc,
            log_pz1_z2=log_pz1_z2,
            log_px_z=log_px_z,
            generative_density=generative_density,
            variational_density=variational_density,
            log_ratio=log_ratio,
            qc_z1_all_probas=qc_z1_all_probas,
            df=dfs,
        )
        # torch.cuda.synchronize()
        return variables