Ejemplo n.º 1
0
    def get_initial_state(self, inputs):
        get_initial_state_fn = getattr(self.cell, "get_initial_state", None)

        if tf.nest.is_nested(inputs):
            # The input are nested sequences. Use the first element in the seq
            # to get batch size and dtype.
            inputs = tf.nest.flatten(inputs)[0]

        input_shape = tf.shape(inputs)
        batch_size = input_shape[1] if self.time_major else input_shape[0]
        dtype = inputs.dtype
        if get_initial_state_fn:
            init_state = get_initial_state_fn(inputs=None,
                                              batch_size=batch_size,
                                              dtype=dtype)
        else:
            init_state = rnn_utils.generate_zero_filled_state(
                batch_size, self.cell.state_size, dtype)
        # Keras RNN expect the states in a list, even if it's a single state
        # tensor.
        if not tf.nest.is_nested(init_state):
            init_state = [init_state]
        # Force the state to be a list in case it is a namedtuple eg
        # LSTMStateTuple.
        return list(init_state)
Ejemplo n.º 2
0
    def reset_states(self, states=None):
        """Reset the recorded states for the stateful RNN layer.

    Can only be used when RNN layer is constructed with `stateful` = `True`.
    Args:
      states: Numpy arrays that contains the value for the initial state, which
        will be feed to cell at the first time step. When the value is None,
        zero filled numpy array will be created based on the cell state size.

    Raises:
      AttributeError: When the RNN layer is not stateful.
      ValueError: When the batch size of the RNN layer is unknown.
      ValueError: When the input numpy array is not compatible with the RNN
        layer state, either size wise or dtype wise.
    """
        if not self.stateful:
            raise AttributeError('Layer must be stateful.')
        spec_shape = None
        if self.input_spec is not None:
            spec_shape = tf.nest.flatten(self.input_spec[0])[0].shape
        if spec_shape is None:
            # It is possible to have spec shape to be None, eg when construct a RNN
            # with a custom cell, or standard RNN layers (LSTM/GRU) which we only know
            # it has 3 dim input, but not its full shape spec before build().
            batch_size = None
        else:
            batch_size = spec_shape[1] if self.time_major else spec_shape[0]
        if not batch_size:
            raise ValueError('If a RNN is stateful, it needs to know '
                             'its batch size. Specify the batch size '
                             'of your input tensors: \n'
                             '- If using a Sequential model, '
                             'specify the batch size by passing '
                             'a `batch_input_shape` '
                             'argument to your first layer.\n'
                             '- If using the functional API, specify '
                             'the batch size by passing a '
                             '`batch_shape` argument to your Input layer.')
        # initialize state if None
        if tf.nest.flatten(self.states)[0] is None:
            if getattr(self.cell, 'get_initial_state', None):
                flat_init_state_values = tf.nest.flatten(
                    self.cell.get_initial_state(
                        inputs=None,
                        batch_size=batch_size,
                        # Use variable_dtype instead of compute_dtype, since the state is
                        # stored in a variable
                        dtype=self.variable_dtype or backend.floatx()))
            else:
                flat_init_state_values = tf.nest.flatten(
                    rnn_utils.generate_zero_filled_state(
                        batch_size, self.cell.state_size, self.variable_dtype
                        or backend.floatx()))
            flat_states_variables = tf.nest.map_structure(
                backend.variable, flat_init_state_values)
            self.states = tf.nest.pack_sequence_as(self.cell.state_size,
                                                   flat_states_variables)
            if not tf.nest.is_nested(self.states):
                self.states = [self.states]
        elif states is None:
            for state, size in zip(tf.nest.flatten(self.states),
                                   tf.nest.flatten(self.cell.state_size)):
                backend.set_value(
                    state,
                    np.zeros([batch_size] + tf.TensorShape(size).as_list()))
        else:
            flat_states = tf.nest.flatten(self.states)
            flat_input_states = tf.nest.flatten(states)
            if len(flat_input_states) != len(flat_states):
                raise ValueError(
                    f'Layer {self.name} expects {len(flat_states)} '
                    f'states, but it received {len(flat_input_states)} '
                    f'state values. States received: {states}')
            set_value_tuples = []
            for i, (value,
                    state) in enumerate(zip(flat_input_states, flat_states)):
                if value.shape != state.shape:
                    raise ValueError(
                        f'State {i} is incompatible with layer {self.name}: '
                        f'expected shape={(batch_size, state)} '
                        f'but found shape={value.shape}')
                set_value_tuples.append((state, value))
            backend.batch_set_value(set_value_tuples)