def step(self, helper: Helper, time: int, inputs: torch.Tensor, state: Optional[Cache]) \ -> Tuple[TransformerDecoderOutput, Cache, torch.Tensor, torch.ByteTensor]: assert state is not None outputs, state = self._inputs_to_outputs(inputs, state) sample_ids = helper.sample(time=time, outputs=outputs) if self._state_context is not None: assert self._state_context_sequence_length is not None sample_ids = torch.where( self._state_context_sequence_length > time, self._state_context[:, time], sample_ids) if time + 1 == self._state_max_decoding_length: # Maximum decoding length reached, mark all batches as finished. # This requires special handling because performing lookup on # position embeddings with `time + 1` may result in IndexError. finished = torch.ones_like(sample_ids, dtype=torch_bool) # Since `next_inputs` will not be used, simply create a null tensor. next_inputs = torch.empty(0) else: finished, next_inputs = helper.next_inputs(self.embed_tokens, time, outputs, sample_ids) next_state = state outputs = TransformerDecoderOutput(logits=outputs, sample_id=sample_ids) return outputs, next_state, next_inputs, finished
def step(self, helper: Helper, time: int, inputs: torch.Tensor, state: Optional[State]) \ -> Tuple[Output, Optional[State], torch.Tensor, torch.ByteTensor]: self._state_previous_inputs.append(inputs) if self._state_recompute_memory: net_output, memory = self._forward( two_stream=True, **self._create_input( self._state_previous_inputs[-self._state_cache_len:])) else: assert state is not None net_output, memory = self._forward( memory=state, cache_len=self._state_cache_len, two_stream=True, **self._create_input(self._state_previous_inputs[-1:])) assert memory is not None # Omit memory for the dummy token. memory = [mem[:, :-1] for mem in memory] logits = F.linear(net_output, self.word_embed.weight, self.lm_bias) logits = logits[:, -1] sample_ids = helper.sample(time=time, outputs=logits) (finished, next_inputs) = helper.next_inputs( self.embed_tokens, time, logits, sample_ids) outputs = XLNetDecoderOutput(logits=logits, sample_id=sample_ids) return outputs, memory, next_inputs, finished
def step(self, helper: Helper, time: int, inputs: torch.Tensor, state: Optional[HiddenState]) \ -> Tuple[BasicRNNDecoderOutput, HiddenState]: cell_outputs, cell_state = self._cell(inputs, state) logits = self._output_layer(cell_outputs) sample_ids = helper.sample(time=time, outputs=logits) next_state = cell_state outputs = BasicRNNDecoderOutput(logits, sample_ids, cell_outputs) return outputs, next_state
def step(self, helper: Helper, time: int, inputs: torch.Tensor, state: Optional[Cache]) -> \ Tuple[TransformerDecoderOutput, Cache]: assert state is not None outputs, state = self._inputs_to_outputs(inputs, state) sample_ids = helper.sample(time=time, outputs=outputs) if self._state_context is not None: assert self._state_context_sequence_length is not None sample_ids = torch.where( self._state_context_sequence_length > time, self._state_context[:, time], sample_ids) next_state = state outputs = TransformerDecoderOutput(logits=outputs, sample_id=sample_ids) return outputs, next_state
def step(self, helper: Helper, time: int, inputs: torch.Tensor, state: Optional[AttentionWrapperState]) -> \ Tuple[AttentionRNNDecoderOutput, AttentionWrapperState]: wrapper_outputs, wrapper_state = self._cell( inputs, state, self.memory, self.memory_sequence_length) # Essentially the same as in BasicRNNDecoder.step() logits = self._output_layer(wrapper_outputs) sample_ids = helper.sample(time=time, outputs=logits) attention_scores = wrapper_state.alignments attention_context = wrapper_state.attention outputs = AttentionRNNDecoderOutput(logits, sample_ids, wrapper_outputs, attention_scores, attention_context) next_state = wrapper_state return outputs, next_state