Ejemplo n.º 1
0
def cat_softmax(probs, mode, tau=1, hard=False, dim=-1):
    if mode == 'REINFORCE' or mode == 'SCST':
        cat_distr = OneHotCategorical(probs=probs)
        return cat_distr.sample(), cat_distr.entropy()
    elif mode == 'GUMBEL':
        cat_distr = RelaxedOneHotCategorical(tau, probs=probs)
        y_soft = cat_distr.rsample()

    if hard:
        # Straight through.
        index = y_soft.max(dim, keepdim=True)[1]
        y_hard = torch.zeros_like(probs,
                                  device=DEVICE).scatter_(dim, index, 1.0)
        ret = y_hard - y_soft.detach() + y_soft
    else:
        # Reparametrization trick.
        ret = y_soft
    return ret, ret
Ejemplo n.º 2
0
    def forward(self, x):
        h1 = self.act(self.fc1(x))
        h2 = self.act(self.fc2(h1))
        h3 = self.act(self.fc3(h2))
        h4 = self.act(self.fc4(h3))
        out = self.fc5(h4)
        probs1 = F.softmax(out[:, :NG1], dim=1)
        probs2 = F.softmax(out[:, NG1:], dim=1)
        distr1 = OneHotCategorical(probs=probs1)
        distr2 = OneHotCategorical(probs=probs2)
        msg_oht1 = distr1.sample()
        msg_oht2 = distr2.sample()

        self.get_log_probs = torch.log((probs1 * msg_oht1).sum(1)) + torch.log(
            (probs2 * msg_oht2).sum(1))
        self.get_entropy = distr2.entropy()
        msg1 = msg_oht1.argmax(1)
        msg2 = msg_oht2.argmax(1)
        msgs_value = torch.cat((msg1.unsqueeze(1), msg2.unsqueeze(1)), dim=1)
        return out, msgs_value
Ejemplo n.º 3
0
    def forward(self, tgt_x):
        batch_size = tgt_x.shape[0]
        tgt_hid = self.x_to_embd(tgt_x)
        lstm_input = torch.zeros((batch_size, NG1)).cuda()
        lstm_hid = tgt_hid.squeeze(1)
        lstm_cell = tgt_hid.squeeze(1)
        msgs = []
        msgs_value = []
        logits = []
        log_probs = 0.

        for _ in range(2):
            lstm_hid, lstm_cell = self.lstm(lstm_input, (lstm_hid, lstm_cell))
            logit = self.out_layer(lstm_hid)
            logits.append(logit)
            probs = nn.functional.softmax(logit, dim=1)
            if self.training:
                cat_distr = OneHotCategorical(probs=probs)
                msg_oht, entropy = cat_distr.sample(), cat_distr.entropy()
                self.get_entropy = entropy
            else:
                msg_oht = nn.functional.one_hot(
                    torch.argmax(probs,
                                 dim=1), num_classes=self.out_size).float()
            log_probs += torch.log((probs * msg_oht).sum(1))
            msgs.append(msg_oht)
            msgs_value.append(msg_oht.argmax(1))
            lstm_input = msg_oht

        msgs = torch.stack(msgs)
        msgs_value = torch.stack(msgs_value).transpose(0, 1)
        logits = torch.stack(logits)
        logits = logits.transpose(0, 1).reshape(batch_size, -1)

        self.get_log_probs = log_probs
        return logits, msgs_value