Esempio n. 1
0
def create_dataset(data_file,
                   grid_size,
                   n_plants,
                   return_trajs=False,
                   seed=None,
                   num_envs=None,
                   use_gold_trajs=False,
                   buffer_scorer=None):
    """Create the environment dataset to be used for training/evaluation."""
    data_dict = load_pickle(data_file)
    env_names = list(data_dict.keys())
    if num_envs is not None:
        env_names = sorted(env_names)[:num_envs]
    env_dict = {}
    all_trajs = []
    np.random.seed(seed)
    for name in env_names:
        data = data_dict[name]
        env_state, trajs = data['env_state'], data['trajs']
        goal_id, seed = env_state['goal_id'], env_state['seed']
        env = environment.TextEnvironment(name=name,
                                          goal_id=goal_id,
                                          grid_size=grid_size,
                                          n_plants=n_plants,
                                          seed=seed)
        env.reset()
        env.grid.grid = env_state['grid']
        index = np.random.choice(len(trajs))
        # Reverse the trajectory to create the context
        env.context = list(reversed(trajs[index]['actions']))
        env_dict[name] = env
        if use_gold_trajs:
            trajs = [trajs[index]]
        if return_trajs:
            features = [
                create_joint_features(traj['actions'], env.context)
                for traj in trajs
            ]
            new_trajs = [
                Traj(features=f, env_name=name, **x)
                for f, x in zip(features, trajs)
            ]
            if buffer_scorer is not None:
                new_trajs = get_top_trajs(buffer_scorer, new_trajs)
            all_trajs += new_trajs

    if return_trajs:
        return env_dict, all_trajs
    else:
        return env_dict
Esempio n. 2
0
  def sample_trajs(self, envs, greedy=False):
    env_names = [env.name for env in envs]
    env_dict = {env.name: env for env in envs}

    for env in envs:
      env.reset()
    rews = {env.name: [] for env in envs}
    actions = {env.name: [] for env in envs}

    kwargs = {'return_state': True}
    contexts = [env.context + 1 for env in envs]  # Zero is not a valid index
    contexts, context_lengths, _ = pad_sequences(contexts, 0)
    contexts = tf.stack(contexts, axis=0)
    context_lengths = np.array(context_lengths, dtype=np.int32)
    encoded_context, initial_state = self.pi.encode_context(
        contexts, context_lengths=context_lengths)
    kwargs.update(enc_output=encoded_context)
    # pylint: disable=protected-access
    model_fn = self.pi._call
    # pylint: enable=protected-access

    while env_names:
      logprobs, next_state = model_fn(
          num_inputs=1, initial_state=initial_state, **kwargs)
      logprobs = logprobs[:, 0]
      acs = self._sample_action(logprobs, greedy=greedy)
      dones = []
      new_env_names = []
      for ac, name in zip(acs, env_names):
        rew, done = env_dict[name].step(ac)
        actions[name].append(ac)
        rews[name].append(rew)
        dones.append(done)
        if not done:
          new_env_names.append(name)
      env_names = new_env_names

      if env_names:
        # Remove the states of `done` environments from recurrent_state only if
        # at least one of them is not `done`
        if isinstance(next_state, tf.Tensor):
          next_state = [
              next_state[i] for i, done in enumerate(dones) if not done
          ]
          initial_state = tf.stack(next_state, axis=0)
        else:  # LSTM Tuple
          raise NotImplementedError

        enc_output = kwargs['enc_output']
        enc_output = [enc_output[i] for i, done in enumerate(dones) if not done]
        kwargs['enc_output'] = tf.stack(enc_output, axis=0)

    env_names = set([env.name for env in envs])
    features = [
        create_joint_features(actions[name], env_dict[name].context)
        for name in env_names
    ]
    # pylint: disable=g-complex-comprehension
    trajs = [
        Traj(
            env_name=name,
            actions=actions[name],
            features=f,
            rewards=rews[name]) for name, f in zip(env_names, features)
    ]
    # pylint: enable=g-complex-comprehension
    return trajs
Esempio n. 3
0
def create_replay_traj(env, actions):
    traj_dict = create_traj(env, actions)
    features = create_joint_features(traj_dict['actions'], env.context)
    return Traj(features=features, **traj_dict)