コード例 #1
0
 def call(self, observation, step_type, network_state=(), training=False):
     state, network_state = self._lstm_encoder(observation,
                                               step_type=step_type,
                                               network_state=network_state,
                                               training=training)
     outer_rank = nest_utils.get_outer_rank(observation,
                                            self.input_tensor_spec)
     output_actions = tf.nest.map_structure(
         lambda proj_net: proj_net(state, outer_rank, training=training)[0],
         self._projection_networks)
     return output_actions, network_state
コード例 #2
0
    def call(self, observation, step_type, network_state=(), training=False):
        num_outer_dims = nest_utils.get_outer_rank(observation,
                                                   self.input_tensor_spec)
        if num_outer_dims not in (1, 2):
            raise ValueError(
                'Input observation must have a batch or batch x time outer shape.'
            )

        has_time_dim = num_outer_dims == 2
        if not has_time_dim:
            # Add a time dimension to the inputs.
            observation = tf.nest.map_structure(lambda t: tf.expand_dims(t, 1),
                                                observation)
            step_type = tf.nest.map_structure(lambda t: tf.expand_dims(t, 1),
                                              step_type)

        states = tf.cast(tf.nest.flatten(observation)[0], tf.float32)
        batch_squash = utils.BatchSquash(2)  # Squash B, and T dims.
        states = batch_squash.flatten(states)  # [B, T, ...] -> [B x T, ...]

        for layer in self._input_layers:
            states = layer(states, training=training)

        states = batch_squash.unflatten(states)  # [B x T, ...] -> [B, T, ...]

        with tf.name_scope('reset_mask'):
            reset_mask = tf.equal(step_type, time_step.StepType.FIRST)
        # Unroll over the time sequence.
        states, network_state = self._dynamic_unroll(
            states,
            reset_mask=reset_mask,
            initial_state=network_state,
            training=training)

        states = batch_squash.flatten(states)  # [B, T, ...] -> [B x T, ...]

        for layer in self._output_layers:
            states = layer(states, training=training)

        actions = []
        for layer, spec in zip(self._action_layers, self._flat_action_spec):
            action = layer(states, training=training)
            action = common.scale_to_spec(action, spec)
            action = batch_squash.unflatten(
                action)  # [B x T, ...] -> [B, T, ...]
            if not has_time_dim:
                action = tf.squeeze(action, axis=1)
            actions.append(action)

        output_actions = tf.nest.pack_sequence_as(self._output_tensor_spec,
                                                  actions)
        return output_actions, network_state
コード例 #3
0
def average_outer_dims(tensor, spec):
    """
    Args:
        tensor (tf.Tensor): a single Tensor
        spec (tf.TensorSpec):

    Returns:
        the average tensor across outer dims
    """
    outer_dims = get_outer_rank(tensor, spec)
    batch_squash = BatchSquash(outer_dims)
    tensor = batch_squash.flatten(tensor)
    return tf.reduce_mean(tensor, axis=0)
コード例 #4
0
ファイル: value_network.py プロジェクト: ssghost/agents
  def call(self, observation, step_type=None, network_state=()):
    outer_rank = nest_utils.get_outer_rank(observation,
                                           self.input_tensor_spec)
    batch_squash = utils.BatchSquash(outer_rank)

    states = tf.cast(tf.nest.flatten(observation)[0], tf.float32)
    states = batch_squash.flatten(states)
    for layer in self._postprocessing_layers:
      states = layer(states)

    value = tf.reshape(states, [-1])
    value = batch_squash.unflatten(value)
    return value, network_state
