def _evaluate(self, epoch): """Perform evaluation for the current policy. We always use the most recent policy, but for computational efficiency we sometimes use a stale version of the metapolicy. During evaluation, our policy expects an un-augmented observation. :param epoch: The epoch number. :return: None """ if self._eval_n_episodes < 1: return if epoch % self._find_best_skill_interval == 0: self._single_option_policy = self._get_best_single_option_policy() for (policy, policy_name) in [(self._single_option_policy, 'best_single_option_policy')]: with logger.tabular_prefix(policy_name + '/'), logger.prefix(policy_name + '/'): with self._policy.deterministic(self._eval_deterministic): if self._eval_render: paths = rollouts(self._eval_env, policy, self._max_path_length, self._eval_n_episodes, render=True, render_mode='rgb_array') else: paths = rollouts(self._eval_env, policy, self._max_path_length, self._eval_n_episodes) total_returns = [path['rewards'].sum() for path in paths] episode_lengths = [len(p['rewards']) for p in paths] logger.record_tabular('return-average', np.mean(total_returns)) logger.record_tabular('return-min', np.min(total_returns)) logger.record_tabular('return-max', np.max(total_returns)) logger.record_tabular('return-std', np.std(total_returns)) logger.record_tabular('episode-length-avg', np.mean(episode_lengths)) logger.record_tabular('episode-length-min', np.min(episode_lengths)) logger.record_tabular('episode-length-max', np.max(episode_lengths)) logger.record_tabular('episode-length-std', np.std(episode_lengths)) self._eval_env.log_diagnostics(paths) batch = self._pool.random_batch(self._batch_size) self.log_diagnostics(batch)
def _evaluate(self, epoch): """Perform evaluation for the current policy. :param epoch: The epoch number. :return: None """ if self._eval_n_episodes < 1: return with self._policy.deterministic(self._eval_deterministic): paths = rollouts(self._eval_env, self._policy, self._max_path_length, self._eval_n_episodes, False) total_returns = [path['rewards'].sum() for path in paths] episode_lengths = [len(p['rewards']) for p in paths] logger.record_tabular('return-average', np.mean(total_returns)) logger.record_tabular('return-min', np.min(total_returns)) logger.record_tabular('return-max', np.max(total_returns)) logger.record_tabular('return-std', np.std(total_returns)) logger.record_tabular('episode-length-avg', np.mean(episode_lengths)) logger.record_tabular('episode-length-min', np.min(episode_lengths)) logger.record_tabular('episode-length-max', np.max(episode_lengths)) logger.record_tabular('episode-length-std', np.std(episode_lengths)) self._eval_env.log_diagnostics(paths) if self._eval_render: self._eval_env.render(paths) batch = self._pool.random_batch(self._batch_size) self.log_diagnostics(batch)
def _save_traces(self, filename): utils._make_dir(filename) obs_vec = [] for z in range(self._num_skills): fixed_z_policy = FixedOptionPolicy(self._policy, self._num_skills, z) paths = rollouts(self._eval_env, fixed_z_policy, self._max_path_length, n_paths=3, render=False) obs_vec.append([path['observations'].tolist() for path in paths]) with open(filename, 'w') as f: json.dump(obs_vec, f)
def _get_best_single_option_policy(self): best_returns = float('-inf') best_z = None for z in range(self._num_skills): fixed_z_policy = FixedOptionPolicy(self._policy, self._num_skills, z) paths = rollouts(self._eval_env, fixed_z_policy, self._max_path_length, self._best_skill_n_rollouts, render=False) total_returns = np.mean([path['rewards'].sum() for path in paths]) if total_returns > best_returns: best_returns = total_returns best_z = z return FixedOptionPolicy(self._policy, self._num_skills, best_z)
def collect_expert_trajectories(expert_snapshot, max_path_length): tf.logging.info('Collecting expert trajectories') with tf.Session() as sess: data = joblib.load(expert_snapshot) policy = data['policy'] env = data['env'] num_skills = data['policy'].observation_space.flat_dim - data['env'].spec.observation_space.flat_dim traj_vec = [] with policy.deterministic(True): for z in range(num_skills): fixed_z_policy = FixedOptionPolicy(policy, num_skills, z) new_paths = rollouts(env, fixed_z_policy, args.max_path_length, n_paths=1) path = new_paths[0] traj_vec.append(path) tf.reset_default_graph() return traj_vec
def _evaluate(self, epoch): """Perform evaluation for the current policy. :param epoch: The epoch number. :return: None """ if self._eval_n_episodes < 1: return with self._policy.deterministic(self._eval_deterministic): paths = rollouts( self._eval_env, self._policy, self.sampler._max_path_length, self._eval_n_episodes, ) total_returns = [path['rewards'].sum() for path in paths] total_violate = [ 10 / 3 * sum([x['cost'] for x in path['env_infos']]) for path in paths ] episode_lengths = [len(p['rewards']) for p in paths] logger.record_tabular('return-average', np.mean(total_returns)) logger.record_tabular('return-min', np.min(total_returns)) logger.record_tabular('return-max', np.max(total_returns)) logger.record_tabular('return-std', np.std(total_returns)) logger.record_tabular('violate-average', np.mean(total_violate)) logger.record_tabular('violate-min', np.min(total_violate)) logger.record_tabular('violate-max', np.max(total_violate)) logger.record_tabular('violate-std', np.std(total_violate)) logger.record_tabular('episode-length-avg', np.mean(episode_lengths)) logger.record_tabular('episode-length-min', np.min(episode_lengths)) logger.record_tabular('episode-length-max', np.max(episode_lengths)) logger.record_tabular('episode-length-std', np.std(episode_lengths)) self._eval_env.log_diagnostics(paths) if self._eval_render: self._eval_env.render(paths) iteration = epoch * self._epoch_length batch = self.sampler.random_batch() self.log_diagnostics(iteration, batch)
def _evaluate(self, epoch, initial_exploration_done, sub_level_policies, g): """Perform evaluation for the current policy. :param epoch: The epoch number. :return: None """ if self._eval_n_episodes < 1: return with self._policy.deterministic(self._eval_deterministic): paths = rollouts( self._eval_env, self._policy, sub_level_policies, initial_exploration_done, self.sampler._max_path_length, self._eval_n_episodes, g, ) '''with self._policy.deterministic(self._eval_deterministic): paths = rollouts(self._eval_env, self._policy,self._sub_level_policies,initial_exploration_done, self.sampler._max_path_length, self._eval_n_episodes, )''' total_returns = [path['rewards'].sum() for path in paths] episode_lengths = [len(p['rewards']) for p in paths] logger.record_tabular('return-average', np.mean(total_returns)) logger.record_tabular('return-min', np.min(total_returns)) logger.record_tabular('return-max', np.max(total_returns)) logger.record_tabular('return-std', np.std(total_returns)) logger.record_tabular('episode-length-avg', np.mean(episode_lengths)) logger.record_tabular('episode-length-min', np.min(episode_lengths)) logger.record_tabular('episode-length-max', np.max(episode_lengths)) logger.record_tabular('episode-length-std', np.std(episode_lengths)) self._eval_env.log_diagnostics(paths) if self._eval_render: self._eval_env.render(paths) iteration = epoch * self._epoch_length batch = self.sampler.random_batch() self.log_diagnostics(iteration, batch)
def get_best_skill(policy, env, num_skills, max_path_length): tf.logging.info('Finding best skill to finetune...') reward_list = [] with policy.deterministic(True): for z in range(num_skills): fixed_z_policy = FixedOptionPolicy(policy, num_skills, z) new_paths = rollouts(env, fixed_z_policy, max_path_length, n_paths=2) total_returns = np.mean( [path['rewards'].sum() for path in new_paths]) tf.logging.info('Reward for skill %d = %.3f', z, total_returns) reward_list.append(total_returns) best_z = np.argmax(reward_list) tf.logging.info('Best skill found: z = %d, reward = %d', best_z, reward_list[best_z]) return best_z
def _evaluate(self, epoch): """Perform evaluation for the current policy. :param epoch: The epoch number. :return: None """ if self._eval_n_episodes < 1: return with self._policy.deterministic(self._eval_deterministic): paths = rollouts( self._eval_env, self._policy, self.sampler._max_path_length, self._eval_n_episodes, ) total_returns = [path['rewards'].sum() for path in paths] episode_lengths = [len(p['rewards']) for p in paths] eval_obs = np.vstack([path['observations'] for path in paths]) cloud_state = np.vstack( [path['env_infos']['cloud_cpu_used'] for path in paths]) print(eval_obs.shape) print(cloud_state) ### before (for ~v5) # eval_edge_s = np.transpose(eval_obs)[:40].reshape(-1, 8, len(eval_obs)) # eval_cloud_s = np.transpose(eval_obs)[40:] # eval_edge_queue = eval_edge_s[2]-eval_edge_s[1] # shape (8, episode length) # eval_edge_cpu = eval_edge_s[3] # eval_workload = eval_edge_s[4] # eval_cloud_queue = eval_cloud_s[2] # eval_cloud_cpu = eval_cloud_s[3] # eval_edge_queue_avg = eval_edge_queue[:3].mean() # shape (,) # eval_cloud_queue_avg = eval_cloud_queue.mean() # float # # eval_edge_power = 10 * (40*eval_edge_cpu.sum(axis=0) * (10 ** 9) / 10) ** 3 # shape (5000,) # eval_cloud_power = 54 * (216*eval_cloud_cpu * (10 ** 9) / 54) ** 3 # shape (5000,) # # eval_edge_power_avg = eval_edge_power.mean() # eval_cloud_power_avg = eval_cloud_power.mean() # # eval_power = eval_edge_power_avg + eval_cloud_power_avg eval_edge_s = np.transpose(eval_obs)[:15].reshape(-1, 3, len(eval_obs)) # eval_edge_s = np.transpose(eval_obs)[:40].reshape(-1, 8, len(eval_obs)) # eval_edge_queue = eval_edge_s[2]-eval_edge_s[1] # shape (8, episode length) eval_edge_queue = eval_edge_s[1] # shape (8, episode length) eval_edge_cpu = eval_edge_s[3] eval_workload = eval_edge_s[4] eval_cloud_cpu = cloud_state[0] eval_edge_queue_avg = eval_edge_queue[:3].mean() # shape (,) # eval_edge_queue_avg = eval_edge_queue.mean() # shape (,) eval_edge_power = 10 * (40 * eval_edge_cpu.sum(axis=0) * (10**9) / 10)**3 # shape (5000,) eval_cloud_power = 54 * (216 * eval_cloud_cpu * (10**9) / 54)**3 # shape (5000,) eval_edge_power_avg = eval_edge_power.mean() eval_cloud_power_avg = eval_cloud_power.mean() eval_power = eval_edge_power_avg + eval_cloud_power_avg for i in range(int(len(eval_edge_power) / 100)): start_ = i * 100 end_ = (i + 1) * 100 logger.record_tabular("eval_q_1_%02d%02d" % (i, i + 1), np.mean(eval_edge_queue[0, start_:end_])) logger.record_tabular("eval_q_2_%02d%02d" % (i, i + 1), np.mean(eval_edge_queue[1, start_:end_])) logger.record_tabular("eval_q_3_%02d%02d" % (i, i + 1), np.mean(eval_edge_queue[2, start_:end_])) # logger.record_tabular("eval_q_4_%02d%02d"%(i,i+1), np.mean(eval_edge_queue[3, start_:end_])) # logger.record_tabular("eval_q_5_%02d%02d"%(i,i+1), np.mean(eval_edge_queue[4, start_:end_])) # logger.record_tabular("eval_q_6_%02d%02d"%(i,i+1), np.mean(eval_edge_queue[5, start_:end_])) # logger.record_tabular("eval_q_7_%02d%02d"%(i,i+1), np.mean(eval_edge_queue[6, start_:end_])) # logger.record_tabular("eval_q_8_%02d%02d"%(i,i+1), np.mean(eval_edge_queue[7, start_:end_])) # logger.record_tabular("eval_q_var_%02d%02d"%(i,i+1), np.mean(np.var(eval_edge_queue[:8, start_:end_], axis=0))) logger.record_tabular( "eval_q_var_%02d%02d" % (i, i + 1), np.mean(np.var(eval_edge_queue[:3, start_:end_], axis=0))) logger.record_tabular("eval_q_1_all", np.mean(eval_edge_queue[0, :])) logger.record_tabular("eval_q_2_all", np.mean(eval_edge_queue[1, :])) logger.record_tabular("eval_q_3_all", np.mean(eval_edge_queue[2, :])) # logger.record_tabular("eval_q_4_all", np.mean(eval_edge_queue[3,:])) # logger.record_tabular("eval_q_5_all", np.mean(eval_edge_queue[4,:])) # logger.record_tabular("eval_q_6_all", np.mean(eval_edge_queue[5,:])) # logger.record_tabular("eval_q_7_all", np.mean(eval_edge_queue[6,:])) # logger.record_tabular("eval_q_8_all", np.mean(eval_edge_queue[7,:])) # logger.record_tabular("eval_q_var_all", np.mean(np.var(eval_edge_queue[:8,:],axis=0))) logger.record_tabular("eval_q_var_all", np.mean(np.var(eval_edge_queue[:3, :], axis=0))) logger.record_tabular("eval_edge_queue_avg", eval_edge_queue_avg) logger.record_tabular("eval_power", eval_power) logger.record_tabular('return-average', np.mean(total_returns)) logger.record_tabular('return-average', np.mean(total_returns)) logger.record_tabular('return-average', np.mean(total_returns)) logger.record_tabular('return-min', np.min(total_returns)) logger.record_tabular('return-max', np.max(total_returns)) logger.record_tabular('return-std', np.std(total_returns)) logger.record_tabular('episode-length-avg', np.mean(episode_lengths)) logger.record_tabular('episode-length-min', np.min(episode_lengths)) logger.record_tabular('episode-length-max', np.max(episode_lengths)) logger.record_tabular('episode-length-std', np.std(episode_lengths)) self._eval_env.log_diagnostics(paths) if self._eval_render: self._eval_env.render(paths) iteration = epoch * self._epoch_length batch = self.sampler.random_batch() self.log_diagnostics(iteration, batch)
def main(): parser = argparse.ArgumentParser() parser.add_argument('file', type=str, help='Path to the snapshot file.') parser.add_argument('--max-path-length', '-l', type=int, default=100) parser.add_argument('--speedup', '-s', type=float, default=1) parser.add_argument('--deterministic', '-d', dest='deterministic', action='store_true') parser.add_argument('--no-deterministic', '-nd', dest='deterministic', action='store_false') parser.add_argument('--separate_videos', type=bool, default=False) parser.set_defaults(deterministic=True) # unity_env args parser.add_argument('--idx', type=int, default=0) parser.add_argument('--no_graphics', type=bool, default=False) args = parser.parse_args() filename = os.path.splitext(args.file)[0] + '.avi' best_filename = os.path.splitext(args.file)[0] + '_best.avi' worst_filename = os.path.splitext(args.file)[0] + '_worst.avi' path_list = [] reward_list = [] with tf.Session() as sess: data = joblib.load(args.file) policy = data['policy'] env = data['env'] num_skills = data['policy'].observation_space.flat_dim - data[ 'env'].spec.observation_space.flat_dim with policy.deterministic(args.deterministic): for z in range(num_skills): fixed_z_policy = FixedOptionPolicy(policy, num_skills, z) new_paths = rollouts(env, fixed_z_policy, args.max_path_length, n_paths=1, render=True, render_mode='rgb_array') path_list.append(new_paths) total_returns = np.mean( [path['rewards'].sum() for path in new_paths]) reward_list.append(total_returns) if args.separate_videos: base = os.path.splitext(args.file)[0] end = '_skill_%02d.avi' % z skill_filename = base + end utils._save_video(new_paths, skill_filename) import csv file_path = args.file.split('/') file_path = file_path[-1].split('.')[0] file_path = './data/' + file_path if not os.path.exists(file_path): os.mkdir(file_path) print(file_path) with open(file_path + '/path%02d.csv' % z, 'w', newline='') as csvfile: spamwriter = csv.writer(csvfile, delimiter=' ', quotechar='|', quoting=csv.QUOTE_MINIMAL) spamwriter.writerow( ['X', '-X', 'Y', '-Y', 'X_speed', 'Y_speed']) for ob in path_list[-1][0]['observations']: spamwriter.writerow(ob) if not args.separate_videos: paths = [path for paths in path_list for path in paths] utils._save_video(paths, filename) print('Best reward: %d' % np.max(reward_list)) print('Worst reward: %d' % np.min(reward_list)) # Record extra long videos for best and worst skills: best_z = np.argmax(reward_list) worst_z = np.argmin(reward_list) for (z, filename) in [(best_z, best_filename), (worst_z, worst_filename)]: fixed_z_policy = FixedOptionPolicy(policy, num_skills, z) new_paths = rollouts(env, fixed_z_policy, 3 * args.max_path_length, n_paths=1, render=True, render_mode='rgb_array') utils._save_video(new_paths, filename) env.terminate()
def train(self): """ CG: the function that conducts ensemble training. :return: """ # Set up parameters for the training process. self._n_epochs = self._base_ac_params['n_epochs'] self._epoch_length = self._base_ac_params['epoch_length'] self._n_train_repeat = self._base_ac_params['n_train_repeat'] self._n_initial_exploration_steps = self._base_ac_params[ 'n_initial_exploration_steps'] self._eval_render = self._base_ac_params['eval_render'] self._eval_n_episodes = self._base_ac_params['eval_n_episodes'] self._eval_deterministic = self._base_ac_params['eval_deterministic'] # Set up the evaluation environment. if self._eval_n_episodes > 0: with tf.variable_scope("low_level_policy", reuse=True): self._eval_env = deep_clone(self._env) # Set up the tensor flow session. self._sess = tf_utils.get_default_session() # Import required libraries for training. import random import math import operator import numpy as np # Initialize the sampler. alg_ins = random.choice(self._alg_instances) self._sampler.initialize(self._env, alg_ins[0].policy, self._pool) # Perform the training/evaluation process. num_episode = 0. with self._sess.as_default(): gt.rename_root('RLAlgorithm') gt.reset() gt.set_def_unique(False) for epoch in gt.timed_for(range(self._n_epochs + 1), save_itrs=True): logger.push_prefix('Epoch #%d | ' % epoch) for t in range(self._epoch_length): isEpisodeEnd = self._sampler.sample() # If an episode is ended, we need to update performance statistics for each AC instance and # pick randomly another AC instance for next episode of exploration. if isEpisodeEnd: num_episode = num_episode + 1. alg_ins[1] = 0.9 * alg_ins[ 1] + 0.1 * self._sampler._last_path_return alg_ins[2] = alg_ins[2] + 1. if self._use_ucb: # Select an algorithm instance based on UCB. selected = False for ains in self._alg_instances: if ains[2] < 1.: alg_ins = ains selected = True break else: ains[3] = ains[1] + math.sqrt( 2.0 * math.log(num_episode) / ains[2]) if not selected: alg_ins = max(self._alg_instances, key=operator.itemgetter(3)) else: # Select an algorithm instance uniformly at random. alg_ins = random.choice(self._alg_instances) self._sampler.set_policy(alg_ins[0].policy) if not self._sampler.batch_ready(): continue gt.stamp('sample') # Perform training over all AC instances. for i in range(self._n_train_repeat): batch = self._sampler.random_batch() for ains in self._alg_instances: ains[0]._do_training(iteration=t + epoch * self._epoch_length, batch=batch) gt.stamp('train') # Perform evaluation after one full epoch of training is completed. if self._eval_n_episodes < 1: continue if self._evaluation_strategy == 'ensemble': # Use a whole ensemble of AC instances for evaluation. paths = rollouts(self._eval_env, self, self._sampler._max_path_length, self._eval_n_episodes) elif self._evaluation_strategy == 'best-policy': # Choose the AC instance with the highest observed performance so far for evaluation. eval_alg_ins = max(self._alg_instances, key=operator.itemgetter(1)) with eval_alg_ins[0].policy.deterministic( self._eval_deterministic): paths = rollouts(self._eval_env, eval_alg_ins[0].policy, self._sampler._max_path_length, self._eval_n_episodes) else: paths = None if paths is not None: total_returns = [path['rewards'].sum() for path in paths] episode_lengths = [len(p['rewards']) for p in paths] logger.record_tabular('return-average', np.mean(total_returns)) logger.record_tabular('return-min', np.min(total_returns)) logger.record_tabular('return-max', np.max(total_returns)) logger.record_tabular('return-std', np.std(total_returns)) logger.record_tabular('episode-length-avg', np.mean(episode_lengths)) logger.record_tabular('episode-length-min', np.min(episode_lengths)) logger.record_tabular('episode-length-max', np.max(episode_lengths)) logger.record_tabular('episode-length-std', np.std(episode_lengths)) self._eval_env.log_diagnostics(paths) if self._eval_render: self._eval_env.render(paths) # Produce log info after each episode of training and evaluation. times_itrs = gt.get_times().stamps.itrs eval_time = times_itrs['eval'][-1] if epoch > 1 else 0 total_time = gt.get_times().total logger.record_tabular('time-train', times_itrs['train'][-1]) logger.record_tabular('time-eval', eval_time) logger.record_tabular('time-sample', times_itrs['sample'][-1]) logger.record_tabular('time-total', total_time) logger.record_tabular('epoch', epoch) self._sampler.log_diagnostics() logger.dump_tabular(with_prefix=False) logger.pop_prefix() gt.stamp('eval') # Terminate the sampler after the training process is completed. self._sampler.terminate()
worst_filename = os.path.splitext(args.file)[0] + '_worst.avi' path_list = [] reward_list = [] with tf.Session() as sess: data = joblib.load(args.file) policy = data['policy'] env = data['env'] num_skills = data['policy'].observation_space.flat_dim - data['env'].spec.observation_space.flat_dim with policy.deterministic(args.deterministic): for z in range(num_skills): fixed_z_policy = FixedOptionPolicy(policy, num_skills, z) new_paths = rollouts(env, fixed_z_policy, args.max_path_length, n_paths=1, render=True, render_mode='rgb_array') path_list.append(new_paths) total_returns = np.mean([path['rewards'].sum() for path in new_paths]) reward_list.append(total_returns) if args.separate_videos: base = os.path.splitext(args.file)[0] end = '_skill_%02d.avi' % z skill_filename = base + end utils._save_video(new_paths, skill_filename) if not args.separate_videos: paths = [path for paths in path_list for path in paths] utils._save_video(paths, filename)
def sample_skills_to_bd(self, **kwargs): """ Evaluate all the latent skills and extract the behaviour descriptors """ list_traj_main = [] list_traj_aux = [] list_outcomes = [] list_skill_bd = [] log_freq = max(self._num_skills // 10, 10) eval_time = time.time() logger.log("EVALUATING: {} skills.".format( self._num_skills)) ### PARALELIZE THIS for z in range(self._num_skills): # Make policy deterministic fixed_z_policy = DeterministicFixedOptionPolicy( self._policy, self._num_skills, z) # Evaluate skill self._eval_env._wrapped_env.env.initialize(seed_task=SEED_TASK) paths = rollouts(env=self._eval_env, policy=fixed_z_policy, path_length=self._max_path_length, n_paths=1, render=False) # Extract trajectory from paths traj_main = paths[0]['env_infos']['position'] traj_aux = paths[0]['env_infos']['position_aux'] list_traj_main.append(traj_main) list_traj_aux.append(traj_aux) # Extract outcomes from paths trial_outcome = self._eval_env._wrapped_env.env.finalize( state=paths[0]['last_obs'], rew_list=paths[0]['rewards'], traj=traj_main, traj_aux=traj_aux) list_outcomes.append(trial_outcome) # Extract and convert bd from paths self.bd_metric.restart() if self.bd_metric.metric_name == 'contact_grid': [self.bd_metric.update({'contact_objects': idict}) \ for idict in paths[0]['env_infos']['contact_objects']] trial_metric = self.bd_metric.calculate(traj=traj_main, traj_aux=traj_aux) list_skill_bd.append(np.argmax(trial_metric)) if not z % log_freq: logger.log("\teval {:5} skills - {} behaviours.".format( z, len(np.unique(list_skill_bd)))) eval_time = time.time() - eval_time # Extract unique data unique_bds, unique_idx = np.unique(list_skill_bd, return_index=True) n_bd = len(unique_bds) unique_outcomes = np.array(list_outcomes)[unique_idx] unique_traj_main = itemgetter(*unique_idx)(list_traj_main) unique_traj_aux = itemgetter(*unique_idx)(list_traj_aux) # Save to file at exact point: nump epoch, num episodes self._write_discovery(eval_time=eval_time, unique_bds=unique_bds, unique_outcomes=unique_outcomes, **kwargs) # Save data at the end unique_bds_1hot = np.zeros((n_bd, self.bd_metric.metric_size)) unique_bds_1hot[np.arange(n_bd), unique_bds] = 1 self._save_dataset( **{ "outcomes": unique_outcomes, "traj_main": unique_traj_main, "traj_aux": unique_traj_aux, "metric_bd": unique_bds_1hot })