Пример #1
0
def _convert_to_tf(x, dtype=None):
    if isinstance(x, SampleBatch):
        x = {k: v for k, v in x.items() if k != SampleBatch.INFOS}
        return tf.nest.map_structure(_convert_to_tf, x)
    elif isinstance(x, Policy):
        return x
    # Special handling of "Repeated" values.
    elif isinstance(x, RepeatedValues):
        return RepeatedValues(
            tf.nest.map_structure(_convert_to_tf, x.values), x.lengths,
            x.max_len)

    if x is not None:
        d = dtype
        x = tf.nest.map_structure(
            lambda f: tf.convert_to_tensor(f, d) if f is not None else None, x)
    return x
Пример #2
0
def _convert_to_tf(x, dtype=None):
    if isinstance(x, SampleBatch):
        dict_ = {k: v for k, v in x.items() if k != SampleBatch.INFOS}
        return tree.map_structure(_convert_to_tf, dict_)
    elif isinstance(x, Policy):
        return x
    # Special handling of "Repeated" values.
    elif isinstance(x, RepeatedValues):
        return RepeatedValues(tree.map_structure(_convert_to_tf, x.values),
                              x.lengths, x.max_len)

    if x is not None:
        d = dtype
        return tree.map_structure(
            lambda f: _convert_to_tf(f, d)
            if isinstance(f, RepeatedValues) else tf.convert_to_tensor(f, d)
            if f is not None and not tf.is_tensor(f) else f,
            x,
        )

    return x
Пример #3
0
def _unpack_obs(obs: TensorType,
                space: gym.Space,
                tensorlib: Any = tf) -> TensorStructType:
    """Unpack a flattened Dict or Tuple observation array/tensor.

    Args:
        obs: The flattened observation tensor, with last dimension equal to
            the flat size and any number of batch dimensions. For example, for
            Box(4,), the obs may have shape [B, 4], or [B, N, M, 4] in case
            the Box was nested under two Repeated spaces.
        space: The original space prior to flattening
        tensorlib: The library used to unflatten (reshape) the array/tensor
    """

    if (isinstance(space, gym.spaces.Dict)
            or isinstance(space, gym.spaces.Tuple)
            or isinstance(space, Repeated)):
        if id(space) in _cache:
            prep = _cache[id(space)]
        else:
            prep = get_preprocessor(space)(space)
            # Make an attempt to cache the result, if enough space left.
            if len(_cache) < 999:
                _cache[id(space)] = prep
        if len(obs.shape) < 2 or obs.shape[-1] != prep.shape[0]:
            raise ValueError(
                "Expected flattened obs shape of [..., {}], got {}".format(
                    prep.shape[0], obs.shape))
        offset = 0
        if tensorlib == tf:
            batch_dims = [
                v if isinstance(v, int) else v.value for v in obs.shape[:-1]
            ]
            batch_dims = [-1 if v is None else v for v in batch_dims]
        else:
            batch_dims = list(obs.shape[:-1])
        if isinstance(space, gym.spaces.Tuple):
            assert len(prep.preprocessors) == len(space.spaces), \
                (len(prep.preprocessors) == len(space.spaces))
            u = []
            for p, v in zip(prep.preprocessors, space.spaces):
                obs_slice = obs[..., offset:offset + p.size]
                offset += p.size
                u.append(
                    _unpack_obs(tensorlib.reshape(obs_slice,
                                                  batch_dims + list(p.shape)),
                                v,
                                tensorlib=tensorlib))
        elif isinstance(space, gym.spaces.Dict):
            assert len(prep.preprocessors) == len(space.spaces), \
                (len(prep.preprocessors) == len(space.spaces))
            u = OrderedDict()
            for p, (k, v) in zip(prep.preprocessors, space.spaces.items()):
                obs_slice = obs[..., offset:offset + p.size]
                offset += p.size
                u[k] = _unpack_obs(tensorlib.reshape(
                    obs_slice, batch_dims + list(p.shape)),
                                   v,
                                   tensorlib=tensorlib)
        elif isinstance(space, Repeated):
            assert isinstance(prep, RepeatedValuesPreprocessor), prep
            child_size = prep.child_preprocessor.size
            # The list lengths are stored in the first slot of the flat obs.
            lengths = obs[..., 0]
            # [B, ..., 1 + max_len * child_sz] -> [B, ..., max_len, child_sz]
            with_repeat_dim = tensorlib.reshape(
                obs[..., 1:], batch_dims + [space.max_len, child_size])
            # Retry the unpack, dropping the List container space.
            u = _unpack_obs(with_repeat_dim,
                            space.child_space,
                            tensorlib=tensorlib)
            return RepeatedValues(u,
                                  lengths=lengths,
                                  max_len=prep._obs_space.max_len)
        else:
            assert False, space
        return u
    else:
        return obs