Exemplo n.º 1
0
  def call(self, observation, step_type=None, network_state=None):
    num_outer_dims = nest_utils.get_outer_rank(observation,
                                               self.observation_spec)
    if num_outer_dims not in (1, 2):
      raise ValueError(
          'Input observation must have a batch or batch x time outer shape.')

    has_time_dim = num_outer_dims == 2
    if not has_time_dim:
      # Add a time dimension to the inputs.
      observation = nest.map_structure(lambda t: tf.expand_dims(t, 1),
                                       observation)
      step_type = nest.map_structure(lambda t: tf.expand_dims(t, 1), step_type)

    states = tf.to_float(nest.flatten(observation)[0])
    batch_squash = utils.BatchSquash(2)  # Squash B, and T dims.
    states = batch_squash.flatten(states)

    for layer in self._input_layers:
      states = layer(states)

    states = batch_squash.unflatten(states)

    with tf.name_scope('reset_mask'):
      reset_mask = tf.equal(step_type, time_step.StepType.FIRST)
    # Unroll over the time sequence.
    states, network_state, _ = rnn_utils.dynamic_unroll(
        self._cell,
        states,
        reset_mask,
        initial_state=network_state,
        dtype=tf.float32)

    states = batch_squash.flatten(states)

    for layer in self._output_layers:
      states = layer(states)

    value = self._value_projection_layer(states)
    value = tf.reshape(value, [-1])
    value = batch_squash.unflatten(value)
    return value, network_state
Exemplo n.º 2
0
    def call(self, observation, step_type=None, network_state=()):
        del step_type  # unused.

        if self._batch_squash:
            outer_rank = nest_utils.get_outer_rank(observation,
                                                   self.observation_spec)
            batch_squash = utils.BatchSquash(outer_rank)

        # Get single observation out regardless of nesting.
        states = tf.cast(nest.flatten(observation)[0], tf.float32)

        if self._batch_squash:
            states = batch_squash.flatten(states)

        for layer in self.layers:
            states = layer(states)

        if self._batch_squash:
            states = batch_squash.unflatten(states)
        return states, network_state
Exemplo n.º 3
0
    def call(self, observation, step_type, network_state=None):
        outer_rank = nest_utils.get_outer_rank(observation,
                                               self.input_tensor_spec)
        batch_squash = utils.BatchSquash(outer_rank)

        observation, network_state = self._lstm_encoder(
            observation, step_type=step_type, network_state=network_state)

        states = batch_squash.flatten(observation)

        actions = []
        for layer, spec in zip(self._action_layers, self._flat_action_spec):
            action = layer(states)
            action = common.scale_to_spec(action, spec)
            action = batch_squash.unflatten(action)
            actions.append(action)

        output_actions = tf.nest.pack_sequence_as(self._output_tensor_spec,
                                                  actions)
        return output_actions, network_state
Exemplo n.º 4
0
    def call(self,
             observations,
             step_type=(),
             network_state=(),
             training=False,
             mask=None):
        outer_rank = nest_utils.get_outer_rank(observations,
                                               self.input_tensor_spec)
        # We use batch_squash here in case the observations have a time sequence
        # compoment.
        batch_squash = utils.BatchSquash(outer_rank)
        observations = tf.nest.map_structure(batch_squash.flatten,
                                             observations)

        # we ignore next_state from the encoder
        state, network_state = self.encoder(observations,
                                            step_type=step_type,
                                            network_state=network_state)
        l_hidden = self.selector(state)
        selection_vector = tf.transpose(l_hidden, perm=[0, 2, 1])

        options = [option(state) for option in self.options]
        options = tf.stack(options)
        options = tf.transpose(options, perm=[1, 0, 2])

        # select an option using the selection_vector from the master network
        state = tf.matmul(selection_vector, options)

        state = batch_squash.unflatten(state)

        def call_projection_net(proj_net):
            distribution, _ = proj_net(state,
                                       outer_rank,
                                       training=training,
                                       mask=mask)
            return distribution

        output_actions = tf.nest.map_structure(call_projection_net,
                                               self.projection_nets)

        return output_actions, network_state
Exemplo n.º 5
0
    def call(self,
             observation,
             step_type=None,
             network_state=(),
             training=False):
        del step_type  # unused.

        if self._batch_squash:
            outer_rank = nest_utils.get_outer_rank(observation,
                                                   self.input_tensor_spec)
            batch_squash = utils.BatchSquash(outer_rank)
            observation = tf.nest.map_structure(batch_squash.flatten,
                                                observation)

        if self._flat_preprocessing_layers is None:
            processed = observation
        else:
            processed = []
            for obs, layer in zip(
                    nest.flatten_up_to(self._preprocessing_nest,
                                       observation,
                                       check_types=False),
                    self._flat_preprocessing_layers):
                processed.append(layer(obs, training=training))
            if len(processed) == 1 and self._preprocessing_combiner is None:
                # If only one observation is passed and the preprocessing_combiner
                # is unspecified, use the preprocessed version of this observation.
                processed = processed[0]

        states = processed

        if self._preprocessing_combiner is not None:
            states = self._preprocessing_combiner(states)

        for layer in self._postprocessing_layers:
            states = layer(states, training=training)

        if self._batch_squash:
            states = tf.nest.map_structure(batch_squash.unflatten, states)

        return states, network_state
    def call(self, observations, step_type=(), network_state=()):
        outer_rank = nest_utils.get_outer_rank(observations,
                                               self.input_tensor_spec)
        # batch_squash, in case observations have a time sequence compoment.
        batch_squash = utils.BatchSquash(outer_rank)
        observations = tf.nest.map_structure(batch_squash.flatten,
                                             observations)

        flat_x = self._flat_x(observations)
        cnn_1 = self._cnn_1(observations)
        cnn_2 = self._cnn_2(cnn_1)
        flat_cnn = self._flat_cnn(cnn_2)
        concat_cnn_x = tf.keras.layers.concatenate([flat_x, flat_cnn])
        dense_1 = self._dense_1(concat_cnn_x)
        dense_2 = self._dense_2(dense_1)
        concat_dense_x = tf.keras.layers.concatenate([flat_x, dense_2])
        # policy_dense_1 = self._policy_dense_1(concat_dense_x)
        actions = self._action_projection_layer(concat_dense_x)

        return tf.nest.pack_sequence_as(self._action_spec,
                                        [actions]), network_state
