Exemple #1
0
        def ans_match(src_seq, ans_seq):
            import torch.nn.functional as F
            BF_ans_mask = sequence_mask(ans_lengths)  # [batch, ans_seq_len]
            BF_src_mask = sequence_mask(src_lengths)  # [batch, src_seq_len]
            BF_src_outputs = src_seq.transpose(0, 1)  # [batch, src_seq_len, 2*hidden_size]
            BF_ans_outputs = ans_seq.transpose(0, 1)  # [batch, ans_seq_len, 2*hidden_size]

            # compute bi-att scores
            src_scores = BF_src_outputs.bmm(BF_ans_outputs.transpose(2, 1))  # [batch, src_seq_len, ans_seq_len]
            ans_scores = BF_ans_outputs.bmm(BF_src_outputs.transpose(2, 1))  # [batch, ans_seq_len, src_seq_len]

            # mask padding
            Expand_BF_ans_mask = BF_ans_mask.unsqueeze(1).expand(src_scores.size())  # [batch, src_seq_len, ans_seq_len]
            src_scores.data.masked_fill_(~(Expand_BF_ans_mask).bool(), -float('inf'))

            Expand_BF_src_mask = BF_src_mask.unsqueeze(1).expand(ans_scores.size())  # [batch, ans_seq_len, src_seq_len]
            ans_scores.data.masked_fill_(~(Expand_BF_src_mask).bool(), -float('inf'))

            # normalize with softmax
            src_alpha = F.softmax(src_scores, dim=2)  # [batch, src_seq_len, ans_seq_len]
            ans_alpha = F.softmax(ans_scores, dim=2)  # [batch, ans_seq_len, src_seq_len]

            # take the weighted average
            BF_src_matched_seq = src_alpha.bmm(BF_ans_outputs)  # [batch, src_seq_len, 2*hidden_size]
            src_matched_seq = BF_src_matched_seq.transpose(0, 1)  # [src_seq_len, batch, 2*hidden_size]

            BF_ans_matched_seq = ans_alpha.bmm(BF_src_outputs)  # [batch, ans_seq_len, 2*hidden_size]
            ans_matched_seq = BF_ans_matched_seq.transpose(0, 1)  # [src_seq_len, batch, 2*hidden_size]

            return src_matched_seq, ans_matched_seq
Exemple #2
0
    def forward(self, src, tgt, src_lengths=None, src_emb=None, tgt_emb=None):
        src_final, src_memory_bank = self.src_encoder(src,
                                                      src_lengths,
                                                      emb=src_emb)
        src_length, batch_size, rnn_size = src_memory_bank.size()

        tgt_final, tgt_memory_bank = self.tgt_encoder(tgt, emb=tgt_emb)
        self.q_src_h = src_memory_bank
        self.q_tgt_h = tgt_memory_bank

        src_memory_bank = src_memory_bank.transpose(
            0, 1)  # batch_size, src_length, rnn_size
        src_memory_bank = src_memory_bank.transpose(
            1, 2)  # batch_size, rnn_size, src_length
        tgt_memory_bank = self.W(tgt_memory_bank.transpose(
            0, 1))  # batch_size, tgt_length, rnn_size

        if self.dist_type == "categorical":
            scores = torch.bmm(tgt_memory_bank, src_memory_bank)
            # mask source attention
            assert (self.mask_val == -float('inf'))
            if src_lengths is not None:
                mask = sequence_mask(src_lengths)
                mask = mask.unsqueeze(1)
                scores.data.masked_fill_(1 - mask, self.mask_val)
            # scoresF should be softmax
            log_scores = F.log_softmax(scores, dim=-1)
            scores = F.softmax(scores, dim=-1)

            # Make scores : T x N x S
            scores = scores.transpose(0, 1)
            log_scores = log_scores.transpose(0, 1)

            scores = Params(
                alpha=scores,
                log_alpha=log_scores,
                dist_type=self.dist_type,
            )
        elif self.dist_type == "none":
            scores = torch.bmm(tgt_memory_bank, src_memory_bank)
            # mask source attention
            if src_lengths is not None:
                mask = sequence_mask(src_lengths)
                mask = mask.unsqueeze(1)
                scores.data.masked_fill_(1 - mask, self.mask_val)
            scores = Params(
                alpha=scores.transpose(0, 1),
                dist_type=self.dist_type,
            )
        else:
            raise Exception("Unsupported dist_type")

        # T x N x S
        return scores
Exemple #3
0
        def ans_match(src_seq, ans_seq):
            import torch.nn.functional as F
            BF_ans_mask = sequence_mask(ans_lengths)  # [batch, ans_seq_len]
            BF_src_mask = sequence_mask(src_lengths)  # [batch, src_seq_len]
            BF_src_outputs = src_seq.transpose(
                0, 1)  # [batch, src_seq_len, 2*hidden_size]
            BF_ans_outputs = ans_seq.transpose(
                0, 1)  # [batch, ans_seq_len, 2*hidden_size]

            # compute bi-att scores
            src_scores = BF_src_outputs.bmm(BF_ans_outputs.transpose(
                2, 1))  # [batch, src_seq_len, ans_seq_len]
            src_scores = bm25.view(bm25.shape[0], 1, -1).expand(
                src_scores.shape) * src_scores

            ans_scores = BF_ans_outputs.bmm(BF_src_outputs.transpose(
                2, 1))  # [batch, ans_seq_len, src_seq_len]

            # mask padding
            Expand_BF_ans_mask = BF_ans_mask.unsqueeze(1).expand(
                src_scores.size())  # [batch, src_seq_len, ans_seq_len]
            src_scores.data.masked_fill_(~(Expand_BF_ans_mask).bool(),
                                         -float('inf'))
            # for i in range(src_scores.size()[0]):
            #     src_scores[i] = bm25[i] * src_scores[i]
            # UNIFORM ATTENTION
            # src_scores = torch.ones(src_scores.shape).to(ans_seq.device)

            Expand_BF_src_mask = BF_src_mask.unsqueeze(1).expand(
                ans_scores.size())  # [batch, ans_seq_len, src_seq_len]
            ans_scores.data.masked_fill_(~(Expand_BF_src_mask).bool(),
                                         -float('inf'))

            # normalize with softmax
            src_alpha = F.softmax(
                src_scores,
                dim=2)  # [batch, src_seq_len, ans_seq_len]  news2src
            ans_alpha = F.softmax(ans_scores,
                                  dim=2)  # [batch, ans_seq_len, src_seq_len]

            # take the weighted average
            BF_src_matched_seq = src_alpha.bmm(
                BF_ans_outputs)  # [batch, src_seq_len, 2*hidden_size]
            src_matched_seq = BF_src_matched_seq.transpose(
                0, 1)  # [src_seq_len, batch, 2*hidden_size]

            BF_ans_matched_seq = ans_alpha.bmm(
                BF_src_outputs)  # [batch, ans_seq_len, 2*hidden_size]
            ans_matched_seq = BF_ans_matched_seq.transpose(
                0, 1)  # [src_seq_len, batch, 2*hidden_size]

            return src_matched_seq, ans_matched_seq
    def forward(self, input, context, context_lengths=None, coverage=None):
        """

        Args:
          input (`FloatTensor`): query vectors `[batch x tgt_len x hidden_size]`
          context (`FloatTensor`): source vectors `[batch x src_len x hidden_size]`
          context_lengths (`LongTensor`): the source context lengths `[batch]`
          coverage (`FloatTensor`): None (not supported yet)

        Returns:
          (`FloatTensor`, `FloatTensor`):

          * Computed vector `[tgt_len x batch x hidden_size]`
          * Attention distribtutions for each query
             `[tgt_len x batch x src_len]`
        """

        batch, sourceL, context_size = context.size()
        batch_, targetL, hidden_size = input.size()
        aeq(batch, batch_)

        # compute attention scores, as in Luong et al.
        align = self.score(input,
                           context)  # BS x tgt_len x src_len   64 x 19 x 13

        # pdb.set_trace()

        if context_lengths is not None:
            mask = sequence_mask(context_lengths)
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.data.masked_fill_(1 - mask, -float('inf'))

        # Softmax to normalize attention weights
        align_vectors = self.sm(align.view(batch * targetL, sourceL))
        align_vectors = align_vectors.view(batch, targetL, sourceL)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, context)

        # concatenate
        concat_c = torch.cat([c, input], 2).view(batch * targetL, -1)
        attn_h = self.linear_out(concat_c).view(batch, targetL, hidden_size)
        if self.attn_type in ["general", "dot"]:
            attn_h = self.tanh(attn_h)

        attn_h = attn_h.transpose(0, 1).contiguous()
        align_vectors = align_vectors.transpose(0, 1).contiguous()

        # Check output sizes
        targetL_, batch_, dim_ = attn_h.size()
        # aeq(targetL, targetL_)
        # aeq(batch, batch_)
        # aeq(hidden_size, dim_)
        targetL_, batch_, sourceL_ = align_vectors.size()
        # aeq(targetL, targetL_)
        # aeq(batch, batch_)
        # aeq(sourceL, sourceL_)

        return attn_h, align_vectors
