示例#1
0
def make_evaluation_index_dataset(data_path,
                                  n_envs,
                                  n_negatives,
                                  expert_eps,
                                  fake_eps,
                                  dtype,
                                  start_seed=10000,
                                  filename_suffix='evaluate'):
    """Creates a single dataset of the 'evaluation' format

    Includes positive examples from the expert and negative from random walks

    :param data_path: path to a directory to save files
    :param n_envs: number of environments (or expert transitions) in the dataset
    :param n_negatives: number of negative for each positive
    :param expert_eps: probability if a random action for the expert
    :param fake_eps: probability if a random action for the random walks (can be made less random and more expert-like)
    :param dtype: 'bool' or 'int' data type for the observations
    :param start_seed: seed for the first environment (for other environments uses next ones)
    :param filename_suffix: dataset name to append to the filenames
    """
    if dtype == 'bool':
        np_dtype = np.bool_
    elif dtype == 'int':
        np_dtype = int
    else:
        raise ValueError
    dataset = []
    state_data = []
    button_idxs_used = set()
    seed = start_seed
    cur_num = 0
    idx = {}
    for ind_env in tqdm(range(n_envs)):
        # generate an env with NEW positions of the buttons
        repeats = True
        while repeats:
            seed += 1
            env = GridWorld(height=5,
                            width=5,
                            n_buttons=3,
                            seed=seed,
                            obs_dtype=dtype)
            observation = env.reset()
            repeats = env.button_idx in button_idxs_used
        states = [observation.astype(np_dtype)]
        state_datas = [env.get_state_data()]
        done = False
        while not done:
            observation, _, done, _ = env.step(
                env.get_expert_action(eps=expert_eps))
            states.append(observation.astype(np_dtype))
            state_datas.append(env.get_state_data())
        n_states = len(states)
        s_ind, s_prime_ind = np.random.choice(np.arange(n_states),
                                              2,
                                              replace=False)
        s_ind, s_prime_ind = min(s_ind, s_prime_ind), max(s_ind, s_prime_ind)
        s, s_prime = states[s_ind], states[s_prime_ind]
        s_data, s_prime_data = state_datas[s_ind], state_datas[s_prime_ind]
        distance = np.abs(s_data[5] - s_prime_data[5]).sum()
        negatives = []
        negatives_data = []
        for ind_neg in range(n_negatives):
            env2 = GridWorld(height=5,
                             width=5,
                             n_buttons=3,
                             seed=ind_env * n_negatives + ind_neg,
                             obs_dtype=dtype)
            env2.load_state_data(*s_data)
            # Random actions until the target distance is reached
            cur_distance = 0
            target_distance = max(
                1, distance +
                np.random.randint(-1, 2))  # might be greater than maxdist?
            while cur_distance != target_distance:
                negative, _, done, _ = env2.step(
                    env2.get_expert_action(eps=fake_eps))
                if done:
                    break
                cur_distance = np.abs(s_data[5] - env2.pos).sum()
            negative_data = env2.get_state_data()
            negatives.append(negative)
            negatives_data.append(negative_data)
        dataset += [(s, s_prime)] + [(s, s_neg) for s_neg in negatives]
        state_data += [(s_data, s_prime_data)] + [
            (s_data, s_neg_data) for s_neg_data in negatives_data
        ]
        start = cur_num
        stop = cur_num + (1 + n_negatives)
        idx[seed] = (start, stop)
        cur_num += 1 + n_negatives

    # save the data
    path = pathlib.Path(data_path)
    path.mkdir(parents=True, exist_ok=True)
    with open(path / f'{filename_suffix}.pkl', 'wb') as f:
        pickle.dump(dataset, f)
    with open(path / f'state_data_{filename_suffix}.pkl', 'wb') as f:
        pickle.dump(state_data, f)
    with open(path / f'idx_{filename_suffix}.pkl', 'wb') as f:
        pickle.dump(idx, f)