Esempio n. 1
0
    def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
        bbsz = prev_output_tokens.size(0)
        vocab = len(self.dictionary)
        src_len = encoder_out.size(1)
        tgt_len = prev_output_tokens.size(1)

        # determine number of steps
        if incremental_state is not None:
            # cache step number
            step = utils.get_incremental_state(self, incremental_state, 'step')
            if step is None:
                step = 0
            utils.set_incremental_state(self, incremental_state, 'step', step + 1)
            steps = [step]
        else:
            steps = list(range(tgt_len))

        # define output in terms of raw probs
        if hasattr(self.args, 'probs'):
            assert self.args.probs.dim() == 3, \
                'expected probs to have size bsz*steps*vocab'
            probs = self.args.probs.index_select(1, torch.LongTensor(steps))
        else:
            probs = torch.FloatTensor(bbsz, len(steps), vocab).zero_()
            for i, step in enumerate(steps):
                # args.beam_probs gives the probability for every vocab element,
                # starting with eos, then unknown, and then the rest of the vocab
                if step < len(self.args.beam_probs):
                    probs[:, i, self.dictionary.eos():] = self.args.beam_probs[step]
                else:
                    probs[:, i, self.dictionary.eos()] = 1.0

        # random attention
        attn = torch.rand(bbsz, tgt_len, src_len)

        return probs, attn
Esempio n. 2
0
    def forward(
        self,
        input_tokens,
        encoder_out,
        incremental_state=None,
        possible_translation_tokens=None,
    ):
        if incremental_state is not None:
            input_tokens = input_tokens[:, -1:]
        bsz, seqlen = input_tokens.size()

        # get outputs from encoder
        (encoder_outs, final_hidden, final_cell, src_lengths,
         src_tokens) = encoder_out

        # embed tokens
        x = self.embed_tokens(input_tokens)
        x = F.dropout(x, p=self.dropout_in, training=self.training)
        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # initialize previous states (or get from cache during incremental generation)
        cached_state = utils.get_incremental_state(self, incremental_state,
                                                   "cached_state")
        if cached_state is not None:
            prev_hiddens, prev_cells, input_feed = cached_state
        else:
            # first time step, initialize previous states
            prev_hiddens, prev_cells = self._init_prev_states(encoder_out)
            input_feed = self.initial_attn_context.expand(
                bsz, self.encoder_hidden_dim)

        attn_scores_per_step = []
        outs = []
        for j in range(seqlen):
            # input feeding: concatenate context vector from previous time step
            if self.attention is not None:
                step_input = torch.cat((x[j, :, :], input_feed), dim=1)
            else:
                step_input = x[j, :, :]
            previous_layer_input = step_input
            for i, rnn in enumerate(self.layers):
                # recurrent cell
                hidden, cell = rnn(step_input,
                                   (prev_hiddens[i], prev_cells[i]))

                # hidden state becomes the input to the next layer
                layer_output = F.dropout(hidden,
                                         p=self.dropout_out,
                                         training=self.training)

                if self.residual_level is not None and i >= self.residual_level:
                    # TODO add an assert related to sizes here
                    step_input = layer_output + previous_layer_input
                else:
                    step_input = layer_output
                previous_layer_input = step_input

                # save state for next time step
                prev_hiddens[i] = hidden
                prev_cells[i] = cell

            if self.attention is not None:
                out, step_attn_scores = self.attention(
                    hidden,
                    encoder_outs,
                    src_lengths,
                )
                input_feed = out
            else:
                combined_output_and_context = hidden
                step_attn_scores = Variable(
                    torch.ones(src_lengths.shape[0],
                               src_lengths.max()).type_as(encoder_outs, ),
                    requires_grad=False,
                ).t()
            attn_scores_per_step.append(step_attn_scores.unsqueeze(1))
            attn_scores = torch.cat(attn_scores_per_step, dim=1)
            # srclen x tgtlen x bsz -> bsz x tgtlen x srclen
            attn_scores = attn_scores.transpose(0, 2)
            combined_output_and_context = torch.cat((hidden, out), dim=1)

            # save final output
            outs.append(combined_output_and_context)

        # cache previous states (no-op except during incremental generation)
        utils.set_incremental_state(
            self,
            incremental_state,
            "cached_state",
            (prev_hiddens, prev_cells, input_feed),
        )

        # collect outputs across time steps
        x = torch.cat(outs, dim=0).view(
            seqlen,
            bsz,
            self.combined_output_and_context_dim,
        )

        # T x B x C -> B x T x C
        x = x.transpose(1, 0)

        # bottleneck layer
        if hasattr(self, "additional_fc"):
            x = self.additional_fc(x)
            x = F.dropout(x, p=self.dropout_out, training=self.training)

        output_projection_w = self.output_projection_w
        output_projection_b = self.output_projection_b
        decoder_input_tokens = input_tokens if self.training else None

        if self.vocab_reduction_module and possible_translation_tokens is None:
            possible_translation_tokens = self.vocab_reduction_module(
                src_tokens, decoder_input_tokens=decoder_input_tokens)

        if possible_translation_tokens is not None:
            output_projection_w = output_projection_w.index_select(
                dim=0, index=possible_translation_tokens)
            output_projection_b = output_projection_b.index_select(
                dim=0, index=possible_translation_tokens)

        # avoiding transpose of projection weights during ONNX tracing
        batch_time_hidden = torch.onnx.operators.shape_as_tensor(x)
        x_flat_shape = torch.cat(
            (torch.LongTensor([-1]), batch_time_hidden[2].view(1)))
        x_flat = torch.onnx.operators.reshape_from_tensor_shape(
            x, x_flat_shape)

        projection_flat = torch.matmul(output_projection_w, x_flat.t()).t()
        logits_shape = torch.cat(
            (batch_time_hidden[:2], torch.LongTensor([-1])))
        logits = (torch.onnx.operators.reshape_from_tensor_shape(
            projection_flat, logits_shape) + output_projection_b)

        return logits, attn_scores, possible_translation_tokens
Esempio n. 3
0
    def forward(self, input_token, target_token, timestep, *inputs):
        """
        Decoder step inputs correspond one-to-one to encoder outputs.
        """
        log_probs_per_model = []
        state_outputs = []

        next_state_input = len(self.models)

        # underlying assumption is each model has same vocab_reduction_module
        vocab_reduction_module = self.models[0].decoder.vocab_reduction_module
        if vocab_reduction_module is not None:
            possible_translation_tokens = inputs[len(self.models)]
            next_state_input += 1
        else:
            possible_translation_tokens = None

        for i, model in enumerate(self.models):
            encoder_output = inputs[i]
            prev_hiddens = []
            prev_cells = []

            for _ in range(len(model.decoder.layers)):
                prev_hiddens.append(inputs[next_state_input])
                prev_cells.append(inputs[next_state_input + 1])
                next_state_input += 2
            prev_input_feed = inputs[next_state_input].view(1, -1)
            next_state_input += 1

            # no batching, we only care about care about "max" length
            src_length_int = int(encoder_output.size()[0])
            src_length = torch.LongTensor(np.array([src_length_int]))

            # notional, not actually used for decoder computation
            src_tokens = torch.LongTensor(np.array([[0] * src_length_int]))
            src_embeddings = encoder_output.new_zeros(encoder_output.shape)

            encoder_out = (
                encoder_output,
                prev_hiddens,
                prev_cells,
                src_length,
                src_tokens,
                src_embeddings,
            )

            # store cached states, use evaluation mode
            model.decoder._is_incremental_eval = True
            model.eval()

            # placeholder
            incremental_state = {}

            # cache previous state inputs
            utils.set_incremental_state(
                model.decoder,
                incremental_state,
                "cached_state",
                (prev_hiddens, prev_cells, prev_input_feed),
            )

            decoder_output = model.decoder(
                input_token.view(1, 1),
                encoder_out,
                incremental_state=incremental_state,
                possible_translation_tokens=possible_translation_tokens,
            )
            logits, _, _ = decoder_output

            log_probs = F.log_softmax(logits, dim=2)

            log_probs_per_model.append(log_probs)

            (next_hiddens, next_cells,
             next_input_feed) = utils.get_incremental_state(
                 model.decoder, incremental_state, "cached_state")

            for h, c in zip(next_hiddens, next_cells):
                state_outputs.extend([h, c])
            state_outputs.append(next_input_feed)

        average_log_probs = torch.mean(torch.cat(log_probs_per_model, dim=0),
                                       dim=0,
                                       keepdim=True)

        if possible_translation_tokens is not None:
            reduced_indices = torch.zeros(self.vocab_size).long().fill_(
                self.unk_token)
            # ONNX-exportable arange (ATen op)
            possible_translation_token_range = torch._dim_arange(
                like=possible_translation_tokens, dim=0)
            reduced_indices[
                possible_translation_tokens] = possible_translation_token_range
            reduced_index = reduced_indices.index_select(dim=0,
                                                         index=target_token)
            score = average_log_probs.view(
                (-1, )).index_select(dim=0, index=reduced_index)
        else:
            score = average_log_probs.view(
                (-1, )).index_select(dim=0, index=target_token)

        word_reward = self.word_rewards.index_select(0, target_token)
        score += word_reward

        self.input_names = ["prev_token", "target_token", "timestep"]
        for i in range(len(self.models)):
            self.input_names.append(f"fixed_input_{i}")

        if possible_translation_tokens is not None:
            self.input_names.append("possible_translation_tokens")

        outputs = [score]
        self.output_names = ["score"]

        for i in range(len(self.models)):
            self.output_names.append(f"fixed_input_{i}")
            outputs.append(inputs[i])

        if possible_translation_tokens is not None:
            self.output_names.append("possible_translation_tokens")
            outputs.append(possible_translation_tokens)

        for i, state in enumerate(state_outputs):
            outputs.append(state)
            self.output_names.append(f"state_output_{i}")
            self.input_names.append(f"state_input_{i}")

        return tuple(outputs)