Exemple #5
0
    def forward(self, src, tgt, src_lengths=None, memory_bank=None):
        #src_final, src_memory_bank = self.src_encoder(src, src_lengths)
        #src_length, batch_size, rnn_size = src_memory_bank.size()
        src_memory_bank = memory_bank.transpose(0, 1).transpose(1, 2)
        if self.inference_network_type == 'embedding_only':
            tgt_memory_bank = self.tgt_encoder(tgt)
        else:
            tgt_final, tgt_memory_bank = self.tgt_encoder(tgt)
        #src_memory_bank = src_memory_bank.transpose(0,1) # batch_size, src_length, rnn_size
        #src_memory_bank = src_memory_bank.contiguous().view(-1, rnn_size) # batch_size*src_length, rnn_size
        #src_memory_bank = self.W(src_memory_bank) \
        #                      .view(batch_size, src_length, rnn_size)
        #src_memory_bank = src_memory_bank.transpose(1,2) # batch_size, rnn_size, src_length
        tgt_memory_bank = tgt_memory_bank.transpose(
            0, 1)  # batch_size, tgt_length, rnn_size

        if self.dist_type == "dirichlet":
            # probably broken
            scores = torch.bmm(tgt_memory_bank, src_memory_bank)
            scores = [scores]
        elif self.dist_type == "normal":
            # log normal
            src_memory_bank = src_memory_bank.transpose(1, 2)
            #assert src_memory_bank.size() == (batch_size, src_length, rnn_size)
            scores = self.get_normal_scores(src_memory_bank, tgt_memory_bank)
        elif self.dist_type == "none":
            scores = [torch.bmm(tgt_memory_bank, src_memory_bank)]
        else:
            raise Exception("Unsupported dist_type")

        nparam = len(scores)
        # length
        if src_lengths is not None:
            mask = sequence_mask(src_lengths)
            mask = mask.unsqueeze(1)
            if self.dist_type == 'normal':
                scores[0].data.masked_fill_(1 - mask, -999)
                scores[1].data.masked_fill_(1 - mask, 0.001)
            else:
                for i in range(nparam):
                    scores[i].data.masked_fill_(1 - mask, self.mask_val)
        return scores
Exemple #6
0
    def forward(self, src, tgt, src_lengths=None):
        src_final, src_memory_bank = self.src_encoder(src, src_lengths)
        src_length, batch_size, rnn_size = src_memory_bank.size()
        tgt_final, tgt_memory_bank = self.tgt_encoder(tgt)
        src_memory_bank = src_memory_bank.transpose(
            0, 1)  # batch_size, src_length, rnn_size
        src_memory_bank = src_memory_bank.contiguous().view(
            -1, rnn_size)  # batch_size*src_length, rnn_size
        src_memory_bank = self.W(src_memory_bank) \
                              .view(batch_size, src_length, rnn_size)
        src_memory_bank = src_memory_bank.transpose(
            1, 2)  # batch_size, rnn_size, src_length
        tgt_memory_bank = tgt_memory_bank.transpose(
            0, 1)  # batch_size, tgt_length, rnn_size
        if self.dist_type == "dirichlet":
            scores = torch.bmm(tgt_memory_bank, src_memory_bank)
            #print("max: {}, min: {}".format(scores.max(), scores.min()))
            # affine
            scores = scores - scores.min(-1)[0].unsqueeze(-1) + 1e-2
            # exp
            #scores = scores.clamp(-1, 1).exp()
            #scores = scores.clamp(min=1e-2)
            scores = [scores]
        elif self.dist_type == "normal":
            # log normal
            src_memory_bank = src_memory_bank.transpose(1, 2)
            assert src_memory_bank.size() == (batch_size, src_length, rnn_size)
            scores = self.get_normal_scores(src_memory_bank, tgt_memory_bank)
        elif self.dist_type == "none":
            scores = [torch.bmm(tgt_memory_bank, src_memory_bank)]
        else:
            raise Exception("Unsupported dist_type")

        nparam = len(scores)
        # length
        if src_lengths is not None:
            mask = sequence_mask(src_lengths)
            mask = mask.unsqueeze(1)
            for i in range(nparam):
                scores[i].data.masked_fill_(1 - mask, self.mask_val)
        return scores
Exemple #7
0
    def encode_seq(self, seq, seq_lengths):
        """ Encode sequence `seq` using its lengths `seq_lengths`
            to mask paddings and return the average sequence encoding.

            seq  (sequence_length, batch_size, feats):
                Sequence encodings.
            seq_length (batch_size):
                Sequence lengths.
        """
        # mask => [B,n], seq_lengths => [B]
        mask = sequence_mask(seq_lengths)
        # mask => [B,n,1]
        mask = mask.unsqueeze(2)  # Make it broadcastable.
        mask = Variable(mask.type(torch.Tensor),
                        requires_grad=False)  # convert to a float variable
        # x => [B,n,d]
        seq = seq.transpose(0, 1)
        seq = seq * mask
        # average/sum
        h = seq.sum(1) / mask.sum(1)  # [B,d] / [B,1]
        return h
    def forward(self, input, memory_bank, memory_lengths=None, coverage=None):
        """

        Args:
          input (`FloatTensor`): query vectors `[batch x tgt_len x dim]`
          memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]`
          memory_lengths (`LongTensor`): the source context lengths `[batch]`
          coverage (`FloatTensor`): None (not supported yet)

        Returns:
          (`FloatTensor`, `FloatTensor`):

          * Computed vector `[tgt_len x batch x dim]`
          * Attention distribtutions for each query
             `[tgt_len x batch x src_len]`
        """

        # one step input
        if input.dim() == 2:
            one_step = True
            input = input.unsqueeze(1)
        else:
            one_step = False

        batch, sourceL, dim = memory_bank.size()
        batch_, targetL, dim_ = input.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)
        if coverage is not None:
            batch_, sourceL_ = coverage.size()
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)

        if coverage is not None:
            cover = coverage.view(-1).unsqueeze(1)
            memory_bank += self.linear_cover(cover).view_as(memory_bank)
            memory_bank = self.tanh(memory_bank)

        # compute attention scores, as in Luong et al.
        align = self.score(input, memory_bank)

        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths)
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.data.masked_fill_(1 - mask, -float('inf'))

        # Softmax to normalize attention weights
        align_vectors = self.sm(align.view(batch*targetL, sourceL))
        align_vectors = align_vectors.view(batch, targetL, sourceL)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, memory_bank)

        # concatenate
        concat_c = torch.cat([c, input], 2).view(batch*targetL, dim*2)
        attn_h = self.linear_out(concat_c).view(batch, targetL, dim)
        if self.attn_type in ["general", "dot"]:
            attn_h = self.tanh(attn_h)

        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)

            # Check output sizes
            batch_, dim_ = attn_h.size()
            aeq(batch, batch_)
            aeq(dim, dim_)
            batch_, sourceL_ = align_vectors.size()
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)
        else:
            attn_h = attn_h.transpose(0, 1).contiguous()
            align_vectors = align_vectors.transpose(0, 1).contiguous()

            # Check output sizes
            targetL_, batch_, dim_ = attn_h.size()
            aeq(targetL, targetL_)
            aeq(batch, batch_)
            aeq(dim, dim_)
            targetL_, batch_, sourceL_ = align_vectors.size()
            aeq(targetL, targetL_)
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)

        return attn_h, align_vectors
    def forward(self, input, memory_bank, memory_lengths=None, coverage=None):
        """

        Args:
          input (`FloatTensor`): query vectors `[batch x tgt_len x dim]`
          memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]`
          memory_lengths (`LongTensor`): the source context lengths `[batch]`
          coverage (`FloatTensor`): None (not supported yet)

        Returns:
          (`FloatTensor`, `FloatTensor`):

          * Computed vector `[tgt_len x batch x dim]`
          * Attention distribtutions for each query
             `[tgt_len x batch x src_len]`
        """

        # one step input
        if input.dim() == 2:
            one_step = True
            input = input.unsqueeze(1)
        else:
            one_step = False

        batch, sourceL, dim = memory_bank.size()
        batch_, targetL, dim_ = input.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)
        if coverage is not None:
            batch_, sourceL_ = coverage.size()
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)

        if coverage is not None:
            cover = coverage.view(-1).unsqueeze(1)
            memory_bank += self.linear_cover(cover).view_as(memory_bank)
            memory_bank = self.tanh(memory_bank)

        # compute attention scores, as in Luong et al.
        align = self.score(input, memory_bank)
        assert memory_lengths is not None
        mask = sequence_mask(memory_lengths)
        mask = mask.unsqueeze(1)  # Make it broadcastable.
        # mask the time step of self
        mask = mask.repeat(1, sourceL, 1)
        mask_self_index = list(range(sourceL))
        mask[:, mask_self_index, mask_self_index] = 0

        if self.attn_type == "fine":
            mask = mask.unsqueeze(3)

        align.data.masked_fill_(1 - mask, -float('inf'))

        # Softmax to normalize attention weights
        align_vectors = self.sm(align)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        if self.attn_type == "fine":
            c = memory_bank.unsqueeze(1).mul(align_vectors).sum(dim=2,
                                                                keepdim=False)
        else:
            c = torch.bmm(align_vectors, memory_bank)

        # concatenate
        concat_c = torch.cat([c, input], 2)
        attn_h = self.linear_out(concat_c)
        if self.attn_type in ["general", "dot"]:
            # attn_h = F.elu(attn_h, 0.1)
            # attn_h = F.elu(self.dropout(attn_h) + input, 0.1)

            # content selection gate
            if not self.no_gate:
                attn_h = F.sigmoid(attn_h).mul(input)

        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)

            # Check output sizes
            batch_, dim_ = attn_h.size()
            aeq(batch, batch_)
            aeq(dim, dim_)
            batch_, sourceL_ = align_vectors.size()
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)
        else:
            attn_h = attn_h.transpose(0, 1).contiguous()
            align_vectors = align_vectors.transpose(0, 1).contiguous()

            # Check output sizes
            targetL_, batch_, dim_ = attn_h.size()
            aeq(targetL, targetL_)
            aeq(batch, batch_)
            aeq(dim, dim_)
            targetL_, batch_, sourceL_ = align_vectors.size()
            aeq(targetL, targetL_)
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)

        return attn_h, align_vectors
