def __call__(self, inputs):
    inputs = tf_utils.batch_concat(inputs)
    zero = tf.constant(0, dtype=inputs.dtype)
    mean = self._mean_layer(inputs)
    if self._binary_grip_action:
      grip_pred = tf.gather(
          tf.constant([2.0, -2.0]), tf.math.argmax(mean[:, :2], axis=1))
      mean = tf.concat([tf.expand_dims(grip_pred, axis=1), mean[:, 2:]], axis=1)

    if self._fixed_scale:
      scale = tf.ones_like(mean) * self._init_scale
      scale = tf.nn.softplus(self._scale_layer(inputs))
      scale *= self._init_scale / tf.nn.softplus(zero)
      scale += self._min_scale

    # Maybe transform the mean.
    if self._tanh_mean:
      mean = tf.tanh(mean)

    if self._use_tfd_independent:
      dist = tfd.Independent(tfd.Normal(loc=mean, scale=scale))
      dist = tfd.MultivariateNormalDiag(loc=mean, scale_diag=scale)

    return dist
Example #2
  def test_batch_concat(self):
    batch_size = 32
    inputs = [
        tf.zeros(shape=(batch_size, 2)),
            'foo': tf.zeros(shape=(batch_size, 5, 3))
        [tf.zeros(shape=(batch_size, 1))],

    output_shape = tf2_utils.batch_concat(inputs).shape.as_list()
    expected_shape = [batch_size, 2 + 5 * 3 + 1]
    self.assertSequenceEqual(output_shape, expected_shape)
Example #3
    def _density_loss(self, current_obs, action, discount, next_obs):
        target = tf2_utils.batch_concat(next_obs)

        # density = self.mixture_density(current_obs, action)
        obs_distr, discount_distr = self.mixture_density(current_obs, action)

        obs_log_prob = obs_distr.log_prob(target)
        obs_loss = tf.reduce_mean(-obs_log_prob)

        discount_log_prob = discount_distr.log_prob(discount)
        discount_loss = tf.reduce_mean(-discount_log_prob)

        loss = obs_loss + discount_loss
        return loss, obs_loss, discount_loss
Example #4
def batch_concat_selection(observation_dict: Dict[str, types.NestedTensor],
                           concat_keys: Optional[Iterable[str]] = None,
                           output_dtype=tf.float32) -> tf.Tensor:
  """Concatenate a dict of observations into 2-D tensors."""
  concat_keys = concat_keys or sorted(observation_dict.keys())
  to_concat = []
  for obs in concat_keys:
    if obs not in observation_dict:
      raise KeyError(
          'Missing observation. Requested: {} (available: {})'.format(
              obs, list(observation_dict.keys())))
    to_concat.append(tf.cast(observation_dict[obs], output_dtype))

  return tf2_utils.batch_concat(to_concat)
Example #5
    def __call__(self, observation: tf.Tensor, action: tf.Tensor) -> tf.Tensor:

        # Maybe transform observations and actions before feeding them on.
        if self._observation_network:
            observation = self._observation_network(observation)
        if self._action_network:
            action = self._action_network(action)

        # Concat observations and actions, with one batch dimension.
        outputs = tf2_utils.batch_concat([observation, action])

        # Maybe transform output before returning.
        if self._critic_network:
            outputs = self._critic_network(outputs)

        return outputs
Example #6
    def __call__(
        x: snt.Module,
        state: Optional[snt.Module] = None,
        message: Optional[snt.Module] = None,
    ) -> snt.Module:
        if state is None:
            state = self.initial_state(x.shape[0])
        if message is None:
            message = self.initial_message(x.shape[0])

        obs_in = self._obs_in_network(x)
        comm_in = self._comm_in_network(message)

        core_in = tf2_utils.batch_concat([obs_in, comm_in])

        core_out, state = self._core_network(core_in, state)

        action = self._action_head(core_out)
        message = self._comm_head(core_out)

        return (action, message), state
Example #7
    def __call__(self, observation: types.NestedTensor,
                 action: types.NestedTensor) -> tf.Tensor:

        # Maybe transform observations and actions before feeding them on.
        if self._observation_network:
            observation = self._observation_network(observation)
        if self._action_network:
            action = self._action_network(action)

        if hasattr(observation, 'dtype') and hasattr(action, 'dtype'):
            if observation.dtype != action.dtype:
                # Observation and action must be the same type for concat to work
                action = tf.cast(action, observation.dtype)

        # Concat observations and actions, with one batch dimension.
        outputs = tf2_utils.batch_concat([observation, action])

        # Maybe transform output before returning.
        if self._critic_network:
            outputs = self._critic_network(outputs)

        return outputs
Example #8
    def __call__(self, inputs, action: tf.Tensor = None, task=None):
        """Evaluates the ControlNetwork.

      inputs:  A dictionary of agent observation tensors.
      action:  Agent actions.
      task:    Optional encoding of the task.

      ValueError: if neither proprio_input is provided.
      ValueError: if some proprio input looks suspiciously like pixel inputs.

      Processed network output.
        if not isinstance(inputs, dict):
            inputs = {'inputs': inputs}

        proprio_input = []
        # By default, treat all observations as proprioceptive.
        if self._proprio_keys is None:
            self._proprio_keys = list(sorted(inputs.keys()))
        for key in self._proprio_keys:
            if[key].shape[1:]) > 32 * 32 * 3:
                raise ValueError(
                    'This input does not resemble a proprioceptive '
                    'state: {} with shape {}'.format(key, inputs[key].shape))

        # Append optional action input (i.e. for critic networks).
        if action is not None:

        proprio_input = tf2_utils.batch_concat(proprio_input)
        proprio_state = self._proprio_encoder(proprio_input)

        return proprio_state
Example #9
 def __call__(self, observations: types.Nest) -> tf.Tensor:
     """Forwards the policy network."""
     return self._network(tf2_utils.batch_concat(observations))