def _adjust_obs_actions_for_policy(self, json_data: dict, policy: Policy) -> dict: """Handle nested action/observation spaces for policies. Translates nested lists/dicts from the json into proper np.ndarrays, according to the (nested) observation- and action- spaces of the given policy. Providing nested lists w/o this preprocessing step would confuse a SampleBatch constructor. """ for k, v in policy.view_requirements.items(): if k not in json_data: continue if policy.config.get("_disable_action_flattening") and \ (k == SampleBatch.ACTIONS or v.data_col == SampleBatch.ACTIONS): json_data[k] = tree.map_structure_up_to( policy.action_space_struct, lambda comp: np.array(comp), json_data[k], check_types=False, ) elif policy.config.get("_disable_preprocessor_api") and \ (k == SampleBatch.OBS or v.data_col == SampleBatch.OBS): json_data[k] = tree.map_structure_up_to( policy.observation_space_struct, lambda comp: np.array(comp), json_data[k], check_types=False, ) return json_data
def get_params_fn(state): params = tree.map_structure_up_to(encoders, lambda e, s: e.get_params(s), encoders, state) encode_params = _slice(encoders, params, 0) decode_before_sum_params = _slice(encoders, params, 1) decode_after_sum_params = _slice(encoders, params, 2) return encode_params, decode_before_sum_params, decode_after_sum_params
def __next__(self) -> types.NestedArray: try: if not self.pmapped_user: item = next(self.iterator) if self.split_fn is None: return jax.device_put(item, self.devices[0]) item_split = self.split_fn(item) return PrefetchingSplit(host=item_split.host, device=jax.device_put( item_split.device, self.devices[0])) items = itertools.islice(self.iterator, self.num_devices) items = tuple(items) if len(items) < self.num_devices: raise StopIteration if self.split_fn is None: return jax.device_put_sharded(tuple(items), self.devices) else: # ((host: x1, device: y1), ..., (host: xN, device: yN)). items_split = (self.split_fn(item) for item in items) # (host: (x1, ..., xN), device: (y1, ..., yN)). split = tree.map_structure_up_to(PrefetchingSplit(None, None), lambda *x: x, *items_split) return PrefetchingSplit(host=np.stack(split.host), device=jax.device_put_sharded( split.device, self.devices)) except StopIteration: raise except Exception: # pylint: disable=broad-except logging.exception('Error for %s', self.iterable) raise
def producer(): """Enqueues batched items from `iterable` on a given thread.""" try: # Build a new iterable for each thread. This is crucial if working with # tensorflow datasets because tf.Graph objects are thread local. it = iter(iterable) while True: items = itertools.islice(it, len(devices)) if not items: break if split_fn is None: buffer.put( jax.api.device_put_sharded(tuple(items), devices)) else: # ((host: x1, device: y1), ..., (host: xN, device: yN)). items_split = (split_fn(item) for item in items) # (host: (x1, ..., xN), device: (y1, ..., yN)). split = tree.map_structure_up_to( PrefetchingSplit(None, None), lambda *x: x, *items_split) buffer.put( PrefetchingSplit(host=np.stack(split.host), device=jax.api.device_put_sharded( split.device, devices))) except Exception as e: # pylint: disable=broad-except logging.exception('Error in producer thread for %s', iterable.__name__) producer_error.append(e) finally: buffer.put(end)
def _map_to_queries(self, fn, *inputs, **kwargs): """Maps DPQuery methods to the subqueries.""" def caller(query, *args): return getattr(query, fn)(*args, **kwargs) return tree.map_structure_up_to(self._queries, caller, self._queries, *inputs)
def encode_fn(x, encode_params, decode_before_sum_params): encoded_structure = tree.map_structure_up_to( encoders, lambda e, *args: e.encode(*args), encoders, x, encode_params) encoded_x = _slice(encoders, encoded_structure, 0) state_update_tensors = _slice(encoders, encoded_structure, 1) return encoded_x, decode_before_sum_params, state_update_tensors
def feedforward_Q_function(input_shapes, *args, preprocessors=None, observation_keys=None, name='feedforward_Q', **kwargs): inputs = create_inputs(input_shapes) if preprocessors is None: preprocessors = tree.map_structure(lambda _: None, inputs) preprocessors = tree.map_structure_up_to(inputs, preprocessors_lib.deserialize, preprocessors) preprocessed_inputs = apply_preprocessors(preprocessors, inputs) # NOTE(hartikainen): `feedforward_model` would do the `cast_and_concat` # step for us, but tf2.2 broke the sequential multi-input handling: See: # https://github.com/tensorflow/tensorflow/issues/37061. out = tf.keras.layers.Lambda(cast_and_concat)(preprocessed_inputs) Q_model_body = feedforward_model(*args, output_shape=[1], name=name, **kwargs) Q_model = tf.keras.Model(inputs, Q_model_body(out), name=name) Q_function = StateActionValueFunction(model=Q_model, observation_keys=observation_keys, name=name) return Q_function
def encode(state, value): """Encode tf_computation.""" encoded_structure = tree.map_structure_up_to( encoders, lambda state, value, e: e.encode(value, state), state, value, encoders) encoded_value = _slice(encoders, encoded_structure, 0) new_state = _slice(encoders, encoded_structure, 1) return new_state, encoded_value
def outer_traverse_fn(subtree): res = traverse_fn(subtree) if is_nested(res): def inner_traverse_fn(do_traverse, subtree): if do_traverse: return dm_tree.traverse(outer_traverse_fn, subtree) else: return _FALSE_SENTINEL return dm_tree.map_structure_up_to(res, inner_traverse_fn, res, subtree) else: return None if res else _FALSE_SENTINEL
def testMapStructureUpTo(self): # Named tuples. ab_tuple = collections.namedtuple("ab_tuple", "a, b") op_tuple = collections.namedtuple("op_tuple", "add, mul") inp_val = ab_tuple(a=2, b=3) inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3)) out = tree.map_structure_up_to(inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops, check_types=False) self.assertEqual(out.a, 6) self.assertEqual(out.b, 15) # Lists. data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] name_list = ["evens", ["odds", "primes"]] out = tree.map_structure_up_to( name_list, lambda name, sec: "first_{}_{}".format(len(sec), name), name_list, data_list) self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]])
def step(self, params, grads, states, itr=None): """Takes a single optimizer step. Args: params: a dict containing the parameters to be updated. grads: a dict containing the gradients for each parameter in params. states: a dict containing any optimizer buffers (momentum, etc) for each parameter in params. itr: an optional integer indicating the current step, for scheduling. Returns: The updated params and optimizer buffers. """ get_hyper = lambda k, v: self.get_opt_params('/'.join(k), itr) hypers = tree.map_structure_with_path(get_hyper, params) outs = tree.map_structure_up_to(params, self.update_param, params, grads, states, hypers) return utils.split_tree(outs, params, 2)
def _slice(encoders, nested_value, idx): """Takes a slice of nested values. We use a collection of encoders to encode a collection of values. When a method of the encoder returns a tuple, e.g., encode / decode params of the get_params method, we need to recover the matching collection of encode params and collection of decode params. This method is a utility to achieve this. Args: encoders: A collection of encoders. nested_value: A collection of indexable values of the same structure as `encoders`. idx: An integer. Index of the values in `nested_value` along which to take the slice. Returns: A collection of values of the same structure as `encoders`. """ return tree.map_structure_up_to(encoders, lambda t: t[idx], nested_value)
def __init__(self, input_shapes, output_shape, observation_keys=None, preprocessors=None, name='policy'): self._input_shapes = input_shapes self._output_shape = output_shape self._observation_keys = observation_keys self._create_inputs(input_shapes) if preprocessors is None: preprocessors = tree.map_structure(lambda x: None, input_shapes) preprocessors = tree.map_structure_up_to( input_shapes, preprocessors_lib.deserialize, preprocessors) self._preprocessors = preprocessors self._name = name
def __call__(self, x: TensorStructType, update: bool = True) -> TensorStructType: if self.no_preprocessor: x = tree.map_structure(lambda x_: np.asarray(x_), x) else: x = np.asarray(x) def _helper(x, rs, buffer, shape): # Discrete|MultiDiscrete spaces -> No normalization. if shape is None: return x # Keep dtype as is througout this filter. orig_dtype = x.dtype if update: if len(x.shape) == len(rs.shape) + 1: # The vectorized case. for i in range(x.shape[0]): rs.push(x[i]) buffer.push(x[i]) else: # The unvectorized case. rs.push(x) buffer.push(x) if self.demean: x = x - rs.mean if self.destd: x = x / (rs.std + SMALL_NUMBER) if self.clip: x = np.clip(x, -self.clip, self.clip) return x.astype(orig_dtype) if self.no_preprocessor: return tree.map_structure_up_to(x, _helper, x, self.rs, self.buffer, self.shape) else: return _helper(x, self.rs, self.buffer, self.shape)
def update_state_fn(state, state_update_tensors): return tree.map_structure_up_to(encoders, lambda e, *args: e.update_state(*args), encoders, state, state_update_tensors)
def decode_after_sum_fn(summed_values, decode_after_sum_params): part_decoded_aggregated_x, num_summands = summed_values return tree.map_structure_up_to( encoders, lambda e, x, params: e.decode_after_sum(x, params, num_summands), encoders, part_decoded_aggregated_x, decode_after_sum_params)
def decode_before_sum_tf_function(encoded_x, decode_before_sum_params): part_decoded_x = tree.map_structure_up_to( encoders, lambda e, *args: e.decode_before_sum(*args), encoders, encoded_x, decode_before_sum_params) one = tf.constant((1, ), tf.int32) return part_decoded_x, one
def decode(encoded_value): """Decode tf_computation.""" return tree.map_structure_up_to(encoders, lambda e, val: e.decode(val), encoders, encoded_value)
def split_tree(tuple_tree, base_tree, n): """Splits tuple_tree with n-tuple leaves into n trees.""" return [tree.map_structure_up_to(base_tree, lambda x: x[i], tuple_tree) # pylint: disable=cell-var-from-loop for i in range(n)]
def update(state: RunningStatisticsState, batch: types.NestedArray, *, config: NestStatisticsConfig = NestStatisticsConfig(), weights: Optional[jnp.ndarray] = None, std_min_value: float = 1e-6, std_max_value: float = 1e6, pmap_axis_name: Optional[str] = None, validate_shapes: bool = True) -> RunningStatisticsState: """Updates the running statistics with the given batch of data. Note: data batch and state elements (mean, etc.) must have the same structure. Note: by default will use int32 for counts and float32 for accumulated variance. This results in an integer overflow after 2^31 data points and degrading precision after 2^24 batch updates or even earlier if variance updates have large dynamic range. To improve precision, consider setting jax_enable_x64 to True, see https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision Arguments: state: The running statistics before the update. batch: The data to be used to update the running statistics. config: The config that specifies which leaves of the nested structure should the running statistics be computed for. weights: Weights of the batch data. Should match the batch dimensions. Passing a weight of 2. should be equivalent to updating on the corresponding data point twice. std_min_value: Minimum value for the standard deviation. std_max_value: Maximum value for the standard deviation. pmap_axis_name: Name of the pmapped axis, if any. validate_shapes: If true, the shapes of all leaves of the batch will be validated. Enabled by default. Doesn't impact performance when jitted. Returns: Updated running statistics. """ # We require exactly the same structure to avoid issues when flattened # batch and state have different order of elements. tree.assert_same_structure(batch, state.mean) batch_shape = tree.flatten(batch)[0].shape # We assume the batch dimensions always go first. batch_dims = batch_shape[:len(batch_shape) - tree.flatten(state.mean)[0].ndim] batch_axis = range(len(batch_dims)) if weights is None: step_increment = np.prod(batch_dims) else: step_increment = jnp.sum(weights) if pmap_axis_name is not None: step_increment = jax.lax.psum(step_increment, axis_name=pmap_axis_name) count = state.count + step_increment # Validation is important. If the shapes don't match exactly, but are # compatible, arrays will be silently broadcasted resulting in incorrect # statistics. if validate_shapes: if weights is not None: if weights.shape != batch_dims: raise ValueError(f'{weights.shape} != {batch_dims}') _validate_batch_shapes(batch, state.mean, batch_dims) def _compute_node_statistics( path: Path, mean: jnp.ndarray, summed_variance: jnp.ndarray, batch: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: assert isinstance(mean, jnp.ndarray), type(mean) assert isinstance(summed_variance, jnp.ndarray), type(summed_variance) if not _is_path_included(config, path): # Return unchanged. return mean, summed_variance # The mean and the sum of past variances are updated with Welford's # algorithm using batches (see https://stackoverflow.com/q/56402955). diff_to_old_mean = batch - mean if weights is not None: expanded_weights = jnp.reshape( weights, list(weights.shape) + [1] * (batch.ndim - weights.ndim)) diff_to_old_mean = diff_to_old_mean * expanded_weights mean_update = jnp.sum(diff_to_old_mean, axis=batch_axis) / count if pmap_axis_name is not None: mean_update = jax.lax.psum(mean_update, axis_name=pmap_axis_name) mean = mean + mean_update diff_to_new_mean = batch - mean variance_update = diff_to_old_mean * diff_to_new_mean variance_update = jnp.sum(variance_update, axis=batch_axis) if pmap_axis_name is not None: variance_update = jax.lax.psum(variance_update, axis_name=pmap_axis_name) summed_variance = summed_variance + variance_update return mean, summed_variance updated_stats = tree_utils.fast_map_structure_with_path( _compute_node_statistics, state.mean, state.summed_variance, batch) # map_structure_up_to is slow, so shortcut if we know the input is not # structured. if isinstance(state.mean, jnp.ndarray): mean, summed_variance = updated_stats else: # Reshape the updated stats from `nest(mean, summed_variance)` to # `nest(mean), nest(summed_variance)`. mean, summed_variance = [ tree.map_structure_up_to(state.mean, lambda s, i=idx: s[i], updated_stats) for idx in range(2) ] def compute_std(path: Path, summed_variance: jnp.ndarray, std: jnp.ndarray) -> jnp.ndarray: assert isinstance(summed_variance, jnp.ndarray) if not _is_path_included(config, path): return std # Summed variance can get negative due to rounding errors. summed_variance = jnp.maximum(summed_variance, 0) std = jnp.sqrt(summed_variance / count) std = jnp.clip(std, std_min_value, std_max_value) return std std = tree_utils.fast_map_structure_with_path(compute_std, summed_variance, state.std) return RunningStatisticsState(count=count, mean=mean, summed_variance=summed_variance, std=std)
def _map_to_queries(self, fn, *inputs, **kwargs): def caller(query, *args): return getattr(query, fn)(*args, **kwargs) return tree.map_structure_up_to(self._queries, caller, self._queries, *inputs)