예제 #1
0
    def forward(self,
                input,
                memory_bank,
                memory_lengths=None,
                coverage=None,
                q_scores=None):
        """

        Args:
          input (`FloatTensor`): query vectors `[batch x tgt_len x dim]`
          memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]`
          memory_lengths (`LongTensor`): the source context lengths `[batch]`
          coverage (`FloatTensor`): None (not supported yet)
          q_scores (`FloatTensor`): the attention params from the inference network

        Returns:
          (`FloatTensor`, `FloatTensor`):

          * Weighted context vector `[tgt_len x batch x dim]`
          * Attention distribtutions for each query
             `[tgt_len x batch x src_len]`
          * Unormalized attention scores for each query 
            `[batch x tgt_len x src_len]`
        """

        # one step input
        if input.dim() == 2:
            one_step = True
            input = input.unsqueeze(1)
            if q_scores is not None:
                # oh, I guess this is super messy
                if q_scores.alpha is not None:
                    q_scores = Params(
                        alpha=q_scores.alpha.unsqueeze(1),
                        log_alpha=q_scores.log_alpha.unsqueeze(1),
                        dist_type=q_scores.dist_type,
                    )
        else:
            one_step = False

        batch, sourceL, dim = memory_bank.size()
        batch_, targetL, dim_ = input.size()
        aeq(batch, batch_)

        # compute attention scores, as in Luong et al.
        # Params should be T x N x S
        if self.p_dist_type == "categorical":
            scores = self.score(input, memory_bank)
            if memory_lengths is not None:
                # mask : N x T x S
                mask = sequence_mask(memory_lengths)
                mask = mask.unsqueeze(1)  # Make it broadcastable.
                scores.data.masked_fill_(~mask, -float('inf'))
            if self.k > 0 and self.k < scores.size(-1):
                topk, idx = scores.data.topk(self.k)
                new_attn_score = torch.zeros_like(scores.data).fill_(
                    float("-inf"))
                new_attn_score = new_attn_score.scatter_(2, idx, topk)
                scores = new_attn_score
            log_scores = F.log_softmax(scores, dim=-1)
            scores = log_scores.exp()

            c_align_vectors = scores

            p_scores = Params(
                alpha=scores,
                log_alpha=log_scores,
                dist_type=self.p_dist_type,
            )

        # each context vector c_t is the weighted average
        # over all the source hidden states
        context_c = torch.bmm(c_align_vectors, memory_bank)
        if self.mode != 'wsram':
            concat_c = torch.cat([input, context_c], -1)
            # N x T x H
            h_c = self.tanh(self.linear_out(concat_c))
        else:
            h_c = None

        # sample or enumerate
        # y_align_vectors: K x N x T x S
        q_sample, p_sample, sample_log_probs = None, None, None
        sample_log_probs_q, sample_log_probs_p, sample_p_div_q_log = None, None, None
        if self.mode == "sample":
            if q_scores is None or self.use_prior:
                p_sample, sample_log_probs = self.sample_attn(
                    p_scores,
                    n_samples=self.n_samples,
                    lengths=memory_lengths,
                    mask=mask if memory_lengths is not None else None)
                y_align_vectors = p_sample
            else:
                q_sample, sample_log_probs = self.sample_attn(
                    q_scores,
                    n_samples=self.n_samples,
                    lengths=memory_lengths,
                    mask=mask if memory_lengths is not None else None)
                y_align_vectors = q_sample
        elif self.mode == "gumbel":
            if q_scores is None or self.use_prior:
                p_sample, _ = self.sample_attn_gumbel(
                    p_scores,
                    self.temperature,
                    n_samples=self.n_samples,
                    lengths=memory_lengths,
                    mask=mask if memory_lengths is not None else None)
                y_align_vectors = p_sample
            else:
                q_sample, _ = self.sample_attn_gumbel(
                    q_scores,
                    self.temperature,
                    n_samples=self.n_samples,
                    lengths=memory_lengths,
                    mask=mask if memory_lengths is not None else None)
                y_align_vectors = q_sample
        elif self.mode == "enum" or self.mode == "exact":
            y_align_vectors = None
        elif self.mode == "wsram":
            assert q_scores is not None
            q_sample, sample_log_probs_q, sample_log_probs_p, sample_p_div_q_log = self.sample_attn_wsram(
                q_scores,
                p_scores,
                n_samples=self.n_samples,
                lengths=memory_lengths,
                mask=mask if memory_lengths is not None else None)
            y_align_vectors = q_sample

        # context_y: K x N x T x H
        if y_align_vectors is not None:
            context_y = torch.bmm(
                y_align_vectors.view(-1, targetL, sourceL),
                memory_bank.unsqueeze(0).repeat(self.n_samples, 1, 1, 1).view(
                    -1, sourceL, dim)).view(self.n_samples, batch, targetL,
                                            dim)
        else:
            # For enumerate, K = S.
            # memory_bank: N x S x H
            context_y = (
                memory_bank.unsqueeze(0).repeat(targetL, 1, 1, 1)  # T, N, S, H
                .permute(2, 1, 0, 3))  # S, N, T, H
        input = input.unsqueeze(0).repeat(context_y.size(0), 1, 1, 1)
        concat_y = torch.cat([input, context_y], -1)
        # K x N x T x H
        h_y = self.tanh(self.linear_out(concat_y))

        if one_step:
            if h_c is not None:
                # N x H
                h_c = h_c.squeeze(1)
            # N x S
            c_align_vectors = c_align_vectors.squeeze(1)
            context_c = context_c.squeeze(1)

            # K x N x H
            h_y = h_y.squeeze(2)
            # K x N x S
            #y_align_vectors = y_align_vectors.squeeze(2)

            q_scores = Params(
                alpha=q_scores.alpha.squeeze(1)
                if q_scores.alpha is not None else None,
                dist_type=q_scores.dist_type,
                samples=q_sample.squeeze(2) if q_sample is not None else None,
                sample_log_probs=sample_log_probs.squeeze(2)
                if sample_log_probs is not None else None,
                sample_log_probs_q=sample_log_probs_q.squeeze(2)
                if sample_log_probs_q is not None else None,
                sample_log_probs_p=sample_log_probs_p.squeeze(2)
                if sample_log_probs_p is not None else None,
                sample_p_div_q_log=sample_p_div_q_log.squeeze(2)
                if sample_p_div_q_log is not None else None,
            ) if q_scores is not None else None
            p_scores = Params(
                alpha=p_scores.alpha.squeeze(1),
                log_alpha=log_scores.squeeze(1),
                dist_type=p_scores.dist_type,
                samples=p_sample.squeeze(2) if p_sample is not None else None,
            )

            if h_c is not None:
                # Check output sizes
                batch_, dim_ = h_c.size()
                aeq(batch, batch_)
                batch_, sourceL_ = c_align_vectors.size()
                aeq(batch, batch_)
                aeq(sourceL, sourceL_)
        else:
            assert False
            # Only support input feeding.
            # T x N x H
            h_c = h_c.transpose(0, 1).contiguous()
            # T x N x S
            c_align_vectors = c_align_vectors.transpose(0, 1).contiguous()

            # T x K x N x H
            h_y = h_y.permute(2, 0, 1, 3).contiguous()
            # T x K x N x S
            #y_align_vectors = y_align_vectors.permute(2, 0, 1, 3).contiguous()

            q_scores = Params(
                alpha=q_scores.alpha.transpose(0, 1).contiguous(),
                dist_type=q_scores.dist_type,
                samples=q_sample.permute(2, 0, 1, 3).contiguous(),
            )
            p_scores = Params(
                alpha=p_scores.alpha.transpose(0, 1).contiguous(),
                log_alpha=log_alpha.transpose(0, 1).contiguous(),
                dist_type=p_scores.dist_type,
                samples=p_sample.permute(2, 0, 1, 3).contiguous(),
            )

            # Check output sizes
            targetL_, batch_, dim_ = h_c.size()
            aeq(targetL, targetL_)
            aeq(batch, batch_)
            aeq(dim, dim_)
            targetL_, batch_, sourceL_ = c_align_vectors.size()
            aeq(targetL, targetL_)
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)

        # For now, don't include samples.
        dist_info = DistInfo(
            q=q_scores,
            p=p_scores,
        )

        # h_y: samples from simplex
        #   either K x N x H, or T x K x N x H
        # h_c: convex combination of memory_bank for input feeding
        #   either N x H, or T x N x H
        # align_vectors: convex coefficients / boltzmann dist
        #   either N x S, or T x N x S
        # raw_scores: unnormalized scores
        #   either N x S, or T x N x S
        return h_y, h_c, context_c, c_align_vectors, dist_info