Exemplo n.º 7
0
    def call(self, inputs, step_type=(), network_state=(), training=False):
        observation, action = inputs
        observation_spec, _ = self.input_tensor_spec
        num_outer_dims = nest_utils.get_outer_rank(observation,
                                                   observation_spec)
        has_time_dim = num_outer_dims == 2

        if has_time_dim:
            batch_squash = utils.BatchSquash(2)  # Squash B, and T dims.
            # Flatten: [B, T, ...] -> [BxT, ...]
            observation = batch_squash.flatten(observation)
            action = batch_squash.flatten(action)
        q_value, network_state = super(CriticNet, self).call(
            (observation, action),
            step_type=step_type,
            network_state=network_state,
            training=training)
        if has_time_dim:
            q_value = batch_squash.unflatten(
                q_value)  # [B x T, ...] -> [B, T, ...]
        return q_value, network_state
Exemplo n.º 8
0
  def call(self, observations, step_type, network_state):
    del step_type  # unused.
    outer_rank = nest_utils.get_outer_rank(observations, self._observation_spec)
    observations = nest.flatten(observations)
    states = tf.to_float(observations[0])

    # Reshape to only a single batch dimension for neural network functions.
    batch_squash = utils.BatchSquash(outer_rank)
    states = batch_squash.flatten(states)

    for layer in self._mlp_layers:
      states = layer(states)

    # TODO(oars): Can we avoid unflattening to flatten again
    states = batch_squash.unflatten(states)
    outputs = [
        projection(states, outer_rank)
        for projection in self._projection_networks
    ]

    return nest.pack_sequence_as(self._action_spec, outputs), network_state
Exemplo n.º 9
0
    def call(self,
             observations,
             step_type=(),
             network_state=(),
             training=False,
             mask=None):
        outer_rank = nest_utils.get_outer_rank(observations,
                                               self.input_tensor_spec)
        batch_squash = utils.BatchSquash(outer_rank)
        observations = tf.nest.map_structure(batch_squash.flatten,
                                             observations)

        state, network_state = self._encoder(observations,
                                             step_type=step_type,
                                             network_state=network_state)

        actions = self._action_projection_layer(state)
        actions = common_utils.scale_to_spec(actions, self._single_action_spec)
        actions = batch_squash.unflatten(actions)
        return tf.nest.pack_sequence_as(self._action_spec,
                                        [actions]), network_state
Exemplo n.º 10
0
    def call(self, inputs, outer_rank):
        # outer_rank is needed because the projection is not done on the raw
        # observations so getting the outer rank is hard as there is no spec to
        # compare to.
        batch_squash = utils.BatchSquash(outer_rank)
        inputs = batch_squash.flatten(inputs)

        means = self._projection_layer(inputs)
        means = tf.reshape(means, [-1] + self._output_spec.shape.as_list())
        means = self._mean_transform(means, self._output_spec)
        means = tf.cast(means, self._output_spec.dtype)

        stds = self._bias(tf.zeros_like(means))
        stds = tf.reshape(stds, [-1] + self._output_spec.shape.as_list())
        stds = self._std_transform(stds)
        stds = tf.cast(stds, self._output_spec.dtype)

        means = batch_squash.unflatten(means)
        stds = batch_squash.unflatten(stds)

        return tfp.distributions.Normal(means, stds)
    def call(self, observations, step_type=(), network_state=()):
        outer_rank = nest_utils.get_outer_rank(observations,
                                               self.input_tensor_spec)
        # We use batch_squash here in case the observations have a time sequence
        # compoment.
        batch_squash = utils.BatchSquash(outer_rank)
        observations = tf.nest.map_structure(batch_squash.flatten,
                                             observations)

        state, network_state = self._encoder(observations,
                                             step_type=step_type,
                                             network_state=network_state)
        actions = self._action_projection_layer(state)
        actions = common_utils.scale_to_spec(actions, self._action_spec)
        actions = batch_squash.unflatten(actions)
        return tf.nest.pack_sequence_as(self._action_spec,
                                        [actions]), network_state


