示例#1
0
    def _adjust_obs_actions_for_policy(self, json_data: dict,
                                       policy: Policy) -> dict:
        """Handle nested action/observation spaces for policies.

        Translates nested lists/dicts from the json into proper
        np.ndarrays, according to the (nested) observation- and action-
        spaces of the given policy.

        Providing nested lists w/o this preprocessing step would
        confuse a SampleBatch constructor.
        """
        for k, v in policy.view_requirements.items():
            if k not in json_data:
                continue
            if policy.config.get("_disable_action_flattening") and \
                    (k == SampleBatch.ACTIONS or
                     v.data_col == SampleBatch.ACTIONS):
                json_data[k] = tree.map_structure_up_to(
                    policy.action_space_struct,
                    lambda comp: np.array(comp),
                    json_data[k],
                    check_types=False,
                )
            elif policy.config.get("_disable_preprocessor_api") and \
                    (k == SampleBatch.OBS or
                     v.data_col == SampleBatch.OBS):
                json_data[k] = tree.map_structure_up_to(
                    policy.observation_space_struct,
                    lambda comp: np.array(comp),
                    json_data[k],
                    check_types=False,
                )
        return json_data
示例#2
0
 def get_params_fn(state):
   params = tree.map_structure_up_to(encoders, lambda e, s: e.get_params(s),
                                     encoders, state)
   encode_params = _slice(encoders, params, 0)
   decode_before_sum_params = _slice(encoders, params, 1)
   decode_after_sum_params = _slice(encoders, params, 2)
   return encode_params, decode_before_sum_params, decode_after_sum_params
示例#3
0
文件: utils.py 项目: deepmind/acme
    def __next__(self) -> types.NestedArray:
        try:
            if not self.pmapped_user:
                item = next(self.iterator)
                if self.split_fn is None:
                    return jax.device_put(item, self.devices[0])
                item_split = self.split_fn(item)
                return PrefetchingSplit(host=item_split.host,
                                        device=jax.device_put(
                                            item_split.device,
                                            self.devices[0]))

            items = itertools.islice(self.iterator, self.num_devices)
            items = tuple(items)
            if len(items) < self.num_devices:
                raise StopIteration
            if self.split_fn is None:
                return jax.device_put_sharded(tuple(items), self.devices)
            else:
                # ((host: x1, device: y1), ..., (host: xN, device: yN)).
                items_split = (self.split_fn(item) for item in items)
                # (host: (x1, ..., xN), device: (y1, ..., yN)).
                split = tree.map_structure_up_to(PrefetchingSplit(None, None),
                                                 lambda *x: x, *items_split)

                return PrefetchingSplit(host=np.stack(split.host),
                                        device=jax.device_put_sharded(
                                            split.device, self.devices))

        except StopIteration:
            raise

        except Exception:  # pylint: disable=broad-except
            logging.exception('Error for %s', self.iterable)
            raise
示例#4
0
    def producer():
        """Enqueues batched items from `iterable` on a given thread."""
        try:
            # Build a new iterable for each thread. This is crucial if working with
            # tensorflow datasets because tf.Graph objects are thread local.
            it = iter(iterable)
            while True:
                items = itertools.islice(it, len(devices))
                if not items:
                    break
                if split_fn is None:
                    buffer.put(
                        jax.api.device_put_sharded(tuple(items), devices))
                else:
                    # ((host: x1, device: y1), ..., (host: xN, device: yN)).
                    items_split = (split_fn(item) for item in items)
                    # (host: (x1, ..., xN), device: (y1, ..., yN)).
                    split = tree.map_structure_up_to(
                        PrefetchingSplit(None, None), lambda *x: x,
                        *items_split)

                    buffer.put(
                        PrefetchingSplit(host=np.stack(split.host),
                                         device=jax.api.device_put_sharded(
                                             split.device, devices)))
        except Exception as e:  # pylint: disable=broad-except
            logging.exception('Error in producer thread for %s',
                              iterable.__name__)
            producer_error.append(e)
        finally:
            buffer.put(end)
示例#5
0
    def _map_to_queries(self, fn, *inputs, **kwargs):
        """Maps DPQuery methods to the subqueries."""
        def caller(query, *args):
            return getattr(query, fn)(*args, **kwargs)

        return tree.map_structure_up_to(self._queries, caller, self._queries,
                                        *inputs)