コード例 #5
0
    def call(self, observation, step_type, network_state=(), training=False):
        """Apply the network.

    Args:
      observation: A tuple of tensors matching `input_tensor_spec`.
      step_type: A tensor of `StepType.
      network_state: (optional.) The network state.
      training: Whether the output is being used for training.

    Returns:
      `(outputs, network_state)` - the network output and next network state.

    Raises:
      ValueError: If observation tensors lack outer `(batch,)` or
        `(batch, time)` axes.
    """
        num_outer_dims = nest_utils.get_outer_rank(observation,
                                                   self.input_tensor_spec)
        if num_outer_dims not in (1, 2):
            raise ValueError(
                'Input observation must have a batch or batch x time outer shape.'
            )

        has_time_dim = num_outer_dims == 2
        if not has_time_dim:
            # Add a time dimension to the inputs.
            observation = tf.nest.map_structure(lambda t: tf.expand_dims(t, 1),
                                                observation)
            step_type = tf.nest.map_structure(lambda t: tf.expand_dims(t, 1),
                                              step_type)

        state, _ = self._input_encoder(observation,
                                       step_type=step_type,
                                       network_state=(),
                                       training=training)

        with tf.name_scope('reset_mask'):
            reset_mask = tf.equal(step_type, time_step.StepType.FIRST)

        # Unroll over the time sequence.
        state, network_state = self._dynamic_unroll(
            state, reset_mask, initial_state=network_state, training=training)

        for layer in self._output_encoder:
            state = layer(state, training=training)

        if not has_time_dim:
            # Remove time dimension from the state.
            state = tf.squeeze(state, [1])

        return state, network_state
コード例 #6
0
ファイル: mi_estimator.py プロジェクト: zhaoyinfu123/alf
    def train_step(self, inputs, state=None):
        """Perform training on one batch of inputs.

        Args:
            inputs (tuple(Tensor, Tensor)): tuple of x and y
            state: not used
        Returns:
            AlgorithmStep
                outputs (Tensor): shape=[batch_size], its mean is the estimated
                    MI
                state: not used
                info (LossInfo): info.loss is the loss
        """
        x, y = inputs
        num_outer_dims = get_outer_rank(x, self._x_spec)
        batch_squash = BatchSquash(num_outer_dims)
        x = batch_squash.flatten(x)
        y = batch_squash.flatten(y)
        x1, y1 = self._sampler(x, y)

        log_ratio = self._model([x, y])[0]
        t1 = self._model([x1, y1])[0]

        if self._type == 'DV':
            ratio = tf.math.exp(tf.minimum(t1, 20))
            mean = tf.stop_gradient(tf.reduce_mean(ratio))
            if self._mean_averager:
                self._mean_averager.update(mean)
                unbiased_mean = tf.stop_gradient(self._mean_averager.get())
            else:
                unbiased_mean = mean
            # estimated MI = reduce_mean(mi)
            # ratio/mean-1 does not contribute to the final estimated MI, since
            # mean(ratio/mean-1) = 0. We add it so that we can have an estimation
            # of the variance of the MI estimator
            mi = log_ratio - (tf.math.log(mean) + ratio / mean - 1)
            loss = ratio / unbiased_mean - log_ratio
        elif self._type == 'KLD':
            ratio = tf.math.exp(tf.minimum(t1, 20))
            mi = log_ratio - ratio + 1
            loss = -mi
        elif self._type == 'JSD':
            mi = -tf.nn.softplus(-log_ratio) - tf.nn.softplus(t1) + math.log(4)
            loss = -mi

        mi = batch_squash.unflatten(mi)
        loss = batch_squash.unflatten(loss)

        return AlgorithmStep(outputs=mi,
                             state=(),
                             info=LossInfo(loss, extra=()))
コード例 #7
0
 def call(self, observations, step_type, network_state, training=False):
   enc_output, network_state = self._encoder(
       observations,
       step_type=step_type,
       network_state=network_state,
       training=training)
   outer_rank = nest_utils.get_outer_rank(observations, self.input_tensor_spec)
   zs = tf.dtypes.cast(enc_output, dtype=tf.float64)
   #zs = self.project_to_zdim()
   state = self._action_generator((observations, zs))
   state = tf.dtypes.cast(state, dtype=tf.float32)
   output_actions = tf.nest.map_structure(
       lambda proj_net: proj_net(state, outer_rank), self._projection_networks)
   return output_actions, network_state
コード例 #8
0
ファイル: ppo_agent.py プロジェクト: hhy5277/agents-1
  def _kl_divergence(self, time_steps, action_distribution_parameters,
                     current_policy_distribution):
    outer_dims = list(
        range(nest_utils.get_outer_rank(time_steps, self.time_step_spec)))

    old_actions_distribution = (
        distribution_spec.nested_distributions_from_specs(
            self._action_distribution_spec, action_distribution_parameters))

    kl_divergence = ppo_utils.nested_kl_divergence(
        old_actions_distribution,
        current_policy_distribution,
        outer_dims=outer_dims)
    return kl_divergence
