def _shape(batch_size, from_shape):
     """ Returns the batch_size concatenated with the from_shape """
     if (not isinstance(from_shape, tensor_shape.TensorShape)
             or from_shape.ndims == 0):
         return tensor_shape.TensorShape(None)
     batch_size = tensor_util.constant_value(
         ops.convert_to_tensor(batch_size, name='batch_size'))
     return tensor_shape.TensorShape([batch_size
                                      ]).concatenate(from_shape)
Beispiel #2
0
 def state_size(self):
     """ The `state_size` property of `TransformerCell` """
     past_attns_shape = [
         self._nb_layers, 2, self._nb_heads, None,
         self._emb_size // self._nb_heads
     ]
     feeder_state_shape = tensor_shape.TensorShape(
         []) if self._feeder_cell is None else self._feeder_cell.state_size
     return TransformerCellState(
         past_attentions=tensor_shape.TensorShape(past_attns_shape),
         feeder_state=feeder_state_shape,
         time=tensor_shape.TensorShape([]))
Beispiel #3
0
 def state_size(self):
     """ The `state_size` property of `AttentionWrapper`.
         :return: An `SelfAttentionWrapperState` tuple containing shapes used by this object.
     """
     return SelfAttentionWrapperState(cell_state=self._cell.state_size,
                                      time=tensor_shape.TensorShape([]),
                                      memory=self._memory_size)
Beispiel #4
0
 def state_size(self):
     """ The `state_size` property of `AttentionWrapper`.
         :return: An `AttentionWrapperState` tuple containing shapes used by this object.
     """
     return AttentionWrapperState(cell_state=self._cell.state_size,
                                  time=tensor_shape.TensorShape([]),
                                  attention=self._attention_layer_size,
                                  alignments=self._memory_time,
                                  attention_state=self._memory_time,
                                  alignment_history=())
Beispiel #5
0
 def output_size(self):
     # Return the cell output and the id
     if self.extract_state:
         return BasicDecoderWithStateOutput(
             rnn_output=self._rnn_output_size(),
             rnn_state=tensor_shape.TensorShape([self._cell.output_size]),
             sample_id=self._helper.sample_ids_shape)
     return seq2seq.BasicDecoderOutput(
         rnn_output=self._rnn_output_size(),
         sample_id=self._helper.sample_ids_shape)
Beispiel #6
0
 def state_size(self):
     """ Returns an `AttentionWrapperState` tuple containing shapes used by this object. """
     return AttentionWrapperState(
         cell_state=self._cell.state_size,
         time=tensor_shape.TensorShape([]),
         attention=self._attention_layer_size,
         alignments=self._item_or_tuple(
             a.alignments_size for a in self._attention_mechanisms),
         attention_state=self._item_or_tuple(
             a.state_size for a in self._attention_mechanisms),
         alignment_history=self._item_or_tuple(
             a.alignments_size if self._alignment_history else ()
             for a in self._attention_mechanisms))
Beispiel #7
0
    def output_size(self):
        """ Returns the size of the RNN output """
        size = self._cell.output_size
        if self._output_layer is None:
            return size

        # To use layer's compute_output_shape, we need to convert the RNNCell's output_size entries into shapes
        # with an unknown batch size.  We then pass this through the layer's compute_output_shape and read off
        # all but the first (batch) dimensions to get the output size of the rnn with the layer applied to the top.
        output_shape_with_unknown_batch = \
            nest.map_structure(lambda shape: tensor_shape.TensorShape([None]).concatenate(shape), size)
        layer_output_shape = self._output_layer.compute_output_shape(output_shape_with_unknown_batch)
        return nest.map_structure(lambda shape: shape[1:], layer_output_shape)
Beispiel #8
0
 def compute_output_shape(self, input_shape):
     """ Computes the output shape of the given layer """
     input_shape = tensor_shape.TensorShape(input_shape)
     input_shape = input_shape.with_rank_at_least(2)
     return input_shape[:-1].concatenate(self.output_size)
Beispiel #9
0
 def state_size(self):
     """ The `state_size` property of `IdentityCell`. """
     return tensor_shape.TensorShape([])
Beispiel #10
0
 def state_size(self):
     """ The `state_size` property of `ArrayConcatWrapper`.
         :return: An `ArrayConcatWrapperState` tuple containing shapes used by this object.
     """
     return ArrayConcatWrapperState(cell_state=self._cell.state_size,
                                    time=tensor_shape.TensorShape([]))
Beispiel #11
0
 def sample_ids_shape(self):
     """ Returns the shape of the sample ids """
     return tensor_shape.TensorShape([])