Esempio n. 1
0
  def call(self, observation, step_type=None, network_state=(), training=False):
    del step_type  # unused.

    if self._batch_squash:
      outer_rank = nest_utils.get_outer_rank(observation,
                                             self.input_tensor_spec)
      batch_squash = utils.BatchSquash(outer_rank)
      observation = tf.nest.map_structure(batch_squash.flatten, observation)

    if self._flat_preprocessing_layers is None:
      processed = observation
    else:
      processed = []
      for obs, layer in zip(
          nest_utils.flatten_up_to(self._preprocessing_nest, observation),
          self._flat_preprocessing_layers):
        processed.append(layer(obs, training=training))
      if len(processed) == 1 and self._preprocessing_combiner is None:
        # If only one observation is passed and the preprocessing_combiner
        # is unspecified, use the preprocessed version of this observation.
        processed = processed[0]

    states = processed

    if self._preprocessing_combiner is not None:
      states, attention_weights = self._preprocessing_combiner(states)

    for layer in self._postprocessing_layers:
      states = layer(states, training=training)

    if self._batch_squash:
      states = tf.nest.map_structure(batch_squash.unflatten, states)

    return states, network_state, attention_weights
    def _adversary_to_obs_space_dtype(self, observation):
        # Make sure we handle cases where observations are provided as a list.
        flat_obs = nest_utils.flatten_up_to(self.adversary_observation_spec,
                                            observation)

        matched_observations = []
        for spec, obs in zip(self.adversary_flat_obs_spec, flat_obs):
            matched_observations.append(np.asarray(obs, dtype=spec.dtype))
        return tf.nest.pack_sequence_as(self.adversary_observation_spec,
                                        matched_observations)
Esempio n. 3
0
    def call(self, inputs, network_state=(), **kwargs):
        nest_utils.assert_same_structure(
            self._nested_layers,
            inputs,
            allow_shallow_nest1=True,
            message=
            ('`self.nested_layers` and `inputs` do not have matching structures'
             ))

        if network_state:
            nest_utils.assert_same_structure(
                self.state_spec,
                network_state,
                allow_shallow_nest1=True,
                message=
                ('network_state and state_spec do not have matching structure'
                 ))
            nested_layers_state = network_state
        else:
            nested_layers_state = tf.nest.map_structure(
                lambda _: (), self._nested_layers)

        # Here we must use map_structure_up_to because nested_layers_state has a
        # "deeper" structure than self._nested_layers.  For example, an LSTM
        # layer's state is composed of a list with two tensors.  The
        # tf.nest.map_structure function would raise an error if two
        # "incompatible" structures are passed in this way.
        def _mapper(inp, layer, state):  # pylint: disable=invalid-name
            return layer(inp, network_state=state, **kwargs)

        outputs_and_next_state = nest_utils.map_structure_up_to(
            self._nested_layers, _mapper, inputs, self._nested_layers,
            nested_layers_state)

        flat_outputs_and_next_state = nest_utils.flatten_up_to(
            self._nested_layers, outputs_and_next_state)
        flat_outputs, flat_next_state = zip(*flat_outputs_and_next_state)

        outputs = tf.nest.pack_sequence_as(self._nested_layers, flat_outputs)
        next_network_state = tf.nest.pack_sequence_as(self._nested_layers,
                                                      flat_next_state)

        return outputs, next_network_state
