Beispiel #1
0
    def sample(self, sample_shape=torch.Size([])):
        """
        :param ~torch.Size sample_shape: Sample shape, last dimension must be
            ``num_steps`` and must be broadcastable to
            ``(batch_size, num_steps)``. batch_size must be int not tuple.
        """
        # shape: batch_size x num_steps x categorical_size
        shape = broadcast_shape(
            torch.Size(list(self.batch_shape) + [1, 1]),
            torch.Size(list(sample_shape) + [1]),
            torch.Size((1, 1, self.event_shape[-1])),
        )
        # state: batch_size x state_dim
        state = OneHotCategorical(logits=self.initial_logits).sample()
        # sample: batch_size x num_steps x categorical_size
        sample = torch.zeros(shape)
        for i in range(shape[-2]):
            # batch_size x 1 x state_dim @
            # batch_size x state_dim x categorical_size
            obs_logits = torch.matmul(state.unsqueeze(-2),
                                      self.observation_logits).squeeze(-2)
            sample[:, i, :] = OneHotCategorical(logits=obs_logits).sample()
            # batch_size x 1 x state_dim @
            # batch_size x state_dim x state_dim
            trans_logits = torch.matmul(state.unsqueeze(-2),
                                        self.transition_logits).squeeze(-2)
            state = OneHotCategorical(logits=trans_logits).sample()

        return sample
 def test_one_hot_categorical_1d(self):
     p = Variable(torch.Tensor([0.1, 0.2, 0.3]), requires_grad=True)
     self.assertEqual(OneHotCategorical(p).sample().size(), (3,))
     self.assertEqual(OneHotCategorical(p).sample((2, 2)).size(), (2, 2, 3))
     self.assertEqual(OneHotCategorical(p).sample_n(1).size(), (1, 3))
     self._gradcheck_log_prob(OneHotCategorical, (p,))
     self.assertRaises(NotImplementedError, OneHotCategorical(p).rsample)
Beispiel #3
0
    def get_random_actions(self, obs, available_actions=None):
        batch_size = obs.shape[0]
        if available_actions is not None:
            logits = torch.ones(batch_size, self.act_dim)
            random_actions = avail_choose(logits, available_actions)
            random_actions = random_actions.sample()
            random_actions = make_onehot(random_actions, batch_size,
                                         self.act_dim).cpu().numpy()
        else:
            if self.discrete_action:
                if self.multidiscrete:
                    random_actions = [
                        OneHotCategorical(logits=torch.ones(
                            batch_size, self.act_dim[i])).sample().numpy()
                        for i in range(len(self.act_dim))
                    ]
                    random_actions = np.concatenate(random_actions, axis=-1)
                else:
                    random_actions = OneHotCategorical(logits=torch.ones(
                        batch_size, self.act_dim)).sample().numpy()
            else:
                random_actions = np.random.uniform(self.act_space.low,
                                                   self.act_space.high,
                                                   size=(batch_size,
                                                         self.act_dim))

        return random_actions
Beispiel #4
0
 def rsample(self, sample_shape=torch.Size()):
     if len(sample_shape) > 0:
         cat = OneHotCategorical(logits=self.logits).sample(
             sample_shape=sample_shape)
     else:
         cat = OneHotCategorical(logits=self.logits.transpose(1, -1))
         cat = cat.sample(sample_shape=sample_shape).transpose(1, -1)
     cat = cat[:, :, None]
     print("LS", self.locs.shape, cat.shape, self.logits.shape)
     loc = (self.locs * cat).sum(dim=1)
     scale = (self.scales * cat).sum(dim=1)
     dist = Normal(loc, scale)
     return dist.sample()
    def test_one_hot_categorical_2d(self):
        probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
        probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
        p = Variable(torch.Tensor(probabilities), requires_grad=True)
        s = Variable(torch.Tensor(probabilities_1), requires_grad=True)
        self.assertEqual(OneHotCategorical(p).sample().size(), (2, 3))
        self.assertEqual(OneHotCategorical(p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3))
        self.assertEqual(OneHotCategorical(p).sample_n(6).size(), (6, 2, 3))
        self._gradcheck_log_prob(OneHotCategorical, (p,))

        dist = OneHotCategorical(p)
        x = dist.sample()
        self.assertEqual(dist.log_prob(x), Categorical(p).log_prob(x.max(-1)[1]))
