Exemple #1
0
    def step(self, helper: Helper, time: int,
             inputs: torch.Tensor, state: Optional[Cache]) \
            -> Tuple[TransformerDecoderOutput, Cache,
                     torch.Tensor, torch.ByteTensor]:
        assert state is not None
        outputs, state = self._inputs_to_outputs(inputs, state)
        sample_ids = helper.sample(time=time, outputs=outputs)
        if self._state_context is not None:
            assert self._state_context_sequence_length is not None
            sample_ids = torch.where(
                self._state_context_sequence_length > time,
                self._state_context[:, time], sample_ids)

        if time + 1 == self._state_max_decoding_length:
            # Maximum decoding length reached, mark all batches as finished.
            # This requires special handling because performing lookup on
            # position embeddings with `time + 1` may result in IndexError.
            finished = torch.ones_like(sample_ids, dtype=torch.uint8)
            # Since `next_inputs` will not be used, simply create a null tensor.
            next_inputs = torch.empty(0)
        else:
            finished, next_inputs = helper.next_inputs(time=time,
                                                       outputs=outputs,
                                                       sample_ids=sample_ids)
        next_state = state
        outputs = TransformerDecoderOutput(logits=outputs,
                                           sample_id=sample_ids)
        return outputs, next_state, next_inputs, finished
Exemple #2
0
    def step(self,
             helper: Helper,
             time: int,
             inputs: torch.Tensor,
             state: Optional[State]) \
            -> Tuple[Output, Optional[State], torch.Tensor, torch.ByteTensor]:
        self._state_previous_inputs.append(inputs)
        if self._state_recompute_memory:
            net_output, memory = self._forward(
                two_stream=True,
                **self._create_input(
                    self._state_previous_inputs[-self._state_cache_len:]))
        else:
            assert state is not None
            net_output, memory = self._forward(
                memory=state,
                cache_len=self._state_cache_len,
                two_stream=True,
                **self._create_input(self._state_previous_inputs[-1:]))
            assert memory is not None
            # Omit memory for the dummy token.
            memory = [mem[:, :-1] for mem in memory]

        logits = F.linear(net_output, self.word_embed.weight, self.lm_bias)
        logits = logits[:, -1]
        sample_ids = helper.sample(time=time, outputs=logits)
        (finished, next_inputs) = helper.next_inputs(self.embed_tokens, time,
                                                     logits, sample_ids)
        outputs = XLNetDecoderOutput(logits=logits, sample_id=sample_ids)
        return outputs, memory, next_inputs, finished
 def step(self, helper: Helper, time: int,
          inputs: torch.Tensor, state: Optional[HiddenState]) \
         -> Tuple[BasicRNNDecoderOutput, HiddenState,
                  torch.Tensor, torch.ByteTensor]:
     cell_outputs, cell_state = self._cell(inputs, state)
     logits = self._output_layer(cell_outputs)
     sample_ids = helper.sample(time=time, outputs=logits)
     (finished, next_inputs) = helper.next_inputs(self.embed_tokens, time,
                                                  logits, sample_ids)
     next_state = cell_state
     outputs = BasicRNNDecoderOutput(logits, sample_ids, cell_outputs)
     return outputs, next_state, next_inputs, finished
    def initialize(  # type: ignore
            self, helper: Helper,
            inputs: Optional[torch.Tensor],
            sequence_length: Optional[torch.LongTensor],
            initial_state: Optional[MaybeList[MaybeTuple[torch.Tensor]]]) -> \
            Tuple[torch.ByteTensor, torch.Tensor,
                  Optional[AttentionWrapperState]]:
        initial_finished, initial_inputs = helper.initialize(
            inputs, sequence_length)
        if initial_state is None:
            state = None
        else:
            batch_size = None

            def _get_batch_size_fn(x):
                nonlocal batch_size
                if isinstance(x, torch.Tensor):
                    batch_size = x.size(0)

            utils.map_structure(_get_batch_size_fn, initial_state)
            state = self._cell.zero_state(  # type: ignore
                batch_size=batch_size)
            state = state._replace(cell_state=initial_state)

        return initial_finished, initial_inputs, state
Exemple #5
0
 def initialize(self, helper: Helper, inputs: Optional[torch.Tensor],
                sequence_length: Optional[torch.LongTensor],
                initial_state: Optional[State]) \
         -> Tuple[torch.ByteTensor, torch.Tensor, Optional[State]]:
     initial_finished, initial_inputs = helper.initialize(
         inputs, sequence_length)
     state = initial_state or self._cell.init_batch()
     return (initial_finished, initial_inputs, state)
Exemple #6
0
 def initialize(self, helper: Helper, inputs: Optional[torch.Tensor],
                sequence_length: Optional[torch.LongTensor],
                initial_state: Optional[Cache]) \
         -> Tuple[torch.ByteTensor, torch.Tensor, Cache]:
     initial_finished, initial_inputs = helper.initialize(
         inputs, sequence_length)
     state = initial_state or self._state_cache
     return initial_finished, initial_inputs, state
Exemple #7
0
 def initialize(self,  # pylint: disable=no-self-use
                helper: Helper,
                inputs: Optional[torch.Tensor],
                sequence_length: Optional[torch.LongTensor],
                initial_state: Optional[State]) \
         -> Tuple[torch.ByteTensor, torch.Tensor, Optional[State]]:
     initial_finished, initial_inputs = helper.initialize(
         self.embed_tokens, inputs, sequence_length)
     return initial_finished, initial_inputs, initial_state
    def step(self, helper: Helper, time: int,
             inputs: torch.Tensor, state: Optional[AttentionWrapperState]) -> \
            Tuple[AttentionRNNDecoderOutput, AttentionWrapperState,
                  torch.Tensor, torch.ByteTensor]:
        wrapper_outputs, wrapper_state = self._cell(
            inputs, state, self.memory, self.memory_sequence_length)
        # Essentially the same as in BasicRNNDecoder.step()

        logits = self._output_layer(wrapper_outputs)
        sample_ids = helper.sample(time=time, outputs=logits)
        finished, next_inputs = helper.next_inputs(self.embed_tokens, time,
                                                   logits, sample_ids)

        attention_scores = wrapper_state.alignments
        attention_context = wrapper_state.attention
        outputs = AttentionRNNDecoderOutput(logits, sample_ids,
                                            wrapper_outputs, attention_scores,
                                            attention_context)
        next_state = wrapper_state

        return outputs, next_state, next_inputs, finished
    def initialize(  # type: ignore
            self, helper: Helper,
            inputs: Optional[torch.Tensor],
            sequence_length: Optional[torch.LongTensor],
            initial_state: Optional[MaybeList[MaybeTuple[torch.Tensor]]]) -> \
            Tuple[torch.ByteTensor, torch.Tensor,
                  Optional[AttentionWrapperState]]:
        initial_finished, initial_inputs = helper.initialize(
            self.embed_tokens, inputs, sequence_length)
        if initial_state is None:
            state = None
        else:
            tensor = utils.get_first_in_structure(initial_state)
            assert tensor is not None
            tensor: torch.Tensor
            state = self._cell.zero_state(batch_size=tensor.size(0))
            state = state._replace(cell_state=initial_state)

        return initial_finished, initial_inputs, state