def call(self, observation, step_type=None, network_state=(), training=False): del step_type # unused. if self._batch_squash: outer_rank = nest_utils.get_outer_rank(observation, self.input_tensor_spec) batch_squash = utils.BatchSquash(outer_rank) observation = tf.nest.map_structure(batch_squash.flatten, observation) if self._flat_preprocessing_layers is None: processed = observation else: processed = [] for obs, layer in zip( nest_utils.flatten_up_to(self._preprocessing_nest, observation), self._flat_preprocessing_layers): processed.append(layer(obs, training=training)) if len(processed) == 1 and self._preprocessing_combiner is None: # If only one observation is passed and the preprocessing_combiner # is unspecified, use the preprocessed version of this observation. processed = processed[0] states = processed if self._preprocessing_combiner is not None: states, attention_weights = self._preprocessing_combiner(states) for layer in self._postprocessing_layers: states = layer(states, training=training) if self._batch_squash: states = tf.nest.map_structure(batch_squash.unflatten, states) return states, network_state, attention_weights
def _adversary_to_obs_space_dtype(self, observation): # Make sure we handle cases where observations are provided as a list. flat_obs = nest_utils.flatten_up_to(self.adversary_observation_spec, observation) matched_observations = [] for spec, obs in zip(self.adversary_flat_obs_spec, flat_obs): matched_observations.append(np.asarray(obs, dtype=spec.dtype)) return tf.nest.pack_sequence_as(self.adversary_observation_spec, matched_observations)
def call(self, inputs, network_state=(), **kwargs): nest_utils.assert_same_structure( self._nested_layers, inputs, allow_shallow_nest1=True, message= ('`self.nested_layers` and `inputs` do not have matching structures' )) if network_state: nest_utils.assert_same_structure( self.state_spec, network_state, allow_shallow_nest1=True, message= ('network_state and state_spec do not have matching structure' )) nested_layers_state = network_state else: nested_layers_state = tf.nest.map_structure( lambda _: (), self._nested_layers) # Here we must use map_structure_up_to because nested_layers_state has a # "deeper" structure than self._nested_layers. For example, an LSTM # layer's state is composed of a list with two tensors. The # tf.nest.map_structure function would raise an error if two # "incompatible" structures are passed in this way. def _mapper(inp, layer, state): # pylint: disable=invalid-name return layer(inp, network_state=state, **kwargs) outputs_and_next_state = nest_utils.map_structure_up_to( self._nested_layers, _mapper, inputs, self._nested_layers, nested_layers_state) flat_outputs_and_next_state = nest_utils.flatten_up_to( self._nested_layers, outputs_and_next_state) flat_outputs, flat_next_state = zip(*flat_outputs_and_next_state) outputs = tf.nest.pack_sequence_as(self._nested_layers, flat_outputs) next_network_state = tf.nest.pack_sequence_as(self._nested_layers, flat_next_state) return outputs, next_network_state
def construct_attention_networks( observation_spec, action_spec, use_rnns=True, actor_fc_layers=(200, 100), value_fc_layers=(200, 100), lstm_size=(128, ), conv_filters=8, conv_kernel=3, scalar_fc=5, scalar_name="direction", scalar_dim=4, use_stacks=False, ): """Creates an actor and critic network designed for use with MultiGrid. A convolution layer processes the image and a dense layer processes the direction the agent is facing. These are fed into some fully connected layers and an LSTM. Args: observation_spec: A tf-agents observation spec. action_spec: A tf-agents action spec. use_rnns: If True, will construct RNN networks. Non-recurrent networks are not supported currently. actor_fc_layers: Dimension and number of fully connected layers in actor. value_fc_layers: Dimension and number of fully connected layers in critic. lstm_size: Number of cells in each LSTM layers. conv_filters: Number of convolution filters. conv_kernel: Size of the convolution kernel. scalar_fc: Number of neurons in the fully connected layer processing the scalar input. scalar_name: Name of the scalar input. scalar_dim: Highest possible value for the scalar input. Used to convert to one-hot representation. use_stacks: Use ResNet stacks (compresses the image). Returns: A tf-agents ActorDistributionRnnNetwork for the actor, and a ValueRnnNetwork for the critic. """ if not use_rnns: raise NotImplementedError( "Non-recurrent attention networks are not suppported.") preprocessing_layers = { "policy_state": tf.keras.layers.Lambda(lambda x: x) } if use_stacks: preprocessing_layers["image"] = tf.keras.models.Sequential([ multigrid_networks.cast_and_scale(), _Stack(conv_filters // 2, 2), _Stack(conv_filters, 2), tf.keras.layers.ReLU(), ]) else: preprocessing_layers["image"] = tf.keras.models.Sequential([ multigrid_networks.cast_and_scale(), tf.keras.layers.Conv2D(conv_filters, conv_kernel, padding="same"), tf.keras.layers.ReLU(), ]) if scalar_name in observation_spec: preprocessing_layers[scalar_name] = tf.keras.models.Sequential([ multigrid_networks.one_hot_layer(scalar_dim), tf.keras.layers.Dense(scalar_fc) ]) if "position" in observation_spec: preprocessing_layers["position"] = tf.keras.models.Sequential([ multigrid_networks.cast_and_scale(), tf.keras.layers.Dense(scalar_fc) ]) preprocessing_nest = tf.nest.map_structure(lambda l: None, preprocessing_layers) flat_observation_spec = nest_utils.flatten_up_to( preprocessing_nest, observation_spec, ) image_index_flat = flat_observation_spec.index(observation_spec["image"]) network_state_index_flat = flat_observation_spec.index( observation_spec["policy_state"]) if use_stacks: image_shape = [i // 4 for i in observation_spec["image"].shape] # H x W x D else: image_shape = observation_spec["image"].shape preprocessing_combiner = AttentionCombinerConv(image_index_flat, network_state_index_flat, image_shape) custom_objects = {"_Stack": _Stack} with tf.keras.utils.custom_object_scope(custom_objects): actor_net = AttentionActorDistributionRnnNetwork( observation_spec, action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=actor_fc_layers, output_fc_layer_params=None, lstm_size=lstm_size) value_net = AttentionValueRnnNetwork( observation_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=value_fc_layers, output_fc_layer_params=None) return actor_net, value_net
def construct_attention_networks(observation_spec, action_spec, use_rnns=True, actor_fc_layers=(200, 100), value_fc_layers=(200, 100), lstm_size=(128,), conv_filters=8, conv_kernel=3, scalar_fc=5, scalar_name='direction', scalar_dim=4): """Creates an actor and critic network designed for use with MultiGrid. A convolution layer processes the image and a dense layer processes the direction the agent is facing. These are fed into some fully connected layers and an LSTM. Args: observation_spec: A tf-agents observation spec. action_spec: A tf-agents action spec. use_rnns: If True, will construct RNN networks. actor_fc_layers: Dimension and number of fully connected layers in actor. value_fc_layers: Dimension and number of fully connected layers in critic. lstm_size: Number of cells in each LSTM layers. conv_filters: Number of convolution filters. conv_kernel: Size of the convolution kernel. scalar_fc: Number of neurons in the fully connected layer processing the scalar input. scalar_name: Name of the scalar input. scalar_dim: Highest possible value for the scalar input. Used to convert to one-hot representation. Returns: A tf-agents ActorDistributionRnnNetwork for the actor, and a ValueRnnNetwork for the critic. """ preprocessing_layers = { 'image': tf.keras.models.Sequential([ cast_and_scale(), tf.keras.layers.Conv2D(conv_filters, conv_kernel, padding='same'), tf.keras.layers.ReLU(), ]), 'policy_state': tf.keras.layers.Lambda(lambda x: x) } if scalar_name in observation_spec: preprocessing_layers[scalar_name] = tf.keras.models.Sequential( [one_hot_layer(scalar_dim), tf.keras.layers.Dense(scalar_fc)]) if 'position' in observation_spec: preprocessing_layers['position'] = tf.keras.models.Sequential( [cast_and_scale(), tf.keras.layers.Dense(scalar_fc)]) preprocessing_nest = tf.nest.map_structure(lambda l: None, preprocessing_layers) flat_observation_spec = nest_utils.flatten_up_to( preprocessing_nest, observation_spec, ) image_index_flat = flat_observation_spec.index(observation_spec['image']) network_state_index_flat = flat_observation_spec.index( observation_spec['policy_state']) image_shape = observation_spec['image'].shape # N x H x W x D preprocessing_combiner = AttentionCombinerConv(image_index_flat, network_state_index_flat, image_shape) if use_rnns: actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( observation_spec, action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=actor_fc_layers, output_fc_layer_params=None, lstm_size=lstm_size) value_net = value_rnn_network.ValueRnnNetwork( observation_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=value_fc_layers, output_fc_layer_params=None) else: actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, fc_layer_params=actor_fc_layers, activation_fn=tf.keras.activations.tanh) value_net = value_network.ValueNetwork( observation_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, fc_layer_params=value_fc_layers, activation_fn=tf.keras.activations.tanh) return actor_net, value_net