####ACTOR TEST####
#action_spec = array_spec.BoundedArraySpec((6,), np.float32, minimum=0, maximum=10)
#observation_spec = array_spec.BoundedArraySpec((64, 64, 3), np.float32, minimum=0,
#                                        maximum=255)
#
#random_env = random_py_environment.RandomPyEnvironment(observation_spec, action_spec=action_spec)
#
## Convert the environment to a TFEnv to generate tensors.
#tf_env = tf_py_environment.TFPyEnvironment(random_env)
#
##preprocessing_layers = {
##    'image': tf.keras.models.Sequential([tf.keras.layers.Conv2D(8, 4),
##                                        tf.keras.layers.Flatten()]),
##    'vector': tf.keras.layers.Dense(5)
##    }
##preprocessing_combiner = tf.keras.layers.Concatenate(axis=-1)
#actor = ActorNetwork(tf_env.observation_spec(),
#                     tf_env.action_spec())
#
#time_step = tf_env.reset()
##print(actor(time_step.observation,time_step.step_type))
    def call(self, inputs: tf.Tensor, batch_dims: int) \
        -> Tuple[tfp.distributions.OneHotCategorical, Tuple]:
        """
        Maps from a shared layer of hidden activations of the overall action net (inputs) to a
        distribution over actions for this head alone.

        :param inputs: The hidden activation from the final shared layer of the action network.
        :param batch_dims: The number of batch dimensions in the inputs.
        :return: A (OneHotCategorical) distribution over actions for this head and the network
            state (an empty tuple as this network is not stateful).
        """
        # outer_rank is needed because the projection is not done on the raw observations so getting
        # the outer rank is hard as there is no spec to compare to.
        # BatchSquash is used to flatten and unflatten a tensor caching the original batch
        # dimension(s).
        batch_squash = network_utils.BatchSquash(batch_dims)
        # We project the logits via a linear transformation to the right dimension for the action
        # head.
        inputs = batch_squash.flatten(inputs)
        logits = self._projection_layer(inputs)
        # We finally return the appropriate TensorFlow distribution and the (empty) network state.
        return self.output_spec.build_distribution(logits=logits), ()
Exemplo n.º 13
0
    def call(self, inputs, step_type=(), network_state=()):
        outer_rank = nest_utils.get_outer_rank(inputs, self.input_tensor_spec)
        batch_squash = utils.BatchSquash(outer_rank)

        observations, actions = inputs
        observations, network_state = self._encoder(
            observations, step_type=step_type, network_state=network_state)

        observations = batch_squash.flatten(observations)

        actions = tf.cast(tf.nest.flatten(actions)[0], tf.float32)
        actions = batch_squash.flatten(actions)
        for layer in self._action_layers:
            actions = layer(actions)

        joint = tf.concat([observations, actions], -1)
        for layer in self._joint_layers:
            joint = layer(joint)

        q_value = tf.reshape(joint, [-1])
        q_value = batch_squash.unflatten(q_value)
        return q_value, network_state
Exemplo n.º 14
0
    def call(self, inputs, outer_rank):
        if inputs.dtype != self._sample_spec.dtype:
            raise ValueError(
                'Inputs to NormalProjectionNetwork must match the sample_spec.dtype.'
            )
        # outer_rank is needed because the projection is not done on the raw
        # observations so getting the outer rank is hard as there is no spec to
        # compare to.
        batch_squash = network_utils.BatchSquash(outer_rank)
        inputs = batch_squash.flatten(inputs)

        means = self._means_projection_layer(inputs)
        means = tf.reshape(means, [-1] + self._sample_spec.shape.as_list())

        if self._state_dependent_std:
            stds = self._stddev_projection_layer(inputs)
        else:
            stds = self._bias(tf.zeros_like(means))
            stds = tf.reshape(stds, [-1] + self._sample_spec.shape.as_list())

        inv_stds = self._std_transform(stds)
        if self._max_std is not None:
            inv_stds += 1 / (self._max_std - self._min_std)
        stds = 1. / inv_stds
        if self._min_std > 0:
            stds += self._min_std
        stds = tf.cast(stds, self._sample_spec.dtype)

        means = means * stds

        # If not scaling the distribution later, use a normalized mean.
        if not self._scale_distribution and self._mean_transform is not None:
            means = self._mean_transform(means, self._sample_spec)
        means = tf.cast(means, self._sample_spec.dtype)

        means = batch_squash.unflatten(means)
        stds = batch_squash.unflatten(stds)

        return self.output_spec.build_distribution(loc=means, scale=stds)
Exemplo n.º 15
0
    def call(self, inputs, outer_rank, training=False, mask=None):
        if inputs.dtype != self._sample_spec.dtype:
            raise ValueError(
                'Inputs to NormalProjectionNetwork must match the sample_spec.dtype.'
            )

        if mask is not None:
            raise NotImplementedError(
                'NormalProjectionNetwork does not yet implement action masking; got '
                'mask={}'.format(mask))

        # outer_rank is needed because the projection is not done on the raw
        # observations so getting the outer rank is hard as there is no spec to
        # compare to.
        batch_squash = network_utils.BatchSquash(outer_rank)
        inputs = batch_squash.flatten(inputs)

        means = self._means_projection_layer(inputs, training=training)
        means = tf.reshape(means, [-1] + self._sample_spec.shape.as_list())

        # If scaling the distribution later, use a normalized mean.
        if not self._scale_distribution and self._mean_transform is not None:
            means = self._mean_transform(means, self._sample_spec)
        means = tf.cast(means, self._sample_spec.dtype)

        if self._state_dependent_std:
            stds = self._stddev_projection_layer(inputs, training=training)
        else:
            stds = self._bias(tf.zeros_like(means), training=training)
            stds = tf.reshape(stds, [-1] + self._sample_spec.shape.as_list())

        if self._std_transform is not None:
            stds = self._std_transform(stds)
        stds = tf.cast(stds, self._sample_spec.dtype)

        means = batch_squash.unflatten(means)
        stds = batch_squash.unflatten(stds)

        return self.output_spec.build_distribution(loc=means, scale=stds), ()
