def _shape(batch_size, from_shape): """ Returns the batch_size concatenated with the from_shape """ if (not isinstance(from_shape, tensor_shape.TensorShape) or from_shape.ndims == 0): return tensor_shape.TensorShape(None) batch_size = tensor_util.constant_value( ops.convert_to_tensor(batch_size, name='batch_size')) return tensor_shape.TensorShape([batch_size ]).concatenate(from_shape)
def state_size(self): """ The `state_size` property of `TransformerCell` """ past_attns_shape = [ self._nb_layers, 2, self._nb_heads, None, self._emb_size // self._nb_heads ] feeder_state_shape = tensor_shape.TensorShape( []) if self._feeder_cell is None else self._feeder_cell.state_size return TransformerCellState( past_attentions=tensor_shape.TensorShape(past_attns_shape), feeder_state=feeder_state_shape, time=tensor_shape.TensorShape([]))
def state_size(self): """ The `state_size` property of `AttentionWrapper`. :return: An `SelfAttentionWrapperState` tuple containing shapes used by this object. """ return SelfAttentionWrapperState(cell_state=self._cell.state_size, time=tensor_shape.TensorShape([]), memory=self._memory_size)
def state_size(self): """ The `state_size` property of `AttentionWrapper`. :return: An `AttentionWrapperState` tuple containing shapes used by this object. """ return AttentionWrapperState(cell_state=self._cell.state_size, time=tensor_shape.TensorShape([]), attention=self._attention_layer_size, alignments=self._memory_time, attention_state=self._memory_time, alignment_history=())
def output_size(self): # Return the cell output and the id if self.extract_state: return BasicDecoderWithStateOutput( rnn_output=self._rnn_output_size(), rnn_state=tensor_shape.TensorShape([self._cell.output_size]), sample_id=self._helper.sample_ids_shape) return seq2seq.BasicDecoderOutput( rnn_output=self._rnn_output_size(), sample_id=self._helper.sample_ids_shape)
def state_size(self): """ Returns an `AttentionWrapperState` tuple containing shapes used by this object. """ return AttentionWrapperState( cell_state=self._cell.state_size, time=tensor_shape.TensorShape([]), attention=self._attention_layer_size, alignments=self._item_or_tuple( a.alignments_size for a in self._attention_mechanisms), attention_state=self._item_or_tuple( a.state_size for a in self._attention_mechanisms), alignment_history=self._item_or_tuple( a.alignments_size if self._alignment_history else () for a in self._attention_mechanisms))
def output_size(self): """ Returns the size of the RNN output """ size = self._cell.output_size if self._output_layer is None: return size # To use layer's compute_output_shape, we need to convert the RNNCell's output_size entries into shapes # with an unknown batch size. We then pass this through the layer's compute_output_shape and read off # all but the first (batch) dimensions to get the output size of the rnn with the layer applied to the top. output_shape_with_unknown_batch = \ nest.map_structure(lambda shape: tensor_shape.TensorShape([None]).concatenate(shape), size) layer_output_shape = self._output_layer.compute_output_shape(output_shape_with_unknown_batch) return nest.map_structure(lambda shape: shape[1:], layer_output_shape)
def compute_output_shape(self, input_shape): """ Computes the output shape of the given layer """ input_shape = tensor_shape.TensorShape(input_shape) input_shape = input_shape.with_rank_at_least(2) return input_shape[:-1].concatenate(self.output_size)
def state_size(self): """ The `state_size` property of `IdentityCell`. """ return tensor_shape.TensorShape([])
def state_size(self): """ The `state_size` property of `ArrayConcatWrapper`. :return: An `ArrayConcatWrapperState` tuple containing shapes used by this object. """ return ArrayConcatWrapperState(cell_state=self._cell.state_size, time=tensor_shape.TensorShape([]))
def sample_ids_shape(self): """ Returns the shape of the sample ids """ return tensor_shape.TensorShape([])