Exemple #10
0
    def forward(self, input, memory_bank, memory_lengths=None, stage1_target=None, plan_attn=None,
                player_row_indices=None, team_row_indices=None):
        """

        Args:
          input (`FloatTensor`): query vectors `[batch x tgt_len x dim]`
          memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]`
          memory_lengths (`LongTensor`): the source context lengths `[batch]`

        Returns:
          (`FloatTensor`, `FloatTensor`):

          * Computed vector `[tgt_len x batch x dim]`
          * Attention distribtutions for each query
             `[tgt_len x batch x src_len]`
        """
        PLAYER_ROWS = 26
        TEAM_ROWS = 2
        EXTRA_RECORDS = 4
        PLAYER_COLS = 22
        TEAM_COLS = 15
        PLAYER_RECORDS_MAX=EXTRA_RECORDS+PLAYER_ROWS*PLAYER_COLS

        # one step input
        if input.dim() == 2:
            one_step = True
            input = input.unsqueeze(1)
        else:
            one_step = False

        batch, sourceL, dim = memory_bank.size()
        #print 'batch, sourceL, dim',batch, sourceL, dim
        batch_, targetL, dim_ = input.size()
        #print 'batch_, targetL, dim_',batch_, targetL, dim_
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)
        SOURCEL = sourceL
        targetL_st1_tgt, batch_st1_tgt,_= stage1_target.size()
        batch_plan, target_plan = plan_attn.size()
        aeq(batch_plan, batch)
        aeq(batch_plan, batch_st1_tgt)
        aeq(target_plan, targetL_st1_tgt)

        target_player_indices_L, batch_player_ind, player_rows_len = player_row_indices.size()
        aeq(target_player_indices_L, targetL_st1_tgt)
        aeq(batch_player_ind, batch)
        aeq(player_rows_len, PLAYER_ROWS)

        target_team_indices_L, batch_team_ind, team_rows_len = team_row_indices.size()
        aeq(target_team_indices_L, targetL_st1_tgt)
        aeq(batch_team_ind, batch)
        aeq(team_rows_len, TEAM_ROWS)

        # compute attention scores, as in Luong et al.
        align = self.score(input, memory_bank)

        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths)
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.data.masked_fill_(1 - mask, -float('inf'))

        # Softmax to normalize attention weights
        align = align.view(batch * targetL, sourceL)
        align_player_cells = self.sm(align[:,EXTRA_RECORDS:PLAYER_RECORDS_MAX].contiguous().view(-1, PLAYER_COLS))
        align_team_cells = self.sm(align[:,PLAYER_RECORDS_MAX:SOURCEL].contiguous().view(-1, TEAM_COLS))
        row_indices = (stage1_target.data.squeeze(2)-EXTRA_RECORDS)/PLAYER_COLS
        prob_prod = plan_attn.t() * Variable(row_indices.lt(PLAYER_ROWS).float(), requires_grad=False)    #stores probabilities for player records
        # (batch, 1, t_len_plan) x (batch, t_len_plan, 26) --> (batch, 1, 26)
        player_prob = torch.bmm(prob_prod.t().unsqueeze(1), player_row_indices.transpose(0,1).float()).squeeze(1)
        player_prob = player_prob.unsqueeze(2).expand(-1,-1,PLAYER_COLS).contiguous().view(-1,PLAYER_COLS)
        player_prob_table = align_player_cells*player_prob

        prob_prod = plan_attn.t() * Variable(row_indices.ge(PLAYER_ROWS).float(), requires_grad=False)    #stores probabilities for team records
        # (batch, 1, t_len_plan) x (batch, t_len_plan, 2) --> (batch, 1, 2)
        team_prob = torch.bmm(prob_prod.t().unsqueeze(1), team_row_indices.transpose(0,1).float()).squeeze(1)
        team_prob = team_prob.unsqueeze(2).expand(-1,-1,TEAM_COLS).contiguous().view(-1,TEAM_COLS)
        team_prob_table = align_team_cells*team_prob

        extra_prob_table = Variable(self.tt.FloatTensor(batch, EXTRA_RECORDS).fill_(0), requires_grad=False)
        align_vectors = torch.cat([extra_prob_table, player_prob_table.view(batch,-1), team_prob_table.view(batch,-1)],1)
        align_vectors = align_vectors.view(batch, targetL, sourceL)

        batch_table, sourceL_table, dim_table = memory_bank.size()
        aeq(batch, batch_table)
        aeq(dim, dim_table)


        if one_step:
            align_vectors = align_vectors.squeeze(1)

            # Check output sizes
            batch_, sourceL_ = align_vectors.size()
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)
        else:
            align_vectors = align_vectors.transpose(0, 1).contiguous()

            aeq(dim, dim_)
            targetL_, batch_, sourceL_ = align_vectors.size()
            aeq(targetL, targetL_)
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)

        return align_vectors
    def forward(self, input, context, context_lengths=None, coverage=None):
        """
        input (FloatTensor): batch x tgt_len x dim: decoder's rnn's output.
        context (FloatTensor): batch x src_len x dim: src hidden states
        context_lengths (LongTensor): the source context lengths.
        coverage (FloatTensor): None (not supported yet)
        """

        # one step input
        if input.dim() == 2:
            one_step = True
            input = input.unsqueeze(1)
        else:
            one_step = False

        batch, sourceL, dim = context.size()
        batch_, targetL, dim_ = input.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)
        if coverage is not None:
            batch_, sourceL_ = coverage.size()
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)

        if coverage is not None:
            cover = coverage.view(-1).unsqueeze(1)
            context += self.linear_cover(cover).view_as(context)
            context = self.tanh(context)

        # compute attention scores, as in Luong et al.
        align = self.score(input, context)

        if context_lengths is not None:
            mask = sequence_mask(context_lengths)
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.data.masked_fill_(1 - mask, -float('inf'))

        # Softmax to normalize attention weights
        align_vectors = self.sm(align.view(batch * targetL, sourceL))
        align_vectors = align_vectors.view(batch, targetL, sourceL)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, context)

        # concatenate
        concat_c = torch.cat([c, input], 2).view(batch * targetL, dim * 2)
        attn_h = self.linear_out(concat_c).view(batch, targetL, dim)
        if self.attn_type in ["general", "dot"]:
            attn_h = self.tanh(attn_h)

        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)

            # Check output sizes
            batch_, dim_ = attn_h.size()
            aeq(batch, batch_)
            aeq(dim, dim_)
            batch_, sourceL_ = align_vectors.size()
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)
        else:
            attn_h = attn_h.transpose(0, 1).contiguous()
            align_vectors = align_vectors.transpose(0, 1).contiguous()

            # Check output sizes
            targetL_, batch_, dim_ = attn_h.size()
            aeq(targetL, targetL_)
            aeq(batch, batch_)
            aeq(dim, dim_)
            targetL_, batch_, sourceL_ = align_vectors.size()
            aeq(targetL, targetL_)
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)

        return attn_h, align_vectors