Exemplo n.º 16
0
    def call(self, inputs, step_type, network_state=None):
        outer_rank = nest_utils.get_outer_rank(inputs, self.input_tensor_spec)
        batch_squash = utils.BatchSquash(outer_rank)  # Squash B, and T dims.

        observation, action = inputs
        observation, _ = self._obs_encoder(observation,
                                           step_type=step_type,
                                           network_state=network_state)

        output, network_state = self._lstm_encoder(inputs=(observation,
                                                           action),
                                                   step_type=step_type,
                                                   network_state=network_state)

        output = batch_squash.flatten(output)
        for layer in self._output_layers:
            output = layer(output)

        q_value = tf.reshape(output, [-1])
        q_value = batch_squash.unflatten(q_value)

        return q_value, network_state
Exemplo n.º 17
0
  def call(self, inputs, unused_step_type=None, network_state=()):
    hidden_state = tf.cast(tf.nest.flatten(inputs), tf.float32)[0]

    # Calls coming from agent.train() has a time dimension. Direct loss calls
    # may not have a time dimension. It order to make BatchSquash work, we need
    # to specify the outer dimension properly.
    has_time_dim = nest_utils.get_outer_rank(inputs,
                                             self.input_tensor_spec) == 2
    outer_rank = 2 if has_time_dim else 1
    batch_squash = network_utils.BatchSquash(outer_rank)
    hidden_state = batch_squash.flatten(hidden_state)

    for layer in self.layers:
      hidden_state = layer(hidden_state)

    actions, stdevs = tf.split(hidden_state, 2, axis=1)
    actions = batch_squash.unflatten(actions)
    stdevs = batch_squash.unflatten(stdevs)
    actions = tf.nest.pack_sequence_as(self._action_spec, [actions])
    stdevs = tf.nest.pack_sequence_as(self._action_spec, [stdevs])

    return self.output_spec.build_distribution(
        loc=actions, scale=stdevs), network_state
Exemplo n.º 18
0
 def call(self,
          observations,
          step_type=(),
          network_state=(),
          training=False):
     num_outer_dims = nest_utils.get_outer_rank(observations,
                                                self.input_tensor_spec)
     has_time_dim = num_outer_dims == 2
     if has_time_dim:
         batch_squash = utils.BatchSquash(2)  # Squash B, and T dims.
         # Flattening: [B, T, ...] -> [BxT, ...]
         observations = batch_squash.flatten(observations)
     z = self._z_encoder(observations, training=training)
     z = z.sample()
     if has_time_dim:
         z = batch_squash.unflatten(z)
     self._input_tensor_spec = self._z_spec
     output = super(RecurrentActorNet,
                    self).call(z,
                               step_type=step_type,
                               network_state=network_state,
                               training=training)
     self._input_tensor_spec = self._s_spec
     return output
Exemplo n.º 19
0
  def call(self, inputs, outer_rank):
    if inputs.dtype != self._sample_spec.dtype:
      raise ValueError(
          'Inputs to NormalProjectionNetwork must match the sample_spec.dtype.')
    # outer_rank is needed because the projection is not done on the raw
    # observations so getting the outer rank is hard as there is no spec to
    # compare to.
    batch_squash = utils.BatchSquash(outer_rank)
    inputs = batch_squash.flatten(inputs)

    means = self._projection_layer(inputs)
    means = tf.reshape(means, [-1] + self._sample_spec.shape.as_list())
    means = self._mean_transform(means, self._sample_spec)
    means = tf.cast(means, self._sample_spec.dtype)

    stds = self._bias(tf.zeros_like(means))
    stds = tf.reshape(stds, [-1] + self._sample_spec.shape.as_list())
    stds = self._std_transform(stds)
    stds = tf.cast(stds, self._sample_spec.dtype)

    means = batch_squash.unflatten(means)
    stds = batch_squash.unflatten(stds)

    return self.output_spec.build_distribution(loc=means, scale=stds)
