示例#1
0
    def forward(self, tgt, memory_bank, memory_bank_t, step=None, **kwargs):
        """Decode, possibly stepwise."""
        if step == 0:
            self._init_cache(memory_bank)

        tgt_words = tgt[:, :, 0].transpose(0, 1)

        emb = self.embeddings(tgt, step=step)
        assert emb.dim() == 3  # len x batch x embedding_dim

        output = emb.transpose(0, 1).contiguous()
        src_memory_bank = memory_bank.transpose(0, 1).contiguous()
        ################=========================================
        tt_memory_bank = memory_bank_t.transpose(0, 1).contiguous()

        pad_idx = self.embeddings.word_padding_idx
        src_lens = kwargs["memory_lengths"]
        src_max_len = self.state["src"].shape[0]
        src_pad_mask = ~sequence_mask(src_lens, src_max_len).unsqueeze(1)
        #################=============================================
        tt_lens = kwargs["memory_lengths_t"]
        tt_max_len = self.state["tt"].shape[0]
        tt_pad_mask = ~sequence_mask(tt_lens, tt_max_len).unsqueeze(1)

        tgt_pad_mask = tgt_words.data.eq(pad_idx).unsqueeze(1)  # [B, 1, T_tgt]

        with_align = kwargs.pop('with_align', False)
        attn_aligns = []

        for i, layer in enumerate(self.transformer_layers):
            layer_cache = self.state["cache"]["layer_{}".format(i)] \
                if step is not None else None
            output, attn, attn_align = layer(output,
                                             src_memory_bank,
                                             src_pad_mask,
                                             tt_memory_bank,
                                             tt_pad_mask,
                                             tgt_pad_mask,
                                             layer_cache=layer_cache,
                                             step=step,
                                             with_align=with_align)
            if attn_align is not None:
                attn_aligns.append(attn_align)

        output = self.layer_norm(output)
        dec_outs = output.transpose(0, 1).contiguous()
        attn = attn.transpose(0, 1).contiguous()

        attns = {"std": attn}
        if self._copy:
            attns["copy"] = attn
        if with_align:
            attns["align"] = attn_aligns[self.alignment_layer]  # `(B, Q, K)`
            # attns["align"] = torch.stack(attn_aligns, 0).mean(0)  # All avg

        # TODO change the way attns is returned dict => list or tuple (onnx)
        return dec_outs, attns
示例#2
0
    def forward(self, source, memory_bank, memory_lengths=None, coverage=None):
        # here we do not need to calculate the align
        # because the answer vector is already averaged representations
        if source.dim() == 2:
            source = source.unsqueeze(1)

        batch, source_l, dim = memory_bank.size()
        batch_, target_l, dim_ = source.size()
        aeq(batch, batch_)

        # compute attention scores, as in Luong et al.
        align = self.score(source, memory_bank)

        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths, max_len=align.size(-1))
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.masked_fill_(1 - mask, -float('inf'))

        # Softmax or sparsemax to normalize attention weights
        if self.attn_func == "softmax":
            align_vectors = F.softmax(align.view(batch * target_l, source_l), -1)
        else:
            align_vectors = sparsemax(align.view(batch * target_l, source_l), -1)
        align_vectors = align_vectors.view(batch, target_l, source_l)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, memory_bank)

        return c.squeeze(1), align_vectors
示例#3
0
    def forward(self, source, memory_bank, memory_lengths=None,
                memory_turns = None, coverage=None):
        # here we implement a hierarchical attention
        if source.dim() == 2:
            source = source.unsqueeze(1)

        batch, source_tl, source_wl, dim = memory_bank.size()
        batch_, target_l, dim_ = source.size()
        aeq(batch, batch_)
        # word level attention
        word_align = self.word_score(source, memory_bank.contiguous()
                           .view(batch, -1, dim))

        # transform align (b, 1, tl * wl) -> (b * tl, 1, wl)
        word_align = word_align.view(batch * source_tl, 1, source_wl)
        if memory_lengths is not None:
            word_mask = sequence_mask_herd(memory_lengths.view(-1), max_len=word_align.size(-1))
            word_mask = word_mask.unsqueeze(1)  # Make it broadcastable.
            word_align.masked_fill_(1 - word_mask, -float('inf'))

        # Softmax or sparsemax to normalize attention weights
        if self.attn_func == "softmax":
            word_align_vectors = F.softmax(word_align.view(batch * source_tl, source_wl), -1)
        else:
            word_align_vectors = sparsemax(word_align.view(batch * source_tl, source_wl), -1)

        # mask the all padded sentences
        sent_pad_mask = memory_lengths.view(-1).eq(0).unsqueeze(1)
        word_align_vectors = torch.mul(word_align_vectors,
                                       (1.0 - sent_pad_mask).type_as(word_align_vectors))
        word_align_vectors = word_align_vectors.view(batch * source_tl, target_l, source_wl)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        cw = torch.bmm(word_align_vectors, memory_bank.view(batch * source_tl, source_wl, -1))
        cw = cw.view(batch, source_tl, -1)
        # concat_cw = torch.cat([cw, source.repeat(1, source_tl, 1)], 2).view(batch*source_tl, -1)
        # attn_hw = self.word_linear_out(concat_cw).view(batch, source_tl, -1)
        # attn_hw = torch.tanh(attn_hw)

        # turn level attention
        turn_align = self.turn_score(source, cw)

        if memory_turns is not None:
            turn_mask = sequence_mask(memory_turns, max_len=turn_align.size(-1))
            turn_mask = turn_mask.unsqueeze(1)  # Make it broadcastable.
            turn_align.masked_fill_(1 - turn_mask, -float('inf'))

        # Softmax or sparsemax to normalize attention weights
        if self.attn_func == "softmax":
            turn_align_vectors = F.softmax(turn_align.view(batch * target_l, source_tl), -1)
        else:
            turn_align_vectors = sparsemax(turn_align.view(batch * target_l, source_tl), -1)
        turn_align_vectors = turn_align_vectors.view(batch, target_l, source_tl)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        ct = torch.bmm(turn_align_vectors, cw)

        return ct.squeeze(1), None