Exemple #12
0
    def forward(self, input, memory_bank, memory_lengths=None, coverage=None):
        """

        Args:
          input (`FloatTensor`): query vectors `[batch x tgt_len x dim]`
          memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]`
          memory_lengths (`LongTensor`): the source context lengths `[batch]`
          coverage (`FloatTensor`): None (not supported yet)

        Returns:
          (`FloatTensor`, `FloatTensor`):

          * Computed vector `[tgt_len x batch x dim]`
          * Attention distribtutions for each query
             `[tgt_len x batch x src_len]`
        """

        # one step input
        if input.dim() == 2:
            one_step = True
            input = input.unsqueeze(1)
        else:
            one_step = False

        batch, sourcel, dim = memory_bank.size()
        batch_, targetl, dim_ = input.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)
        if coverage is not None:
            batch_, sourcel_ = coverage.size()
            aeq(batch, batch_)
            aeq(sourcel, sourcel_)

        if coverage is not None:
            cover = coverage.view(-1).unsqueeze(1)
            memory_bank += self.linear_cover(cover).view_as(memory_bank)
            memory_bank = self.tanh(memory_bank)

        # compute attention scores, as in Luong et al.
        align = self.score(input, memory_bank)

        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths)
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.data.masked_fill_(1 - mask, -float('inf'))

        # Softmax to normalize attention weights
        align_vectors = self.sm(align.view(batch * targetl, sourcel))
        align_vectors = align_vectors.view(batch, targetl, sourcel)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, memory_bank)

        # concatenate
        concat_c = torch.cat([c, input], 2).view(batch * targetl, dim * 2)
        attn_h = self.linear_out(concat_c).view(batch, targetl, dim)
        if self.attn_type in ["general", "dot"]:
            attn_h = self.tanh(attn_h)

        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)

            # Check output sizes
            batch_, dim_ = attn_h.size()
            aeq(batch, batch_)
            aeq(dim, dim_)
            batch_, sourcel_ = align_vectors.size()
            aeq(batch, batch_)
            aeq(sourcel, sourcel_)
        else:
            attn_h = attn_h.transpose(0, 1).contiguous()
            align_vectors = align_vectors.transpose(0, 1).contiguous()

            # Check output sizes
            targetl_, batch_, dim_ = attn_h.size()
            aeq(targetl, targetl_)
            aeq(batch, batch_)
            aeq(dim, dim_)
            targetl_, batch_, sourcel_ = align_vectors.size()
            aeq(targetl, targetl_)
            aeq(batch, batch_)
            aeq(sourcel, sourcel_)

        return attn_h, align_vectors
    def forward(self,
                input,
                context,
                context_lengths=None,
                coverage=None,
                embedding_now=None,
                embedding_copy=None,
                word_freq=None):
        """

        Args:
          input (`FloatTensor`): query vectors `[batch x tgt_len x dim]`, decoder hidden state at each timestep
          context (`FloatTensor`): source vectors `[batch x src_len x dim]`, encoder hidden state at each timestep
          context_lengths (`LongTensor`): the source context lengths `[batch]`
          coverage (`FloatTensor`): None (not supported yet)
          embedding_copy (`FloatTensor`): the original input sequence embeddings with affect `[batch x src_len x emb_dim]`
          word_freq (`FloatTensor`): the word frequency `[batch x src_len]`

        Returns:
          (`FloatTensor`, `FloatTensor`):

          * Computed vector `[tgt_len x batch x dim]`
          * Attention distribtutions for each query
             `[tgt_len x batch x src_len]`
        """

        # one step input for InputFeedDecoder
        if input.dim() == 2:
            one_step = True
            input = input.unsqueeze(1)
        else:
            one_step = False

        batch, sourceL, dim = context.size()
        batch_, targetL, dim_ = input.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)
        if coverage is not None:
            batch_, sourceL_ = coverage.size()
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)

        if coverage is not None:
            cover = coverage.view(-1).unsqueeze(1)
            context += self.linear_cover(cover).view_as(context)
            context = self.tanh(context)

        # compute attention scores, as in Luong et al.
        # Add affective attention here px
        align = self.score(input, context, embedding_now, embedding_copy,
                           word_freq)

        if context_lengths is not None:
            mask = sequence_mask(context_lengths)
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.data.masked_fill_(1 - mask, -float('inf'))

        # Softmax to normalize attention weights
        align_vectors = self.sm(align.view(batch * targetL, sourceL))
        align_vectors = align_vectors.view(batch, targetL, sourceL)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, context)

        # concatenate
        concat_c = torch.cat([c, input], 2).view(batch * targetL, dim * 2)
        attn_h = self.linear_out(concat_c).view(batch, targetL, dim)
        if self.attn_type in ["general", "dot"]:
            attn_h = self.tanh(attn_h)

        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)

            # Check output sizes
            batch_, dim_ = attn_h.size()
            aeq(batch, batch_)
            aeq(dim, dim_)
            batch_, sourceL_ = align_vectors.size()
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)
        else:
            attn_h = attn_h.transpose(0, 1).contiguous()
            align_vectors = align_vectors.transpose(0, 1).contiguous()

            # Check output sizes
            targetL_, batch_, dim_ = attn_h.size()
            aeq(targetL, targetL_)
            aeq(batch, batch_)
            aeq(dim, dim_)
            targetL_, batch_, sourceL_ = align_vectors.size()
            aeq(targetL, targetL_)
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)

        return attn_h, align_vectors
    def forward(self,
                input,
                memory_bank,
                memory_lengths=None,
                coverage=None,
                emb_weight=None,
                idf_weights=None):
        """

        Args:
          input (`FloatTensor`): query vectors `[batch x tgt_len x dim]`
          memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]`
          memory_lengths (`LongTensor`): the source context lengths `[batch]`
          coverage (`FloatTensor`): None (not supported yet)
          
          # thkim
          emb_weight : maybe intra attention related ...
          idf_weights : idf values, multiply it to attn weight

        Returns:
          (`FloatTensor`, `FloatTensor`):

          * Computed vector `[tgt_len x batch x dim]`
          * Attention distribtutions for each query
             `[tgt_len x batch x src_len]`
        """

        # one step input
        if input.dim() == 2:
            one_step = True
            input = input.unsqueeze(1)
        else:
            one_step = False

        batch, sourceL, dim = memory_bank.size()
        batch_, targetL, dim_ = input.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)
        if coverage is not None:
            batch_, sourceL_ = coverage.size()
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)

        if coverage is not None:
            cover = coverage.view(-1).unsqueeze(1)
            memory_bank += self.linear_cover(cover).view_as(memory_bank)
            memory_bank = self.tanh(memory_bank)

        # compute attention scores, as in Luong et al.
        align = self.score(input, memory_bank)

        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths)
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.data.masked_fill_(1 - mask, -float('inf'))

        ## Intra-temporal attention
        ## assum train is going on the gpu

        align = torch.exp(align)  # batch * 1(target_length) * input_length
        #         print("globalattn line 203: align")

        if len(self.attn_outputs) < 1:  # t=1
            #             print("global attn line:208, attn_outputs")
            #             print(len(self.attn_outputs))
            align_vectors = self.sm(align.view(batch * targetL, sourceL))
            align_vectors = align_vectors.view(batch, targetL, sourceL)
        else:  # t > 1
            #             print("global attn line:209, attn_outputs")
            #             print(len(self.attn_outputs))
            temporal_attns = torch.cat(self.attn_outputs,
                                       1)  # batch * len(t-1) * input_length
            normalizing_factor = torch.sum(temporal_attns, 1).unsqueeze(1)
            #             print("global attn line:214, normalizing factor")

            # wrong implementation
            # normalizing_factor = torch.autograd.Variable(torch.cat([torch.ones(align.size()[0], 1, 1).cuda(), torch.cumsum(torch.exp(align), 2).data[:,:,:-1]],2))
            #             align = torch.exp(align) / normalizing_factor
            #             align_vectors = align / torch.sum(align, 2).unsqueeze(2)

            align_vectors = align / normalizing_factor
            align_vectors = self.sm(align.view(batch * targetL, sourceL))
            align_vectors = align_vectors.view(batch, targetL, sourceL)

        # Softmax to normalize attention weights
        ## 기존 attention