예제 #2
0
    def _run_forward_pass(self,
                          tgt,
                          memory_bank,
                          state,
                          memory_lengths=None,
                          q_scores=None,
                          tgt_emb=None):
        """
        See StdRNNDecoder._run_forward_pass() for description
        of arguments and return values.
        """
        # Additional args check.
        input_feed = state.input_feed.squeeze(0)
        input_feed_batch, _ = input_feed.size()
        tgt_len, tgt_batch, _ = tgt.size()
        aeq(tgt_batch, input_feed_batch)
        # END Additional args check.

        # Initialize local and return variables.
        decoder_outputs = []
        decoder_outputs_baseline = []
        dist_infos = []
        attns = {"std": []}
        if q_scores is not None:
            attns["q"] = []
        if self._copy:
            attns["copy"] = []
        if self._coverage:
            attns["coverage"] = []

        emb = self.dropout(
            self.embeddings(tgt)) if tgt_emb is None else tgt_emb
        assert emb.dim() == 3  # len x batch x embedding_dim

        tgt_len, batch_size = emb.size(0), emb.size(1)
        src_len = memory_bank.size(0)

        hidden = state.hidden
        coverage = state.coverage.squeeze(0) \
            if state.coverage is not None else None

        # Input feed concatenates hidden state with
        # input at every time step.
        for i, emb_t in enumerate(emb.split(1)):
            emb_t = emb_t.squeeze(0)
            decoder_input = torch.cat([emb_t, input_feed], -1)

            rnn_output, hidden = self.rnn(decoder_input.unsqueeze(0), hidden)
            rnn_output = rnn_output.squeeze(0)
            if q_scores is not None:
                # map over tensor-like keys
                q_scores_i = Params(
                    alpha=q_scores.alpha[i],
                    log_alpha=q_scores.log_alpha[i],
                    dist_type=q_scores.dist_type,
                )
            else:
                q_scores_i = None
            decoder_output_y, decoder_output_c, context_c, attn_c, dist_info = self.attn(
                rnn_output,
                memory_bank.transpose(0, 1),
                memory_lengths=memory_lengths,
                q_scores=q_scores_i)

            dist_infos += [dist_info]
            if self.context_gate is not None and decoder_output_c is not None:
                # TODO: context gate should be employed
                # instead of second RNN transform.
                decoder_output_c = self.context_gate(decoder_input, rnn_output,
                                                     decoder_output_c)
            if decoder_output_c is not None:
                decoder_output_c = self.dropout(decoder_output_c)
            input_feed = context_c

            # decoder_output_y : K x N x H
            decoder_output_y = self.dropout(decoder_output_y)

            decoder_outputs += [decoder_output_y]
            if decoder_output_c is not None:
                decoder_outputs_baseline += [decoder_output_c]
            attns["std"] += [attn_c]
            if q_scores is not None:
                attns["q"] += [q_scores.alpha[i]]

            # Update the coverage attention.
            if self._coverage:
                coverage = coverage + p_attn \
                    if coverage is not None else p_attn
                attns["coverage"] += [coverage]

            # Run the forward pass of the copy attention layer.
            if self._copy and not self._reuse_copy_attn:
                _, copy_attn = self.copy_attn(decoder_output,
                                              memory_bank.transpose(0, 1))
                attns["copy"] += [copy_attn]
            elif self._copy:
                attns["copy"] = attns["std"]

        q_info = Params(
            alpha=q_scores.alpha,
            dist_type=q_scores.dist_type,
            samples=torch.stack([d.q.samples for d in dist_infos], dim=0)
            if dist_infos[0].q.samples is not None else None,
            log_alpha=q_scores.log_alpha,
            sample_log_probs=torch.stack(
                [d.q.sample_log_probs for d in dist_infos], dim=0)
            if dist_infos[0].q.sample_log_probs is not None else None,
            sample_log_probs_q=torch.stack(
                [d.q.sample_log_probs_q for d in dist_infos], dim=0)
            if dist_infos[0].q.sample_log_probs_q is not None else None,
            sample_log_probs_p=torch.stack(
                [d.q.sample_log_probs_p for d in dist_infos], dim=0)
            if dist_infos[0].q.sample_log_probs_p is not None else None,
            sample_p_div_q_log=torch.stack(
                [d.q.sample_p_div_q_log for d in dist_infos], dim=0)
            if dist_infos[0].q.sample_p_div_q_log is not None else None,
        ) if q_scores is not None else None
        p_info = Params(
            alpha=torch.stack([d.p.alpha for d in dist_infos], dim=0),
            dist_type=dist_infos[0].p.dist_type,
            log_alpha=torch.stack([d.p.log_alpha for d in dist_infos], dim=0)
            if dist_infos[0].p.log_alpha is not None else None,
            samples=torch.stack([d.p.samples for d in dist_infos], dim=0)
            if dist_infos[0].p.samples is not None else None,
        )
        dist_info = DistInfo(
            q=q_info,
            p=p_info,
        )

        return hidden, decoder_outputs, input_feed, attns, dist_info, decoder_outputs_baseline