示例#1
0
    def _run_forward_pass(self, tgt, memory_bank, memory_lengths=None):
        """
		Private helper for running the specific RNN forward pass.
		Must be overriden by all subclasses.
		Args:
			tgt (LongTensor): a sequence of input tokens tensors
								 [len x batch x nfeats].
			memory_bank (FloatTensor): output(tensor sequence) from the
						  encoder RNN of size (src_len x batch x hidden_size).
			state (FloatTensor): hidden state from the encoder RNN for
								 initializing the decoder.
			memory_lengths (LongTensor): the source memory_bank lengths.
		Returns:
			dec_state (Tensor): final hidden state from the decoder.
			dec_outs ([FloatTensor]): an array of output of every time
									 step from the decoder.
			attns (dict of (str, [FloatTensor]): a dictionary of different
							type of attention Tensor array of every time
							step from the decoder.
		"""
        assert not self._copy  # TODO, no support yet.
        assert not self._coverage  # TODO, no support yet.

        # Initialize local and return variables.
        attns = {}
        emb = self.embeddings(tgt)

        # Run the forward pass of the RNN.
        if isinstance(self.rnn, nn.GRU):
            rnn_output, dec_state = self.rnn(emb, self.state["hidden"][0])
        else:
            rnn_output, dec_state = self.rnn(emb, self.state["hidden"])

        # Check
        tgt_len, tgt_batch, _ = tgt.size()
        output_len, output_batch, _ = rnn_output.size()
        aeq(tgt_len, output_len)
        aeq(tgt_batch, output_batch)
        # END

        # Calculate the attention.
        dec_outs, p_attn = self.attn(rnn_output.transpose(0, 1).contiguous(),
                                     memory_bank.transpose(0, 1),
                                     memory_lengths=memory_lengths)
        attns["std"] = p_attn

        # Calculate the context gate.
        if self.context_gate is not None:
            dec_outs = self.context_gate(
                emb.view(-1, emb.size(2)),
                rnn_output.view(-1, rnn_output.size(2)),
                dec_outs.view(-1, dec_outs.size(2)))
            dec_outs = \
             dec_outs.view(tgt_len, tgt_batch, self.hidden_size)

        dec_outs = self.dropout(dec_outs)
        return dec_state, dec_outs, attns
    def forward(self, base_target_emb, input_from_dec, encoder_out_top,
                encoder_out_combine):
        """
        Args:
            base_target_emb: target emb tensor
            input: output of decode conv
            encoder_out_t: the key matrix for calculation of attetion weight,
                which is the top output of encode conv
            encoder_out_combine:
                the value matrix for the attention-weighted sum,
                which is the combination of base emb and top output of encode

        """
        # checks
        # batch, channel, height, width = base_target_emb.size()
        batch, _, height, _ = base_target_emb.size()
        # batch_, channel_, height_, width_ = input_from_dec.size()
        batch_, _, height_, _ = input_from_dec.size()
        aeq(batch, batch_)
        aeq(height, height_)

        # enc_batch, enc_channel, enc_height = encoder_out_top.size()
        enc_batch, _, enc_height = encoder_out_top.size()
        # enc_batch_, enc_channel_, enc_height_ = encoder_out_combine.size()
        enc_batch_, _, enc_height_ = encoder_out_combine.size()

        aeq(enc_batch, enc_batch_)
        aeq(enc_height, enc_height_)

        preatt = seq_linear(self.linear_in, input_from_dec)
        target = (base_target_emb + preatt) * SCALE_WEIGHT
        target = torch.squeeze(target, 3)
        target = torch.transpose(target, 1, 2)
        pre_attn = torch.bmm(target, encoder_out_top)

        if self.mask is not None:
            pre_attn.data.masked_fill_(self.mask, -float('inf'))

        attn = F.softmax(pre_attn, dim=2)

        context_output = torch.bmm(attn,
                                   torch.transpose(encoder_out_combine, 1, 2))
        context_output = torch.transpose(torch.unsqueeze(context_output, 3), 1,
                                         2)
        return context_output, attn