#         align_vectors = self.sm(align.view(batch*targetL, sourceL))
#         align_vectors = align_vectors.view(batch, targetL, sourceL)

#         print("global attn line:270 idf_weights", torch.autograd.Variable(idf_weights.t().unsqueeze(1), requires_grad=False))
#         print("global attn line:270", align_vectors)
        if idf_weights is not None:
            align_vectors = align_vectors * torch.autograd.Variable(
                idf_weights.t().unsqueeze(1), requires_grad=False)
#         input()

# each context vector c_t is the weighted average
# over all the source hidden states
        c = torch.bmm(align_vectors,
                      memory_bank)  # for intra-temporal attention
        self.attn_outputs.append(align)
        #         print("gb attn line:237 len attn_outputs", len(self.attn_outputs))

        # ======== intra-decoder attention
        if len(self.decoder_outputs) < 1:
            # TO DO : change initial value to zero vector
            # ? what is size of zero vector? 밑에 decoder attn도 조금 이상해 보임
            # set zero vector to first case
            c_dec = input * 0
#             print("glbal-attn", "dd")
        else:
            decoder_history = torch.cat(self.decoder_outputs,
                                        1)  # batch * tgt_len(?) * dim
            decoder_align = self.score(input, decoder_history, "dec_attn")
            #             print("global attn line:223 decoder align")
            #             print(decoder_align)
            #             input()

            #             print("global-attn line:225", decoder_history)
            #             if len(self.decoder_outputs) == 5:
            #                 input()

            history_len = len(self.decoder_outputs)
            decoder_align_vectors = self.sm(
                decoder_align.view(batch * targetL, history_len))
            decoder_align_vectors = decoder_align_vectors.view(
                batch, targetL, history_len)
            #             print("global-attn line:232", decoder_align_vectors)
            c_dec = torch.bmm(decoder_align_vectors, decoder_history)

        self.decoder_outputs.append(input)

        # ========
        ##
        #         print("gb-attn line:239", self.linear_out.weight.data.size())
        #         if emb_weight is not None:
        #             print("gb-attn line:240", emb_weight.data.size())
        #             self.linear_out.weight = self.tanh(emb_weight * self.linear_out.weight)
        # print("gb-attn line:240", (self.linear_out.weight.data*emb_weight.data).size())
        # input()

        #         print("h attn line:371 c", c.size())
        #         print("h attn line:372 input", input.size())
        #         print("h attn line:372 c_dec", c_dec.size())
        #         input()

        # concatenate
        concat_c = torch.cat([c, input, c_dec],
                             2).view(batch * targetL, dim * 3)
        attn_h = self.linear_out(concat_c).view(batch, targetL, dim)
        if self.attn_type in ["general", "dot"]:
            attn_h = self.tanh(attn_h)

        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)

            # Check output sizes
            batch_, dim_ = attn_h.size()
            aeq(batch, batch_)
            aeq(dim, dim_)
            batch_, sourceL_ = align_vectors.size()
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)
        else:
            attn_h = attn_h.transpose(0, 1).contiguous()
            align_vectors = align_vectors.transpose(0, 1).contiguous()

            # Check output sizes
            targetL_, batch_, dim_ = attn_h.size()
            aeq(targetL, targetL_)
            aeq(batch, batch_)
            aeq(dim, dim_)
            targetL_, batch_, sourceL_ = align_vectors.size()
            aeq(targetL, targetL_)
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)

        return attn_h, align_vectors
    def forward(self,
                input,
                memory_bank,
                memory_lengths=None,
                coverage=None,
                q_scores=None):
        """

        Args:
          input (`FloatTensor`): query vectors `[batch x tgt_len x dim]`
          memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]`
          memory_lengths (`LongTensor`): the source context lengths `[batch]`
          coverage (`FloatTensor`): None (not supported yet)
          q_scores (`FloatTensor`): the attention params from the inference network

        Returns:
          (`FloatTensor`, `FloatTensor`):

          * Weighted context vector `[tgt_len x batch x dim]`
          * Attention distribtutions for each query
             `[tgt_len x batch x src_len]`
          * Unormalized attention scores for each query 
            `[batch x tgt_len x src_len]`
        """

        # one step input
        if input.dim() == 2:
            one_step = True
            input = input.unsqueeze(1)
            if q_scores is not None:
                # oh, I guess this is super messy
                if q_scores.alpha is not None:
                    q_scores = Params(
                        alpha=q_scores.alpha.unsqueeze(1),
                        log_alpha=q_scores.log_alpha.unsqueeze(1),
                        dist_type=q_scores.dist_type,
                    )
        else:
            one_step = False

        batch, sourceL, dim = memory_bank.size()
        batch_, targetL, dim_ = input.size()
        aeq(batch, batch_)

        # compute attention scores, as in Luong et al.
        # Params should be T x N x S
        if self.p_dist_type == "categorical":
            scores = self.score(input, memory_bank)
            if memory_lengths is not None:
                # mask : N x T x S
                mask = sequence_mask(memory_lengths)
                mask = mask.unsqueeze(1)  # Make it broadcastable.
                scores.data.masked_fill_(~mask, -float('inf'))
            if self.k > 0 and self.k < scores.size(-1):
                topk, idx = scores.data.topk(self.k)
                new_attn_score = torch.zeros_like(scores.data).fill_(
                    float("-inf"))
                new_attn_score = new_attn_score.scatter_(2, idx, topk)
                scores = new_attn_score
            log_scores = F.log_softmax(scores, dim=-1)
            scores = log_scores.exp()

            c_align_vectors = scores

            p_scores = Params(
                alpha=scores,
                log_alpha=log_scores,
                dist_type=self.p_dist_type,
            )

        # each context vector c_t is the weighted average
        # over all the source hidden states
        context_c = torch.bmm(c_align_vectors, memory_bank)
        if self.mode != 'wsram':
            concat_c = torch.cat([input, context_c], -1)
            # N x T x H
            h_c = self.tanh(self.linear_out(concat_c))
        else:
            h_c = None

        # sample or enumerate
        # y_align_vectors: K x N x T x S
        q_sample, p_sample, sample_log_probs = None, None, None
        sample_log_probs_q, sample_log_probs_p, sample_p_div_q_log = None, None, None
        if self.mode == "sample":
            if q_scores is None or self.use_prior:
                p_sample, sample_log_probs = self.sample_attn(
                    p_scores,
                    n_samples=self.n_samples,
                    lengths=memory_lengths,
                    mask=mask if memory_lengths is not None else None)
                y_align_vectors = p_sample
            else:
                q_sample, sample_log_probs = self.sample_attn(
                    q_scores,
                    n_samples=self.n_samples,
                    lengths=memory_lengths,
                    mask=mask if memory_lengths is not None else None)
                y_align_vectors = q_sample
        elif self.mode == "gumbel":
            if q_scores is None or self.use_prior:
                p_sample, _ = self.sample_attn_gumbel(
                    p_scores,
                    self.temperature,
                    n_samples=self.n_samples,
                    lengths=memory_lengths,
                    mask=mask if memory_lengths is not None else None)
                y_align_vectors = p_sample
            else:
                q_sample, _ = self.sample_attn_gumbel(
                    q_scores,
                    self.temperature,
                    n_samples=self.n_samples,
                    lengths=memory_lengths,
                    mask=mask if memory_lengths is not None else None)
                y_align_vectors = q_sample
        elif self.mode == "enum" or self.mode == "exact":
            y_align_vectors = None
        elif self.mode == "wsram":
            assert q_scores is not None
            q_sample, sample_log_probs_q, sample_log_probs_p, sample_p_div_q_log = self.sample_attn_wsram(
                q_scores,
                p_scores,
                n_samples=self.n_samples,
                lengths=memory_lengths,
                mask=mask if memory_lengths is not None else None)
            y_align_vectors = q_sample

        # context_y: K x N x T x H
        if y_align_vectors is not None:
            context_y = torch.bmm(
                y_align_vectors.view(-1, targetL, sourceL),
                memory_bank.unsqueeze(0).repeat(self.n_samples, 1, 1, 1).view(
                    -1, sourceL, dim)).view(self.n_samples, batch, targetL,
                                            dim)
        else:
            # For enumerate, K = S.
            # memory_bank: N x S x H
            context_y = (
                memory_bank.unsqueeze(0).repeat(targetL, 1, 1, 1)  # T, N, S, H
                .permute(2, 1, 0, 3))  # S, N, T, H
        input = input.unsqueeze(0).repeat(context_y.size(0), 1, 1, 1)
        concat_y = torch.cat([input, context_y], -1)
        # K x N x T x H
        h_y = self.tanh(self.linear_out(concat_y))

        if one_step:
            if h_c is not None:
                # N x H
                h_c = h_c.squeeze(1)
            # N x S
            c_align_vectors = c_align_vectors.squeeze(1)
            context_c = context_c.squeeze(1)

            # K x N x H
            h_y = h_y.squeeze(2)
            # K x N x S
            #y_align_vectors = y_align_vectors.squeeze(2)

            q_scores = Params(
                alpha=q_scores.alpha.squeeze(1)
                if q_scores.alpha is not None else None,
                dist_type=q_scores.dist_type,
                samples=q_sample.squeeze(2) if q_sample is not None else None,
                sample_log_probs=sample_log_probs.squeeze(2)
                if sample_log_probs is not None else None,
                sample_log_probs_q=sample_log_probs_q.squeeze(2)
                if sample_log_probs_q is not None else None,
                sample_log_probs_p=sample_log_probs_p.squeeze(2)
                if sample_log_probs_p is not None else None,
                sample_p_div_q_log=sample_p_div_q_log.squeeze(2)
                if sample_p_div_q_log is not None else None,
            ) if q_scores is not None else None
            p_scores = Params(
                alpha=p_scores.alpha.squeeze(1),
                log_alpha=log_scores.squeeze(1),
                dist_type=p_scores.dist_type,
                samples=p_sample.squeeze(2) if p_sample is not None else None,
            )

            if h_c is not None:
                # Check output sizes
                batch_, dim_ = h_c.size()
                aeq(batch, batch_)
                batch_, sourceL_ = c_align_vectors.size()
                aeq(batch, batch_)
                aeq(sourceL, sourceL_)
        else:
            assert False
            # Only support input feeding.
            # T x N x H
            h_c = h_c.transpose(0, 1).contiguous()
            # T x N x S
            c_align_vectors = c_align_vectors.transpose(0, 1).contiguous()

            # T x K x N x H
            h_y = h_y.permute(2, 0, 1, 3).contiguous()
            # T x K x N x S
            #y_align_vectors = y_align_vectors.permute(2, 0, 1, 3).contiguous()

            q_scores = Params(
                alpha=q_scores.alpha.transpose(0, 1).contiguous(),
                dist_type=q_scores.dist_type,
                samples=q_sample.permute(2, 0, 1, 3).contiguous(),
            )
            p_scores = Params(
                alpha=p_scores.alpha.transpose(0, 1).contiguous(),
                log_alpha=log_alpha.transpose(0, 1).contiguous(),
                dist_type=p_scores.dist_type,
                samples=p_sample.permute(2, 0, 1, 3).contiguous(),
            )

            # Check output sizes
            targetL_, batch_, dim_ = h_c.size()
            aeq(targetL, targetL_)
            aeq(batch, batch_)
            aeq(dim, dim_)
            targetL_, batch_, sourceL_ = c_align_vectors.size()
            aeq(targetL, targetL_)
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)

        # For now, don't include samples.
        dist_info = DistInfo(
            q=q_scores,
            p=p_scores,
        )

        # h_y: samples from simplex
        #   either K x N x H, or T x K x N x H
        # h_c: convex combination of memory_bank for input feeding
        #   either N x H, or T x N x H
        # align_vectors: convex coefficients / boltzmann dist
        #   either N x S, or T x N x S
        # raw_scores: unnormalized scores
        #   either N x S, or T x N x S
        return h_y, h_c, context_c, c_align_vectors, dist_info
    def forward2(self, input, memory_bank, memory_lengths=None, coverage=None):
        """

        Args:
          input (`FloatTensor`): query vectors `[batch x tgt_len x dim]`
          memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]`
          memory_lengths (`LongTensor`): the source context lengths `[batch]`
          coverage (`FloatTensor`): None (not supported yet)

        Returns:
          (`FloatTensor`, `FloatTensor`):

          * Computed vector `[tgt_len x batch x dim]`
          * Attention distribtutions for each query
             `[tgt_len x batch x src_len]`
        """

        memory_bank1, memory_bank2 = memory_bank
        if memory_lengths is not None:
            memory_lengths1, memory_lengths2 = memory_lengths

        # one step input
        if input.dim() == 2:
            one_step = True
            input = input.unsqueeze(1)
        else:
            one_step = False

        batch1, sourceL1, dim1 = memory_bank1.size()
        batch2, sourceL2, dim2 = memory_bank2.size()
        batch_, targetL, dim_ = input.size()
        aeq(batch1, batch2)
        aeq(batch1, batch_)
        aeq(dim1, dim2)
        aeq(dim1, dim_)
        aeq(self.dim, dim1)
        if coverage is not None:
            batch_, sourceL_ = coverage.size()
            aeq(batch1, batch_)
            aeq(sourceL2, sourceL_)

        if coverage is not None:
            # Todo: do not support
            cover = coverage.view(-1).unsqueeze(1)
            memory_bank2 += self.linear_cover(cover).view_as(memory_bank2)
            memory_bank2 = self.tanh(memory_bank2)

        # compute attention scores, as in Luong et al.
        align1 = self.score(input, memory_bank1)
        align2 = self.score(input, memory_bank2)

        if memory_lengths is not None:
            mask1 = sequence_mask(memory_lengths1)
            mask1 = mask1.unsqueeze(1)  # Make it broadcastable.
            align1.data.masked_fill_(~mask1, -float('inf'))

            mask2 = sequence_mask(memory_lengths2)
            mask2 = mask2.unsqueeze(1)  # Make it broadcastable.
            align2.data.masked_fill_(~mask2, -float('inf'))

        # Softmax to normalize attention weights
        align_vectors1 = self.sm(align1.view(batch1*targetL, sourceL1))
        align_vectors1 = align_vectors1.view(batch1, targetL, sourceL1)

        align_vectors2 = self.sm(align2.view(batch2 * targetL, sourceL2))
        align_vectors2 = align_vectors2.view(batch2, targetL, sourceL2)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c1 = torch.bmm(align_vectors1, memory_bank1)    # 64 * 1 * 256
        c2 = torch.bmm(align_vectors2, memory_bank2)    # 64 * 1 * 256

        # concatenate
        concat_c = torch.cat([c1, c2, input], 2).view(batch1*targetL, dim1*3)
        attn_h = self.linear_out2(concat_c).view(batch1, targetL, dim1)        # decoding output
        if self.attn_type in ["general", "dot"]:
            attn_h = self.tanh(attn_h)

        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors1 = align_vectors1.squeeze(1)
            align_vectors2 = align_vectors2.squeeze(1)
            # Check output sizes
            batch_, dim_ = attn_h.size()
            aeq(batch1, batch_)
            aeq(dim1, dim_)
            batch_, sourceL_ = align_vectors1.size()
            aeq(batch1, batch_)
            aeq(sourceL1, sourceL_)
        else:
            attn_h = attn_h.transpose(0, 1).contiguous()
            align_vectors1 = align_vectors1.transpose(0, 1).contiguous()
            align_vectors2 = align_vectors2.transpose(0, 1).contiguous()

            # Check output sizes
            targetL_, batch_, dim_ = attn_h.size()
            aeq(targetL, targetL_)
            aeq(batch1, batch_)
            aeq(dim1, dim_)
            targetL_, batch_, sourceL_ = align_vectors1.size()
            aeq(targetL, targetL_)
            aeq(batch1, batch_)
            aeq(sourceL1, sourceL_)

        return attn_h, (align_vectors1, align_vectors2)