示例#4
0
    def forward(self, src, lengths=None):
        """See :func:`EncoderBase.forward()`"""
        self._check_args(src, lengths)

        emb = self.embeddings(src)

        out = emb.transpose(0, 1).contiguous()
        '''
        indicators = []
        for idx, l in enumerate(lengths):
            src_item = src[:l, idx, 0]
            src_seq = str(l.item()) + '_' + ' '.join([self.vocab[i] for i in src_item])
            if src_seq not in self.reaction_atoms:
                print('error')
            reaction_atom_indicator = self.reaction_atoms[src_seq].tolist()
            reaction_atom_indicator.extend([0] * (out.shape[1] - l.item()))
            indicators.append(reaction_atom_indicator)
        indicators = np.array(indicators, dtype='float32')
        indicators = torch.from_numpy(indicators).float().cuda()
        indicators = indicators.unsqueeze(2)
        out = torch.cat((out, indicators), dim=2)
        out = self.linear(out)
        out = self.linear_dropout(out)        
        '''
        mask = ~sequence_mask(lengths).unsqueeze(1)
        # Run the forward pass of every layer of the transformer.
        for layer in self.transformer:
            out = layer(out, mask)
        out = self.layer_norm(out)

        # User added to incorporate reaction atom indicator
        out = out.transpose(0, 1).contiguous()

        return emb, out, lengths
示例#5
0
    def forward(self, memory_bank, memory_lengths=None, coverage=None):
        """

        Args:
          source (FloatTensor): query vectors ``(batch, tgt_len, dim)``
          memory_bank (FloatTensor): source vectors ``(batch, src_len, dim)``
          memory_lengths (LongTensor): the source context lengths ``(batch,)``
          coverage (FloatTensor): None (not supported yet)

        Returns:
          (FloatTensor, FloatTensor):

          * Computed vector ``(tgt_len, batch, dim)``
          * Attention distribtutions for each query
            ``(tgt_len, batch, src_len)``
        """
        batch, source_l, dim = memory_bank.size()
        source = self.source
        source = source.expand(batch, -1)
        source = source.unsqueeze(1)
        batch_, target_l, dim_ = source.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)
        if coverage is not None:
            batch_, source_l_ = coverage.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        if coverage is not None:
            cover = coverage.view(-1).unsqueeze(1)
            memory_bank += self.linear_cover(cover).view_as(memory_bank)
            memory_bank = torch.tanh(memory_bank)

        # compute attention scores, as in Luong et al.
        align = self.score(source, memory_bank)

        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths, max_len=align.size(-1))
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.masked_fill_(~mask, -float('inf'))

        # Softmax or sparsemax to normalize attention weights
        if self.attn_func == "softmax":
            align_vectors = F.softmax(align.view(batch*target_l, source_l), -1)
        else:
            align_vectors = sparsemax(align.view(batch*target_l, source_l), -1)
        align_vectors = align_vectors.view(batch, target_l, source_l)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, memory_bank)  # batch, target_l, dim
        c = c.mean(dim=1)  # batch, dim

        # Check output sizes
        batch_, dim_ = c.size()
        aeq(batch, batch_)
        aeq(dim, dim_)

        return c
示例#6
0
    def forward(self, src, grh, lengths=None):
        """See :func:`EncoderBase.forward()`"""
        self._check_args(src, lengths)
        assert src.size(0) == grh.size(-1), "srclen != grh_n"

        emb = self.embeddings(src) # batch, srclen, dim

        out = emb.transpose(0, 1).contiguous()
        mask = ~sequence_mask(lengths).unsqueeze(1)
        # Run the forward pass of every layer of the tranformer.
        out_list = []
        out_list.append(out)
        if self.aggregation == "dense":
            for i in range(len(self.transformer)):
                out = self.dense_linear[i](torch.cat(out_list, dim=-1))
                out = self.transformer[i](out, grh, mask)
                out_list.append(out)
            aggregate_out = torch.cat(out_list, dim=-1)
            aggregate_out = self.aggregate_layer_norm(aggregate_out)
            aggregate_out = self.aggregate_linear(aggregate_out)
            return emb, aggregate_out.transpose(0, 1).contiguous(), lengths
        else:
            for layer in self.transformer:
                out = layer(out, grh, mask)
                out_list.append(out)
            out = self.layer_norm(out)
            if self.aggregation == "jump":
                aggregate_out = torch.cat(out_list, dim=-1)
                aggregate_out = self.aggregate_layer_norm(aggregate_out)
                aggregate_out = self.aggregate_linear(aggregate_out)
                return emb, aggregate_out.transpose(0, 1).contiguous(), lengths
        return emb, out.transpose(0, 1).contiguous(), lengths
