def test_custom_cell(self): class CustomCell(tf.keras.layers.AbstractRNNCell): @property def output_size(self): raise ValueError("assert_like_rnncell should not run code") keras_utils.assert_like_rnncell("cell", CustomCell())
def __init__(self, cell, sampler, output_layer=None, **kwargs): """Initialize BasicDecoder. Args: cell: An `RNNCell` instance. sampler: A `Sampler` instance. output_layer: (Optional) An instance of `tf.layers.Layer`, i.e., `tf.layers.Dense`. Optional layer to apply to the RNN output prior to storing the result or sampling. **kwargs: Other keyward arguments for layer creation. Raises: TypeError: if `cell`, `helper` or `output_layer` have an incorrect type. """ keras_utils.assert_like_rnncell("cell", cell) if not isinstance(sampler, sampler_py.Sampler): raise TypeError( "sampler must be a Sampler, received: {}".format(sampler)) if (output_layer is not None and not isinstance(output_layer, tf.keras.layers.Layer)): raise TypeError( "output_layer must be a Layer, received: {}".format( output_layer)) self.cell = cell self.sampler = sampler self.output_layer = output_layer super().__init__(**kwargs)
def __init__( self, cell, beam_width, output_layer=None, length_penalty_weight=0.0, coverage_penalty_weight=0.0, reorder_tensor_arrays=True, **kwargs ): """Initialize the BeamSearchDecoderMixin. Args: cell: An `RNNCell` instance. beam_width: Python integer, the number of beams. output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e., `tf.keras.layers.Dense`. Optional layer to apply to the RNN output prior to storing the result or sampling. length_penalty_weight: Float weight to penalize length. Disabled with 0.0. coverage_penalty_weight: Float weight to penalize the coverage of source sentence. Disabled with 0.0. reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell state will be reordered according to the beam search path. If the `TensorArray` can be reordered, the stacked form will be returned. Otherwise, the `TensorArray` will be returned as is. Set this flag to `False` if the cell state contains `TensorArray`s that are not amenable to reordering. **kwargs: Dict, other keyword arguments for parent class. Raises: TypeError: if `cell` is not an instance of `RNNCell`, or `output_layer` is not an instance of `tf.keras.layers.Layer`. """ keras_utils.assert_like_rnncell("cell", cell) if output_layer is not None and not isinstance( output_layer, tf.keras.layers.Layer ): raise TypeError( "output_layer must be a Layer, received: %s" % type(output_layer) ) self._cell = cell self._output_layer = output_layer self._reorder_tensor_arrays = reorder_tensor_arrays self._start_tokens = None self._end_token = None self._batch_size = None self._beam_width = beam_width self._length_penalty_weight = length_penalty_weight self._coverage_penalty_weight = coverage_penalty_weight super().__init__(**kwargs)
def __init__(self, cell: tf.keras.layers.Layer, sampler: sampler_py.Sampler, output_layer: Optional[tf.keras.layers.Layer] = None, **kwargs): """Initialize BasicDecoder. Args: cell: A layer that implements the `tf.keras.layers.AbstractRNNCell` interface. sampler: A `Sampler` instance. output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e., `tf.keras.layers.Dense`. Optional layer to apply to the RNN output prior to storing the result or sampling. **kwargs: Other keyword arguments for layer creation. """ keras_utils.assert_like_rnncell("cell", cell) self.cell = cell self.sampler = sampler self.output_layer = output_layer super().__init__(**kwargs)
def test_non_cell(self): with self.assertRaises(TypeError): keras_utils.assert_like_rnncell("cell", tf.keras.layers.Dense(10))
def test_standard_cell(self): keras_utils.assert_like_rnncell("cell", tf.keras.layers.LSTMCell(10))
def test_non_cell(): with pytest.raises(TypeError): keras_utils.assert_like_rnncell("cell", tf.keras.layers.Dense(10))