Exemplo n.º 20
0
def normal(inputs,
           output_spec,
           outer_rank=1,
           projection_layer=default_fully_connected,
           mean_transform=tanh_squash_to_spec,
           std_initializer=tf.zeros_initializer(),
           std_transform=tf.exp,
           distribution_cls=tfp.distributions.Normal):
    """Project a batch of inputs to a batch of means and standard deviations.

  Given an output spec for a single tensor continuous action, produces a
  neural net layer converting inputs to a normal distribution matching
  the spec.  The mean is derived from a fully connected linear layer as
  mean_transform(layer_output, output_spec).  The std is fixed to a single
  trainable tensor (thus independent of the inputs).  Specifically, std is
  parameterized as std_transform(variable).

  Args:
    inputs: An input Tensor of shape [batch_size, ?].
    output_spec: An output spec (either BoundedArraySpec or BoundedTensorSpec).
    outer_rank: The number of outer dimensions of inputs to consider batch
      dimensions and to treat as batch dimensions of output distribution.
    projection_layer: Function taking in inputs, num_elements, scope and
      returning a projection of inputs to a Tensor of width num_elements.
    mean_transform: A function taking in layer output and the output_spec,
      returning the means.  Defaults to tanh_squash_to_spec.
    std_initializer: Initializer for std_dev variables.
    std_transform: The function applied to the trainable std variable. For
      example, tf.exp (default), tf.nn.softplus.
    distribution_cls: The distribution class to use for output distribution.
      Default is tfp.distributions.Normal.

  Returns:
    A tf.distribution.Normal object in which the standard deviation is not
      dependent on input.

  Raises:
    ValueError: If output_spec is invalid.
  """
    if not tensor_spec.is_bounded(output_spec):
        raise ValueError('Input output_spec is of invalid type '
                         '%s.' % type(output_spec))
    if not tensor_spec.is_continuous(output_spec):
        raise ValueError('Output is not continuous.')

    batch_squash = utils.BatchSquash(outer_rank)
    inputs = batch_squash.flatten(inputs)
    means = projection_layer(inputs,
                             output_spec.shape.num_elements(),
                             scope='means')
    stds = tf.contrib.layers.bias_add(
        tf.zeros_like(means),  # Independent of inputs.
        initializer=std_initializer,
        scope='stds',
        activation_fn=None)

    means = tf.reshape(means, [-1] + output_spec.shape.as_list())
    means = mean_transform(means, output_spec)
    means = tf.cast(means, output_spec.dtype)

    stds = tf.reshape(stds, [-1] + output_spec.shape.as_list())
    stds = std_transform(stds)
    stds = tf.cast(stds, output_spec.dtype)

    means, stds = batch_squash.unflatten(means), batch_squash.unflatten(stds)
    return distribution_cls(means, stds)
Exemplo n.º 21
0
    def call(self, inputs, step_type, network_state=(), training=False):
        observation, action = inputs
        observation_spec, _ = self.input_tensor_spec
        num_outer_dims = nest_utils.get_outer_rank(observation,
                                                   observation_spec)
        if num_outer_dims not in (1, 2):
            raise ValueError(
                'Input observation must have a batch or batch x time outer shape.'
            )

        has_time_dim = num_outer_dims == 2
        if not has_time_dim:
            # Add a time dimension to the inputs.
            observation = tf.nest.map_structure(lambda t: tf.expand_dims(t, 1),
                                                observation)
            action = tf.nest.map_structure(lambda t: tf.expand_dims(t, 1),
                                           action)
            step_type = tf.nest.map_structure(lambda t: tf.expand_dims(t, 1),
                                              step_type)

        observation = tf.cast(tf.nest.flatten(observation)[0], tf.float32)
        action = tf.cast(tf.nest.flatten(action)[0], tf.float32)

        batch_squash = utils.BatchSquash(2)  # Squash B, and T dims.
        observation = batch_squash.flatten(
            observation)  # [B, T, ...] -> [BxT, ...]
        action = batch_squash.flatten(action)

        for layer in self._observation_layers:
            observation = layer(observation, training=training)

        for layer in self._action_layers:
            action = layer(action, training=training)

        joint = tf.concat([observation, action], -1)
        for layer in self._joint_layers:
            joint = layer(joint, training=training)

        joint = batch_squash.unflatten(joint)  # [B x T, ...] -> [B, T, ...]

        network_kwargs = {}
        if isinstance(self._lstm_network, dynamic_unroll_layer.DynamicUnroll):
            network_kwargs['reset_mask'] = tf.equal(step_type,
                                                    time_step.StepType.FIRST,
                                                    name='mask')

        # Unroll over the time sequence.
        output = self._lstm_network(inputs=joint,
                                    initial_state=network_state,
                                    training=training,
                                    **network_kwargs)
        if isinstance(self._lstm_network, dynamic_unroll_layer.DynamicUnroll):
            joint, network_state = output
        else:
            joint = output[0]
            network_state = tf.nest.pack_sequence_as(
                self._lstm_network.cell.state_size,
                tf.nest.flatten(output[1:]))

        output = batch_squash.flatten(joint)  # [B, T, ...] -> [B x T, ...]

        for layer in self._output_layers:
            output = layer(output, training=training)

        q_value = tf.reshape(output, [-1])
        q_value = batch_squash.unflatten(
            q_value)  # [B x T, ...] -> [B, T, ...]
        if not has_time_dim:
            q_value = tf.squeeze(q_value, axis=1)

        return q_value, network_state
Exemplo n.º 22
0
 def squash_dataset_element(sequence, info):
     return tf.nest.map_structure(
         utils.BatchSquash(2).flatten, (sequence, info))