示例#7
0
    def forward(self, src, lengths=None, batch=None):
        """See :func:`EncoderBase.forward()`"""
        self._check_args(src, lengths)

        emb = self.embeddings(src)

        batch_size = emb.size()[1]
        seq_len = emb.size()[0]

        out = emb.transpose(0, 1).contiguous()
        mask = ~sequence_mask(lengths).unsqueeze(1)

        batch_geometric = get_embs_graph(batch, out)

        if self.decoder_dim != self.d_model:
            out = self.linear_before(out)

        for i, layer in enumerate(self.transformer):
            out, attns = layer(out, mask)

            out_gnn = out.view((batch_size * seq_len, self.d_model))
            memory_bank = self.gnns[i](out_gnn, batch_geometric.edge_index, edge_type=batch_geometric.y)

            out = self.rnn(memory_bank, out_gnn)

            out = out.view((batch_size, seq_len, self.d_model))

        out = self.layer_norm(out)
        if self.decoder_dim != self.d_model:
            out = self.linear(out)
        attn_weight = F.softmax(self.FC(out), dim=1)
        outputs2 = attn_weight * out
        outputs2 = self.FC2(outputs2.sum(dim=1))

        return emb, out.transpose(0, 1).contiguous(), lengths, outputs2
    def forward(self, src, lengths=None, segment_count=None, padding_value=None):
        self._check_args(src, lengths)
        encoder_final, memory_bank, lengths = self.rnn_encoder(src, lengths, enforce_sorted=False)
        context = self.attn(memory_bank.transpose(0,1), memory_lengths=lengths)
        segment_count_, dim_ = context.size()
        assert dim_ == self.hidden_size
        assert segment_count.sum() == segment_count_
        batch_ = segment_count.size()
        segment_representation = context.split(segment_count.tolist())
        segment_representation_padded = pad_sequence(segment_representation, padding_value=padding_value)
        memory_bank, _ = self.content_selection_attn(segment_representation_padded.transpose(0, 1).contiguous(),
                                        segment_representation_padded.transpose(0, 1),
                                        memory_lengths=segment_count)

        _, batch, emb_dim = memory_bank.size()
        assert batch == batch_[0]
        assert emb_dim == self.hidden_size
        if segment_count is not None:
            # we avoid padding while mean pooling
            mask = sequence_mask(segment_count).float()
            mask = mask / segment_count.unsqueeze(1).float()
            mean = torch.bmm(mask.unsqueeze(1), memory_bank.transpose(0, 1)).squeeze(1)
        else:
            mean = memory_bank.mean(0)

        mean = mean.expand(self.num_layers, batch, emb_dim)
        encoder_final = (mean, mean)
        return encoder_final, memory_bank, segment_count
示例#9
0
    def forward(self, src, lengths=None, batch=None):
        """See :func:`EncoderBase.forward()`"""
        self._check_args(src, lengths)

        emb = self.embeddings(src)

        batch_size = emb.size()[1]
        seq_len = emb.size()[0]

        out = emb.transpose(0, 1).contiguous()
        mask = ~sequence_mask(lengths).unsqueeze(1)

        if self.decoder_dim != self.d_model:
            out = self.linear_before(out)

        batch_geometric = get_embs_graph(batch, out)
        memory_bank = batch_geometric.x

        for layer in self.gnns:
            new_memory_bank = layer(memory_bank, batch_geometric.edge_index, edge_type=batch_geometric.y)
            memory_bank = self.rnn(new_memory_bank, memory_bank)

        memory_bank = memory_bank.view((batch_size, seq_len, self.d_model))

        for layer in self.transformer:
            out, attns = layer(out, mask)
        out = self.layer_norm(out)

        out = torch.cat([out, memory_bank], dim=2)

        out = self.final_linear(out)

        return emb, out.transpose(0, 1).contiguous(), lengths
示例#10
0
    def _compute_te_loss(self, target_attns):
        """
        :param target_attns: a tuple (stacked_target_attns, target_attns_lens, src_states_target_list)
        :return: target encoding loss
        """
        # stacked_target_attns: [b_size, max_sep_num, sample_size+1]
        # target_attns_lens: [b_size]
        # src_states_target_list: [b_size]
        stacked_target_attns, target_attns_lens, src_states_target_list = target_attns
        b_size, max_sep_num, cls_num = stacked_target_attns.size()
        device = stacked_target_attns.device

        gt_tensor = torch.Tensor(src_states_target_list).view(b_size, 1).repeat(1, max_sep_num).to(device)

        # class_dist_flat: [b_size * max_sep_num, sample_size+1]
        class_dist_flat = stacked_target_attns.view(-1, cls_num)
        log_dist_flat = torch.log(class_dist_flat + EPS)
        target_flat = gt_tensor.view(-1, 1)
        # [b_size * max_sep_num, 1]
        losses_flat = -torch.gather(log_dist_flat, dim=1, index=target_flat.long())
        losses = losses_flat.view(b_size, max_sep_num)

        mask = sequence_mask(torch.Tensor(target_attns_lens)).to(device)
        losses = losses * mask.float()
        losses = losses.sum(dim=1)
        return losses
