示例#1
0
    def test_fast_map_structure_with_path(self):
        structure = {
            'a': {
                'b': np.array([0.0])
            },
            'c': (np.array([1.0]), np.array([2.0])),
            'd': [np.array(3.0), np.array(4.0)],
        }

        def map_fn(path: Sequence[str], x: np.ndarray, y: np.ndarray):
            return x + y + len(path)

        single_arg_map_fn = functools.partial(map_fn, y=np.array([0.0]))

        expected_mapped_structure = (tree.map_structure_with_path(
            single_arg_map_fn, structure))
        mapped_structure = (tree_utils.fast_map_structure_with_path(
            single_arg_map_fn, structure))
        self.assertEqual(mapped_structure, expected_mapped_structure)

        expected_double_mapped_structure = (tree.map_structure_with_path(
            map_fn, structure, mapped_structure))
        double_mapped_structure = (tree_utils.fast_map_structure_with_path(
            map_fn, structure, mapped_structure))
        self.assertEqual(double_mapped_structure,
                         expected_double_mapped_structure)
示例#2
0
 def _save_value_functions(self, checkpoint_dir):
     tree.map_structure_with_path(
         lambda path, Q: Q.save_weights(
             os.path.join(
                 checkpoint_dir, '-'.join(('Q', *[str(x) for x in path]))),
             save_format='tf'),
         self.Qs)
示例#3
0
    def add_samples(self, samples):
        num_samples = tree.flatten(samples)[0].shape[0]

        assert (('episode_index_forwards' in samples.keys())
                is ('episode_index_backwards' in samples.keys()))
        if 'episode_index_forwards' not in samples.keys():
            samples['episode_index_forwards'] = np.full(
                (num_samples, *self.fields['episode_index_forwards'].shape),
                self.fields['episode_index_forwards'].default_value,
                dtype=self.fields['episode_index_forwards'].dtype)
            samples['episode_index_backwards'] = np.full(
                (num_samples, *self.fields['episode_index_backwards'].shape),
                self.fields['episode_index_backwards'].default_value,
                dtype=self.fields['episode_index_backwards'].dtype)

        index = np.arange(
            self._pointer, self._pointer + num_samples) % self._max_size

        def add_sample(path, data, new_values, field):
            assert new_values.shape[0] == num_samples, (
                new_values.shape, num_samples)
            data[index] = new_values

        tree.map_structure_with_path(
            add_sample, self.data, samples, self.fields)

        self._advance(num_samples)
    def add_samples(self, samples):
        # samples: dict(7(1000))
        num_samples = tree.flatten(samples)[0].shape[0]

        assert (('episode_index_forwards' in samples.keys()) is
                ('episode_index_backwards' in samples.keys()))
        if 'episode_index_forwards' not in samples.keys():
            samples['episode_index_forwards'] = np.full(
                (num_samples, *self.fields['episode_index_forwards'].shape),
                self.fields['episode_index_forwards'].default_value,
                dtype=self.fields['episode_index_forwards'].dtype)
            samples['episode_index_backwards'] = np.full(
                (num_samples, *self.fields['episode_index_backwards'].shape),
                self.fields['episode_index_backwards'].default_value,
                dtype=self.fields['episode_index_backwards'].dtype)

        # 不会超过最大容量
        index = np.arange(self._pointer,
                          self._pointer + num_samples) % self._max_size

        def add_sample(path, data, new_values, field):
            assert new_values.shape[0] == num_samples, (new_values.shape,
                                                        num_samples)
            data[index] = new_values

        # map_structure_with_path 中, dict 是外部结构不是叶子节点,多个dict 必须key 相同,value可相互运算
        # 才能对 value 的叶子节点进行对应运算.func 中的 path 是 dict 的 key,如果是嵌套的,tuple 就是两个 key
        # 对 self.data 的每个 key 对应的value进行操作. map_structure 是对
        tree.map_structure_with_path(add_sample, self.data, samples,
                                     self.fields)

        self._advance(num_samples)