Esempio n. 4
0
    def extract_features(
        self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused,
    ):
        """
        Similar to *forward* but only return features.

        Returns:
            tuple:
                - the decoder's features of shape `(batch, tgt_len, embed_dim)`
                - attention weights of shape `(batch, tgt_len, src_len)`
        """
        if self.attention is not None:
            assert encoder_out is not None
            encoder_padding_mask = encoder_out['encoder_padding_mask']
            encoder_out = encoder_out['encoder_out']
            # get outputs from encoder
            encoder_outs = encoder_out[0]
            srclen = encoder_outs.size(0)

        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
        bsz, seqlen = prev_output_tokens.size()

        # embed tokens
        x = self.embed_tokens(prev_output_tokens)
        x = F.dropout(x, p=self.dropout_in, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # initialize previous states (or get from cache during incremental generation)
        cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
        if cached_state is not None:
            prev_hiddens, prev_cells, input_feed = cached_state
        else:
            num_layers = len(self.layers)
            prev_hiddens = [x.new_zeros(bsz, self.hidden_size) for i in range(num_layers)]
            prev_cells = [x.new_zeros(bsz, self.hidden_size) for i in range(num_layers)]
            input_feed = x.new_zeros(bsz, self.encoder_output_units) \
                if self.attention is not None else None

        if self.attention is not None:
            attn_scores = x.new_zeros(srclen, seqlen, bsz)
        outs = []
        for j in range(seqlen):
            # input feeding: concatenate context vector from previous time step
            input = torch.cat((x[j, :, :], input_feed), dim=1) \
                if input_feed is not None else x[j, :, :]

            for i, rnn in enumerate(self.layers):
                # recurrent cell
                hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i]))
                if self.residual and i > 0:  # residual connection starts from the 2nd layer
                    prev_layer_hidden = input[:, :hidden.size(1)]

                # compute and apply attention using the 1st layer's hidden state
                if self.attention is not None:
                    if i == 0:
                        context, attn_scores[:, j, :], _ = self.attention(
                            hidden, encoder_outs, encoder_padding_mask,
                        )

                    # hidden state concatenated with context vector becomes the
                    # input to the next layer
                    input = torch.cat((hidden, context), dim=1)
                else:
                    input = hidden
                input = F.dropout(input, p=self.dropout_out, training=self.training)
                if self.residual and i > 0:
                    if self.attention is not None:
                        hidden_sum = input[:, :hidden.size(1)] + prev_layer_hidden
                        input = torch.cat((hidden_sum, input[:, hidden.size(1):]), dim=1)
                    else:
                        input = input + prev_layer_hidden

                # save state for next time step
                prev_hiddens[i] = hidden
                prev_cells[i] = cell

            # input feeding
            input_feed = context if self.attention is not None else None

            # save final output
            outs.append(input)

        # cache previous states (no-op except during incremental generation)
        utils.set_incremental_state(
            self, incremental_state, 'cached_state',
            (prev_hiddens, prev_cells, input_feed),
        )

        # collect outputs across time steps
        x = torch.cat(outs, dim=0).view(seqlen, bsz, -1)
        assert x.size(2) == self.hidden_size + self.encoder_output_units

        # T x B x C -> B x T x C
        x = x.transpose(1, 0)

        if hasattr(self, 'additional_fc') and self.adaptive_softmax is None:
            x = self.additional_fc(x)
            x = F.dropout(x, p=self.dropout_out, training=self.training)

        # srclen x tgtlen x bsz -> bsz x tgtlen x srclen
        if not self.training and self.attention is not None and self.need_attn:
            attn_scores = attn_scores.transpose(0, 2)
        else:
            attn_scores = None

        return x, attn_scores
    def forward(
            self,
            prev_output_tokens: torch.Tensor,  # Z_Tokens[Batch, SeqLength]
            encoder_out=None,
            incremental_state: Dict[str, Any] = None):
        assert incremental_state is not None, 'This model is for incremental decoding only'
        prev_output_tokens = prev_output_tokens[:,
                                                -1:]  # Z_Tokens[Batch, Len=1]
        bsz = prev_output_tokens.size(0)

        if prev_output_tokens.device != self.tree.word_idx.device:
            self.tree.to_cuda(device=prev_output_tokens.device)

        # Move the batched state to the next state according to the automaton
        batch_space_mask = prev_output_tokens.squeeze(-1).eq(
            self.subword_space_idx)  # B[Batch]
        cached_state = utils.get_incremental_state(self.lm_decoder,
                                                   incremental_state,
                                                   'cached_state')

        if cached_state is None:  # First step
            assert (prev_output_tokens == self.subword_eos_idx).all(), \
                'expecting the input to the first time step to be <eos>'
            w: torch.Tensor = prev_output_tokens.new_full(
                [bsz, 1], self.word_eos_idx)  # Z[Batch, Len=1]
            lm_probs: torch.Tensor = self.lm_decoder.get_normalized_probs(
                self.lm_decoder(w, incremental_state=incremental_state),
                log_probs=False,
                sample=None)  # R[Batch, 1, Vocab]
            cumsum_probs: torch.Tensor = lm_probs.cumsum(
                dim=-1)  # R[Batch, 1, Vocab]
            nodes: torch.Tensor = prev_output_tokens.new_full(
                [bsz], self.tree.root_id)  # Z_NodeId[Batch]
            all_children = self.tree.children[
                nodes, :]  # Z[Batch, PossibleChildren]

        else:  # Not the first step
            cumsum_probs: torch.Tensor = utils.get_incremental_state(
                self, incremental_state, 'cumsum_probs')  # R[Batch, 1, Vocab]
            nodes: torch.Tensor = utils.get_incremental_state(
                self, incremental_state, 'nodes')  # Z_NodeId[Batch]
            assert nodes.size(0) == bsz
            w: torch.Tensor = self.tree.word_idx[nodes].unsqueeze(
                1)  # Z[Batch, Len=1]
            w[w < 0] = self.word_unk_idx

            old_cached_state = _clone_cached_state(cached_state)
            # recompute cumsum_probs from inter-word transition probabilities
            # only for those whose prev_output_token is <space>
            lm_probs: torch.Tensor = self.lm_decoder.get_normalized_probs(
                self.lm_decoder(w, incremental_state=incremental_state),
                log_probs=False,
                sample=None)  # R[Batch, 1, Vocab]
            self.lm_decoder.masked_copy_incremental_state(
                incremental_state, old_cached_state,
                batch_space_mask)  # restore those not masked
            cumsum_probs[batch_space_mask] = lm_probs.cumsum(
                dim=-1)[batch_space_mask]

            prev_all_children = self.tree.children[
                nodes, :]  # Z[Batch, PossibleChildren]
            prev_possible_tokens = self.tree.prev_subword_idx[
                prev_all_children]  # Z[Batch, PossibleChildren]
            # intra-word transition: go to child; oov transition: go to "None" node
            mask = prev_possible_tokens.eq(
                prev_output_tokens.expand_as(prev_possible_tokens))
            nodes: torch.Tensor = (prev_all_children * mask.long()).sum(
                dim=1)  # Z[Batch]
            # inter-word transition: go back to root
            nodes[batch_space_mask] = self.tree.root_id  # Z[Batch]
            all_children = self.tree.children[
                nodes, :]  # Z[Batch, PossibleChildren]

        utils.set_incremental_state(self, incremental_state, 'cumsum_probs',
                                    cumsum_probs)
        utils.set_incremental_state(self, incremental_state, 'nodes', nodes)

        # Compute probabilities
        # initialize out_probs [Batch, 1, Vocab]
        if self.open_vocab:
            # set out_probs to oov_penalty * P(<unk>|h) (case 3 in Eqn. 15)
            out_probs = self.oov_penalty * (
                cumsum_probs[:, :, self.word_unk_idx] -
                cumsum_probs[:, :, self.word_unk_idx - 1]
            ).unsqueeze(-1).repeat(1, 1, self.subword_vocab_size)

            # set the probability of emitting <space> to 0 if prev_output_tokens
            # is <space> or <eos>, and that of emitting <eos> to 0 if
            # prev_output_tokens is not <space>
            batch_space_eos_mask = batch_space_mask | \
                prev_output_tokens.squeeze(-1).eq(self.subword_eos_idx)
            out_probs[batch_space_eos_mask, :,
                      self.subword_space_idx] = self.zero
            out_probs[~batch_space_mask, :, self.subword_eos_idx] = self.zero

            # set transition probability to 1 for those whose node is out of the
            # tree, i.e. node is None (case 4 in Eqn. 15)
            batch_node_none_mask = nodes.eq(self.tree.none_id)  # B[Batch]
            out_probs[batch_node_none_mask] = 1.
        else:
            # set out_probs to 0
            out_probs = cumsum_probs.new_full(
                [bsz, 1, self.subword_vocab_size], self.zero)

        # compute parent probabilities for those whose node is not None
        left_ranges = self.tree.word_set_idx[nodes, 0]  # Z[Batch]
        right_ranges = self.tree.word_set_idx[nodes, 1]  # Z[Batch]
        batch_node_not_root_mask = nodes.ne(self.tree.none_id) & nodes.ne(
            self.tree.root_id)  # B[Batch]
        sum_probs = torch.where(
            batch_node_not_root_mask,
            (cumsum_probs.squeeze(1).gather(-1, right_ranges.unsqueeze(-1)) -
             cumsum_probs.squeeze(1).gather(
                 -1, left_ranges.unsqueeze(-1))).squeeze(-1),
            cumsum_probs.new([1.0]))  # R[Batch]

        # compute transition probabilities to child nodes (case 2 in Eqn. 15)
        left_ranges_of_all_children = self.tree.word_set_idx[
            all_children, 0]  # Z[Batch, PossibleChildren]
        right_ranges_of_all_children = self.tree.word_set_idx[
            all_children, 1]  # Z[Batch, PossibleChildren]
        cumsum_probs_of_all_children = (
            cumsum_probs.squeeze(1).gather(-1, right_ranges_of_all_children) -
            cumsum_probs.squeeze(1).gather(-1, left_ranges_of_all_children)
        ).unsqueeze(1) / sum_probs.unsqueeze(-1).unsqueeze(
            -1)  # R[Batch, 1, PossibleChildren]
        cumsum_probs_of_all_children[sum_probs < self.zero, :, :] = self.zero
        next_possible_tokens = self.tree.prev_subword_idx[
            all_children]  # Z[Batch, PossibleChildren]
        out_probs.scatter_(-1, next_possible_tokens.unsqueeze(1),
                           cumsum_probs_of_all_children)
        # assume self.subword_pad_idx is the padding index in self.tree.prev_subword_idx
        out_probs[:, :, self.subword_pad_idx] = self.zero

        # apply word-level probabilities for <space> (case 1 in Eqn. 15)
        word_idx = self.tree.word_idx[nodes]  # Z[Batch]
        batch_node_word_end_mask = word_idx.ge(0)  # B[Batch]
        # get rid of -1's (word idx of root or non-terminal states). It doesn't
        # matter what the "dummy" index it would be replaced with (as it will
        # always be masked out by batch_node_word_end_mask), as long as it is > 0
        word_idx[word_idx < 0] = 1
        word_probs = torch.where(
            sum_probs < self.zero, cumsum_probs.new([self.zero]),
            (cumsum_probs.squeeze(1).gather(-1, word_idx.unsqueeze(-1)) -
             cumsum_probs.squeeze(1).gather(
                 -1,
                 word_idx.unsqueeze(-1) - 1)).squeeze(-1) /
            sum_probs)  # R[Batch]
        out_probs[batch_node_word_end_mask, 0, self.subword_space_idx] = \
            word_probs[batch_node_word_end_mask]

        # take log of probs and clip it from below to avoid log(0)
        out_logprobs = torch.log(
            torch.max(out_probs, out_probs.new([self.zero])))

        # assign log-probs of emitting word <eos> to that of emitting subword <eos>
        out_logprobs[batch_space_mask, :, self.subword_eos_idx] = \
            torch.log(lm_probs)[batch_space_mask, :, self.word_eos_idx]

        utils.set_incremental_state(self, incremental_state, 'out_logprobs',
                                    out_logprobs)

        # note that here we return log-probs rather than logits, and the second
        # element is None, which is usually a tensor of attention weights in
        # attention-based models
        return out_logprobs, None