示例#3
0
    def score(self, h_t, h_s):
        """
		Args:
		  h_t (`FloatTensor`): sequence of queries `[batch x tgt_len x dim]`
		  h_s (`FloatTensor`): sequence of sources `[batch x src_len x dim]`

		Returns:
		  :obj:`FloatTensor`:
		   raw attention scores (unnormalized) for each src index
		  `[batch x tgt_len x src_len]`

		"""

        # Check input sizes
        src_batch, src_len, src_dim = h_s.size()
        tgt_batch, tgt_len, tgt_dim = h_t.size()
        aeq(src_batch, tgt_batch)
        aeq(src_dim, tgt_dim)
        aeq(self.dim, src_dim)

        if self.attn_type in ["general", "dot"]:
            if self.attn_hidden > 0:
                h_t = self.transform_in(h_t)
                h_s = self.transform_in(h_s)

            if self.attn_type == "general":
                h_t = self.linear_in(h_t)
            h_s_ = h_s.transpose(1, 2)
            # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len)
            return torch.bmm(h_t, h_s_)
        else:
            dim = self.dim
            d = self.attn_hidden if self.attn_hidden > 0 else dim
            wq = self.linear_query(h_t.view(-1, dim))
            wq = wq.view(tgt_batch, tgt_len, 1, d)
            wq = wq.expand(tgt_batch, tgt_len, src_len, d)

            uh = self.linear_context(h_s.contiguous().view(-1, dim))
            uh = uh.view(src_batch, 1, src_len, d)
            uh = uh.expand(src_batch, tgt_len, src_len, d)

            # (batch, t_len, s_len, d)
            wquh = self.tanh(wq + uh)

            if self.attn_type == "mlp":
                return self.v(wquh.view(-1, d)).view(tgt_batch, tgt_len,
                                                     src_len)
            elif self.attn_type == "fine":
                return self.v(wquh.view(-1, d)).view(tgt_batch, tgt_len,
                                                     src_len, dim)
示例#4
0
    def forward(self, input, memory_bank, memory_lengths=None, coverage=None):
        """

		Args:
		  input (`FloatTensor`): query vectors `[batch x tgt_len x dim]`
		  memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]`
		  memory_lengths (`LongTensor`): the source context lengths `[batch]`
		  coverage (`FloatTensor`): None (not supported yet)

		Returns:
		  (`FloatTensor`, `FloatTensor`):

		  * Computed vector `[tgt_len x batch x dim]`
		  * Attention distribtutions for each query
			 `[tgt_len x batch x src_len]`
		"""

        # one step input
        if input.dim() == 2:
            one_step = True
            input = input.unsqueeze(1)
        else:
            one_step = False

        batch, sourceL, dim = memory_bank.size()
        batch_, targetL, dim_ = input.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)
        if coverage is not None:
            batch_, sourceL_ = coverage.size()
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)

        if coverage is not None:
            cover = coverage.view(-1).unsqueeze(1)
            memory_bank += self.linear_cover(cover).view_as(memory_bank)
            memory_bank = self.tanh(memory_bank)

        # compute attention scores, as in Luong et al.
        align = self.score(input, memory_bank)
        assert memory_lengths is not None
        mask = sequence_mask(memory_lengths)
        mask = mask.unsqueeze(1)  # Make it broadcastable.
        # mask the time step of self
        mask = mask.repeat(1, sourceL, 1)
        mask_self_index = list(range(sourceL))
        mask[:, mask_self_index, mask_self_index] = 0

        if self.attn_type == "fine":
            mask = mask.unsqueeze(3)

        align.data.masked_fill_(1 - mask, -float('inf'))

        # Softmax to normalize attention weights
        align_vectors = self.sm(align)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        if self.attn_type == "fine":
            c = memory_bank.unsqueeze(1).mul(align_vectors).sum(dim=2,
                                                                keepdim=False)
        else:
            c = torch.bmm(align_vectors, memory_bank)

        # concatenate
        concat_c = torch.cat([c, input], 2)
        attn_h = self.linear_out(concat_c)
        if self.attn_type in ["general", "dot"]:
            # attn_h = F.elu(attn_h, 0.1)
            # attn_h = F.elu(self.dropout(attn_h) + input, 0.1)

            # content selection gate
            attn_h = F.sigmoid(attn_h).mul(input)

        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)

            # Check output sizes
            batch_, dim_ = attn_h.size()
            aeq(batch, batch_)
            aeq(dim, dim_)
            batch_, sourceL_ = align_vectors.size()
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)
        else:
            attn_h = attn_h.transpose(0, 1).contiguous()
            align_vectors = align_vectors.transpose(0, 1).contiguous()

            # Check output sizes
            targetL_, batch_, dim_ = attn_h.size()
            aeq(targetL, targetL_)
            aeq(batch, batch_)
            aeq(dim, dim_)
            targetL_, batch_, sourceL_ = align_vectors.size()
            aeq(targetL, targetL_)
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)

        return attn_h, align_vectors