示例#5
0
    def compress(
        self,
        bulk: bool = False,
        columns: Set[str] = frozenset(["obs", "new_obs"])
    ) -> "SampleBatch":
        """Compresses the data buffers (by column) in place.

        Args:
            bulk: Whether to compress across the batch dimension (0)
                as well. If False will compress n separate list items, where n
                is the batch size.
            columns: The columns to compress. Default: Only
                compress the obs and new_obs columns.

        Returns:
            This very (now compressed) SampleBatch.
        """
        def _compress_in_place(path, value):
            if path[0] not in columns:
                return
            curr = self
            for i, p in enumerate(path):
                if i == len(path) - 1:
                    if bulk:
                        curr[p] = pack(value)
                    else:
                        curr[p] = np.array([pack(o) for o in value])
                curr = curr[p]

        tree.map_structure_with_path(_compress_in_place, self)

        return self
示例#6
0
    def decompress_if_needed(
        self, columns: Set[str] = frozenset(["obs",
                                             "new_obs"])) -> "SampleBatch":
        """Decompresses data buffers (per column if not compressed) in place.

        Args:
            columns: The columns to decompress. Default: Only
                decompress the obs and new_obs columns.

        Returns:
            This very (now uncompressed) SampleBatch.
        """
        def _decompress_in_place(path, value):
            if path[0] not in columns:
                return
            curr = self
            for p in path[:-1]:
                curr = curr[p]
            # Bulk compressed.
            if is_compressed(value):
                curr[path[-1]] = unpack(value)
            # Non bulk compressed.
            elif len(value) > 0 and is_compressed(value[0]):
                curr[path[-1]] = np.array([unpack(o) for o in value])

        tree.map_structure_with_path(_decompress_in_place, self)

        return self
示例#7
0
文件: asserts.py 项目: graingert/chex
def assert_tree_shape_prefix(tree: ArrayTree,
                             shape_prefix: Sequence[int],
                             *,
                             ignore_nones: bool = False):
    """Asserts all tree leaves' shapes have the same prefix.

  Args:
    tree: a tree to assert.
    shape_prefix: an expected shapes' prefix.
    ignore_nones: whether to ignore `None`s in the tree.
  Raise:
    AssertionError: if some leaf's shape doesn't start with the expected prefix;
                    if `ignore_nones` isn't set and the tree contains `None`s.
  """

    if not ignore_nones:
        assert_tree_no_nones(tree)

    def _assert_fn(path, leaf):
        if leaf is None: return

        prefix = leaf.shape[:len(shape_prefix)]
        if prefix != shape_prefix:
            raise AssertionError(
                f"Tree leaf '{_format_tree_path(path)}' has a shape prefix "
                f"diffent from expected: {prefix} != {shape_prefix}.")

    dm_tree.map_structure_with_path(_assert_fn, tree)
示例#8
0
def log_stats(ex, stats, step=None, sep='.'):
  def log(path, value):
    if isinstance(value, tf.Tensor):
      value = value.numpy()
    if isinstance(value, np.ndarray):
      value = value.mean()
    key = sep.join(map(str, path))
    ex.log_scalar(key, value, step=step)
  tree.map_structure_with_path(log, stats)
示例#9
0
def _merge_adresses_in(synthesis, not_criterias, to_merge):
    def _add_entry(synthesis, not_criterias, path, val):
        if list(path) not in not_criterias:
            if synthesis and path in tuple(zip(*synthesis))[0]:
                i = tuple(zip(*synthesis))[0].index(path)
                if val not in synthesis[i][1]:
                    synthesis[i][1].append(val)
            else:
                synthesis.append((path, [val]))

    tree.map_structure_with_path(partial(_add_entry, synthesis, not_criterias),
                                 to_merge)
示例#10
0
    def add_hyperparam_group(self, group, suffix, defaults):
        """Adds new hyperparameters to the hyperparams dict."""
        # Use default hyperparams unless overridden by group hyperparams
        group_dict = {key: key for key in defaults if key not in group}
        for key in group:
            if key != 'params':  # Reserved keyword 'params'
                group_dict[key] = '%s_%s' % (key, suffix)
                self._hyperparameters[group_dict[key]] = group[key]
        # Set up params2hyperparams
        def set_p2h(k, _):
            self._params2hyperparams['/'.join(k)] = group_dict

        tree.map_structure_with_path(set_p2h, group['params'])
示例#11
0
    def create_param_groups(self, params, defaults):
        """Creates param-hyperparam mappings."""
        if isinstance(params, list):
            for group_index, group in enumerate(params):
                # Add group to hyperparams and get this group's full hyperparameters
                self.add_hyperparam_group(group, group_index, defaults)
        else:
            mapping = {key: key for key in self._hyperparameters}

            def set_p2h(k, _):
                self._params2hyperparams['/'.join(k)] = mapping

            tree.map_structure_with_path(set_p2h, params)