Exemplo n.º 23
0
  def call(self, observation, step_type, network_state=None, training=False):
    # Preprocess for multiple observations
    if self._flat_preprocessing_layers is None:
      processed = observation
    else:
      processed = []
      for obs, layer in zip(
          nest.flatten_up_to(
              self._preprocessing_nest, observation, check_types=False),
          self._flat_preprocessing_layers):
        processed.append(layer(obs, training=training))
      if len(processed) == 1 and self._preprocessing_combiner is None:
        # If only one observation is passed and the preprocessing_combiner
        # is unspecified, use the preprocessed version of this observation.
        processed = processed[0]
    observation = processed
    if self._preprocessing_combiner is not None:
      observation = self._preprocessing_combiner(observation)
    observation_spec = tensor_spec.TensorSpec((observation.shape[-1],), dtype=observation.dtype)

    num_outer_dims = nest_utils.get_outer_rank(observation,
                                               observation_spec)
    if num_outer_dims not in (1, 2):
      raise ValueError(
          'Input observation must have a batch or batch x time outer shape.')

    has_time_dim = num_outer_dims == 2
    if not has_time_dim:
      # Add a time dimension to the inputs.
      observation = tf.nest.map_structure(lambda t: tf.expand_dims(t, 1),
                                          observation)
      step_type = tf.nest.map_structure(lambda t: tf.expand_dims(t, 1),
                                        step_type)

    states = tf.cast(tf.nest.flatten(observation)[0], tf.float32)
    batch_squash = utils.BatchSquash(2)  # Squash B, and T dims.
    states = batch_squash.flatten(states)  # [B, T, ...] -> [B x T, ...]

    for layer in self._input_layers:
      states = layer(states, training=training)

    states = batch_squash.unflatten(states)  # [B x T, ...] -> [B, T, ...]

    with tf.name_scope('reset_mask'):
      reset_mask = tf.equal(step_type, time_step.StepType.FIRST)
    # Unroll over the time sequence.
    states, network_state = self._dynamic_unroll(
        states,
        reset_mask,
        initial_state=network_state,
        training=training)

    states = batch_squash.flatten(states)  # [B, T, ...] -> [B x T, ...]

    for layer in self._output_layers:
      states = layer(states, training=training)

    actions = []
    for layer, spec in zip(self._action_layers, self._flat_action_spec):
      action = layer(states, training=training)
      action = common.scale_to_spec(action, spec)
      action = batch_squash.unflatten(action)  # [B x T, ...] -> [B, T, ...]
      if not has_time_dim:
        action = tf.squeeze(action, axis=1)
      actions.append(action)

    output_actions = tf.nest.pack_sequence_as(self._output_tensor_spec, actions)
    return output_actions, network_state
Exemplo n.º 24
0
    def _apply_actor_network(self, time_step, policy_state):
        has_batch_dim = time_step.step_type.shape.as_list()[0] > 1
        observation = time_step.observation
        if self._observation_normalizer:
            observation = self._observation_normalizer.normalize(observation)
        actions, policy_state = self._actor_network(observation,
                                                    time_step.step_type,
                                                    policy_state,
                                                    training=self._training)
        if has_batch_dim:
            return actions, policy_state

        # samples "best" safe action out of 50
        sampled_ac = actions.sample(50)
        obs = nest_utils.stack_nested_tensors(
            [time_step.observation for _ in range(50)])
        obs_outer_rank = nest_utils.get_outer_rank(
            obs, self.time_step_spec.observation)
        ac_outer_rank = nest_utils.get_outer_rank(sampled_ac, self.action_spec)
        obs_batch_squash = utils.BatchSquash(obs_outer_rank)
        ac_batch_squash = utils.BatchSquash(ac_outer_rank)
        obs = tf.nest.map_structure(obs_batch_squash.flatten, obs)
        sampled_ac = tf.nest.map_structure(ac_batch_squash.flatten, sampled_ac)
        q_val, _ = self._safety_critic_network((obs, sampled_ac),
                                               time_step.step_type)
        fail_prob = tf.nn.sigmoid(q_val)
        safe_ac_mask = fail_prob < self._safety_threshold
        safe_ac_idx = tf.where(safe_ac_mask)

        resample_count = 0
        start_time = time.time()
        while self._training and resample_count < 4 and not safe_ac_idx.shape.as_list(
        )[0]:
            if self._resample_counter is not None:
                self._resample_counter()
            resample_count += 1
            if isinstance(actions, dist_utils.SquashToSpecNormal):
                scale = actions.input_distribution.scale * 1.5  # increase variance by constant 1.5
                ac_mean = actions.mean()
            else:
                scale = actions.scale * 1.5
                ac_mean = actions.mean()
            actions = self._actor_network.output_spec.build_distribution(
                loc=ac_mean, scale=scale)
            sampled_ac = actions.sample(50)
            sampled_ac = tf.nest.map_structure(ac_batch_squash.flatten,
                                               sampled_ac)
            q_val, _ = self._safety_critic_network((obs, sampled_ac),
                                                   time_step.step_type)

            fail_prob = tf.nn.sigmoid(q_val)
            safe_ac_idx = tf.where(fail_prob < self._safety_threshold)
        # logging.debug('resampled {} times, {} seconds'.format(resample_count, time.time() - start_time))
        sampled_ac = ac_batch_squash.unflatten(sampled_ac)
        if None in safe_ac_idx.shape.as_list() or not np.prod(
                safe_ac_idx.shape.as_list()):  # return safest action
            safe_idx = tf.argmin(fail_prob)
        else:
            sampled_ac = tf.gather(sampled_ac, safe_ac_idx)
            fail_prob_safe = tf.gather(fail_prob, safe_ac_idx)
            if self._training:
                safe_idx = tf.argmax(fail_prob_safe)[
                    0]  # picks most unsafe action out of "safe" options
            else:
                safe_idx = tf.argmin(fail_prob_safe)[0]
        ac = sampled_ac[safe_idx]
        assert ac.shape.as_list(
        )[0] == 1, 'action shape is not correct: {}'.format(ac.shape.as_list())
        return ac, policy_state