示例#6
0
 def encode_fn(x, encode_params, decode_before_sum_params):
     encoded_structure = tree.map_structure_up_to(
         encoders, lambda e, *args: e.encode(*args), encoders, x,
         encode_params)
     encoded_x = _slice(encoders, encoded_structure, 0)
     state_update_tensors = _slice(encoders, encoded_structure, 1)
     return encoded_x, decode_before_sum_params, state_update_tensors
示例#7
0
def feedforward_Q_function(input_shapes,
                           *args,
                           preprocessors=None,
                           observation_keys=None,
                           name='feedforward_Q',
                           **kwargs):
    inputs = create_inputs(input_shapes)

    if preprocessors is None:
        preprocessors = tree.map_structure(lambda _: None, inputs)

    preprocessors = tree.map_structure_up_to(inputs,
                                             preprocessors_lib.deserialize,
                                             preprocessors)

    preprocessed_inputs = apply_preprocessors(preprocessors, inputs)

    # NOTE(hartikainen): `feedforward_model` would do the `cast_and_concat`
    # step for us, but tf2.2 broke the sequential multi-input handling: See:
    # https://github.com/tensorflow/tensorflow/issues/37061.
    out = tf.keras.layers.Lambda(cast_and_concat)(preprocessed_inputs)
    Q_model_body = feedforward_model(*args,
                                     output_shape=[1],
                                     name=name,
                                     **kwargs)

    Q_model = tf.keras.Model(inputs, Q_model_body(out), name=name)

    Q_function = StateActionValueFunction(model=Q_model,
                                          observation_keys=observation_keys,
                                          name=name)

    return Q_function
示例#8
0
 def encode(state, value):
     """Encode tf_computation."""
     encoded_structure = tree.map_structure_up_to(
         encoders, lambda state, value, e: e.encode(value, state), state,
         value, encoders)
     encoded_value = _slice(encoders, encoded_structure, 0)
     new_state = _slice(encoders, encoded_structure, 1)
     return new_state, encoded_value
示例#9
0
  def outer_traverse_fn(subtree):
    res = traverse_fn(subtree)
    if is_nested(res):

      def inner_traverse_fn(do_traverse, subtree):
        if do_traverse:
          return dm_tree.traverse(outer_traverse_fn, subtree)
        else:
          return _FALSE_SENTINEL

      return dm_tree.map_structure_up_to(res, inner_traverse_fn, res, subtree)
    else:
      return None if res else _FALSE_SENTINEL
示例#10
0
    def testMapStructureUpTo(self):
        # Named tuples.
        ab_tuple = collections.namedtuple("ab_tuple", "a, b")
        op_tuple = collections.namedtuple("op_tuple", "add, mul")
        inp_val = ab_tuple(a=2, b=3)
        inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
        out = tree.map_structure_up_to(inp_val,
                                       lambda val, ops:
                                       (val + ops.add) * ops.mul,
                                       inp_val,
                                       inp_ops,
                                       check_types=False)
        self.assertEqual(out.a, 6)
        self.assertEqual(out.b, 15)

        # Lists.
        data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
        name_list = ["evens", ["odds", "primes"]]
        out = tree.map_structure_up_to(
            name_list, lambda name, sec: "first_{}_{}".format(len(sec), name),
            name_list, data_list)
        self.assertEqual(out,
                         ["first_4_evens", ["first_5_odds", "first_3_primes"]])
示例#11
0
    def step(self, params, grads, states, itr=None):
        """Takes a single optimizer step.

    Args:
      params: a dict containing the parameters to be updated.
      grads: a dict containing the gradients for each parameter in params.
      states: a dict containing any optimizer buffers (momentum, etc) for
        each parameter in params.
      itr: an optional integer indicating the current step, for scheduling.
    Returns:
       The updated params and optimizer buffers.
    """
        get_hyper = lambda k, v: self.get_opt_params('/'.join(k), itr)
        hypers = tree.map_structure_with_path(get_hyper, params)
        outs = tree.map_structure_up_to(params, self.update_param, params,
                                        grads, states, hypers)
        return utils.split_tree(outs, params, 2)
示例#12
0
def _slice(encoders, nested_value, idx):
    """Takes a slice of nested values.

  We use a collection of encoders to encode a collection of values. When a
  method of the encoder returns a tuple, e.g., encode / decode params of the
  get_params method, we need to recover the matching collection of encode params
  and collection of decode params. This method is a utility to achieve this.

  Args:
    encoders: A collection of encoders.
    nested_value: A collection of indexable values of the same structure as
      `encoders`.
    idx: An integer. Index of the values in `nested_value` along which to take
      the slice.

  Returns:
    A collection of values of the same structure as `encoders`.
  """
    return tree.map_structure_up_to(encoders, lambda t: t[idx], nested_value)