示例#12
0
    def assert_dtype(self, test_dtype, module_fn: ModuleFn, shape,
                     input_dtype):
        """Checks that modules accepting float32 input_dtype output test_dtype."""

        if input_dtype != jnp.float32:
            self.skipTest('Skipping module with non-f32 input')

        def ones_creator(next_creator, shape, dtype, init, context):
            if context.full_name == 'vector_quantizer/embeddings':
                # NOTE: vector_quantizer/embeddings is created using a ctor argument
                # so dtype is not expected to follow input to __call__.
                dtype = test_dtype
            else:
                self.assertEqual(dtype, test_dtype, msg=context.full_name)

            # NOTE: We need to do this since some initializers (e.g. random.uniform)
            # do not support <32bit dtypes. This also makes the test run a bit faster.
            init = jnp.ones
            return next_creator(shape, dtype, init)

        def g(x):
            with hk.custom_creator(ones_creator):
                mod = module_fn()
                return mod(x)

        g = hk.transform_with_state(g)

        # No custom creator for state so we need to do this manually.
        def cast_if_floating(x):
            if jnp.issubdtype(x.dtype, jnp.floating):
                x = x.astype(test_dtype)
            return x

        def init_fn(rng, x):
            params, state = g.init(rng, x)
            state = jax.tree_map(cast_if_floating, state)
            return params, state

        x = np.ones(shape, test_dtype)
        rng = jax.random.PRNGKey(42)
        params, state = jax.eval_shape(init_fn, rng, x)

        for _ in range(2):
            y, state = jax.eval_shape(g.apply, params, state, rng, x)

            def assert_dtype(path, v):
                if jnp.issubdtype(v.dtype, jnp.floating):
                    self.assertEqual(v.dtype, test_dtype, msg=path)

            tree.map_structure_with_path(assert_dtype, y)
            tree.map_structure_with_path(assert_dtype, state)
示例#13
0
            def call(*args, **kwargs):
                args_flat = []
                for a in args:
                    if type(a) is list:
                        args_flat.extend(a)
                    else:
                        args_flat.append(a)
                args = args_flat

                # We have not built any placeholders yet: Do this once here,
                # then reuse the same placeholders each time we call this
                # function again.
                if symbolic_out[0] is None:
                    with session_or_none.graph.as_default():

                        def _create_placeholders(path, value):
                            if dynamic_shape:
                                if len(value.shape) > 0:
                                    shape = (None,) + value.shape[1:]
                                else:
                                    shape = ()
                            else:
                                shape = value.shape
                            return tf1.placeholder(
                                dtype=value.dtype,
                                shape=shape,
                                name=".".join([str(p) for p in path]),
                            )

                        placeholders = tree.map_structure_with_path(
                            _create_placeholders, args
                        )
                        for ph in tree.flatten(placeholders):
                            args_placeholders.append(ph)

                        placeholders = tree.map_structure_with_path(
                            _create_placeholders, kwargs
                        )
                        for k, ph in placeholders.items():
                            kwargs_placeholders[k] = ph

                        symbolic_out[0] = fn(*args_placeholders, **kwargs_placeholders)
                feed_dict = dict(zip(args_placeholders, tree.flatten(args)))
                tree.map_structure(
                    lambda ph, v: feed_dict.__setitem__(ph, v),
                    kwargs_placeholders,
                    kwargs,
                )
                ret = session_or_none.run(symbolic_out[0], feed_dict)
                return ret
示例#14
0
def assert_tree_all_equal_comparator(equality_comparator: TLeavesEqCmpFn,
                                     error_msg_fn: TLeavesEqCmpErrorFn,
                                     *trees: Sequence[ArrayTree]):
  """Assert all trees are equal using custom comparator for leaves."""
  if len(trees) < 2:
    return
  assert_tree_all_equal_structs(*trees)

  def _tree_error_msg_fn(l_1: TLeaf, l_2: TLeaf, path: str, i_1: int, i_2: int):
    msg = error_msg_fn(l_1, l_2)
    return f"Trees {i_1} and {i_2} differ in leaves '{path}': {msg}."

  cmp = functools.partial(_assert_leaves_all_eq_comparator,
                          equality_comparator, _tree_error_msg_fn)
  tree.map_structure_with_path(cmp, *trees)