Exemplo n.º 25
0
    def _loss(self,
              experience,
              td_errors_loss_fn=tf.losses.huber_loss,
              gamma=1.0,
              reward_scale_factor=1.0,
              weights=None):
        """Computes critic loss for CategoricalDQN training.

    See Algorithm 1 and the discussion immediately preceding it in page 6 of
    "A Distributional Perspective on Reinforcement Learning"
      Bellemare et al., 2017
      https://arxiv.org/abs/1707.06887

    Args:
      experience: A batch of experience data in the form of a `Trajectory`. The
        structure of `experience` must match that of `self.policy.step_spec`.
        All tensors in `experience` must be shaped `[batch, time, ...]` where
        `time` must be equal to `self.required_experience_time_steps` if that
        property is not `None`.
      td_errors_loss_fn: A function(td_targets, predictions) to compute loss.
      gamma: Discount for future rewards.
      reward_scale_factor: Multiplicative factor to scale rewards.
      weights: Optional weights used for importance sampling.
    Returns:
      critic_loss: A scalar critic loss.
    Raises:
      ValueError:
        if the number of actions is greater than 1.
    """
        # Check that `experience` includes two outer dimensions [B, T, ...]. This
        # method requires a time dimension to compute the loss properly.
        self._check_trajectory_dimensions(experience)

        if self._n_step_update == 1:
            time_steps, actions, next_time_steps = self._experience_to_transitions(
                experience)
        else:
            # To compute n-step returns, we need the first time steps, the first
            # actions, and the last time steps. Therefore we extract the first and
            # last transitions from our Trajectory.
            first_two_steps = tf.nest.map_structure(lambda x: x[:, :2],
                                                    experience)
            last_two_steps = tf.nest.map_structure(lambda x: x[:, -2:],
                                                   experience)
            time_steps, actions, _ = self._experience_to_transitions(
                first_two_steps)
            _, _, next_time_steps = self._experience_to_transitions(
                last_two_steps)

        with tf.name_scope('critic_loss'):
            tf.nest.assert_same_structure(actions, self.action_spec)
            tf.nest.assert_same_structure(time_steps, self.time_step_spec)
            tf.nest.assert_same_structure(next_time_steps, self.time_step_spec)

            rank = nest_utils.get_outer_rank(time_steps.observation,
                                             self._time_step_spec.observation)

            # If inputs have a time dimension and the q_network is stateful,
            # combine the batch and time dimension.
            batch_squash = (None if rank <= 1 or self._q_network.state_spec
                            in ((), None) else utils.BatchSquash(rank))

            # q_logits contains the Q-value logits for all actions.
            q_logits, _ = self._q_network(time_steps.observation,
                                          time_steps.step_type)
            next_q_distribution = self._next_q_distribution(
                next_time_steps, batch_squash)

            if batch_squash is not None:
                # Squash outer dimensions to a single dimensions for facilitation
                # computing the loss the following. Required for supporting temporal
                # inputs, for example.
                q_logits = batch_squash.flatten(q_logits)
                actions = batch_squash.flatten(actions)
                next_time_steps = tf.nest.map_structure(
                    batch_squash.flatten, next_time_steps)

            actions = tf.nest.flatten(actions)[0]
            if actions.shape.ndims > 1:
                actions = tf.squeeze(actions, range(1, actions.shape.ndims))

            # Project the sample Bellman update \hat{T}Z_{\theta} onto the original
            # support of Z_{\theta} (see Figure 1 in paper).
            batch_size = tf.shape(q_logits)[0]
            tiled_support = tf.tile(self._support, [batch_size])
            tiled_support = tf.reshape(tiled_support,
                                       [batch_size, self._num_atoms])

            if self._n_step_update == 1:
                discount = next_time_steps.discount
                if discount.shape.ndims == 1:
                    # We expect discount to have a shape of [batch_size], while
                    # tiled_support will have a shape of [batch_size, num_atoms]. To
                    # multiply these, we add a second dimension of 1 to the discount.
                    discount = discount[:, None]
                next_value_term = tf.multiply(discount,
                                              tiled_support,
                                              name='next_value_term')

                reward = next_time_steps.reward
                if reward.shape.ndims == 1:
                    # See the explanation above.
                    reward = reward[:, None]
                reward_term = tf.multiply(reward_scale_factor,
                                          reward,
                                          name='reward_term')

                target_support = tf.add(reward_term,
                                        gamma * next_value_term,
                                        name='target_support')
            else:
                # When computing discounted return, we need to throw out the last time
                # index of both reward and discount, which are filled with dummy values
                # to match the dimensions of the observation.
                rewards = reward_scale_factor * experience.reward[:, :-1]
                discounts = gamma * experience.discount[:, :-1]

                # TODO(b/134618876): Properly handle Trajectories that include episode
                # boundaries with nonzero discount.

                # TODO(b/131557265): Replace value_ops.discounted_return with a method
                # that only computes the single value needed.
                discounted_rewards = value_ops.discounted_return(
                    rewards=rewards,
                    discounts=discounts,
                    final_value=tf.zeros([batch_size], dtype=discounts.dtype),
                    time_major=False)

                # We only need the first value within the time dimension which
                # corresponds to the full final return. The remaining values are only
                # partial returns.
                discounted_rewards = discounted_rewards[:, :1]

                final_value_discount = tf.reduce_prod(discounts, axis=1)
                final_value_discount = final_value_discount[:, None]

                # Save the values of discounted_rewards and final_value_discount in
                # order to check them in unit tests.
                self._discounted_rewards = discounted_rewards
                self._final_value_discount = final_value_discount

                target_support = tf.add(discounted_rewards,
                                        final_value_discount * tiled_support,
                                        name='target_support')

            target_distribution = tf.stop_gradient(
                project_distribution(target_support, next_q_distribution,
                                     self._support))

            # Obtain the current Q-value logits for the selected actions.
            indices = tf.range(tf.shape(q_logits)[0])[:, None]
            indices = tf.cast(indices, actions.dtype)
            reshaped_actions = tf.concat([indices, actions[:, None]], 1)
            chosen_action_logits = tf.gather_nd(q_logits, reshaped_actions)

            # Compute the cross-entropy loss between the logits. If inputs have
            # a time dimension, compute the sum over the time dimension before
            # computing the mean over the batch dimension.
            if batch_squash is not None:
                target_distribution = batch_squash.unflatten(
                    target_distribution)
                chosen_action_logits = batch_squash.unflatten(
                    chosen_action_logits)
                critic_loss = tf.reduce_mean(
                    tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits_v2(
                        labels=target_distribution,
                        logits=chosen_action_logits),
                                  axis=1))
            else:
                critic_loss = tf.reduce_mean(
                    tf.nn.softmax_cross_entropy_with_logits_v2(
                        labels=target_distribution,
                        logits=chosen_action_logits))

            with tf.name_scope('Losses/'):
                tf.compat.v2.summary.scalar('critic_loss',
                                            critic_loss,
                                            step=self.train_step_counter)

            if self._debug_summaries:
                distribution_errors = target_distribution - chosen_action_logits
                with tf.name_scope('distribution_errors'):
                    common.generate_tensor_summaries(
                        'distribution_errors',
                        distribution_errors,
                        step=self.train_step_counter)
                    tf.compat.v2.summary.scalar(
                        'mean',
                        tf.reduce_mean(distribution_errors),
                        step=self.train_step_counter)
                    tf.compat.v2.summary.scalar(
                        'mean_abs',
                        tf.reduce_mean(tf.abs(distribution_errors)),
                        step=self.train_step_counter)
                    tf.compat.v2.summary.scalar(
                        'max',
                        tf.reduce_max(distribution_errors),
                        step=self.train_step_counter)
                    tf.compat.v2.summary.scalar(
                        'min',
                        tf.reduce_min(distribution_errors),
                        step=self.train_step_counter)
                with tf.name_scope('target_distribution'):
                    common.generate_tensor_summaries(
                        'target_distribution',
                        target_distribution,
                        step=self.train_step_counter)

            # TODO(b/127318640): Give appropriate values for td_loss and td_error for
            # prioritized replay.
            return tf_agent.LossInfo(
                critic_loss, dqn_agent.DqnLossInfo(td_loss=(), td_error=()))