示例#13
0
    def __init__(self,
                 input_shapes,
                 output_shape,
                 observation_keys=None,
                 preprocessors=None,
                 name='policy'):
        self._input_shapes = input_shapes
        self._output_shape = output_shape
        self._observation_keys = observation_keys
        self._create_inputs(input_shapes)

        if preprocessors is None:
            preprocessors = tree.map_structure(lambda x: None, input_shapes)

        preprocessors = tree.map_structure_up_to(
            input_shapes, preprocessors_lib.deserialize, preprocessors)

        self._preprocessors = preprocessors

        self._name = name
示例#14
0
文件: filter.py 项目: ray-project/ray
    def __call__(self,
                 x: TensorStructType,
                 update: bool = True) -> TensorStructType:
        if self.no_preprocessor:
            x = tree.map_structure(lambda x_: np.asarray(x_), x)
        else:
            x = np.asarray(x)

        def _helper(x, rs, buffer, shape):
            # Discrete|MultiDiscrete spaces -> No normalization.
            if shape is None:
                return x

            # Keep dtype as is througout this filter.
            orig_dtype = x.dtype

            if update:
                if len(x.shape) == len(rs.shape) + 1:
                    # The vectorized case.
                    for i in range(x.shape[0]):
                        rs.push(x[i])
                        buffer.push(x[i])
                else:
                    # The unvectorized case.
                    rs.push(x)
                    buffer.push(x)
            if self.demean:
                x = x - rs.mean
            if self.destd:
                x = x / (rs.std + SMALL_NUMBER)
            if self.clip:
                x = np.clip(x, -self.clip, self.clip)
            return x.astype(orig_dtype)

        if self.no_preprocessor:
            return tree.map_structure_up_to(x, _helper, x, self.rs,
                                            self.buffer, self.shape)
        else:
            return _helper(x, self.rs, self.buffer, self.shape)
示例#15
0
 def update_state_fn(state, state_update_tensors):
     return tree.map_structure_up_to(encoders,
                                     lambda e, *args: e.update_state(*args),
                                     encoders, state, state_update_tensors)
示例#16
0
 def decode_after_sum_fn(summed_values, decode_after_sum_params):
     part_decoded_aggregated_x, num_summands = summed_values
     return tree.map_structure_up_to(
         encoders,
         lambda e, x, params: e.decode_after_sum(x, params, num_summands),
         encoders, part_decoded_aggregated_x, decode_after_sum_params)
示例#17
0
 def decode_before_sum_tf_function(encoded_x, decode_before_sum_params):
     part_decoded_x = tree.map_structure_up_to(
         encoders, lambda e, *args: e.decode_before_sum(*args), encoders,
         encoded_x, decode_before_sum_params)
     one = tf.constant((1, ), tf.int32)
     return part_decoded_x, one
示例#18
0
 def decode(encoded_value):
     """Decode tf_computation."""
     return tree.map_structure_up_to(encoders, lambda e, val: e.decode(val),
                                     encoders, encoded_value)
示例#19
0
def split_tree(tuple_tree, base_tree, n):
  """Splits tuple_tree with n-tuple leaves into n trees."""
  return [tree.map_structure_up_to(base_tree, lambda x: x[i], tuple_tree)  # pylint: disable=cell-var-from-loop
          for i in range(n)]
