def create_dataset(data_file, grid_size, n_plants, return_trajs=False, seed=None, num_envs=None, use_gold_trajs=False, buffer_scorer=None): """Create the environment dataset to be used for training/evaluation.""" data_dict = load_pickle(data_file) env_names = list(data_dict.keys()) if num_envs is not None: env_names = sorted(env_names)[:num_envs] env_dict = {} all_trajs = [] np.random.seed(seed) for name in env_names: data = data_dict[name] env_state, trajs = data['env_state'], data['trajs'] goal_id, seed = env_state['goal_id'], env_state['seed'] env = environment.TextEnvironment(name=name, goal_id=goal_id, grid_size=grid_size, n_plants=n_plants, seed=seed) env.reset() env.grid.grid = env_state['grid'] index = np.random.choice(len(trajs)) # Reverse the trajectory to create the context env.context = list(reversed(trajs[index]['actions'])) env_dict[name] = env if use_gold_trajs: trajs = [trajs[index]] if return_trajs: features = [ create_joint_features(traj['actions'], env.context) for traj in trajs ] new_trajs = [ Traj(features=f, env_name=name, **x) for f, x in zip(features, trajs) ] if buffer_scorer is not None: new_trajs = get_top_trajs(buffer_scorer, new_trajs) all_trajs += new_trajs if return_trajs: return env_dict, all_trajs else: return env_dict
def sample_trajs(self, envs, greedy=False): env_names = [env.name for env in envs] env_dict = {env.name: env for env in envs} for env in envs: env.reset() rews = {env.name: [] for env in envs} actions = {env.name: [] for env in envs} kwargs = {'return_state': True} contexts = [env.context + 1 for env in envs] # Zero is not a valid index contexts, context_lengths, _ = pad_sequences(contexts, 0) contexts = tf.stack(contexts, axis=0) context_lengths = np.array(context_lengths, dtype=np.int32) encoded_context, initial_state = self.pi.encode_context( contexts, context_lengths=context_lengths) kwargs.update(enc_output=encoded_context) # pylint: disable=protected-access model_fn = self.pi._call # pylint: enable=protected-access while env_names: logprobs, next_state = model_fn( num_inputs=1, initial_state=initial_state, **kwargs) logprobs = logprobs[:, 0] acs = self._sample_action(logprobs, greedy=greedy) dones = [] new_env_names = [] for ac, name in zip(acs, env_names): rew, done = env_dict[name].step(ac) actions[name].append(ac) rews[name].append(rew) dones.append(done) if not done: new_env_names.append(name) env_names = new_env_names if env_names: # Remove the states of `done` environments from recurrent_state only if # at least one of them is not `done` if isinstance(next_state, tf.Tensor): next_state = [ next_state[i] for i, done in enumerate(dones) if not done ] initial_state = tf.stack(next_state, axis=0) else: # LSTM Tuple raise NotImplementedError enc_output = kwargs['enc_output'] enc_output = [enc_output[i] for i, done in enumerate(dones) if not done] kwargs['enc_output'] = tf.stack(enc_output, axis=0) env_names = set([env.name for env in envs]) features = [ create_joint_features(actions[name], env_dict[name].context) for name in env_names ] # pylint: disable=g-complex-comprehension trajs = [ Traj( env_name=name, actions=actions[name], features=f, rewards=rews[name]) for name, f in zip(env_names, features) ] # pylint: enable=g-complex-comprehension return trajs
def create_replay_traj(env, actions): traj_dict = create_traj(env, actions) features = create_joint_features(traj_dict['actions'], env.context) return Traj(features=features, **traj_dict)