コード例 #9
0
 def call(self, observations, step_type, network_state, training=False):
   if self._mask_xy and len(observations["observation"].shape) == 1:
     observations["observation"] = observations["observation"][2:]
   elif self._mask_xy and observations["observation"].shape[0] != 0:
     observations["observation"] = observations["observation"][:, 2:]
   state, network_state = self._encoder(
       observations,
       step_type=step_type,
       network_state=network_state,
       training=training)
   outer_rank = nest_utils.get_outer_rank(observations, self.input_tensor_spec)
   output_actions = tf.nest.map_structure(
       lambda proj_net: proj_net(state, outer_rank), self._projection_networks)
   return output_actions, network_state
コード例 #10
0
ファイル: value_network.py プロジェクト: kuanfang/agents
    def call(self, observation, step_type=None, network_state=()):
        del step_type  # unused.

        outer_rank = nest_utils.get_outer_rank(observation,
                                               self.observation_spec)
        batch_squash = utils.BatchSquash(outer_rank)

        states = tf.cast(nest.flatten(observation)[0], tf.float32)
        states = batch_squash.flatten(states)
        for layer in self.layers:
            states = layer(states)

        value = tf.reshape(states, [-1])
        value = batch_squash.unflatten(value)
        return value, network_state
コード例 #11
0
    def _kl_divergence(self, time_steps, action_distribution_parameters,
                       current_policy_distribution):
        """Compute mean KL divergence for 2 policies on given batch of timesteps"""
        outer_dims = list(
            range(nest_utils.get_outer_rank(time_steps, self.time_step_spec)))

        old_actions_distribution = distribution_spec.nested_distributions_from_specs(
            self._action_distribution_spec,
            action_distribution_parameters["dist_params"])

        kl_divergence = ppo_utils.nested_kl_divergence(
            old_actions_distribution,
            current_policy_distribution,
            outer_dims=outer_dims)
        return kl_divergence
コード例 #12
0
    def call(self, inputs, step_type=None, network_state=(), training=False):
        del step_type  # unused.

        if self._uint8_input:
            inputs = tf.cast(inputs, tf.float32) / 255.00
        if self._batch_squash:
            outer_rank = nest_utils.get_outer_rank(inputs,
                                                   self.input_tensor_spec)
            batch_squash = utils.BatchSquash(outer_rank)
            inputs = tf.nest.map_structure(batch_squash.flatten, inputs)
        states = inputs
        states = self._encoder(states, training=training)
        if self._batch_squash:
            states = tf.nest.map_structure(batch_squash.unflatten, states)
        return states, network_state
コード例 #13
0
    def call(self, observations, step_type=(), network_state=()):
        outer_rank = nest_utils.get_outer_rank(observations,
                                               self.input_tensor_spec)
        # batch_squash, in case observations have a time sequence compoment.
        batch_squash = utils.BatchSquash(outer_rank)
        observations = tf.nest.map_structure(batch_squash.flatten,
                                             observations)

        state, network_state = self._encoder(observations,
                                             step_type=step_type,
                                             network_state=network_state)
        actions = self._action_projection_layer(state)
        # actions = common_utils.scale_to_spec(actions, self._action_spec)
        # actions = batch_squash.unflatten(actions)
        return tf.nest.pack_sequence_as(self._action_spec,
                                        [actions]), network_state
コード例 #14
0
    def call(self, observation, step_type, network_state=None):
        num_outer_dims = nest_utils.get_outer_rank(observation,
                                                   self.observation_spec)
        if num_outer_dims not in (1, 2):
            raise ValueError(
                'Input observation must have a batch or batch x time outer shape.'
            )

        has_time_dim = num_outer_dims == 2
        if not has_time_dim:
            # Add a time dimension to the inputs.
            observation = nest.map_structure(lambda t: tf.expand_dims(t, 1),
                                             observation)
            step_type = nest.map_structure(lambda t: tf.expand_dims(t, 1),
                                           step_type)

        states = tf.to_float(nest.flatten(observation)[0])
        batch_squash = utils.BatchSquash(2)  # Squash B, and T dims.
        states = batch_squash.flatten(states)

        for layer in self._input_layers:
            states = layer(states)

        states = batch_squash.unflatten(states)

        with tf.name_scope('reset_mask'):
            reset_mask = tf.equal(step_type, time_step.StepType.FIRST)
        # Unroll over the time sequence.
        states, network_state, _ = rnn_utils.dynamic_unroll(
            self._cell,
            states,
            reset_mask,
            initial_state=network_state,
            dtype=tf.float32)

        states = batch_squash.flatten(states)

        for layer in self._output_layers:
            states = layer(states)

        states = batch_squash.unflatten(states)
        outputs = [
            projection(states, num_outer_dims)
            for projection in self._projection_networks
        ]

        return nest.pack_sequence_as(self._action_spec, outputs), network_state
