Esempio n. 1
0
    def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
        inputs, initial_state, constants = rnn_utils.standardize_args(
            inputs, initial_state, constants, self._num_constants)

        if initial_state is None and constants is None:
            return super(RNN, self).__call__(inputs, **kwargs)

        # If any of `initial_state` or `constants` are specified and are Keras
        # tensors, then add them to the inputs and temporarily modify the
        # input_spec to include them.

        additional_inputs = []
        additional_specs = []
        if initial_state is not None:
            additional_inputs += initial_state
            self.state_spec = tf.nest.map_structure(
                lambda s: InputSpec(shape=backend.int_shape(s)), initial_state)
            additional_specs += self.state_spec
        if constants is not None:
            additional_inputs += constants
            self.constants_spec = [
                InputSpec(shape=backend.int_shape(constant))
                for constant in constants
            ]
            self._num_constants = len(constants)
            additional_specs += self.constants_spec
        # additional_inputs can be empty if initial_state or constants are provided
        # but empty (e.g. the cell is stateless).
        flat_additional_inputs = tf.nest.flatten(additional_inputs)
        is_keras_tensor = backend.is_keras_tensor(
            flat_additional_inputs[0]) if flat_additional_inputs else True
        for tensor in flat_additional_inputs:
            if backend.is_keras_tensor(tensor) != is_keras_tensor:
                raise ValueError(
                    'The initial state or constants of an RNN layer cannot be '
                    'specified via a mix of Keras tensors and non-Keras tensors '
                    '(a "Keras tensor" is a tensor that was returned by a Keras layer '
                    ' or by `Input` during Functional model construction). '
                    f'Received: initial_state={initial_state}, constants={constants}'
                )

        if is_keras_tensor:
            # Compute the full input spec, including state and constants
            full_input = [inputs] + additional_inputs
            if self.built:
                # Keep the input_spec since it has been populated in build() method.
                full_input_spec = self.input_spec + additional_specs
            else:
                # The original input_spec is None since there could be a nested tensor
                # input. Update the input_spec to match the inputs.
                full_input_spec = generic_utils.to_list(
                    tf.nest.map_structure(lambda _: None,
                                          inputs)) + additional_specs
            # Perform the call with temporarily replaced input_spec
            self.input_spec = full_input_spec
            output = super(RNN, self).__call__(full_input, **kwargs)
            # Remove the additional_specs from input spec and keep the rest. It is
            # important to keep since the input spec was populated by build(), and
            # will be reused in the stateful=True.
            self.input_spec = self.input_spec[:-len(additional_specs)]
            return output
        else:
            if initial_state is not None:
                kwargs['initial_state'] = initial_state
            if constants is not None:
                kwargs['constants'] = constants
            return super(RNN, self).__call__(inputs, **kwargs)
Esempio n. 2
0
    def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
        """`Bidirectional.__call__` implements the same API as the wrapped
        `RNN`."""
        inputs, initial_state, constants = rnn_utils.standardize_args(
            inputs, initial_state, constants, self._num_constants
        )

        if isinstance(inputs, list):
            if len(inputs) > 1:
                initial_state = inputs[1:]
            inputs = inputs[0]

        if initial_state is None and constants is None:
            return super().__call__(inputs, **kwargs)

        # Applies the same workaround as in `RNN.__call__`
        additional_inputs = []
        additional_specs = []
        if initial_state is not None:
            # Check if `initial_state` can be split into half
            num_states = len(initial_state)
            if num_states % 2 > 0:
                raise ValueError(
                    "When passing `initial_state` to a Bidirectional RNN, "
                    "the state should be a list containing the states of "
                    "the underlying RNNs. "
                    f"Received: {initial_state}"
                )

            kwargs["initial_state"] = initial_state
            additional_inputs += initial_state
            state_specs = tf.nest.map_structure(
                lambda state: InputSpec(shape=backend.int_shape(state)),
                initial_state,
            )
            self.forward_layer.state_spec = state_specs[: num_states // 2]
            self.backward_layer.state_spec = state_specs[num_states // 2 :]
            additional_specs += state_specs
        if constants is not None:
            kwargs["constants"] = constants
            additional_inputs += constants
            constants_spec = [
                InputSpec(shape=backend.int_shape(constant))
                for constant in constants
            ]
            self.forward_layer.constants_spec = constants_spec
            self.backward_layer.constants_spec = constants_spec
            additional_specs += constants_spec

            self._num_constants = len(constants)
            self.forward_layer._num_constants = self._num_constants
            self.backward_layer._num_constants = self._num_constants

        is_keras_tensor = backend.is_keras_tensor(
            tf.nest.flatten(additional_inputs)[0]
        )
        for tensor in tf.nest.flatten(additional_inputs):
            if backend.is_keras_tensor(tensor) != is_keras_tensor:
                raise ValueError(
                    "The initial state of a Bidirectional"
                    " layer cannot be specified with a mix of"
                    " Keras tensors and non-Keras tensors"
                    ' (a "Keras tensor" is a tensor that was'
                    " returned by a Keras layer, or by `Input`)"
                )

        if is_keras_tensor:
            # Compute the full input spec, including state
            full_input = [inputs] + additional_inputs
            # The original input_spec is None since there could be a nested
            # tensor input. Update the input_spec to match the inputs.
            full_input_spec = [
                None for _ in range(len(tf.nest.flatten(inputs)))
            ] + additional_specs
            # Removing kwargs since the value are passed with input list.
            kwargs["initial_state"] = None
            kwargs["constants"] = None

            # Perform the call with temporarily replaced input_spec
            original_input_spec = self.input_spec
            self.input_spec = full_input_spec
            output = super().__call__(full_input, **kwargs)
            self.input_spec = original_input_spec
            return output
        else:
            return super().__call__(inputs, **kwargs)