示例#15
0
    def test_bias_parameter_shape(self):
        params = self.network.init(self.init_rng_key,
                                   _sample_input(self.input_shape))
        self.assertLen(tree.flatten(params), 2)

        def check_params(path, param):
            if path[-1] == 'b':
                self.assertNotEqual(self.output_shape, param.shape)
                chex.assert_shape(param, (1, ))
            elif path[-1] == 'w':
                chex.assert_shape(param, self.weights_shape)
            else:
                self.fail('Unexpected parameter %s.' % path)

        tree.map_structure_with_path(check_params, params)
示例#16
0
文件: asserts.py 项目: graingert/chex
def assert_tree_no_nones(tree: ArrayTree):
    """Asserts tree does not contain `None`s.

  Args:
    tree: a tree to assert.

  Raises:
    AssertionError: if the tree contains `None`s.
  """
    def _assert_fn(path, leaf):
        if leaf is None:
            raise AssertionError(
                f"`None` detected at '{_format_tree_path(path)}'.")

    dm_tree.map_structure_with_path(_assert_fn, tree)
示例#17
0
文件: tf_ops.py 项目: holdenk/ray
def get_placeholder(*,
                    space=None,
                    value=None,
                    name=None,
                    time_axis=False,
                    flatten=True):
    from ray.rllib.models.catalog import ModelCatalog

    if space is not None:
        if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple)):
            if flatten:
                return ModelCatalog.get_action_placeholder(space, None)
            else:
                return tree.map_structure_with_path(
                    lambda path, component: get_placeholder(
                        space=component,
                        name=name + "." + ".".join([str(p) for p in path]),
                    ),
                    get_base_struct_from_space(space),
                )
        return tf1.placeholder(
            shape=(None, ) + ((None, ) if time_axis else ()) + space.shape,
            dtype=tf.float32 if space.dtype == np.float64 else space.dtype,
            name=name,
        )
    else:
        assert value is not None
        shape = value.shape[1:]
        return tf1.placeholder(
            shape=(None, ) + ((None, ) if time_axis else ()) +
            (shape if isinstance(shape, tuple) else tuple(shape.as_list())),
            dtype=tf.float32 if value.dtype == np.float64 else value.dtype,
            name=name,
        )
示例#18
0
    def signature(cls,
                  environment_spec: specs.EnvironmentSpec,
                  extras_spec: types.NestedSpec = ()):
        """This is a helper method for generating signatures for Reverb tables.

    Signatures are useful for validating data types and shapes, see Reverb's
    documentation for details on how they are used.

    Args:
      environment_spec: A `specs.EnvironmentSpec` whose fields are nested
        structures with leaf nodes that have `.shape` and `.dtype` attributes.
        This should come from the environment that will be used to generate
        the data inserted into the Reverb table.
      extras_spec: A nested structure with leaf nodes that have `.shape` and
        `.dtype` attributes. The structure (and shapes/dtypes) of this must
        be the same as the `extras` passed into `ReverbAdder.add`.

    Returns:
      A `Step` whose leaf nodes are `tf.TensorSpec` objects.
    """
        spec_step = Step(observation=environment_spec.observations,
                         action=environment_spec.actions,
                         reward=environment_spec.rewards,
                         discount=environment_spec.discounts,
                         extras=extras_spec)
        return tree.map_structure_with_path(array_spec_to_tensor_spec,
                                            spec_step)
示例#19
0
    def signature(cls,
                  environment_spec: specs.EnvironmentSpec,
                  extras_spec: types.NestedSpec = ()):

        # This function currently assumes that self._discount is a scalar.
        # If it ever becomes a nested structure and/or a np.ndarray, this method
        # will need to know its structure / shape. This is because the signature
        # discount shape is the environment's discount shape and this adder's
        # discount shape broadcasted together. Also, the reward shape is this
        # signature discount shape broadcasted together with the environment
        # reward shape. As long as self._discount is a scalar, it will not affect
        # either the signature discount shape nor the signature reward shape, so we
        # can ignore it.

        rewards_spec, step_discounts_spec = tree_utils.broadcast_structures(
            environment_spec.rewards, environment_spec.discounts)
        rewards_spec = tree.map_structure(_broadcast_specs, rewards_spec,
                                          step_discounts_spec)
        step_discounts_spec = tree.map_structure(copy.deepcopy,
                                                 step_discounts_spec)

        transition_spec = types.Transition(
            environment_spec.observations,
            environment_spec.actions,
            rewards_spec,
            step_discounts_spec,
            environment_spec.observations,  # next_observation
            extras_spec)

        return tree.map_structure_with_path(base.spec_like_to_tensor_spec,
                                            transition_spec)