コード例 #15
0
    def call(self, observation, step_type, network_state=None):
        num_outer_dims = nest_utils.get_outer_rank(observation,
                                                   self._observation_spec)
        if num_outer_dims not in (1, 2):
            raise ValueError(
                'Input observation must have a batch or batch x time outer shape.'
            )

        has_time_dim = num_outer_dims == 2
        if not has_time_dim:
            # Add a time dimension to the inputs.
            observation = nest.map_structure(lambda t: tf.expand_dims(t, 1),
                                             observation)
            step_type = nest.map_structure(lambda t: tf.expand_dims(t, 1),
                                           step_type)

        state = tf.to_float(nest.flatten(observation)[0])

        num_feature_dims = 3 if self._conv_layer_params else 1
        state.shape.with_rank_at_least(num_feature_dims)
        batch_squash = utils.BatchSquash(state.shape.ndims - num_feature_dims)

        state = batch_squash.flatten(state)
        state, network_state = self._input_encoder(state, step_type,
                                                   network_state)
        state = batch_squash.unflatten(state)

        with tf.name_scope('reset_mask'):
            reset_mask = tf.equal(step_type, time_step.StepType.FIRST)
        # Unroll over the time sequence.
        state, network_state, _ = rnn_utils.dynamic_unroll(
            self._cell,
            state,
            reset_mask,
            initial_state=network_state,
            dtype=tf.float32)

        state = batch_squash.flatten(state)
        for layer in self._output_encoder:
            state = layer(state)
        state = batch_squash.unflatten(state)

        if not has_time_dim:
            # Remove time dimension from the state.
            state = tf.squeeze(state, [1])

        return state, network_state
コード例 #16
0
ファイル: tf_policy.py プロジェクト: ymodak/agents
    def _maybe_reset_state(self, time_step, policy_state):
        if policy_state is ():  # pylint: disable=literal-comparison
            return policy_state

        batch_size = tf.compat.dimension_value(time_step.discount.shape[0])
        if batch_size is None:
            batch_size = tf.shape(time_step.discount)[0]

        # Make sure we call this with a kwarg as it may be wrapped in tf.function
        # which would expect a tensor if it was not a kwarg.
        zero_state = self.get_initial_state(batch_size=batch_size)
        condition = time_step.is_first()
        # When experience is a sequence we only reset automatically for the first
        # time_step in the sequence as we can't easily generalize how the policy is
        # unrolled over the sequence.
        if nest_utils.get_outer_rank(time_step, self._time_step_spec) > 1:
            condition = time_step.is_first()[:, 0, ...]
        return nest_utils.where(condition, zero_state, policy_state)
コード例 #17
0
ファイル: actor_critic.py プロジェクト: SyNthw8ve/Thruster
    def call(self, observations, step_type=(), network_state=()):

        outer_rank = nest_utils.get_outer_rank(observations,
                                               self.input_tensor_spec)

        batch_squash = BatchSquash(outer_rank)
        observations = nest.map_structure(batch_squash.flatten, observations)

        state, network_state = self._encoder(observations,
                                             step_type=step_type,
                                             network_state=network_state)

        actions = self._action_projection_layer(state)
        actions = scale_to_spec(actions, self._single_action_spec)
        actions = batch_squash.unflatten(actions)

        return nest.pack_sequence_as(self._action_spec,
                                     [actions]), network_state
コード例 #18
0
  def call(self, observations, step_type=(), network_state=(), training=False):
    if self._image_encoder:
      encoded, network_state = self._image_encoder(
          observations, training=training)
      encoded = self._fc_encoder(encoded)
    else:
      # dm_control state observations need to be flattened as they are
      # structured as a dict(position, velocity)
      encoded = tf.keras.layers.concatenate(
          [observations['position'], observations['velocity']])

    encoded = self._dense_layers(encoded)

    outer_rank = nest_utils.get_outer_rank(observations, self.input_tensor_spec)
    action_distribution, network_state = self._distribution_projection_network(
        encoded, outer_rank, training=training)

    return action_distribution, network_state