Esempio n. 6
0
    def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
        bsz, seqlen = prev_output_tokens.size()

        # get outputs from encoder
        encoder_outs, _, _ = encoder_out
        srclen = encoder_outs.size(0)

        # embed tokens
        x = self.embed_tokens(prev_output_tokens)
        x = F.dropout(x, p=self.dropout_in, training=self.training)
        embed_dim = x.size(2)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # initialize previous states (or get from cache during incremental generation)
        cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
        if cached_state is not None:
            prev_hiddens, prev_cells, input_feed = cached_state
        else:
            _, encoder_hiddens, encoder_cells = encoder_out
            num_layers = len(self.layers)
            prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)]
            prev_cells = [encoder_cells[i] for i in range(num_layers)]
            input_feed = Variable(x.data.new(bsz, embed_dim).zero_())

        attn_scores = Variable(x.data.new(srclen, seqlen, bsz).zero_())
        outs = []
        for j in range(seqlen):
            # input feeding: concatenate context vector from previous time step
            input = torch.cat((x[j, :, :], input_feed), dim=1)

            for i, rnn in enumerate(self.layers):
                # recurrent cell
                hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i]))

                # hidden state becomes the input to the next layer
                input = F.dropout(hidden, p=self.dropout_out, training=self.training)

                # save state for next time step
                prev_hiddens[i] = hidden
                prev_cells[i] = cell

            # apply attention using the last layer's hidden state
            out, attn_scores[:, j, :] = self.attention(hidden, encoder_outs)
            out = F.dropout(out, p=self.dropout_out, training=self.training)

            # input feeding
            input_feed = out

            # save final output
            outs.append(out)

        # cache previous states (no-op except during incremental generation)
        utils.set_incremental_state(
            self, incremental_state, 'cached_state', (prev_hiddens, prev_cells, input_feed))

        # collect outputs across time steps
        x = torch.cat(outs, dim=0).view(seqlen, bsz, embed_dim)

        # T x B x C -> B x T x C
        x = x.transpose(1, 0)

        # srclen x tgtlen x bsz -> bsz x tgtlen x srclen
        attn_scores = attn_scores.transpose(0, 2)

        # project back to size of vocabulary
        if hasattr(self, 'additional_fc'):
            x = self.additional_fc(x)
            x = F.dropout(x, p=self.dropout_out, training=self.training)
        x = self.fc_out(x)

        return x, attn_scores
    def forward(self,
                prev_output_tokens,
                shapes=None,
                tgt_tok_bounds=None,
                sort_order=None,
                encoder_out=None,
                src_lengths=None,
                incremental_state=None,
                **kwargs):
        if incremental_state is not None:
            for i in range(len(self.decoders)):
                es = utils.get_incremental_state(self, incremental_state,
                                                 'decoder-' + str(i + 1))
                if es is None:
                    utils.set_incremental_state(self, incremental_state,
                                                'decoder-' + str(i + 1), {})

        char_flag = (
            not self.args.token_sequences) and self.args.char_sequences
        if len(self.decoders) > 1:
            g_shapes = shapes if not char_flag else None
            toks_prev_output_tokens = split_on_sep(prev_output_tokens[0],
                                                   self.sequence_separator,
                                                   shapes=g_shapes)
            g_shapes = shapes if char_flag else None
            char_prev_output_tokens = split_on_sep(prev_output_tokens[1],
                                                   self.sequence_separator,
                                                   shapes=g_shapes)

            assert len(toks_prev_output_tokens) == len(self.decoders)
            assert len(char_prev_output_tokens) == len(self.decoders)
        else:
            toks_prev_output_tokens = [prev_output_tokens[0]]
            char_prev_output_tokens = [prev_output_tokens[1]]

        outputs = []
        decoder_input = [encoder_out]
        if self.training:
            self.decoder_hidden_states = [[]
                                          for i in range(len(self.decoders))]
        for i in range(self.first_decoder, self.last_decoder):
            incremental_state_i = utils.get_incremental_state(
                self, incremental_state, 'decoder-' + str(i + 1))
            if not self.training and self.first_decoder > 0 and i == self.first_decoder:
                assert len(decoder_input) == 1
                for d_idx in range(0, i):
                    if self.args.model_type == 'lstm':
                        new_decoder_input = {
                            'encoder_out':
                            (torch.cat(self.decoder_hidden_states[d_idx],
                                       0), None, None),
                            'encoder_padding_mask':
                            None
                        }
                        decoder_input.append(new_decoder_input)
                    else:
                        new_decoder_input = EncoderOut(
                            encoder_out=torch.cat(
                                self.decoder_hidden_states[d_idx], 0),
                            encoder_padding_mask=None,
                            encoder_embedding=None,
                            encoder_states=None)
                        decoder_input.append(new_decoder_input)

            feats_only = False
            prev_output_tokens_i = [
                toks_prev_output_tokens[i], char_prev_output_tokens[i]
            ]
            src_lengths_i = [src_lengths[0][0], src_lengths[1][0]
                             ] if src_lengths is not None else [[], []]
            decoder_out = self.decoders[i](
                prev_output_tokens_i,
                tgt_tok_bounds[i],
                sort_order,
                encoder_out=decoder_input,
                features_only=feats_only,
                incremental_state=incremental_state_i,
                src_lengths=src_lengths_i if src_lengths is not None else None,
                return_all_hiddens=False,
            )
            outputs.append(decoder_out)

            hidden_state = decoder_out[1]['hidden'].transpose(0, 1)
            if not self.training:
                self.decoder_hidden_states[i].append(hidden_state)

            if self.args.model_type == 'transformer':
                new_decoder_input = EncoderOut(
                    encoder_out=hidden_state,
                    encoder_padding_mask=None,
                    encoder_embedding=None,
                    encoder_states=None,
                )
                decoder_input.append(new_decoder_input)
            else:
                new_decoder_input = {
                    'encoder_out': (hidden_state, None, None),
                    'encoder_padding_mask': None
                }
                decoder_input.append(new_decoder_input)

        if not self.training and len(outputs) != len(self.decoders):
            assert len(outputs) == self.last_decoder - self.first_decoder
            output_clone = outputs[-1]
            for o_idx in range(0, len(self.decoders) - len(outputs)):
                outputs = [output_clone] + outputs

        x = []
        attn_scores = []
        for o in outputs:
            x.append(o[0])
            attn_scores.append(o[1])

        return x, attn_scores
Esempio n. 8
0
    def forward(
        self,
        prev_output_tokens,
        encoder_out,
        incremental_state=None,
        possible_translation_tokens=None,
        timestep=None,
    ):
        (encoder_x, src_tokens,
         encoder_padding_mask) = self._unpack_encoder_out(encoder_out)

        bsz, seqlen = prev_output_tokens.size()
        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]

        x = self.embed_tokens(prev_output_tokens)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        state_outputs = []
        if incremental_state is not None:
            prev_states = utils.get_incremental_state(self, incremental_state,
                                                      "cached_state")
            if prev_states is None:
                prev_states = self._init_prev_states(encoder_out)

            # final 2 states of list are projected key and value
            saved_state = {
                "prev_key": prev_states[-2],
                "prev_value": prev_states[-1]
            }
            self.attention._set_input_buffer(incremental_state, saved_state)

        if incremental_state is not None:
            # first num_layers pairs of states are (prev_hidden, prev_cell)
            # for each layer
            h_prev = prev_states[0]
            c_prev = prev_states[1]
        else:
            h_prev = self._init_hidden(encoder_out, bsz).type_as(x)
            c_prev = torch.zeros([1, bsz, self.lstm_units]).type_as(x)

        x = self._concat_latent_code(x, encoder_out)
        x, (h_next, c_next) = self.initial_rnn_layer(x, (h_prev, c_prev))
        if incremental_state is not None:
            state_outputs.extend([h_next, c_next])

        x = F.dropout(x, p=self.dropout, training=self.training)

        attention_in = x
        if self.proj_layer is not None:
            attention_in = self.proj_layer(x)

        attention_out, attention_weights = self.attention(
            query=attention_in,
            key=encoder_x,
            value=encoder_x,
            key_padding_mask=encoder_padding_mask,
            incremental_state=incremental_state,
            static_kv=True,
            need_weights=(not self.training),
        )

        for i, layer in enumerate(self.extra_rnn_layers):
            residual = x
            rnn_input = torch.cat([x, attention_out], dim=2)
            rnn_input = self._concat_latent_code(rnn_input, encoder_out)

            if incremental_state is not None:
                # first num_layers pairs of states are (prev_hidden, prev_cell)
                # for each layer
                h_prev = prev_states[2 * i + 2]
                c_prev = prev_states[2 * i + 3]
            else:
                h_prev = self._init_hidden(encoder_out, bsz).type_as(x)
                c_prev = torch.zeros([1, bsz, self.lstm_units]).type_as(x)

            x, (h_next, c_next) = layer(rnn_input, (h_prev, c_prev))
            if incremental_state is not None:
                state_outputs.extend([h_next, c_next])
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = x + residual

        x = torch.cat([x, attention_out], dim=2)
        x = self._concat_latent_code(x, encoder_out)
        x = self.bottleneck_layer(x)

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        if (self.vocab_reduction_module is not None
                and possible_translation_tokens is None):
            decoder_input_tokens = prev_output_tokens.contiguous()
            possible_translation_tokens = self.vocab_reduction_module(
                src_tokens, decoder_input_tokens=decoder_input_tokens)

        output_weights = self.embed_out
        if possible_translation_tokens is not None:
            output_weights = output_weights.index_select(
                dim=0, index=possible_translation_tokens)

        logits = F.linear(x, output_weights)

        if incremental_state is not None:
            # encoder projections can be reused at each incremental step
            state_outputs.extend([prev_states[-2], prev_states[-1]])
            utils.set_incremental_state(self, incremental_state,
                                        "cached_state", state_outputs)

        return logits, attention_weights, possible_translation_tokens
Esempio n. 9
0
 def getCached(self, incremental_state, key):
     x = utils.get_incremental_state(self, incremental_state, key)
     return x
