def make_evaluation_index_dataset(data_path, n_envs, n_negatives, expert_eps, fake_eps, dtype, start_seed=10000, filename_suffix='evaluate'): """Creates a single dataset of the 'evaluation' format Includes positive examples from the expert and negative from random walks :param data_path: path to a directory to save files :param n_envs: number of environments (or expert transitions) in the dataset :param n_negatives: number of negative for each positive :param expert_eps: probability if a random action for the expert :param fake_eps: probability if a random action for the random walks (can be made less random and more expert-like) :param dtype: 'bool' or 'int' data type for the observations :param start_seed: seed for the first environment (for other environments uses next ones) :param filename_suffix: dataset name to append to the filenames """ if dtype == 'bool': np_dtype = np.bool_ elif dtype == 'int': np_dtype = int else: raise ValueError dataset = [] state_data = [] button_idxs_used = set() seed = start_seed cur_num = 0 idx = {} for ind_env in tqdm(range(n_envs)): # generate an env with NEW positions of the buttons repeats = True while repeats: seed += 1 env = GridWorld(height=5, width=5, n_buttons=3, seed=seed, obs_dtype=dtype) observation = env.reset() repeats = env.button_idx in button_idxs_used states = [observation.astype(np_dtype)] state_datas = [env.get_state_data()] done = False while not done: observation, _, done, _ = env.step( env.get_expert_action(eps=expert_eps)) states.append(observation.astype(np_dtype)) state_datas.append(env.get_state_data()) n_states = len(states) s_ind, s_prime_ind = np.random.choice(np.arange(n_states), 2, replace=False) s_ind, s_prime_ind = min(s_ind, s_prime_ind), max(s_ind, s_prime_ind) s, s_prime = states[s_ind], states[s_prime_ind] s_data, s_prime_data = state_datas[s_ind], state_datas[s_prime_ind] distance = np.abs(s_data[5] - s_prime_data[5]).sum() negatives = [] negatives_data = [] for ind_neg in range(n_negatives): env2 = GridWorld(height=5, width=5, n_buttons=3, seed=ind_env * n_negatives + ind_neg, obs_dtype=dtype) env2.load_state_data(*s_data) # Random actions until the target distance is reached cur_distance = 0 target_distance = max( 1, distance + np.random.randint(-1, 2)) # might be greater than maxdist? while cur_distance != target_distance: negative, _, done, _ = env2.step( env2.get_expert_action(eps=fake_eps)) if done: break cur_distance = np.abs(s_data[5] - env2.pos).sum() negative_data = env2.get_state_data() negatives.append(negative) negatives_data.append(negative_data) dataset += [(s, s_prime)] + [(s, s_neg) for s_neg in negatives] state_data += [(s_data, s_prime_data)] + [ (s_data, s_neg_data) for s_neg_data in negatives_data ] start = cur_num stop = cur_num + (1 + n_negatives) idx[seed] = (start, stop) cur_num += 1 + n_negatives # save the data path = pathlib.Path(data_path) path.mkdir(parents=True, exist_ok=True) with open(path / f'{filename_suffix}.pkl', 'wb') as f: pickle.dump(dataset, f) with open(path / f'state_data_{filename_suffix}.pkl', 'wb') as f: pickle.dump(state_data, f) with open(path / f'idx_{filename_suffix}.pkl', 'wb') as f: pickle.dump(idx, f)