コード例 #19
0
ファイル: normalizers.py プロジェクト: runjerry/alf
 def _normalize(m, m2, spec, t):
     # in some extreme cases, due to floating errors, var might be a very
     # large negative value (close to 0)
     var = tf.nn.relu(m2 - tf.square(m))
     outer_dims = get_outer_rank(t, spec)
     batch_squash = BatchSquash(outer_dims)
     t = batch_squash.flatten(t)
     t = tf.nn.batch_normalization(
         t,
         m,
         var,
         offset=None,
         scale=None,
         variance_epsilon=self._variance_epsilon)
     if clip_value > 0:
         t = tf.clip_by_value(t, -clip_value, clip_value)
     t = batch_squash.unflatten(t)
     return t
コード例 #20
0
  def call(self, observations, step_type, network_state):
    del step_type  # unused.
    outer_rank = nest_utils.get_outer_rank(observations, self.input_tensor_spec)
    observations = tf.nest.flatten(observations)
    states = tf.cast(observations[0], tf.float32)

    # Reshape to only a single batch dimension for neural network functions.
    batch_squash = utils.BatchSquash(outer_rank)
    states = batch_squash.flatten(states)

    for layer in self._mlp_layers:
      states = layer(states)

    # TODO(oars): Can we avoid unflattening to flatten again
    states = batch_squash.unflatten(states)
    output_actions = tf.nest.map_structure(
        lambda proj_net: proj_net(states, outer_rank),
        self._projection_networks)
    return output_actions, network_state
コード例 #21
0
    def call(self,
             observations,
             step_type,
             network_state,
             training=False,
             mask=None):

        if len(tf.shape(observations)) == 2 or len(
                tf.shape(observations)) == 1:
            observations = tf.reshape(observations, [1, -1])

        if len(tf.shape(observations)) == 3:
            observations = tf.squeeze(observations, axis=0)

        embeddings = self._gnn(observations, training=training)
        # extract ego state (node 0)
        # print(embeddings)

        if tf.shape(embeddings)[0] > 0:
            embeddings = embeddings[:, 0]

        with tf.name_scope("PPOActorNetwork"):
            tf.summary.histogram("embedding", embeddings)

        state, network_state = self._encoder(embeddings,
                                             step_type=step_type,
                                             network_state=network_state,
                                             training=training)

        outer_rank = nest_utils.get_outer_rank(observations,
                                               self.input_tensor_spec)

        def call_projection_net(proj_net):
            distribution, _ = proj_net(state,
                                       outer_rank,
                                       training=training,
                                       mask=mask)
            return distribution

        output_actions = tf.nest.map_structure(call_projection_net,
                                               self._projection_networks)
        # print(output_actions, "output_actions")
        return output_actions, network_state
コード例 #22
0
    def call(self,
             observation,
             step_type=None,
             network_state=(),
             training=False):
        """Runs the given observation through the network.
		Args:
			observation: The observation to provide to the network.
			step_type: The step type for the given observation. See `StepType` in
				time_step.py.
			network_state: A state tuple to pass to the network, mainly used by RNNs.
			training: Whether the output is being used for training.
		Returns:
			A tuple `(logits, network_state)`.
		"""

        # observation shape = [batch_size, seq_len, ...] or [batch_size, ...]
        num_outer_dims = nest_utils.get_outer_rank(observation,
                                                   self.input_tensor_spec)
        if num_outer_dims == 2:
            seq_length = observation.shape[1]
        else:
            seq_length = 1

        look_ahead_mask = self._create_look_ahead_mask(
            seq_length)  # (seq_len, seq_len)

        output, network_state = self._encoder(observation,
                                              step_type,
                                              network_state=network_state,
                                              training=training,
                                              mask=look_ahead_mask)

        q_value = self._q_value_layer(output, training=training)

        if not training and self._output_last_state:
            # Remove time dimension during inference/evaluation
            # and only output last element of output sequence to
            # get action of dimension (batch_size, ) instead of (batch_size, 1, )
            if num_outer_dims == 2:
                q_value = tf.squeeze(q_value, axis=1)

        return q_value, network_state