Beispiel #6
0
    def setUp(self) -> None:
        self.test_probs = torch.tensor([[0.3, 0.2, 0.4, 0.1, 0.25, 0.5, 0.25, 0.3, 0.4, 0.1, 0.1, 0.1],
                                        [0.2, 0.3, 0.1, 0.4, 0.5, 0.3, 0.2, 0.2, 0.3, 0.2, 0.2, 0.1]])
        self.test_sections = (4, 3, 5)

        self.test_actions = torch.tensor([[0., 0., 1., 0., 0., 1., 0., 0., 1., 0., 0., 0.],
                                          [0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0., 0.]]).long()

        self.test_sected_actions = torch.split(self.test_actions, self.test_sections, dim=-1)

        self.test_multi_onehot_categorical = MultiOneHotCategorical(self.test_probs, self.test_sections)

        self.test_onehot_categorical1 = OneHotCategorical(self.test_probs[:, :4])
        self.test_onehot_categorical2 = OneHotCategorical(self.test_probs[:, 4:7])
        self.test_onehot_categorical3 = OneHotCategorical(self.test_probs[:, 7:])
Beispiel #7
0
    def sample_angles(self, mean, concentration, factor, weights):
        if not (weights >= 0).all():
            print("BIG FUCKUP PRE!")
        weights = OneHotCategorical(probs=weights).sample()
        if not (weights >= 0).all():
            print("BIG FUCKUP POST!")
        print("subweights", weights)
        print(factor.shape, weights.shape)
        factor_0 = factor[torch.arange(factor.size(0)), 0,
                          weights.argmax(dim=-1)]
        factor_1 = factor[torch.arange(factor.size(0)), 1,
                          weights.argmax(dim=-1)]
        factor_2 = factor[torch.arange(factor.size(0)), 2,
                          weights.argmax(dim=-1)]
        angles_0 = self.mixture_of_von_mises(mean[:, 0], concentration[:, 0],
                                             weights)
        mean[:, 1] = mean[:, 1] + (factor_0 * angles_0).unsqueeze(-1)

        angles_1 = self.mixture_of_von_mises(mean[:, 1], concentration[:, 1],
                                             weights)
        mean[:, 2] = mean[:, 2] + (factor_1 * angles_0 +
                                   factor_2 * angles_1).unsqueeze(-1)

        angles_2 = self.mixture_of_von_mises(mean[:, 2], concentration[:, 2],
                                             weights)

        angles = torch.cat(
            (angles_0[:, None], angles_1[:, None], angles_2[:, None]), dim=1)
        return angles
Beispiel #8
0
    def forward(self,
                inp: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, None]:
        """
        Returns a sample of the policy on the input with the mean and log
        probability of the sample.

        Args:
            inp: The input tensor to put through the network.categorie

        Returns:
            A multi-categorical distribution of the network.
        """
        linear = self.linear(inp)

        value = self.value(linear)
        value = value.view(*value.shape[:-1], -1, self.num_classes)

        probs = self.probs(value)
        dist = OneHotCategorical(probs)

        sample = dist.sample()
        log_prob = dist.log_prob(sample)

        # Straight through gradient trick
        sample = sample + probs - probs.detach()

        mean = torch.argmax(probs, dim=-1).values

        return sample, log_prob, None
Beispiel #9
0
    def forward(self, x):
        params = self.network(x)

        if self.beta not in (None, 0., np.inf):
            RelaxedOneHotCategorical(temperature=1./self.beta, logits=params)

        return OneHotCategorical(logits=params)
Beispiel #10
0
    def forward(self, e0, eg):
        """Returns the logits of a OneHotCategorical distribution."""
        output = AttrDict()
        output.seq_len_logits = remove_spatial(self.p(e0, eg))
        output.seq_len_pred = OneHotCategorical(logits=output.seq_len_logits)

        return output