示例#5
0
 def _check_args(self, src, lengths=None, hidden=None):
     _, n_batch, _ = src.size()
     if lengths is not None:
         n_batch_, = lengths.size()
         aeq(n_batch, n_batch_)
示例#6
0
    def forward(self, source, memory_bank, memory_lengths=None, coverage=None):
        """

        Args:
          source (`FloatTensor`): query vectors `[batch x tgt_len x dim]`
          memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]`
          memory_lengths (`LongTensor`): the source context lengths `[batch]`
          coverage (`FloatTensor`): None (not supported yet)

        Returns:
          (`FloatTensor`, `FloatTensor`):

          * Computed vector `[tgt_len x batch x dim]`
          * Attention distribtutions for each query
             `[tgt_len x batch x src_len]`
        """

        # one step input
        if source.dim() == 2:
            one_step = True
            source = source.unsqueeze(1)
        else:
            one_step = False

        batch, source_l, dim = memory_bank.size()
        batch_, target_l, dim_ = source.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)
        if coverage is not None:
            batch_, source_l_ = coverage.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        if coverage is not None:
            cover = coverage.view(-1).unsqueeze(1)
            memory_bank += self.linear_cover(cover).view_as(memory_bank)
            memory_bank = torch.tanh(memory_bank)

        # compute attention scores, as in Luong et al.
        align = self.score(source, memory_bank)

        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths, max_len=align.size(-1))
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.masked_fill_(1 - mask, -float('inf'))

        # Softmax or sparsemax to normalize attention weights
        if self.attn_func == "softmax":
            align_vectors = F.softmax(align.view(batch * target_l, source_l),
                                      -1)
        else:
            align_vectors = sparsemax(align.view(batch * target_l, source_l),
                                      -1)
        align_vectors = align_vectors.view(batch, target_l, source_l)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, memory_bank)

        # concatenate
        concat_c = torch.cat([c, source], 2).view(batch * target_l, dim * 2)
        attn_h = self.linear_out(concat_c).view(batch, target_l, dim)
        if self.attn_type in ["general", "dot"]:
            attn_h = torch.tanh(attn_h)

        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)

            # Check output sizes
            batch_, dim_ = attn_h.size()
            aeq(batch, batch_)
            aeq(dim, dim_)
            batch_, source_l_ = align_vectors.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        else:
            attn_h = attn_h.transpose(0, 1).contiguous()
            align_vectors = align_vectors.transpose(0, 1).contiguous()
            # Check output sizes
            target_l_, batch_, dim_ = attn_h.size()
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(dim, dim_)
            target_l_, batch_, source_l_ = align_vectors.size()
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        return attn_h, align_vectors
示例#7
0
	def _run_forward_pass(self, tgt, memory_bank1, memory_bank2,
	                      memory_lengths1=None, memory_lengths2=None):
		"""
		See StdRNNDecoder._run_forward_pass() for description
		of arguments and return values.
		Args:
			tgt (`LongTensor`): sequences of padded tokens
				 `[tgt_len x batch x nfeats]`.
			memory_bank1 (`FloatTensor`): vectors from the encoder1
				 `[src_len x batch x hidden]`.
			memory_lengths1 (`LongTensor`): the padded source lengths
				`[batch]`.
			memory_bank2 (`FloatTensor`): vectors from the encoder2
				 `[tmpl_len x batch x hidden]`.
			memory_lengths2 (`LongTensor`): the padded source lengths
				`[batch]`.

		"""
		# Additional args check.
		input_feed = self.state["input_feed"].squeeze(0)
		input_feed_batch, _ = input_feed.size()
		_, tgt_batch, _ = tgt.size()
		aeq(tgt_batch, input_feed_batch)
		# END Additional args check.

		# Initialize local and return variables.
		dec_outs = []
		attns = {"std": []}  # std attn for src1
		attns["std2"] = []  # std attn for src2
		if self._copy:
			attns["copy"] = []  # copy attn for src1
		if self._coverage:
			# TODO: necessary for src2?
			attns["coverage"] = []  # coverage for src1

		emb = self.embeddings(tgt)
		assert emb.dim() == 3  # len x batch x embedding_dim

		dec_state = self.state["hidden"]
		coverage = self.state["coverage"].squeeze(0) \
			if self.state["coverage"] is not None else None

		# Input feed concatenates hidden state with
		# input at every time step.
		for _, emb_t in enumerate(emb.split(1)):
			emb_t = emb_t.squeeze(0)
			decoder_input = torch.cat([emb_t, input_feed], 1)
			rnn_output, dec_state = self.rnn(decoder_input, dec_state)

			decoder_output1, p_attn1 = self.attn(
				rnn_output,
				memory_bank1.transpose(0, 1),
				memory_lengths=memory_lengths1)

			decoder_output2, p_attn2 = self.attn(
				rnn_output,
				memory_bank2.transpose(0, 1),
				memory_lengths=memory_lengths2)

			if self.context_gate is not None:
				# TODO: context gate should be employed
				# instead of second RNN transform.
				decoder_output = self.context_gate(
					decoder_input, rnn_output, decoder_output1, decoder_output2
				)
			decoder_output = self.dropout(decoder_output)
			input_feed = decoder_output

			dec_outs += [decoder_output]
			attns["std"] += [p_attn1]
			attns["std2"] += [p_attn2]

			# Update the coverage attention.
			if self._coverage:
				coverage = coverage + p_attn1 if coverage is not None else p_attn1
				attns["coverage"] += [coverage]

			# Run the forward pass of the copy attention layer.
			if self._copy and not self._reuse_copy_attn:
				_, copy_attn1 = self.copy_attn(decoder_output,
				                               memory_bank1.transpose(0, 1))
				attns["copy"] += [copy_attn1]
			elif self._copy:
				attns["copy"] = attns["std"]
		# Return result.
		return dec_state, dec_outs, attns