Esempio n. 10
0
    def forward(self,
                decoder_in_dict,
                encoder_out_dict,
                incremental_state=None,
                incr_doc_step=False,
                batch_idxs=None,
                new_incr_cached=None):

        prev_output_tokens = decoder_in_dict['prev_output_tokens']

        encoder_padding_mask = encoder_out_dict[
            'encoder_padding_mask']  # [b x w]
        if encoder_padding_mask is not None:
            encoder_padding_mask = encoder_padding_mask.transpose(0, 1)
        srcbsz, srcdoclen, srcdim = encoder_out_dict['encoder_out'][0].size(
        )  # [b x w x d]

        # summarise whole input for h0 decoder use, verbose but clearer
        src_summary_h0 = encoder_out_dict['encoder_out'][0].mean(1)  # [b x d]

        bsz, doclen, sentlen = prev_output_tokens.size(
        )  # these sizes are target ones

        start_doc = 0
        if incremental_state is not None:
            doclen = 1

        # get initial input embedding for document RNN decoder
        x = prev_output_tokens.data.new(bsz).fill_(self.sod_idx)
        x = self.decoder.embed_tokens(x)

        ## Decode sentence states ##

        # initialize previous states (or get from cache during incremental generation)
        cached_state_rnn = utils.get_incremental_state(self, incremental_state,
                                                       'cached_state_rnn')
        if incr_doc_step and cached_state_rnn is not None:
            # doing the fist step of the ith (i>1) sentence in incremental generation
            prev_hiddens, prev_cells, input = cached_state_rnn
            outs = [input]

        elif incremental_state is not None \
                and new_incr_cached is not None: # doing subsequents steps of a sentence in incremental generation
            bidxs, old_bsz, reorder_state = batch_idxs
            if reorder_state is not None:  # need to do this when some hypotheses have been finished when generating
                # reducing decoding to lower nb of hypotheses
                new_incr_cached = new_incr_cached.index_select(
                    0, reorder_state)
                outs = [new_incr_cached]
            else:
                outs = [new_incr_cached]
        else:
            # first state of first sentence in incremental generation or
            # or first coming here to generate the whole sentence in training/scoring
            # previous is h0 with encoder output summary
            outs = []
            encoder_hiddens_cells = src_summary_h0  # [b x d]
            prev_hiddens = []
            prev_cells = []
            for i in range(len(self.layers)):
                prev_hiddens.append(encoder_hiddens_cells)
                prev_cells.append(encoder_hiddens_cells)
            input = x

        # attn of document decoder over input aggregated units (e.g. encoded sequence of paragraphs)
        attn_scores = x.data.new(srcdoclen, doclen, bsz).zero_()

        if (incremental_state is not None and incr_doc_step) \
                or incremental_state is None:
            for j in range(start_doc, doclen):
                for i, rnn in enumerate(self.layers):
                    # recurrent cell
                    hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i]))

                    # hidden state becomes the input to the next layer
                    input = hidden

                    # save state for next time step
                    prev_hiddens[i] = hidden
                    prev_cells[i] = cell

                # apply attention using the last layer's hidden state (sentence vector)
                if self.wordAttention is not None:

                    # inputs to attention are of the form
                    # input: bsz x input_embed_dim
                    # source_hids: srclen x bsz x output_embed_dim
                    # either attend to the input representation by the cnn encoder or to its combination with the input embeddings
                    if self.hidemb:
                        attn_h, attn_scores_out = self.wordAttention(hidden, \
                                                            encoder_out_dict['encoder_out'][1].transpose(0, 1),\
                                                            encoder_padding_mask)
                    else:
                        attn_h, attn_scores_out = self.wordAttention(hidden, \
                                                                     encoder_out_dict['encoder_out'][0].transpose(0, 1), \
                                                                     encoder_padding_mask)

                    out = attn_h  # [b x d]
                else:
                    out = hidden

                # input to next time step
                input = out
                new_incr_cached = out.clone()

                # save final output
                if incremental_state is not None:
                    outs.append(out)
                else:
                    outs.append(out.unsqueeze(1))

        ## Decode sentences ##

        # When training/validation, make all sentence s_t decoding steps in parallel here
        sent_states = None
        if incremental_state is not None:
            dec_outs = x.data.new(bsz, doclen, sentlen,
                                  len(self.decoder.dictionary)).zero_()
            # decode by sentence s_j
            for j in range(doclen):

                sp = self.embed_sent_positions(
                    decoder_in_dict['sentence_position'])

                dec_out, word_atte_scores = self.decoder(
                    prev_output_tokens[:, j, :],
                    outs[j],
                    sp,
                    encoder_out_dict,
                    incremental_state,
                    firstfeed=self.firstfeed,
                    normpos=self.normpos)
                # prev_output_tokens is [ b x s x w ], at each time step decode sentence j [b x w]
                # dec_out is [b x w x vocabulary]
                if j == 0:
                    dec_outs = dec_out
                else:
                    dec_outs = torch.cat((dec_outs, dec_out), 1)
                    # dec_outs is [bxs x w x  vocabulary], dim=0
                    # dec_outs is [b x s*w x  vocabulary], dim=1

        else:
            # decode everything in parallel
            sent_states = torch.cat(outs, dim=1).view(bsz * doclen, -1)
            ys = prev_output_tokens.view(bsz * doclen, -1)
            sp = make_sent_positions(prev_output_tokens,
                                     self.padding_idx).view(bsz * doclen)
            sp = self.embed_sent_positions(sp)

            # Replicate encoder_out_dict for the new nb of batches to do all in parallel

            ebsz, esrclen, edim = encoder_out_dict['encoder_out'][0].size()
            new_enc_out_dict = {}
            #repeat input for each target
            new_enc_out_dict['encoder_out'] = (
                encoder_out_dict['encoder_out'][0].view(
                    ebsz, 1, esrclen, edim).expand(ebsz, doclen, esrclen,
                                                   edim).contiguous().view(
                                                       ebsz * doclen, esrclen,
                                                       edim),
                encoder_out_dict['encoder_out'][1].view(
                    ebsz, 1, esrclen, edim).expand(ebsz, doclen, esrclen,
                                                   edim).contiguous().view(
                                                       ebsz * doclen, esrclen,
                                                       edim))

            new_enc_out_dict['encoder_padding_mask'] = None
            if encoder_out_dict['encoder_padding_mask'] is not None:
                new_enc_out_dict['encoder_padding_mask'] = encoder_out_dict['encoder_padding_mask']\
                                                    .view(ebsz, 1, esrclen).expand(ebsz, doclen, esrclen)\
                                                    .contiguous().view(ebsz*doclen, -1)

            #decode all target sentences of all documents in parallel
            dec_out, word_atte_scores = self.decoder(ys,
                                                     sent_states,
                                                     sp,
                                                     new_enc_out_dict,
                                                     firstfeed=self.firstfeed,
                                                     normpos=self.normpos)
            dec_outs = dec_out.view(bsz, doclen * sentlen,
                                    len(self.decoder.dictionary))

        if incremental_state is not None and incr_doc_step:
            # only if we moved to the next document sentence
            # cache previous states (no-op except during incremental generation)
            utils.set_incremental_state(self, incremental_state,
                                        'cached_state_rnn',
                                        (prev_hiddens, prev_cells, out))

        # srclen x tgtlen x bsz -> bsz x tgtlen x srclen
        attn_scores = attn_scores.transpose(0, 2)

        tkeys = None
        if sent_states is not None:
            tkeys = self.state2key(sent_states)
        else:
            tkeys = self.state2key(outs[j])

        return dec_outs, (attn_scores,
                          word_atte_scores), new_incr_cached, tkeys
Esempio n. 11
0
    def extract_features(self,
                         prev_output_tokens,
                         encoder_out,
                         incremental_state=None):
        """
        Similar to *forward* but only return features.
        """
        encoder_sentemb = encoder_out['sentemb']
        encoder_padding_mask = encoder_out['encoder_padding_mask']
        lang = encoder_out['decoder_lang']
        encoder_out = encoder_out['encoder_out']

        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
        bsz, seqlen = prev_output_tokens.size()

        # get outputs from encoder
        encoder_outs, encoder_hiddens, encoder_cells = encoder_out[:3]
        srclen = encoder_outs.size(0)

        # embed tokens
        x = self.embed_tokens(prev_output_tokens)
        x = F.dropout(x, p=self.dropout_in, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # embed language
        lang_tensor = torch.LongTensor([self.lang_dictionary[lang]] * bsz).to(
            device=prev_output_tokens.device)
        l = self.embed_langs(lang_tensor)

        # initialize previous states (or get from cache during incremental generation)
        cached_state = utils.get_incremental_state(self, incremental_state,
                                                   'cached_state')
        if cached_state is not None:
            prev_hiddens, prev_cells, input_feed = cached_state
        else:
            num_layers = len(self.layers)
            prev_hiddens = [encoder_sentemb for i in range(num_layers)]
            prev_cells = [encoder_sentemb for i in range(num_layers)]
            prev_hiddens = [self.encoder_hidden_proj(x) for x in prev_hiddens]
            prev_cells = [self.encoder_cell_proj(x) for x in prev_cells]
            input_feed = x.new_zeros(bsz, self.hidden_size)

        outs = []
        for j in range(seqlen):
            # input feeding: concatenate context vector from previous time step
            input = torch.cat((x[j, :, :], encoder_sentemb, input_feed, l),
                              dim=1)

            for i, rnn in enumerate(self.layers):
                # recurrent cell
                hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i]))

                # hidden state becomes the input to the next layer
                input = F.dropout(hidden,
                                  p=self.dropout_out,
                                  training=self.training)

                # save state for next time step
                prev_hiddens[i] = hidden
                prev_cells[i] = cell

            out = hidden
            out = F.dropout(out, p=self.dropout_out, training=self.training)

            # input feeding
            input_feed = out

            # save final output
            outs.append(out)

        # cache previous states (no-op except during incremental generation)
        utils.set_incremental_state(
            self,
            incremental_state,
            'cached_state',
            (prev_hiddens, prev_cells, input_feed),
        )

        # collect outputs across time steps
        x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size)

        # T x B x C -> B x T x C
        x = x.transpose(1, 0)

        if hasattr(self, 'additional_fc') and self.adaptive_softmax is None:
            x = self.additional_fc(x)
            x = F.dropout(x, p=self.dropout_out, training=self.training)

        return x, None
Esempio n. 12
0
    def forward(self,
                prev_output_tokens,
                encoder_out=None,
                incremental_state=None,
                **kwargs):
        encoder_padding_mask = encoder_out["encoder_padding_mask"]
        encoder_outs = encoder_out["encoder_out"]

        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
        bsz, seqlen = prev_output_tokens.size()

        srclen = encoder_outs.size(0)

        # embed tokens
        embeddings = self.embed_tokens(prev_output_tokens)
        x = embeddings
        if self.dropout is not None:
            x = self.dropout(x)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # initialize previous states (or get from cache during incremental
        # generation)
        cached_state = utils.get_incremental_state(self, incremental_state,
                                                   "cached_state")
        if cached_state is not None:
            prev_hiddens, prev_cells = cached_state
        else:
            prev_hiddens = [encoder_out["encoder_out"].mean(dim=0)
                            ] * self.num_layers
            prev_cells = [x.new_zeros(bsz, self.hidden_size)] * self.num_layers

        attn_scores = x.new_zeros(bsz, srclen)
        attention_outs = []
        outs = []
        for j in range(seqlen):
            input = x[j, :, :]
            attention_out = None
            for i, layer in enumerate(self.layers):
                # the previous state is one layer below except for the bottom
                # layer where the previous state is the state emitted by the
                # top layer
                hidden, cell = layer(
                    input,
                    (
                        prev_hiddens[(i - 1) % self.num_layers],
                        prev_cells[(i - 1) % self.num_layers],
                    ),
                )
                if self.dropout is not None:
                    hidden = self.dropout(hidden)
                prev_hiddens[i] = hidden
                prev_cells[i] = cell
                if attention_out is None:
                    attention_out, attn_scores = self.attention(
                        hidden, encoder_outs, encoder_padding_mask)
                    if self.dropout is not None:
                        attention_out = self.dropout(attention_out)
                    attention_outs.append(attention_out)
                input = attention_out

            # collect the output of the top layer
            outs.append(hidden)

        # cache previous states (no-op except during incremental generation)
        utils.set_incremental_state(self, incremental_state, "cached_state",
                                    (prev_hiddens, prev_cells))

        # collect outputs across time steps
        x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size)
        attention_outs_concat = torch.cat(attention_outs,
                                          dim=0).view(seqlen, bsz,
                                                      self.context_dim)

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)
        attention_outs_concat = attention_outs_concat.transpose(0, 1)

        # concat LSTM output, attention output and embedding
        # before output projection
        x = torch.cat((x, attention_outs_concat, embeddings), dim=2)
        x = self.deep_output_layer(x)
        x = torch.tanh(x)
        if self.dropout is not None:
            x = self.dropout(x)
        # project back to size of vocabulary
        x = self.output_projection(x)

        # to return the full attn_scores tensor, we need to fix the decoder
        # to account for subsampling input frames
        # return x, attn_scores
        return x, None