Beispiel #11
0
def main(n_skills, path):
    '''
    :param n_skills:
    :return:
    '''
    env = gym.make('MountainCar-v0')
    alpha = .1
    gamma = .9
    prior = OneHotCategorical(torch.ones((1, n_skills)))
    hidden_sizes = {s: [30, 30] for s in ("actor", "discriminator", "critic")}
    trainer = DIAYN(env, prior, hidden_sizes, alpha=alpha, gamma=gamma)

    path_plot = path + "plot_diayn\\"
    if not os.path.exists(path_plot):
        os.makedirs(path_plot)

    path_save = path + "save_diayn\\"
    if not os.path.exists(path_save):
        os.makedirs(path_save)

    for k in range(0, 1):
        iter_ = 200
        trainer.train(iter_)
        trainer.plot_rewards(path_plot + "diyan_train_rewards_" +
                             str((k + 1) * iter_))
        #input("Press Enter to see skills")
        mountaincar(
            trainer, path_plot + "diyan_train_trajectoires_" + str(
                (k + 1) * iter_))
        # plt.show() not needed since plt.ion() is called in diayn.py
        # plt.pause(1)
        trainer.save(path_save)
Beispiel #12
0
    def sample(self, letter, race, gender):
        """Sample name from start letter, race and gender"""
        with torch.no_grad():
            assert letter in self.vocab.start_letters, "Invalid letter"
            assert race in self.races.available_races, "Invalid race"
            assert gender in self.genders.available_genders, "Invalid gender"

            # Prepare inputs
            letter_t, race_t, gender_t = self._transform_input(letter, race, gender)
            letter_t, race_t, gender_t = self._expand_dims(letter_t, race_t, gender_t)

            # Merge all input tensors
            input = torch.cat([letter_t, race_t, gender_t], 2)
            outputs = [letter]

            # Initialize hidden states
            hx, cx = self.model.init_states(batch_size=1, device=self.device)

            while True:
                output, hx, cx = self.model(input, hx, cx, lengths=torch.tensor([1]))

                sample = OneHotCategorical(logits=output).sample()
                index = torch.argmax(sample)
                char = self.vocab.get_char(index.item())

                if char == '.' or len(outputs) == self.max_len:
                    break

                outputs.append(char)
                input = torch.cat([sample, race_t, gender_t], 2)

            name = ''.join(map(str, outputs))
            return name
Beispiel #13
0
    def generate(self, num_samples):
        with torch.no_grad():
            print("_" * 20)
            for _ in range(num_samples):
                hx, cx = self.model.init_states(batch_size=1, device=self.device)

                letter, race, gender = self._init_random_input()
                letter_t, race_t, gender_t = self._transform_input(letter, race, gender)

                input = torch.cat([letter_t, race_t, gender_t], 1)
                outputs = [letter]

                while True:
                    output, hx, cx = self.model(input, hx, cx)

                    sample = OneHotCategorical(logits=output).sample()
                    index = torch.argmax(sample)
                    char = self.vocab.idx2char[index.item()]
                    outputs.append(char)

                    input = torch.cat([sample, race_t, gender_t], 1)

                    if char == '.' or len(outputs) == 50:
                        break

                print("Start letter: {}, Race: {}, Gender: {}".format(letter, race, gender))
                print("Generated sample: {}".format(''.join(map(str, outputs))))

            print("_" * 20)
Beispiel #14
0
    def forward(self, zb, a):
        h1 = F.elu(self.lin1(zb))
        h2 = F.elu(self.lin2(h1))

        # finish forward binary and categorical covariates
        bin_out_dict = dict()

        # for each categorical variable
        for i in range(len(self.headnames)):
            # calculate probability paramater
            p_a0 = self.binheads_a0[i](h2)
            p_a1 = self.binheads_a1[i](h2)
            dist_p_a0 = torch.sigmoid(p_a0)
            dist_p_a1 = torch.sigmoid(p_a1)
            # create distribution in dict
            if self.headnames[i] == 'BINARY':
                bin_out_dict[self.headnames[i]] = bernoulli.Bernoulli((1-a)*dist_p_a0 + a*dist_p_a1)
            else:
                bin_out_dict[self.headnames[i]] = OneHotCategorical((1-a)*dist_p_a0 + a*dist_p_a1)

        # finish forward continuous vars for the right TAR head
        mu_a0 = self.mu_a0(h2)
        mu_a1 = self.mu_a1(h2)
        sigma_a0 = self.softplus(self.sigma_a0(h2))
        sigma_a1 = self.softplus(self.sigma_a1(h2))
        # cap sigma to prevent collapse for continuous vars being 0
        sigma_a0 = torch.clamp(sigma_a0, min=0.1)
        sigma_a1 = torch.clamp(sigma_a1, min=0.1)
        con_out = normal.Normal((1-a) * mu_a0 + a * mu_a1, (1-a)* sigma_a0 + a * sigma_a1)

        return con_out, bin_out_dict
