def mapping(item): # Already torch tensor -> make sure it's on right device. if torch.is_tensor(item): return item if device is None else item.to(device) # Special handling of "Repeated" values. elif isinstance(item, RepeatedValues): return RepeatedValues( tree.map_structure(mapping, item.values), item.lengths, item.max_len ) # Numpy arrays. if isinstance(item, np.ndarray): # Object type (e.g. info dicts in train batch): leave as-is. if item.dtype == object: return item # Non-writable numpy-arrays will cause PyTorch warning. elif item.flags.writeable is False: with warnings.catch_warnings(): warnings.simplefilter("ignore") tensor = torch.from_numpy(item) # Already numpy: Wrap as torch tensor. else: tensor = torch.from_numpy(item) # Everything else: Convert to numpy, then wrap as torch tensor. else: tensor = torch.from_numpy(np.asarray(item)) # Floatify all float64 tensors. if tensor.dtype == torch.double: tensor = tensor.float() return tensor if device is None else tensor.to(device)
def mapping(item): # Already torch tensor -> make sure it's on right device. if torch.is_tensor(item): return item if device is None else item.to(device) # Special handling of "Repeated" values. elif isinstance(item, RepeatedValues): return RepeatedValues(tree.map_structure(mapping, item.values), item.lengths, item.max_len) tensor = torch.from_numpy(np.asarray(item)) # Floatify all float64 tensors. if tensor.dtype == torch.double: tensor = tensor.float() return tensor if device is None else tensor.to(device)
def _convert_to_tf(x, dtype=None): if isinstance(x, SampleBatch): x = {k: v for k, v in x.items() if k != SampleBatch.INFOS} return tf.nest.map_structure(_convert_to_tf, x) elif isinstance(x, Policy): return x # Special handling of "Repeated" values. elif isinstance(x, RepeatedValues): return RepeatedValues(tf.nest.map_structure(_convert_to_tf, x.values), x.lengths, x.max_len) if x is not None: d = dtype x = tf.nest.map_structure( lambda f: tf.convert_to_tensor(f, d) if f is not None else None, x) return x
def _convert_to_tf(x, dtype=None): if isinstance(x, SampleBatch): dict_ = {k: v for k, v in x.items() if k != SampleBatch.INFOS} return tree.map_structure(_convert_to_tf, dict_) elif isinstance(x, Policy): return x # Special handling of "Repeated" values. elif isinstance(x, RepeatedValues): return RepeatedValues(tree.map_structure(_convert_to_tf, x.values), x.lengths, x.max_len) if x is not None: d = dtype return tree.map_structure( lambda f: _convert_to_tf(f, d) if isinstance(f, RepeatedValues) else tf.convert_to_tensor(f, d) if f is not None and not tf.is_tensor(f) else f, x, ) return x
def _unpack_obs(obs: TensorType, space: gym.Space, tensorlib: Any = tf) -> TensorStructType: """Unpack a flattened Dict or Tuple observation array/tensor. Args: obs: The flattened observation tensor, with last dimension equal to the flat size and any number of batch dimensions. For example, for Box(4,), the obs may have shape [B, 4], or [B, N, M, 4] in case the Box was nested under two Repeated spaces. space: The original space prior to flattening tensorlib: The library used to unflatten (reshape) the array/tensor """ if (isinstance(space, gym.spaces.Dict) or isinstance(space, gym.spaces.Tuple) or isinstance(space, Repeated)): if id(space) in _cache: prep = _cache[id(space)] else: prep = get_preprocessor(space)(space) # Make an attempt to cache the result, if enough space left. if len(_cache) < 999: _cache[id(space)] = prep if len(obs.shape) < 2 or obs.shape[-1] != prep.shape[0]: raise ValueError( "Expected flattened obs shape of [..., {}], got {}".format( prep.shape[0], obs.shape)) offset = 0 if tensorlib == tf: batch_dims = [ v if isinstance(v, int) else v.value for v in obs.shape[:-1] ] batch_dims = [-1 if v is None else v for v in batch_dims] else: batch_dims = list(obs.shape[:-1]) if isinstance(space, gym.spaces.Tuple): assert len(prep.preprocessors) == len(space.spaces), \ (len(prep.preprocessors) == len(space.spaces)) u = [] for p, v in zip(prep.preprocessors, space.spaces): obs_slice = obs[..., offset:offset + p.size] offset += p.size u.append( _unpack_obs(tensorlib.reshape(obs_slice, batch_dims + list(p.shape)), v, tensorlib=tensorlib)) elif isinstance(space, gym.spaces.Dict): assert len(prep.preprocessors) == len(space.spaces), \ (len(prep.preprocessors) == len(space.spaces)) u = OrderedDict() for p, (k, v) in zip(prep.preprocessors, space.spaces.items()): obs_slice = obs[..., offset:offset + p.size] offset += p.size u[k] = _unpack_obs(tensorlib.reshape( obs_slice, batch_dims + list(p.shape)), v, tensorlib=tensorlib) elif isinstance(space, Repeated): assert isinstance(prep, RepeatedValuesPreprocessor), prep child_size = prep.child_preprocessor.size # The list lengths are stored in the first slot of the flat obs. lengths = obs[..., 0] # [B, ..., 1 + max_len * child_sz] -> [B, ..., max_len, child_sz] with_repeat_dim = tensorlib.reshape( obs[..., 1:], batch_dims + [space.max_len, child_size]) # Retry the unpack, dropping the List container space. u = _unpack_obs(with_repeat_dim, space.child_space, tensorlib=tensorlib) return RepeatedValues(u, lengths=lengths, max_len=prep._obs_space.max_len) else: assert False, space return u else: return obs