def denormalize(batch: types.NestedArray, mean_std: NestedMeanStd) -> types.NestedArray: """Denormalizes values in a nested structure using the given mean/std. Only values of inexact types are denormalized. See https://numpy.org/doc/stable/_images/dtype-hierarchy.png for Numpy type hierarchy. Args: batch: a nested structure containing batch of data. mean_std: mean and standard deviation used for denormalization. Returns: Nested structure with denormalized values. """ def denormalize_leaf(data: jnp.ndarray, mean: jnp.ndarray, std: jnp.ndarray) -> jnp.ndarray: # Only denormalize inexact types. if not np.issubdtype(data.dtype, np.inexact): return data return data * std + mean return tree_utils.fast_map_structure(denormalize_leaf, batch, mean_std.mean, mean_std.std)
def _validate_batch_shapes(batch: types.NestedArray, reference_sample: types.NestedArray, batch_dims: Tuple[int, ...]) -> None: """Verifies shapes of the batch leaves against the reference sample. Checks that batch dimensions are the same in all leaves in the batch. Checks that non-batch dimensions for all leaves in the batch are the same as in the reference sample. Arguments: batch: the nested batch of data to be verified. reference_sample: the nested array to check non-batch dimensions. batch_dims: a Tuple of indices of batch dimensions in the batch shape. Returns: None. """ def validate_node_shape(reference_sample: jnp.ndarray, batch: jnp.ndarray) -> None: expected_shape = batch_dims + reference_sample.shape assert batch.shape == expected_shape, f'{batch.shape} != {expected_shape}' tree_utils.fast_map_structure(validate_node_shape, reference_sample, batch)
def normalize(batch: types.NestedArray, mean_std: NestedMeanStd, max_abs_value: Optional[float] = None) -> types.NestedArray: """Normalizes data using running statistics.""" def normalize_leaf(data: jnp.ndarray, mean: jnp.ndarray, std: jnp.ndarray) -> jnp.ndarray: # Only normalize inexact types. if not jnp.issubdtype(data.dtype, jnp.inexact): return data data = (data - mean) / std if max_abs_value is not None: # TODO(b/124318564): remove pylint directive data = jnp.clip(data, -max_abs_value, +max_abs_value) # pylint: disable=invalid-unary-operand-type return data return tree_utils.fast_map_structure(normalize_leaf, batch, mean_std.mean, mean_std.std)
def update(self, last: bool = False): """Perform a meta-training update.""" for i in range(self._minibatch_size): # Retrieve a batch of data from replay. data = self._queue.sample() data = tree_utils.fast_map_structure( lambda x: tf.convert_to_tensor(x), data) # Do a batch of SGD. results, gradients, logits = self._step(data=data) # Check gradients. #for g, v in zip(gradients, self._network.trainable_variables): # name = v.name.replace('/', '-') # results.update({name: tf.reduce_mean(g)}) # Compute elapsed time. timestamp = time.time() elapsed_time = timestamp - self._timestamp if self._timestamp else 0 self._timestamp = timestamp # Update our counts and record it. counts = self._counter.increment(steps=1, walltime=elapsed_time) results.update(counts) # Compute KL Divergence. pi = tfd.Categorical(logits=logits[:-1]) if counts['steps'] > 1: kl_divergence = tf.reduce_mean(pi.kl_divergence(self._pi_old)) results.update({'kl_divergence': kl_divergence}) if counts['steps'] == 1: results.update({'kl_divergence': 0.0}) self._pi_old = pi # Update learning rate. self._learning_rate.assign( self._lr_scheduler(step=counts['steps'])) # Snapshot and attempt to write logs. if hasattr(self, '_snapshotter'): self._snapshotter.save() if self._logger: self._logger.write(results)
def clip(batch: types.NestedArray, clipping_config: NestClippingConfig) -> types.NestedArray: """Clips the batch.""" def max_abs_value_for_path(path: Path, x: jnp.ndarray) -> Optional[float]: del x # Unused, needed by interface. return next( (max_abs_value for clipping_path, max_abs_value in clipping_config.path_map if _is_prefix(clipping_path, path)), None) max_abs_values = tree_utils.fast_map_structure_with_path( max_abs_value_for_path, batch) def clip_leaf(data: jnp.ndarray, max_abs_value: Optional[float]) -> jnp.ndarray: if max_abs_value is not None: # TODO(b/124318564): remove pylint directive data = jnp.clip(data, -max_abs_value, +max_abs_value) # pylint: disable=invalid-unary-operand-type return data return tree_utils.fast_map_structure(clip_leaf, batch, max_abs_values)
def preprocess_observation(ob: Dict[str, np.ndarray]): # Type conversion to float32. ob = tree_utils.fast_map_structure(lambda x: tf.cast(x, tf.float32), ob) # Apply logarithm to remaining steps. def avoid_inf(x: tf.Tensor, epsilon: float = 1e-7): return tf.where(tf.math.is_inf(x), epsilon, x) ob['remaining_steps'] = avoid_inf(tf.math.log(ob['remaining_steps'])) ob['trial_remaining_steps'] = avoid_inf(tf.math.log(ob['trial_remaining_steps'])) # Pop spatial observation spatial_ob = ob.pop('observation', None) # Add dimension ob['remaining_steps'] = tf.expand_dims(ob['remaining_steps'], axis=-1) ob['trial_remaining_steps'] = tf.expand_dims(ob['trial_remaining_steps'], axis=-1) ob['termination'] = tf.expand_dims(ob['termination'], axis=-1) ob['step_done'] = tf.expand_dims(ob['step_done'], axis=-1) ob['action_mask'] = tf.expand_dims(ob['action_mask'], axis=-1) ob['option_success'] = tf.expand_dims(ob['option_success'], axis=-1) flat_ob = tf.concat(tree.flatten(ob), axis=-1) return spatial_ob, flat_ob
def step(self): """Does a step of SGD and logs the results.""" # Retrieve a batch of data from replay. data = self._dataset.sample() data = tree_utils.fast_map_structure(lambda x: tf.convert_to_tensor(x), data) # Do a batch of SGD. results = self._step(data=data) # Compute elapsed time. timestamp = time.time() elapsed_time = timestamp - self._timestamp if self._timestamp else 0 self._timestamp = timestamp # Update our counts and record it. counts = self._counter.increment(steps=1, walltime=elapsed_time) results.update(counts) # Snapshot and attempt to write logs. self._snapshotter.save() self._logger.write(results)