示例#11
0
    def _compute_orthogonal_loss(self, sep_states):
        """
        The orthogonal loss computation function
        sep_states: a tuple (stacked_sep_states, sep_states_lens)
        :return: a scalar, the orthogonal loss
        """
        # stacked_sep_states: [b_size, max_sep_num, src_h_size]
        stacked_sep_states, sep_states_lens = sep_states
        b_size, max_sep_num, src_h_size = stacked_sep_states.size()
        b_size_ = len(sep_states_lens)
        aeq(b_size, b_size_)

        device = stacked_sep_states.device

        # obtain the mask
        # [b_size, max_sep_num]
        mask = sequence_mask(torch.Tensor(sep_states_lens)).to(device)
        mask = mask.float()
        # [b_size, 1, max_sep_num]
        mask = mask.unsqueeze(1)
        # [b_size, max_sep_num, max_sep_num]
        mask_2d = torch.bmm(mask.transpose(1, 2), mask)

        # compute the loss
        # [b_size, max_sep_num, max_sep_num]
        identity = torch.eye(max_sep_num).unsqueeze(0).repeat(b_size, 1, 1).to(device)
        # [b_size, max_sep_num, max_sep_num]
        orthogonal_loss_ = torch.bmm(stacked_sep_states, stacked_sep_states.transpose(1, 2)) - identity
        orthogonal_loss_ = orthogonal_loss_ * mask_2d
        # [b_size]
        orthogonal_loss = torch.norm(orthogonal_loss_.view(b_size, -1), p=2, dim=1)
        return orthogonal_loss
示例#12
0
文件: loss.py 项目: takatomo-k/s2s
 def _make_shard_state(self, batch, range_, result):
     mel, mel_lengths = batch.tgt
     #import pdb;pdb.set_trace()
     return {
         "output": result["dec_out"],
         "target": mel[1:],
         "lengths": sequence_mask(mel_lengths-1).transpose(0,1).unsqueeze(-1).type(torch.FloatTensor).cuda()
     }
    def __call__(self,
                 tgt: torch.Tensor,
                 memory_bank: torch.Tensor,
                 step: Optional[int] = None,
                 **kwargs):
        """Decode, possibly stepwise."""
        if step == 0:
            self._init_cache(memory_bank)

        tgt_words = tgt[:, :, 0].transpose(0, 1)

        emb = self.embeddings(tgt, step=step)
        assert emb.dim() == 3  # len x batch x embedding_dim

        output = emb.transpose(0, 1).contiguous()
        src_memory_bank = memory_bank.transpose(0, 1).contiguous()

        pad_idx = self.embeddings.word_padding_idx
        src_lens = kwargs["memory_lengths"]
        src_max_len = self.state["src"].shape[0]
        #Turbo add bool -> float
        src_pad_mask = ~sequence_mask(src_lens, src_max_len).unsqueeze(1)
        tgt_pad_mask = tgt_words.data.eq(pad_idx).unsqueeze(1)  # [B, 1, T_tgt]

        with_align = kwargs.pop('with_align', False)
        if with_align:
            raise "with_align must be False"
        attn_aligns = []

        # It's Turbo's show time!
        for i, layer in enumerate(self.transformer_layers):
            layer_cache = self.state["cache"]["layer_{}".format(i)] \
                if step is not None else None
            output, attn, attn_align = layer(output,
                                             src_memory_bank,
                                             src_pad_mask,
                                             tgt_pad_mask,
                                             layer_cache=layer_cache,
                                             step=step,
                                             with_align=with_align)
            if attn_align is not None:
                attn_aligns.append(attn_align)

        # Turbo finished.
        output = self.layer_norm(output)
        dec_outs = output.transpose(0, 1).contiguous()
        attn = attn.transpose(0, 1).contiguous()

        attns = {"std": attn}
        if self._copy:
            attns["copy"] = attn
        if with_align:
            attns["align"] = attn_aligns[self.alignment_layer]  # `(B, Q, K)`
            # attns["align"] = torch.stack(attn_aligns, 0).mean(0)  # All avg

        # TODO(OpenNMT-py) change the way attns is returned dict => list or tuple (onnx)

        return dec_outs, attns