Exemple #17
0
    def forward(self,
                input,
                memory_bank,
                entity_representation,
                memory_lengths=None,
                coverage=None,
                count_entities=None,
                total_entities_list=None):
        """

        Args:
          input (`FloatTensor`): query vectors `[batch x tgt_len x dim]`
          memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]`
          entity_representation (`FloatTensor`): source vectors `[batch x num_entities x eu_k_dim]`
          memory_lengths (`LongTensor`): the source context lengths `[batch]`
          coverage (`FloatTensor`): None (not supported yet)
          count_entities (`LongTensor`): entity lengths `[batch]`
          total_entities_list (`FloatTensor`): source vectors `[batch x num_entities x src_len]`

        Returns:
          (`FloatTensor`, `FloatTensor`):

          * Computed vector `[tgt_len x batch x dim]`
          * Attention distribtutions for each query
             `[tgt_len x batch x src_len]`
        """

        # one step input
        if input.dim() == 2:
            one_step = True
            input = input.unsqueeze(1)
        else:
            one_step = False

        batch, sourceL, dim = memory_bank.size()
        batch__, sourceL__, dim__ = entity_representation.size()
        batch_, targetL, dim_ = input.size()
        batch___, num_entities, src_len = total_entities_list.size()
        aeq(batch, batch_)
        aeq(batch, batch__)
        aeq(self.dim, dim_)
        aeq(self.entity_dim, dim__)
        aeq(self.dim, dim)
        aeq(sourceL, src_len)
        aeq(num_entities, sourceL__)
        if coverage is not None:
            batch_, sourceL_ = coverage.size()
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)

        if coverage is not None:
            cover = coverage.view(-1).unsqueeze(1)
            memory_bank += self.linear_cover(cover).view_as(memory_bank)
            memory_bank = self.tanh(memory_bank)

        # compute attention scores, as in Luong et al.
        entity_align = self.score(input,
                                  entity_representation,
                                  entity_attn=True)

        if count_entities is not None:
            count_entities_mask = sequence_mask(count_entities.data)
            count_entities_mask = count_entities_mask.unsqueeze(
                1)  # Make it broadcastable.
            entity_align.data.masked_fill_(1 - count_entities_mask,
                                           -float('inf'))
        entity_align_vectors = self.sm(
            entity_align.view(batch * targetL, sourceL__))
        entity_align_vectors = entity_align_vectors.unsqueeze(2).expand(
            -1, -1, sourceL)

        align = self.score(input, memory_bank)
        align = align.unsqueeze(2).expand(-1, -1, sourceL__, -1)
        total_entities_list = total_entities_list.unsqueeze(1).expand(
            -1, targetL, -1, -1)
        align = align * total_entities_list  # apply mask of records belonging to entities
        mask = total_entities_list.eq(0)
        align.data.masked_fill_(mask.data, -float('inf'))

        # Softmax to normalize attention weights
        align_vectors = self.sm(
            align.view(batch * targetL * sourceL__, sourceL))

        count_entities_mask = count_entities_mask.unsqueeze(
            3)  #.expand(-1, -1, -1, sourceL)
        align_vectors = align_vectors.view(batch, targetL, sourceL__, sourceL)
        align_vectors.data.masked_fill_(1 - count_entities_mask, 0)
        align_vectors = align_vectors.view(batch * targetL, sourceL__, sourceL)

        align_vectors = (entity_align_vectors * align_vectors).sum(1)
        align_vectors = align_vectors.view(batch, targetL, sourceL)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, memory_bank)

        # concatenate
        concat_c = torch.cat([c, input], 2).view(batch * targetL, dim * 2)
        attn_h = self.linear_out(concat_c).view(batch, targetL, dim)
        if self.attn_type in ["general", "dot"]:
            attn_h = self.tanh(attn_h)

        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)

            # Check output sizes
            batch_, dim_ = attn_h.size()
            aeq(batch, batch_)
            aeq(dim, dim_)
            batch_, sourceL_ = align_vectors.size()
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)
        else:
            attn_h = attn_h.transpose(0, 1).contiguous()
            align_vectors = align_vectors.transpose(0, 1).contiguous()

            # Check output sizes
            targetL_, batch_, dim_ = attn_h.size()
            aeq(targetL, targetL_)
            aeq(batch, batch_)
            aeq(dim, dim_)
            targetL_, batch_, sourceL_ = align_vectors.size()
            aeq(targetL, targetL_)
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)

        return attn_h, align_vectors
