class MetaRLAlgorithm(metaclass=abc.ABCMeta): def __init__( self, env, policy, train_tasks, eval_tasks, meta_batch=64, num_iterations=100, num_train_steps_per_itr=1000, num_tasks_sample=100, num_steps_per_task=100, num_evals=10, num_steps_per_eval=1000, batch_size=1024, embedding_batch_size=1024, embedding_mini_batch_size=1024, max_path_length=1000, discount=0.99, replay_buffer_size=1000000, #1000000, reward_scale=1, train_embedding_source='posterior_only', eval_embedding_source='initial_pool', eval_deterministic=True, render=False, save_replay_buffer=False, save_algorithm=False, save_environment=False, obs_emb_dim=0): """ Base class for Meta RL Algorithms :param env: training env :param policy: policy that is conditioned on a latent variable z that rl_algorithm is responsible for feeding in :param train_tasks: list of tasks used for training :param eval_tasks: list of tasks used for eval :param meta_batch: number of tasks used for meta-update :param num_iterations: number of meta-updates taken :param num_train_steps_per_itr: number of meta-updates performed per iteration :param num_tasks_sample: number of train tasks to sample to collect data for :param num_steps_per_task: number of transitions to collect per task :param num_evals: number of independent evaluation runs, with separate task encodings :param num_steps_per_eval: number of transitions to sample for evaluation :param batch_size: size of batches used to compute RL update :param embedding_batch_size: size of batches used to compute embedding :param embedding_mini_batch_size: size of batch used for encoder update :param max_path_length: max episode length :param discount: :param replay_buffer_size: max replay buffer size :param reward_scale: :param render: :param save_replay_buffer: :param save_algorithm: :param save_environment: """ self.env = env self.policy = policy self.exploration_policy = policy # Can potentially use a different policy purely for exploration rather than also solving tasks, currently not being used self.train_tasks = train_tasks self.eval_tasks = eval_tasks self.meta_batch = meta_batch self.num_iterations = num_iterations self.num_train_steps_per_itr = num_train_steps_per_itr self.num_tasks_sample = num_tasks_sample self.num_steps_per_task = num_steps_per_task self.num_evals = num_evals self.num_steps_per_eval = num_steps_per_eval self.batch_size = batch_size self.embedding_batch_size = embedding_batch_size self.embedding_mini_batch_size = embedding_mini_batch_size self.max_path_length = max_path_length self.discount = discount self.replay_buffer_size = min( int(replay_buffer_size / (len(train_tasks))), 1000) self.reward_scale = reward_scale self.train_embedding_source = train_embedding_source self.eval_embedding_source = eval_embedding_source # TODO: add options for computing embeddings on train tasks too self.eval_deterministic = eval_deterministic self.render = render self.save_replay_buffer = save_replay_buffer self.save_algorithm = save_algorithm self.save_environment = save_environment self.eval_sampler = InPlacePathSampler( env=env, policy=policy, max_samples=self.num_steps_per_eval, max_path_length=self.max_path_length, ) # separate replay buffers for # - training RL update # - training encoder update # - testing encoder self.replay_buffer = MultiTaskReplayBuffer(self.replay_buffer_size, env, self.train_tasks, state_dim=obs_emb_dim) self.enc_replay_buffer = MultiTaskReplayBuffer(self.replay_buffer_size, env, self.train_tasks, state_dim=obs_emb_dim) self.eval_enc_replay_buffer = MultiTaskReplayBuffer( self.replay_buffer_size, env, self.eval_tasks, state_dim=obs_emb_dim) self._n_env_steps_total = 0 self._n_train_steps_total = 0 self._n_rollouts_total = 0 self._do_train_time = 0 self._epoch_start_time = None self._algo_start_time = None self._old_table_keys = None self._current_path_builder = PathBuilder() self._exploration_paths = [] def make_exploration_policy(self, policy): return policy def make_eval_policy(self, policy): return policy def sample_task(self, is_eval=False): ''' sample task randomly ''' if is_eval: idx = np.random.randint(len(self.eval_tasks)) else: idx = np.random.randint(len(self.train_tasks)) return idx def train(self): ''' meta-training loop ''' self.pretrain() params = self.get_epoch_snapshot(-1) logger.save_itr_params(-1, params) gt.reset() gt.set_def_unique(False) self._current_path_builder = PathBuilder() self.train_obs = self._start_new_rollout() # at each iteration, we first collect data from tasks, perform meta-updates, then try to evaluate for it_ in gt.timed_for( range(self.num_iterations), save_itrs=True, ): self._start_epoch(it_) self.training_mode(True) if it_ == 0: print('collecting initial pool of data for train and eval') # temp for evaluating for idx in self.train_tasks: print('train task', idx) self.task_idx = idx self.env.reset_task(idx) self.collect_data_sampling_from_prior( num_samples=self.max_path_length * 10, resample_z_every_n=self.max_path_length, eval_task=False) """ for idx in self.eval_tasks: self.task_idx = idx self.env.reset_task(idx) # TODO: make number of initial trajectories a parameter self.collect_data_sampling_from_prior(num_samples=self.max_path_length * 20, resample_z_every_n=self.max_path_length, eval_task=True) """ # Sample data from train tasks. for i in range(self.num_tasks_sample): idx = np.random.randint(len(self.train_tasks)) self.task_idx = idx self.env.reset_task(idx) # TODO: there may be more permutations of sampling/adding to encoding buffer we may wish to try if self.train_embedding_source == 'initial_pool': # embeddings are computed using only the initial pool of data # sample data from posterior to train RL algorithm self.collect_data_from_task_posterior( idx=idx, num_samples=self.num_steps_per_task, add_to_enc_buffer=False) elif self.train_embedding_source == 'posterior_only': self.collect_data_from_task_posterior( idx=idx, num_samples=self.num_steps_per_task, eval_task=False, add_to_enc_buffer=True) elif self.train_embedding_source == 'online_exploration_trajectories': # embeddings are computed using only data collected using the prior # sample data from posterior to train RL algorithm self.enc_replay_buffer.task_buffers[idx].clear() # resamples using current policy, conditioned on prior self.collect_data_sampling_from_prior( num_samples=self.num_steps_per_task, resample_z_every_n=self.max_path_length, add_to_enc_buffer=True) self.env.reset_task(idx) self.collect_data_from_task_posterior( idx=idx, num_samples=self.num_steps_per_task, add_to_enc_buffer=False, viz=True) elif self.train_embedding_source == 'online_on_policy_trajectories': # sample from prior, then sample more from the posterior # embeddings computed from both prior and posterior data self.enc_replay_buffer.task_buffers[idx].clear() self.collect_data_online( idx=idx, num_samples=self.num_steps_per_task, add_to_enc_buffer=True) else: raise Exception( "Invalid option for computing train embedding {}". format(self.train_embedding_source)) # Sample train tasks and compute gradient updates on parameters. for train_step in range(self.num_train_steps_per_itr): indices = np.random.choice(self.train_tasks, self.meta_batch) self._do_training(indices, train_step) self._n_train_steps_total += 1 gt.stamp('train') #self.training_mode(False) # eval self._try_to_eval(it_) gt.stamp('eval') self._end_epoch() def pretrain(self): """ Do anything before the main training phase. """ pass def sample_z_from_prior(self): """ Samples z from the prior distribution, which can be either a delta function at 0 or a standard Gaussian depending on whether we use the information bottleneck. :return: latent z as a Numpy array """ pass def sample_z_from_posterior(self, idx, eval_task): """ Samples z from the posterior distribution given data from task idx, where data comes from the encoding buffer :param idx: task idx from which to compute the posterior from :param eval_task: whether or not the task is an eval task :return: latent z as a Numpy array """ pass # TODO: maybe find a better name for resample_z_every_n? def collect_data_sampling_from_prior(self, num_samples=1, resample_z_every_n=None, eval_task=False, add_to_enc_buffer=True): # do not resample z if resample_z_every_n is None if resample_z_every_n is None: self.policy.clear_z() self.collect_data(self.policy, num_samples=num_samples, eval_task=eval_task, add_to_enc_buffer=add_to_enc_buffer) else: # collects more data in batches of resample_z_every_n until done while num_samples > 0: self.collect_data_sampling_from_prior( num_samples=min(resample_z_every_n, num_samples), resample_z_every_n=None, eval_task=eval_task, add_to_enc_buffer=add_to_enc_buffer) num_samples -= resample_z_every_n def collect_data_from_task_posterior(self, idx, num_samples=1, resample_z_every_n=None, eval_task=False, add_to_enc_buffer=True, viz=False): # do not resample z if resample_z_every_n is None if resample_z_every_n is None: self.sample_z_from_posterior(idx, eval_task=eval_task) self.collect_data(self.policy, num_samples=num_samples, eval_task=eval_task, add_to_enc_buffer=add_to_enc_buffer, viz=viz) else: # collects more data in batches of resample_z_every_n until done while num_samples > 0: self.collect_data_from_task_posterior( idx=idx, num_samples=min(resample_z_every_n, num_samples), resample_z_every_n=None, eval_task=eval_task, add_to_enc_buffer=add_to_enc_buffer, viz=viz) num_samples -= resample_z_every_n # split number of prior and posterior samples def collect_data_online(self, idx, num_samples, eval_task=False, add_to_enc_buffer=True): self.collect_data_sampling_from_prior( num_samples=num_samples, resample_z_every_n=self.max_path_length, eval_task=eval_task, add_to_enc_buffer=True) self.env.reset_task(idx) self.collect_data_from_task_posterior( idx=idx, num_samples=num_samples, resample_z_every_n=self.max_path_length, eval_task=eval_task, add_to_enc_buffer=add_to_enc_buffer, viz=True) # TODO: since switching tasks now resets the environment, we are not correctly handling episodes terminating # correctly. We also aren't using the episodes anywhere, but we should probably change this to make it gather paths # until we have more samples than num_samples, to make sure every episode cleanly terminates when intended. # @profile def collect_data(self, agent, num_samples=1, max_resets=None, eval_task=False, add_to_enc_buffer=True, viz=False): ''' collect data from current env in batch mode with given policy ''' images = [] # if num_samples == 50: # import pdb; pdb.set_trace() env_time = self.env.time rews = [] terms = [] n_resets = 0 for _ in range(num_samples): action, agent_info = self._get_action_and_info( agent, self.train_obs) if self.render: self.env.render() next_ob, raw_reward, terminal, env_info = (self.env.step(action)) if viz: images.append(next_ob) # vis.image(next_ob[-1]) reward = raw_reward rews += [reward] terms += [terminal] terminal = np.array([terminal]) reward = np.array([reward]) self._handle_step( self.task_idx, np.concatenate( [self.train_obs.flatten()[None], agent_info['obs_emb']], axis=-1), action, reward, np.concatenate([ next_ob.flatten()[None], torch.zeros(agent_info['obs_emb'].shape) ], axis=-1), terminal, eval_task=eval_task, add_to_enc_buffer=add_to_enc_buffer, agent_info=agent_info, env_info=env_info, ) # TODO USE masking here to handle the terminal episodes # print(len(self._current_path_builder)) if terminal or len( self._current_path_builder) >= self.max_path_length: self._handle_rollout_ending(eval_task=eval_task) self.train_obs = self._start_new_rollout() n_resets += 1 if _ + self.max_path_length > num_samples - 1: break if max_resets is not None and n_resets > max_resets: break else: # print((next_ob - self.train_obs).sum()) # self.train_obs = None self.train_obs = next_ob if viz and np.random.random() < 0.3: # import pdb; pdb.set_trace() vis.images(np.stack(images)[:, -1:]) vis.line(np.array([rews, terms]).T, opts=dict(width=400, height=320)) vis.text('', opts=dict(width=10000, height=5)) # vis.video(np.stack(images)) if not eval_task: self._n_env_steps_total += num_samples gt.stamp('sample') def _try_to_eval(self, epoch): logger.save_extra_data(self.get_extra_data_to_save(epoch)) if self._can_evaluate(): self.evaluate(epoch) params = self.get_epoch_snapshot(epoch) logger.save_itr_params(epoch, params) table_keys = logger.get_table_key_set() if self._old_table_keys is not None: assert table_keys == self._old_table_keys, ( "Table keys cannot change from iteration to iteration.") self._old_table_keys = table_keys logger.record_tabular( "Number of train steps total", self._n_train_steps_total, ) logger.record_tabular( "Number of env steps total", self._n_env_steps_total, ) logger.record_tabular( "Number of rollouts total", self._n_rollouts_total, ) times_itrs = gt.get_times().stamps.itrs train_time = times_itrs['train'][-1] sample_time = times_itrs['sample'][-1] eval_time = times_itrs['eval'][-1] if epoch > 0 else 0 epoch_time = train_time + sample_time + eval_time total_time = gt.get_times().total logger.record_tabular('Train Time (s)', train_time) logger.record_tabular('(Previous) Eval Time (s)', eval_time) logger.record_tabular('Sample Time (s)', sample_time) logger.record_tabular('Epoch Time (s)', epoch_time) logger.record_tabular('Total Train Time (s)', total_time) logger.record_tabular("Epoch", epoch) logger.dump_tabular(with_prefix=False, with_timestamp=False) else: logger.log("Skipping eval for now.") def _can_evaluate(self): """ One annoying thing about the logger table is that the keys at each iteration need to be the exact same. So unless you can compute everything, skip evaluation. A common example for why you might want to skip evaluation is that at the beginning of training, you may not have enough data for a validation and training set. :return: """ # import pdb; pdb.set_trace() return ( # len(self._exploration_paths) > 0 # and self.replay_buffer.num_steps_can_sample(self.task_idx) >= self.batch_size) def _can_train(self): return all([ self.replay_buffer.num_steps_can_sample(idx) >= self.batch_size for idx in self.train_tasks ]) def _get_action_and_info(self, agent, observation): """ Get an action to take in the environment. :param observation: :return: """ agent.set_num_steps_total(self._n_env_steps_total) return agent.get_action(observation, ) def _start_epoch(self, epoch): self._epoch_start_time = time.time() self._exploration_paths = [] self._do_train_time = 0 logger.push_prefix('Iteration #%d | ' % epoch) def _end_epoch(self): logger.log("Epoch Duration: {0}".format(time.time() - self._epoch_start_time)) logger.log("Started Training: {0}".format(self._can_train())) logger.pop_prefix() def _start_new_rollout(self): ret = self.env.reset() if isinstance(ret, tuple): ret = ret[0] return ret # not used def _handle_path(self, path): """ Naive implementation: just loop through each transition. :param path: :return: """ for (ob, action, reward, next_ob, terminal, agent_info, env_info) in zip( path["observations"], path["actions"], path["rewards"], path["next_observations"], path["terminals"], path["agent_infos"], path["env_infos"], ): self._handle_step( ob.reshape(-1), action, reward, next_ob.reshape(-1), terminal, agent_info=agent_info, env_info=env_info, ) self._handle_rollout_ending() def _handle_step( self, task_idx, observation, action, reward, next_observation, terminal, agent_info, env_info, eval_task=False, add_to_enc_buffer=True, ): """ Implement anything that needs to happen after every step :return: """ self._current_path_builder.add_all( task=task_idx, observations=observation, actions=action, rewards=reward, next_observations=next_observation, terminals=terminal, agent_infos=agent_info, env_infos=env_info, ) if eval_task: self.eval_enc_replay_buffer.add_sample( task=task_idx, observation=observation, action=action, reward=reward, terminal=terminal, next_observation=next_observation, agent_info=agent_info, env_info=env_info, ) else: self.replay_buffer.add_sample( task=task_idx, observation=observation, action=action, reward=reward, terminal=terminal, next_observation=next_observation, agent_info=agent_info, env_info=env_info, ) if add_to_enc_buffer: self.enc_replay_buffer.add_sample( task=task_idx, observation=observation, action=action, reward=reward, terminal=terminal, next_observation=next_observation, agent_info=agent_info, env_info=env_info, ) def _handle_rollout_ending(self, eval_task=False): """ Implement anything that needs to happen after every rollout. """ if eval_task: self.eval_enc_replay_buffer.terminate_episode(self.task_idx) else: self.replay_buffer.terminate_episode(self.task_idx) self.enc_replay_buffer.terminate_episode(self.task_idx) self._n_rollouts_total += 1 if len(self._current_path_builder) > 0: # and False: # self._exploration_paths.append( # self._current_path_builder.get_all_stacked() # ) self._current_path_builder = PathBuilder() def get_epoch_snapshot(self, epoch): data_to_save = dict( epoch=epoch, exploration_policy=self.exploration_policy, ) if self.save_environment: data_to_save['env'] = self.training_env return data_to_save def get_extra_data_to_save(self, epoch): """ Save things that shouldn't be saved every snapshot but rather overwritten every time. :param epoch: :return: """ if self.render: self.training_env.render(close=True) data_to_save = dict(epoch=epoch, ) if self.save_environment: data_to_save['env'] = self.training_env if self.save_replay_buffer: data_to_save['replay_buffer'] = self.replay_buffer if self.save_algorithm: data_to_save['algorithm'] = self return data_to_save @abc.abstractmethod def training_mode(self, mode): """ Set training mode to `mode`. :param mode: If True, training will happen (e.g. set the dropout probabilities to not all ones). """ pass @abc.abstractmethod def evaluate(self, epoch): """ Evaluate the policy, e.g. save/print progress. :param epoch: :return: """ pass @abc.abstractmethod def _do_training(self): """ Perform some update, e.g. perform one gradient step. :return: """ pass
class OfflineMetaRLAlgorithm(metaclass=abc.ABCMeta): def __init__(self, env, agent, train_tasks, eval_tasks, goal_radius, eval_deterministic=True, render=False, render_eval_paths=False, plotter=None, **kwargs): """ :param env: training env :param agent: agent that is conditioned on a latent variable z that rl_algorithm is responsible for feeding in :param train_tasks: list of tasks used for training :param eval_tasks: list of tasks used for eval :param goal_radius: reward threshold for defining sparse rewards see default experiment config file for descriptions of the rest of the arguments """ self.env = env self.agent = agent self.train_tasks = train_tasks self.eval_tasks = eval_tasks self.goal_radius = goal_radius self.meta_batch = kwargs['meta_batch'] self.batch_size = kwargs['batch_size'] self.num_iterations = kwargs['num_iterations'] self.num_train_steps_per_itr = kwargs['num_train_steps_per_itr'] self.num_initial_steps = kwargs['num_initial_steps'] self.num_tasks_sample = kwargs['num_tasks_sample'] self.num_steps_prior = kwargs['num_steps_prior'] self.num_steps_posterior = kwargs['num_steps_posterior'] self.num_extra_rl_steps_posterior = kwargs[ 'num_extra_rl_steps_posterior'] self.num_evals = kwargs['num_evals'] self.num_steps_per_eval = kwargs['num_steps_per_eval'] self.embedding_batch_size = kwargs['embedding_batch_size'] self.embedding_mini_batch_size = kwargs['embedding_mini_batch_size'] self.max_path_length = kwargs['max_path_length'] self.discount = kwargs['discount'] self.replay_buffer_size = kwargs['replay_buffer_size'] self.reward_scale = kwargs['reward_scale'] self.update_post_train = kwargs['update_post_train'] self.num_exp_traj_eval = kwargs['num_exp_traj_eval'] self.save_replay_buffer = kwargs['save_replay_buffer'] self.save_algorithm = kwargs['save_algorithm'] self.save_environment = kwargs['save_environment'] self.dump_eval_paths = kwargs['dump_eval_paths'] self.data_dir = kwargs['data_dir'] self.train_epoch = kwargs['train_epoch'] self.eval_epoch = kwargs['eval_epoch'] self.sample = kwargs['sample'] self.n_trj = kwargs['n_trj'] self.allow_eval = kwargs['allow_eval'] self.mb_replace = kwargs['mb_replace'] self.eval_deterministic = eval_deterministic self.render = render self.eval_statistics = None self.render_eval_paths = render_eval_paths self.plotter = plotter self.train_buffer = MultiTaskReplayBuffer(self.replay_buffer_size, env, self.train_tasks, self.goal_radius) self.eval_buffer = MultiTaskReplayBuffer(self.replay_buffer_size, env, self.eval_tasks, self.goal_radius) self.replay_buffer = MultiTaskReplayBuffer(self.replay_buffer_size, env, self.train_tasks, self.goal_radius) self.enc_replay_buffer = MultiTaskReplayBuffer(self.replay_buffer_size, env, self.train_tasks, self.goal_radius) # offline sampler which samples from the train/eval buffer self.offline_sampler = OfflineInPlacePathSampler( env=env, policy=agent, max_path_length=self.max_path_length) # online sampler for evaluation (if collect on-policy context, for offline context, use self.offline_sampler) self.sampler = InPlacePathSampler(env=env, policy=agent, max_path_length=self.max_path_length) self._n_env_steps_total = 0 self._n_train_steps_total = 0 self._n_rollouts_total = 0 self._do_train_time = 0 self._epoch_start_time = None self._algo_start_time = None self._old_table_keys = None self._current_path_builder = PathBuilder() self._exploration_paths = [] self.init_buffer() def init_buffer(self): train_trj_paths = [] eval_trj_paths = [] # trj entry format: [obs, action, reward, new_obs] if self.sample: for n in range(self.n_trj): if self.train_epoch is None: train_trj_paths += glob.glob( os.path.join(self.data_dir, "goal_idx*", "trj_evalsample%d_step*.npy" % (n))) else: train_trj_paths += glob.glob( os.path.join( self.data_dir, "goal_idx*", "trj_evalsample%d_step%d.npy" % (n, self.train_epoch))) if self.eval_epoch is None: eval_trj_paths += glob.glob( os.path.join(self.data_dir, "goal_idx*", "trj_evalsample%d_step*.npy" % (n))) else: eval_trj_paths += glob.glob( os.path.join( self.data_dir, "goal_idx*", "trj_evalsample%d_step%d.npy" % (n, self.eval_epoch))) else: if self.train_epoch is None: train_trj_paths = glob.glob( os.path.join(self.data_dir, "goal_idx*", "trj_eval[0-%d]_step*.npy") % (self.n_trj)) else: train_trj_paths = glob.glob( os.path.join( self.data_dir, "goal_idx*", "trj_eval[0-%d]_step%d.npy" % (self.n_trj, self.train_epoch))) if self.eval_epoch is None: eval_trj_paths = glob.glob( os.path.join(self.data_dir, "goal_idx*", "trj_eval[0-%d]_step*.npy") % (self.n_trj)) else: eval_trj_paths = glob.glob( os.path.join( self.data_dir, "goal_idx*", "trj_eval[0-%d]_step%d.npy" % (self.n_trj, self.test_epoch))) train_paths = [ train_trj_path for train_trj_path in train_trj_paths if int(train_trj_path.split('/')[-2].split('goal_idx')[-1]) in self.train_tasks ] train_task_idxs = [ int(train_trj_path.split('/')[-2].split('goal_idx')[-1]) for train_trj_path in train_trj_paths if int(train_trj_path.split('/')[-2].split('goal_idx')[-1]) in self.train_tasks ] eval_paths = [ eval_trj_path for eval_trj_path in eval_trj_paths if int(eval_trj_path.split('/')[-2].split('goal_idx')[-1]) in self.eval_tasks ] eval_task_idxs = [ int(eval_trj_path.split('/')[-2].split('goal_idx')[-1]) for eval_trj_path in eval_trj_paths if int(eval_trj_path.split('/')[-2].split('goal_idx')[-1]) in self.eval_tasks ] obs_train_lst = [] action_train_lst = [] reward_train_lst = [] next_obs_train_lst = [] terminal_train_lst = [] task_train_lst = [] obs_eval_lst = [] action_eval_lst = [] reward_eval_lst = [] next_obs_eval_lst = [] terminal_eval_lst = [] task_eval_lst = [] for train_path, train_task_idx in zip(train_paths, train_task_idxs): trj_npy = np.load(train_path, allow_pickle=True) obs_train_lst += list(trj_npy[:, 0]) action_train_lst += list(trj_npy[:, 1]) reward_train_lst += list(trj_npy[:, 2]) next_obs_train_lst += list(trj_npy[:, 3]) terminal = [0 for _ in range(trj_npy.shape[0])] terminal[-1] = 1 terminal_train_lst += terminal task_train = [train_task_idx for _ in range(trj_npy.shape[0])] task_train_lst += task_train for eval_path, eval_task_idx in zip(eval_paths, eval_task_idxs): trj_npy = np.load(eval_path, allow_pickle=True) obs_eval_lst += list(trj_npy[:, 0]) action_eval_lst += list(trj_npy[:, 1]) reward_eval_lst += list(trj_npy[:, 2]) next_obs_eval_lst += list(trj_npy[:, 3]) terminal = [0 for _ in range(trj_npy.shape[0])] terminal[-1] = 1 terminal_eval_lst += terminal task_eval = [eval_task_idx for _ in range(trj_npy.shape[0])] task_eval_lst += task_eval # load training buffer for i, ( task_train, obs, action, reward, next_obs, terminal, ) in enumerate( zip( task_train_lst, obs_train_lst, action_train_lst, reward_train_lst, next_obs_train_lst, terminal_train_lst, )): self.train_buffer.add_sample( task_train, obs, action, reward, terminal, next_obs, **{'env_info': {}}, ) # load evaluation buffer for i, ( task_eval, obs, action, reward, next_obs, terminal, ) in enumerate( zip( task_eval_lst, obs_eval_lst, action_eval_lst, reward_eval_lst, next_obs_eval_lst, terminal_eval_lst, )): self.eval_buffer.add_sample( task_eval, obs, action, reward, terminal, next_obs, **{'env_info': {}}, ) def _try_to_eval(self, epoch): #logger.save_extra_data(self.get_extra_data_to_save(epoch)) if self._can_evaluate(): self.evaluate(epoch) #params = self.get_epoch_snapshot(epoch) #logger.save_itr_params(epoch, params) table_keys = logger.get_table_key_set() if self._old_table_keys is not None: assert table_keys == self._old_table_keys, ( "Table keys cannot change from iteration to iteration.") self._old_table_keys = table_keys logger.record_tabular("Number of train steps total", self._n_train_steps_total) logger.record_tabular("Number of env steps total", self._n_env_steps_total) logger.record_tabular("Number of rollouts total", self._n_rollouts_total) times_itrs = gt.get_times().stamps.itrs train_time = times_itrs['train'][-1] sample_time = times_itrs['sample'][-1] eval_time = times_itrs['eval'][-1] if epoch > 0 else 0 epoch_time = train_time + sample_time + eval_time total_time = gt.get_times().total logger.record_tabular('Train Time (s)', train_time) logger.record_tabular('(Previous) Eval Time (s)', eval_time) logger.record_tabular('Sample Time (s)', sample_time) logger.record_tabular('Epoch Time (s)', epoch_time) logger.record_tabular('Total Train Time (s)', total_time) logger.record_tabular("Epoch", epoch) logger.dump_tabular(with_prefix=False, with_timestamp=False) else: logger.log("Skipping eval for now.") def _can_evaluate(self): """ One annoying thing about the logger table is that the keys at each iteration need to be the exact same. So unless you can compute everything, skip evaluation. A common example for why you might want to skip evaluation is that at the beginning of training, you may not have enough data for a validation and training set. :return: """ # eval collects its own context, so can eval any time return True def _can_train(self): return all([ self.replay_buffer.num_steps_can_sample(idx) >= self.batch_size for idx in self.train_tasks ]) def _get_action_and_info(self, agent, observation): """ Get an action to take in the environment. :param observation: :return: """ agent.set_num_steps_total(self._n_env_steps_total) return agent.get_action(observation, ) def _start_epoch(self, epoch): self._epoch_start_time = time.time() self._exploration_paths = [] self._do_train_time = 0 logger.push_prefix('Iteration #%d | ' % epoch) def _end_epoch(self): logger.log("Epoch Duration: {0}".format(time.time() - self._epoch_start_time)) logger.log("Started Training: {0}".format(self._can_train())) logger.pop_prefix() ##### Snapshotting utils ##### def get_epoch_snapshot(self, epoch): data_to_save = dict( epoch=epoch, exploration_policy=self.exploration_policy, ) if self.save_environment: data_to_save['env'] = self.training_env return data_to_save def get_extra_data_to_save(self, epoch): """ Save things that shouldn't be saved every snapshot but rather overwritten every time. :param epoch: :return: """ if self.render: self.training_env.render(close=True) data_to_save = dict(epoch=epoch, ) if self.save_environment: data_to_save['env'] = self.training_env if self.save_replay_buffer: data_to_save['replay_buffer'] = self.replay_buffer if self.save_algorithm: data_to_save['algorithm'] = self return data_to_save def _do_eval(self, indices, epoch, buffer): final_returns = [] online_returns = [] for idx in indices: all_rets = [] for r in range(self.num_evals): paths = self.collect_paths(idx, epoch, r, buffer) all_rets.append( [eval_util.get_average_returns([p]) for p in paths]) final_returns.append(np.mean([a[-1] for a in all_rets])) # record online returns for the first n trajectories n = min([len(a) for a in all_rets]) all_rets = [a[:n] for a in all_rets] all_rets = np.mean(np.stack(all_rets), axis=0) # avg return per nth rollout online_returns.append(all_rets) n = min([len(t) for t in online_returns]) online_returns = [t[:n] for t in online_returns] return final_returns, online_returns def test(self, log_dir, end_point=-1): assert os.path.exists(log_dir) gt.reset() gt.set_def_unique(False) self._current_path_builder = PathBuilder() # at each iteration, we first collect data from tasks, perform meta-updates, then try to evaluate for it_ in gt.timed_for(range(self.num_iterations), save_itrs=True): self._start_epoch(it_) if it_ == 0: print('collecting initial pool of data for test') # temp for evaluating for idx in self.train_tasks: self.task_idx = idx self.env.reset_task(idx) self.collect_data(self.num_initial_steps, 1, np.inf, buffer=self.train_buffer) # Sample data from train tasks. for i in range(self.num_tasks_sample): idx = np.random.choice(self.train_tasks, 1)[0] self.task_idx = idx self.env.reset_task(idx) self.enc_replay_buffer.task_buffers[idx].clear() # collect some trajectories with z ~ prior if self.num_steps_prior > 0: self.collect_data(self.num_steps_prior, 1, np.inf, buffer=self.train_buffer) # collect some trajectories with z ~ posterior if self.num_steps_posterior > 0: self.collect_data(self.num_steps_posterior, 1, self.update_post_train, buffer=self.train_buffer) # even if encoder is trained only on samples from the prior, the policy needs to learn to handle z ~ posterior if self.num_extra_rl_steps_posterior > 0: self.collect_data(self.num_extra_rl_steps_posterior, 1, self.update_post_train, buffer=self.train_buffer, add_to_enc_buffer=False) print([ self.replay_buffer.task_buffers[idx]._size for idx in self.train_tasks ]) print([ self.enc_replay_buffer.task_buffers[idx]._size for idx in self.train_tasks ]) for train_step in range(self.num_train_steps_per_itr): self._n_train_steps_total += 1 gt.stamp('train') # eval self.training_mode(False) if it_ % 5 == 0 and it_ > end_point: status = self.load_epoch_model(it_, log_dir) if status: self._try_to_eval(it_) gt.stamp('eval') self._end_epoch() def train(self): ''' meta-training loop ''' params = self.get_epoch_snapshot(-1) logger.save_itr_params(-1, params) gt.reset() gt.set_def_unique(False) self._current_path_builder = PathBuilder() # at each iteration, we first collect data from tasks, perform meta-updates, then try to evaluate for it_ in gt.timed_for(range(self.num_iterations), save_itrs=True): self._start_epoch(it_) self.training_mode(True) if it_ == 0: print('collecting initial pool of data for train and eval') # temp for evaluating for idx in self.train_tasks: self.task_idx = idx self.env.reset_task(idx) self.collect_data(self.num_initial_steps, 1, np.inf, buffer=self.train_buffer) # Sample data from train tasks. for i in range(self.num_tasks_sample): idx = np.random.choice(self.train_tasks, 1)[0] self.task_idx = idx self.env.reset_task(idx) self.enc_replay_buffer.task_buffers[idx].clear() # collect some trajectories with z ~ prior if self.num_steps_prior > 0: self.collect_data(self.num_steps_prior, 1, np.inf, buffer=self.train_buffer) # collect some trajectories with z ~ posterior if self.num_steps_posterior > 0: self.collect_data(self.num_steps_posterior, 1, self.update_post_train, buffer=self.train_buffer) # even if encoder is trained only on samples from the prior, the policy needs to learn to handle z ~ posterior if self.num_extra_rl_steps_posterior > 0: self.collect_data(self.num_extra_rl_steps_posterior, 1, self.update_post_train, buffer=self.train_buffer, add_to_enc_buffer=False) indices_lst = [] z_means_lst = [] z_vars_lst = [] # Sample train tasks and compute gradient updates on parameters. for train_step in range(self.num_train_steps_per_itr): indices = np.random.choice(self.train_tasks, self.meta_batch, replace=self.mb_replace) z_means, z_vars = self._do_training(indices, zloss=True) indices_lst.append(indices) z_means_lst.append(z_means) z_vars_lst.append(z_vars) self._n_train_steps_total += 1 indices = np.concatenate(indices_lst) z_means = np.concatenate(z_means_lst) z_vars = np.concatenate(z_vars_lst) data_dict = self.data_dict(indices, z_means, z_vars) logger.save_itr_data(it_, **data_dict) gt.stamp('train') self.training_mode(False) # eval params = self.get_epoch_snapshot(it_) logger.save_itr_params(it_, params) if self.allow_eval: logger.save_extra_data(self.get_extra_data_to_save(it_)) self._try_to_eval(it_) gt.stamp('eval') self._end_epoch() def data_dict(self, indices, z_means, z_vars): data_dict = {} data_dict['task_idx'] = indices for i in range(z_means.shape[1]): data_dict['z_means%d' % i] = list(z_means[:, i]) for i in range(z_vars.shape[1]): data_dict['z_vars%d' % i] = list(z_vars[:, i]) return data_dict def evaluate(self, epoch): if self.eval_statistics is None: self.eval_statistics = OrderedDict() ### sample trajectories from prior for debugging / visualization if self.dump_eval_paths: # 100 arbitrarily chosen for visualizations of point_robot trajectories # just want stochasticity of z, not the policy self.agent.clear_z() prior_paths, _ = self.offline_sampler.obtain_samples( buffer=self.train_buffer, deterministic=self.eval_deterministic, max_samples=self.max_path_length * 20, accum_context=False, resample=1) logger.save_extra_data( prior_paths, path='eval_trajectories/prior-epoch{}'.format(epoch)) ### train tasks # eval on a subset of train tasks for speed # {}-dir envs if len(self.train_tasks) == 2 and len(self.eval_tasks) == 2: indices = self.train_tasks eval_util.dprint('evaluating on {} train tasks'.format( len(indices))) ### eval train tasks with posterior sampled from the training replay buffer train_returns = [] for idx in indices: self.task_idx = idx self.env.reset_task(idx) paths = [] print(self.num_steps_per_eval, self.max_path_length) for _ in range(self.num_steps_per_eval // self.max_path_length): context = self.sample_context(idx) self.agent.infer_posterior(context, idx) p, _ = self.offline_sampler.obtain_samples( buffer=self.train_buffer, deterministic=self.eval_deterministic, max_samples=self.max_path_length, accum_context=False, max_trajs=1, resample=np.inf) paths += p if self.sparse_rewards: for p in paths: sparse_rewards = np.stack( e['sparse_reward'] for e in p['env_infos']).reshape(-1, 1) p['rewards'] = sparse_rewards train_returns.append(eval_util.get_average_returns(paths)) ### eval train tasks with on-policy data to match eval of test tasks train_final_returns, train_online_returns = self._do_eval( indices, epoch, buffer=self.train_buffer) eval_util.dprint('train online returns') eval_util.dprint(train_online_returns) ### test tasks eval_util.dprint('evaluating on {} test tasks'.format( len(self.eval_tasks))) test_final_returns, test_online_returns = self._do_eval( self.eval_tasks, epoch, buffer=self.eval_buffer) eval_util.dprint('test online returns') eval_util.dprint(test_online_returns) # save the final posterior self.agent.log_diagnostics(self.eval_statistics) if hasattr(self.env, "log_diagnostics"): self.env.log_diagnostics(paths, prefix=None) avg_train_online_return = np.mean(np.stack(train_online_returns), axis=0) avg_test_online_return = np.mean(np.stack(test_online_returns), axis=0) for i in indices: self.eval_statistics[ f'AverageTrainReturn_train_task{i}'] = train_returns[i] self.eval_statistics[ f'AverageReturn_all_train_task{i}'] = train_final_returns[ i] self.eval_statistics[ f'AverageReturn_all_test_tasks{i}'] = test_final_returns[i] # non {}-dir envs else: indices = np.random.choice(self.train_tasks, len(self.eval_tasks)) eval_util.dprint('evaluating on {} train tasks'.format( len(indices))) ### eval train tasks with posterior sampled from the training replay buffer train_returns = [] for idx in indices: self.task_idx = idx self.env.reset_task(idx) paths = [] for _ in range(self.num_steps_per_eval // self.max_path_length): context = self.sample_context(idx) self.agent.infer_posterior(context, idx) p, _ = self.offline_sampler.obtain_samples( buffer=self.train_buffer, deterministic=self.eval_deterministic, max_samples=self.max_path_length, accum_context=False, max_trajs=1, resample=np.inf) paths += p if self.sparse_rewards: for p in paths: sparse_rewards = np.stack( e['sparse_reward'] for e in p['env_infos']).reshape(-1, 1) p['rewards'] = sparse_rewards train_returns.append(eval_util.get_average_returns(paths)) train_returns = np.mean(train_returns) ### eval train tasks with on-policy data to match eval of test tasks train_final_returns, train_online_returns = self._do_eval( indices, epoch, buffer=self.train_buffer) eval_util.dprint('train online returns') eval_util.dprint(train_online_returns) ### test tasks eval_util.dprint('evaluating on {} test tasks'.format( len(self.eval_tasks))) test_final_returns, test_online_returns = self._do_eval( self.eval_tasks, epoch, buffer=self.eval_buffer) eval_util.dprint('test online returns') eval_util.dprint(test_online_returns) # save the final posterior self.agent.log_diagnostics(self.eval_statistics) if hasattr(self.env, "log_diagnostics"): self.env.log_diagnostics(paths, prefix=None) avg_train_return = np.mean(train_final_returns) avg_test_return = np.mean(test_final_returns) avg_train_online_return = np.mean(np.stack(train_online_returns), axis=0) avg_test_online_return = np.mean(np.stack(test_online_returns), axis=0) self.eval_statistics[ 'AverageTrainReturn_all_train_tasks'] = train_returns self.eval_statistics[ 'AverageReturn_all_train_tasks'] = avg_train_return self.eval_statistics[ 'AverageReturn_all_test_tasks'] = avg_test_return self.loss['train_returns'] = train_returns self.loss['avg_train_return'] = avg_train_return self.loss['avg_test_return'] = avg_test_return self.loss['avg_train_online_return'] = np.mean( avg_train_online_return) self.loss['avg_test_online_return'] = np.mean( avg_test_online_return) logger.save_extra_data(avg_train_online_return, path='online-train-epoch{}'.format(epoch)) logger.save_extra_data(avg_test_online_return, path='online-test-epoch{}'.format(epoch)) for key, value in self.eval_statistics.items(): logger.record_tabular(key, value) self.eval_statistics = None if self.render_eval_paths: self.env.render_paths(paths) if self.plotter: self.plotter.draw() def collect_paths(self, idx, epoch, run, buffer): self.task_idx = idx self.env.reset_task(idx) self.agent.clear_z() paths = [] num_transitions = 0 # num_trajs = 0 while num_transitions < self.num_steps_per_eval: path, num = self.offline_sampler.obtain_samples( buffer=buffer, deterministic=self.eval_deterministic, max_samples=self.num_steps_per_eval - num_transitions, max_trajs=1, accum_context=True, rollout=True) paths += path num_transitions += num if self.sparse_rewards: for p in paths: sparse_rewards = np.stack( e['sparse_reward'] for e in p['env_infos']).reshape(-1, 1) p['rewards'] = sparse_rewards goal = self.env._goal for path in paths: path['goal'] = goal # goal # save the paths for visualization, only useful for point mass if self.dump_eval_paths: logger.save_extra_data( paths, path='eval_trajectories/task{}-epoch{}-run{}'.format( idx, epoch, run)) return paths def collect_data(self, num_samples, resample_z_rate, update_posterior_rate, buffer, add_to_enc_buffer=True): ''' get trajectories from current env in batch mode with given policy collect complete trajectories until the number of collected transitions >= num_samples :param agent: policy to rollout :param num_samples: total number of transitions to sample :param resample_z_rate: how often to resample latent context z (in units of trajectories) :param update_posterior_rate: how often to update q(z | c) from which z is sampled (in units of trajectories) :param add_to_enc_buffer: whether to add collected data to encoder replay buffer ''' # start from the prior self.agent.clear_z() num_transitions = 0 while num_transitions < num_samples: paths, n_samples = self.offline_sampler.obtain_samples( buffer=buffer, max_samples=num_samples - num_transitions, max_trajs=update_posterior_rate, accum_context=False, resample=resample_z_rate, rollout=False) num_transitions += n_samples self.replay_buffer.add_paths(self.task_idx, paths) if add_to_enc_buffer: self.enc_replay_buffer.add_paths(self.task_idx, paths) if update_posterior_rate != np.inf: context = self.sample_context(self.task_idx) self.agent.infer_posterior(context, task_indices=np.array( [self.task_idx])) self._n_env_steps_total += num_transitions gt.stamp('sample') @abc.abstractmethod def training_mode(self, mode): """ Set training mode to `mode`. :param mode: If True, training will happen (e.g. set the dropout probabilities to not all ones). """ pass @abc.abstractmethod def _do_training(self): """ Perform some update, e.g. perform one gradient step. :return: """ pass