示例#14
0
        def ans_match(src_seq, ans_seq):
            import torch.nn.functional as F
            BF_ans_mask = sequence_mask(ans_lengths)  # [batch, ans_seq_len]
            BF_src_mask = sequence_mask(src_lengths)  # [batch, src_seq_len]
            BF_src_outputs = src_seq.transpose(
                0, 1)  # [batch, src_seq_len, 2*hidden_size]
            BF_ans_outputs = ans_seq.transpose(
                0, 1)  # [batch, ans_seq_len, 2*hidden_size]

            # compute bi-att scores
            src_scores = BF_src_outputs.bmm(BF_ans_outputs.transpose(
                2, 1))  # [batch, src_seq_len, ans_seq_len]
            ans_scores = BF_ans_outputs.bmm(BF_src_outputs.transpose(
                2, 1))  # [batch, ans_seq_len, src_seq_len]

            # mask padding
            Expand_BF_ans_mask = BF_ans_mask.unsqueeze(1).expand(
                src_scores.size())  # [batch, src_seq_len, ans_seq_len]
            src_scores.data.masked_fill_(~(Expand_BF_ans_mask).bool(),
                                         -float('inf'))
            # src_scores = torch.ones(src_scores.shape).to(ans_seq.device)

            Expand_BF_src_mask = BF_src_mask.unsqueeze(1).expand(
                ans_scores.size())  # [batch, ans_seq_len, src_seq_len]
            ans_scores.data.masked_fill_(~(Expand_BF_src_mask).bool(),
                                         -float('inf'))

            # normalize with softmax
            src_alpha = F.softmax(src_scores,
                                  dim=2)  # [batch, src_seq_len, ans_seq_len]
            ans_alpha = F.softmax(ans_scores,
                                  dim=2)  # [batch, ans_seq_len, src_seq_len]

            # take the weighted average
            BF_src_matched_seq = src_alpha.bmm(
                BF_ans_outputs)  # [batch, src_seq_len, 2*hidden_size]
            src_matched_seq = BF_src_matched_seq.transpose(
                0, 1)  # [src_seq_len, batch, 2*hidden_size]

            BF_ans_matched_seq = ans_alpha.bmm(
                BF_src_outputs)  # [batch, ans_seq_len, 2*hidden_size]
            ans_matched_seq = BF_ans_matched_seq.transpose(
                0, 1)  # [src_seq_len, batch, 2*hidden_size]

            return src_matched_seq, ans_matched_seq
    def forward(self, src, lengths=None):
        """See :func:`EncoderBase.forward()`"""
        self._check_args(src, lengths)
        # print("First:", lengths, lengths.dtype)
        emb = self.embeddings(src)
        # print("before", emb.size())

        if self.conv_net or self.max_pooling or self.highway_net or self.linear_mapping:
            emb = emb.permute(1, 0, 2)

        if self.conv_net or self.max_pooling:
            emb = emb.unsqueeze(1)

        if self.conv_net:
            emb = self.conv_net(emb)

        if self.max_pooling:
            emb = self.max_pooling(emb)

        if self.conv_net or self.max_pooling:
            emb = emb.squeeze(1)

        if self.highway_net:
            emb = self.highway_net(emb)

        if self.linear_mapping:
            emb = self.linear_mapping(emb)

        if self.conv_net or self.max_pooling or self.highway_net or self.linear_mapping:
            emb = emb.permute(1, 0, 2)
        # print(emb.size())
        if self.pos_encoding:
            emb = self.pos_encoding(emb)

        out = emb.transpose(0, 1).contiguous()
        # print("-----", out.size())
        # out = emb.contiguous()
        len_type = lengths.dtype
        if self.max_pooling:
            lengths = torch.ceil(lengths.float() /
                                 self.conv_pooling).to(dtype=len_type)

        if lengths.max().tolist() != out.size(1):
            # print("ERRoor", lengths.max().tolist(), out.size(1))
            lengths = torch.tensor(emb.size(1) * [emb.size(0)],
                                   device=emb.device)
        mask = ~sequence_mask(lengths).unsqueeze(1)
        # print("Second:", lengths, lengths.dtype, mask.size(), out.size())
        # print("")
        for layer in self.transformer:
            out = layer(out, mask)
        out = self.layer_norm(out)

        # print(emb.size(), out.transpose(0, 1).contiguous().size(), lengths.size())

        return emb, out.transpose(0, 1).contiguous(), lengths
示例#16
0
    def __call__(self, batch, rl_forward, baseline_forward):
        """
        There's no better way for now than a for-loop...
        """
        
        rl_sentences, rl_log_probs, rl_attns = rl_forward
        baseline_sentences, baseline_log_probs, baseline_attns = baseline_forward
        
        device = batch.tgt.device
        
        rl_lengths = list()
        baseline_lengths = list()
        decoded_sequences = list()
        rl_scores = list()
        baseline_scores = list()
        for b in range(batch.batch_size):
            
            rl_candidate, rl_length = self.cleaner.clean_candidate_tokens(rl_sentences[:, b],
                                                                    batch.src_map[:, b], 
                                                                    batch.src_ex_vocab[b],
                                                                    rl_attns[:, b])
            baseline_candidate, baseline_length = self.cleaner.clean_candidate_tokens(baseline_sentences[:, b],
                                                                    batch.src_map[:, b], 
                                                                    batch.src_ex_vocab[b],
                                                                    baseline_attns[:, b])
            
            if rl_length == 0:
                rl_length = 1
            rl_lengths.append(rl_length)
            baseline_lengths.append(baseline_length)
            decoded_sequences.append((self.references[batch.indices[b].item()],
                                      " ".join(baseline_candidate), " ".join(rl_candidate)))

            rl_scores.append(self.metric(rl_candidate, batch.indices[b].item()))
            baseline_scores.append(self.metric(baseline_candidate, batch.indices[b].item()))
            
        rl_lengths = torch.LongTensor(rl_lengths).to(device)
        baseline_lengths = torch.LongTensor(baseline_lengths).to(device)
        mask = sequence_mask(rl_lengths, max_len=len(rl_sentences))
        
        sequences_scores = rl_log_probs.masked_fill(~mask.transpose(0,1), 0)
        sequences_scores = sequences_scores.sum(dim=0) / rl_lengths.float()
        
        # we reward the model according to f1_score
        
        rl_rewards = torch.FloatTensor(rl_scores).to(device)
        baseline_rewards = torch.FloatTensor(baseline_scores).to(device)
        rewards = baseline_rewards - rl_rewards
    
        loss = (rewards * sequences_scores).mean()
        stats = self._stats(loss, baseline_rewards.mean(), rl_rewards.mean(),
                            baseline_lengths, rl_lengths,
                            decoded_sequences)
        
        return loss, stats
def build_chunk_mask(lengths, ent_size):
    """
    [bsz, n_ents, n_ents]
    Filled with -inf where self-attention shouldn't attend, a zeros elsewhere.
    """
    ones = lengths // ent_size
    ones = sequence_mask(ones).unsqueeze(1).repeat(1, ones.max(),
                                                   1).to(lengths.device)
    mask = torch.full(ones.shape, float('-inf')).to(lengths.device)
    mask.masked_fill_(ones, 0)
    return mask
