def __init__(self, args): self.args = args utl.seed(self.args.seed, self.args.deterministic_execution) # calculate number of updates and keep count of frames/iterations self.num_updates = int( args.num_frames) // args.policy_num_steps // args.num_processes self.frames = 0 self.iter_idx = 0 # initialise tensorboard logger self.logger = TBLogger(self.args, self.args.exp_label) # initialise environments self.envs = make_vec_envs( env_name=args.env_name, seed=args.seed, num_processes=args.num_processes, gamma=args.policy_gamma, device=device, episodes_per_task=self.args.max_rollouts_per_task, normalise_rew=args.norm_rew_for_policy, ret_rms=None, ) # calculate what the maximum length of the trajectories is self.args.max_trajectory_len = self.envs._max_episode_steps self.args.max_trajectory_len *= self.args.max_rollouts_per_task # get policy input dimensions self.args.state_dim = self.envs.observation_space.shape[0] self.args.task_dim = self.envs.task_dim self.args.belief_dim = self.envs.belief_dim self.args.num_states = self.envs.num_states # get policy output (action) dimensions self.args.action_space = self.envs.action_space if isinstance(self.envs.action_space, gym.spaces.discrete.Discrete): self.args.action_dim = 1 elif isinstance(self.envs.action_space, gym.spaces.multi_discrete.MultiDiscrete): self.args.action_dim = self.envs.action_space.nvec[0] else: self.args.action_dim = self.envs.action_space.shape[0] # initialise VAE and policy self.vae = VaribadVAE(self.args, self.logger, lambda: self.iter_idx) self.policy_storage = self.initialise_policy_storage() self.policy = self.initialise_policy()
def __init__(self, args): """ Seeds everything. Initialises: logger, environments, policy (+storage +optimiser). """ self.args = args # make sure everything has the same seed utl.seed(self.args.seed) # initialize tensorboard logger if self.args.log_tensorboard: self.tb_logger = TBLogger(self.args) # initialise environment self.env = make_env(self.args.env_name, self.args.max_rollouts_per_task, seed=self.args.seed, n_tasks=self.args.num_tasks) # unwrapped env to get some info about the environment unwrapped_env = self.env.unwrapped # split to train/eval tasks shuffled_tasks = np.random.permutation( unwrapped_env.get_all_task_idx()) self.train_tasks = shuffled_tasks[:self.args.num_train_tasks] if self.args.num_eval_tasks > 0: self.eval_tasks = shuffled_tasks[-self.args.num_eval_tasks:] else: self.eval_tasks = [] # calculate what the maximum length of the trajectories is args.max_trajectory_len = unwrapped_env._max_episode_steps args.max_trajectory_len *= self.args.max_rollouts_per_task self.args.max_trajectory_len = args.max_trajectory_len # get action / observation dimensions if isinstance(self.env.action_space, gym.spaces.discrete.Discrete): self.args.action_dim = 1 else: self.args.action_dim = self.env.action_space.shape[0] self.args.obs_dim = self.env.observation_space.shape[0] self.args.num_states = unwrapped_env.num_states if hasattr( unwrapped_env, 'num_states') else None self.args.act_space = self.env.action_space # initialize policy self.initialize_policy() # initialize buffer for RL updates self.policy_storage = MultiTaskPolicyStorage( max_replay_buffer_size=int(self.args.policy_buffer_size), obs_dim=self._get_augmented_obs_dim(), action_space=self.env.action_space, tasks=self.train_tasks, trajectory_len=args.max_trajectory_len, ) self.current_experience_storage = None self.args.belief_reward = False # initialize arg to not use belief rewards
def __init__(self, args): """ Seeds everything. Initialises: logger, environments, policy (+storage +optimiser). """ self.args = args # make sure everything has the same seed utl.seed(self.args.seed) # initialize tensorboard logger if self.args.log_tensorboard: self.tb_logger = TBLogger(self.args) self.args, env = off_utl.expand_args(self.args, include_act_space=True) if self.args.act_space.__class__.__name__ == "Discrete": self.args.policy = 'dqn' else: self.args.policy = 'sac' # load buffers with data if 'load_data' not in self.args or self.args.load_data: goals, augmented_obs_dim = self.load_buffer( env) # env is input just for possible relabelling option self.args.augmented_obs_dim = augmented_obs_dim self.goals = goals # initialize policy self.initialize_policy() # load vae for inference in evaluation self.load_vae() # create environment for evaluation self.env = make_env( args.env_name, args.max_rollouts_per_task, presampled_tasks=args.presampled_tasks, seed=args.seed, ) # n_tasks=self.args.num_eval_tasks) if self.args.env_name == 'GridNavi-v2': self.env.unwrapped.goals = [ tuple(goal.astype(int)) for goal in self.goals ]
def __init__(self, args): self.args = args utl.seed(self.args.seed, self.args.deterministic_execution) # count number of frames and number of meta-iterations self.frames = 0 self.iter_idx = 0 # initialise tensorboard logger self.logger = TBLogger(self.args, self.args.exp_label) # initialise environments self.envs = make_vec_envs( env_name=args.env_name, seed=args.seed, num_processes=args.num_processes, gamma=args.policy_gamma, log_dir=args.agent_log_dir, device=device, allow_early_resets=False, episodes_per_task=self.args.max_rollouts_per_task, obs_rms=None, ret_rms=None, ) # calculate what the maximum length of the trajectories is args.max_trajectory_len = self.envs._max_episode_steps args.max_trajectory_len *= self.args.max_rollouts_per_task # calculate number of meta updates self.args.num_updates = int( args.num_frames) // args.policy_num_steps // args.num_processes # get action / observation dimensions if isinstance(self.envs.action_space, gym.spaces.discrete.Discrete): self.args.action_dim = 1 else: self.args.action_dim = self.envs.action_space.shape[0] self.args.obs_dim = self.envs.observation_space.shape[0] self.args.num_states = self.envs.num_states if str.startswith( self.args.env_name, 'Grid') else None self.args.act_space = self.envs.action_space self.vae = VaribadVAE(self.args, self.logger, lambda: self.iter_idx) self.initialise_policy()
class OfflineMetaLearner: """ Off-line Meta-Learner class, a.k.a no interaction with env. """ def __init__(self, args): """ Seeds everything. Initialises: logger, environments, policy (+storage +optimiser). """ self.args = args # make sure everything has the same seed utl.seed(self.args.seed) # initialize tensorboard logger if self.args.log_tensorboard: self.tb_logger = TBLogger(self.args) self.args, env = off_utl.expand_args(self.args, include_act_space=True) if self.args.act_space.__class__.__name__ == "Discrete": self.args.policy = 'dqn' else: self.args.policy = 'sac' # load buffers with data if 'load_data' not in self.args or self.args.load_data: goals, augmented_obs_dim = self.load_buffer( env) # env is input just for possible relabelling option self.args.augmented_obs_dim = augmented_obs_dim self.goals = goals # initialize policy self.initialize_policy() # load vae for inference in evaluation self.load_vae() # create environment for evaluation self.env = make_env( args.env_name, args.max_rollouts_per_task, presampled_tasks=args.presampled_tasks, seed=args.seed, ) # n_tasks=self.args.num_eval_tasks) if self.args.env_name == 'GridNavi-v2': self.env.unwrapped.goals = [ tuple(goal.astype(int)) for goal in self.goals ] def initialize_policy(self): if self.args.policy == 'dqn': q_network = FlattenMlp(input_size=self.args.augmented_obs_dim, output_size=self.args.act_space.n, hidden_sizes=self.args.dqn_layers).to( ptu.device) self.agent = DQN( q_network, # optimiser_vae=self.optimizer_vae, lr=self.args.policy_lr, gamma=self.args.gamma, tau=self.args.soft_target_tau, ).to(ptu.device) else: # assert self.args.act_space.__class__.__name__ == "Box", ( # "Can't train SAC with discrete action space!") q1_network = FlattenMlp( input_size=self.args.augmented_obs_dim + self.args.action_dim, output_size=1, hidden_sizes=self.args.dqn_layers).to(ptu.device) q2_network = FlattenMlp( input_size=self.args.augmented_obs_dim + self.args.action_dim, output_size=1, hidden_sizes=self.args.dqn_layers).to(ptu.device) policy = TanhGaussianPolicy( obs_dim=self.args.augmented_obs_dim, action_dim=self.args.action_dim, hidden_sizes=self.args.policy_layers).to(ptu.device) self.agent = SAC( policy, q1_network, q2_network, actor_lr=self.args.actor_lr, critic_lr=self.args.critic_lr, gamma=self.args.gamma, tau=self.args.soft_target_tau, use_cql=self.args.use_cql if 'use_cql' in self.args else False, alpha_cql=self.args.alpha_cql if 'alpha_cql' in self.args else None, entropy_alpha=self.args.entropy_alpha, automatic_entropy_tuning=self.args.automatic_entropy_tuning, alpha_lr=self.args.alpha_lr, clip_grad_value=self.args.clip_grad_value, ).to(ptu.device) def load_vae(self): self.vae = VAE(self.args) vae_models_path = os.path.join(self.args.vae_dir, self.args.env_name, self.args.vae_model_name, 'models') off_utl.load_trained_vae(self.vae, vae_models_path) def load_buffer(self, env): if self.args.hindsight_relabelling: # without arr_type loading -- GPU will explode dataset, goals = off_utl.load_dataset( data_dir=self.args.relabelled_data_dir, args=self.args, num_tasks=self.args.num_train_tasks, allow_dense_data_loading=False, arr_type='numpy') dataset = off_utl.batch_to_trajectories(dataset, self.args) dataset, goals = off_utl.mix_task_rollouts( dataset, env, goals, self.args) # reward relabelling dataset = off_utl.trajectories_to_batch(dataset) else: dataset, goals = off_utl.load_dataset( data_dir=self.args.relabelled_data_dir, args=self.args, num_tasks=self.args.num_train_tasks, allow_dense_data_loading=False, arr_type='numpy') augmented_obs_dim = dataset[0][0].shape[1] self.storage = MultiTaskPolicyStorage( max_replay_buffer_size=max([d[0].shape[0] for d in dataset]), obs_dim=dataset[0][0].shape[1], action_space=self.args.act_space, tasks=range(len(goals)), trajectory_len=self.args.trajectory_len) for task, set in enumerate(dataset): self.storage.add_samples(task, observations=set[0], actions=set[1], rewards=set[2], next_observations=set[3], terminals=set[4]) return goals, augmented_obs_dim def train(self): self._start_training() for iter_ in range(self.args.num_iters): self.training_mode(True) indices = np.random.choice(len(self.goals), self.args.meta_batch) train_stats = self.update(indices) self.training_mode(False) self.log(iter_ + 1, train_stats) def update(self, tasks): rl_losses_agg = {} for update in range(self.args.rl_updates_per_iter): # sample random RL batch obs, actions, rewards, next_obs, terms = self.sample_rl_batch( tasks, self.args.batch_size) # flatten out task dimension t, b, _ = obs.size() obs = obs.view(t * b, -1) actions = actions.view(t * b, -1) rewards = rewards.view(t * b, -1) next_obs = next_obs.view(t * b, -1) terms = terms.view(t * b, -1) # RL update rl_losses = self.agent.update(obs, actions, rewards, next_obs, terms, action_space=self.env.action_space) for k, v in rl_losses.items(): if update == 0: # first iterate - create list rl_losses_agg[k] = [v] else: # append values rl_losses_agg[k].append(v) # take mean for k in rl_losses_agg: rl_losses_agg[k] = np.mean(rl_losses_agg[k]) self._n_rl_update_steps_total += self.args.rl_updates_per_iter return rl_losses_agg def evaluate(self): num_episodes = self.args.max_rollouts_per_task num_steps_per_episode = self.env.unwrapped._max_episode_steps num_tasks = self.args.num_eval_tasks obs_size = self.env.unwrapped.observation_space.shape[0] returns_per_episode = np.zeros((num_tasks, num_episodes)) success_rate = np.zeros(num_tasks) rewards = np.zeros((num_tasks, self.args.trajectory_len)) reward_preds = np.zeros((num_tasks, self.args.trajectory_len)) observations = np.zeros( (num_tasks, self.args.trajectory_len + 1, obs_size)) if self.args.policy == 'sac': log_probs = np.zeros((num_tasks, self.args.trajectory_len)) # This part is very specific for the Semi-Circle env # if self.args.env_name == 'PointRobotSparse-v0': # reward_belief = np.zeros((num_tasks, self.args.trajectory_len)) # # low_x, high_x, low_y, high_y = -2., 2., -1., 2. # resolution = 0.1 # grid_x = np.arange(low_x, high_x + resolution, resolution) # grid_y = np.arange(low_y, high_y + resolution, resolution) # centers_x = (grid_x[:-1] + grid_x[1:]) / 2 # centers_y = (grid_y[:-1] + grid_y[1:]) / 2 # yv, xv = np.meshgrid(centers_y, centers_x, sparse=False, indexing='ij') # centers = np.vstack([xv.ravel(), yv.ravel()]).T # n_grid_points = centers.shape[0] # reward_belief_discretized = np.zeros((num_tasks, self.args.trajectory_len, centers.shape[0])) for task_loop_i, task in enumerate( self.env.unwrapped.get_all_eval_task_idx()): obs = ptu.from_numpy(self.env.reset(task)) obs = obs.reshape(-1, obs.shape[-1]) step = 0 # get prior parameters with torch.no_grad(): task_sample, task_mean, task_logvar, hidden_state = self.vae.encoder.prior( batch_size=1) observations[task_loop_i, step, :] = ptu.get_numpy(obs[0, :obs_size]) for episode_idx in range(num_episodes): running_reward = 0. for step_idx in range(num_steps_per_episode): # add distribution parameters to observation - policy is conditioned on posterior augmented_obs = self.get_augmented_obs( obs, task_mean, task_logvar) if self.args.policy == 'dqn': action, value = self.agent.act(obs=augmented_obs, deterministic=True) else: action, _, _, log_prob = self.agent.act( obs=augmented_obs, deterministic=self.args.eval_deterministic, return_log_prob=True) # observe reward and next obs next_obs, reward, done, info = utl.env_step( self.env, action.squeeze(dim=0)) running_reward += reward.item() # done_rollout = False if ptu.get_numpy(done[0][0]) == 0. else True # update encoding task_sample, task_mean, task_logvar, hidden_state = self.update_encoding( obs=next_obs, action=action, reward=reward, done=done, hidden_state=hidden_state) rewards[task_loop_i, step] = reward.item() reward_preds[task_loop_i, step] = ptu.get_numpy( self.vae.reward_decoder(task_sample, next_obs, obs, action)[0, 0]) # This part is very specific for the Semi-Circle env # if self.args.env_name == 'PointRobotSparse-v0': # reward_belief[task, step] = ptu.get_numpy( # self.vae.compute_belief_reward(task_mean, task_logvar, obs, next_obs, action)[0]) # # reward_belief_discretized[task, step, :] = ptu.get_numpy( # self.vae.compute_belief_reward(task_mean.repeat(n_grid_points, 1), # task_logvar.repeat(n_grid_points, 1), # None, # torch.cat((ptu.FloatTensor(centers), # ptu.zeros(centers.shape[0], 1)), dim=-1).unsqueeze(0), # None)[:, 0]) observations[task_loop_i, step + 1, :] = ptu.get_numpy( next_obs[0, :obs_size]) if self.args.policy != 'dqn': log_probs[task_loop_i, step] = ptu.get_numpy(log_prob[0]) if "is_goal_state" in dir( self.env.unwrapped ) and self.env.unwrapped.is_goal_state(): success_rate[task_loop_i] = 1. # set: obs <- next_obs obs = next_obs.clone() step += 1 returns_per_episode[task_loop_i, episode_idx] = running_reward if self.args.policy == 'dqn': return returns_per_episode, success_rate, observations, rewards, reward_preds # This part is very specific for the Semi-Circle env # elif self.args.env_name == 'PointRobotSparse-v0': # return returns_per_episode, success_rate, log_probs, observations, \ # rewards, reward_preds, reward_belief, reward_belief_discretized, centers else: return returns_per_episode, success_rate, log_probs, observations, rewards, reward_preds def log(self, iteration, train_stats): # --- save model --- if iteration % self.args.save_interval == 0: save_path = os.path.join(self.tb_logger.full_output_folder, 'models') if not os.path.exists(save_path): os.mkdir(save_path) torch.save( self.agent.state_dict(), os.path.join(save_path, "agent{0}.pt".format(iteration))) if iteration % self.args.log_interval == 0: if self.args.policy == 'dqn': returns, success_rate, observations, rewards, reward_preds = self.evaluate( ) # This part is super specific for the Semi-Circle env # elif self.args.env_name == 'PointRobotSparse-v0': # returns, success_rate, log_probs, observations, \ # rewards, reward_preds, reward_belief, reward_belief_discretized, points = self.evaluate() else: returns, success_rate, log_probs, observations, rewards, reward_preds = self.evaluate( ) if self.args.log_tensorboard: tasks_to_vis = np.random.choice(self.args.num_eval_tasks, 5) for i, task in enumerate(tasks_to_vis): self.env.reset(task) if PLOT_VIS: self.tb_logger.writer.add_figure( 'policy_vis/task_{}'.format(i), utl_eval.plot_rollouts(observations[task, :], self.env), self._n_rl_update_steps_total) self.tb_logger.writer.add_figure( 'reward_prediction_train/task_{}'.format(i), utl_eval.plot_rew_pred_vs_rew(rewards[task, :], reward_preds[task, :]), self._n_rl_update_steps_total) # self.tb_logger.writer.add_figure('reward_prediction_train/task_{}'.format(i), # utl_eval.plot_rew_pred_vs_reward_belief_vs_rew(rewards[task, :], # reward_preds[task, :], # reward_belief[task, :]), # self._n_rl_update_steps_total) # if self.args.env_name == 'PointRobotSparse-v0': # This part is super specific for the Semi-Circle env # for t in range(0, int(self.args.trajectory_len/4), 3): # self.tb_logger.writer.add_figure('discrete_belief_reward_pred_task_{}/timestep_{}'.format(i, t), # utl_eval.plot_discretized_belief_halfcircle(reward_belief_discretized[task, t, :], # points, self.env, # observations[task, :t+1]), # self._n_rl_update_steps_total) if self.args.max_rollouts_per_task > 1: for episode_idx in range(self.args.max_rollouts_per_task): self.tb_logger.writer.add_scalar( 'returns_multi_episode/episode_{}'.format( episode_idx + 1), np.mean(returns[:, episode_idx]), self._n_rl_update_steps_total) self.tb_logger.writer.add_scalar( 'returns_multi_episode/sum', np.mean(np.sum(returns, axis=-1)), self._n_rl_update_steps_total) self.tb_logger.writer.add_scalar( 'returns_multi_episode/success_rate', np.mean(success_rate), self._n_rl_update_steps_total) else: self.tb_logger.writer.add_scalar( 'returns/returns_mean', np.mean(returns), self._n_rl_update_steps_total) self.tb_logger.writer.add_scalar( 'returns/returns_std', np.std(returns), self._n_rl_update_steps_total) self.tb_logger.writer.add_scalar( 'returns/success_rate', np.mean(success_rate), self._n_rl_update_steps_total) if self.args.policy == 'dqn': self.tb_logger.writer.add_scalar( 'rl_losses/qf_loss_vs_n_updates', train_stats['qf_loss'], self._n_rl_update_steps_total) self.tb_logger.writer.add_scalar( 'weights/q_network', list(self.agent.qf.parameters())[0].mean(), self._n_rl_update_steps_total) if list(self.agent.qf.parameters())[0].grad is not None: param_list = list(self.agent.qf.parameters()) self.tb_logger.writer.add_scalar( 'gradients/q_network', sum([ param_list[i].grad.mean() for i in range(len(param_list)) ]), self._n_rl_update_steps_total) self.tb_logger.writer.add_scalar( 'weights/q_target', list(self.agent.target_qf.parameters())[0].mean(), self._n_rl_update_steps_total) if list(self.agent.target_qf.parameters() )[0].grad is not None: param_list = list(self.agent.target_qf.parameters()) self.tb_logger.writer.add_scalar( 'gradients/q_target', sum([ param_list[i].grad.mean() for i in range(len(param_list)) ]), self._n_rl_update_steps_total) else: self.tb_logger.writer.add_scalar( 'policy/log_prob', np.mean(log_probs), self._n_rl_update_steps_total) self.tb_logger.writer.add_scalar( 'rl_losses/qf1_loss', train_stats['qf1_loss'], self._n_rl_update_steps_total) self.tb_logger.writer.add_scalar( 'rl_losses/qf2_loss', train_stats['qf2_loss'], self._n_rl_update_steps_total) self.tb_logger.writer.add_scalar( 'rl_losses/policy_loss', train_stats['policy_loss'], self._n_rl_update_steps_total) self.tb_logger.writer.add_scalar( 'rl_losses/alpha_entropy_loss', train_stats['alpha_entropy_loss'], self._n_rl_update_steps_total) # weights and gradients self.tb_logger.writer.add_scalar( 'weights/q1_network', list(self.agent.qf1.parameters())[0].mean(), self._n_rl_update_steps_total) if list(self.agent.qf1.parameters())[0].grad is not None: param_list = list(self.agent.qf1.parameters()) self.tb_logger.writer.add_scalar( 'gradients/q1_network', sum([ param_list[i].grad.mean() for i in range(len(param_list)) ]), self._n_rl_update_steps_total) self.tb_logger.writer.add_scalar( 'weights/q1_target', list(self.agent.qf1_target.parameters())[0].mean(), self._n_rl_update_steps_total) if list(self.agent.qf1_target.parameters() )[0].grad is not None: param_list = list(self.agent.qf1_target.parameters()) self.tb_logger.writer.add_scalar( 'gradients/q1_target', sum([ param_list[i].grad.mean() for i in range(len(param_list)) ]), self._n_rl_update_steps_total) self.tb_logger.writer.add_scalar( 'weights/q2_network', list(self.agent.qf2.parameters())[0].mean(), self._n_rl_update_steps_total) if list(self.agent.qf2.parameters())[0].grad is not None: param_list = list(self.agent.qf2.parameters()) self.tb_logger.writer.add_scalar( 'gradients/q2_network', sum([ param_list[i].grad.mean() for i in range(len(param_list)) ]), self._n_rl_update_steps_total) self.tb_logger.writer.add_scalar( 'weights/q2_target', list(self.agent.qf2_target.parameters())[0].mean(), self._n_rl_update_steps_total) if list(self.agent.qf2_target.parameters() )[0].grad is not None: param_list = list(self.agent.qf2_target.parameters()) self.tb_logger.writer.add_scalar( 'gradients/q2_target', sum([ param_list[i].grad.mean() for i in range(len(param_list)) ]), self._n_rl_update_steps_total) self.tb_logger.writer.add_scalar( 'weights/policy', list(self.agent.policy.parameters())[0].mean(), self._n_rl_update_steps_total) if list(self.agent.policy.parameters() )[0].grad is not None: param_list = list(self.agent.policy.parameters()) self.tb_logger.writer.add_scalar( 'gradients/policy', sum([ param_list[i].grad.mean() for i in range(len(param_list)) ]), self._n_rl_update_steps_total) for k, v in [ ('num_rl_updates', self._n_rl_update_steps_total), ('time_elapsed', time.time() - self._start_time), ('iteration', iteration), ]: self.tb_logger.writer.add_scalar(k, v, self._n_rl_update_steps_total) self.tb_logger.finish_iteration(iteration) print( "Iteration -- {}, Success rate -- {:.3f}, Avg. return -- {:.3f}, Elapsed time {:5d}[s]" .format(iteration, np.mean(success_rate), np.mean(np.sum(returns, axis=-1)), int(time.time() - self._start_time))) def sample_rl_batch(self, tasks, batch_size): ''' sample batch of unordered rl training data from a list/array of tasks ''' # this batch consists of transitions sampled randomly from replay buffer batches = [ ptu.np_to_pytorch_batch(self.storage.random_batch( task, batch_size)) for task in tasks ] unpacked = [utl.unpack_batch(batch) for batch in batches] # group elements together unpacked = [[x[i] for x in unpacked] for i in range(len(unpacked[0]))] unpacked = [torch.cat(x, dim=0) for x in unpacked] return unpacked def _start_training(self): self._n_rl_update_steps_total = 0 self._start_time = time.time() def training_mode(self, mode): self.agent.train(mode) def update_encoding(self, obs, action, reward, done, hidden_state): # reset hidden state of the recurrent net when the task is done hidden_state = self.vae.encoder.reset_hidden(hidden_state, done) with torch.no_grad(): # size should be (batch, dim) task_sample, task_mean, task_logvar, hidden_state = self.vae.encoder( actions=action.float(), states=obs, rewards=reward, hidden_state=hidden_state, return_prior=False) return task_sample, task_mean, task_logvar, hidden_state @staticmethod def get_augmented_obs(obs, mean, logvar): mean = mean.reshape((-1, mean.shape[-1])) logvar = logvar.reshape((-1, logvar.shape[-1])) return torch.cat((obs, mean, logvar), dim=-1) def load_model(self, agent_path, device='cpu'): self.agent.load_state_dict(torch.load(agent_path, map_location=device)) self.load_vae() self.training_mode(False)
class MetaLearner: """ Meta-Learner class with the main training loop for variBAD. """ def __init__(self, args): self.args = args utl.seed(self.args.seed, self.args.deterministic_execution) # count number of frames and number of meta-iterations self.frames = 0 self.iter_idx = 0 # initialise tensorboard logger self.logger = TBLogger(self.args, self.args.exp_label) # initialise environments self.envs = make_vec_envs( env_name=args.env_name, seed=args.seed, num_processes=args.num_processes, gamma=args.policy_gamma, log_dir=args.agent_log_dir, device=device, allow_early_resets=False, episodes_per_task=self.args.max_rollouts_per_task, obs_rms=None, ret_rms=None, ) # calculate what the maximum length of the trajectories is args.max_trajectory_len = self.envs._max_episode_steps args.max_trajectory_len *= self.args.max_rollouts_per_task # calculate number of meta updates self.args.num_updates = int( args.num_frames) // args.policy_num_steps // args.num_processes # get action / observation dimensions if isinstance(self.envs.action_space, gym.spaces.discrete.Discrete): self.args.action_dim = 1 else: self.args.action_dim = self.envs.action_space.shape[0] self.args.obs_dim = self.envs.observation_space.shape[0] self.args.num_states = self.envs.num_states if str.startswith( self.args.env_name, 'Grid') else None self.args.act_space = self.envs.action_space self.vae = VaribadVAE(self.args, self.logger, lambda: self.iter_idx) self.initialise_policy() def initialise_policy(self): # initialise rollout storage for the policy self.policy_storage = OnlineStorage( self.args, self.args.policy_num_steps, self.args.num_processes, self.args.obs_dim, self.args.act_space, hidden_size=self.args.aggregator_hidden_size, latent_dim=self.args.latent_dim, normalise_observations=self.args.norm_obs_for_policy, normalise_rewards=self.args.norm_rew_for_policy, ) # initialise policy network input_dim = self.args.obs_dim * int( self.args.condition_policy_on_state) input_dim += ( 1 + int(not self.args.sample_embeddings)) * self.args.latent_dim if hasattr(self.envs.action_space, 'low'): action_low = self.envs.action_space.low action_high = self.envs.action_space.high else: action_low = action_high = None policy_net = Policy( state_dim=input_dim, action_space=self.args.act_space, init_std=self.args.policy_init_std, hidden_layers=self.args.policy_layers, activation_function=self.args.policy_activation_function, normalise_actions=self.args.normalise_actions, action_low=action_low, action_high=action_high, ).to(device) # initialise policy trainer if self.args.policy == 'a2c': self.policy = A2C( policy_net, self.args.policy_value_loss_coef, self.args.policy_entropy_coef, optimiser_vae=self.vae.optimiser_vae, lr=self.args.lr_policy, eps=self.args.policy_eps, alpha=self.args.a2c_alpha, ) elif self.args.policy == 'ppo': self.policy = PPO( policy_net, self.args.policy_value_loss_coef, self.args.policy_entropy_coef, optimiser_vae=self.vae.optimiser_vae, lr=self.args.lr_policy, eps=self.args.policy_eps, ppo_epoch=self.args.ppo_num_epochs, num_mini_batch=self.args.ppo_num_minibatch, use_huber_loss=self.args.ppo_use_huberloss, use_clipped_value_loss=self.args.ppo_use_clipped_value_loss, clip_param=self.args.ppo_clip_param, ) else: raise NotImplementedError def train(self): """ Given some stream of environments and a logger (tensorboard), (meta-)trains the policy. """ start_time = time.time() # reset environments (prev_obs_raw, prev_obs_normalised) = self.envs.reset() prev_obs_raw = prev_obs_raw.to(device) prev_obs_normalised = prev_obs_normalised.to(device) # insert initial observation / embeddings to rollout storage self.policy_storage.prev_obs_raw[0].copy_(prev_obs_raw) self.policy_storage.prev_obs_normalised[0].copy_(prev_obs_normalised) self.policy_storage.to(device) vae_is_pretrained = False for self.iter_idx in range(self.args.num_updates): # First, re-compute the hidden states given the current rollouts (since the VAE might've changed) # compute latent embedding (will return prior if current trajectory is empty) with torch.no_grad(): latent_sample, latent_mean, latent_logvar, hidden_state = self.encode_running_trajectory( ) # check if we flushed the policy storage assert len(self.policy_storage.latent_mean) == 0 # add this initial hidden state to the policy storage self.policy_storage.hidden_states[0].copy_(hidden_state) self.policy_storage.latent_samples.append(latent_sample.clone()) self.policy_storage.latent_mean.append(latent_mean.clone()) self.policy_storage.latent_logvar.append(latent_logvar.clone()) # rollout policies for a few steps for step in range(self.args.policy_num_steps): # sample actions from policy with torch.no_grad(): value, action, action_log_prob = utl.select_action( args=self.args, policy=self.policy, obs=prev_obs_normalised if self.args.norm_obs_for_policy else prev_obs_raw, deterministic=False, latent_sample=latent_sample, latent_mean=latent_mean, latent_logvar=latent_logvar, ) # observe reward and next obs (next_obs_raw, next_obs_normalised), ( rew_raw, rew_normalised), done, infos = utl.env_step( self.envs, action) tasks = torch.FloatTensor([info['task'] for info in infos]).to(device) done = torch.from_numpy(np.array( done, dtype=int)).to(device).float().view((-1, 1)) # create mask for episode ends masks_done = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]).to(device) # bad_mask is true if episode ended because time limit was reached bad_masks = torch.FloatTensor( [[0.0] if 'bad_transition' in info.keys() else [1.0] for info in infos]).to(device) # compute next embedding (for next loop and/or value prediction bootstrap) latent_sample, latent_mean, latent_logvar, hidden_state = utl.update_encoding( encoder=self.vae.encoder, next_obs=next_obs_raw, action=action, reward=rew_raw, done=done, hidden_state=hidden_state) # before resetting, update the embedding and add to vae buffer # (last state might include useful task info) if not (self.args.disable_decoder and self.args.disable_stochasticity_in_latent): self.vae.rollout_storage.insert(prev_obs_raw.clone(), action.detach().clone(), next_obs_raw.clone(), rew_raw.clone(), done.clone(), tasks.clone()) # add the obs before reset to the policy storage # (only used to recompute embeddings if rlloss is backpropagated through encoder) self.policy_storage.next_obs_raw[step] = next_obs_raw.clone() self.policy_storage.next_obs_normalised[ step] = next_obs_normalised.clone() # reset environments that are done done_indices = np.argwhere( done.cpu().detach().flatten()).flatten() if len(done_indices) == self.args.num_processes: [next_obs_raw, next_obs_normalised] = self.envs.reset() if not self.args.sample_embeddings: latent_sample = latent_sample else: for i in done_indices: [next_obs_raw[i], next_obs_normalised[i]] = self.envs.reset(index=i) if not self.args.sample_embeddings: latent_sample[i] = latent_sample[i] # # add experience to policy buffer self.policy_storage.insert( obs_raw=next_obs_raw, obs_normalised=next_obs_normalised, actions=action, action_log_probs=action_log_prob, rewards_raw=rew_raw, rewards_normalised=rew_normalised, value_preds=value, masks=masks_done, bad_masks=bad_masks, done=done, hidden_states=hidden_state.squeeze(0).detach(), latent_sample=latent_sample.detach(), latent_mean=latent_mean.detach(), latent_logvar=latent_logvar.detach(), ) prev_obs_normalised = next_obs_normalised prev_obs_raw = next_obs_raw self.frames += self.args.num_processes # --- UPDATE --- if self.args.precollect_len <= self.frames: # check if we are pre-training the VAE if self.args.pretrain_len > 0 and not vae_is_pretrained: for _ in range(self.args.pretrain_len): self.vae.compute_vae_loss(update=True) vae_is_pretrained = True # otherwise do the normal update (policy + vae) else: train_stats = self.update( obs=prev_obs_normalised if self.args.norm_obs_for_policy else prev_obs_raw, latent_sample=latent_sample, latent_mean=latent_mean, latent_logvar=latent_logvar) # log run_stats = [action, action_log_prob, value] if train_stats is not None: self.log(run_stats, train_stats, start_time) # clean up after update self.policy_storage.after_update() def encode_running_trajectory(self): """ (Re-)Encodes (for each process) the entire current trajectory. Returns sample/mean/logvar and hidden state (if applicable) for the current timestep. :return: """ # for each process, get the current batch (zero-padded obs/act/rew + length indicators) prev_obs, next_obs, act, rew, lens = self.vae.rollout_storage.get_running_batch( ) # get embedding - will return (1+sequence_len) * batch * input_size -- includes the prior! all_latent_samples, all_latent_means, all_latent_logvars, all_hidden_states = self.vae.encoder( actions=act, states=next_obs, rewards=rew, hidden_state=None, return_prior=True) # get the embedding / hidden state of the current time step (need to do this since we zero-padded) latent_sample = (torch.stack([ all_latent_samples[lens[i]][i] for i in range(len(lens)) ])).detach().to(device) latent_mean = (torch.stack([ all_latent_means[lens[i]][i] for i in range(len(lens)) ])).detach().to(device) latent_logvar = (torch.stack([ all_latent_logvars[lens[i]][i] for i in range(len(lens)) ])).detach().to(device) hidden_state = (torch.stack([ all_hidden_states[lens[i]][i] for i in range(len(lens)) ])).detach().to(device) return latent_sample, latent_mean, latent_logvar, hidden_state def get_value(self, obs, latent_sample, latent_mean, latent_logvar): obs = utl.get_augmented_obs(self.args, obs, latent_sample, latent_mean, latent_logvar) return self.policy.actor_critic.get_value(obs).detach() def update(self, obs, latent_sample, latent_mean, latent_logvar): """ Meta-update. Here the policy is updated for good average performance across tasks. :return: """ # update policy (if we are not pre-training, have enough data in the vae buffer, and are not at iteration 0) if self.iter_idx >= self.args.pretrain_len and self.iter_idx > 0: # bootstrap next value prediction with torch.no_grad(): next_value = self.get_value(obs=obs, latent_sample=latent_sample, latent_mean=latent_mean, latent_logvar=latent_logvar) # compute returns for current rollouts self.policy_storage.compute_returns( next_value, self.args.policy_use_gae, self.args.policy_gamma, self.args.policy_tau, use_proper_time_limits=self.args.use_proper_time_limits) # update agent (this will also call the VAE update!) policy_train_stats = self.policy.update( args=self.args, policy_storage=self.policy_storage, encoder=self.vae.encoder, rlloss_through_encoder=self.args.rlloss_through_encoder, compute_vae_loss=self.vae.compute_vae_loss) else: policy_train_stats = 0, 0, 0, 0 # pre-train the VAE if self.iter_idx < self.args.pretrain_len: self.vae.compute_vae_loss(update=True) return policy_train_stats, None def log(self, run_stats, train_stats, start_time): train_stats, meta_train_stats = train_stats # --- visualise behaviour of policy --- if self.iter_idx % self.args.vis_interval == 0: obs_rms = self.envs.venv.obs_rms if self.args.norm_obs_for_policy else None ret_rms = self.envs.venv.ret_rms if self.args.norm_rew_for_policy else None utl_eval.visualise_behaviour( args=self.args, policy=self.policy, image_folder=self.logger.full_output_folder, iter_idx=self.iter_idx, obs_rms=obs_rms, ret_rms=ret_rms, encoder=self.vae.encoder, reward_decoder=self.vae.reward_decoder, state_decoder=self.vae.state_decoder, task_decoder=self.vae.task_decoder, compute_rew_reconstruction_loss=self.vae. compute_rew_reconstruction_loss, compute_state_reconstruction_loss=self.vae. compute_state_reconstruction_loss, compute_task_reconstruction_loss=self.vae. compute_task_reconstruction_loss, compute_kl_loss=self.vae.compute_kl_loss, ) # --- evaluate policy ---- if self.iter_idx % self.args.eval_interval == 0: obs_rms = self.envs.venv.obs_rms if self.args.norm_obs_for_policy else None ret_rms = self.envs.venv.ret_rms if self.args.norm_rew_for_policy else None returns_per_episode = utl_eval.evaluate(args=self.args, policy=self.policy, obs_rms=obs_rms, ret_rms=ret_rms, encoder=self.vae.encoder, iter_idx=self.iter_idx) # log the return avg/std across tasks (=processes) returns_avg = returns_per_episode.mean(dim=0) returns_std = returns_per_episode.std(dim=0) for k in range(len(returns_avg)): self.logger.add('return_avg_per_iter/episode_{}'.format(k + 1), returns_avg[k], self.iter_idx) self.logger.add( 'return_avg_per_frame/episode_{}'.format(k + 1), returns_avg[k], self.frames) self.logger.add('return_std_per_iter/episode_{}'.format(k + 1), returns_std[k], self.iter_idx) self.logger.add( 'return_std_per_frame/episode_{}'.format(k + 1), returns_std[k], self.frames) print( "Updates {}, num timesteps {}, FPS {}, {} \n Mean return (train): {:.5f} \n" .format(self.iter_idx, self.frames, int(self.frames / (time.time() - start_time)), self.vae.rollout_storage.prev_obs.shape, returns_avg[-1].item())) # --- save models --- if self.iter_idx % self.args.save_interval == 0: save_path = os.path.join(self.logger.full_output_folder, 'models') if not os.path.exists(save_path): os.mkdir(save_path) torch.save( self.policy.actor_critic, os.path.join(save_path, "policy{0}.pt".format(self.iter_idx))) torch.save( self.vae.encoder, os.path.join(save_path, "encoder{0}.pt".format(self.iter_idx))) if self.vae.state_decoder is not None: torch.save( self.vae.state_decoder, os.path.join(save_path, "state_decoder{0}.pt".format(self.iter_idx))) if self.vae.reward_decoder is not None: torch.save( self.vae.reward_decoder, os.path.join(save_path, "reward_decoder{0}.pt".format(self.iter_idx))) if self.vae.task_decoder is not None: torch.save( self.vae.task_decoder, os.path.join(save_path, "task_decoder{0}.pt".format(self.iter_idx))) # save normalisation params of envs if self.args.norm_rew_for_policy: # save rolling mean and std rew_rms = self.envs.venv.ret_rms utl.save_obj(rew_rms, save_path, "env_rew_rms{0}.pkl".format(self.iter_idx)) if self.args.norm_obs_for_policy: obs_rms = self.envs.venv.obs_rms utl.save_obj(obs_rms, save_path, "env_obs_rms{0}.pkl".format(self.iter_idx)) # --- log some other things --- if self.iter_idx % self.args.log_interval == 0: self.logger.add('policy_losses/value_loss', train_stats[0], self.iter_idx) self.logger.add('policy_losses/action_loss', train_stats[1], self.iter_idx) self.logger.add('policy_losses/dist_entropy', train_stats[2], self.iter_idx) self.logger.add('policy_losses/sum', train_stats[3], self.iter_idx) self.logger.add('policy/action', run_stats[0][0].float().mean(), self.iter_idx) if hasattr(self.policy.actor_critic, 'logstd'): self.logger.add('policy/action_logstd', self.policy.actor_critic.dist.logstd.mean(), self.iter_idx) self.logger.add('policy/action_logprob', run_stats[1].mean(), self.iter_idx) self.logger.add('policy/value', run_stats[2].mean(), self.iter_idx) self.logger.add('encoder/latent_mean', torch.cat(self.policy_storage.latent_mean).mean(), self.iter_idx) self.logger.add( 'encoder/latent_logvar', torch.cat(self.policy_storage.latent_logvar).mean(), self.iter_idx) # log the average weights and gradients of all models (where applicable) for [model, name ] in [[self.policy.actor_critic, 'policy'], [self.vae.encoder, 'encoder'], [self.vae.reward_decoder, 'reward_decoder'], [self.vae.state_decoder, 'state_transition_decoder'], [self.vae.task_decoder, 'task_decoder']]: if model is not None: param_list = list(model.parameters()) param_mean = np.mean([ param_list[i].data.cpu().numpy().mean() for i in range(len(param_list)) ]) self.logger.add('weights/{}'.format(name), param_mean, self.iter_idx) if name == 'policy': self.logger.add('weights/policy_std', param_list[0].data.mean(), self.iter_idx) if param_list[0].grad is not None: param_grad_mean = np.mean([ param_list[i].grad.cpu().numpy().mean() for i in range(len(param_list)) ]) self.logger.add('gradients/{}'.format(name), param_grad_mean, self.iter_idx) def load_and_render(self, load_iter): #save_path = os.path.join('/ext/varibad_github/v2/varibad/logs/logs_HalfCheetahJoint-v0/varibad_73__15:05_17:14:07', 'models') #save_path = os.path.join('/ext/varibad_github/v2/varibad/logs/hfield', 'models') save_path = os.path.join( '/ext/varibad_github/v2/varibad/logs/logs_HalfCheetahBlocks-v0/varibad_73__15:05_20:20:25', 'models') self.policy.actor_critic = torch.load( os.path.join(save_path, "policy{0}.pt".format(load_iter))) self.vae.encoder = torch.load( os.path.join(save_path, "encoder{0}.pt").format(load_iter)) args = self.args device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") num_processes = 1 num_episodes = 100 num_steps = 1999 #import pdb; pdb.set_trace() # initialise environments envs = make_vec_envs( env_name=args.env_name, seed=args.seed, num_processes=num_processes, # 1 gamma=args.policy_gamma, log_dir=args.agent_log_dir, device=device, allow_early_resets=False, episodes_per_task=self.args.max_rollouts_per_task, obs_rms=None, ret_rms=None, ) # reset latent state to prior latent_sample, latent_mean, latent_logvar, hidden_state = self.vae.encoder.prior( num_processes) for episode_idx in range(num_episodes): (prev_obs_raw, prev_obs_normalised) = envs.reset() prev_obs_raw = prev_obs_raw.to(device) prev_obs_normalised = prev_obs_normalised.to(device) for step_idx in range(num_steps): with torch.no_grad(): _, action, _ = utl.select_action( args=self.args, policy=self.policy, obs=prev_obs_normalised if self.args.norm_obs_for_policy else prev_obs_raw, latent_sample=latent_sample, latent_mean=latent_mean, latent_logvar=latent_logvar, deterministic=True) # observe reward and next obs (next_obs_raw, next_obs_normalised), ( rew_raw, rew_normalised), done, infos = utl.env_step(envs, action) # render envs.venv.venv.envs[0].env.env.env.env.render() # update the hidden state latent_sample, latent_mean, latent_logvar, hidden_state = utl.update_encoding( encoder=self.vae.encoder, next_obs=next_obs_raw, action=action, reward=rew_raw, done=None, hidden_state=hidden_state) prev_obs_normalised = next_obs_normalised prev_obs_raw = next_obs_raw if done[0]: break
class Learner: """ Learner (no meta-learning), can be used to train avg/oracle/belief-oracle policies. """ def __init__(self, args): self.args = args utl.seed(self.args.seed, self.args.deterministic_execution) # calculate number of updates and keep count of frames/iterations self.num_updates = int( args.num_frames) // args.policy_num_steps // args.num_processes self.frames = 0 self.iter_idx = 0 # initialise tensorboard logger self.logger = TBLogger(self.args, self.args.exp_label) # initialise environments self.envs = make_vec_envs( env_name=args.env_name, seed=args.seed, num_processes=args.num_processes, gamma=args.policy_gamma, device=device, episodes_per_task=self.args.max_rollouts_per_task, normalise_rew=args.norm_rew_for_policy, ret_rms=None, ) # calculate what the maximum length of the trajectories is args.max_trajectory_len = self.envs._max_episode_steps args.max_trajectory_len *= self.args.max_rollouts_per_task # get policy input dimensions self.args.state_dim = self.envs.observation_space.shape[0] self.args.task_dim = self.envs.task_dim self.args.belief_dim = self.envs.belief_dim self.args.num_states = self.envs.num_states # get policy output (action) dimensions self.args.action_space = self.envs.action_space if isinstance(self.envs.action_space, gym.spaces.discrete.Discrete): self.args.action_dim = 1 else: self.args.action_dim = self.envs.action_space.shape[0] # initialise policy self.policy_storage = self.initialise_policy_storage() self.policy = self.initialise_policy() def initialise_policy_storage(self): return OnlineStorage( args=self.args, num_steps=self.args.policy_num_steps, num_processes=self.args.num_processes, state_dim=self.args.state_dim, latent_dim=0, # use metalearner.py if you want to use the VAE belief_dim=self.args.belief_dim, task_dim=self.args.task_dim, action_space=self.args.action_space, hidden_size=0, normalise_rewards=self.args.norm_rew_for_policy, ) def initialise_policy(self): if hasattr(self.envs.action_space, 'low'): action_low = self.envs.action_space.low action_high = self.envs.action_space.high else: action_low = action_high = None # initialise policy network policy_net = Policy( args=self.args, # pass_state_to_policy=self.args.pass_state_to_policy, pass_latent_to_policy= False, # use metalearner.py if you want to use the VAE pass_belief_to_policy=self.args.pass_belief_to_policy, pass_task_to_policy=self.args.pass_task_to_policy, dim_state=self.args.state_dim, dim_latent=0, dim_belief=self.args.belief_dim, dim_task=self.args.task_dim, # hidden_layers=self.args.policy_layers, activation_function=self.args.policy_activation_function, policy_initialisation=self.args.policy_initialisation, # action_space=self.envs.action_space, init_std=self.args.policy_init_std, norm_actions_of_policy=self.args.norm_actions_of_policy, action_low=action_low, action_high=action_high, ).to(device) # initialise policy trainer if self.args.policy == 'a2c': policy = A2C( self.args, policy_net, self.args.policy_value_loss_coef, self.args.policy_entropy_coef, policy_optimiser=self.args.policy_optimiser, policy_anneal_lr=self.args.policy_anneal_lr, train_steps=self.num_updates, lr=self.args.lr_policy, eps=self.args.policy_eps, ) elif self.args.policy == 'ppo': policy = PPO( self.args, policy_net, self.args.policy_value_loss_coef, self.args.policy_entropy_coef, policy_optimiser=self.args.policy_optimiser, policy_anneal_lr=self.args.policy_anneal_lr, train_steps=self.num_updates, lr=self.args.lr_policy, eps=self.args.policy_eps, ppo_epoch=self.args.ppo_num_epochs, num_mini_batch=self.args.ppo_num_minibatch, use_huber_loss=self.args.ppo_use_huberloss, use_clipped_value_loss=self.args.ppo_use_clipped_value_loss, clip_param=self.args.ppo_clip_param, ) else: raise NotImplementedError return policy def train(self): """ Main training loop """ start_time = time.time() # reset environments state, belief, task = utl.reset_env(self.envs, self.args) # insert initial observation / embeddings to rollout storage self.policy_storage.prev_state[0].copy_(state) # log once before training with torch.no_grad(): self.log(None, None, start_time) for self.iter_idx in range(self.num_updates): # rollout policies for a few steps for step in range(self.args.policy_num_steps): # sample actions from policy with torch.no_grad(): value, action, action_log_prob = utl.select_action( args=self.args, policy=self.policy, state=state, belief=belief, task=task, deterministic=False) # observe reward and next obs [state, belief, task], (rew_raw, rew_normalised), done, infos = utl.env_step( self.envs, action, self.args) # create mask for episode ends masks_done = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]).to(device) # bad_mask is true if episode ended because time limit was reached bad_masks = torch.FloatTensor( [[0.0] if 'bad_transition' in info.keys() else [1.0] for info in infos]).to(device) # reset environments that are done done_indices = np.argwhere(done.flatten()).flatten() if len(done_indices) > 0: state, belief, task = utl.reset_env(self.envs, self.args, indices=done_indices, state=state) # add experience to policy buffer self.policy_storage.insert( state=state, belief=belief, task=task, actions=action, action_log_probs=action_log_prob, rewards_raw=rew_raw, rewards_normalised=rew_normalised, value_preds=value, masks=masks_done, bad_masks=bad_masks, done=torch.from_numpy(np.array(done, dtype=float)).unsqueeze(1), ) self.frames += self.args.num_processes # --- UPDATE --- train_stats = self.update(state=state, belief=belief, task=task) # log run_stats = [action, action_log_prob, value] if train_stats is not None: with torch.no_grad(): self.log(run_stats, train_stats, start_time) # clean up after update self.policy_storage.after_update() def get_value(self, state, belief, task): return self.policy.actor_critic.get_value(state=state, belief=belief, task=task, latent=None).detach() def update(self, state, belief, task): """ Meta-update. Here the policy is updated for good average performance across tasks. :return: policy_train_stats which are: value_loss_epoch, action_loss_epoch, dist_entropy_epoch, loss_epoch """ # bootstrap next value prediction with torch.no_grad(): next_value = self.get_value(state=state, belief=belief, task=task) # compute returns for current rollouts self.policy_storage.compute_returns( next_value, self.args.policy_use_gae, self.args.policy_gamma, self.args.policy_tau, use_proper_time_limits=self.args.use_proper_time_limits) policy_train_stats = self.policy.update( policy_storage=self.policy_storage) return policy_train_stats, None def log(self, run_stats, train_stats, start): """ Evaluate policy, save model, write to tensorboard logger. """ # --- visualise behaviour of policy --- if self.iter_idx % self.args.vis_interval == 0: ret_rms = self.envs.venv.ret_rms if self.args.norm_rew_for_policy else None utl_eval.visualise_behaviour( args=self.args, policy=self.policy, image_folder=self.logger.full_output_folder, iter_idx=self.iter_idx, ret_rms=ret_rms, ) # --- evaluate policy ---- if self.iter_idx % self.args.eval_interval == 0: ret_rms = self.envs.venv.ret_rms if self.args.norm_rew_for_policy else None returns_per_episode = utl_eval.evaluate(args=self.args, policy=self.policy, ret_rms=ret_rms, iter_idx=self.iter_idx) # log the average return across tasks (=processes) returns_avg = returns_per_episode.mean(dim=0) returns_std = returns_per_episode.std(dim=0) for k in range(len(returns_avg)): self.logger.add('return_avg_per_iter/episode_{}'.format(k + 1), returns_avg[k], self.iter_idx) self.logger.add( 'return_avg_per_frame/episode_{}'.format(k + 1), returns_avg[k], self.frames) self.logger.add('return_std_per_iter/episode_{}'.format(k + 1), returns_std[k], self.iter_idx) self.logger.add( 'return_std_per_frame/episode_{}'.format(k + 1), returns_std[k], self.frames) print( "Updates {}, num timesteps {}, FPS {} \n Mean return (train): {:.5f} \n" .format(self.iter_idx, self.frames, int(self.frames / (time.time() - start)), returns_avg[-1].item())) # save model if self.iter_idx % self.args.save_interval == 0: save_path = os.path.join(self.logger.full_output_folder, 'models') if not os.path.exists(save_path): os.mkdir(save_path) idx_labels = [''] if self.args.save_intermediate_models: idx_labels.append(int(self.iter_idx)) for idx_label in idx_labels: torch.save(self.policy.actor_critic, os.path.join(save_path, f"policy{idx_label}.pt")) # save normalisation params of envs if self.args.norm_rew_for_policy: rew_rms = self.envs.venv.ret_rms utl.save_obj(rew_rms, save_path, f"env_rew_rms{idx_label}") # TODO: grab from policy and save? # if self.args.norm_obs_for_policy: # obs_rms = self.envs.venv.obs_rms # utl.save_obj(obs_rms, save_path, f"env_obs_rms{idx_label}") # --- log some other things --- if (self.iter_idx % self.args.log_interval == 0) and (train_stats is not None): train_stats, _ = train_stats self.logger.add('policy_losses/value_loss', train_stats[0], self.iter_idx) self.logger.add('policy_losses/action_loss', train_stats[1], self.iter_idx) self.logger.add('policy_losses/dist_entropy', train_stats[2], self.iter_idx) self.logger.add('policy_losses/sum', train_stats[3], self.iter_idx) # writer.add_scalar('policy/action', action.mean(), j) self.logger.add('policy/action', run_stats[0][0].float().mean(), self.iter_idx) if hasattr(self.policy.actor_critic, 'logstd'): self.logger.add('policy/action_logstd', self.policy.actor_critic.dist.logstd.mean(), self.iter_idx) self.logger.add('policy/action_logprob', run_stats[1].mean(), self.iter_idx) self.logger.add('policy/value', run_stats[2].mean(), self.iter_idx) param_list = list(self.policy.actor_critic.parameters()) param_mean = np.mean([ param_list[i].data.cpu().numpy().mean() for i in range(len(param_list)) ]) param_grad_mean = np.mean([ param_list[i].grad.cpu().numpy().mean() for i in range(len(param_list)) ]) self.logger.add('weights/policy', param_mean, self.iter_idx) self.logger.add('weights/policy_std', param_list[0].data.cpu().mean(), self.iter_idx) self.logger.add('gradients/policy', param_grad_mean, self.iter_idx) self.logger.add('gradients/policy_std', param_list[0].grad.cpu().numpy().mean(), self.iter_idx)
def __init__(self, args): self.args = args utl.seed(self.args.seed, self.args.deterministic_execution) # calculate number of updates and keep count of frames/iterations self.num_updates = int( args.num_frames) // args.policy_num_steps // args.num_processes self.frames = 0 self.iter_idx = -1 # initialise tensorboard logger self.logger = TBLogger(self.args, self.args.exp_label) # initialise environments self.envs = make_vec_envs( env_name=args.env_name, seed=args.seed, num_processes=args.num_processes, gamma=args.policy_gamma, device=device, episodes_per_task=self.args.max_rollouts_per_task, normalise_rew=args.norm_rew_for_policy, ret_rms=None, tasks=None) if self.args.single_task_mode: # get the current tasks (which will be num_process many different tasks) self.train_tasks = self.envs.get_task() # set the tasks to the first task (i.e. just a random task) self.train_tasks[1:] = self.train_tasks[0] # make it a list self.train_tasks = [t for t in self.train_tasks] # re-initialise environments with those tasks self.envs = make_vec_envs( env_name=args.env_name, seed=args.seed, num_processes=args.num_processes, gamma=args.policy_gamma, device=device, episodes_per_task=self.args.max_rollouts_per_task, normalise_rew=args.norm_rew_for_policy, ret_rms=None, tasks=self.train_tasks, ) # save the training tasks so we can evaluate on the same envs later utl.save_obj(self.train_tasks, self.logger.full_output_folder, "train_tasks") else: self.train_tasks = None # calculate what the maximum length of the trajectories is args.max_trajectory_len = self.envs._max_episode_steps args.max_trajectory_len *= self.args.max_rollouts_per_task # get policy input dimensions self.args.state_dim = self.envs.observation_space.shape[0] self.args.task_dim = self.envs.task_dim self.args.belief_dim = self.envs.belief_dim self.args.num_states = self.envs.num_states # get policy output (action) dimensions self.args.action_space = self.envs.action_space if isinstance(self.envs.action_space, gym.spaces.discrete.Discrete): self.args.action_dim = 1 else: self.args.action_dim = self.envs.action_space.shape[0] # initialise policy self.policy_storage = self.initialise_policy_storage() self.policy = self.initialise_policy()
def __init__(self, args): """ Seeds everything. Initialises: logger, environments, policy (+storage +optimiser). """ self.args = args # make sure everything has the same seed utl.seed(self.args.seed) # initialise environment self.env = make_env( self.args.env_name, self.args.max_rollouts_per_task, seed=self.args.seed, n_tasks=1, modify_init_state_dist=self.args.modify_init_state_dist if 'modify_init_state_dist' in self.args else False, on_circle_init_state=self.args.on_circle_init_state if 'on_circle_init_state' in self.args else True) # saving buffer with task in name folder if hasattr(self.args, 'save_buffer') and self.args.save_buffer: env_dir = os.path.join(self.args.main_save_dir, '{}'.format(self.args.env_name)) goal = self.env.unwrapped._goal self.output_dir = os.path.join( env_dir, self.args.save_dir, 'seed_{}_'.format(self.args.seed) + off_utl.create_goal_path_ext_from_goal(goal)) if self.args.save_models or self.args.save_buffer: os.makedirs(self.output_dir, exist_ok=True) config_utl.save_config_file(args, self.output_dir) # initialize tensorboard logger if self.args.log_tensorboard: self.tb_logger = TBLogger(self.args) # if not self.args.log_tensorboard: # self.save_config_json_file() # unwrapped env to get some info about the environment unwrapped_env = self.env.unwrapped # calculate what the maximum length of the trajectories is args.max_trajectory_len = unwrapped_env._max_episode_steps args.max_trajectory_len *= self.args.max_rollouts_per_task self.args.max_trajectory_len = args.max_trajectory_len # get action / observation dimensions if isinstance(self.env.action_space, gym.spaces.discrete.Discrete): self.args.action_dim = 1 else: self.args.action_dim = self.env.action_space.shape[0] self.args.obs_dim = self.env.observation_space.shape[0] self.args.num_states = unwrapped_env.num_states if hasattr( unwrapped_env, 'num_states') else None self.args.act_space = self.env.action_space # simulate env step to get reward types _, _, _, info = unwrapped_env.step(unwrapped_env.action_space.sample()) reward_types = [ reward_type for reward_type in list(info.keys()) if reward_type.startswith('reward') ] # support dense rewards training (if exists) self.args.dense_train_sparse_test = self.args.dense_train_sparse_test \ if 'dense_train_sparse_test' in self.args else False # initialize policy self.initialize_policy() # initialize buffer for RL updates self.policy_storage = MultiTaskPolicyStorage( max_replay_buffer_size=int(self.args.policy_buffer_size), obs_dim=self.args.obs_dim, action_space=self.env.action_space, tasks=[0], trajectory_len=args.max_trajectory_len, num_reward_arrays=len(reward_types) if reward_types and self.args.dense_train_sparse_test else 1, reward_types=reward_types, ) self.args.belief_reward = False # initialize arg to not use belief rewards
class Learner: """ Learner (no meta-learning), can be used to train Oracle policies. """ def __init__(self, args): self.args = args # make sure everything has the same seed utl.seed(self.args.seed, self.args.deterministic_execution) # initialise tensorboard logger self.logger = TBLogger(self.args, self.args.exp_label) # initialise environments self.envs = make_vec_envs( env_name=args.env_name, seed=args.seed, num_processes=args.num_processes, gamma=args.policy_gamma, log_dir=args.agent_log_dir, device=device, allow_early_resets=False, episodes_per_task=self.args.max_rollouts_per_task, obs_rms=None, ret_rms=None, ) # calculate what the maximum length of the trajectories is args.max_trajectory_len = self.envs._max_episode_steps args.max_trajectory_len *= self.args.max_rollouts_per_task # calculate number of meta updates self.args.num_updates = int( args.num_frames) // args.policy_num_steps // args.num_processes # get action / observation dimensions if isinstance(self.envs.action_space, gym.spaces.discrete.Discrete): self.args.action_dim = 1 else: self.args.action_dim = self.envs.action_space.shape[0] self.args.obs_dim = self.envs.observation_space.shape[0] self.args.num_states = self.envs.num_states if str.startswith( self.args.env_name, 'Grid') else None self.args.act_space = self.envs.action_space self.initialise_policy() # count number of frames and updates self.frames = 0 self.iter_idx = 0 def initialise_policy(self): # variables for task encoder (used for oracle) state_dim = self.envs.observation_space.shape[0] # TODO: this isn't ideal, find a nicer way to get the task dimension! if 'BeliefOracle' in self.args.env_name: task_dim = gym.make(self.args.env_name).observation_space.shape[0] - \ gym.make(self.args.env_name.replace('BeliefOracle', '')).observation_space.shape[0] latent_dim = self.args.latent_dim state_embedding_size = self.args.state_embedding_size use_task_encoder = True elif 'Oracle' in self.args.env_name: task_dim = gym.make(self.args.env_name).observation_space.shape[0] - \ gym.make(self.args.env_name.replace('Oracle', '')).observation_space.shape[0] latent_dim = self.args.latent_dim state_embedding_size = self.args.state_embedding_size use_task_encoder = True else: task_dim = latent_dim = state_embedding_size = 0 use_task_encoder = False # initialise rollout storage for the policy self.policy_storage = OnlineStorage( self.args, self.args.policy_num_steps, self.args.num_processes, self.args.obs_dim, self.args.act_space, hidden_size=0, latent_dim=self.args.latent_dim, normalise_observations=self.args.norm_obs_for_policy, normalise_rewards=self.args.norm_rew_for_policy, ) if hasattr(self.envs.action_space, 'low'): action_low = self.envs.action_space.low action_high = self.envs.action_space.high else: action_low = action_high = None # initialise policy network policy_net = Policy( # general state_dim=int(self.args.condition_policy_on_state) * state_dim, action_space=self.envs.action_space, init_std=self.args.policy_init_std, hidden_layers=self.args.policy_layers, activation_function=self.args.policy_activation_function, use_task_encoder=use_task_encoder, # task encoding things (for oracle) task_dim=task_dim, latent_dim=latent_dim, state_embed_dim=state_embedding_size, # normalise_actions=self.args.normalise_actions, action_low=action_low, action_high=action_high, ).to(device) # initialise policy if self.args.policy == 'a2c': # initialise policy trainer (A2C) self.policy = A2C( policy_net, self.args.policy_value_loss_coef, self.args.policy_entropy_coef, lr=self.args.lr_policy, eps=self.args.policy_eps, alpha=self.args.a2c_alpha, ) elif self.args.policy == 'ppo': # initialise policy network self.policy = PPO( policy_net, self.args.policy_value_loss_coef, self.args.policy_entropy_coef, lr=self.args.lr_policy, eps=self.args.policy_eps, ppo_epoch=self.args.ppo_num_epochs, num_mini_batch=self.args.ppo_num_minibatch, use_huber_loss=self.args.ppo_use_huberloss, use_clipped_value_loss=self.args.ppo_use_clipped_value_loss, clip_param=self.args.ppo_clip_param, ) else: raise NotImplementedError def train(self): """ Given some stream of environments and a logger (tensorboard), (meta-)trains the policy. """ start_time = time.time() # reset environments (prev_obs_raw, prev_obs_normalised) = self.envs.reset() prev_obs_raw = prev_obs_raw.to(device) prev_obs_normalised = prev_obs_normalised.to(device) # insert initial observation / embeddings to rollout storage self.policy_storage.prev_obs_raw[0].copy_(prev_obs_raw) self.policy_storage.prev_obs_normalised[0].copy_(prev_obs_normalised) self.policy_storage.to(device) for self.iter_idx in range(self.args.num_updates): # check if we flushed the policy storage assert len(self.policy_storage.latent_mean) == 0 # rollouts policies for a few steps for step in range(self.args.policy_num_steps): # sample actions from policy with torch.no_grad(): value, action, action_log_prob = utl.select_action( policy=self.policy, args=self.args, obs=prev_obs_normalised if self.args.norm_obs_for_policy else prev_obs_raw, deterministic=False) # observe reward and next obs (next_obs_raw, next_obs_normalised), ( rew_raw, rew_normalised), done, infos = utl.env_step( self.envs, action) action = action.float() # create mask for episode ends masks_done = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]).to(device) # bad_mask is true if episode ended because time limit was reached bad_masks = torch.FloatTensor( [[0.0] if 'bad_transition' in info.keys() else [1.0] for info in infos]).to(device) # add the obs before reset to the policy storage self.policy_storage.next_obs_raw[step] = next_obs_raw.clone() self.policy_storage.next_obs_normalised[ step] = next_obs_normalised.clone() # reset environments that are done done_indices = np.argwhere(done.flatten()).flatten() if len(done_indices) == self.args.num_processes: [next_obs_raw, next_obs_normalised] = self.envs.reset() if not self.args.sample_embeddings: latent_sample = latent_sample else: for i in done_indices: [next_obs_raw[i], next_obs_normalised[i]] = self.envs.reset(index=i) if not self.args.sample_embeddings: latent_sample[i] = latent_sample[i] # add experience to policy buffer self.policy_storage.insert( obs_raw=next_obs_raw.clone(), obs_normalised=next_obs_normalised.clone(), actions=action.clone(), action_log_probs=action_log_prob.clone(), rewards_raw=rew_raw.clone(), rewards_normalised=rew_normalised.clone(), value_preds=value.clone(), masks=masks_done.clone(), bad_masks=bad_masks.clone(), done=torch.from_numpy(np.array( done, dtype=float)).unsqueeze(1).clone(), ) prev_obs_normalised = next_obs_normalised prev_obs_raw = next_obs_raw self.frames += self.args.num_processes # --- UPDATE --- train_stats = self.update(prev_obs_normalised if self.args. norm_obs_for_policy else prev_obs_raw) # log run_stats = [action, action_log_prob, value] if train_stats is not None: self.log(run_stats, train_stats, start_time) # clean up after update self.policy_storage.after_update() def get_value(self, obs): obs = utl.get_augmented_obs(args=self.args, obs=obs) return self.policy.actor_critic.get_value(obs).detach() def update(self, obs): """ Meta-update. Here the policy is updated for good average performance across tasks. :return: policy_train_stats which are: value_loss_epoch, action_loss_epoch, dist_entropy_epoch, loss_epoch """ # bootstrap next value prediction with torch.no_grad(): next_value = self.get_value(obs) # compute returns for current rollouts self.policy_storage.compute_returns( next_value, self.args.policy_use_gae, self.args.policy_gamma, self.args.policy_tau, use_proper_time_limits=self.args.use_proper_time_limits) policy_train_stats = self.policy.update( args=self.args, policy_storage=self.policy_storage) return policy_train_stats, None def log(self, run_stats, train_stats, start): """ Evaluate policy, save model, write to tensorboard logger. """ train_stats, meta_train_stats = train_stats # --- visualise behaviour of policy --- if self.iter_idx % self.args.vis_interval == 0: obs_rms = self.envs.venv.obs_rms if self.args.norm_obs_for_policy else None ret_rms = self.envs.venv.ret_rms if self.args.norm_rew_for_policy else None utl_eval.visualise_behaviour( args=self.args, policy=self.policy, image_folder=self.logger.full_output_folder, iter_idx=self.iter_idx, obs_rms=obs_rms, ret_rms=ret_rms, ) # --- evaluate policy ---- if self.iter_idx % self.args.eval_interval == 0: obs_rms = self.envs.venv.obs_rms if self.args.norm_obs_for_policy else None ret_rms = self.envs.venv.ret_rms if self.args.norm_rew_for_policy else None returns_per_episode = utl_eval.evaluate(args=self.args, policy=self.policy, obs_rms=obs_rms, ret_rms=ret_rms, iter_idx=self.iter_idx) # log the average return across tasks (=processes) returns_avg = returns_per_episode.mean(dim=0) returns_std = returns_per_episode.std(dim=0) for k in range(len(returns_avg)): self.logger.add('return_avg_per_iter/episode_{}'.format(k + 1), returns_avg[k], self.iter_idx) self.logger.add( 'return_avg_per_frame/episode_{}'.format(k + 1), returns_avg[k], self.frames) self.logger.add('return_std_per_iter/episode_{}'.format(k + 1), returns_std[k], self.iter_idx) self.logger.add( 'return_std_per_frame/episode_{}'.format(k + 1), returns_std[k], self.frames) print( "Updates {}, num timesteps {}, FPS {} \n Mean return (train): {:.5f} \n" .format(self.iter_idx, self.frames, int(self.frames / (time.time() - start)), returns_avg[-1].item())) # save model if self.iter_idx % self.args.save_interval == 0: save_path = os.path.join(self.logger.full_output_folder, 'models') if not os.path.exists(save_path): os.mkdir(save_path) torch.save( self.policy.actor_critic, os.path.join(save_path, "policy{0}.pt".format(self.iter_idx))) # save normalisation params of envs if self.args.norm_rew_for_policy: # save rolling mean and std rew_rms = self.envs.venv.ret_rms utl.save_obj(rew_rms, save_path, "env_rew_rms{0}.pkl".format(self.iter_idx)) if self.args.norm_obs_for_policy: obs_rms = self.envs.venv.obs_rms utl.save_obj(obs_rms, save_path, "env_obs_rms{0}.pkl".format(self.iter_idx)) # --- log some other things --- if self.iter_idx % self.args.log_interval == 0: self.logger.add('policy_losses/value_loss', train_stats[0], self.iter_idx) self.logger.add('policy_losses/action_loss', train_stats[1], self.iter_idx) self.logger.add('policy_losses/dist_entropy', train_stats[2], self.iter_idx) self.logger.add('policy_losses/sum', train_stats[3], self.iter_idx) # writer.add_scalar('policy/action', action.mean(), j) self.logger.add('policy/action', run_stats[0][0].float().mean(), self.iter_idx) if hasattr(self.policy.actor_critic, 'logstd'): self.logger.add('policy/action_logstd', self.policy.actor_critic.dist.logstd.mean(), self.iter_idx) self.logger.add('policy/action_logprob', run_stats[1].mean(), self.iter_idx) self.logger.add('policy/value', run_stats[2].mean(), self.iter_idx) param_list = list(self.policy.actor_critic.parameters()) param_mean = np.mean( [param_list[i].data.mean() for i in range(len(param_list))]) param_grad_mean = np.mean( [param_list[i].grad.mean() for i in range(len(param_list))]) self.logger.add('weights/policy', param_mean, self.iter_idx) self.logger.add('weights/policy_std', param_list[0].data.mean(), self.iter_idx) self.logger.add('gradients/policy', param_grad_mean, self.iter_idx) self.logger.add('gradients/policy_std', param_list[0].grad.mean(), self.iter_idx)
class Learner: """ Learner class. """ def __init__(self, args): """ Seeds everything. Initialises: logger, environments, policy (+storage +optimiser). """ self.args = args # make sure everything has the same seed utl.seed(self.args.seed) # initialise environment self.env = make_env(self.args.env_name, self.args.max_rollouts_per_task, seed=self.args.seed, n_tasks=1, modify_init_state_dist=self.args.modify_init_state_dist if 'modify_init_state_dist' in self.args else False, on_circle_init_state=self.args.on_circle_init_state if 'on_circle_init_state' in self.args else True) # saving buffer with task in name folder if hasattr(self.args, 'save_buffer') and self.args.save_buffer: env_dir = os.path.join(self.args.main_save_dir, '{}'.format(self.args.env_name)) goal = self.env.unwrapped._goal self.output_dir = os.path.join(env_dir, self.args.save_dir, 'seed_{}_'.format(self.args.seed) + off_utl.create_goal_path_ext_from_goal(goal)) if self.args.save_models or self.args.save_buffer: os.makedirs(self.output_dir, exist_ok=True) config_utl.save_config_file(args, self.output_dir) # initialize tensorboard logger if self.args.log_tensorboard: self.tb_logger = TBLogger(self.args) # if not self.args.log_tensorboard: # self.save_config_json_file() # unwrapped env to get some info about the environment unwrapped_env = self.env.unwrapped # calculate what the maximum length of the trajectories is args.max_trajectory_len = unwrapped_env._max_episode_steps args.max_trajectory_len *= self.args.max_rollouts_per_task self.args.max_trajectory_len = args.max_trajectory_len # get action / observation dimensions if isinstance(self.env.action_space, gym.spaces.discrete.Discrete): self.args.action_dim = 1 else: self.args.action_dim = self.env.action_space.shape[0] self.args.obs_dim = self.env.observation_space.shape[0] self.args.num_states = unwrapped_env.num_states if hasattr(unwrapped_env, 'num_states') else None self.args.act_space = self.env.action_space # simulate env step to get reward types _, _, _, info = unwrapped_env.step(unwrapped_env.action_space.sample()) reward_types = [reward_type for reward_type in list(info.keys()) if reward_type.startswith('reward')] # support dense rewards training (if exists) self.args.dense_train_sparse_test = self.args.dense_train_sparse_test \ if 'dense_train_sparse_test' in self.args else False # initialize policy self.initialize_policy() # initialize buffer for RL updates self.policy_storage = MultiTaskPolicyStorage( max_replay_buffer_size=int(self.args.policy_buffer_size), obs_dim=self.args.obs_dim, action_space=self.env.action_space, tasks=[0], trajectory_len=args.max_trajectory_len, num_reward_arrays=len(reward_types) if reward_types and self.args.dense_train_sparse_test else 1, reward_types=reward_types, ) self.args.belief_reward = False # initialize arg to not use belief rewards def initialize_policy(self): if self.args.policy == 'dqn': assert self.args.act_space.__class__.__name__ == "Discrete", ( "Can't train DQN with continuous action space!") q_network = FlattenMlp(input_size=self.args.obs_dim, output_size=self.args.act_space.n, hidden_sizes=self.args.dqn_layers) self.agent = DQN( q_network, # optimiser_vae=self.optimizer_vae, lr=self.args.policy_lr, gamma=self.args.gamma, eps_init=self.args.dqn_epsilon_init, eps_final=self.args.dqn_epsilon_final, exploration_iters=self.args.dqn_exploration_iters, tau=self.args.soft_target_tau, ).to(ptu.device) # elif self.args.policy == 'ddqn': # assert self.args.act_space.__class__.__name__ == "Discrete", ( # "Can't train DDQN with continuous action space!") # q_network = FlattenMlp(input_size=self.args.obs_dim, # output_size=self.args.act_space.n, # hidden_sizes=self.args.dqn_layers) # self.agent = DoubleDQN( # q_network, # # optimiser_vae=self.optimizer_vae, # lr=self.args.policy_lr, # eps_optim=self.args.dqn_eps, # alpha_optim=self.args.dqn_alpha, # gamma=self.args.gamma, # eps_init=self.args.dqn_epsilon_init, # eps_final=self.args.dqn_epsilon_final, # exploration_iters=self.args.dqn_exploration_iters, # tau=self.args.soft_target_tau, # ).to(ptu.device) elif self.args.policy == 'sac': assert self.args.act_space.__class__.__name__ == "Box", ( "Can't train SAC with discrete action space!") q1_network = FlattenMlp(input_size=self.args.obs_dim + self.args.action_dim, output_size=1, hidden_sizes=self.args.dqn_layers) q2_network = FlattenMlp(input_size=self.args.obs_dim + self.args.action_dim, output_size=1, hidden_sizes=self.args.dqn_layers) policy = TanhGaussianPolicy(obs_dim=self.args.obs_dim, action_dim=self.args.action_dim, hidden_sizes=self.args.policy_layers) self.agent = SAC( policy, q1_network, q2_network, actor_lr=self.args.actor_lr, critic_lr=self.args.critic_lr, gamma=self.args.gamma, tau=self.args.soft_target_tau, entropy_alpha=self.args.entropy_alpha, automatic_entropy_tuning=self.args.automatic_entropy_tuning, alpha_lr=self.args.alpha_lr ).to(ptu.device) else: raise NotImplementedError def train(self): """ meta-training loop """ self._start_training() self.task_idx = 0 for iter_ in range(self.args.num_iters): self.training_mode(True) if iter_ == 0: print('Collecting initial pool of data..') self.env.reset_task(idx=self.task_idx) self.collect_rollouts(num_rollouts=self.args.num_init_rollouts_pool, random_actions=True) print('Done!') # collect data from subset of train tasks self.env.reset_task(idx=self.task_idx) self.collect_rollouts(num_rollouts=self.args.num_rollouts_per_iter) # update train_stats = self.update([self.task_idx]) self.training_mode(False) if self.args.policy == 'dqn': self.agent.set_exploration_parameter(iter_ + 1) # evaluate and log if (iter_ + 1) % self.args.log_interval == 0: self.log(iter_ + 1, train_stats) def update(self, tasks): ''' RL updates :param tasks: list/array of task indices. perform update based on the tasks :return: ''' # --- RL TRAINING --- rl_losses_agg = {} for update in range(self.args.rl_updates_per_iter): # sample random RL batch obs, actions, rewards, next_obs, terms = self.sample_rl_batch(tasks, self.args.batch_size) # flatten out task dimension t, b, _ = obs.size() obs = obs.view(t * b, -1) actions = actions.view(t * b, -1) rewards = rewards.view(t * b, -1) next_obs = next_obs.view(t * b, -1) terms = terms.view(t * b, -1) # RL update rl_losses = self.agent.update(obs, actions, rewards, next_obs, terms) for k, v in rl_losses.items(): if update == 0: # first iterate - create list rl_losses_agg[k] = [v] else: # append values rl_losses_agg[k].append(v) # take mean for k in rl_losses_agg: rl_losses_agg[k] = np.mean(rl_losses_agg[k]) self._n_rl_update_steps_total += self.args.rl_updates_per_iter return rl_losses_agg def evaluate(self, tasks): num_episodes = self.args.max_rollouts_per_task num_steps_per_episode = self.env.unwrapped._max_episode_steps returns_per_episode = np.zeros((len(tasks), num_episodes)) success_rate = np.zeros(len(tasks)) if self.args.policy == 'dqn': values = np.zeros((len(tasks), self.args.max_trajectory_len)) else: obs_size = self.env.unwrapped.observation_space.shape[0] observations = np.zeros((len(tasks), self.args.max_trajectory_len + 1, obs_size)) log_probs = np.zeros((len(tasks), self.args.max_trajectory_len)) for task_idx, task in enumerate(tasks): obs = ptu.from_numpy(self.env.reset(task)) obs = obs.reshape(-1, obs.shape[-1]) step = 0 if self.args.policy == 'sac': observations[task_idx, step, :] = ptu.get_numpy(obs[0, :obs_size]) for episode_idx in range(num_episodes): running_reward = 0. for step_idx in range(num_steps_per_episode): # add distribution parameters to observation - policy is conditioned on posterior if self.args.policy == 'dqn': action, value = self.agent.act(obs=obs, deterministic=True) else: action, _, _, log_prob = self.agent.act(obs=obs, deterministic=self.args.eval_deterministic, return_log_prob=True) # observe reward and next obs next_obs, reward, done, info = utl.env_step(self.env, action.squeeze(dim=0)) running_reward += reward.item() if self.args.policy == 'dqn': values[task_idx, step] = value.item() else: observations[task_idx, step + 1, :] = ptu.get_numpy(next_obs[0, :obs_size]) log_probs[task_idx, step] = ptu.get_numpy(log_prob[0]) if "is_goal_state" in dir(self.env.unwrapped) and self.env.unwrapped.is_goal_state(): success_rate[task_idx] = 1. # set: obs <- next_obs obs = next_obs.clone() step += 1 returns_per_episode[task_idx, episode_idx] = running_reward if self.args.policy == 'dqn': return returns_per_episode, success_rate, values else: return returns_per_episode, success_rate, log_probs, observations def log(self, iteration, train_stats): # --- save models --- if iteration % self.args.save_interval == 0: if self.args.save_models: if self.args.log_tensorboard: save_path = os.path.join(self.tb_logger.full_output_folder, 'models') else: save_path = os.path.join(self.output_dir, 'models') if not os.path.exists(save_path): os.mkdir(save_path) torch.save(self.agent.state_dict(), os.path.join(save_path, "agent{0}.pt".format(iteration))) if hasattr(self.args, 'save_buffer') and self.args.save_buffer: self.save_buffer() # evaluate to get more stats if self.args.policy == 'dqn': # get stats on train tasks returns_train, success_rate_train, values = self.evaluate([0]) else: # get stats on train tasks returns_train, success_rate_train, log_probs, observations = self.evaluate([0]) if self.args.log_tensorboard: if self.args.policy != 'dqn': # self.env.reset(0) # self.tb_logger.writer.add_figure('policy_vis_train/task_0', # utl_eval.plot_rollouts(observations[0, :], self.env), # self._n_env_steps_total) # obs, _, _, _, _ = self.sample_rl_batch(tasks=[0], # batch_size=self.policy_storage.task_buffers[0].size()) # self.tb_logger.writer.add_figure('state_space_coverage/task_0', # utl_eval.plot_visited_states(ptu.get_numpy(obs[0][:, :2]), self.env), # self._n_env_steps_total) pass # some metrics self.tb_logger.writer.add_scalar('metrics/successes_in_buffer', self._successes_in_buffer / self._n_env_steps_total, self._n_env_steps_total) if self.args.max_rollouts_per_task > 1: for episode_idx in range(self.args.max_rollouts_per_task): self.tb_logger.writer.add_scalar('returns_multi_episode/episode_{}'. format(episode_idx + 1), np.mean(returns_train[:, episode_idx]), self._n_env_steps_total) self.tb_logger.writer.add_scalar('returns_multi_episode/sum', np.mean(np.sum(returns_train, axis=-1)), self._n_env_steps_total) self.tb_logger.writer.add_scalar('returns_multi_episode/success_rate', np.mean(success_rate_train), self._n_env_steps_total) else: self.tb_logger.writer.add_scalar('returns/returns_mean_train', np.mean(returns_train), self._n_env_steps_total) self.tb_logger.writer.add_scalar('returns/returns_std_train', np.std(returns_train), self._n_env_steps_total) self.tb_logger.writer.add_scalar('returns/success_rate_train', np.mean(success_rate_train), self._n_env_steps_total) # policy if self.args.policy == 'dqn': self.tb_logger.writer.add_scalar('policy/value_init', np.mean(values[:, 0]), self._n_env_steps_total) self.tb_logger.writer.add_scalar('policy/value_halfway', np.mean(values[:, int(values.shape[-1]/2)]), self._n_env_steps_total) self.tb_logger.writer.add_scalar('policy/value_final', np.mean(values[:, -1]), self._n_env_steps_total) self.tb_logger.writer.add_scalar('policy/exploration_epsilon', self.agent.eps, self._n_env_steps_total) # RL losses self.tb_logger.writer.add_scalar('rl_losses/qf_loss_vs_n_updates', train_stats['qf_loss'], self._n_rl_update_steps_total) self.tb_logger.writer.add_scalar('rl_losses/qf_loss_vs_n_env_steps', train_stats['qf_loss'], self._n_env_steps_total) else: self.tb_logger.writer.add_scalar('policy/log_prob', np.mean(log_probs), self._n_env_steps_total) self.tb_logger.writer.add_scalar('rl_losses/qf1_loss', train_stats['qf1_loss'], self._n_env_steps_total) self.tb_logger.writer.add_scalar('rl_losses/qf2_loss', train_stats['qf2_loss'], self._n_env_steps_total) self.tb_logger.writer.add_scalar('rl_losses/policy_loss', train_stats['policy_loss'], self._n_env_steps_total) self.tb_logger.writer.add_scalar('rl_losses/alpha_entropy_loss', train_stats['alpha_entropy_loss'], self._n_env_steps_total) # weights and gradients if self.args.policy == 'dqn': self.tb_logger.writer.add_scalar('weights/q_network', list(self.agent.qf.parameters())[0].mean(), self._n_env_steps_total) if list(self.agent.qf.parameters())[0].grad is not None: param_list = list(self.agent.qf.parameters()) self.tb_logger.writer.add_scalar('gradients/q_network', sum([param_list[i].grad.mean() for i in range(len(param_list))]), self._n_env_steps_total) self.tb_logger.writer.add_scalar('weights/q_target', list(self.agent.target_qf.parameters())[0].mean(), self._n_env_steps_total) if list(self.agent.target_qf.parameters())[0].grad is not None: param_list = list(self.agent.target_qf.parameters()) self.tb_logger.writer.add_scalar('gradients/q_target', sum([param_list[i].grad.mean() for i in range(len(param_list))]), self._n_env_steps_total) else: self.tb_logger.writer.add_scalar('weights/q1_network', list(self.agent.qf1.parameters())[0].mean(), self._n_env_steps_total) if list(self.agent.qf1.parameters())[0].grad is not None: param_list = list(self.agent.qf1.parameters()) self.tb_logger.writer.add_scalar('gradients/q1_network', sum([param_list[i].grad.mean() for i in range(len(param_list))]), self._n_env_steps_total) self.tb_logger.writer.add_scalar('weights/q1_target', list(self.agent.qf1_target.parameters())[0].mean(), self._n_env_steps_total) if list(self.agent.qf1_target.parameters())[0].grad is not None: param_list = list(self.agent.qf1_target.parameters()) self.tb_logger.writer.add_scalar('gradients/q1_target', sum([param_list[i].grad.mean() for i in range(len(param_list))]), self._n_env_steps_total) self.tb_logger.writer.add_scalar('weights/q2_network', list(self.agent.qf2.parameters())[0].mean(), self._n_env_steps_total) if list(self.agent.qf2.parameters())[0].grad is not None: param_list = list(self.agent.qf2.parameters()) self.tb_logger.writer.add_scalar('gradients/q2_network', sum([param_list[i].grad.mean() for i in range(len(param_list))]), self._n_env_steps_total) self.tb_logger.writer.add_scalar('weights/q2_target', list(self.agent.qf2_target.parameters())[0].mean(), self._n_env_steps_total) if list(self.agent.qf2_target.parameters())[0].grad is not None: param_list = list(self.agent.qf2_target.parameters()) self.tb_logger.writer.add_scalar('gradients/q2_target', sum([param_list[i].grad.mean() for i in range(len(param_list))]), self._n_env_steps_total) self.tb_logger.writer.add_scalar('weights/policy', list(self.agent.policy.parameters())[0].mean(), self._n_env_steps_total) if list(self.agent.policy.parameters())[0].grad is not None: param_list = list(self.agent.policy.parameters()) self.tb_logger.writer.add_scalar('gradients/policy', sum([param_list[i].grad.mean() for i in range(len(param_list))]), self._n_env_steps_total) self.tb_logger.finish_iteration(iteration) print("Iteration -- {}, Success rate -- {:.3f}, Avg. return -- {:.3f}, Elapsed time {:5d}[s]" .format(iteration, np.mean(success_rate_train), np.mean(np.sum(returns_train, axis=-1)), int(time.time() - self._start_time))) # output to user # print("Iteration -- {:3d}, Num. RL updates -- {:6d}, Elapsed time {:5d}[s]". # format(iteration, # self._n_rl_update_steps_total, # int(time.time() - self._start_time))) def training_mode(self, mode): self.agent.train(mode) def collect_rollouts(self, num_rollouts, random_actions=False): ''' :param num_rollouts: :param random_actions: whether to use policy to sample actions, or randomly sample action space :return: ''' for rollout in range(num_rollouts): obs = ptu.from_numpy(self.env.reset(self.task_idx)) obs = obs.reshape(-1, obs.shape[-1]) done_rollout = False while not done_rollout: if random_actions: if self.args.policy == 'dqn': action = ptu.FloatTensor([[[self.env.action_space.sample()]]]).long() # Sample random action else: action = ptu.FloatTensor([self.env.action_space.sample()]) # Sample random action else: if self.args.policy == 'dqn': action, _ = self.agent.act(obs=obs) # DQN else: action, _, _, _ = self.agent.act(obs=obs) # SAC # observe reward and next obs next_obs, reward, done, info = utl.env_step(self.env, action.squeeze(dim=0)) done_rollout = False if ptu.get_numpy(done[0][0]) == 0. else True # add data to policy buffer - (s+, a, r, s'+, term) term = self.env.unwrapped.is_goal_state() if "is_goal_state" in dir(self.env.unwrapped) else False if self.args.dense_train_sparse_test: rew_to_buffer = {rew_type: rew for rew_type, rew in info.items() if rew_type.startswith('reward')} else: rew_to_buffer = ptu.get_numpy(reward.squeeze(dim=0)) self.policy_storage.add_sample(task=self.task_idx, observation=ptu.get_numpy(obs.squeeze(dim=0)), action=ptu.get_numpy(action.squeeze(dim=0)), reward=rew_to_buffer, terminal=np.array([term], dtype=float), next_observation=ptu.get_numpy(next_obs.squeeze(dim=0))) # set: obs <- next_obs obs = next_obs.clone() # update statistics self._n_env_steps_total += 1 if "is_goal_state" in dir(self.env.unwrapped) and self.env.unwrapped.is_goal_state(): # count successes self._successes_in_buffer += 1 self._n_rollouts_total += 1 def sample_rl_batch(self, tasks, batch_size): ''' sample batch of unordered rl training data from a list/array of tasks ''' # this batch consists of transitions sampled randomly from replay buffer batches = [ptu.np_to_pytorch_batch( self.policy_storage.random_batch(task, batch_size)) for task in tasks] unpacked = [utl.unpack_batch(batch) for batch in batches] # group elements together unpacked = [[x[i] for x in unpacked] for i in range(len(unpacked[0]))] unpacked = [torch.cat(x, dim=0) for x in unpacked] return unpacked def _start_training(self): self._n_env_steps_total = 0 self._n_rl_update_steps_total = 0 self._n_vae_update_steps_total = 0 self._n_rollouts_total = 0 self._successes_in_buffer = 0 self._start_time = time.time() def load_model(self, device='cpu', **kwargs): if "agent_path" in kwargs: self.agent.load_state_dict(torch.load(kwargs["agent_path"], map_location=device)) self.training_mode(False) def save_buffer(self): size = self.policy_storage.task_buffers[0].size() np.save(os.path.join(self.output_dir, 'obs'), self.policy_storage.task_buffers[0]._observations[:size]) np.save(os.path.join(self.output_dir, 'actions'), self.policy_storage.task_buffers[0]._actions[:size]) if self.args.dense_train_sparse_test: for reward_type, reward_arr in self.policy_storage.task_buffers[0]._rewards.items(): np.save(os.path.join(self.output_dir, reward_type), reward_arr[:size]) else: np.save(os.path.join(self.output_dir, 'rewards'), self.policy_storage.task_buffers[0]._rewards[:size]) np.save(os.path.join(self.output_dir, 'next_obs'), self.policy_storage.task_buffers[0]._next_obs[:size]) np.save(os.path.join(self.output_dir, 'terminals'), self.policy_storage.task_buffers[0]._terminals[:size])
class MetaLearner: """ Meta-Learner class with the main training loop for variBAD. """ def __init__(self, args): self.args = args utl.seed(self.args.seed, self.args.deterministic_execution) # calculate number of updates and keep count of frames/iterations self.num_updates = int( args.num_frames) // args.policy_num_steps // args.num_processes self.frames = 0 self.iter_idx = -1 # initialise tensorboard logger self.logger = TBLogger(self.args, self.args.exp_label) # initialise environments self.envs = make_vec_envs( env_name=args.env_name, seed=args.seed, num_processes=args.num_processes, gamma=args.policy_gamma, device=device, episodes_per_task=self.args.max_rollouts_per_task, normalise_rew=args.norm_rew_for_policy, ret_rms=None, tasks=None) if self.args.single_task_mode: # get the current tasks (which will be num_process many different tasks) self.train_tasks = self.envs.get_task() # set the tasks to the first task (i.e. just a random task) self.train_tasks[1:] = self.train_tasks[0] # make it a list self.train_tasks = [t for t in self.train_tasks] # re-initialise environments with those tasks self.envs = make_vec_envs( env_name=args.env_name, seed=args.seed, num_processes=args.num_processes, gamma=args.policy_gamma, device=device, episodes_per_task=self.args.max_rollouts_per_task, normalise_rew=args.norm_rew_for_policy, ret_rms=None, tasks=self.train_tasks) # save the training tasks so we can evaluate on the same envs later utl.save_obj(self.train_tasks, self.logger.full_output_folder, "train_tasks") else: self.train_tasks = None # calculate what the maximum length of the trajectories is self.args.max_trajectory_len = self.envs._max_episode_steps self.args.max_trajectory_len *= self.args.max_rollouts_per_task # get policy input dimensions self.args.state_dim = self.envs.observation_space.shape[0] self.args.task_dim = self.envs.task_dim self.args.belief_dim = self.envs.belief_dim self.args.num_states = self.envs.num_states # get policy output (action) dimensions self.args.action_space = self.envs.action_space if isinstance(self.envs.action_space, gym.spaces.discrete.Discrete): self.args.action_dim = 1 else: self.args.action_dim = self.envs.action_space.shape[0] # initialise VAE and policy self.vae = VaribadVAE(self.args, self.logger, lambda: self.iter_idx) self.policy_storage = self.initialise_policy_storage() self.policy = self.initialise_policy() def initialise_policy_storage(self): return OnlineStorage( args=self.args, num_steps=self.args.policy_num_steps, num_processes=self.args.num_processes, state_dim=self.args.state_dim, latent_dim=self.args.latent_dim, belief_dim=self.args.belief_dim, task_dim=self.args.task_dim, action_space=self.args.action_space, hidden_size=self.args.encoder_gru_hidden_size, normalise_rewards=self.args.norm_rew_for_policy, ) def initialise_policy(self): # initialise policy network policy_net = Policy( args=self.args, # pass_state_to_policy=self.args.pass_state_to_policy, pass_latent_to_policy=self.args.pass_latent_to_policy, pass_belief_to_policy=self.args.pass_belief_to_policy, pass_task_to_policy=self.args.pass_task_to_policy, dim_state=self.args.state_dim, dim_latent=self.args.latent_dim * 2, dim_belief=self.args.belief_dim, dim_task=self.args.task_dim, # hidden_layers=self.args.policy_layers, activation_function=self.args.policy_activation_function, policy_initialisation=self.args.policy_initialisation, # action_space=self.envs.action_space, init_std=self.args.policy_init_std, ).to(device) # initialise policy trainer if self.args.policy == 'a2c': policy = A2C( self.args, policy_net, self.args.policy_value_loss_coef, self.args.policy_entropy_coef, policy_optimiser=self.args.policy_optimiser, policy_anneal_lr=self.args.policy_anneal_lr, train_steps=self.num_updates, optimiser_vae=self.vae.optimiser_vae, lr=self.args.lr_policy, eps=self.args.policy_eps, ) elif self.args.policy == 'ppo': policy = PPO( self.args, policy_net, self.args.policy_value_loss_coef, self.args.policy_entropy_coef, policy_optimiser=self.args.policy_optimiser, policy_anneal_lr=self.args.policy_anneal_lr, train_steps=self.num_updates, lr=self.args.lr_policy, eps=self.args.policy_eps, ppo_epoch=self.args.ppo_num_epochs, num_mini_batch=self.args.ppo_num_minibatch, use_huber_loss=self.args.ppo_use_huberloss, use_clipped_value_loss=self.args.ppo_use_clipped_value_loss, clip_param=self.args.ppo_clip_param, optimiser_vae=self.vae.optimiser_vae, ) else: raise NotImplementedError return policy def train(self): """ Main Meta-Training loop """ start_time = time.time() # reset environments prev_state, belief, task = utl.reset_env(self.envs, self.args) # insert initial observation / embeddings to rollout storage self.policy_storage.prev_state[0].copy_(prev_state) # log once before training with torch.no_grad(): self.log(None, None, start_time) for self.iter_idx in range(self.num_updates): # First, re-compute the hidden states given the current rollouts (since the VAE might've changed) with torch.no_grad(): latent_sample, latent_mean, latent_logvar, hidden_state = self.encode_running_trajectory( ) # add this initial hidden state to the policy storage assert len(self.policy_storage.latent_mean ) == 0 # make sure we emptied buffers self.policy_storage.hidden_states[0].copy_(hidden_state) self.policy_storage.latent_samples.append(latent_sample.clone()) self.policy_storage.latent_mean.append(latent_mean.clone()) self.policy_storage.latent_logvar.append(latent_logvar.clone()) # rollout policies for a few steps for step in range(self.args.policy_num_steps): # sample actions from policy with torch.no_grad(): value, action = utl.select_action( args=self.args, policy=self.policy, state=prev_state, belief=belief, task=task, deterministic=False, latent_sample=latent_sample, latent_mean=latent_mean, latent_logvar=latent_logvar, ) # take step in the environment [next_state, belief, task], (rew_raw, rew_normalised), done, infos = utl.env_step( self.envs, action, self.args) done = torch.from_numpy(np.array( done, dtype=int)).to(device).float().view((-1, 1)) # create mask for episode ends masks_done = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]).to(device) # bad_mask is true if episode ended because time limit was reached bad_masks = torch.FloatTensor( [[0.0] if 'bad_transition' in info.keys() else [1.0] for info in infos]).to(device) with torch.no_grad(): # compute next embedding (for next loop and/or value prediction bootstrap) latent_sample, latent_mean, latent_logvar, hidden_state = utl.update_encoding( encoder=self.vae.encoder, next_obs=next_state, action=action, reward=rew_raw, done=done, hidden_state=hidden_state) # before resetting, update the embedding and add to vae buffer # (last state might include useful task info) if not (self.args.disable_decoder and self.args.disable_kl_term): self.vae.rollout_storage.insert( prev_state.clone(), action.detach().clone(), next_state.clone(), rew_raw.clone(), done.clone(), task.clone() if task is not None else None) # add the obs before reset to the policy storage self.policy_storage.next_state[step] = next_state.clone() # reset environments that are done done_indices = np.argwhere(done.cpu().flatten()).flatten() if len(done_indices) > 0: next_state, belief, task = utl.reset_env( self.envs, self.args, indices=done_indices, state=next_state) # TODO: deal with resampling for posterior sampling algorithm # latent_sample = latent_sample # latent_sample[i] = latent_sample[i] # add experience to policy buffer self.policy_storage.insert( state=next_state, belief=belief, task=task, actions=action, rewards_raw=rew_raw, rewards_normalised=rew_normalised, value_preds=value, masks=masks_done, bad_masks=bad_masks, done=done, hidden_states=hidden_state.squeeze(0), latent_sample=latent_sample, latent_mean=latent_mean, latent_logvar=latent_logvar, ) prev_state = next_state self.frames += self.args.num_processes # --- UPDATE --- if self.args.precollect_len <= self.frames: # check if we are pre-training the VAE if self.args.pretrain_len > self.iter_idx: for p in range(self.args.num_vae_updates_per_pretrain): self.vae.compute_vae_loss( update=True, pretrain_index=self.iter_idx * self.args.num_vae_updates_per_pretrain + p) # otherwise do the normal update (policy + vae) else: train_stats = self.update(state=prev_state, belief=belief, task=task, latent_sample=latent_sample, latent_mean=latent_mean, latent_logvar=latent_logvar) # log run_stats = [ action, self.policy_storage.action_log_probs, value ] with torch.no_grad(): self.log(run_stats, train_stats, start_time) # clean up after update self.policy_storage.after_update() self.envs.close() def encode_running_trajectory(self): """ (Re-)Encodes (for each process) the entire current trajectory. Returns sample/mean/logvar and hidden state (if applicable) for the current timestep. :return: """ # for each process, get the current batch (zero-padded obs/act/rew + length indicators) prev_obs, next_obs, act, rew, lens = self.vae.rollout_storage.get_running_batch( ) # get embedding - will return (1+sequence_len) * batch * input_size -- includes the prior! all_latent_samples, all_latent_means, all_latent_logvars, all_hidden_states = self.vae.encoder( actions=act, states=next_obs, rewards=rew, hidden_state=None, return_prior=True) # get the embedding / hidden state of the current time step (need to do this since we zero-padded) latent_sample = (torch.stack([ all_latent_samples[lens[i]][i] for i in range(len(lens)) ])).to(device) latent_mean = (torch.stack([ all_latent_means[lens[i]][i] for i in range(len(lens)) ])).to(device) latent_logvar = (torch.stack([ all_latent_logvars[lens[i]][i] for i in range(len(lens)) ])).to(device) hidden_state = (torch.stack([ all_hidden_states[lens[i]][i] for i in range(len(lens)) ])).to(device) return latent_sample, latent_mean, latent_logvar, hidden_state def get_value(self, state, belief, task, latent_sample, latent_mean, latent_logvar): latent = utl.get_latent_for_policy(self.args, latent_sample=latent_sample, latent_mean=latent_mean, latent_logvar=latent_logvar) return self.policy.actor_critic.get_value(state=state, belief=belief, task=task, latent=latent).detach() def update(self, state, belief, task, latent_sample, latent_mean, latent_logvar): """ Meta-update. Here the policy is updated for good average performance across tasks. :return: """ # update policy (if we are not pre-training, have enough data in the vae buffer, and are not at iteration 0) if self.iter_idx >= self.args.pretrain_len and self.iter_idx > 0: # bootstrap next value prediction with torch.no_grad(): next_value = self.get_value(state=state, belief=belief, task=task, latent_sample=latent_sample, latent_mean=latent_mean, latent_logvar=latent_logvar) # compute returns for current rollouts self.policy_storage.compute_returns( next_value, self.args.policy_use_gae, self.args.policy_gamma, self.args.policy_tau, use_proper_time_limits=self.args.use_proper_time_limits) # update agent (this will also call the VAE update!) policy_train_stats = self.policy.update( policy_storage=self.policy_storage, encoder=self.vae.encoder, rlloss_through_encoder=self.args.rlloss_through_encoder, compute_vae_loss=self.vae.compute_vae_loss) else: policy_train_stats = 0, 0, 0, 0 # pre-train the VAE if self.iter_idx < self.args.pretrain_len: self.vae.compute_vae_loss(update=True) return policy_train_stats def log(self, run_stats, train_stats, start_time): # --- visualise behaviour of policy --- if (self.iter_idx + 1) % self.args.vis_interval == 0: ret_rms = self.envs.venv.ret_rms if self.args.norm_rew_for_policy else None utl_eval.visualise_behaviour( args=self.args, policy=self.policy, image_folder=self.logger.full_output_folder, iter_idx=self.iter_idx, ret_rms=ret_rms, encoder=self.vae.encoder, reward_decoder=self.vae.reward_decoder, state_decoder=self.vae.state_decoder, task_decoder=self.vae.task_decoder, compute_rew_reconstruction_loss=self.vae. compute_rew_reconstruction_loss, compute_state_reconstruction_loss=self.vae. compute_state_reconstruction_loss, compute_task_reconstruction_loss=self.vae. compute_task_reconstruction_loss, compute_kl_loss=self.vae.compute_kl_loss, tasks=self.train_tasks, ) # --- evaluate policy ---- if (self.iter_idx + 1) % self.args.eval_interval == 0: ret_rms = self.envs.venv.ret_rms if self.args.norm_rew_for_policy else None returns_per_episode = utl_eval.evaluate( args=self.args, policy=self.policy, ret_rms=ret_rms, encoder=self.vae.encoder, iter_idx=self.iter_idx, tasks=self.train_tasks, ) # log the return avg/std across tasks (=processes) returns_avg = returns_per_episode.mean(dim=0) returns_std = returns_per_episode.std(dim=0) for k in range(len(returns_avg)): self.logger.add('return_avg_per_iter/episode_{}'.format(k + 1), returns_avg[k], self.iter_idx) self.logger.add( 'return_avg_per_frame/episode_{}'.format(k + 1), returns_avg[k], self.frames) self.logger.add('return_std_per_iter/episode_{}'.format(k + 1), returns_std[k], self.iter_idx) self.logger.add( 'return_std_per_frame/episode_{}'.format(k + 1), returns_std[k], self.frames) print(f"Updates {self.iter_idx}, " f"Frames {self.frames}, " f"FPS {int(self.frames / (time.time() - start_time))}, " f"\n Mean return (train): {returns_avg[-1].item()} \n") # --- save models --- if (self.iter_idx + 1) % self.args.save_interval == 0: save_path = os.path.join(self.logger.full_output_folder, 'models') if not os.path.exists(save_path): os.mkdir(save_path) idx_labels = [''] if self.args.save_intermediate_models: idx_labels.append(int(self.iter_idx)) for idx_label in idx_labels: torch.save(self.policy.actor_critic, os.path.join(save_path, f"policy{idx_label}.pt")) torch.save(self.vae.encoder, os.path.join(save_path, f"encoder{idx_label}.pt")) if self.vae.state_decoder is not None: torch.save( self.vae.state_decoder, os.path.join(save_path, f"state_decoder{idx_label}.pt")) if self.vae.reward_decoder is not None: torch.save( self.vae.reward_decoder, os.path.join(save_path, f"reward_decoder{idx_label}.pt")) if self.vae.task_decoder is not None: torch.save( self.vae.task_decoder, os.path.join(save_path, f"task_decoder{idx_label}.pt")) # save normalisation params of envs if self.args.norm_rew_for_policy: rew_rms = self.envs.venv.ret_rms utl.save_obj(rew_rms, save_path, f"env_rew_rms{idx_label}") # TODO: grab from policy and save? # if self.args.norm_obs_for_policy: # obs_rms = self.envs.venv.obs_rms # utl.save_obj(obs_rms, save_path, f"env_obs_rms{idx_label}") # --- log some other things --- if ((self.iter_idx + 1) % self.args.log_interval == 0) and (train_stats is not None): self.logger.add('environment/state_max', self.policy_storage.prev_state.max(), self.iter_idx) self.logger.add('environment/state_min', self.policy_storage.prev_state.min(), self.iter_idx) self.logger.add('environment/rew_max', self.policy_storage.rewards_raw.max(), self.iter_idx) self.logger.add('environment/rew_min', self.policy_storage.rewards_raw.min(), self.iter_idx) self.logger.add('policy_losses/value_loss', train_stats[0], self.iter_idx) self.logger.add('policy_losses/action_loss', train_stats[1], self.iter_idx) self.logger.add('policy_losses/dist_entropy', train_stats[2], self.iter_idx) self.logger.add('policy_losses/sum', train_stats[3], self.iter_idx) self.logger.add('policy/action', run_stats[0][0].float().mean(), self.iter_idx) if hasattr(self.policy.actor_critic, 'logstd'): self.logger.add('policy/action_logstd', self.policy.actor_critic.dist.logstd.mean(), self.iter_idx) self.logger.add('policy/action_logprob', run_stats[1].mean(), self.iter_idx) self.logger.add('policy/value', run_stats[2].mean(), self.iter_idx) self.logger.add('encoder/latent_mean', torch.cat(self.policy_storage.latent_mean).mean(), self.iter_idx) self.logger.add( 'encoder/latent_logvar', torch.cat(self.policy_storage.latent_logvar).mean(), self.iter_idx) # log the average weights and gradients of all models (where applicable) for [model, name ] in [[self.policy.actor_critic, 'policy'], [self.vae.encoder, 'encoder'], [self.vae.reward_decoder, 'reward_decoder'], [self.vae.state_decoder, 'state_transition_decoder'], [self.vae.task_decoder, 'task_decoder']]: if model is not None: param_list = list(model.parameters()) param_mean = np.mean([ param_list[i].data.cpu().numpy().mean() for i in range(len(param_list)) ]) self.logger.add('weights/{}'.format(name), param_mean, self.iter_idx) if name == 'policy': self.logger.add('weights/policy_std', param_list[0].data.mean(), self.iter_idx) if param_list[0].grad is not None: param_grad_mean = np.mean([ param_list[i].grad.cpu().numpy().mean() for i in range(len(param_list)) ]) self.logger.add('gradients/{}'.format(name), param_grad_mean, self.iter_idx)