Beispiel #15
0
 def sampleiter(self, bs=1):
     """
     Ancestral sampling with probability tables.
     
     1 sample is a tensor (1, 2*N).
     A minibatch of samples is a tensor (bs, 2*N).
     1 variable is a tensor (bs, N)
     """
     while True:
         with torch.no_grad():
             pA = self.pAgt.softmax(dim=0).expand(bs, -1)
             a = OneHotCategorical(pA).sample()
             pB = torch.einsum("ij,bi->bj", self.pAtoBgt.softmax(dim=1), a)
             b = OneHotCategorical(pB).sample()
             s = torch.cat([a, b], dim=1)
         yield s
Beispiel #16
0
 def backward(ctx, grad_output):
     from torch.distributions import OneHotCategorical
     x, dim = ctx.saved_tensors, ctx.dim
     if ctx.needs_input_grad[0]:
         return grad_output.unsqueeze(dim).mul(
             OneHotCategorical(logits=x.movedim(dim, -1)).sample().movedim(
                 -1, dim)), None
     return None, None
Beispiel #17
0
 def __init__(self, normal_means, normal_stds, weights):
     self.num_gaussians = weights.shape[1]
     self.normal_means = normal_means
     self.normal_stds = normal_stds
     self.normal = MultivariateDiagonalNormal(normal_means, normal_stds)
     self.normals = [MultivariateDiagonalNormal(normal_means[:, :, i], normal_stds[:, :, i]) for i in range(self.num_gaussians)]
     self.weights = weights
     self.categorical = OneHotCategorical(self.weights[:, :, 0])
Beispiel #18
0
 def forward(self, x):
     # For convenience we use torch.distributions to sample and compute the values of interest for the distribution see (https://pytorch.org/docs/stable/distributions.html) for more details.
     probs = self.encode(x.view(-1, 784))
     m = OneHotCategorical(probs)
     action = m.sample()
     log_prob = m.log_prob(action)
     entropy = m.entropy()
     return self.decode(action), log_prob, entropy
 def test_one_hot_categorical_shape(self):
     dist = OneHotCategorical(torch.Tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
     self.assertEqual(dist._batch_shape, torch.Size((3,)))
     self.assertEqual(dist._event_shape, torch.Size((2,)))
     self.assertEqual(dist.sample().size(), torch.Size((3, 2)))
     self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
     self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3,)))
     self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
     self.assertEqual(dist.log_prob(dist.enumerate_support()).size(), torch.Size((2, 3)))
Beispiel #20
0
    def get_random_actions_discrete(self, obs, available_actions=None):
        assert len(obs.shape) == 2, "No random actions on sequence"
        batch_size = obs.shape[0]
        if available_actions is not None:
            logits = torch.ones(batch_size, self.act_dim)
            random_actions = avail_choose(logits, available_actions)
            random_actions = random_actions.sample()
            random_actions = make_onehot(random_actions, batch_size, self.act_dim).cpu().numpy()
        else:
            if self.multidiscrete:
                random_actions = [OneHotCategorical(logits=torch.ones(batch_size, self.act_dim[i])).sample().numpy() for
                                  i in
                                  range(len(self.act_dim))]
                random_actions = np.concatenate(random_actions, axis=-1)
            else:
                random_actions = OneHotCategorical(logits=torch.ones(batch_size, self.act_dim)).sample().numpy()

        return random_actions