Esempio n. 13
0
    def forward(self, input_tokens, prev_scores, timestep, *inputs):
        """
        Decoder step inputs correspond one-to-one to encoder outputs.
        HOWEVER: after the first step, encoder outputs (i.e, the first
        len(self.models) elements of inputs) must be tiled k (beam size)
        times on the batch dimension (axis 1).
        """
        log_probs_per_model = []
        attn_weights_per_model = []
        state_outputs = []

        # from flat to (batch x 1)
        input_tokens = input_tokens.unsqueeze(1)

        next_state_input = len(self.models)

        # size of "batch" dimension of input as tensor
        batch_size = torch.onnx.operators.shape_as_tensor(input_tokens)[0]

        # underlying assumption is each model has same vocab_reduction_module
        vocab_reduction_module = self.models[0].decoder.vocab_reduction_module
        if vocab_reduction_module is not None:
            possible_translation_tokens = inputs[len(self.models)]
            next_state_input += 1
        else:
            possible_translation_tokens = None

        for i, model in enumerate(self.models):
            encoder_output = inputs[i]
            prev_hiddens = []
            prev_cells = []

            for _ in range(len(model.decoder.layers)):
                prev_hiddens.append(inputs[next_state_input])
                prev_cells.append(inputs[next_state_input + 1])
                next_state_input += 2

            # ensure previous attention context has batch dimension
            input_feed_shape = torch.cat((
                batch_size.view(1),
                torch.LongTensor([-1]),
            ), )
            prev_input_feed = torch.onnx.operators.reshape_from_tensor_shape(
                inputs[next_state_input],
                input_feed_shape,
            )
            next_state_input += 1

            # no batching, we only care about care about "max" length
            src_length_int = encoder_output.size()[0]
            src_length = torch.LongTensor(np.array([src_length_int]))

            # notional, not actually used for decoder computation
            src_tokens = torch.LongTensor(np.array([[0] * src_length_int]))

            encoder_out = (
                encoder_output,
                prev_hiddens,
                prev_cells,
                src_length,
                src_tokens,
            )

            # store cached states, use evaluation mode
            model.decoder._is_incremental_eval = True
            model.eval()

            # placeholder
            incremental_state = {}

            # cache previous state inputs
            utils.set_incremental_state(
                model.decoder,
                incremental_state,
                'cached_state',
                (prev_hiddens, prev_cells, prev_input_feed),
            )

            decoder_output = model.decoder(
                input_tokens,
                encoder_out,
                incremental_state=incremental_state,
                possible_translation_tokens=possible_translation_tokens,
            )
            logits, attn_scores, _ = decoder_output

            log_probs = F.log_softmax(logits, dim=2)

            log_probs_per_model.append(log_probs)
            attn_weights_per_model.append(attn_scores)

            (
                next_hiddens,
                next_cells,
                next_input_feed,
            ) = utils.get_incremental_state(
                model.decoder,
                incremental_state,
                'cached_state',
            )

            for h, c in zip(next_hiddens, next_cells):
                state_outputs.extend([h, c])
            state_outputs.append(next_input_feed)

        average_log_probs = torch.mean(
            torch.cat(log_probs_per_model, dim=1),
            dim=1,
            keepdim=True,
        )

        average_attn_weights = torch.mean(
            torch.cat(attn_weights_per_model, dim=1),
            dim=1,
            keepdim=True,
        )

        best_scores_k_by_k, best_tokens_k_by_k = torch.topk(
            average_log_probs.squeeze(1),
            k=self.beam_size,
        )

        prev_scores_k_by_k = prev_scores.view(-1, 1).expand(-1, self.beam_size)
        total_scores_k_by_k = best_scores_k_by_k + prev_scores_k_by_k

        # flatten to take top k over all (beam x beam) hypos
        total_scores_flat = total_scores_k_by_k.view(-1)
        best_tokens_flat = best_tokens_k_by_k.view(-1)

        best_scores, best_indices = torch.topk(
            total_scores_flat,
            k=self.beam_size,
        )

        best_tokens = best_tokens_flat.index_select(
            dim=0,
            index=best_indices,
        ).view(-1)

        # integer division to determine which input produced each successor
        prev_hypos = best_indices / self.beam_size

        attention_weights = average_attn_weights.index_select(
            dim=0,
            index=prev_hypos,
        )

        if possible_translation_tokens is not None:
            best_tokens = possible_translation_tokens.index_select(
                dim=0,
                index=best_tokens,
            )

        word_rewards_for_best_tokens = self.word_rewards.index_select(
            0,
            best_tokens,
        )
        best_scores += word_rewards_for_best_tokens

        self.input_names = ['prev_tokens', 'prev_scores', 'timestep']
        for i in range(len(self.models)):
            self.input_names.append('fixed_input_{}'.format(i))

        if possible_translation_tokens is not None:
            self.input_names.append('possible_translation_tokens')

        # 'attention_weights_average' output shape: (src_length x beam_size)
        attention_weights = attention_weights.squeeze(1)

        outputs = [
            best_tokens,
            best_scores,
            prev_hypos,
            attention_weights,
        ]
        self.output_names = [
            'best_tokens_indices',
            'best_scores',
            'prev_hypos_indices',
            'attention_weights_average',
        ]
        for i in range(len(self.models)):
            self.output_names.append('fixed_input_{}'.format(i))
            if self.tile_internal:
                outputs.append(inputs[i].repeat(1, self.beam_size, 1))
            else:
                outputs.append(inputs[i])

        if possible_translation_tokens is not None:
            self.output_names.append('possible_translation_tokens')
            outputs.append(possible_translation_tokens)

        for i, state in enumerate(state_outputs):
            next_state = state.index_select(
                dim=0,
                index=prev_hypos,
            )
            outputs.append(next_state)
            self.output_names.append('state_output_{}'.format(i))
            self.input_names.append('state_input_{}'.format(i))

        return tuple(outputs)
Esempio n. 14
0
 def _get_hidden_state(self, incremental_state):
     return utils.get_incremental_state(self, incremental_state,
                                        'hidden_state')
Esempio n. 15
0
 def reorder_incremental_state(self, incremental_state, new_order):
     super().reorder_incremental_state(incremental_state, new_order)
     encoder_out = utils.get_incremental_state(self, incremental_state, 'encoder_out')
     if encoder_out is not None:
         encoder_out = tuple(eo.index_select(0, new_order) for eo in encoder_out)
         utils.set_incremental_state(self, incremental_state, 'encoder_out', encoder_out)
Esempio n. 16
0
    def forward(
        self,
        prev_output_tokens,
        encoder_out=None,
        incremental_state=None,
        possible_translation_tokens=None,
        timestep=None,
    ):
        (encoder_x, src_tokens, encoder_padding_mask) = encoder_out

        # embed positions
        positions = (self.embed_positions(prev_output_tokens,
                                          incremental_state=incremental_state)
                     if self.embed_positions is not None else None)

        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
            if positions is not None:
                positions = positions[:, -1:]

            if self.onnx_trace:
                assert type(incremental_state) is list
                assert timestep is not None

                state_list = incremental_state
                incremental_state = {}
                state_index = 0

                for layer in self.layers:
                    utils.set_incremental_state(
                        layer.avg_attn,
                        incremental_state,
                        "prev_vec",
                        state_list[state_index],
                    )
                    utils.set_incremental_state(
                        layer.avg_attn,
                        incremental_state,
                        "prev_sum",
                        state_list[state_index + 1],
                    )
                    state_index += 2
                    utils.set_incremental_state(layer.avg_attn,
                                                incremental_state, "prev_pos",
                                                timestep.float())

                    if layer.encoder_attn is not None:

                        utils.set_incremental_state(
                            layer.encoder_attn,
                            incremental_state,
                            "prev_key",
                            state_list[state_index],
                        )
                        utils.set_incremental_state(
                            layer.encoder_attn,
                            incremental_state,
                            "prev_value",
                            state_list[state_index + 1],
                        )
                        state_index += 2

        # embed tokens and positions
        x = self.embed_scale * self.embed_tokens(prev_output_tokens)

        if self.project_in_dim is not None:
            x = self.project_in_dim(x)

        if positions is not None:
            x += positions
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)
        attn = None

        inner_states = [x]

        # decoder layers
        for layer in self.layers:
            x, attn = layer(
                x,
                encoder_x,
                encoder_padding_mask,
                incremental_state,
                self_attn_mask=self.buffered_future_mask(x)
                if incremental_state is None else None,
            )
            inner_states.append(x)

        if self.normalize:
            x = self.layer_norm(x)

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        if self.project_out_dim is not None:
            x = self.project_out_dim(x)

        # project back to size of vocabulary
        if self.share_input_output_embed:
            output_weights = self.embed_tokens.weight
        else:
            output_weights = self.embed_out

        if (self.vocab_reduction_module is not None
                and possible_translation_tokens is None):
            decoder_input_tokens = prev_output_tokens.contiguous()
            possible_translation_tokens = self.vocab_reduction_module(
                src_tokens, decoder_input_tokens=decoder_input_tokens)
        if possible_translation_tokens is not None:
            output_weights = output_weights.index_select(
                dim=0, index=possible_translation_tokens)

        if self.adaptive_softmax is None:
            logits = F.linear(x, output_weights)
        else:
            assert (
                possible_translation_tokens is None
            ), "vocabulary reduction and adaptive softmax are incompatible!"
            logits = x

        if self.onnx_trace:
            state_outputs = []
            for layer in self.layers:
                prev_vec = utils.get_incremental_state(layer.avg_attn,
                                                       incremental_state,
                                                       "prev_vec")
                prev_sum = utils.get_incremental_state(layer.avg_attn,
                                                       incremental_state,
                                                       "prev_sum")
                state_outputs.extend([prev_vec, prev_sum])

                if layer.encoder_attn is not None:
                    prev_key = utils.get_incremental_state(
                        layer.encoder_attn, incremental_state, "prev_key")
                    prev_value = utils.get_incremental_state(
                        layer.encoder_attn, incremental_state, "prev_value")
                    state_outputs.extend([prev_key, prev_value])

            return logits, attn, possible_translation_tokens, state_outputs

        return logits, attn, possible_translation_tokens