示例#8
0
    def _run_forward_pass(self, tgt, memory_bank, memory_lengths=None):
        """
		See StdRNNDecoder._run_forward_pass() for description
		of arguments and return values.
		"""
        # Additional args check.
        input_feed = self.state["input_feed"].squeeze(0)
        input_feed_batch, _ = input_feed.size()
        _, tgt_batch, _ = tgt.size()
        aeq(tgt_batch, input_feed_batch)
        # END Additional args check.

        # Initialize local and return variables.
        dec_outs = []
        attns = {"std": []}
        if self._copy:
            attns["copy"] = []
        if self._coverage:
            attns["coverage"] = []

        emb = self.embeddings(tgt)
        assert emb.dim() == 3  # len x batch x embedding_dim

        dec_state = self.state["hidden"]
        coverage = self.state["coverage"].squeeze(0) \
         if self.state["coverage"] is not None else None

        # Input feed concatenates hidden state with
        # input at every time step.
        for _, emb_t in enumerate(emb.split(1)):
            emb_t = emb_t.squeeze(0)
            decoder_input = torch.cat([emb_t, input_feed], 1)
            rnn_output, dec_state = self.rnn(decoder_input, dec_state)
            decoder_output, p_attn = self.attn(rnn_output,
                                               memory_bank.transpose(0, 1),
                                               memory_lengths=memory_lengths)
            if self.context_gate is not None:
                # TODO: context gate should be employed
                # instead of second RNN transform.
                decoder_output = self.context_gate(decoder_input, rnn_output,
                                                   decoder_output)
            decoder_output = self.dropout(decoder_output)
            input_feed = decoder_output

            dec_outs += [decoder_output]
            attns["std"] += [p_attn]

            # Update the coverage attention.
            if self._coverage:
                coverage = coverage + p_attn \
                 if coverage is not None else p_attn
                attns["coverage"] += [coverage]

            # Run the forward pass of the copy attention layer.
            if self._copy and not self._reuse_copy_attn:
                _, copy_attn = self.copy_attn(decoder_output,
                                              memory_bank.transpose(0, 1))
                attns["copy"] += [copy_attn]
            elif self._copy:
                attns["copy"] = attns["std"]
        # Return result.
        return dec_state, dec_outs, attns