def step(self, helper: Helper, time: int,
             inputs: torch.Tensor, state: Optional[Cache]) \
            -> Tuple[TransformerDecoderOutput, Cache,
                     torch.Tensor, torch.ByteTensor]:
        assert state is not None
        outputs, state = self._inputs_to_outputs(inputs, state)
        sample_ids = helper.sample(time=time, outputs=outputs)
        if self._state_context is not None:
            assert self._state_context_sequence_length is not None
            sample_ids = torch.where(
                self._state_context_sequence_length > time,
                self._state_context[:, time], sample_ids)

        if time + 1 == self._state_max_decoding_length:
            # Maximum decoding length reached, mark all batches as finished.
            # This requires special handling because performing lookup on
            # position embeddings with `time + 1` may result in IndexError.
            finished = torch.ones_like(sample_ids, dtype=torch_bool)
            # Since `next_inputs` will not be used, simply create a null tensor.
            next_inputs = torch.empty(0)
        else:
            finished, next_inputs = helper.next_inputs(self.embed_tokens, time,
                                                       outputs, sample_ids)
        next_state = state
        outputs = TransformerDecoderOutput(logits=outputs,
                                           sample_id=sample_ids)
        return outputs, next_state, next_inputs, finished
Esempio n. 2
0
    def step(self,
             helper: Helper,
             time: int,
             inputs: torch.Tensor,
             state: Optional[State]) \
            -> Tuple[Output, Optional[State], torch.Tensor, torch.ByteTensor]:
        self._state_previous_inputs.append(inputs)
        if self._state_recompute_memory:
            net_output, memory = self._forward(
                two_stream=True,
                **self._create_input(
                    self._state_previous_inputs[-self._state_cache_len:]))
        else:
            assert state is not None
            net_output, memory = self._forward(
                memory=state, cache_len=self._state_cache_len, two_stream=True,
                **self._create_input(self._state_previous_inputs[-1:]))
            assert memory is not None
            # Omit memory for the dummy token.
            memory = [mem[:, :-1] for mem in memory]

        logits = F.linear(net_output, self.word_embed.weight, self.lm_bias)
        logits = logits[:, -1]
        sample_ids = helper.sample(time=time, outputs=logits)
        (finished, next_inputs) = helper.next_inputs(
            self.embed_tokens, time, logits, sample_ids)
        outputs = XLNetDecoderOutput(logits=logits, sample_id=sample_ids)
        return outputs, memory, next_inputs, finished
 def next_inputs(self, helper: Helper, time: int,
                 outputs: TransformerDecoderOutput) -> \
         Tuple[torch.Tensor, torch.ByteTensor]:
     finished, next_inputs = helper.next_inputs(self.embed_tokens, time,
                                                outputs.logits,
                                                outputs.sample_id)
     return next_inputs, finished
 def step(self, helper: Helper, time: int,
          inputs: torch.Tensor, state: Optional[HiddenState]) \
         -> Tuple[BasicRNNDecoderOutput, HiddenState,
                  torch.Tensor, torch.ByteTensor]:
     cell_outputs, cell_state = self._cell(inputs, state)
     logits = self._output_layer(cell_outputs)
     sample_ids = helper.sample(time=time, outputs=logits)
     (finished, next_inputs) = helper.next_inputs(self.embed_tokens, time,
                                                  logits, sample_ids)
     next_state = cell_state
     outputs = BasicRNNDecoderOutput(logits, sample_ids, cell_outputs)
     return outputs, next_state, next_inputs, finished
    def step(self, helper: Helper, time: int,
             inputs: torch.Tensor, state: Optional[AttentionWrapperState]) -> \
            Tuple[AttentionRNNDecoderOutput, AttentionWrapperState,
                  torch.Tensor, torch.ByteTensor]:
        wrapper_outputs, wrapper_state = self._cell(
            inputs, state, self.memory, self.memory_sequence_length)
        # Essentially the same as in BasicRNNDecoder.step()

        logits = self._output_layer(wrapper_outputs)
        sample_ids = helper.sample(time=time, outputs=logits)
        finished, next_inputs = helper.next_inputs(self.embed_tokens, time,
                                                   logits, sample_ids)

        attention_scores = wrapper_state.alignments
        attention_context = wrapper_state.attention
        outputs = AttentionRNNDecoderOutput(logits, sample_ids,
                                            wrapper_outputs, attention_scores,
                                            attention_context)
        next_state = wrapper_state

        return outputs, next_state, next_inputs, finished