コード例 #23
0
ファイル: encoding_network.py プロジェクト: GitHub30/agents
  def call(self, observation, step_type=None, network_state=()):
    del step_type  # unused.

    if self._batch_squash:
      outer_rank = nest_utils.get_outer_rank(observation, self.observation_spec)
      batch_squash = utils.BatchSquash(outer_rank)

    # Get single observation out regardless of nesting.
    states = tf.to_float(nest.flatten(observation)[0])

    if self._batch_squash:
      states = batch_squash.flatten(states)

    for layer in self.layers:
      states = layer(states)

    if self._batch_squash:
      states = batch_squash.unflatten(states)
    return states, network_state
コード例 #24
0
ファイル: agents.py プロジェクト: krishpop/sqrl
 def call(self, observations, step_type, network_state=(), training=False):
     obs, ac, alpha = observations
     pre_obs, _ = self._obs_encoder(obs,
                                    step_type=step_type,
                                    network_state=network_state,
                                    training=training)
     pre_alpha, _ = self._alph_encoder(alpha,
                                       step_type=step_type,
                                       network_state=network_state,
                                       training=training)
     observations = (pre_obs, ac, pre_alpha)
     state, network_state = self._encoder(observations,
                                          step_type=step_type,
                                          network_state=network_state,
                                          training=training)
     outer_rank = nest_utils.get_outer_rank(observations,
                                            self.encoder_input_tensor_spec)
     q_distribution, _ = self._projection_network(state, outer_rank)
     return q_distribution, network_state
コード例 #25
0
    def call(self, observations, step_type=(), network_state=()):
        outer_rank = nest_utils.get_outer_rank(observations,
                                               self.input_tensor_spec)
        # batch_squash, in case observations have a time sequence compoment.
        batch_squash = utils.BatchSquash(outer_rank)
        observations = tf.nest.map_structure(batch_squash.flatten,
                                             observations)

        flat_x = self._flat_x(observations)
        cnn = self._cnn_1(observations)
        flat_cnn = self._flat_cnn(cnn)
        concat_cnn_x = tf.keras.layers.concatenate([flat_x, flat_cnn])
        dense_1 = self._dense_1(concat_cnn_x)
        dense_2 = self._dense_2(dense_1)
        concat_dense_x = tf.keras.layers.concatenate([flat_x, dense_2])
        actions = self._action_projection_layer(concat_dense_x)

        return tf.nest.pack_sequence_as(self._action_spec,
                                        [actions]), network_state
コード例 #26
0
 def call(self, inputs, step_type=None, network_state=(), training=False):
     del step_type  # unused.
     if self._batch_squash:
         outer_rank = nest_utils.get_outer_rank(inputs,
                                                self.input_tensor_spec)
         batch_squash = utils.BatchSquash(outer_rank)
         inputs = tf.nest.map_structure(batch_squash.flatten, inputs)
     states = tf.concat(inputs, axis=-1)
     states = self._fc_encoder(states, training=training)
     if self._batch_squash:
         states = tf.nest.map_structure(batch_squash.unflatten, states)
     loc = states[..., :self.output_dim]
     if self.scale is None:
         scale_diag = tf.nn.softplus(states[..., self.output_dim:])
         scale_diag *= 0.693 / tf.nn.softplus(0.)
         scale_diag += 1e-6
     else:
         scale_diag = tf.ones_like(loc) * self.scale
     return (loc, scale_diag), network_state
コード例 #27
0
ファイル: actor_rnn_network.py プロジェクト: ruizhaogit/alf
    def call(self, observation, step_type, network_state=None):
        outer_rank = nest_utils.get_outer_rank(observation,
                                               self.input_tensor_spec)
        batch_squash = utils.BatchSquash(outer_rank)

        observation, network_state = self._lstm_encoder(
            observation, step_type=step_type, network_state=network_state)

        states = batch_squash.flatten(observation)

        actions = []
        for layer, spec in zip(self._action_layers, self._flat_action_spec):
            action = layer(states)
            action = common.scale_to_spec(action, spec)
            action = batch_squash.unflatten(action)
            actions.append(action)

        output_actions = tf.nest.pack_sequence_as(self._output_tensor_spec,
                                                  actions)
        return output_actions, network_state