Beispiel #21
0
 def reparameterize(self, p):
     if self.training:
         # At training time we sample from a relaxed Gumbel-Softmax Distribution. The samples are continuous but when we increase the temperature the samples gets closer to a Categorical.
         m = RelaxedOneHotCategorical(TEMPERATURE, p)
         return m.rsample()
     else:
         # At testing time we sample from a Categorical Distribution.
         m = OneHotCategorical(p)
         return m.sample()
 def sampleiter(self, bs=1):
     """Ancestral Sampling from Conditional Probability Tables"""
     while True:
         h = []
         h.append(OneHotCategorical(torch.einsum(        "i->i",  self.table_asia_gt        ))      .sample((bs,)))
         h.append(OneHotCategorical(torch.einsum(    "ai,za->zi", self.table_tub_gt,    h[0]))      .sample())
         h.append(OneHotCategorical(torch.einsum(        "i->i",  self.table_smoke_gt       ))      .sample((bs,)))
         h.append(OneHotCategorical(torch.einsum(    "ai,za->zi", self.table_lung_gt,   h[2]))      .sample())
         h.append(OneHotCategorical(torch.einsum(    "ai,za->zi", self.table_bronc_gt,  h[2]))      .sample())
         h.append(OneHotCategorical(torch.einsum("bai,za,zb->zi", self.table_either_gt, h[1], h[3])).sample())
         h.append(OneHotCategorical(torch.einsum(    "ai,za->zi", self.table_xray_gt,   h[5]))      .sample())
         h.append(OneHotCategorical(torch.einsum("bai,za,zb->zi", self.table_dysp_gt,   h[4], h[5])).sample())
         yield torch.stack(h, dim=1)
 def forward(self, x):
     x = x[:, None, :]
     x = self.deconv1(x)
     x = self.batch_norm1(x)
     x = self.deconv2(x)
     x = self.batch_norm2(x)
     x = self.deconv3(x)
     x = self.batch_norm3(x)
     x = self.deconv4(x)
     x = self.batch_norm4(x)[:, 0, :]
     params = self.mlp(x)
     return OneHotCategorical(logits=params)
Beispiel #24
0
def test_fusion():
    # threshold only 1e-4 for float32 --> verify with float64
    torch.set_default_dtype(torch.float64)
    for distributions in [
        [
            Normal(loc=torch.rand(10), scale=torch.rand(10) + 1),
            Normal(loc=torch.rand(10), scale=torch.rand(10) + 1),
            Normal(loc=torch.rand(10), scale=torch.rand(10) + 1),
        ],
        [
            MultivariateNormal(
                loc=torch.rand(10),
                covariance_matrix=torch.diag(torch.rand(10)),
            ),
            MultivariateNormal(
                loc=torch.rand(10),
                covariance_matrix=torch.diag(torch.rand(10)),
            ),
            MultivariateNormal(
                loc=torch.rand(10),
                covariance_matrix=torch.diag(torch.rand(10)),
            ),
        ],
        [
            Bernoulli(logits=torch.rand(10)),
            Bernoulli(logits=torch.rand(10)),
            Bernoulli(logits=torch.rand(10)),
        ],
        [
            Categorical(logits=torch.rand(10)),
            Categorical(logits=torch.rand(10)),
            Categorical(logits=torch.rand(10)),
        ],
        [
            OneHotCategorical(logits=torch.rand(10)),
            OneHotCategorical(logits=torch.rand(10)),
            OneHotCategorical(logits=torch.rand(10)),
        ],
    ]:
        test_fusion_manually(distributions, threshold=1e-10)
Beispiel #25
0
 def similarity(self, x, y):
   logits = x.log_softmax(dim=2)
   if self.hard:
     cat = OneHotCategorical(logits=y).sample()
   else:
     cat = y.softmax(dim=2)
   sim = (logits[None, :] * cat[:, None]).view(x.size(0), x.size(0), x.size(1), -1)
   sim = sim.sum(dim=-1)
   ind = torch.arange(sim.size(0), dtype=torch.long, device=sim.device)
   sim[ind, ind] = 0.0
   count = sim.size(0) ** 2 - sim.size(0)
   sim = sim.view(-1, x.size(1)).sum(dim=0) / count
   return sim
Beispiel #26
0
def simulate_p_cond__rt_ch(p_cond__rt_ch, n_sample=1):
    """

    @param p_cond__rt_ch: [condition, frame, ch]
    @type p_cond__rt_ch: torch.Tensor
    @param ev:
    @return: p_ch_rt_sim[condition, frame, ch]
    @rtype: torch.Tensor
    """
    p_cond__rt_ch_sim = OneHotCategorical(
        p_cond__rt_ch.reshape([p_cond__rt_ch.shape[0], -1])).sample(
            [n_sample]).sum(0).reshape(p_cond__rt_ch.shape)

    # print((p_rt_ch.shape, p_rt_ch_sim.shape))
    return p_cond__rt_ch_sim