Esempio n. 4
0
def construct_attention_networks(
    observation_spec,
    action_spec,
    use_rnns=True,
    actor_fc_layers=(200, 100),
    value_fc_layers=(200, 100),
    lstm_size=(128, ),
    conv_filters=8,
    conv_kernel=3,
    scalar_fc=5,
    scalar_name="direction",
    scalar_dim=4,
    use_stacks=False,
):
    """Creates an actor and critic network designed for use with MultiGrid.

  A convolution layer processes the image and a dense layer processes the
  direction the agent is facing. These are fed into some fully connected layers
  and an LSTM.

  Args:
    observation_spec: A tf-agents observation spec.
    action_spec: A tf-agents action spec.
    use_rnns: If True, will construct RNN networks. Non-recurrent networks are
      not supported currently.
    actor_fc_layers: Dimension and number of fully connected layers in actor.
    value_fc_layers: Dimension and number of fully connected layers in critic.
    lstm_size: Number of cells in each LSTM layers.
    conv_filters: Number of convolution filters.
    conv_kernel: Size of the convolution kernel.
    scalar_fc: Number of neurons in the fully connected layer processing the
      scalar input.
    scalar_name: Name of the scalar input.
    scalar_dim: Highest possible value for the scalar input. Used to convert to
      one-hot representation.
    use_stacks: Use ResNet stacks (compresses the image).

  Returns:
    A tf-agents ActorDistributionRnnNetwork for the actor, and a ValueRnnNetwork
    for the critic.
  """
    if not use_rnns:
        raise NotImplementedError(
            "Non-recurrent attention networks are not suppported.")
    preprocessing_layers = {
        "policy_state": tf.keras.layers.Lambda(lambda x: x)
    }
    if use_stacks:
        preprocessing_layers["image"] = tf.keras.models.Sequential([
            multigrid_networks.cast_and_scale(),
            _Stack(conv_filters // 2, 2),
            _Stack(conv_filters, 2),
            tf.keras.layers.ReLU(),
        ])
    else:
        preprocessing_layers["image"] = tf.keras.models.Sequential([
            multigrid_networks.cast_and_scale(),
            tf.keras.layers.Conv2D(conv_filters, conv_kernel, padding="same"),
            tf.keras.layers.ReLU(),
        ])
    if scalar_name in observation_spec:
        preprocessing_layers[scalar_name] = tf.keras.models.Sequential([
            multigrid_networks.one_hot_layer(scalar_dim),
            tf.keras.layers.Dense(scalar_fc)
        ])
    if "position" in observation_spec:
        preprocessing_layers["position"] = tf.keras.models.Sequential([
            multigrid_networks.cast_and_scale(),
            tf.keras.layers.Dense(scalar_fc)
        ])

    preprocessing_nest = tf.nest.map_structure(lambda l: None,
                                               preprocessing_layers)
    flat_observation_spec = nest_utils.flatten_up_to(
        preprocessing_nest,
        observation_spec,
    )
    image_index_flat = flat_observation_spec.index(observation_spec["image"])
    network_state_index_flat = flat_observation_spec.index(
        observation_spec["policy_state"])
    if use_stacks:
        image_shape = [i // 4
                       for i in observation_spec["image"].shape]  # H x W x D
    else:
        image_shape = observation_spec["image"].shape
    preprocessing_combiner = AttentionCombinerConv(image_index_flat,
                                                   network_state_index_flat,
                                                   image_shape)

    custom_objects = {"_Stack": _Stack}
    with tf.keras.utils.custom_object_scope(custom_objects):
        actor_net = AttentionActorDistributionRnnNetwork(
            observation_spec,
            action_spec,
            preprocessing_layers=preprocessing_layers,
            preprocessing_combiner=preprocessing_combiner,
            input_fc_layer_params=actor_fc_layers,
            output_fc_layer_params=None,
            lstm_size=lstm_size)
        value_net = AttentionValueRnnNetwork(
            observation_spec,
            preprocessing_layers=preprocessing_layers,
            preprocessing_combiner=preprocessing_combiner,
            input_fc_layer_params=value_fc_layers,
            output_fc_layer_params=None)

    return actor_net, value_net
def construct_attention_networks(observation_spec,
                                 action_spec,
                                 use_rnns=True,
                                 actor_fc_layers=(200, 100),
                                 value_fc_layers=(200, 100),
                                 lstm_size=(128,),
                                 conv_filters=8,
                                 conv_kernel=3,
                                 scalar_fc=5,
                                 scalar_name='direction',
                                 scalar_dim=4):
  """Creates an actor and critic network designed for use with MultiGrid.

  A convolution layer processes the image and a dense layer processes the
  direction the agent is facing. These are fed into some fully connected layers
  and an LSTM.

  Args:
    observation_spec: A tf-agents observation spec.
    action_spec: A tf-agents action spec.
    use_rnns: If True, will construct RNN networks.
    actor_fc_layers: Dimension and number of fully connected layers in actor.
    value_fc_layers: Dimension and number of fully connected layers in critic.
    lstm_size: Number of cells in each LSTM layers.
    conv_filters: Number of convolution filters.
    conv_kernel: Size of the convolution kernel.
    scalar_fc: Number of neurons in the fully connected layer processing the
      scalar input.
    scalar_name: Name of the scalar input.
    scalar_dim: Highest possible value for the scalar input. Used to convert to
      one-hot representation.

  Returns:
    A tf-agents ActorDistributionRnnNetwork for the actor, and a ValueRnnNetwork
    for the critic.
  """
  preprocessing_layers = {
      'image':
          tf.keras.models.Sequential([
              cast_and_scale(),
              tf.keras.layers.Conv2D(conv_filters, conv_kernel, padding='same'),
              tf.keras.layers.ReLU(),
          ]),
      'policy_state':
          tf.keras.layers.Lambda(lambda x: x)
  }
  if scalar_name in observation_spec:
    preprocessing_layers[scalar_name] = tf.keras.models.Sequential(
        [one_hot_layer(scalar_dim),
         tf.keras.layers.Dense(scalar_fc)])
  if 'position' in observation_spec:
    preprocessing_layers['position'] = tf.keras.models.Sequential(
        [cast_and_scale(), tf.keras.layers.Dense(scalar_fc)])

  preprocessing_nest = tf.nest.map_structure(lambda l: None,
                                             preprocessing_layers)
  flat_observation_spec = nest_utils.flatten_up_to(
      preprocessing_nest,
      observation_spec,
  )
  image_index_flat = flat_observation_spec.index(observation_spec['image'])
  network_state_index_flat = flat_observation_spec.index(
      observation_spec['policy_state'])
  image_shape = observation_spec['image'].shape  # N x H x W x D
  preprocessing_combiner = AttentionCombinerConv(image_index_flat,
                                                 network_state_index_flat,
                                                 image_shape)

  if use_rnns:
    actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
        observation_spec,
        action_spec,
        preprocessing_layers=preprocessing_layers,
        preprocessing_combiner=preprocessing_combiner,
        input_fc_layer_params=actor_fc_layers,
        output_fc_layer_params=None,
        lstm_size=lstm_size)
    value_net = value_rnn_network.ValueRnnNetwork(
        observation_spec,
        preprocessing_layers=preprocessing_layers,
        preprocessing_combiner=preprocessing_combiner,
        input_fc_layer_params=value_fc_layers,
        output_fc_layer_params=None)
  else:
    actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_spec,
        action_spec,
        preprocessing_layers=preprocessing_layers,
        preprocessing_combiner=preprocessing_combiner,
        fc_layer_params=actor_fc_layers,
        activation_fn=tf.keras.activations.tanh)
    value_net = value_network.ValueNetwork(
        observation_spec,
        preprocessing_layers=preprocessing_layers,
        preprocessing_combiner=preprocessing_combiner,
        fc_layer_params=value_fc_layers,
        activation_fn=tf.keras.activations.tanh)

  return actor_net, value_net