def call(self, observation, step_type, network_state=(), training=False): state, network_state = self._lstm_encoder(observation, step_type=step_type, network_state=network_state, training=training) outer_rank = nest_utils.get_outer_rank(observation, self.input_tensor_spec) output_actions = tf.nest.map_structure( lambda proj_net: proj_net(state, outer_rank, training=training)[0], self._projection_networks) return output_actions, network_state
def call(self, observation, step_type, network_state=(), training=False): num_outer_dims = nest_utils.get_outer_rank(observation, self.input_tensor_spec) if num_outer_dims not in (1, 2): raise ValueError( 'Input observation must have a batch or batch x time outer shape.' ) has_time_dim = num_outer_dims == 2 if not has_time_dim: # Add a time dimension to the inputs. observation = tf.nest.map_structure(lambda t: tf.expand_dims(t, 1), observation) step_type = tf.nest.map_structure(lambda t: tf.expand_dims(t, 1), step_type) states = tf.cast(tf.nest.flatten(observation)[0], tf.float32) batch_squash = utils.BatchSquash(2) # Squash B, and T dims. states = batch_squash.flatten(states) # [B, T, ...] -> [B x T, ...] for layer in self._input_layers: states = layer(states, training=training) states = batch_squash.unflatten(states) # [B x T, ...] -> [B, T, ...] with tf.name_scope('reset_mask'): reset_mask = tf.equal(step_type, time_step.StepType.FIRST) # Unroll over the time sequence. states, network_state = self._dynamic_unroll( states, reset_mask=reset_mask, initial_state=network_state, training=training) states = batch_squash.flatten(states) # [B, T, ...] -> [B x T, ...] for layer in self._output_layers: states = layer(states, training=training) actions = [] for layer, spec in zip(self._action_layers, self._flat_action_spec): action = layer(states, training=training) action = common.scale_to_spec(action, spec) action = batch_squash.unflatten( action) # [B x T, ...] -> [B, T, ...] if not has_time_dim: action = tf.squeeze(action, axis=1) actions.append(action) output_actions = tf.nest.pack_sequence_as(self._output_tensor_spec, actions) return output_actions, network_state
def average_outer_dims(tensor, spec): """ Args: tensor (tf.Tensor): a single Tensor spec (tf.TensorSpec): Returns: the average tensor across outer dims """ outer_dims = get_outer_rank(tensor, spec) batch_squash = BatchSquash(outer_dims) tensor = batch_squash.flatten(tensor) return tf.reduce_mean(tensor, axis=0)
def call(self, observation, step_type=None, network_state=()): outer_rank = nest_utils.get_outer_rank(observation, self.input_tensor_spec) batch_squash = utils.BatchSquash(outer_rank) states = tf.cast(tf.nest.flatten(observation)[0], tf.float32) states = batch_squash.flatten(states) for layer in self._postprocessing_layers: states = layer(states) value = tf.reshape(states, [-1]) value = batch_squash.unflatten(value) return value, network_state
def call(self, observation, step_type, network_state=(), training=False): """Apply the network. Args: observation: A tuple of tensors matching `input_tensor_spec`. step_type: A tensor of `StepType. network_state: (optional.) The network state. training: Whether the output is being used for training. Returns: `(outputs, network_state)` - the network output and next network state. Raises: ValueError: If observation tensors lack outer `(batch,)` or `(batch, time)` axes. """ num_outer_dims = nest_utils.get_outer_rank(observation, self.input_tensor_spec) if num_outer_dims not in (1, 2): raise ValueError( 'Input observation must have a batch or batch x time outer shape.' ) has_time_dim = num_outer_dims == 2 if not has_time_dim: # Add a time dimension to the inputs. observation = tf.nest.map_structure(lambda t: tf.expand_dims(t, 1), observation) step_type = tf.nest.map_structure(lambda t: tf.expand_dims(t, 1), step_type) state, _ = self._input_encoder(observation, step_type=step_type, network_state=(), training=training) with tf.name_scope('reset_mask'): reset_mask = tf.equal(step_type, time_step.StepType.FIRST) # Unroll over the time sequence. state, network_state = self._dynamic_unroll( state, reset_mask, initial_state=network_state, training=training) for layer in self._output_encoder: state = layer(state, training=training) if not has_time_dim: # Remove time dimension from the state. state = tf.squeeze(state, [1]) return state, network_state
def train_step(self, inputs, state=None): """Perform training on one batch of inputs. Args: inputs (tuple(Tensor, Tensor)): tuple of x and y state: not used Returns: AlgorithmStep outputs (Tensor): shape=[batch_size], its mean is the estimated MI state: not used info (LossInfo): info.loss is the loss """ x, y = inputs num_outer_dims = get_outer_rank(x, self._x_spec) batch_squash = BatchSquash(num_outer_dims) x = batch_squash.flatten(x) y = batch_squash.flatten(y) x1, y1 = self._sampler(x, y) log_ratio = self._model([x, y])[0] t1 = self._model([x1, y1])[0] if self._type == 'DV': ratio = tf.math.exp(tf.minimum(t1, 20)) mean = tf.stop_gradient(tf.reduce_mean(ratio)) if self._mean_averager: self._mean_averager.update(mean) unbiased_mean = tf.stop_gradient(self._mean_averager.get()) else: unbiased_mean = mean # estimated MI = reduce_mean(mi) # ratio/mean-1 does not contribute to the final estimated MI, since # mean(ratio/mean-1) = 0. We add it so that we can have an estimation # of the variance of the MI estimator mi = log_ratio - (tf.math.log(mean) + ratio / mean - 1) loss = ratio / unbiased_mean - log_ratio elif self._type == 'KLD': ratio = tf.math.exp(tf.minimum(t1, 20)) mi = log_ratio - ratio + 1 loss = -mi elif self._type == 'JSD': mi = -tf.nn.softplus(-log_ratio) - tf.nn.softplus(t1) + math.log(4) loss = -mi mi = batch_squash.unflatten(mi) loss = batch_squash.unflatten(loss) return AlgorithmStep(outputs=mi, state=(), info=LossInfo(loss, extra=()))
def call(self, observations, step_type, network_state, training=False): enc_output, network_state = self._encoder( observations, step_type=step_type, network_state=network_state, training=training) outer_rank = nest_utils.get_outer_rank(observations, self.input_tensor_spec) zs = tf.dtypes.cast(enc_output, dtype=tf.float64) #zs = self.project_to_zdim() state = self._action_generator((observations, zs)) state = tf.dtypes.cast(state, dtype=tf.float32) output_actions = tf.nest.map_structure( lambda proj_net: proj_net(state, outer_rank), self._projection_networks) return output_actions, network_state
def _kl_divergence(self, time_steps, action_distribution_parameters, current_policy_distribution): outer_dims = list( range(nest_utils.get_outer_rank(time_steps, self.time_step_spec))) old_actions_distribution = ( distribution_spec.nested_distributions_from_specs( self._action_distribution_spec, action_distribution_parameters)) kl_divergence = ppo_utils.nested_kl_divergence( old_actions_distribution, current_policy_distribution, outer_dims=outer_dims) return kl_divergence
def call(self, observations, step_type, network_state, training=False): if self._mask_xy and len(observations["observation"].shape) == 1: observations["observation"] = observations["observation"][2:] elif self._mask_xy and observations["observation"].shape[0] != 0: observations["observation"] = observations["observation"][:, 2:] state, network_state = self._encoder( observations, step_type=step_type, network_state=network_state, training=training) outer_rank = nest_utils.get_outer_rank(observations, self.input_tensor_spec) output_actions = tf.nest.map_structure( lambda proj_net: proj_net(state, outer_rank), self._projection_networks) return output_actions, network_state
def call(self, observation, step_type=None, network_state=()): del step_type # unused. outer_rank = nest_utils.get_outer_rank(observation, self.observation_spec) batch_squash = utils.BatchSquash(outer_rank) states = tf.cast(nest.flatten(observation)[0], tf.float32) states = batch_squash.flatten(states) for layer in self.layers: states = layer(states) value = tf.reshape(states, [-1]) value = batch_squash.unflatten(value) return value, network_state
def _kl_divergence(self, time_steps, action_distribution_parameters, current_policy_distribution): """Compute mean KL divergence for 2 policies on given batch of timesteps""" outer_dims = list( range(nest_utils.get_outer_rank(time_steps, self.time_step_spec))) old_actions_distribution = distribution_spec.nested_distributions_from_specs( self._action_distribution_spec, action_distribution_parameters["dist_params"]) kl_divergence = ppo_utils.nested_kl_divergence( old_actions_distribution, current_policy_distribution, outer_dims=outer_dims) return kl_divergence
def call(self, inputs, step_type=None, network_state=(), training=False): del step_type # unused. if self._uint8_input: inputs = tf.cast(inputs, tf.float32) / 255.00 if self._batch_squash: outer_rank = nest_utils.get_outer_rank(inputs, self.input_tensor_spec) batch_squash = utils.BatchSquash(outer_rank) inputs = tf.nest.map_structure(batch_squash.flatten, inputs) states = inputs states = self._encoder(states, training=training) if self._batch_squash: states = tf.nest.map_structure(batch_squash.unflatten, states) return states, network_state
def call(self, observations, step_type=(), network_state=()): outer_rank = nest_utils.get_outer_rank(observations, self.input_tensor_spec) # batch_squash, in case observations have a time sequence compoment. batch_squash = utils.BatchSquash(outer_rank) observations = tf.nest.map_structure(batch_squash.flatten, observations) state, network_state = self._encoder(observations, step_type=step_type, network_state=network_state) actions = self._action_projection_layer(state) # actions = common_utils.scale_to_spec(actions, self._action_spec) # actions = batch_squash.unflatten(actions) return tf.nest.pack_sequence_as(self._action_spec, [actions]), network_state
def call(self, observation, step_type, network_state=None): num_outer_dims = nest_utils.get_outer_rank(observation, self.observation_spec) if num_outer_dims not in (1, 2): raise ValueError( 'Input observation must have a batch or batch x time outer shape.' ) has_time_dim = num_outer_dims == 2 if not has_time_dim: # Add a time dimension to the inputs. observation = nest.map_structure(lambda t: tf.expand_dims(t, 1), observation) step_type = nest.map_structure(lambda t: tf.expand_dims(t, 1), step_type) states = tf.to_float(nest.flatten(observation)[0]) batch_squash = utils.BatchSquash(2) # Squash B, and T dims. states = batch_squash.flatten(states) for layer in self._input_layers: states = layer(states) states = batch_squash.unflatten(states) with tf.name_scope('reset_mask'): reset_mask = tf.equal(step_type, time_step.StepType.FIRST) # Unroll over the time sequence. states, network_state, _ = rnn_utils.dynamic_unroll( self._cell, states, reset_mask, initial_state=network_state, dtype=tf.float32) states = batch_squash.flatten(states) for layer in self._output_layers: states = layer(states) states = batch_squash.unflatten(states) outputs = [ projection(states, num_outer_dims) for projection in self._projection_networks ] return nest.pack_sequence_as(self._action_spec, outputs), network_state
def call(self, observation, step_type, network_state=None): num_outer_dims = nest_utils.get_outer_rank(observation, self._observation_spec) if num_outer_dims not in (1, 2): raise ValueError( 'Input observation must have a batch or batch x time outer shape.' ) has_time_dim = num_outer_dims == 2 if not has_time_dim: # Add a time dimension to the inputs. observation = nest.map_structure(lambda t: tf.expand_dims(t, 1), observation) step_type = nest.map_structure(lambda t: tf.expand_dims(t, 1), step_type) state = tf.to_float(nest.flatten(observation)[0]) num_feature_dims = 3 if self._conv_layer_params else 1 state.shape.with_rank_at_least(num_feature_dims) batch_squash = utils.BatchSquash(state.shape.ndims - num_feature_dims) state = batch_squash.flatten(state) state, network_state = self._input_encoder(state, step_type, network_state) state = batch_squash.unflatten(state) with tf.name_scope('reset_mask'): reset_mask = tf.equal(step_type, time_step.StepType.FIRST) # Unroll over the time sequence. state, network_state, _ = rnn_utils.dynamic_unroll( self._cell, state, reset_mask, initial_state=network_state, dtype=tf.float32) state = batch_squash.flatten(state) for layer in self._output_encoder: state = layer(state) state = batch_squash.unflatten(state) if not has_time_dim: # Remove time dimension from the state. state = tf.squeeze(state, [1]) return state, network_state
def _maybe_reset_state(self, time_step, policy_state): if policy_state is (): # pylint: disable=literal-comparison return policy_state batch_size = tf.compat.dimension_value(time_step.discount.shape[0]) if batch_size is None: batch_size = tf.shape(time_step.discount)[0] # Make sure we call this with a kwarg as it may be wrapped in tf.function # which would expect a tensor if it was not a kwarg. zero_state = self.get_initial_state(batch_size=batch_size) condition = time_step.is_first() # When experience is a sequence we only reset automatically for the first # time_step in the sequence as we can't easily generalize how the policy is # unrolled over the sequence. if nest_utils.get_outer_rank(time_step, self._time_step_spec) > 1: condition = time_step.is_first()[:, 0, ...] return nest_utils.where(condition, zero_state, policy_state)
def call(self, observations, step_type=(), network_state=()): outer_rank = nest_utils.get_outer_rank(observations, self.input_tensor_spec) batch_squash = BatchSquash(outer_rank) observations = nest.map_structure(batch_squash.flatten, observations) state, network_state = self._encoder(observations, step_type=step_type, network_state=network_state) actions = self._action_projection_layer(state) actions = scale_to_spec(actions, self._single_action_spec) actions = batch_squash.unflatten(actions) return nest.pack_sequence_as(self._action_spec, [actions]), network_state
def call(self, observations, step_type=(), network_state=(), training=False): if self._image_encoder: encoded, network_state = self._image_encoder( observations, training=training) encoded = self._fc_encoder(encoded) else: # dm_control state observations need to be flattened as they are # structured as a dict(position, velocity) encoded = tf.keras.layers.concatenate( [observations['position'], observations['velocity']]) encoded = self._dense_layers(encoded) outer_rank = nest_utils.get_outer_rank(observations, self.input_tensor_spec) action_distribution, network_state = self._distribution_projection_network( encoded, outer_rank, training=training) return action_distribution, network_state
def _normalize(m, m2, spec, t): # in some extreme cases, due to floating errors, var might be a very # large negative value (close to 0) var = tf.nn.relu(m2 - tf.square(m)) outer_dims = get_outer_rank(t, spec) batch_squash = BatchSquash(outer_dims) t = batch_squash.flatten(t) t = tf.nn.batch_normalization( t, m, var, offset=None, scale=None, variance_epsilon=self._variance_epsilon) if clip_value > 0: t = tf.clip_by_value(t, -clip_value, clip_value) t = batch_squash.unflatten(t) return t
def call(self, observations, step_type, network_state): del step_type # unused. outer_rank = nest_utils.get_outer_rank(observations, self.input_tensor_spec) observations = tf.nest.flatten(observations) states = tf.cast(observations[0], tf.float32) # Reshape to only a single batch dimension for neural network functions. batch_squash = utils.BatchSquash(outer_rank) states = batch_squash.flatten(states) for layer in self._mlp_layers: states = layer(states) # TODO(oars): Can we avoid unflattening to flatten again states = batch_squash.unflatten(states) output_actions = tf.nest.map_structure( lambda proj_net: proj_net(states, outer_rank), self._projection_networks) return output_actions, network_state
def call(self, observations, step_type, network_state, training=False, mask=None): if len(tf.shape(observations)) == 2 or len( tf.shape(observations)) == 1: observations = tf.reshape(observations, [1, -1]) if len(tf.shape(observations)) == 3: observations = tf.squeeze(observations, axis=0) embeddings = self._gnn(observations, training=training) # extract ego state (node 0) # print(embeddings) if tf.shape(embeddings)[0] > 0: embeddings = embeddings[:, 0] with tf.name_scope("PPOActorNetwork"): tf.summary.histogram("embedding", embeddings) state, network_state = self._encoder(embeddings, step_type=step_type, network_state=network_state, training=training) outer_rank = nest_utils.get_outer_rank(observations, self.input_tensor_spec) def call_projection_net(proj_net): distribution, _ = proj_net(state, outer_rank, training=training, mask=mask) return distribution output_actions = tf.nest.map_structure(call_projection_net, self._projection_networks) # print(output_actions, "output_actions") return output_actions, network_state
def call(self, observation, step_type=None, network_state=(), training=False): """Runs the given observation through the network. Args: observation: The observation to provide to the network. step_type: The step type for the given observation. See `StepType` in time_step.py. network_state: A state tuple to pass to the network, mainly used by RNNs. training: Whether the output is being used for training. Returns: A tuple `(logits, network_state)`. """ # observation shape = [batch_size, seq_len, ...] or [batch_size, ...] num_outer_dims = nest_utils.get_outer_rank(observation, self.input_tensor_spec) if num_outer_dims == 2: seq_length = observation.shape[1] else: seq_length = 1 look_ahead_mask = self._create_look_ahead_mask( seq_length) # (seq_len, seq_len) output, network_state = self._encoder(observation, step_type, network_state=network_state, training=training, mask=look_ahead_mask) q_value = self._q_value_layer(output, training=training) if not training and self._output_last_state: # Remove time dimension during inference/evaluation # and only output last element of output sequence to # get action of dimension (batch_size, ) instead of (batch_size, 1, ) if num_outer_dims == 2: q_value = tf.squeeze(q_value, axis=1) return q_value, network_state
def call(self, observation, step_type=None, network_state=()): del step_type # unused. if self._batch_squash: outer_rank = nest_utils.get_outer_rank(observation, self.observation_spec) batch_squash = utils.BatchSquash(outer_rank) # Get single observation out regardless of nesting. states = tf.to_float(nest.flatten(observation)[0]) if self._batch_squash: states = batch_squash.flatten(states) for layer in self.layers: states = layer(states) if self._batch_squash: states = batch_squash.unflatten(states) return states, network_state
def call(self, observations, step_type, network_state=(), training=False): obs, ac, alpha = observations pre_obs, _ = self._obs_encoder(obs, step_type=step_type, network_state=network_state, training=training) pre_alpha, _ = self._alph_encoder(alpha, step_type=step_type, network_state=network_state, training=training) observations = (pre_obs, ac, pre_alpha) state, network_state = self._encoder(observations, step_type=step_type, network_state=network_state, training=training) outer_rank = nest_utils.get_outer_rank(observations, self.encoder_input_tensor_spec) q_distribution, _ = self._projection_network(state, outer_rank) return q_distribution, network_state
def call(self, observations, step_type=(), network_state=()): outer_rank = nest_utils.get_outer_rank(observations, self.input_tensor_spec) # batch_squash, in case observations have a time sequence compoment. batch_squash = utils.BatchSquash(outer_rank) observations = tf.nest.map_structure(batch_squash.flatten, observations) flat_x = self._flat_x(observations) cnn = self._cnn_1(observations) flat_cnn = self._flat_cnn(cnn) concat_cnn_x = tf.keras.layers.concatenate([flat_x, flat_cnn]) dense_1 = self._dense_1(concat_cnn_x) dense_2 = self._dense_2(dense_1) concat_dense_x = tf.keras.layers.concatenate([flat_x, dense_2]) actions = self._action_projection_layer(concat_dense_x) return tf.nest.pack_sequence_as(self._action_spec, [actions]), network_state
def call(self, inputs, step_type=None, network_state=(), training=False): del step_type # unused. if self._batch_squash: outer_rank = nest_utils.get_outer_rank(inputs, self.input_tensor_spec) batch_squash = utils.BatchSquash(outer_rank) inputs = tf.nest.map_structure(batch_squash.flatten, inputs) states = tf.concat(inputs, axis=-1) states = self._fc_encoder(states, training=training) if self._batch_squash: states = tf.nest.map_structure(batch_squash.unflatten, states) loc = states[..., :self.output_dim] if self.scale is None: scale_diag = tf.nn.softplus(states[..., self.output_dim:]) scale_diag *= 0.693 / tf.nn.softplus(0.) scale_diag += 1e-6 else: scale_diag = tf.ones_like(loc) * self.scale return (loc, scale_diag), network_state
def call(self, observation, step_type, network_state=None): outer_rank = nest_utils.get_outer_rank(observation, self.input_tensor_spec) batch_squash = utils.BatchSquash(outer_rank) observation, network_state = self._lstm_encoder( observation, step_type=step_type, network_state=network_state) states = batch_squash.flatten(observation) actions = [] for layer, spec in zip(self._action_layers, self._flat_action_spec): action = layer(states) action = common.scale_to_spec(action, spec) action = batch_squash.unflatten(action) actions.append(action) output_actions = tf.nest.pack_sequence_as(self._output_tensor_spec, actions) return output_actions, network_state
def call(self, observation, step_type=None, network_state=(), training=False): del step_type # unused. if self._batch_squash: outer_rank = nest_utils.get_outer_rank(observation, self.input_tensor_spec) batch_squash = utils.BatchSquash(outer_rank) observation = tf.nest.map_structure(batch_squash.flatten, observation) if self._flat_preprocessing_layers is None: processed = observation else: processed = [] for obs, layer in zip( nest.flatten_up_to(self._preprocessing_nest, observation, check_types=False), self._flat_preprocessing_layers): processed.append(layer(obs, training=training)) if len(processed) == 1 and self._preprocessing_combiner is None: # If only one observation is passed and the preprocessing_combiner # is unspecified, use the preprocessed version of this observation. processed = processed[0] states = processed if self._preprocessing_combiner is not None: states = self._preprocessing_combiner(states) for layer in self._postprocessing_layers: states = layer(states, training=training) if self._batch_squash: states = tf.nest.map_structure(batch_squash.unflatten, states) return states, network_state
def call(self, observation, step_type=None, network_state=None): num_outer_dims = nest_utils.get_outer_rank(observation, self.input_tensor_spec) if num_outer_dims not in (1, 2): raise ValueError( 'Input observation must have a batch or batch x time outer shape.') has_time_dim = num_outer_dims == 2 if not has_time_dim: # Add a time dimension to the inputs. observation = tf.nest.map_structure(lambda t: tf.expand_dims(t, 1), observation) step_type = tf.nest.map_structure(lambda t: tf.expand_dims(t, 1), step_type) states = tf.cast(tf.nest.flatten(observation)[0], tf.float32) batch_squash = utils.BatchSquash(2) # Squash B, and T dims. states = batch_squash.flatten(states) for layer in self._input_layers: states = layer(states) states = batch_squash.unflatten(states) with tf.name_scope('reset_mask'): reset_mask = tf.equal(step_type, time_step.StepType.FIRST) # Unroll over the time sequence. states, network_state = self._dynamic_unroll( states, reset_mask, initial_state=network_state) states = batch_squash.flatten(states) for layer in self._output_layers: states = layer(states) value = self._value_projection_layer(states) value = tf.reshape(value, [-1]) value = batch_squash.unflatten(value) return value, network_state
def call( self, observations: Union[tf.Tensor, np.ndarray], step_type: Optional[Any], network_state: Union[Tuple, Tuple[Union[tf.Tensor, np.ndarray]]] = () ) -> Tuple[Union[tfp.distributions.OneHotCategorical, Tuple[tfp.distributions.OneHotCategorical]], Union[ Tuple, Tuple[Union[tf.Tensor, np.ndarray]]]]: """ Run a forward pass of the action network mapping observations to a distribution over actions. :param observations: Tensor/Array of observation values from the environment. :param step_type: Not used in this network. Kept as an argument to be consistent with the standard TensorFlow Agents interface. :param network_state: The state of the network. Not required here as this network has no state since it is not recurrent. :return: A distribution over actions and the current network state. """ # Use shared layers to attain inputs shared across each head. hidden_activations = tf.cast(observations, tf.float32) for layer in self._shared_layers: hidden_activations = layer(hidden_activations) # Determine the number of batch dimensions. Since this requires comparison to the input # tensor spec and the batch dimensions are preserved by the shared linear layers we # calculate batch dimensions based on the supplied observations. outer_rank = nest_utils.get_outer_rank(observations, self.input_tensor_spec) # Attain a nested set of actions i.e. a tuple of actions one for each head. action_dist = tf.nest.map_structure( lambda proj_net: proj_net(hidden_activations, outer_rank)[0], self._action_heads) # If there is only one action head unpack the tuple of 1 to attain the singular action # distribution itself. if len(self._action_subspace_dimensions) == 1: action_dist = action_dist[0] return action_dist, network_state