コード例 #1
0
 def initialize(self, helper: Helper, inputs: Optional[torch.Tensor],
                sequence_length: Optional[torch.LongTensor],
                initial_state: Optional[Cache]) \
         -> Tuple[torch.ByteTensor, torch.Tensor, Cache]:
     initial_finished, initial_inputs = helper.initialize(
         self.embed_tokens, inputs, sequence_length)
     state = initial_state or self._state_cache
     return initial_finished, initial_inputs, state
コード例 #2
0
 def initialize(self, helper: Helper, inputs: Optional[torch.Tensor],
                sequence_length: Optional[torch.LongTensor],
                initial_state: Optional[State]) \
         -> Tuple[torch.ByteTensor, torch.Tensor, Optional[State]]:
     initial_finished, initial_inputs = helper.initialize(
         self.embed_tokens, inputs, sequence_length)
     state = initial_state or self._cell.init_batch()
     return (initial_finished, initial_inputs, state)
コード例 #3
0
    def initialize(  # type: ignore
            self, helper: Helper,
            inputs: Optional[torch.Tensor],
            sequence_length: Optional[torch.LongTensor],
            initial_state: Optional[MaybeList[MaybeTuple[torch.Tensor]]]) -> \
            Tuple[torch.ByteTensor, torch.Tensor,
                  Optional[AttentionWrapperState]]:
        initial_finished, initial_inputs = helper.initialize(
            self.embed_tokens, inputs, sequence_length)
        if initial_state is None:
            state = None
        else:
            tensor = utils.get_first_in_structure(initial_state)
            assert tensor is not None
            tensor: torch.Tensor
            state = self._cell.zero_state(batch_size=tensor.size(0))
            state = state._replace(cell_state=initial_state)

        return initial_finished, initial_inputs, state