def initialize(  # type: ignore
            self, helper: Helper,
            inputs: Optional[torch.Tensor],
            sequence_length: Optional[torch.LongTensor],
            initial_state: Optional[MaybeList[MaybeTuple[torch.Tensor]]]) -> \
            Tuple[torch.ByteTensor, torch.Tensor,
                  Optional[AttentionWrapperState]]:
        initial_finished, initial_inputs = helper.initialize(
            inputs, sequence_length)
        if initial_state is None:
            state = None
        else:
            batch_size = None

            def _get_batch_size_fn(x):
                nonlocal batch_size
                if isinstance(x, torch.Tensor):
                    batch_size = x.size(0)

            utils.map_structure(_get_batch_size_fn, initial_state)
            state = self._cell.zero_state(  # type: ignore
                batch_size=batch_size)
            state = state._replace(cell_state=initial_state)

        return initial_finished, initial_inputs, state
Exemple #2
0
 def initialize(self, helper: Helper, inputs: Optional[torch.Tensor],
                sequence_length: Optional[torch.LongTensor],
                initial_state: Optional[State]) \
         -> Tuple[torch.ByteTensor, torch.Tensor, Optional[State]]:
     initial_finished, initial_inputs = helper.initialize(
         inputs, sequence_length)
     state = initial_state or self._cell.init_batch()
     return (initial_finished, initial_inputs, state)
Exemple #3
0
 def initialize(self, helper: Helper, inputs: Optional[torch.Tensor],
                sequence_length: Optional[torch.LongTensor],
                initial_state: Optional[Cache]) \
         -> Tuple[torch.ByteTensor, torch.Tensor, Cache]:
     initial_finished, initial_inputs = helper.initialize(
         inputs, sequence_length)
     state = initial_state or self._state_cache
     return initial_finished, initial_inputs, state
Exemple #4
0
 def initialize(self,  # pylint: disable=no-self-use
                helper: Helper,
                inputs: Optional[torch.Tensor],
                sequence_length: Optional[torch.LongTensor],
                initial_state: Optional[State]) \
         -> Tuple[torch.ByteTensor, torch.Tensor, Optional[State]]:
     initial_finished, initial_inputs = helper.initialize(
         self.embed_tokens, inputs, sequence_length)
     return initial_finished, initial_inputs, initial_state
    def initialize(  # type: ignore
            self, helper: Helper,
            inputs: Optional[torch.Tensor],
            sequence_length: Optional[torch.LongTensor],
            initial_state: Optional[MaybeList[MaybeTuple[torch.Tensor]]]) -> \
            Tuple[torch.ByteTensor, torch.Tensor,
                  Optional[AttentionWrapperState]]:
        initial_finished, initial_inputs = helper.initialize(
            self.embed_tokens, inputs, sequence_length)
        if initial_state is None:
            state = None
        else:
            tensor = utils.get_first_in_structure(initial_state)
            assert tensor is not None
            tensor: torch.Tensor
            state = self._cell.zero_state(batch_size=tensor.size(0))
            state = state._replace(cell_state=initial_state)

        return initial_finished, initial_inputs, state