示例#18
0
    def forward(self, src_enc, tgt, src_lengths):
        dec_outputs = []
        dec_states = []
        dec_attns = []
        tgt_len = tgt.shape[0]
        src_mask = sequence_mask(src_lengths,
                                 max_len=src_enc.size(0)).transpose(0, 1)

        # Precompute all target-side embeddings
        tgt_embed = self.embedding(tgt.squeeze(-1))
        # Initialize decoder state
        s = torch.tanh(self.Ws(src_enc[0, :, :]))
        #dec_states += [s.clone().unsqueeze(0),]
        # Recurrence
        for i in range(1, tgt_len):
            # Condition attention on current decoder state
            s_ = s.unsqueeze(0).expand(src_enc.shape[0], s.shape[0],
                                       s.shape[1])
            inpt = torch.cat([s_, src_enc], -1)
            ei = self.va(torch.tanh(self.Wa(inpt))).squeeze(-1)
            ei = ei.masked_fill(1 - src_mask, -float('inf'))
            ai = torch.exp(torch.log_softmax(ei, 0))
            #print (ai.shape, src_enc.shape); sys.exit(0)
            ci = torch.sum(ai.unsqueeze(-1) * src_enc, 0)

            # Compute decoder output (single layer, no drop-out)
            # note: use encoder state s_(i-1), before state update
            inpt = torch.cat([tgt_embed[i - 1, :, :], s, ci], 1)
            ti = self.Wo(inpt)

            # Store decoder state and output, attention distributions
            dec_states += [
                s.clone().unsqueeze(0),
            ]
            dec_outputs += [
                ti.clone().unsqueeze(0),
            ]
            dec_attns += [
                ai.clone().transpose(0, 1).unsqueeze(0),
            ]

            # Update decoder state
            inpt = torch.cat([tgt_embed[i - 1, :, :], s, ci], -1)
            zi = torch.sigmoid(self.Wz(inpt))  # update
            ri = torch.sigmoid(self.Wr(inpt))  # reset
            inpt = torch.cat([tgt_embed[i - 1, :, :], ri * s, ci], -1)
            ni = torch.tanh(self.Wn(inpt))  # proposal
            s = (1.0 - zi) * s + zi * ni  # new state

        dec_outputs = torch.cat(dec_outputs)  # (tgt_len, batch_len, nhidden)
        dec_states = torch.cat(dec_states)  # (tgt_len, batch_len, nhidden)
        dec_attns = torch.cat(dec_attns)  # (tgt_len, batch_len, src_lengths)
        return dec_outputs, dec_states, dec_attns
示例#19
0
    def forward(self, src, lengths=None):
        self._check_args(src, lengths)

        emb = self.embeddings(src)

        out = emb.transpose(0, 1).contiguous()
        mask = ~sequence_mask(lengths).unsqueeze(1)# Run the forward pass of every layer of the tranformer.
        for layer in self.transformer:
            out = layer(out, mask)
        out = self.layer_norm(out)

        return emb, out.transpose(0, 1).contiguous(), lengths
示例#20
0
文件: loss.py 项目: takatomo-k/s2s
 def _make_shard_state(self, batch, range_, result):
     #import pdb;pdb.set_trace()
     tgt, tgt_lengths = batch.tgt
     txt = batch.txt[0]
     
     #import pdb;pdb.set_trace()
     return {
         "output": result["dec_out"],
         "target": tgt[1:],
         "tgt_lengths":sequence_mask(tgt_lengths).transpose(0,1).unsqueeze(-1).type(torch.FloatTensor).cuda(),
         "txt_out": result["txt_out"],
         "txt": txt[1:]
        }
示例#21
0
 def forward(self, src, batch, lengths=None):
     """See :func:`EncoderBase.forward()`"""
     self._check_args(src, lengths)
     emb = self.embeddings(src)
     out = emb.transpose(0, 1).contiguous()
     mask = ~sequence_mask(lengths).unsqueeze(1)
     # Run the forward pass of every layer of the tranformer.
     for i, layer in enumerate(self.transformer):
         out, at_self_attn = layer(out, mask)
         self.build_visualization(batch, i, at_self_attn)
     out = self.layer_norm(out)
     self.batch_count += 1
     return emb, out.transpose(0, 1).contiguous(), lengths
示例#22
0
    def forward(self, src, lengths=None):
        """See :func:`EncoderBase.forward()`"""
        self._check_args(src, lengths)

        emb = self.embeddings(src)  #src[300,13,1]  emb[300,13,512]

        out = emb.transpose(0, 1).contiguous()  #[13,300,512]
        mask = ~sequence_mask(lengths).unsqueeze(1)
        # Run the forward pass of every layer of the tranformer.
        for layer in self.transformer:
            out = layer(out, mask)
        out = self.layer_norm(out)

        return emb, out.transpose(0, 1).contiguous(), lengths
    def forward(self, memory_bank, lengths):  #encoder_out
        # memory_bank: [maxlen, B, H]
        # lengths: [B, ]
        mask = sequence_mask(lengths).float()  # [B, maxlen]
        mask = mask / lengths.unsqueeze(1).float()  # [B, maxlen]
        # arg1: [B, 1, maxlen], arg2: [B, maxlen, H]] ==> [B, H]
        mean = torch.bmm(mask.unsqueeze(1),
                         memory_bank.transpose(0, 1)).squeeze(1)

        x = torch.tanh(self.fc1(mean))
        if self.dropout is not None:
            x = self.dropout(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=-1, dtype=torch.float32).type_as(x)