Exemple #18
0
    def forward(self,
                input,
                context,
                ctl,
                ctl_iter,
                context_lengths=None,
                coverage=None):
        """

        Args:
          input (`FloatTensor`): query vectors `[batch x tgt_len x dim]`
          context (`FloatTensor`): source vectors `[batch x src_len x dim]`
          context_lengths (`LongTensor`): the source context lengths `[batch]`
          coverage (`FloatTensor`): None (not supported yet)

        Returns:
          (`FloatTensor`, `FloatTensor`):

          * Computed vector `[tgt_len x batch x dim]`
          * Attention distribtutions for each query
             `[tgt_len x batch x src_len]`
        """

        # one step input
        if input.dim() == 2:
            one_step = True
            input = input.unsqueeze(1)
        else:
            one_step = False

        batch, sourceL, dim = context.size()
        batch_, targetL, dim_ = input.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)
        if coverage is not None:
            batch_, sourceL_ = coverage.size()
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)

        if coverage is not None:
            cover = coverage.view(-1).unsqueeze(1)
            context += self.linear_cover(cover).view_as(context)
            context = self.tanh(context)

        # compute attention scores, as in Luong et al.
        align = self.score(input, context)

        if context_lengths is not None:
            mask = sequence_mask(context_lengths)
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.data.masked_fill_(1 - mask, -float('inf'))

        # Softmax to normalize attention weights
        align_vectors = self.sm(align.view(batch * targetL, sourceL))
        align_vectors = align_vectors.view(batch, targetL, sourceL)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, context)

        ctl_diff = ctl.expand(-1, ctl_iter.size()[1]) - ctl_iter
        weights_attn = self.sigmoid(ctl_diff)
        weights_lm = 1 - self.sigmoid(ctl_diff)

        weights_attn = weights_attn.expand(c.size()[2], -1, -1)
        weights_lm = weights_lm.expand(c.size()[2], -1, -1)
        weights_attn = torch.transpose(weights_attn, 0, 1)
        weights_attn = torch.transpose(weights_attn, 1, 2)
        weights_lm = torch.transpose(weights_lm, 0, 1)
        weights_lm = torch.transpose(weights_lm, 1, 2)

        # concatenate
        concat_c = torch.cat([c * weights_attn, input * weights_lm],
                             2).view(batch * targetL, dim * 2)
        attn_h = self.linear_trans(concat_c).view(batch, targetL, dim)

        if self.attn_type in ["general", "dot"]:
            attn_h = self.tanh(attn_h)

        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)

            # Check output sizes
            batch_, dim_ = attn_h.size()
            aeq(batch, batch_)
            aeq(dim, dim_)
            batch_, sourceL_ = align_vectors.size()
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)
        else:
            attn_h = attn_h.transpose(0, 1).contiguous()
            align_vectors = align_vectors.transpose(0, 1).contiguous()

            # Check output sizes
            targetL_, batch_, dim_ = attn_h.size()
            aeq(targetL, targetL_)
            aeq(batch, batch_)
            aeq(dim, dim_)
            targetL_, batch_, sourceL_ = align_vectors.size()
            aeq(targetL, targetL_)
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)

        return attn_h, align_vectors