Esempio n. 17
0
    def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
        #def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
        encoder_padding_mask = encoder_out['encoder_padding_mask']
        encoder_out = encoder_out['encoder_out']

        if incremental_state is not None:
            print(prev_output_tokens.size())
            # prev_output_tokens = prev_output_tokens[:, -1:]
            prev_output_tokens = prev_output_tokens[:, -1:, :]

        # bsz, one_input_size = prev_output_tokens.size()
        # self.seq_len = 360
        # seqlen = self.seq_len
        bsz, seqlen, segment_units = prev_output_tokens.size()

        # get outputs from encoder
        encoder_outs, encoder_hiddens, encoder_cells = encoder_out[:3]
        srclen = encoder_outs.size(0)

        # embed tokens
        # x = self.embed_tokens(prev_output_tokens)

        x = prev_output_tokens.view(-1, self.seq_len, self.input_size).float()
        # print(x.size())
        # x = F.dropout(x, p=self.dropout_in, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # initialize previous states (or get from cache during incremental generation)
        cached_state = utils.get_incremental_state(self, incremental_state,
                                                   'cached_state')
        if cached_state is not None:
            prev_hiddens, prev_cells, input_feed = cached_state
        else:
            num_layers = len(self.layers)
            prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)]
            prev_cells = [encoder_cells[i] for i in range(num_layers)]
            if self.encoder_hidden_proj is not None:
                prev_hiddens = [
                    self.encoder_hidden_proj(x) for x in prev_hiddens
                ]
                prev_cells = [self.encoder_cell_proj(x) for x in prev_cells]
            input_feed = x.new_ones(bsz, self.hidden_size) * encoder_outs[
                -1, :bsz, :self.hidden_size]  #[0.5,0.1,1.0,0.0,0.0]#0.5
            input_feed = nn.functional.relu(input_feed)

        attn_scores = x.new_zeros(
            srclen, seqlen, bsz
        )  #x.new_zeros(segment_units, seqlen, bsz)  #x.new_zeros(srclen, seqlen, bsz)
        outs = []
        boundry_param_list = []
        segment_param_list = []

        for j in range(seqlen):
            # from fairseq import pdb; pdb.set_trace()
            # input feeding: concatenate context vector from previous time step
            input_d = F.dropout(x[j, :, :], p=0.5, training=self.training)
            input_mask = input_d > 1e-6  #0.#-1e-6
            input_in = (x[j, :, :] * input_mask.float()) + (
                (1 - input_mask.float()) * input_feed)
            #input = torch.clamp(input, min=-1.0, max=1.0)
            #import pdb; pdb.set_trace()
            self.print_count += 1
            if self.print_count % 1000 == 0:  #random.random() > 0.9999:
                #from fairseq import pdb; pdb.set_trace()
                means = (input_in * (self.max_vals + 1e-6)).view(
                    -1, 18, 5).mean(dim=1).cpu().detach().numpy()
                print("\n\ninput means\t", means)
                wandb.log({"input0": wandb.Histogram(means[:, 0])})
                wandb.log({"input1": wandb.Histogram(means[:, 1])})
                wandb.log({"input2": wandb.Histogram(means[:, 2])})
                wandb.log({"input3": wandb.Histogram(means[:, 3])})
                #wandb.log({"input4": wandb.Histogram(means[4])})
                mean_x = x[j, :, :].view(-1, 18, 5).mean(dim=1)
                print("x[j, :, :] means\t", mean_x.cpu().detach().numpy())
                mean_feed = input_feed.view(-1, 18, 5).mean(dim=1)
                print("input_feed means\t", mean_feed.cpu().detach().numpy())

            # if random.random()>0.0:
            #     input = x[j, :, :]#torch.cat((x[j, :, :], input_feed), dim=1)
            # else:
            #     input = input_feed
            input = input_in
            for i, rnn in enumerate(self.layers):
                # recurrent cell
                hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i]))

                # hidden state becomes the input to the next layer
                #input = F.dropout(hidden, p=self.dropout_out, training=self.training)

                # save state for next time step
                prev_hiddens[i] = hidden
                prev_cells[i] = cell

            # apply attention using the last layer's hidden state
            if self.attention is not None:
                out, attn_scores[:, j, :] = self.attention(
                    hidden, encoder_outs, encoder_padding_mask)
            else:
                out = hidden
            # from fairseq import pdb; pdb.set_trace()
            ntf_input = self.ntf_projection(out)
            boundry_params, segment_params = torch.split(
                ntf_input, [3, 8 * self.num_segments], dim=1)
            segment_params = segment_params.view((-1, 8, self.num_segments))
            boundry_param_list.append(boundry_params)
            segment_param_list.append(segment_params)
            # boundry_params = torch.Tensor([200.0,10000.0,200.0]).to(self.device)*torch.sigmoid(boundry_params)
            boundry_params = torch.Tensor([150.0, 10000.0, 100.0]).to(
                self.device) * torch.sigmoid(boundry_params)

            segment_params = torch.cat([
                torch.sigmoid(segment_params[:, :4, :]),
                torch.tanh(segment_params[:, 4:, :])
            ],
                                       dim=1)
            # vf, a, rhocr, g, omegar, omegas, epsq, epsv
            segment_params = segment_params * torch.Tensor([[150.0], [
                2.0
            ], [100.0], [5.0], [10.0], [10.0], [10.0], [10.0]]).to(self.device)
            segment_params = segment_params.permute(0, 2, 1)
            unscaled_input = input_in * self.max_vals

            # print("boundry_params",boundry_params[0,::5].mean().item(),boundry_params.size())
            # print("segment_params",segment_params[0,::5,0].mean().item(),segment_params.size())
            # print(unscaled_input)

            model_steps = []
            num_steps = 3  #18
            for _ in range(num_steps):
                out1 = self.ntf_module(unscaled_input, segment_params,
                                       boundry_params)
                model_steps.append(out1)
                unscaled_input = out1

            out = torch.stack(model_steps, dim=0).mean(dim=0)

            avg_sum_max_vals = self.max_vals  # summing everything above but speed and occ need to be avg
            # avg_sum_max_vals[1::5] *= num_steps #mean occupancy
            # avg_sum_max_vals[2::5] *= num_steps #mean speed

            out = out / (avg_sum_max_vals + 1e-6)

            # print(out.mean().item())

            # out = out / (self.max_vals+1e-6)

            #from fairseq import pdb; pdb.set_trace()

            #out = F.dropout(out, p=self.dropout_out, training=self.training)

            # input feeding
            input_feed = out  #.view(-1,360,90)

            # save final output
            outs.append(out)

        # cache previous states (no-op except during incremental generation)
        utils.set_incremental_state(
            self,
            incremental_state,
            'cached_state',
            (prev_hiddens, prev_cells, input_feed),
        )

        # collect outputs across time steps
        #print(torch.stack(outs, dim=0).size())
        # from fairseq import pdb; pdb.set_trace();
        x = torch.stack(outs, dim=1)  #.view(seqlen, bsz, self.hidden_size)

        self.all_boundry_params = torch.stack(boundry_param_list, dim=1)
        self.all_segment_params = torch.stack(segment_param_list, dim=1)

        # T x B x C -> B x T x C
        #x = x.transpose(1, 0)

        # srclen x tgtlen x bsz -> bsz x tgtlen x srclen
        if not self.training and self.need_attn:
            attn_scores = attn_scores.transpose(0, 2)
        else:
            attn_scores = None

        # project back to size of vocabulary
        if self.adaptive_softmax is None:
            if hasattr(self, 'additional_fc'):
                x = self.additional_fc(x)
                x = F.dropout(x, p=self.dropout_out, training=self.training)
            if self.share_input_output_embed:
                x = F.linear(x, self.embed_tokens.weight)
            # else:
            #     x = self.fc_out(x)
        # import fairseq.pdb as pdb; pdb.set_trace()#[:,-1,:]

        #x = x.contiguous().view(bsz,-1)#self.output_size)#self.fc_out(x)
        return x, attn_scores
Esempio n. 18
0
 def _get_input_buffer(
         self, incremental_state: Optional[Dict[str,
                                                Dict[str,
                                                     Optional[Tensor]]]]):
     return utils.get_incremental_state(self, incremental_state,
                                        "input_buffer")
Esempio n. 19
0
 def _get_cached_bert(self, incremental_state):
     return utils.get_incremental_state(
         self,
         incremental_state,
         'cached_bert',
     )
Esempio n. 20
0
    def extract_features(self,
                         prev_output_tokens,
                         encoder_out,
                         incremental_state=None):
        """
        Similar to *forward* but only return features.
        """
        if encoder_out is not None:
            encoder_padding_mask = encoder_out['encoder_padding_mask']
            encoder_out = encoder_out['encoder_out']
        else:
            encoder_padding_mask = None
            encoder_out = None

        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
        bsz, seqlen = prev_output_tokens.size()

        # get outputs from encoder
        if encoder_out is not None:
            encoder_outs, encoder_hiddens, encoder_cells = encoder_out[:3]
            srclen = encoder_outs.size(0)
        else:
            srclen = None

        # embed tokens
        x = self.embed_tokens(prev_output_tokens)
        x = F.dropout(x, p=self.dropout_in, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # initialize previous states (or get from cache during incremental generation)
        cached_state = utils.get_incremental_state(self, incremental_state,
                                                   'cached_state')
        if cached_state is not None:
            prev_hiddens, prev_cells, input_feed = cached_state
        elif encoder_out is not None:
            # setup recurrent cells
            num_layers = len(self.layers)
            prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)]
            prev_cells = [encoder_cells[i] for i in range(num_layers)]
            if self.encoder_hidden_proj is not None:
                prev_hiddens = [
                    self.encoder_hidden_proj(x) for x in prev_hiddens
                ]
                prev_cells = [self.encoder_cell_proj(x) for x in prev_cells]
            input_feed = x.new_zeros(bsz, self.hidden_size)
        else:
            # setup zero cells, since there is no encoder
            num_layers = len(self.layers)
            zero_state = x.new_zeros(bsz, self.hidden_size)
            prev_hiddens = [zero_state for i in range(num_layers)]
            prev_cells = [zero_state for i in range(num_layers)]
            input_feed = None

        assert srclen is not None or self.attention is None, \
            "attention is not supported if there are no encoder outputs"
        attn_scores = x.new_zeros(srclen, seqlen,
                                  bsz) if self.attention is not None else None
        outs = []
        for j in range(seqlen):
            # input feeding: concatenate context vector from previous time step
            if input_feed is not None:
                input = torch.cat((x[j, :, :], input_feed), dim=1)
            else:
                input = x[j]

            for i, rnn in enumerate(self.layers):
                # recurrent cell
                hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i]))

                # hidden state becomes the input to the next layer
                input = F.dropout(hidden,
                                  p=self.dropout_out,
                                  training=self.training)

                # save state for next time step
                prev_hiddens[i] = hidden
                prev_cells[i] = cell

            # apply attention using the last layer's hidden state
            if self.attention is not None:
                out, attn_scores[:, j, :] = self.attention(
                    hidden, encoder_outs, encoder_padding_mask)
            else:
                out = hidden
            out = F.dropout(out, p=self.dropout_out, training=self.training)

            # input feeding
            if input_feed is not None:
                input_feed = out

            # save final output
            outs.append(out)

        # cache previous states (no-op except during incremental generation)
        utils.set_incremental_state(
            self,
            incremental_state,
            'cached_state',
            (prev_hiddens, prev_cells, input_feed),
        )

        # collect outputs across time steps
        x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size)

        # T x B x C -> B x T x C
        x = x.transpose(1, 0)

        if hasattr(self, 'additional_fc') and self.adaptive_softmax is None:
            x = self.additional_fc(x)
            x = F.dropout(x, p=self.dropout_out, training=self.training)

        # srclen x tgtlen x bsz -> bsz x tgtlen x srclen
        if not self.training and self.need_attn and self.attention is not None:
            attn_scores = attn_scores.transpose(0, 2)
        else:
            attn_scores = None
        return x, attn_scores