def get_transformer_encoder_attn(model, src, lengths=None):
    emb = model.embeddings(src)
    model._check_args(src, lengths)

    out = emb.transpose(0, 1).contiguous()
    mask = ~sequence_mask(lengths).unsqueeze(1)
    attn_matrices = list()
    # Run the forward pass of every layer of the tranformer.
    for layer in model.transformer:
        out, attns = transformer_encoder_forward_with_attn(layer, out, mask)
        attns.detach()
        attn_matrices.append(attns)

    return emb, out.transpose(0, 1).contiguous(), lengths, attn_matrices
示例#25
0
    def forward(self, tgt, memory_bank, step=None, emotion=None, **kwargs):
        """Decode, possibly stepwise."""
        if step == 0:
            self._init_cache(memory_bank)

        tgt_words = tgt[:, :, 0].transpose(0, 1)

        emb = self.embeddings(tgt, step=step)
        assert emb.dim() == 3  # len x batch x embedding_dim

        # add emotion embedding using linear transformation
        if emotion is not None:
            batch_emotion_embedding = self.emo_embedding(
                emotion)  # (bacth, emotion_emb_size)
            batch_emotion_embedding = batch_emotion_embedding.unsqueeze(
                0).repeat(emb.size(0), 1, 1)  # (len, bacth, emotion_emb_size)
            emb = self.emo_mlp(
                torch.cat([emb, batch_emotion_embedding],
                          dim=2))  # emb: (len, bacth, embedding_dim)

        output = emb.transpose(0, 1).contiguous()
        src_memory_bank = memory_bank.transpose(0, 1).contiguous()

        pad_idx = self.embeddings.word_padding_idx
        src_lens = kwargs["memory_lengths"]
        src_max_len = self.state["src"].shape[0]
        src_pad_mask = ~sequence_mask(src_lens, src_max_len).unsqueeze(1)
        tgt_pad_mask = tgt_words.data.eq(pad_idx).unsqueeze(1)  # [B, 1, T_tgt]

        for i, layer in enumerate(self.transformer_layers):
            layer_cache = self.state["cache"]["layer_{}".format(i)] \
                if step is not None else None
            output, attn = layer(output,
                                 src_memory_bank,
                                 src_pad_mask,
                                 tgt_pad_mask,
                                 layer_cache=layer_cache,
                                 step=step)

        output = self.layer_norm(output)
        dec_outs = output.transpose(0, 1).contiguous()
        attn = attn.transpose(0, 1).contiguous()

        attns = {"std": attn}
        if self._copy:
            attns["copy"] = attn

        # TODO change the way attns is returned dict => list or tuple (onnx)
        return dec_outs, attns
示例#26
0
文件: loss.py 项目: takatomo-k/s2s
    def _make_shard_state(self, batch, range_, result):
        tgt, tgt_lengths = batch.tgt
        src_txt = batch.src_txt[0]
        tgt_txt = batch.tgt_txt[0]

        return {
            "output": result["dec_out"],
            "target": tgt[1:],
            "tgt_lengths":sequence_mask(tgt_lengths).transpose(0,1).unsqueeze(-1).type(torch.FloatTensor).cuda(),
            "asr_dec_out": result["asr_dec_out"],
            "src_txt": src_txt[1:],
            "tgt_txt_out": result["tgt_txt_out"],
            "tgt_txt":tgt_txt[1:],
            
           }
示例#27
0
    def forward(self, src, lengths=None):
        #import pdb;pdb.set_trace()
        """See :func:`EncoderBase.forward()`"""
        self._check_args(src, lengths)

        emb = self.embeddings(src) if self.embeddings is not None else src
        emb = self.pe(self.noise(emb))
        out = emb.transpose(0, 1).contiguous()
        mask = ~sequence_mask(lengths).unsqueeze(1)  # [B, 1, T]
        # Run the forward pass of every layer of the tranformer.

        for layer in self.transformer:
            out = layer(out, mask)
        out = self.layer_norm(out)
        return emb, out.transpose(0, 1).contiguous(), lengths
