def get_batch(self, indices=None, is_tf=True): if indices is None: indices = np.random.choice(self._valid_start_indices[:-self._horizon], size=self._batch_size) sampled_datadict = self._datadict.leaf_apply( lambda arr: np.stack([arr[idx:idx+self._horizon+1] for idx in indices], axis=0)) inputs = AttrDict() outputs = AttrDict() for key in self._env_spec.names: value = sampled_datadict[key] if key in self._env_spec.observation_names: inputs[key] = value[:, 0] elif key in self._env_spec.action_names: inputs[key] = value[:, :-1] if key in self._env_spec.output_observation_names: outputs[key] = value[:, 1:] outputs.done = sampled_datadict.done[:, 1:].cumsum(axis=1).astype(bool) if is_tf: for d in (inputs, outputs): d.leaf_modify(lambda x: tf.convert_to_tensor(x)) return inputs, outputs
def _load_hdf5s(self): hdf5_fnames = file_utils.get_files_ending_with(self._hdf5_folders, '.hdf5') # initialize to empty lists datadict = AttrDict() for key in self._env_spec.names: datadict[key] = [] datadict.done = [] datadict.hdf5_fname = [] datadict.rollout_timestep = [] # concatenate each hdf5 for hdf5_fname in hdf5_fnames: logger.debug('Loading ' + hdf5_fname) with h5py.File(hdf5_fname, 'r') as f: hdf5_names = file_utils.get_hdf5_leaf_names(f) hdf5_lens = np.array([len(f[name]) for name in hdf5_names]) if len(hdf5_names) == 0: logger.warning('Empty hdf5, skipping!') continue if not np.all(hdf5_lens == hdf5_lens[0]): logger.warning('data lengths not all the same, skipping!') continue if hdf5_lens[0] == 0: logger.warning('data lengths are 0, skipping!') continue for key in self._env_spec.names: assert key in f, '"{0}" not in env space names'.format(key) value = self._parse_hdf5(key, f[key]) datadict[key].append(value) datadict.done.append([False] * (len(value) - 1) + [True]) datadict.hdf5_fname.append([hdf5_fname] * len(value)) datadict.rollout_timestep.append(np.arange(len(value))) # turn every value into a single numpy array datadict.leaf_modify(lambda arr_list: np.concatenate(arr_list, axis=0)) datadict_len = len(datadict.done) datadict.leaf_assert(lambda arr: len(arr) == datadict_len) logger.debug('Dataset length: {}'.format(datadict_len)) # everywhere not done valid_start_indices = np.where(np.logical_not(datadict.done))[0] return datadict, valid_start_indices