示例#20
0
def update(state: RunningStatisticsState,
           batch: types.NestedArray,
           *,
           config: NestStatisticsConfig = NestStatisticsConfig(),
           weights: Optional[jnp.ndarray] = None,
           std_min_value: float = 1e-6,
           std_max_value: float = 1e6,
           pmap_axis_name: Optional[str] = None,
           validate_shapes: bool = True) -> RunningStatisticsState:
    """Updates the running statistics with the given batch of data.

  Note: data batch and state elements (mean, etc.) must have the same structure.

  Note: by default will use int32 for counts and float32 for accumulated
  variance. This results in an integer overflow after 2^31 data points and
  degrading precision after 2^24 batch updates or even earlier if variance
  updates have large dynamic range.
  To improve precision, consider setting jax_enable_x64 to True, see
  https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision

  Arguments:
    state: The running statistics before the update.
    batch: The data to be used to update the running statistics.
    config: The config that specifies which leaves of the nested structure
      should the running statistics be computed for.
    weights: Weights of the batch data. Should match the batch dimensions.
      Passing a weight of 2. should be equivalent to updating on the
      corresponding data point twice.
    std_min_value: Minimum value for the standard deviation.
    std_max_value: Maximum value for the standard deviation.
    pmap_axis_name: Name of the pmapped axis, if any.
    validate_shapes: If true, the shapes of all leaves of the batch will be
      validated. Enabled by default. Doesn't impact performance when jitted.

  Returns:
    Updated running statistics.
  """
    # We require exactly the same structure to avoid issues when flattened
    # batch and state have different order of elements.
    tree.assert_same_structure(batch, state.mean)
    batch_shape = tree.flatten(batch)[0].shape
    # We assume the batch dimensions always go first.
    batch_dims = batch_shape[:len(batch_shape) -
                             tree.flatten(state.mean)[0].ndim]
    batch_axis = range(len(batch_dims))
    if weights is None:
        step_increment = np.prod(batch_dims)
    else:
        step_increment = jnp.sum(weights)
    if pmap_axis_name is not None:
        step_increment = jax.lax.psum(step_increment, axis_name=pmap_axis_name)
    count = state.count + step_increment

    # Validation is important. If the shapes don't match exactly, but are
    # compatible, arrays will be silently broadcasted resulting in incorrect
    # statistics.
    if validate_shapes:
        if weights is not None:
            if weights.shape != batch_dims:
                raise ValueError(f'{weights.shape} != {batch_dims}')
        _validate_batch_shapes(batch, state.mean, batch_dims)

    def _compute_node_statistics(
            path: Path, mean: jnp.ndarray, summed_variance: jnp.ndarray,
            batch: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
        assert isinstance(mean, jnp.ndarray), type(mean)
        assert isinstance(summed_variance, jnp.ndarray), type(summed_variance)
        if not _is_path_included(config, path):
            # Return unchanged.
            return mean, summed_variance
        # The mean and the sum of past variances are updated with Welford's
        # algorithm using batches (see https://stackoverflow.com/q/56402955).
        diff_to_old_mean = batch - mean
        if weights is not None:
            expanded_weights = jnp.reshape(
                weights,
                list(weights.shape) + [1] * (batch.ndim - weights.ndim))
            diff_to_old_mean = diff_to_old_mean * expanded_weights
        mean_update = jnp.sum(diff_to_old_mean, axis=batch_axis) / count
        if pmap_axis_name is not None:
            mean_update = jax.lax.psum(mean_update, axis_name=pmap_axis_name)
        mean = mean + mean_update

        diff_to_new_mean = batch - mean
        variance_update = diff_to_old_mean * diff_to_new_mean
        variance_update = jnp.sum(variance_update, axis=batch_axis)
        if pmap_axis_name is not None:
            variance_update = jax.lax.psum(variance_update,
                                           axis_name=pmap_axis_name)
        summed_variance = summed_variance + variance_update
        return mean, summed_variance

    updated_stats = tree_utils.fast_map_structure_with_path(
        _compute_node_statistics, state.mean, state.summed_variance, batch)
    # map_structure_up_to is slow, so shortcut if we know the input is not
    # structured.
    if isinstance(state.mean, jnp.ndarray):
        mean, summed_variance = updated_stats
    else:
        # Reshape the updated stats from `nest(mean, summed_variance)` to
        # `nest(mean), nest(summed_variance)`.
        mean, summed_variance = [
            tree.map_structure_up_to(state.mean,
                                     lambda s, i=idx: s[i],
                                     updated_stats) for idx in range(2)
        ]

    def compute_std(path: Path, summed_variance: jnp.ndarray,
                    std: jnp.ndarray) -> jnp.ndarray:
        assert isinstance(summed_variance, jnp.ndarray)
        if not _is_path_included(config, path):
            return std
        # Summed variance can get negative due to rounding errors.
        summed_variance = jnp.maximum(summed_variance, 0)
        std = jnp.sqrt(summed_variance / count)
        std = jnp.clip(std, std_min_value, std_max_value)
        return std

    std = tree_utils.fast_map_structure_with_path(compute_std, summed_variance,
                                                  state.std)

    return RunningStatisticsState(count=count,
                                  mean=mean,
                                  summed_variance=summed_variance,
                                  std=std)
示例#21
0
    def _map_to_queries(self, fn, *inputs, **kwargs):
        def caller(query, *args):
            return getattr(query, fn)(*args, **kwargs)

        return tree.map_structure_up_to(self._queries, caller, self._queries,
                                        *inputs)