Beispiel #1
0
    def test_custom_cell(self):
        class CustomCell(tf.keras.layers.AbstractRNNCell):
            @property
            def output_size(self):
                raise ValueError("assert_like_rnncell should not run code")

        keras_utils.assert_like_rnncell("cell", CustomCell())
Beispiel #2
0
    def __init__(self, cell, sampler, output_layer=None, **kwargs):
        """Initialize BasicDecoder.

        Args:
          cell: An `RNNCell` instance.
          sampler: A `Sampler` instance.
          output_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
            `tf.layers.Dense`. Optional layer to apply to the RNN output prior
             to storing the result or sampling.
          **kwargs: Other keyward arguments for layer creation.

        Raises:
          TypeError: if `cell`, `helper` or `output_layer` have an incorrect
          type.
        """
        keras_utils.assert_like_rnncell("cell", cell)
        if not isinstance(sampler, sampler_py.Sampler):
            raise TypeError(
                "sampler must be a Sampler, received: {}".format(sampler))
        if (output_layer is not None
                and not isinstance(output_layer, tf.keras.layers.Layer)):
            raise TypeError(
                "output_layer must be a Layer, received: {}".format(
                    output_layer))
        self.cell = cell
        self.sampler = sampler
        self.output_layer = output_layer
        super().__init__(**kwargs)
Beispiel #3
0
    def __init__(
        self,
        cell,
        beam_width,
        output_layer=None,
        length_penalty_weight=0.0,
        coverage_penalty_weight=0.0,
        reorder_tensor_arrays=True,
        **kwargs
    ):
        """Initialize the BeamSearchDecoderMixin.

        Args:
          cell: An `RNNCell` instance.
          beam_width:  Python integer, the number of beams.
          output_layer: (Optional) An instance of `tf.keras.layers.Layer`,
            i.e., `tf.keras.layers.Dense`.  Optional layer to apply to the RNN
            output prior to storing the result or sampling.
          length_penalty_weight: Float weight to penalize length. Disabled with
             0.0.
          coverage_penalty_weight: Float weight to penalize the coverage of
            source sentence. Disabled with 0.0.
          reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the
            cell state will be reordered according to the beam search path. If
            the `TensorArray` can be reordered, the stacked form will be
            returned. Otherwise, the `TensorArray` will be returned as is. Set
            this flag to `False` if the cell state contains `TensorArray`s that
            are not amenable to reordering.
          **kwargs: Dict, other keyword arguments for parent class.

        Raises:
          TypeError: if `cell` is not an instance of `RNNCell`,
            or `output_layer` is not an instance of `tf.keras.layers.Layer`.
        """
        keras_utils.assert_like_rnncell("cell", cell)
        if output_layer is not None and not isinstance(
            output_layer, tf.keras.layers.Layer
        ):
            raise TypeError(
                "output_layer must be a Layer, received: %s" % type(output_layer)
            )
        self._cell = cell
        self._output_layer = output_layer
        self._reorder_tensor_arrays = reorder_tensor_arrays

        self._start_tokens = None
        self._end_token = None
        self._batch_size = None
        self._beam_width = beam_width
        self._length_penalty_weight = length_penalty_weight
        self._coverage_penalty_weight = coverage_penalty_weight
        super().__init__(**kwargs)
Beispiel #4
0
    def __init__(self,
                 cell: tf.keras.layers.Layer,
                 sampler: sampler_py.Sampler,
                 output_layer: Optional[tf.keras.layers.Layer] = None,
                 **kwargs):
        """Initialize BasicDecoder.

        Args:
          cell: A layer that implements the `tf.keras.layers.AbstractRNNCell`
            interface.
          sampler: A `Sampler` instance.
          output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e.,
            `tf.keras.layers.Dense`. Optional layer to apply to the RNN output
             prior to storing the result or sampling.
          **kwargs: Other keyword arguments for layer creation.
        """
        keras_utils.assert_like_rnncell("cell", cell)
        self.cell = cell
        self.sampler = sampler
        self.output_layer = output_layer
        super().__init__(**kwargs)
Beispiel #5
0
 def test_non_cell(self):
     with self.assertRaises(TypeError):
         keras_utils.assert_like_rnncell("cell", tf.keras.layers.Dense(10))
Beispiel #6
0
 def test_standard_cell(self):
     keras_utils.assert_like_rnncell("cell", tf.keras.layers.LSTMCell(10))
Beispiel #7
0
def test_non_cell():
    with pytest.raises(TypeError):
        keras_utils.assert_like_rnncell("cell", tf.keras.layers.Dense(10))