示例#28
0
    def forward(self, src, imgs=None, lengths=None):
        """See :func:`EncoderBase.forward()`"""
        self._check_args(src, lengths)
        visual_out = self.video_encoder(imgs)

        emb = self.embeddings(src)

        out = emb.transpose(0, 1).contiguous()
        mask = ~sequence_mask(lengths).unsqueeze(1)
        # Run the forward pass of every layer of the tranformer.
        for layer in self.transformer:
            out = layer(out, mask, imgs=visual_out)
        out = self.layer_norm(out)

        return emb, (visual_out.transpose(0, 1).contiguous(),
                     out.transpose(0, 1).contiguous()), lengths
    def forward(self, source, memory_bank, memory_lengths=None, coverage=None):
        """
                Args:
                  source (`FloatTensor`): query vectors `[batch x tgt_len x dim]`
                  memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]`
                  memory_lengths (`LongTensor`): the source context lengths `[batch]`
                  coverage (`FloatTensor`): None (not supported yet)

                Returns:
                  (`FloatTensor`, `FloatTensor`):

                  * Computed vector `[tgt_len x batch x dim]`
                  * Attention distribtutions for each query
                     `[tgt_len x batch x src_len]`
                """
        if source.dim() == 2:
            source = source.unsqueeze(1)

        batch, source_l, dim = memory_bank.size()
        batch_, target_l, dim_ = source.size()
        aeq(batch, batch_)

        # compute attention scores, as in Luong et al.
        align = self.score(source, memory_bank)

        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths, max_len=align.size(-1))
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.masked_fill_(1 - mask, -float('inf'))

        # Softmax or sparsemax to normalize attention weights
        if self.attn_func == "softmax":
            align_vectors = F.softmax(align.view(batch * target_l, source_l), -1)
        else:
            align_vectors = sparsemax(align.view(batch * target_l, source_l), -1)
        align_vectors = align_vectors.view(batch, target_l, source_l)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, memory_bank)

        # concatenate
        concat_c = torch.cat([c, source], 2).view(batch * target_l,
                                                  self.indim + self.outdim)
        attn_h = self.linear_out(concat_c).view(batch, target_l, self.outdim)

        return attn_h.squeeze(1), align_vectors.squeeze(1)
示例#30
0
    def forward(self, tgt, memory_bank, step=None, **kwargs):
        """Decode, possibly stepwise."""
        if step == 0:
            self._init_cache(memory_bank)

        tgt_words = tgt[:, :, 0].transpose(0, 1)

        emb = self.embeddings(tgt, step=step)
        if self.n_latent > 1:
            emb = emb + self.latent_embedding(kwargs["latent_input"].to(
                self.latent_embedding.weight.device))
        assert emb.dim() == 3  # len x batch x embedding_dim

        if self.n_segments > 0:
            emb = emb + self.segment_embedding(kwargs["segment_input"])
        assert emb.dim() == 3  # len x batch x embedding_dim

        output = emb.transpose(0, 1).contiguous()
        src_memory_bank = memory_bank.transpose(0, 1).contiguous()

        pad_idx = self.embeddings.word_padding_idx
        src_lens = kwargs["memory_lengths"]
        src_max_len = self.state["src"].shape[0]
        src_pad_mask = ~sequence_mask(src_lens, src_max_len).unsqueeze(1)
        tgt_pad_mask = tgt_words.data.eq(pad_idx).unsqueeze(1)  # [B, 1, T_tgt]

        for i, layer in enumerate(self.transformer_layers):
            layer_cache = self.state["cache"]["layer_{}".format(i)] \
                if step is not None else None
            output, attn = layer(output,
                                 src_memory_bank,
                                 src_pad_mask,
                                 tgt_pad_mask,
                                 layer_cache=layer_cache,
                                 step=step)

        output = self.layer_norm(output)
        dec_outs = output.transpose(0, 1).contiguous()
        attn = attn.transpose(0, 1).contiguous()

        attns = {"std": attn}
        if self._copy:
            attns["copy"] = attn

        # TODO change the way attns is returned dict => list or tuple (onnx)
        return dec_outs, attns
示例#31
0
    def forward(self, source, memory_bank, memory_lengths=None, coverage=None):
        """

        Args:
          source (FloatTensor): query vectors ``(batch, tgt_len, dim)``
          memory_bank (FloatTensor): source vectors ``(batch, src_len, dim)``
          memory_lengths (LongTensor): the source context lengths ``(batch,)``
          coverage (FloatTensor): None (not supported yet)

        Returns:
          (FloatTensor, FloatTensor):

          * Computed vector ``(tgt_len, batch, dim)``
          * Attention distribtutions for each query
            ``(tgt_len, batch, src_len)``
        """

        # one step input
        if source.dim() == 2:
            one_step = True
            source = source.unsqueeze(1)
        else:
            one_step = False

        batch, source_l, dim = memory_bank.size()
        batch_, target_l, dim_ = source.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)
        if coverage is not None:
            batch_, source_l_ = coverage.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        if coverage is not None:
            cover = coverage.view(-1).unsqueeze(1)
            memory_bank += self.linear_cover(cover).view_as(memory_bank)
            memory_bank = torch.tanh(memory_bank)

        # compute attention scores, as in Luong et al.
        align = self.score(source, memory_bank)

        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths, max_len=align.size(-1))
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.masked_fill_(1 - mask, -float('inf'))

        # Softmax or sparsemax to normalize attention weights
        if self.attn_func == "softmax":
            align_vectors = F.softmax(align.view(batch*target_l, source_l), -1)
        else:
            align_vectors = sparsemax(align.view(batch*target_l, source_l), -1)
        align_vectors = align_vectors.view(batch, target_l, source_l)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, memory_bank)

        # concatenate
        concat_c = torch.cat([c, source], 2).view(batch*target_l, dim*2)
        attn_h = self.linear_out(concat_c).view(batch, target_l, dim)
        if self.attn_type in ["general", "dot"]:
            attn_h = torch.tanh(attn_h)

        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)

            # Check output sizes
            batch_, dim_ = attn_h.size()
            aeq(batch, batch_)
            aeq(dim, dim_)
            batch_, source_l_ = align_vectors.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        else:
            attn_h = attn_h.transpose(0, 1).contiguous()
            align_vectors = align_vectors.transpose(0, 1).contiguous()
            # Check output sizes
            target_l_, batch_, dim_ = attn_h.size()
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(dim, dim_)
            target_l_, batch_, source_l_ = align_vectors.size()
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        return attn_h, align_vectors