Exemple #1
0
    def _ditribution(self, input):

        logits = input.sum(-1)
        batch_size = logits.shape[0]
        logits = logits.view(batch_size, -1)
        dist = modulars.GumbelCategorical(logits)
        return dist
Exemple #2
0
    def _ditribution(self, input, encoder, v_mask):

        logits = encoder(input)
        batch_size = logits.shape[0]
        logits = logits.view(batch_size, -1)
        logits.masked_fill_(v_mask == 0, -1e5)
        dist = modulars.GumbelCategorical(logits)
        return dist
Exemple #3
0
    def _ditribution(self, input):

        # logits = encoder(input)
        logits = input.sum(-1)
        batch_size = logits.shape[0]
        logits = logits.view(batch_size, -1)
        logits.masked_fill_(logits == 0, -1e3)
        dist = modulars.GumbelCategorical(logits)
        return dist
Exemple #4
0
    def forward(self, hidden_entity, hidden_trans, object_list, e_label, e_s,
                e_e, e_mask, edge, r_mask, v_mask):
        batch_size, max_len = e_mask.shape
        hidden_h = object_list.mean(1)
        logits0 = 10. * torch.zeros(batch_size,
                                    self.state_size).cuda().detach()
        z0 = modulars.GumbelCategorical(logits0)
        state = {'hidden_pre': object_list, 'hidden_h': hidden_h, 'z': z0}
        v_input = object_list

        z_ = z0
        Loss_time = []
        for i in range(max_len - 1):
            entity_emb = hidden_entity[i]
            trans_emb = hidden_trans[i]

            state_pre = state['z']

            q_input = e_label[:, i, :]

            # atten, state = self.entity_cell(entity_emb, state, v_input, q_out.unsqueeze(1))
            state = self.cell[0](entity_emb, v_input, v_mask)
            recon, l_forward, l_back = self._kl_divergence_recon(
                z_, state['z'], state_pre, state['z_back'], q_input,
                object_list, v_mask)

            z_ = self.trans_cell(trans_emb, edge, r_mask, state['z'], v_mask)
            v_input = state['hidden_pre']

            if i == 0:
                l_t = recon
            else:
                l_t = recon + l_forward  #+ l_back

            Loss_time.append(l_t.unsqueeze(-1))

        L_t = torch.cat(
            Loss_time, dim=1
        ) * e_mask[:, :
                   -1]  #/ (e_mask[:, :-1].sum(1) + 1).unsqueeze(-1).float()

        state = self.cell[-1](hidden_entity[-1], v_input, v_mask)
        recon, l_forward, l_back = self._kl_divergence_recon(
            z_, state['z'], state_pre, state['z_back'], q_input, object_list,
            v_mask)

        loss_time = torch.sum(L_t, dim=1).mean() + l_forward.mean()
        # h_out = state['hidden_pre']
        # h_out = torch.matmul(state['hidden_pre'].transpose(-1, -2), state['z'].sample().exp().unsqueeze(-1)).squeeze(-1)
        node = F.softmax(state['z'].logits, dim=1)
        h_out = torch.matmul(state['hidden_pre'].transpose(-1, -2),
                             node.unsqueeze(-1)).squeeze(-1)

        return loss_time, h_out
Exemple #5
0
    def forward(self, trans_mat, z):
        # trans_mat = self.linear_r(trans_mat)
        # q = self.linear_q(q)
        # q_ = q.unsqueeze(1).unsqueeze(1).expand(trans_mat.shape)
        # relation = self.linear_out(trans_mat * q_).squeeze(-1)

        # relation_mat = F.softmax(relation.masked_fill(r_mask == 0., -1e12), dim=1)
        relation_mat = trans_mat
        logits = torch.matmul(relation_mat, z.logits.unsqueeze(-1)).squeeze(-1)
        z_ = modulars.GumbelCategorical(logits)

        return z_
Exemple #6
0
    def forward(self, q, trans_mat, r_mask, z, v_mask):
        trans_mat = self.linear_r(trans_mat)
        batch_size, num_obj, num_obj = r_mask.shape
        q = self.linear_q(q)
        q_ = q.unsqueeze(1).unsqueeze(1).expand(trans_mat.shape)
        relation = self.linear_out(trans_mat * q_).squeeze(-1)

        relation_mat = F.softmax(relation.masked_fill(r_mask == 0.,
                                                      -float('inf')),
                                 dim=1)

        logits = torch.matmul(relation_mat, z.logits.unsqueeze(-1)).squeeze(-1)
        z_ = modulars.GumbelCategorical(logits)

        return z_
Exemple #7
0
    def forward(self, q_embeddings, object_list, e_label, e_s, e_e, e_mask,
                edge, r_mask, logits, v_mask):
        batch_size, max_len = e_mask.shape
        logits0 = 10. * torch.zeros(batch_size,
                                    self.state_size).cuda().detach()
        z0 = modulars.GumbelCategorical(logits0)
        state = {'hidden_pre': object_list, 'z': z0}
        v_input = object_list

        z_ = z0
        Loss_time = []
        Loss_recon = []
        for i in range(max_len - 1):

            trans_emb = q_embeddings[i]

            state_pre = state['z']

            logit_input = logits[:, i, :, :]

            # q_input = q[entity_end, torch.arange(batch_size)]
            q_input = e_label[:, i, :]

            state = self.entity_cell(logit_input, v_mask)
            recon, l_forward, l_back = self._kl_divergence_recon(
                z_, state['z'], state_pre, state['z_back'], q_input,
                object_list, v_mask)

            z_ = self.trans_cell(trans_emb, edge, r_mask, state['z'], v_mask)

            if i == 0:
                l_t = recon
            else:
                l_t = recon + l_forward  #+ l_back

            Loss_time.append(l_t.unsqueeze(-1))
            Loss_recon.append(recon.unsqueeze(-1))

        L_t = torch.cat(
            Loss_time, dim=1
        ) * e_mask[:, :-1]  #/ (e_mask[:, :-1].sum(1) + 1.).unsqueeze(-1)
        L_recon = torch.cat(Loss_recon, dim=1) * e_mask[:, :-1]
        loss_time = torch.sum(L_t, dim=1).mean(0)
        loss_r = torch.sum(L_recon, dim=1).mean(0)

        return loss_time
Exemple #8
0
    def forward(self, object_list, e_label, e_s, e_e, e_mask, logits,
                trans_mats, v_mask):
        batch_size, max_len = e_mask.shape
        logits0 = 10. * torch.zeros(batch_size,
                                    self.state_size).cuda().detach()
        z0 = modulars.GumbelCategorical(logits0.masked_fill_(
            v_mask == 0, -1e5))
        state = {'hidden_pre': object_list, 'z': z0}
        v_input = object_list

        z_ = z0
        Loss_time = []
        for i in range(max_len - 1):

            state_pre = state['z']

            logit_input = logits[i]
            trans_input = trans_mats[i]

            q_input = e_label[:, i, :]

            state = self.entity_cell(logit_input)
            recon, l_forward, l_back = self._kl_divergence_recon(
                z_, state['z'], state_pre, state['z_back'], q_input,
                object_list, v_mask)

            z_ = self.trans_cell(trans_input, state['z'])

            if i == 0:
                l_t = recon
            else:
                l_t = 10 * recon + 0.1 * l_forward  #+ l_back

            Loss_time.append(l_t.unsqueeze(-1))

        L_t = torch.cat(
            Loss_time, dim=1
        ) * e_mask[:, :-1]  #/ (e_mask[:, :-1].sum(1) + 1.).unsqueeze(-1)
        loss_time = torch.sum(L_t, dim=1).mean(0)

        return loss_time