Exemple #19
0
    def forward(self, input, memory_bank, memory_lengths=None, coverage=None):
        """

        Args:
          input (`FloatTensor`): query vectors `[batch x tgt_len x dim]`
          memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]`
          memory_lengths (`LongTensor`): the source context lengths `[batch]`
          coverage (`FloatTensor`): None (not supported yet)

        Returns:
          (`FloatTensor`, `FloatTensor`):

          * Computed vector `[tgt_len x batch x dim]`
          * Attention distribtutions for each query
             `[tgt_len x batch x src_len]`
        """

        # one step input
        if input.dim() == 2:
            one_step = True
            input = input.unsqueeze(1)
        else:
            one_step = False

        batch, sourceL, dim = memory_bank.size()
        batch_, targetL, dim_ = input.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)
        if coverage is not None:
            batch_, sourceL_ = coverage.size()
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)

        if coverage is not None:
            cover = coverage.view(-1).unsqueeze(1)
            memory_bank += self.linear_cover(cover).view_as(memory_bank)
            memory_bank = self.tanh(memory_bank)

        # compute attention scores, as in Luong et al.
        # Local attention
        # Generate aligned position p_t
        if self.attn_model == "local-p": # If predictive alignment model
            p_t = torch.zeros((batch, targetL, 1), device=input.device) + (sourceL - 1)
            p_t = p_t * self.sigmoid(self.v_predictive(self.tanh(self.linear_predictive(input.view(-1, dim))))).view(batch, targetL, 1)
        elif self.attn_model == "local-m": # If monotonic alignment model
            p_t = torch.arange(targetL, device=input.device).repeat(batch, 1).view(batch, targetL, 1)
        # Create a mask to filter all scores that are outside of the window with size 2D
        indices_of_sources = torch.arange(sourceL, device=input.device).repeat(batch, targetL, 1)  # batch x tgt_len x src_len
        mask_local = (indices_of_sources >= p_t - self.D).int() & (indices_of_sources <= p_t + self.D).int()  # batch x tgt_len x src_len
        # Calculate alignment scores
        align = self.score(input, memory_bank, mask_local)

        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths)
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.data.masked_fill_(1 - mask, -float('inf'))

        # Softmax to normalize attention weights
        align_vectors = self.sm(align.view(batch*targetL, sourceL)).view(batch, targetL, sourceL)
        # align_vectors = align_vectors.view(batch, targetL, sourceL)
        # Local attention
        if self.attn_model == "local-p": # If predictive alignment model
            # Favor alignment points near p_t  by truncated Gaussian distribution
            gaussian = torch.exp(-1.0*(((indices_of_sources - p_t) ** 2))/(2*(self.D/2.0)**2)) * mask_local.float() # batch x tgt_len x src_len
            align_vectors = align_vectors * gaussian
        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, memory_bank)

        # concatenate
        concat_c = torch.cat([c, input], 2).view(batch*targetL, dim*2)
        attn_h = self.linear_out(concat_c).view(batch, targetL, dim)
        if self.attn_score_func in ["general", "dot"]:
            attn_h = self.tanh(attn_h)

        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)

            # Check output sizes
            batch_, dim_ = attn_h.size()
            aeq(batch, batch_)
            aeq(dim, dim_)
            batch_, sourceL_ = align_vectors.size()
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)
        else:
            attn_h = attn_h.transpose(0, 1).contiguous()
            align_vectors = align_vectors.transpose(0, 1).contiguous()

            # Check output sizes
            targetL_, batch_, dim_ = attn_h.size()
            aeq(targetL, targetL_)
            aeq(batch, batch_)
            aeq(dim, dim_)
            targetL_, batch_, sourceL_ = align_vectors.size()
            aeq(targetL, targetL_)
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)

        return attn_h, align_vectors
Exemple #20
0
    def _run_forward_pass(self,
                          tgt,
                          memory_bank,
                          state,
                          memory_lengths=None,
                          q_scores_sample=None,
                          q_scores=None):
        """
        See StdRNNDecoder._run_forward_pass() for description
        of arguments and return values.
        """
        # Additional args check.
        input_feed = state.input_feed.squeeze(0)
        input_feed_batch, _ = input_feed.size()
        tgt_len, tgt_batch, _ = tgt.size()
        aeq(tgt_batch, input_feed_batch)
        # END Additional args check.

        if self.dist_type == "dirichlet":
            p_a_scores = [[]]
        elif self.dist_type == "normal":
            p_a_scores = [[], []]
        else:
            p_a_scores = [[]]
        n_param = len(p_a_scores)
        # Initialize local and return variables.
        decoder_outputs = []
        attns = {"std": []}
        if q_scores_sample is not None:
            attns["q"] = []
        if q_scores is not None:
            attns["q_raw_mean"] = []
            attns["q_raw_std"] = []
            attns["p_raw_mean"] = []
            attns["p_raw_std"] = []
        if self._copy:
            attns["copy"] = []
        if self._coverage:
            attns["coverage"] = []

        emb = self.embeddings(tgt)
        assert emb.dim() == 3  # len x batch x embedding_dim

        tgt_len, batch_size = emb.size(0), emb.size(1)
        src_len = memory_bank.size(0)

        hidden = state.hidden
        #[item.fill_(0) for item in hidden]
        coverage = state.coverage.squeeze(0) \
            if state.coverage is not None else None

        # Input feed concatenates hidden state with
        # input at every time step.
        if q_scores is not None:
            q_scores_mean = q_scores[0].view(batch_size, tgt_len,
                                             -1).transpose(0, 1)
            q_scores_std = q_scores[1].view(batch_size, tgt_len,
                                            -1).transpose(0, 1)
        for i, emb_t in enumerate(emb.split(1)):
            emb_t = emb_t.squeeze(0)
            decoder_input = torch.cat([emb_t, input_feed], 1)

            rnn_output, hidden = self.rnn(decoder_input, hidden)
            if q_scores_sample is not None:
                q_sample = q_scores_sample[i]
            else:
                q_sample = None
            decoder_output, p_attn, raw_scores = self.attn(
                rnn_output,
                memory_bank.transpose(0, 1),
                memory_lengths=memory_lengths,
                q_scores_sample=q_sample)
            if q_sample is not None:
                attns["q"] += [q_sample]
                attns["q_raw_mean"] += [q_scores_mean[i]]
                attns["q_raw_std"] += [q_scores_std[i]]
                attns["p_raw_mean"] += [raw_scores[0].view(-1, src_len)]
                attns["p_raw_std"] += [raw_scores[1].view(-1, src_len)]

            # raw_scores: [batch x tgt_len x src_len]
            #assert raw_scores.size() == (batch_size, 1, src_len)

            assert len(raw_scores) == n_param
            for i in range(n_param):
                p_a_scores[i] += [raw_scores[i]]
            if self.context_gate is not None:
                # TODO: context gate should be employed
                # instead of second RNN transform.
                decoder_output = self.context_gate(decoder_input, rnn_output,
                                                   decoder_output)
            decoder_output = self.dropout(decoder_output)
            input_feed = decoder_output

            decoder_outputs += [decoder_output]
            attns["std"] += [p_attn]

            # Update the coverage attention.
            if self._coverage:
                coverage = coverage + p_attn \
                    if coverage is not None else p_attn
                attns["coverage"] += [coverage]

            # Run the forward pass of the copy attention layer.
            if self._copy and not self._reuse_copy_attn:
                _, copy_attn = self.copy_attn(decoder_output,
                                              memory_bank.transpose(0, 1))
                attns["copy"] += [copy_attn]
            elif self._copy:
                attns["copy"] = attns["std"]

        for i in range(n_param):
            p_a_scores[i] = torch.cat(p_a_scores[i], dim=1)
        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths)
            mask = mask.unsqueeze(1)
            if self.dist_type == 'normal':
                p_a_scores[0].data.masked_fill_(1 - mask, -999)
                p_a_scores[1].data.masked_fill_(1 - mask, 0.001)
            else:
                for i in range(n_param):
                    p_a_scores[i].data.masked_fill_(1 - mask, 1e-2)
        return hidden, decoder_outputs, attns, p_a_scores