示例#20
0
    def add_learn_on_batch_results(
        self,
        results: Dict,
        policy_id: PolicyID = DEFAULT_POLICY_ID,
    ) -> None:
        """Adds a policy.learn_on_(loaded)?_batch() result to this builder.

        Args:
            results: The results returned by Policy.learn_on_batch or
                Policy.learn_on_loaded_batch.
            policy_id: The policy's ID, whose learn_on_(loaded)_batch method
                returned `results`.
        """
        assert (not self.is_finalized
                ), "LearnerInfo already finalized! Cannot add more results."

        # No towers: Single CPU.
        if "tower_0" not in results:
            self.results_all_towers[policy_id].append(results)
        # Multi-GPU case:
        else:
            self.results_all_towers[policy_id].append(
                tree.map_structure_with_path(
                    lambda p, *s: all_tower_reduce(p, *s),
                    *(results.pop("tower_{}".format(tower_num))
                      for tower_num in range(self.num_devices))))
            for k, v in results.items():
                if k == LEARNER_STATS_KEY:
                    for k1, v1 in results[k].items():
                        self.results_all_towers[policy_id][-1][
                            LEARNER_STATS_KEY][k1] = v1
                else:
                    self.results_all_towers[policy_id][-1][k] = v
示例#21
0
    def rows(self) -> Iterator[Dict[str, TensorType]]:
        """Returns an iterator over data rows, i.e. dicts with column values.

        Note that if `seq_lens` is set in self, we set it to [1] in the rows.

        Yields:
            Dict[str, TensorType]: The column values of the row in this
                iteration.

        Examples:
            >>> batch = SampleBatch({
            ...    "a": [1, 2, 3],
            ...    "b": [4, 5, 6],
            ...    "seq_lens": [1, 2]
            ... })
            >>> for row in batch.rows():
                   print(row)
            {"a": 1, "b": 4, "seq_lens": [1]}
            {"a": 2, "b": 5, "seq_lens": [1]}
            {"a": 3, "b": 6, "seq_lens": [1]}
        """

        # Do we add seq_lens=[1] to each row?
        seq_lens = None if self.get(
            SampleBatch.SEQ_LENS) is None else np.array([1])

        self_as_dict = {k: v for k, v in self.items()}

        for i in range(self.count):
            yield tree.map_structure_with_path(
                lambda p, v: v[i] if p[0] != self.SEQ_LENS else seq_lens,
                self_as_dict,
            )
示例#22
0
    def rows(self) -> Iterator[Dict[str, TensorType]]:
        """Returns an iterator over data rows, i.e. dicts with column values.

        Note that if `seq_lens` is set in self, we set it to 1 in the rows.

        Yields:
            The column values of the row in this iteration.

        Examples:
            >>> from ray.rllib.policy.sample_batch import SampleBatch
            >>> batch = SampleBatch({ # doctest: +SKIP
            ...    "a": [1, 2, 3],
            ...    "b": [4, 5, 6],
            ...    "seq_lens": [1, 2]
            ... })
            >>> for row in batch.rows(): # doctest: +SKIP
            ...    print(row) # doctest: +SKIP
            {"a": 1, "b": 4, "seq_lens": 1}
            {"a": 2, "b": 5, "seq_lens": 1}
            {"a": 3, "b": 6, "seq_lens": 1}
        """

        seq_lens = None if self.get(SampleBatch.SEQ_LENS, 1) is None else 1

        self_as_dict = {k: v for k, v in self.items()}

        for i in range(self.count):
            yield tree.map_structure_with_path(
                lambda p, v: v[i] if p[0] != self.SEQ_LENS else seq_lens,
                self_as_dict,
            )
