Exemple #1
0
def denormalize(batch: types.NestedArray,
                mean_std: NestedMeanStd) -> types.NestedArray:
  """Denormalizes values in a nested structure using the given mean/std.

  Only values of inexact types are denormalized.
  See https://numpy.org/doc/stable/_images/dtype-hierarchy.png for Numpy type
  hierarchy.

  Args:
    batch: a nested structure containing batch of data.
    mean_std: mean and standard deviation used for denormalization.

  Returns:
    Nested structure with denormalized values.
  """

  def denormalize_leaf(data: jnp.ndarray, mean: jnp.ndarray,
                       std: jnp.ndarray) -> jnp.ndarray:
    # Only denormalize inexact types.
    if not np.issubdtype(data.dtype, np.inexact):
      return data
    return data * std + mean

  return tree_utils.fast_map_structure(denormalize_leaf, batch, mean_std.mean,
                                       mean_std.std)
Exemple #2
0
def _validate_batch_shapes(batch: types.NestedArray,
                           reference_sample: types.NestedArray,
                           batch_dims: Tuple[int, ...]) -> None:
    """Verifies shapes of the batch leaves against the reference sample.

  Checks that batch dimensions are the same in all leaves in the batch.
  Checks that non-batch dimensions for all leaves in the batch are the same
  as in the reference sample.

  Arguments:
    batch: the nested batch of data to be verified.
    reference_sample: the nested array to check non-batch dimensions.
    batch_dims: a Tuple of indices of batch dimensions in the batch shape.

  Returns:
    None.
  """
    def validate_node_shape(reference_sample: jnp.ndarray,
                            batch: jnp.ndarray) -> None:
        expected_shape = batch_dims + reference_sample.shape
        assert batch.shape == expected_shape, f'{batch.shape} != {expected_shape}'

    tree_utils.fast_map_structure(validate_node_shape, reference_sample, batch)
Exemple #3
0
def normalize(batch: types.NestedArray,
              mean_std: NestedMeanStd,
              max_abs_value: Optional[float] = None) -> types.NestedArray:
    """Normalizes data using running statistics."""
    def normalize_leaf(data: jnp.ndarray, mean: jnp.ndarray,
                       std: jnp.ndarray) -> jnp.ndarray:
        # Only normalize inexact types.
        if not jnp.issubdtype(data.dtype, jnp.inexact):
            return data
        data = (data - mean) / std
        if max_abs_value is not None:
            # TODO(b/124318564): remove pylint directive
            data = jnp.clip(data, -max_abs_value, +max_abs_value)  # pylint: disable=invalid-unary-operand-type
        return data

    return tree_utils.fast_map_structure(normalize_leaf, batch, mean_std.mean,
                                         mean_std.std)
Exemple #4
0
    def update(self, last: bool = False):
        """Perform a meta-training update."""
        for i in range(self._minibatch_size):
            # Retrieve a batch of data from replay.
            data = self._queue.sample()
            data = tree_utils.fast_map_structure(
                lambda x: tf.convert_to_tensor(x), data)

            # Do a batch of SGD.
            results, gradients, logits = self._step(data=data)

            # Check gradients.
            #for g, v in zip(gradients, self._network.trainable_variables):
            #  name = v.name.replace('/', '-')
            #  results.update({name: tf.reduce_mean(g)})

            # Compute elapsed time.
            timestamp = time.time()
            elapsed_time = timestamp - self._timestamp if self._timestamp else 0
            self._timestamp = timestamp

            # Update our counts and record it.
            counts = self._counter.increment(steps=1, walltime=elapsed_time)
            results.update(counts)

            # Compute KL Divergence.
            pi = tfd.Categorical(logits=logits[:-1])
            if counts['steps'] > 1:
                kl_divergence = tf.reduce_mean(pi.kl_divergence(self._pi_old))
                results.update({'kl_divergence': kl_divergence})
            if counts['steps'] == 1:
                results.update({'kl_divergence': 0.0})
            self._pi_old = pi

            # Update learning rate.
            self._learning_rate.assign(
                self._lr_scheduler(step=counts['steps']))

            # Snapshot and attempt to write logs.
            if hasattr(self, '_snapshotter'):
                self._snapshotter.save()

            if self._logger:
                self._logger.write(results)
Exemple #5
0
def clip(batch: types.NestedArray,
         clipping_config: NestClippingConfig) -> types.NestedArray:
    """Clips the batch."""
    def max_abs_value_for_path(path: Path, x: jnp.ndarray) -> Optional[float]:
        del x  # Unused, needed by interface.
        return next(
            (max_abs_value
             for clipping_path, max_abs_value in clipping_config.path_map
             if _is_prefix(clipping_path, path)), None)

    max_abs_values = tree_utils.fast_map_structure_with_path(
        max_abs_value_for_path, batch)

    def clip_leaf(data: jnp.ndarray,
                  max_abs_value: Optional[float]) -> jnp.ndarray:
        if max_abs_value is not None:
            # TODO(b/124318564): remove pylint directive
            data = jnp.clip(data, -max_abs_value, +max_abs_value)  # pylint: disable=invalid-unary-operand-type
        return data

    return tree_utils.fast_map_structure(clip_leaf, batch, max_abs_values)
Exemple #6
0
def preprocess_observation(ob: Dict[str, np.ndarray]):
  # Type conversion to float32.
  ob = tree_utils.fast_map_structure(lambda x: tf.cast(x, tf.float32), ob)

  # Apply logarithm to remaining steps.
  def avoid_inf(x: tf.Tensor, epsilon: float = 1e-7):
    return tf.where(tf.math.is_inf(x), epsilon, x)
  ob['remaining_steps'] = avoid_inf(tf.math.log(ob['remaining_steps']))
  ob['trial_remaining_steps'] = avoid_inf(tf.math.log(ob['trial_remaining_steps']))

  # Pop spatial observation
  spatial_ob = ob.pop('observation', None)

  # Add dimension
  ob['remaining_steps'] = tf.expand_dims(ob['remaining_steps'], axis=-1)
  ob['trial_remaining_steps'] = tf.expand_dims(ob['trial_remaining_steps'], axis=-1)
  ob['termination'] = tf.expand_dims(ob['termination'], axis=-1)
  ob['step_done'] = tf.expand_dims(ob['step_done'], axis=-1)
  ob['action_mask'] = tf.expand_dims(ob['action_mask'], axis=-1)
  ob['option_success'] = tf.expand_dims(ob['option_success'], axis=-1)
  flat_ob = tf.concat(tree.flatten(ob), axis=-1)
  return spatial_ob, flat_ob
Exemple #7
0
  def step(self):
    """Does a step of SGD and logs the results."""

    # Retrieve a batch of data from replay.
    data = self._dataset.sample()
    data = tree_utils.fast_map_structure(lambda x: tf.convert_to_tensor(x), data)

    # Do a batch of SGD.
    results = self._step(data=data)

    # Compute elapsed time.
    timestamp = time.time()
    elapsed_time = timestamp - self._timestamp if self._timestamp else 0
    self._timestamp = timestamp

    # Update our counts and record it.
    counts = self._counter.increment(steps=1, walltime=elapsed_time)
    results.update(counts)

    # Snapshot and attempt to write logs.
    self._snapshotter.save()
    self._logger.write(results)