Esempio n. 21
0
    def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None):
        if encoder_out_dict is not None:
            encoder_out = encoder_out_dict['encoder_out']
            encoder_padding_mask = encoder_out_dict['encoder_padding_mask']

        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
        bsz, seqlen = prev_output_tokens.size()

        # get outputs from encoder
        encoder_outs, _, _ = encoder_out[:3]
        srclen = encoder_outs.size(0)

        if bsz != encoder_outs.size(1):
            prev_output_tokens = prev_output_tokens.t()
            bsz, seqlen = seqlen, bsz

        # embed tokens
        x = self.embed_tokens(prev_output_tokens)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        x_in = F.dropout(x, p=self.dropout_in, training=self.training)

        # initialize previous states (or get from cache during incremental generation)
        cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
        if cached_state is not None:
            prev_hiddens, prev_cells = cached_state
        else:
            _, encoder_hiddens, encoder_cells = encoder_out[:3]
            if self.initial_state == 'same':
                prev_hiddens = encoder_hiddens
                prev_cells = encoder_cells
            elif self.initial_state == 'linear':
                prev_hiddens = self.tanh(self.proj_hidden(encoder_hiddens))
                prev_cells = self.tanh(self.proj_cell(encoder_cells))
            else:
                raise NotImplementedError()

        attn_scores = x.data.new(srclen, seqlen, bsz).zero_()
        outs = []
        ctxs = []

        if hasattr(self, 'ctx_proj'):
            encoder_ctx = self.ctx_proj(encoder_outs)
        else:
            encoder_ctx = encoder_outs
        for j in range(seqlen):
            input = x_in[j]

            for i, rnn in enumerate(self.layers):
                # recurrent cell
                hidden, cell = rnn(input, (prev_hiddens, prev_cells))

                # apply attention using the last layer's hidden state
                if self.attention is not None and i == 0:
                    # attention output becomes the input to the next layer
                    attn_input = F.dropout(hidden, p=self.dropout_out, training=self.training)
                    ctx, attn_scores[:, j, :] = self.attention(attn_input, encoder_ctx, encoder_padding_mask)
                    ctxs.append(ctx)
                    input = F.dropout(ctx, p=self.dropout_out, training=self.training)
                else:
                    out = hidden
                # save state for next time step
                prev_hiddens = hidden
                prev_cells = cell

            # save final output
            outs.append(out)

        # cache previous states (no-op except during incremental generation)
        utils.set_incremental_state(
            self, incremental_state, 'cached_state', (prev_hiddens, prev_cells))

        # collect outputs across time steps
        outs = torch.stack(outs)
        ctxs = torch.stack(ctxs)

        out = torch.cat([outs, ctxs, x], dim=2)
        out = F.dropout(out, p=self.dropout_out, training=self.training)

        # T x B x C -> B x T x C
        out = out.transpose(1, 0)

        # srclen x tgtlen x bsz -> bsz x tgtlen x srclen
        if not self.training and self.need_attn:
            attn_scores = attn_scores.transpose(0, 2)
        else:
            attn_scores = None

        # project back to size of vocabulary
        out = self.tanh(self.additional_fc(out))
        out = F.dropout(out, p=self.dropout_out, training=self.training)
        out = self.fc_out(out)

        return out, attn_scores
 def _get_input_buffer(self, incremental_state, incremental_clone_id=""):
     return (utils.get_incremental_state(
         self, incremental_state, "attn_state" + incremental_clone_id)
             or {})
 def _get_monotonic_buffer(self, incremental_state):
     return utils.get_incremental_state(
         self,
         incremental_state,
         'monotonic',
     ) or {}
 def get_pointer(self, incremental_state):
     return utils.get_incremental_state(
         self,
         incremental_state,
         'monotonic',
     ) or {}
Esempio n. 25
0
    def forward(self, input_token, timestep, *inputs):
        """
        Decoder step inputs correspond one-to-one to encoder outputs.
        """
        log_probs_per_model = []
        attn_weights_per_model = []
        state_outputs = []

        next_state_input = len(self.models)

        # underlying assumption is each model has same vocab_reduction_module
        vocab_reduction_module = self.models[0].decoder.vocab_reduction_module
        if vocab_reduction_module is not None:
            possible_translation_tokens = inputs[len(self.models)]
            next_state_input += 1
        else:
            possible_translation_tokens = None

        for i, model in enumerate(self.models):
            encoder_output = inputs[i]
            prev_hiddens = []
            prev_cells = []

            for _ in range(len(model.decoder.layers)):
                prev_hiddens.append(inputs[next_state_input])
                prev_cells.append(inputs[next_state_input + 1])
                next_state_input += 2
            prev_input_feed = inputs[next_state_input].view((1, -1))
            next_state_input += 1

            # no batching, we only care about care about "max" length
            src_length_int = int(encoder_output.size()[0])
            src_length = torch.LongTensor(np.array([src_length_int]))

            # notional, not actually used for decoder computation
            src_tokens = torch.LongTensor(np.array([[0] * src_length_int]))

            encoder_out = (
                encoder_output,
                prev_hiddens,
                prev_cells,
                src_length,
                src_tokens,
            )

            # store cached states, use evaluation mode
            model.decoder._is_incremental_eval = True
            model.eval()

            # placeholder
            incremental_state = {}

            # cache previous state inputs
            utils.set_incremental_state(
                model.decoder,
                incremental_state,
                "cached_state",
                (prev_hiddens, prev_cells, prev_input_feed),
            )

            decoder_output = model.decoder(
                input_token,
                encoder_out,
                incremental_state=incremental_state,
                possible_translation_tokens=possible_translation_tokens,
            )
            logits, attn_scores, _ = decoder_output

            log_probs = F.log_softmax(logits, dim=2)

            log_probs_per_model.append(log_probs)
            attn_weights_per_model.append(attn_scores)

            (next_hiddens, next_cells,
             next_input_feed) = utils.get_incremental_state(
                 model.decoder, incremental_state, "cached_state")

            for h, c in zip(next_hiddens, next_cells):
                state_outputs.extend([h, c])
            state_outputs.append(next_input_feed)

        average_log_probs = torch.mean(torch.cat(log_probs_per_model, dim=0),
                                       dim=0,
                                       keepdim=True)

        average_attn_weights = torch.mean(torch.cat(attn_weights_per_model,
                                                    dim=0),
                                          dim=0,
                                          keepdim=True)

        best_scores, best_tokens = torch.topk(average_log_probs.view(1, -1),
                                              k=self.beam_size)

        if possible_translation_tokens is not None:
            best_tokens = possible_translation_tokens.index_select(
                dim=0, index=best_tokens.view(-1)).view(1, -1)

        word_rewards_for_best_tokens = self.word_rewards.index_select(
            0, best_tokens.view(-1))
        best_scores += word_rewards_for_best_tokens

        self.input_names = ["prev_token", "timestep"]
        for i in range(len(self.models)):
            self.input_names.append(f"fixed_input_{i}")

        if possible_translation_tokens is not None:
            self.input_names.append("possible_translation_tokens")

        outputs = [best_tokens, best_scores, average_attn_weights]
        self.output_names = [
            "best_tokens_indices",
            "best_scores",
            "attention_weights_average",
        ]
        for i, state in enumerate(state_outputs):
            outputs.append(state)
            self.output_names.append(f"state_output_{i}")
            self.input_names.append(f"state_input_{i}")

        return tuple(outputs)
Esempio n. 26
0
    def forward(self, input_tokens, prev_scores, timestep, *inputs):
        """
        Decoder step inputs correspond one-to-one to encoder outputs.
        HOWEVER: after the first step, encoder outputs (i.e, the first
        len(self.models) elements of inputs) must be tiled k (beam size)
        times on the batch dimension (axis 1).
        """
        log_probs_per_model = []
        attn_weights_per_model = []
        state_outputs = []
        beam_axis_per_state = []

        # from flat to (batch x 1)
        input_tokens = input_tokens.unsqueeze(1)

        next_state_input = len(self.models)

        # size of "batch" dimension of input as tensor
        batch_size = torch.onnx.operators.shape_as_tensor(input_tokens)[0]

        # underlying assumption is each model has same vocab_reduction_module
        vocab_reduction_module = self.models[0].decoder.vocab_reduction_module
        if vocab_reduction_module is not None:
            possible_translation_tokens = inputs[len(self.models)]
            next_state_input += 1
        else:
            possible_translation_tokens = None

        for i, model in enumerate(self.models):
            if (isinstance(model, rnn.RNNModel)
                    or isinstance(model, char_source_model.CharSourceModel)
                    or isinstance(model,
                                  word_prediction_model.WordPredictionModel)):
                encoder_output = inputs[i]
                prev_hiddens = []
                prev_cells = []

                for _ in range(len(model.decoder.layers)):
                    prev_hiddens.append(inputs[next_state_input])
                    prev_cells.append(inputs[next_state_input + 1])
                    next_state_input += 2

                # ensure previous attention context has batch dimension
                input_feed_shape = torch.cat(
                    (batch_size.view(1), torch.LongTensor([-1])))
                prev_input_feed = torch.onnx.operators.reshape_from_tensor_shape(
                    inputs[next_state_input], input_feed_shape)
                next_state_input += 1

                # no batching, we only care about care about "max" length
                src_length_int = int(encoder_output.size()[0])
                src_length = torch.LongTensor(np.array([src_length_int]))

                # notional, not actually used for decoder computation
                src_tokens = torch.LongTensor(np.array([[0] * src_length_int]))
                src_embeddings = encoder_output.new_zeros(encoder_output.shape)

                encoder_out = (
                    encoder_output,
                    prev_hiddens,
                    prev_cells,
                    src_length,
                    src_tokens,
                    src_embeddings,
                )

                # store cached states, use evaluation mode
                model.decoder._is_incremental_eval = True
                model.eval()

                # placeholder
                incremental_state = {}

                # cache previous state inputs
                utils.set_incremental_state(
                    model.decoder,
                    incremental_state,
                    "cached_state",
                    (prev_hiddens, prev_cells, prev_input_feed),
                )

                decoder_output = model.decoder(
                    input_tokens,
                    encoder_out,
                    incremental_state=incremental_state,
                    possible_translation_tokens=possible_translation_tokens,
                )
                logits, attn_scores, _ = decoder_output

                log_probs = F.log_softmax(logits, dim=2)

                log_probs_per_model.append(log_probs)
                attn_weights_per_model.append(attn_scores)

                (
                    next_hiddens,
                    next_cells,
                    next_input_feed,
                ) = utils.get_incremental_state(model.decoder,
                                                incremental_state,
                                                "cached_state")

                for h, c in zip(next_hiddens, next_cells):
                    state_outputs.extend([h, c])
                    beam_axis_per_state.extend([0, 0])

                state_outputs.append(next_input_feed)
                beam_axis_per_state.append(0)

            elif isinstance(model, transformer.TransformerModel):
                encoder_output = inputs[i]

                # store cached states, use evaluation mode
                model.decoder._is_incremental_eval = True
                model.eval()

                # placeholder
                incremental_state = {}

                state_inputs = []
                for _ in model.decoder.layers:
                    # (prev_key, prev_value) for self- and encoder-attention
                    state_inputs.extend(
                        inputs[next_state_input:next_state_input + 4])
                    next_state_input += 4

                encoder_out = (encoder_output, None, None)

                decoder_output = model.decoder(
                    input_tokens,
                    encoder_out,
                    incremental_state=state_inputs,
                    possible_translation_tokens=possible_translation_tokens,
                    timestep=timestep,
                )
                logits, attn_scores, _, attention_states = decoder_output

                log_probs = F.log_softmax(logits, dim=2)
                log_probs_per_model.append(log_probs)
                attn_weights_per_model.append(attn_scores)

                state_outputs.extend(attention_states)
                beam_axis_per_state.extend([0 for _ in attention_states])
            else:
                raise RuntimeError(f"Not a supported model: {type(model)}")

        average_log_probs = torch.mean(torch.cat(log_probs_per_model, dim=1),
                                       dim=1,
                                       keepdim=True)

        average_attn_weights = torch.mean(torch.cat(attn_weights_per_model,
                                                    dim=1),
                                          dim=1,
                                          keepdim=True)

        best_scores_k_by_k, best_tokens_k_by_k = torch.topk(
            average_log_probs.squeeze(1), k=self.beam_size)

        prev_scores_k_by_k = prev_scores.view(-1, 1).expand(-1, self.beam_size)
        total_scores_k_by_k = best_scores_k_by_k + prev_scores_k_by_k

        # flatten to take top k over all (beam x beam) hypos
        total_scores_flat = total_scores_k_by_k.view(-1)
        best_tokens_flat = best_tokens_k_by_k.view(-1)

        best_scores, best_indices = torch.topk(total_scores_flat,
                                               k=self.beam_size)

        best_tokens = best_tokens_flat.index_select(
            dim=0, index=best_indices).view(-1)

        # integer division to determine which input produced each successor
        prev_hypos = best_indices / self.beam_size

        attention_weights = average_attn_weights.index_select(dim=0,
                                                              index=prev_hypos)

        if possible_translation_tokens is not None:
            best_tokens = possible_translation_tokens.index_select(
                dim=0, index=best_tokens)

        word_rewards_for_best_tokens = self.word_rewards.index_select(
            0, best_tokens)
        best_scores += word_rewards_for_best_tokens

        self.input_names = ["prev_tokens", "prev_scores", "timestep"]
        for i in range(len(self.models)):
            self.input_names.append(f"fixed_input_{i}")

        if possible_translation_tokens is not None:
            self.input_names.append("possible_translation_tokens")

        # 'attention_weights_average' output shape: (src_length x beam_size)
        attention_weights = attention_weights.squeeze(1)

        outputs = [best_tokens, best_scores, prev_hypos, attention_weights]
        self.output_names = [
            "best_tokens_indices",
            "best_scores",
            "prev_hypos_indices",
            "attention_weights_average",
        ]
        for i in range(len(self.models)):
            self.output_names.append(f"fixed_input_{i}")
            if self.tile_internal:
                outputs.append(inputs[i].repeat(1, self.beam_size, 1))
            else:
                outputs.append(inputs[i])

        if possible_translation_tokens is not None:
            self.output_names.append("possible_translation_tokens")
            outputs.append(possible_translation_tokens)

        for i, state in enumerate(state_outputs):
            beam_axis = beam_axis_per_state[i]
            next_state = state.index_select(dim=beam_axis, index=prev_hypos)
            outputs.append(next_state)
            self.output_names.append(f"state_output_{i}")
            self.input_names.append(f"state_input_{i}")

        return tuple(outputs)