示例#23
0
    def _slice(self, slice_: slice):
        """Helper method to handle SampleBatch slicing using a slice object.

        The returned SampleBatch uses the same underlying data object as
        `self`, so changing the slice will also change `self`.

        Note that only zero or positive bounds are allowed for both start
        and stop values. The slice step must be 1 (or None, which is the
        same).

        Args:
            slice_ (slice): The python slice object to slice by.

        Returns:
            SampleBatch: A new SampleBatch, however "linking" into the same
                data (sliced) as self.
        """
        start = slice_.start or 0
        stop = slice_.stop or len(self)
        assert start >= 0 and stop >= 0 and slice_.step in [1, None]

        if self.get(SampleBatch.SEQ_LENS) is not None and \
                len(self[SampleBatch.SEQ_LENS]) > 0:
            # Build our slice-map, if not done already.
            if not self._slice_map:
                sum_ = 0
                for i, l in enumerate(self[SampleBatch.SEQ_LENS]):
                    for _ in range(l):
                        self._slice_map.append((i, sum_))
                    sum_ += l
                self._slice_map.append((len(self[SampleBatch.SEQ_LENS]), sum_))

            start_seq_len, start = self._slice_map[start]
            stop_seq_len, stop = self._slice_map[stop]
            if self.zero_padded:
                start = start_seq_len * self.max_seq_len
                stop = stop_seq_len * self.max_seq_len

            def map_(path, value):
                if path[0] != SampleBatch.SEQ_LENS and not path[0].startswith(
                        "state_in_"):
                    return value[start:stop]
                else:
                    return value[start_seq_len:stop_seq_len]

            data = tree.map_structure_with_path(map_, self)
            return SampleBatch(
                data,
                _is_training=self.is_training,
                _time_major=self.time_major,
                _zero_padded=self.zero_padded,
                _max_seq_len=self.max_seq_len if self.zero_padded else None,
            )
        else:
            data = tree.map_structure(lambda value: value[start:stop], self)
            return SampleBatch(
                data,
                _is_training=self.is_training,
                _time_major=self.time_major,
            )
示例#24
0
    def assert_dtype(
        self,
        test_dtype: DType,
        module_fn: descriptors.ModuleFn,
        shape: Shape,
        input_dtype: DType,
    ):
        """Checks that modules accepting float32 input_dtype output test_dtype."""
        if jax.local_devices()[0].platform != 'tpu':
            self.skipTest('bfloat16 only supported on TPU')

        if input_dtype != jnp.float32:
            self.skipTest('Skipping module without float32 input')

        rng = jax.random.PRNGKey(42)

        def g(x):
            mod = module_fn()
            return mod(x)

        init_fn, apply_fn = hk.transform_with_state(g)

        # Create state in f32 to start.
        # NOTE: We need to do this since some initializers (e.g. random.uniform) do
        # not support <32bit dtypes.
        x = jax.random.uniform(rng, shape)
        params, state = jax.eval_shape(init_fn, rng, x)

        # Cast f32 to test_dtype.
        def make_param(v):
            dtype = test_dtype if v.dtype == jnp.float32 else v.dtype
            return jnp.ones(v.shape, dtype)

        params, state = jax.tree_map(make_param, (params, state))

        # test_dtype in should result in test_dtype out.
        x = x.astype(test_dtype)

        for _ in range(2):
            y, state = jax.eval_shape(apply_fn, params, state, rng, x)

            def assert_dtype(path, v):
                if v.dtype != jnp.int32:
                    self.assertEqual(v.dtype, test_dtype, msg=path)

            tree.map_structure_with_path(assert_dtype, y)
            tree.map_structure_with_path(assert_dtype, state)
示例#25
0
def field_from_gym_space(name, space):
    if isinstance(space, spaces.Box):
        if isinstance(name, (list, tuple)):
            name = '/'.join(name)
        return Field(name=name, dtype=space.dtype, shape=space.shape)
    elif isinstance(space, spaces.Dict):
        return tree.map_structure_with_path(field_from_gym_space, space.spaces)
    else:
        raise NotImplementedError(space)