Exemplo n.º 26
0
  def actor_loss(self,
                 time_steps,
                 actions,
                 next_time_steps,
                 weights=None):
    """Computes the actor_loss for SAC training.

    Args:
      time_steps: A batch of timesteps.
      actions: A batch of actions.
      next_time_steps: A batch of next timesteps.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.

    Returns:
      actor_loss: A scalar actor loss.
    """
    prev_time_steps, prev_actions, time_steps = time_steps, actions, next_time_steps  # pylint: disable=line-too-long
    with tf.name_scope('actor_loss'):
      nest_utils.assert_same_structure(time_steps, self.time_step_spec)

      actions, log_pi = self._actions_and_log_probs(time_steps)
      target_input = (time_steps.observation, actions)
      target_q_values1, _ = self._critic_network_1(
          target_input, step_type=time_steps.step_type, training=False)
      target_q_values2, _ = self._critic_network_2(
          target_input, step_type=time_steps.step_type, training=False)
      target_q_values = tf.minimum(target_q_values1, target_q_values2)
      actor_loss = tf.exp(self._log_alpha) * log_pi - target_q_values

      ### Flatten time dimension. We'll add it back when adding the loss.
      num_outer_dims = nest_utils.get_outer_rank(time_steps,
                                                 self.time_step_spec)
      has_time_dim = (num_outer_dims == 2)
      if has_time_dim:
        batch_squash = utils.BatchSquash(2)  # Squash B, and T dims.
        obs = batch_squash.flatten(time_steps.observation)
        prev_obs = batch_squash.flatten(prev_time_steps.observation)
        prev_actions = batch_squash.flatten(prev_actions)
      else:
        obs = time_steps.observation
        prev_obs = prev_time_steps.observation
      z = self._actor_network._z_encoder(obs, training=True)  # pylint: disable=protected-access
      prior = self._actor_network._predictor((prev_obs, prev_actions),  # pylint: disable=protected-access
                                             training=True)

      # kl is a vector of length batch_size, which has already been summed over
      # the latent dimension z.
      kl = tfp.distributions.kl_divergence(z, prior)
      if has_time_dim:
        kl = batch_squash.unflatten(kl)

      kl_coef = tf.stop_gradient(
          tf.exp(self._actor_network._log_kl_coefficient))  # pylint: disable=protected-access
      # The actor loss trains both the predictor and the encoder.
      actor_loss += kl_coef * kl

      if actor_loss.shape.rank > 1:
        # Sum over the time dimension.
        actor_loss = tf.reduce_sum(
            actor_loss, axis=range(1, actor_loss.shape.rank))
      reg_loss = self._actor_network.losses if self._actor_network else None
      agg_loss = common.aggregate_losses(
          per_example_loss=actor_loss,
          sample_weight=weights,
          regularization_loss=reg_loss)
      actor_loss = agg_loss.total_loss
      self._actor_loss_debug_summaries(actor_loss, actions, log_pi,
                                       target_q_values, time_steps)
      tf.compat.v2.summary.scalar(
          name='encoder_kl',
          data=tf.reduce_mean(kl),
          step=self.train_step_counter)

      return actor_loss