Esempio n. 27
0
    def forward_unprojected(self, input_tokens, encoder_out, incremental_state=None):
        if incremental_state is not None:
            input_tokens = input_tokens[:, -1:]
        bsz, seqlen = input_tokens.size()

        # get outputs from encoder
        (encoder_outs, final_hidden, final_cell, src_lengths, src_tokens) = encoder_out

        # embed tokens
        x = self.embed_tokens(input_tokens)
        x = F.dropout(x, p=self.dropout_in, training=self.training)
        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # initialize previous states (or get from cache during incremental generation)
        cached_state = utils.get_incremental_state(
            self, incremental_state, "cached_state"
        )
        input_feed = None
        if cached_state is not None:
            prev_hiddens, prev_cells, input_feed = cached_state
        else:
            # first time step, initialize previous states
            prev_hiddens, prev_cells = self._init_prev_states(encoder_out)
            if self.attention.context_dim:
                input_feed = self.initial_attn_context.expand(
                    bsz, self.attention.context_dim
                )

        attn_scores_per_step = []
        outs = []
        for j in range(seqlen):
            # input feeding: concatenate context vector from previous time step
            step_input = maybe_cat((x[j, :, :], input_feed), dim=1)
            previous_layer_input = step_input
            for i, rnn in enumerate(self.layers):
                # recurrent cell
                hidden, cell = rnn(step_input, (prev_hiddens[i], prev_cells[i]))

                # hidden state becomes the input to the next layer
                layer_output = F.dropout(
                    hidden, p=self.dropout_out, training=self.training
                )

                if self.residual_level is not None and i >= self.residual_level:
                    # TODO add an assert related to sizes here
                    step_input = layer_output + previous_layer_input
                else:
                    step_input = layer_output
                previous_layer_input = step_input

                # save state for next time step
                prev_hiddens[i] = hidden
                prev_cells[i] = cell

            out, step_attn_scores = self.attention(hidden, encoder_outs, src_lengths)
            input_feed = out
            attn_scores_per_step.append(step_attn_scores.unsqueeze(1))
            attn_scores = torch.cat(attn_scores_per_step, dim=1)
            # srclen x tgtlen x bsz -> bsz x tgtlen x srclen
            attn_scores = attn_scores.transpose(0, 2)
            combined_output_and_context = maybe_cat((hidden, out), dim=1)

            # save final output
            outs.append(combined_output_and_context)

        # cache previous states (no-op except during incremental generation)
        utils.set_incremental_state(
            self,
            incremental_state,
            "cached_state",
            (prev_hiddens, prev_cells, input_feed),
        )

        # collect outputs across time steps
        x = torch.cat(outs, dim=0).view(
            seqlen, bsz, self.combined_output_and_context_dim
        )

        # T x B x C -> B x T x C
        x = x.transpose(1, 0)

        # bottleneck layer
        if hasattr(self, "additional_fc"):
            x = self.additional_fc(x)
            x = F.dropout(x, p=self.dropout_out, training=self.training)
        return x, attn_scores
Esempio n. 28
0
    def forward(self,
                prev_output_tokens,
                encoder_out_dict,
                incremental_state=None):
        encoder_out = encoder_out_dict['encoder_out']
        encoder_padding_mask = encoder_out_dict['encoder_padding_mask']

        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
        bsz, seqlen = prev_output_tokens.size()

        # get outputs from encoder
        encoder_outs, encoder_hiddens, encoder_cells = encoder_out[:3]
        srclen = encoder_outs.size(0)

        # embed tokens
        x = self.embed_tokens(prev_output_tokens)
        x = F.dropout(x, p=self.dropout_in, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # initialize previous states (or get from cache during incremental generation)
        cached_state = utils.get_incremental_state(self, incremental_state,
                                                   'cached_state')
        if cached_state is not None:
            prev_hiddens, prev_cells, input_feed = cached_state
        else:
            num_layers = len(self.layers)
            prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)]
            prev_cells = [encoder_cells[i] for i in range(num_layers)]
            if self.encoder_hidden_proj is not None:
                prev_hiddens = [
                    self.encoder_hidden_proj(x) for x in prev_hiddens
                ]
                prev_cells = [self.encoder_cell_proj(x) for x in prev_cells]
            input_feed = x.new_zeros(bsz, self.hidden_size)

        attn_scores = x.new_zeros(srclen, seqlen, bsz)
        outs = []
        for j in range(seqlen):
            # input feeding: concatenate context vector from previous time step
            input = torch.cat((x[j, :, :], input_feed), dim=1)

            for i, rnn in enumerate(self.layers):
                # recurrent cell
                hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i]))

                # hidden state becomes the input to the next layer
                input = F.dropout(hidden,
                                  p=self.dropout_out,
                                  training=self.training)

                # save state for next time step
                prev_hiddens[i] = hidden
                prev_cells[i] = cell

            # apply attention using the last layer's hidden state
            if self.attention is not None:
                out, attn_scores[:, j, :] = self.attention(
                    hidden, encoder_outs, encoder_padding_mask)
            else:
                out = hidden
            out = F.dropout(out, p=self.dropout_out, training=self.training)

            # input feeding
            input_feed = out

            # save final output
            outs.append(out)

        # cache previous states (no-op except during incremental generation)
        utils.set_incremental_state(
            self,
            incremental_state,
            'cached_state',
            (prev_hiddens, prev_cells, input_feed),
        )

        # collect outputs across time steps
        x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size)

        # T x B x C -> B x T x C
        x = x.transpose(1, 0)

        # srclen x tgtlen x bsz -> bsz x tgtlen x srclen
        if not self.training and self.need_attn:
            attn_scores = attn_scores.transpose(0, 2)
        else:
            attn_scores = None

        # project back to size of vocabulary
        if self.adaptive_softmax is None:
            if hasattr(self, 'additional_fc'):
                x = self.additional_fc(x)
                x = F.dropout(x, p=self.dropout_out, training=self.training)
            if self.share_input_output_embed:
                x = F.linear(x, self.embed_tokens.weight)
            else:
                x = self.fc_out(x)
        return x, attn_scores
    def forward(self,
                prev_output_tokens,
                encoder_out,
                lang,
                incremental_state=None):
        encoder_sentemb = encoder_out['sentemb']
        encoder_padding_mask = encoder_out['encoder_padding_mask']
        encoder_out = encoder_out['encoder_out']

        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
        bsz, seqlen = prev_output_tokens.size()

        # get outputs from encoder
        encoder_outs, encoder_hiddens, encoder_cells = encoder_out[:3]

        # embed tokens
        x = self.embed_tokens(prev_output_tokens)
        x = F.dropout(x, p=self.dropout_in, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # embed language
        lang_tensor = torch.cuda.LongTensor([self.lang_dictionary[lang]] * bsz)
        l = self.embed_langs(lang_tensor)

        # B x T x C -> T x B x C
        #l = l.transpose(0, 1)

        # initialize previous states (or get from cache during incremental generation)
        cached_state = utils.get_incremental_state(self, incremental_state,
                                                   'cached_state')
        if cached_state is not None:
            prev_hiddens, prev_cells, input_feed = cached_state
            print(len(prev_cells[0]))
        else:
            num_layers = len(self.layers)
            # Hiddens and cells are initialized with a linear transformation of the embedding produced by the encoder
            prev_hiddens = [encoder_sentemb for i in range(num_layers)]
            prev_cells = [encoder_sentemb for i in range(num_layers)]
            prev_hiddens = [self.encoder_hidden_proj(x) for x in prev_hiddens]
            prev_cells = [self.encoder_cell_proj(x) for x in prev_cells]
            input_feed = x.new_zeros(bsz, self.hidden_size)

        outs = []
        for j in range(seqlen):
            # input feeding: concatenate context vector from previous time step
            input = torch.cat((x[j, :, :], encoder_sentemb, input_feed, l),
                              dim=1)

            for i, rnn in enumerate(self.layers):
                # recurrent cell
                hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i]))

                # hidden state becomes the input to the next layer
                input = F.dropout(hidden,
                                  p=self.dropout_out,
                                  training=self.training)

                # save state for next time step
                prev_hiddens[i] = hidden
                prev_cells[i] = cell

            out = hidden
            out = F.dropout(out, p=self.dropout_out, training=self.training)

            # input feeding
            input_feed = out

            # save final output
            outs.append(out)

        # cache previous states (no-op except during incremental generation)
        utils.set_incremental_state(
            self,
            incremental_state,
            'cached_state',
            (prev_hiddens, prev_cells, input_feed),
        )

        # collect outputs across time steps
        x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size)

        # T x B x C -> B x T x C
        x = x.transpose(1, 0)

        # project back to size of vocabulary
        if self.adaptive_softmax is None:
            if hasattr(self, 'additional_fc'):
                x = self.additional_fc(x)
                x = F.dropout(x, p=self.dropout_out, training=self.training)
            if self.share_input_output_embed:
                x = F.linear(x, self.embed_tokens.weight)
            else:
                x = self.fc_out(x)
        return x, None
 def _get_input_buffer(self, incremental_state):
     return utils.get_incremental_state(self, incremental_state,
                                        'input_buffer')
Esempio n. 31
0
 def _get_input_buffer(self, incremental_state):
     return utils.get_incremental_state(self, incremental_state, 'input_buffer')
Esempio n. 32
0
 def _get_input_buffer(self, incremental_state):
     return utils.get_incremental_state(
         self,
         incremental_state,
         'attn_state',
     ) or {}
Esempio n. 33
0
 def _get_input_buffer(self, incremental_state):
     return utils.get_incremental_state(
         self,
         incremental_state,
         'attn_state',
     ) or {}