Beispiel #27
0
def generative_story(
    method: storch.method.Method, model: DiscreteVAE, data: torch.Tensor
):
    x = storch.denote_independent(data.view(-1, 784), 0, "data")

    # Encode data. Shape: (data, 2 * 10)
    q_logits = model.encode(x)
    # Shape: (data, 2, 10)
    q_logits = q_logits.reshape(-1, 2, 10)
    q = OneHotCategorical(probs=q_logits.softmax(dim=-1))
    # Sample from variational posterior
    z = method(q)

    prior = OneHotCategorical(probs=torch.ones_like(q.probs) / 10.0)
    # Shape: (data)
    KL_div = torch.distributions.kl_divergence(q, prior).sum(-1)
    storch.add_cost(KL_div, "kl-div")

    z_in = z.reshape(z.shape[:-2] + (2 * 10,))
    reconstruction = model.decode(z_in)
    bce = torch.nn.BCELoss(reduction="none")(reconstruction, x).sum(-1)
    # bce = torch.nn.BCELoss(reduction="sum")(reconstruction, x)
    storch.add_cost(bce, "reconstruction")
    return z
Beispiel #28
0
    def given_states(self, states):
        """
        Distribution conditional on the state variable.

        :param ~torch.Tensor map_states: State trajectory. Must be
            integer-valued (long) and broadcastable to
            ``(batch_size, num_steps)``.
        """
        shape = broadcast_shape(
            list(self.batch_shape) + [1, 1],
            list(states.shape[:-1]) + [1, 1],
            [1, 1, self.observation_logits.shape[-1]],
        )
        states_index = states.unsqueeze(-1) * torch.ones(shape,
                                                         dtype=torch.long)
        obs_logits = self.observation_logits * torch.ones(shape)
        logits = torch.gather(obs_logits, -2, states_index)
        return OneHotCategorical(logits=logits)
    def latent_prior_sample(self, y, n_batch, n_samples):
        n_cat = self.n_labels
        n_latent = self.n_latent

        u = Normal(
            torch.zeros(n_latent, device="cuda"), torch.ones(n_latent, device="cuda"),
        ).sample((n_samples, n_batch))

        if y is None:
            ys = OneHotCategorical(
                probs=(1.0 / n_cat) * torch.ones(n_cat, device="cuda")
            ).sample((n_samples, n_batch))
        else:
            ys = torch.cuda.FloatTensor(n_batch, n_cat)
            ys.zero_()
            ys.scatter_(1, y.view(-1, 1), 1)
            ys = ys.view(1, n_batch, n_cat).expand(n_samples, n_batch, n_cat)

        z2_y = torch.cat([u, ys], dim=-1)
        pz1_z2m, pz1_z2_v = self.decoder_z1_z2(z2_y)
        z = Normal(pz1_z2m, pz1_z2_v).sample()
        return dict(z1=z, z2=u, ys=ys)
 def sampleiter(self, bs=1):
     """
     Ancestral sampling with MLP.
     
     1 sample is a tensor (1, M, N).
     A minibatch of samples is a tensor (bs, M, N).
     1 variable is a tensor (bs, 1, N)
     """
     while True:
         with torch.no_grad():
             h = []   # Hard (onehot) samples  (bs,1,N)
             for i in range(self.M):
                 O = torch.zeros(bs, self.M-i, self.N)   # (bs,M-i,N)
                 v = torch.cat(h+[O], dim=1)             # (bs,M-i,N) + (bs,1,N)*i
                 v = torch.einsum("hik,i,bik->bh", self.W0gt[i], self.gammagt[i], v)
                 v = v + self.B0gt[i].unsqueeze(0)
                 v = self.leaky_relu(v)
                 v = torch.einsum("oh,bh->bo",     self.W1gt[i], v)
                 v = v + self.B1gt[i].unsqueeze(0)
                 v = v.softmax(dim=1).unsqueeze(1)
                 h.append(OneHotCategorical(v).sample())
             s = torch.cat(h, dim=1)
         yield s