コード例 #28
0
ファイル: encoding_network.py プロジェクト: xzxzxzxz/agents
    def call(self,
             observation,
             step_type=None,
             network_state=(),
             training=False):
        del step_type  # unused.

        if self._batch_squash:
            outer_rank = nest_utils.get_outer_rank(observation,
                                                   self.input_tensor_spec)
            batch_squash = utils.BatchSquash(outer_rank)
            observation = tf.nest.map_structure(batch_squash.flatten,
                                                observation)

        if self._flat_preprocessing_layers is None:
            processed = observation
        else:
            processed = []
            for obs, layer in zip(
                    nest.flatten_up_to(self._preprocessing_nest,
                                       observation,
                                       check_types=False),
                    self._flat_preprocessing_layers):
                processed.append(layer(obs, training=training))
            if len(processed) == 1 and self._preprocessing_combiner is None:
                # If only one observation is passed and the preprocessing_combiner
                # is unspecified, use the preprocessed version of this observation.
                processed = processed[0]

        states = processed

        if self._preprocessing_combiner is not None:
            states = self._preprocessing_combiner(states)

        for layer in self._postprocessing_layers:
            states = layer(states, training=training)

        if self._batch_squash:
            states = tf.nest.map_structure(batch_squash.unflatten, states)

        return states, network_state
コード例 #29
0
ファイル: value_rnn_network.py プロジェクト: weiddeng/agents
  def call(self, observation, step_type=None, network_state=None):
    num_outer_dims = nest_utils.get_outer_rank(observation,
                                               self.input_tensor_spec)
    if num_outer_dims not in (1, 2):
      raise ValueError(
          'Input observation must have a batch or batch x time outer shape.')

    has_time_dim = num_outer_dims == 2
    if not has_time_dim:
      # Add a time dimension to the inputs.
      observation = tf.nest.map_structure(lambda t: tf.expand_dims(t, 1),
                                          observation)
      step_type = tf.nest.map_structure(lambda t: tf.expand_dims(t, 1),
                                        step_type)

    states = tf.cast(tf.nest.flatten(observation)[0], tf.float32)
    batch_squash = utils.BatchSquash(2)  # Squash B, and T dims.
    states = batch_squash.flatten(states)

    for layer in self._input_layers:
      states = layer(states)

    states = batch_squash.unflatten(states)

    with tf.name_scope('reset_mask'):
      reset_mask = tf.equal(step_type, time_step.StepType.FIRST)
    # Unroll over the time sequence.
    states, network_state = self._dynamic_unroll(
        states,
        reset_mask,
        initial_state=network_state)

    states = batch_squash.flatten(states)

    for layer in self._output_layers:
      states = layer(states)

    value = self._value_projection_layer(states)
    value = tf.reshape(value, [-1])
    value = batch_squash.unflatten(value)
    return value, network_state
コード例 #30
0
    def call(
        self,
        observations: Union[tf.Tensor, np.ndarray],
        step_type: Optional[Any],
        network_state: Union[Tuple, Tuple[Union[tf.Tensor, np.ndarray]]] = ()
    ) -> Tuple[Union[tfp.distributions.OneHotCategorical,
                     Tuple[tfp.distributions.OneHotCategorical]], Union[
                         Tuple, Tuple[Union[tf.Tensor, np.ndarray]]]]:
        """
        Run a forward pass of the action network mapping observations to a distribution over
        actions.

        :param observations: Tensor/Array of observation values from the environment.
        :param step_type: Not used in this network. Kept as an argument to be consistent with the
            standard TensorFlow Agents interface.
        :param network_state: The state of the network. Not required here as this network has no
            state since it is not recurrent.
        :return: A distribution over actions and the current network state.
        """
        # Use shared layers to attain inputs shared across each head.
        hidden_activations = tf.cast(observations, tf.float32)
        for layer in self._shared_layers:
            hidden_activations = layer(hidden_activations)

        # Determine the number of batch dimensions. Since this requires comparison to the input
        # tensor spec and the batch dimensions are preserved by the shared linear layers we
        # calculate batch dimensions based on the supplied observations.
        outer_rank = nest_utils.get_outer_rank(observations,
                                               self.input_tensor_spec)

        # Attain a nested set of actions i.e. a tuple of actions one for each head.
        action_dist = tf.nest.map_structure(
            lambda proj_net: proj_net(hidden_activations, outer_rank)[0],
            self._action_heads)

        # If there is only one action head unpack the tuple of 1 to attain the singular action
        # distribution itself.
        if len(self._action_subspace_dimensions) == 1:
            action_dist = action_dist[0]

        return action_dist, network_state