示例#26
0
def get_placeholder(
    *,
    space: Optional[gym.Space] = None,
    value: Optional[Any] = None,
    name: Optional[str] = None,
    time_axis: bool = False,
    flatten: bool = True
) -> "tf1.placeholder":
    """Returns a tf1.placeholder object given optional hints, such as a space.

    Note that the returned placeholder will always have a leading batch
    dimension (None).

    Args:
        space: An optional gym.Space to hint the shape and dtype of the
            placeholder.
        value: An optional value to hint the shape and dtype of the
            placeholder.
        name: An optional name for the placeholder.
        time_axis: Whether the placeholder should also receive a time
            dimension (None).
        flatten: Whether to flatten the given space into a plain Box space
            and then create the placeholder from the resulting space.

    Returns:
        The tf1 placeholder.
    """
    from ray.rllib.models.catalog import ModelCatalog

    if space is not None:
        if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple)):
            if flatten:
                return ModelCatalog.get_action_placeholder(space, None)
            else:
                return tree.map_structure_with_path(
                    lambda path, component: get_placeholder(
                        space=component,
                        name=name + "." + ".".join([str(p) for p in path]),
                    ),
                    get_base_struct_from_space(space),
                )
        return tf1.placeholder(
            shape=(None,) + ((None,) if time_axis else ()) + space.shape,
            dtype=tf.float32 if space.dtype == np.float64 else space.dtype,
            name=name,
        )
    else:
        assert value is not None
        shape = value.shape[1:]
        return tf1.placeholder(
            shape=(None,)
            + ((None,) if time_axis else ())
            + (shape if isinstance(shape, tuple) else tuple(shape.as_list())),
            dtype=tf.float32 if value.dtype == np.float64 else value.dtype,
            name=name,
        )
示例#27
0
  def observation_spec(self):
    # Note: this function will be called before reset call is issued and
    # observation might not be ready.
    # But for this specific environment it works.
    obs = self._observation()

    def mk_spec(path_tuple, np_arr):
      return ArraySpec(np_arr.shape, np_arr.dtype, name='_'.join(path_tuple) + '_spec')

    return nest.map_structure_with_path(mk_spec, obs)
示例#28
0
    def testMapWithPathCompatibleStructures(self, s1, s2, check_types,
                                            expected):
        def path_and_sum(path, *values):
            return path, sum(values)

        result = tree.map_structure_with_path(path_and_sum,
                                              s1,
                                              s2,
                                              check_types=check_types)
        self.assertEqual(expected, result)
示例#29
0
    def finalize(self):
        self.is_finalized = True

        info = {}
        for policy_id, results_all_towers in self.results_all_towers.items():
            # Reduce mean across all minibatch SGD steps (axis=0 to keep
            # all shapes as-is).
            info[policy_id] = tree.map_structure_with_path(
                all_tower_reduce, *results_all_towers)

        return info
示例#30
0
    def signature(
        cls,
        environment_spec: mava_specs.EnvironmentSpec,
        extras_spec: tf.TypeSpec = {},
    ) -> tf.TypeSpec:

        # This function currently assumes that self._discount is a scalar.
        # If it ever becomes a nested structure and/or a np.ndarray, this method
        # will need to know its structure / shape. This is because the signature
        # discount shape is the environment's discount shape and this adder's
        # discount shape broadcasted together. Also, the reward shape is this
        # signature discount shape broadcasted together with the environment
        # reward shape. As long as self._discount is a scalar, it will not affect
        # either the signature discount shape nor the signature reward shape, so we
        # can ignore it.

        agent_specs = environment_spec.get_agent_specs()
        agents = environment_spec.get_agent_ids()
        env_extras_spec = environment_spec.get_extra_specs()
        extras_spec.update(env_extras_spec)

        obs_specs = {}
        act_specs = {}
        reward_specs = {}
        step_discount_specs = {}
        for agent in agents:

            rewards_spec, step_discounts_spec = tree_utils.broadcast_structures(
                agent_specs[agent].rewards, agent_specs[agent].discounts
            )

            rewards_spec = tree.map_structure(
                _broadcast_specs, rewards_spec, step_discounts_spec
            )
            step_discounts_spec = tree.map_structure(copy.deepcopy, step_discounts_spec)

            obs_specs[agent] = agent_specs[agent].observations
            act_specs[agent] = agent_specs[agent].actions
            reward_specs[agent] = rewards_spec
            step_discount_specs[agent] = step_discounts_spec

        transition_spec = [
            obs_specs,
            act_specs,
            extras_spec,
            reward_specs,
            step_discount_specs,
            obs_specs,  # next_observation
            extras_spec,
        ]

        return tree.map_structure_with_path(
            base.spec_like_to_tensor_spec, tuple(transition_spec)
        )