def simulate_policy(args): data = torch.load(args.file) policy = data['evaluation/policy'] env = data['evaluation/env'] print("Policy loaded") if args.gpu: set_gpu_mode(True) policy.cuda() paths = [] while True: path = rollout( env, policy, max_path_length=args.H, render=True, ) paths.append(path) if hasattr(env, "log_diagnostics"): env.log_diagnostics(paths) if hasattr(env, "get_diagnostics"): for k, v in env.get_diagnostics(paths).items(): logger.record_tabular(k, v) else: logger.record_dict( eval_util.get_generic_path_information(paths), prefix="evaluation/", ) logger.dump_tabular()
def train_ae(ae_trainer, training_distrib, num_epochs=100, num_batches_per_epoch=500, batch_size=512, goal_key='image_desired_goal', rl_csv_fname='progress.csv'): from rlkit.core import logger logger.remove_tabular_output(rl_csv_fname, relative_to_snapshot_dir=True) logger.add_tabular_output('ae_progress.csv', relative_to_snapshot_dir=True) for epoch in range(num_epochs): for batch_num in range(num_batches_per_epoch): goals = ptu.from_numpy( training_distrib.sample(batch_size)[goal_key]) batch = dict(raw_next_observations=goals, ) ae_trainer.train_from_torch(batch) log = OrderedDict() log['epoch'] = epoch append_log(log, ae_trainer.eval_statistics, prefix='ae/') logger.record_dict(log) logger.dump_tabular(with_prefix=True, with_timestamp=False) ae_trainer.end_epoch(epoch) logger.add_tabular_output(rl_csv_fname, relative_to_snapshot_dir=True) logger.remove_tabular_output('ae_progress.csv', relative_to_snapshot_dir=True)
def _log_stats(self, epoch): logger.log("Epoch {} finished".format(epoch), with_timestamp=True) """ Trainer """ logger.record_dict(self.trainer.get_diagnostics(), prefix='trainer/') """ Evaluation """ logger.record_dict( self.eval_data_collector.get_diagnostics(), prefix='evaluation/', ) eval_paths = self.eval_data_collector.get_epoch_paths() if hasattr(self.eval_env, 'get_diagnostics'): logger.record_dict( self.eval_env.get_diagnostics(eval_paths), prefix='evaluation/', ) logger.record_dict( eval_util.get_generic_path_information(eval_paths), prefix="evaluation/", ) """ Misc """ gt.stamp('logging') logger.record_dict(_get_epoch_timings()) logger.record_tabular('Epoch', epoch) logger.dump_tabular(with_prefix=False, with_timestamp=False)
def run(self): if self.progress_csv_file_name != 'progress.csv': logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output( self.progress_csv_file_name, relative_to_snapshot_dir=True, ) timer.return_global_times = True for _ in range(self.num_iters): self._begin_epoch() timer.start_timer('saving') logger.save_itr_params(self.epoch, self._get_snapshot()) timer.stop_timer('saving') log_dict, _ = self._train() logger.record_dict(log_dict) logger.dump_tabular(with_prefix=True, with_timestamp=False) self._end_epoch() logger.save_itr_params(self.epoch, self._get_snapshot()) if self.progress_csv_file_name != 'progress.csv': logger.remove_tabular_output( self.progress_csv_file_name, relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, )
def _log_vae_stats(self): logger.record_dict( self.vae_trainer_original.get_diagnostics(), prefix='vae_trainer_original/', ) logger.record_dict( self.vae_trainer_segmented.get_diagnostics(), prefix='vae_trainer_segmented/', )
def train(self): timer.return_global_times = True for _ in range(self.num_epochs): self._begin_epoch() timer.start_timer('saving') logger.save_itr_params(self.epoch, self._get_snapshot()) timer.stop_timer('saving') log_dict, _ = self._train() logger.record_dict(log_dict) logger.dump_tabular(with_prefix=True, with_timestamp=False) self._end_epoch() logger.save_itr_params(self.epoch, self._get_snapshot())
def pretrain_q_with_bc_data(self, batch_size): logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('pretrain_q.csv', relative_to_snapshot_dir=True) self.update_policy = True # then train policy and Q function together prev_time = time.time() for i in range(self.num_pretrain_steps): self.eval_statistics = dict() if i % self.pretraining_logging_period == 0: self._need_to_update_eval_statistics = True train_data = self.replay_buffer.random_batch(batch_size) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] # goals = train_data['resampled_goals'] train_data['observations'] = obs # torch.cat((obs, goals), dim=1) train_data[ 'next_observations'] = next_obs # torch.cat((next_obs, goals), dim=1) self.train_from_torch(train_data) if self.do_pretrain_rollouts and i % self.pretraining_env_logging_period == 0: total_ret = self.do_rollouts() print("Return at step {} : {}".format(i, total_ret / 20)) if i % self.pretraining_logging_period == 0: if self.do_pretrain_rollouts: self.eval_statistics[ "pretrain_bc/avg_return"] = total_ret / 20 self.eval_statistics["batch"] = i self.eval_statistics["epoch_time"] = time.time() - prev_time logger.record_dict(self.eval_statistics) logger.dump_tabular(with_prefix=True, with_timestamp=False) prev_time = time.time() logger.remove_tabular_output( 'pretrain_q.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) self._need_to_update_eval_statistics = True self.eval_statistics = dict()
def train(self): # first train only the Q function iteration = 0 for i in range(self.num_batches): train_data = self.replay_buffer.random_batch(self.batch_size) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] train_data['observations'] = obs train_data['next_observations'] = next_obs self.trainer.train_from_torch(train_data) if i % self.logging_period == 0: stats_with_prefix = add_prefix( self.trainer.eval_statistics, prefix="trainer/") self.trainer.end_epoch(iteration) iteration += 1 logger.record_dict(stats_with_prefix) logger.dump_tabular(with_prefix=True, with_timestamp=False)
def simulate_policy(fpath, env_name, seed, max_path_length, num_eval_steps, headless, max_eps, verbose=True, pause=False): data = torch.load(fpath, map_location=ptu.device) policy = data['evaluation/policy'] policy.to(ptu.device) # make new env, reloading with data['evaluation/env'] seems to make bug env = gym.make(env_name, **{"headless": headless, "verbose": False}) env.seed(seed) if pause: input("Waiting to start.") path_collector = MdpPathCollector(env, policy) paths = path_collector.collect_new_paths( max_path_length, num_eval_steps, discard_incomplete_paths=True, ) if max_eps: paths = paths[:max_eps] if verbose: completions = sum([ info["completed"] for path in paths for info in path["env_infos"] ]) print("Completed {} out of {}".format(completions, len(paths))) # plt.plot(paths[0]["actions"]) # plt.show() # plt.plot(paths[2]["observations"]) # plt.show() logger.record_dict( eval_util.get_generic_path_information(paths), prefix="evaluation/", ) logger.dump_tabular() return paths
def pretrain_q_with_bc_data(self, batch_size): logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('pretrain_q.csv', relative_to_snapshot_dir=True) prev_time = time.time() for i in range(self.num_pretrain_steps): self.eval_statistics = dict() if i % self.pretraining_logging_period == 0: self._need_to_update_eval_statistics = True train_data = self.replay_buffer.random_batch(self.bc_batch_size) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] # goals = train_data['resampled_goals'] train_data['observations'] = obs # torch.cat((obs, goals), dim=1) train_data[ 'next_observations'] = next_obs # torch.cat((next_obs, goals), dim=1) self.train_from_torch(train_data, pretrain=True) if i % self.pretraining_logging_period == 0: self.eval_statistics["batch"] = i self.eval_statistics["epoch_time"] = time.time() - prev_time stats_with_prefix = add_prefix(self.eval_statistics, prefix="trainer/") logger.record_dict(stats_with_prefix) logger.dump_tabular(with_prefix=True, with_timestamp=False) prev_time = time.time() logger.remove_tabular_output( 'pretrain_q.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) self._need_to_update_eval_statistics = True self.eval_statistics = dict()
def train(self): # first train only the Q function iteration = 0 timer.return_global_times = True timer.reset() for i in range(self.num_batches): if self.use_meta_learning_buffer: train_data = self.meta_replay_buffer.sample_meta_batch( rl_batch_size=self.batch_size, meta_batch_size=self.meta_batch_size, embedding_batch_size=self.task_embedding_batch_size, ) train_data = np_to_pytorch_batch(train_data) else: task_indices = np.random.choice( self.train_tasks, self.meta_batch_size, ) train_data = self.replay_buffer.sample_batch( task_indices, self.batch_size, ) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] train_data['observations'] = obs train_data['next_observations'] = next_obs train_data['context'] = ( self.task_embedding_replay_buffer.sample_context( task_indices, self.task_embedding_batch_size, )) timer.start_timer('train', unique=False) self.trainer.train_from_torch(train_data) timer.stop_timer('train') if i % self.logging_period == 0 or i == self.num_batches - 1: stats_with_prefix = add_prefix( self.trainer.eval_statistics, prefix="trainer/") self.trainer.end_epoch(iteration) logger.record_dict(stats_with_prefix) timer.start_timer('extra_fns', unique=False) for fn in self._extra_eval_fns: extra_stats = fn() logger.record_dict(extra_stats) timer.stop_timer('extra_fns') # TODO: evaluate during offline RL # eval_stats = self.get_eval_statistics() # eval_stats_with_prefix = add_prefix(eval_stats, prefix="eval/") # logger.record_dict(eval_stats_with_prefix) logger.record_tabular('iteration', iteration) logger.record_dict(_get_epoch_timings()) try: import os import psutil process = psutil.Process(os.getpid()) logger.record_tabular('RAM Usage (Mb)', int(process.memory_info().rss / 1000000)) except ImportError: pass logger.dump_tabular(with_prefix=True, with_timestamp=False) iteration += 1
def _log_stats(self, epoch, num_epochs_per_eval=0): logger.log("Epoch {} finished".format(epoch), with_timestamp=True) """ Replay Buffer """ logger.record_dict(self.replay_buffer.get_diagnostics(), prefix='replay_buffer/') """ Trainer """ logger.record_dict(self.trainer.get_diagnostics(), prefix='trainer/') """ Exploration """ logger.record_dict(self.expl_data_collector.get_diagnostics(), prefix='exploration/') expl_paths = self.expl_data_collector.get_epoch_paths() if hasattr(self.expl_env, 'get_diagnostics'): logger.record_dict( self.expl_env.get_diagnostics(expl_paths), prefix='exploration/', ) logger.record_dict( eval_util.get_generic_path_information(expl_paths), prefix="exploration/", ) """ Evaluation """ try: # if epoch % num_epochs_per_eval == 0: logger.record_dict( self.eval_data_collector.get_diagnostics(), prefix='evaluation/', ) eval_paths = self.eval_data_collector.get_epoch_paths() if hasattr(self.eval_env, 'get_diagnostics'): logger.record_dict( self.eval_env.get_diagnostics(eval_paths), prefix='evaluation/', ) logger.record_dict( eval_util.get_generic_path_information(eval_paths), prefix="evaluation/", ) except: pass """ Misc """ gt.stamp('logging') logger.record_dict(_get_epoch_timings()) logger.record_tabular('Epoch', epoch) logger.dump_tabular(with_prefix=False, with_timestamp=False)
def _log_stats(self, epoch): logger.log("Epoch {} finished".format(epoch), with_timestamp=True) """ Replay Buffer """ logger.record_dict(self.replay_buffer.get_diagnostics(), prefix='replay_buffer/') """ Trainer """ logger.record_dict(self.trainer.get_diagnostics(), prefix='trainer/') """ Exploration """ logger.record_dict(self.expl_data_collector.get_diagnostics(), prefix='exploration/') expl_paths = self.expl_data_collector.get_epoch_paths() # import ipdb; ipdb.set_trace() if hasattr(self.expl_env, 'get_diagnostics'): logger.record_dict( self.expl_env.get_diagnostics(expl_paths), prefix='exploration/', ) if not self.batch_rl or self.eval_both: logger.record_dict( eval_util.get_generic_path_information(expl_paths), prefix="exploration/", ) """ Evaluation """ logger.record_dict( self.eval_data_collector.get_diagnostics(), prefix='evaluation/', ) eval_paths = self.eval_data_collector.get_epoch_paths() if hasattr(self.eval_env, 'get_diagnostics'): logger.record_dict( self.eval_env.get_diagnostics(eval_paths), prefix='evaluation/', ) logger.record_dict( eval_util.get_generic_path_information(eval_paths), prefix="evaluation/", ) """ Misc """ gt.stamp('logging') logger.record_dict(_get_epoch_timings()) logger.record_tabular('Epoch', epoch) logger.dump_tabular(with_prefix=False, with_timestamp=False)
def pretrain_q_with_bc_data(self): """ :return: """ logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('pretrain_q.csv', relative_to_snapshot_dir=True) self.update_policy = False # first train only the Q function for i in range(self.q_num_pretrain1_steps): self.eval_statistics = dict() train_data = self.replay_buffer.random_batch(self.bc_batch_size) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] # goals = train_data['resampled_goals'] train_data['observations'] = obs # torch.cat((obs, goals), dim=1) train_data[ 'next_observations'] = next_obs # torch.cat((next_obs, goals), dim=1) self.train_from_torch(train_data, pretrain=True) if i % self.pretraining_logging_period == 0: stats_with_prefix = add_prefix(self.eval_statistics, prefix="trainer/") logger.record_dict(stats_with_prefix) logger.dump_tabular(with_prefix=True, with_timestamp=False) self.update_policy = True # then train policy and Q function together prev_time = time.time() for i in range(self.q_num_pretrain2_steps): self.eval_statistics = dict() if i % self.pretraining_logging_period == 0: self._need_to_update_eval_statistics = True train_data = self.replay_buffer.random_batch(self.bc_batch_size) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] # goals = train_data['resampled_goals'] train_data['observations'] = obs # torch.cat((obs, goals), dim=1) train_data[ 'next_observations'] = next_obs # torch.cat((next_obs, goals), dim=1) self.train_from_torch(train_data, pretrain=True) if i % self.pretraining_logging_period == 0: self.eval_statistics["batch"] = i self.eval_statistics["epoch_time"] = time.time() - prev_time stats_with_prefix = add_prefix(self.eval_statistics, prefix="trainer/") logger.record_dict(stats_with_prefix) logger.dump_tabular(with_prefix=True, with_timestamp=False) prev_time = time.time() logger.remove_tabular_output( 'pretrain_q.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) self._need_to_update_eval_statistics = True self.eval_statistics = dict() if self.post_pretrain_hyperparams: self.set_algorithm_weights(**self.post_pretrain_hyperparams)
def pretrain_policy_with_bc( self, policy, train_buffer, test_buffer, steps, label="policy", ): """Given a policy, first get its optimizer, then run the policy on the train buffer, get the losses, and back propagate the loss. After training on a batch, test on the test buffer and get the statistics :param policy: :param train_buffer: :param test_buffer: :param steps: :param label: :return: """ logger.remove_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'pretrain_%s.csv' % label, relative_to_snapshot_dir=True, ) optimizer = self.optimizers[policy] prev_time = time.time() for i in range(steps): train_policy_loss, train_logp_loss, train_mse_loss, train_stats = self.run_bc_batch( train_buffer, policy) train_policy_loss = train_policy_loss * self.bc_weight optimizer.zero_grad() train_policy_loss.backward() optimizer.step() test_policy_loss, test_logp_loss, test_mse_loss, test_stats = self.run_bc_batch( test_buffer, policy) test_policy_loss = test_policy_loss * self.bc_weight if i % self.pretraining_logging_period == 0: stats = { "pretrain_bc/batch": i, "pretrain_bc/Train Logprob Loss": ptu.get_numpy(train_logp_loss), "pretrain_bc/Test Logprob Loss": ptu.get_numpy(test_logp_loss), "pretrain_bc/Train MSE": ptu.get_numpy(train_mse_loss), "pretrain_bc/Test MSE": ptu.get_numpy(test_mse_loss), "pretrain_bc/train_policy_loss": ptu.get_numpy(train_policy_loss), "pretrain_bc/test_policy_loss": ptu.get_numpy(test_policy_loss), "pretrain_bc/epoch_time": time.time() - prev_time, } logger.record_dict(stats) logger.dump_tabular(with_prefix=True, with_timestamp=False) pickle.dump( self.policy, open(logger.get_snapshot_dir() + '/bc_%s.pkl' % label, "wb")) prev_time = time.time() logger.remove_tabular_output( 'pretrain_%s.csv' % label, relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) if self.post_bc_pretrain_hyperparams: self.set_algorithm_weights(**self.post_bc_pretrain_hyperparams)
def _log_stats(self, epoch): logger.log("Epoch {} finished".format(epoch), with_timestamp=True) """ Replay Buffer """ logger.record_dict( self.replay_buffer.get_diagnostics(), global_step=epoch, prefix="replay_buffer/", ) """ Trainer """ logger.record_dict(self.trainer.get_diagnostics(), global_step=epoch, prefix="trainer/") """ Exploration """ logger.record_dict( self.expl_data_collector.get_diagnostics(), global_step=epoch, prefix="exploration/", ) expl_paths = self.expl_data_collector.get_epoch_paths() if hasattr(self.expl_env, "get_diagnostics"): logger.record_dict( self.expl_env.get_diagnostics(expl_paths), global_step=epoch, prefix="exploration/", ) logger.record_dict( eval_util.get_generic_path_information(expl_paths), global_step=epoch, prefix="exploration/", ) """ Evaluation """ logger.record_dict( self.eval_data_collector.get_diagnostics(), global_step=epoch, prefix="evaluation/", ) eval_paths = self.eval_data_collector.get_epoch_paths() if hasattr(self.eval_env, "get_diagnostics"): logger.record_dict( self.eval_env.get_diagnostics(eval_paths), global_step=epoch, prefix="evaluation/", ) logger.record_dict( eval_util.get_generic_path_information(eval_paths), global_step=epoch, prefix="evaluation/", ) """ Misc """ gt.stamp("logging") logger.record_dict(_get_epoch_timings(), global_step=epoch) logger.record_tabular("Epoch", epoch) logger.dump_tabular(with_prefix=False, with_timestamp=False)
def run_eval_loop(sample_stochastically=True): start_time = time.time() prefix = "stochastic_" if sample_stochastically else "" for i in range(num_episodes): obs = env.reset() video.init(enabled=(i == 0)) done = False episode_reward = 0 ep_infos = [] while not done: # center crop image if encoder_type == "pixel" and "crop" in data_augs: obs = utils.center_crop_image(obs, image_size) if encoder_type == "pixel" and "translate" in data_augs: # first crop the center with pre_transform_image_size obs = utils.center_crop_image(obs, pre_transform_image_size) # then translate cropped to center obs = utils.center_translate(obs, image_size) with utils.eval_mode(agent): if sample_stochastically: action = agent.sample_action(obs / 255.0) else: action = agent.select_action(obs / 255.0) obs, reward, done, info = env.step(action) video.record(env) episode_reward += reward ep_infos.append(info) video.save("%d.mp4" % step) L.log("eval/" + prefix + "episode_reward", episode_reward, step) all_ep_rewards.append(episode_reward) all_infos.append(ep_infos) L.log("eval/" + prefix + "eval_time", time.time() - start_time, step) mean_ep_reward = np.mean(all_ep_rewards) best_ep_reward = np.max(all_ep_rewards) std_ep_reward = np.std(all_ep_rewards) L.log("eval/" + prefix + "mean_episode_reward", mean_ep_reward, step) L.log("eval/" + prefix + "best_episode_reward", best_ep_reward, step) rlkit_logger.record_dict( {"Average Returns": mean_ep_reward}, prefix="evaluation/" ) statistics = compute_path_info(all_infos) rlkit_logger.record_dict(statistics, prefix="evaluation/") filename = ( work_dir + "/" + env_name + "-" + data_augs + "--s" + str(seed) + "--eval_scores.npy" ) key = env_name + "-" + data_augs try: log_data = np.load(filename, allow_pickle=True) log_data = log_data.item() except: log_data = {} if key not in log_data: log_data[key] = {} log_data[key][step] = {} log_data[key][step]["step"] = step log_data[key][step]["mean_ep_reward"] = mean_ep_reward log_data[key][step]["max_ep_reward"] = best_ep_reward log_data[key][step]["std_ep_reward"] = std_ep_reward log_data[key][step]["env_step"] = step * action_repeat np.save(filename, log_data)
def experiment(variant): gym.logger.set_level(40) work_dir = rlkit_logger.get_snapshot_dir() args = parse_args() seed = int(variant["seed"]) utils.set_seed_everywhere(seed) os.makedirs(work_dir, exist_ok=True) agent_kwargs = variant["agent_kwargs"] data_augs = agent_kwargs["data_augs"] encoder_type = agent_kwargs["encoder_type"] discrete_continuous_dist = agent_kwargs["discrete_continuous_dist"] env_suite = variant["env_suite"] env_name = variant["env_name"] env_kwargs = variant["env_kwargs"] pre_transform_image_size = variant["pre_transform_image_size"] image_size = variant["image_size"] frame_stack = variant["frame_stack"] batch_size = variant["batch_size"] replay_buffer_capacity = variant["replay_buffer_capacity"] num_train_steps = variant["num_train_steps"] num_eval_episodes = variant["num_eval_episodes"] eval_freq = variant["eval_freq"] action_repeat = variant["action_repeat"] init_steps = variant["init_steps"] log_interval = variant["log_interval"] use_raw_actions = variant["use_raw_actions"] pre_transform_image_size = ( pre_transform_image_size if "crop" in data_augs else image_size ) pre_transform_image_size = pre_transform_image_size if data_augs == "crop": pre_transform_image_size = 100 image_size = image_size elif data_augs == "translate": pre_transform_image_size = 100 image_size = 108 if env_suite == 'kitchen': env_kwargs['imwidth'] = pre_transform_image_size env_kwargs['imheight'] = pre_transform_image_size else: env_kwargs['image_kwargs']['imwidth'] = pre_transform_image_size env_kwargs['image_kwargs']['imheight'] = pre_transform_image_size expl_env = primitives_make_env.make_env(env_suite, env_name, env_kwargs) eval_env = primitives_make_env.make_env(env_suite, env_name, env_kwargs) # stack several consecutive frames together if encoder_type == "pixel": expl_env = utils.FrameStack(expl_env, k=frame_stack) eval_env = utils.FrameStack(eval_env, k=frame_stack) # make directory ts = time.gmtime() ts = time.strftime("%m-%d", ts) env_name = env_name exp_name = ( env_name + "-" + ts + "-im" + str(image_size) + "-b" + str(batch_size) + "-s" + str(seed) + "-" + encoder_type ) work_dir = work_dir + "/" + exp_name utils.make_dir(work_dir) video_dir = utils.make_dir(os.path.join(work_dir, "video")) model_dir = utils.make_dir(os.path.join(work_dir, "model")) buffer_dir = utils.make_dir(os.path.join(work_dir, "buffer")) video = VideoRecorder(video_dir if args.save_video else None) with open(os.path.join(work_dir, "args.json"), "w") as f: json.dump(vars(args), f, sort_keys=True, indent=4) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if use_raw_actions: continuous_action_dim = expl_env.action_space.low.size discrete_action_dim = 0 else: num_primitives = expl_env.num_primitives max_arg_len = expl_env.max_arg_len if discrete_continuous_dist: continuous_action_dim = max_arg_len discrete_action_dim = num_primitives else: continuous_action_dim = max_arg_len + num_primitives discrete_action_dim = 0 if encoder_type == "pixel": obs_shape = (3 * frame_stack, image_size, image_size) pre_aug_obs_shape = ( 3 * frame_stack, pre_transform_image_size, pre_transform_image_size, ) else: obs_shape = env.observation_space.shape pre_aug_obs_shape = obs_shape replay_buffer = utils.ReplayBuffer( obs_shape=pre_aug_obs_shape, action_size=continuous_action_dim + discrete_action_dim, capacity=replay_buffer_capacity, batch_size=batch_size, device=device, image_size=image_size, pre_image_size=pre_transform_image_size, ) agent = make_agent( obs_shape=obs_shape, continuous_action_dim=continuous_action_dim, discrete_action_dim=discrete_action_dim, args=args, device=device, agent_kwargs=agent_kwargs, ) L = Logger(work_dir, use_tb=args.save_tb) episode, episode_reward, done = 0, 0, True start_time = time.time() epoch_start_time = time.time() train_expl_st = time.time() total_train_expl_time = 0 all_infos = [] ep_infos = [] num_train_calls = 0 for step in range(num_train_steps): # evaluate agent periodically if step % eval_freq == 0: total_train_expl_time += time.time()-train_expl_st L.log("eval/episode", episode, step) evaluate( eval_env, agent, video, num_eval_episodes, L, step, encoder_type, data_augs, image_size, pre_transform_image_size, env_name, action_repeat, work_dir, seed, ) if args.save_model: agent.save_curl(model_dir, step) if args.save_buffer: replay_buffer.save(buffer_dir) train_expl_st = time.time() if done: if step > 0: if step % log_interval == 0: L.log("train/duration", time.time() - epoch_start_time, step) L.dump(step) if step % log_interval == 0: L.log("train/episode_reward", episode_reward, step) obs = expl_env.reset() done = False episode_reward = 0 episode_step = 0 episode += 1 if step % log_interval == 0: all_infos.append(ep_infos) L.log("train/episode", episode, step) statistics = compute_path_info(all_infos) rlkit_logger.record_dict(statistics, prefix="exploration/") rlkit_logger.record_tabular( "time/epoch (s)", time.time() - epoch_start_time ) rlkit_logger.record_tabular("time/total (s)", time.time() - start_time) rlkit_logger.record_tabular("time/training and exploration (s)", total_train_expl_time) rlkit_logger.record_tabular("trainer/num train calls", num_train_calls) rlkit_logger.record_tabular("exploration/num steps total", step) rlkit_logger.record_tabular("Epoch", step // log_interval) rlkit_logger.dump_tabular(with_prefix=False, with_timestamp=False) all_infos = [] epoch_start_time = time.time() ep_infos = [] # sample action for data collection if step < init_steps: action = expl_env.action_space.sample() else: with utils.eval_mode(agent): action = agent.sample_action(obs / 255.0) # run training update if step >= init_steps: num_updates = 1 for _ in range(num_updates): agent.update(replay_buffer, L, step) num_train_calls += 1 next_obs, reward, done, info = expl_env.step(action) ep_infos.append(info) # allow infinit bootstrap done_bool = ( 0 if episode_step + 1 == expl_env._max_episode_steps else float(done) ) episode_reward += reward replay_buffer.add(obs, action, reward, next_obs, done_bool) obs = next_obs episode_step += 1
def _log_stats(self, epoch): logger.log(f"Epoch {epoch} finished", with_timestamp=True) """ Replay Buffer """ logger.record_dict( self.replay_buffer.get_diagnostics(), prefix="replay_buffer/" ) """ Trainer """ logger.record_dict(self.trainer.get_diagnostics(), prefix="trainer/") """ Exploration """ logger.record_dict( self.expl_data_collector.get_diagnostics(), prefix="exploration/" ) expl_paths = self.expl_data_collector.get_epoch_paths() if len(expl_paths) > 0: if hasattr(self.expl_env, "get_diagnostics"): logger.record_dict( self.expl_env.get_diagnostics(expl_paths), prefix="exploration/", ) logger.record_dict( eval_util.get_generic_path_information(expl_paths), prefix="exploration/", ) """ Evaluation """ logger.record_dict( self.eval_data_collector.get_diagnostics(), prefix="evaluation/", ) eval_paths = self.eval_data_collector.get_epoch_paths() if hasattr(self.eval_env, "get_diagnostics"): logger.record_dict( self.eval_env.get_diagnostics(eval_paths), prefix="evaluation/", ) logger.record_dict( eval_util.get_generic_path_information(eval_paths), prefix="evaluation/", ) """ Misc """ gt.stamp("logging") timings = _get_epoch_timings() timings["time/training and exploration (s)"] = self.total_train_expl_time logger.record_dict(timings) logger.record_tabular("Epoch", epoch) logger.dump_tabular(with_prefix=False, with_timestamp=False)
def pretrain_policy_with_bc(self): if self.buffer_for_bc_training == "demos": self.bc_training_buffer = self.demo_train_buffer self.bc_test_buffer = self.demo_test_buffer elif self.buffer_for_bc_training == "replay_buffer": self.bc_training_buffer = self.replay_buffer.train_replay_buffer self.bc_test_buffer = self.replay_buffer.validation_replay_buffer else: self.bc_training_buffer = None self.bc_test_buffer = None if self.load_policy_path: self.policy = load_local_or_remote_file(self.load_policy_path) ptu.copy_model_params_from_to(self.policy, self.target_policy) return logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('pretrain_policy.csv', relative_to_snapshot_dir=True) if self.do_pretrain_rollouts: total_ret = self.do_rollouts() print("INITIAL RETURN", total_ret / 20) prev_time = time.time() for i in range(self.bc_num_pretrain_steps): train_policy_loss, train_logp_loss, train_mse_loss, train_log_std = self.run_bc_batch( self.demo_train_buffer, self.policy) train_policy_loss = train_policy_loss * self.bc_weight self.policy_optimizer.zero_grad() train_policy_loss.backward() self.policy_optimizer.step() test_policy_loss, test_logp_loss, test_mse_loss, test_log_std = self.run_bc_batch( self.demo_test_buffer, self.policy) test_policy_loss = test_policy_loss * self.bc_weight if self.do_pretrain_rollouts and i % self.pretraining_env_logging_period == 0: total_ret = self.do_rollouts() print("Return at step {} : {}".format(i, total_ret / 20)) if i % self.pretraining_logging_period == 0: stats = { "pretrain_bc/batch": i, "pretrain_bc/Train Logprob Loss": ptu.get_numpy(train_logp_loss), "pretrain_bc/Test Logprob Loss": ptu.get_numpy(test_logp_loss), "pretrain_bc/Train MSE": ptu.get_numpy(train_mse_loss), "pretrain_bc/Test MSE": ptu.get_numpy(test_mse_loss), "pretrain_bc/train_policy_loss": ptu.get_numpy(train_policy_loss), "pretrain_bc/test_policy_loss": ptu.get_numpy(test_policy_loss), "pretrain_bc/epoch_time": time.time() - prev_time, } if self.do_pretrain_rollouts: stats["pretrain_bc/avg_return"] = total_ret / 20 logger.record_dict(stats) logger.dump_tabular(with_prefix=True, with_timestamp=False) pickle.dump(self.policy, open(logger.get_snapshot_dir() + '/bc.pkl', "wb")) prev_time = time.time() logger.remove_tabular_output( 'pretrain_policy.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) ptu.copy_model_params_from_to(self.policy, self.target_policy) if self.post_bc_pretrain_hyperparams: self.set_algorithm_weights(**self.post_bc_pretrain_hyperparams)
def _try_to_eval(self, epoch): if epoch % self.logging_period != 0: return if epoch in self.save_extra_manual_epoch_set: logger.save_extra_data( self.get_extra_data_to_save(epoch), file_name='extra_snapshot_itr{}'.format(epoch), mode='cloudpickle', ) if self._save_extra_every_epoch: logger.save_extra_data(self.get_extra_data_to_save(epoch)) gt.stamp('save-extra') if self._can_evaluate(): self.evaluate(epoch) gt.stamp('eval') params = self.get_epoch_snapshot(epoch) logger.save_itr_params(epoch, params) gt.stamp('save-snapshot') table_keys = logger.get_table_key_set() if self._old_table_keys is not None: assert table_keys == self._old_table_keys, ( "Table keys cannot change from iteration to iteration.") self._old_table_keys = table_keys logger.record_dict( self.trainer.get_diagnostics(), prefix='trainer/', ) logger.record_tabular( "Number of train steps total", self._n_train_steps_total, ) logger.record_tabular( "Number of env steps total", self._n_env_steps_total, ) logger.record_tabular( "Number of rollouts total", self._n_rollouts_total, ) times_itrs = gt.get_times().stamps.itrs train_time = times_itrs['train'][-1] sample_time = times_itrs['sample'][-1] save_extra_time = times_itrs['save-extra'][-1] save_snapshot_time = times_itrs['save-snapshot'][-1] eval_time = times_itrs['eval'][-1] if epoch > 0 else 0 epoch_time = train_time + sample_time + save_extra_time + eval_time total_time = gt.get_times().total logger.record_tabular('in_unsupervised_model', float(self.in_unsupervised_phase)) logger.record_tabular('Train Time (s)', train_time) logger.record_tabular('(Previous) Eval Time (s)', eval_time) logger.record_tabular('Sample Time (s)', sample_time) logger.record_tabular('Save Extra Time (s)', save_extra_time) logger.record_tabular('Save Snapshot Time (s)', save_snapshot_time) logger.record_tabular('Epoch Time (s)', epoch_time) logger.record_tabular('Total Train Time (s)', total_time) logger.record_tabular("Epoch", epoch) logger.dump_tabular(with_prefix=False, with_timestamp=False) else: logger.log("Skipping eval for now.")
def experiment(variant): import numpy as np import torch from torch import nn, optim from tqdm import tqdm import rlkit.torch.pytorch_util as ptu from rlkit.core import logger from rlkit.envs.primitives_make_env import make_env from rlkit.torch.model_based.dreamer.mlp import Mlp, MlpResidual from rlkit.torch.model_based.dreamer.train_world_model import ( compute_world_model_loss, get_dataloader, get_dataloader_rt, get_dataloader_separately, update_network, visualize_rollout, world_model_loss_rt, ) from rlkit.torch.model_based.dreamer.world_models import ( LowlevelRAPSWorldModel, WorldModel, ) env_suite, env_name, env_kwargs = ( variant["env_suite"], variant["env_name"], variant["env_kwargs"], ) max_path_length = variant["env_kwargs"]["max_path_length"] low_level_primitives = variant["low_level_primitives"] num_low_level_actions_per_primitive = variant[ "num_low_level_actions_per_primitive"] low_level_action_dim = variant["low_level_action_dim"] dataloader_kwargs = variant["dataloader_kwargs"] env = make_env(env_suite, env_name, env_kwargs) world_model_kwargs = variant["model_kwargs"] optimizer_kwargs = variant["optimizer_kwargs"] gradient_clip = variant["gradient_clip"] if low_level_primitives: world_model_kwargs["action_dim"] = low_level_action_dim else: world_model_kwargs["action_dim"] = env.action_space.low.shape[0] image_shape = env.image_shape world_model_kwargs["image_shape"] = image_shape scaler = torch.cuda.amp.GradScaler() world_model_loss_kwargs = variant["world_model_loss_kwargs"] clone_primitives = variant["clone_primitives"] clone_primitives_separately = variant["clone_primitives_separately"] clone_primitives_and_train_world_model = variant.get( "clone_primitives_and_train_world_model", False) batch_len = variant.get("batch_len", 100) num_epochs = variant["num_epochs"] loss_to_use = variant.get("loss_to_use", "both") logdir = logger.get_snapshot_dir() if clone_primitives_separately: ( train_dataloaders, test_dataloaders, train_datasets, test_datasets, ) = get_dataloader_separately( variant["datafile"], num_low_level_actions_per_primitive= num_low_level_actions_per_primitive, num_primitives=env.num_primitives, env=env, **dataloader_kwargs, ) elif clone_primitives_and_train_world_model: print("LOADING DATA") ( train_dataloader, test_dataloader, train_dataset, test_dataset, ) = get_dataloader_rt( variant["datafile"], max_path_length=max_path_length * num_low_level_actions_per_primitive + 1, **dataloader_kwargs, ) elif low_level_primitives or clone_primitives: print("LOADING DATA") ( train_dataloader, test_dataloader, train_dataset, test_dataset, ) = get_dataloader( variant["datafile"], max_path_length=max_path_length * num_low_level_actions_per_primitive + 1, **dataloader_kwargs, ) else: train_dataloader, test_dataloader, train_dataset, test_dataset = get_dataloader( variant["datafile"], max_path_length=max_path_length + 1, **dataloader_kwargs, ) if clone_primitives_and_train_world_model: if variant["mlp_act"] == "elu": mlp_act = nn.functional.elu elif variant["mlp_act"] == "relu": mlp_act = nn.functional.relu if variant["mlp_res"]: mlp_class = MlpResidual else: mlp_class = Mlp criterion = nn.MSELoss() primitive_model = mlp_class( hidden_sizes=variant["mlp_hidden_sizes"], output_size=low_level_action_dim, input_size=250 + env.action_space.low.shape[0] + 1, hidden_activation=mlp_act, ).to(ptu.device) world_model_class = LowlevelRAPSWorldModel world_model = world_model_class( primitive_model=primitive_model, **world_model_kwargs, ).to(ptu.device) optimizer = optim.Adam( world_model.parameters(), **optimizer_kwargs, ) best_test_loss = np.inf for i in tqdm(range(num_epochs)): eval_statistics = OrderedDict() print("Epoch: ", i) total_primitive_loss = 0 total_world_model_loss = 0 total_div_loss = 0 total_image_pred_loss = 0 total_transition_loss = 0 total_entropy_loss = 0 total_pred_discount_loss = 0 total_reward_pred_loss = 0 total_train_steps = 0 for data in train_dataloader: with torch.cuda.amp.autocast(): ( high_level_actions, obs, rewards, terminals, ), low_level_actions = data obs = obs.to(ptu.device).float() low_level_actions = low_level_actions.to( ptu.device).float() high_level_actions = high_level_actions.to( ptu.device).float() rewards = rewards.to(ptu.device).float() terminals = terminals.to(ptu.device).float() assert all(terminals[:, -1] == 1) rt_idxs = np.arange( num_low_level_actions_per_primitive, obs.shape[1], num_low_level_actions_per_primitive, ) rt_idxs = np.concatenate( [[0], rt_idxs] ) # reset obs, effect of first primitive, second primitive, so on batch_start = np.random.randint(0, obs.shape[1] - batch_len, size=(obs.shape[0])) batch_indices = np.linspace( batch_start, batch_start + batch_len, batch_len, endpoint=False, ).astype(int) ( post, prior, post_dist, prior_dist, image_dist, reward_dist, pred_discount_dist, _, action_preds, ) = world_model( obs, (high_level_actions, low_level_actions), use_network_action=False, batch_indices=batch_indices, rt_idxs=rt_idxs, ) obs = world_model.flatten_obs( obs[np.arange(batch_indices.shape[1]), batch_indices].permute(1, 0, 2), (int(np.prod(image_shape)), ), ) rewards = rewards.reshape(-1, rewards.shape[-1]) terminals = terminals.reshape(-1, terminals.shape[-1]) ( world_model_loss, div, image_pred_loss, reward_pred_loss, transition_loss, entropy_loss, pred_discount_loss, ) = world_model_loss_rt( world_model, image_shape, image_dist, reward_dist, { key: value[np.arange(batch_indices.shape[1]), batch_indices].permute(1, 0, 2).reshape( -1, value.shape[-1]) for key, value in prior.items() }, { key: value[np.arange(batch_indices.shape[1]), batch_indices].permute(1, 0, 2).reshape( -1, value.shape[-1]) for key, value in post.items() }, prior_dist, post_dist, pred_discount_dist, obs, rewards, terminals, **world_model_loss_kwargs, ) batch_start = np.random.randint( 0, low_level_actions.shape[1] - batch_len, size=(low_level_actions.shape[0]), ) batch_indices = np.linspace( batch_start, batch_start + batch_len, batch_len, endpoint=False, ).astype(int) primitive_loss = criterion( action_preds[np.arange(batch_indices.shape[1]), batch_indices].permute(1, 0, 2).reshape( -1, action_preds.shape[-1]), low_level_actions[:, 1:] [np.arange(batch_indices.shape[1]), batch_indices].permute(1, 0, 2).reshape( -1, action_preds.shape[-1]), ) total_primitive_loss += primitive_loss.item() total_world_model_loss += world_model_loss.item() total_div_loss += div.item() total_image_pred_loss += image_pred_loss.item() total_transition_loss += transition_loss.item() total_entropy_loss += entropy_loss.item() total_pred_discount_loss += pred_discount_loss.item() total_reward_pred_loss += reward_pred_loss.item() if loss_to_use == "wm": loss = world_model_loss elif loss_to_use == "primitive": loss = primitive_loss else: loss = world_model_loss + primitive_loss total_train_steps += 1 update_network(world_model, optimizer, loss, gradient_clip, scaler) scaler.update() eval_statistics["train/primitive_loss"] = (total_primitive_loss / total_train_steps) eval_statistics["train/world_model_loss"] = ( total_world_model_loss / total_train_steps) eval_statistics["train/image_pred_loss"] = (total_image_pred_loss / total_train_steps) eval_statistics["train/transition_loss"] = (total_transition_loss / total_train_steps) eval_statistics["train/entropy_loss"] = (total_entropy_loss / total_train_steps) eval_statistics["train/pred_discount_loss"] = ( total_pred_discount_loss / total_train_steps) eval_statistics["train/reward_pred_loss"] = ( total_reward_pred_loss / total_train_steps) latest_state_dict = world_model.state_dict().copy() with torch.no_grad(): total_primitive_loss = 0 total_world_model_loss = 0 total_div_loss = 0 total_image_pred_loss = 0 total_transition_loss = 0 total_entropy_loss = 0 total_pred_discount_loss = 0 total_reward_pred_loss = 0 total_loss = 0 total_test_steps = 0 for data in test_dataloader: with torch.cuda.amp.autocast(): ( high_level_actions, obs, rewards, terminals, ), low_level_actions = data obs = obs.to(ptu.device).float() low_level_actions = low_level_actions.to( ptu.device).float() high_level_actions = high_level_actions.to( ptu.device).float() rewards = rewards.to(ptu.device).float() terminals = terminals.to(ptu.device).float() assert all(terminals[:, -1] == 1) rt_idxs = np.arange( num_low_level_actions_per_primitive, obs.shape[1], num_low_level_actions_per_primitive, ) rt_idxs = np.concatenate( [[0], rt_idxs] ) # reset obs, effect of first primitive, second primitive, so on batch_start = np.random.randint(0, obs.shape[1] - batch_len, size=(obs.shape[0])) batch_indices = np.linspace( batch_start, batch_start + batch_len, batch_len, endpoint=False, ).astype(int) ( post, prior, post_dist, prior_dist, image_dist, reward_dist, pred_discount_dist, _, action_preds, ) = world_model( obs, (high_level_actions, low_level_actions), use_network_action=False, batch_indices=batch_indices, rt_idxs=rt_idxs, ) obs = world_model.flatten_obs( obs[np.arange(batch_indices.shape[1]), batch_indices].permute(1, 0, 2), (int(np.prod(image_shape)), ), ) rewards = rewards.reshape(-1, rewards.shape[-1]) terminals = terminals.reshape(-1, terminals.shape[-1]) ( world_model_loss, div, image_pred_loss, reward_pred_loss, transition_loss, entropy_loss, pred_discount_loss, ) = world_model_loss_rt( world_model, image_shape, image_dist, reward_dist, { key: value[np.arange(batch_indices.shape[1]), batch_indices].permute( 1, 0, 2).reshape( -1, value.shape[-1]) for key, value in prior.items() }, { key: value[np.arange(batch_indices.shape[1]), batch_indices].permute( 1, 0, 2).reshape( -1, value.shape[-1]) for key, value in post.items() }, prior_dist, post_dist, pred_discount_dist, obs, rewards, terminals, **world_model_loss_kwargs, ) batch_start = np.random.randint( 0, low_level_actions.shape[1] - batch_len, size=(low_level_actions.shape[0]), ) batch_indices = np.linspace( batch_start, batch_start + batch_len, batch_len, endpoint=False, ).astype(int) primitive_loss = criterion( action_preds[np.arange(batch_indices.shape[1]), batch_indices].permute( 1, 0, 2).reshape( -1, action_preds.shape[-1]), low_level_actions[:, 1:] [np.arange(batch_indices.shape[1]), batch_indices].permute(1, 0, 2).reshape( -1, action_preds.shape[-1]), ) total_primitive_loss += primitive_loss.item() total_world_model_loss += world_model_loss.item() total_div_loss += div.item() total_image_pred_loss += image_pred_loss.item() total_transition_loss += transition_loss.item() total_entropy_loss += entropy_loss.item() total_pred_discount_loss += pred_discount_loss.item() total_reward_pred_loss += reward_pred_loss.item() total_loss += world_model_loss.item( ) + primitive_loss.item() total_test_steps += 1 eval_statistics["test/primitive_loss"] = ( total_primitive_loss / total_test_steps) eval_statistics["test/world_model_loss"] = ( total_world_model_loss / total_test_steps) eval_statistics["test/image_pred_loss"] = ( total_image_pred_loss / total_test_steps) eval_statistics["test/transition_loss"] = ( total_transition_loss / total_test_steps) eval_statistics["test/entropy_loss"] = (total_entropy_loss / total_test_steps) eval_statistics["test/pred_discount_loss"] = ( total_pred_discount_loss / total_test_steps) eval_statistics["test/reward_pred_loss"] = ( total_reward_pred_loss / total_test_steps) if (total_loss / total_test_steps) <= best_test_loss: best_test_loss = total_loss / total_test_steps os.makedirs(logdir + "/models/", exist_ok=True) best_wm_state_dict = world_model.state_dict().copy() torch.save( best_wm_state_dict, logdir + "/models/world_model.pt", ) if i % variant["plotting_period"] == 0: print("Best test loss", best_test_loss) world_model.load_state_dict(best_wm_state_dict) visualize_wm( env, world_model, train_dataset.outputs, train_dataset.inputs[1], test_dataset.outputs, test_dataset.inputs[1], logdir, max_path_length, low_level_primitives, num_low_level_actions_per_primitive, primitive_model=primitive_model, ) world_model.load_state_dict(latest_state_dict) logger.record_dict(eval_statistics, prefix="") logger.dump_tabular(with_prefix=False, with_timestamp=False) elif clone_primitives_separately: world_model.load_state_dict(torch.load(variant["world_model_path"])) criterion = nn.MSELoss() primitives = [] for p in range(env.num_primitives): arguments_size = train_datasets[p].inputs[0].shape[-1] m = Mlp( hidden_sizes=variant["mlp_hidden_sizes"], output_size=low_level_action_dim, input_size=world_model.feature_size + arguments_size, hidden_activation=torch.nn.functional.relu, ).to(ptu.device) if variant.get("primitives_path", None): m.load_state_dict( torch.load(variant["primitives_path"] + "primitive_model_{}.pt".format(p))) primitives.append(m) optimizers = [ optim.Adam(p.parameters(), **optimizer_kwargs) for p in primitives ] for i in tqdm(range(num_epochs)): if i % variant["plotting_period"] == 0: visualize_rollout( env, None, None, world_model, logdir, max_path_length, use_env=True, forcing="none", tag="none", low_level_primitives=low_level_primitives, num_low_level_actions_per_primitive= num_low_level_actions_per_primitive, primitive_model=primitives, use_separate_primitives=True, ) visualize_rollout( env, None, None, world_model, logdir, max_path_length, use_env=True, forcing="teacher", tag="none", low_level_primitives=low_level_primitives, num_low_level_actions_per_primitive= num_low_level_actions_per_primitive, primitive_model=primitives, use_separate_primitives=True, ) visualize_rollout( env, None, None, world_model, logdir, max_path_length, use_env=True, forcing="self", tag="none", low_level_primitives=low_level_primitives, num_low_level_actions_per_primitive= num_low_level_actions_per_primitive, primitive_model=primitives, use_separate_primitives=True, ) eval_statistics = OrderedDict() print("Epoch: ", i) for p, ( train_dataloader, test_dataloader, primitive_model, optimizer, ) in enumerate( zip(train_dataloaders, test_dataloaders, primitives, optimizers)): total_loss = 0 total_train_steps = 0 for data in train_dataloader: with torch.cuda.amp.autocast(): (arguments, obs), actions = data obs = obs.to(ptu.device).float() actions = actions.to(ptu.device).float() arguments = arguments.to(ptu.device).float() action_preds = world_model( obs, (arguments, actions), primitive_model, use_network_action=False, )[-1] loss = criterion(action_preds, actions) total_loss += loss.item() total_train_steps += 1 update_network(primitive_model, optimizer, loss, gradient_clip, scaler) scaler.update() eval_statistics["train/primitive_loss {}".format(p)] = ( total_loss / total_train_steps) best_test_loss = np.inf with torch.no_grad(): total_loss = 0 total_test_steps = 0 for data in test_dataloader: with torch.cuda.amp.autocast(): (high_level_actions, obs), actions = data obs = obs.to(ptu.device).float() actions = actions.to(ptu.device).float() high_level_actions = high_level_actions.to( ptu.device).float() action_preds = world_model( obs, (high_level_actions, actions), primitive_model, use_network_action=False, )[-1] loss = criterion(action_preds, actions) total_loss += loss.item() total_test_steps += 1 eval_statistics["test/primitive_loss {}".format(p)] = ( total_loss / total_test_steps) if (total_loss / total_test_steps) <= best_test_loss: best_test_loss = total_loss / total_test_steps os.makedirs(logdir + "/models/", exist_ok=True) torch.save( primitive_model.state_dict(), logdir + "/models/primitive_model_{}.pt".format(p), ) logger.record_dict(eval_statistics, prefix="") logger.dump_tabular(with_prefix=False, with_timestamp=False) visualize_rollout( env, None, None, world_model, logdir, max_path_length, use_env=True, forcing="none", tag="none", low_level_primitives=low_level_primitives, num_low_level_actions_per_primitive= num_low_level_actions_per_primitive, primitive_model=primitives, use_separate_primitives=True, ) elif clone_primitives: world_model.load_state_dict(torch.load(variant["world_model_path"])) criterion = nn.MSELoss() primitive_model = Mlp( hidden_sizes=variant["mlp_hidden_sizes"], output_size=low_level_action_dim, input_size=world_model.feature_size + env.action_space.low.shape[0] + 1, hidden_activation=torch.nn.functional.relu, ).to(ptu.device) optimizer = optim.Adam( primitive_model.parameters(), **optimizer_kwargs, ) for i in tqdm(range(num_epochs)): if i % variant["plotting_period"] == 0: visualize_rollout( env, None, None, world_model, logdir, max_path_length, use_env=True, forcing="none", tag="none", low_level_primitives=low_level_primitives, num_low_level_actions_per_primitive= num_low_level_actions_per_primitive, primitive_model=primitive_model, ) visualize_rollout( env, train_dataset.outputs, train_dataset.inputs[1], world_model, logdir, max_path_length, use_env=False, forcing="teacher", tag="train", low_level_primitives=low_level_primitives, num_low_level_actions_per_primitive= num_low_level_actions_per_primitive - 1, ) visualize_rollout( env, test_dataset.outputs, test_dataset.inputs[1], world_model, logdir, max_path_length, use_env=False, forcing="teacher", tag="test", low_level_primitives=low_level_primitives, num_low_level_actions_per_primitive= num_low_level_actions_per_primitive - 1, ) eval_statistics = OrderedDict() print("Epoch: ", i) total_loss = 0 total_train_steps = 0 for data in train_dataloader: with torch.cuda.amp.autocast(): (high_level_actions, obs), actions = data obs = obs.to(ptu.device).float() actions = actions.to(ptu.device).float() high_level_actions = high_level_actions.to( ptu.device).float() action_preds = world_model( obs, (high_level_actions, actions), primitive_model, use_network_action=False, )[-1] loss = criterion(action_preds, actions) total_loss += loss.item() total_train_steps += 1 update_network(primitive_model, optimizer, loss, gradient_clip, scaler) scaler.update() eval_statistics[ "train/primitive_loss"] = total_loss / total_train_steps best_test_loss = np.inf with torch.no_grad(): total_loss = 0 total_test_steps = 0 for data in test_dataloader: with torch.cuda.amp.autocast(): (high_level_actions, obs), actions = data obs = obs.to(ptu.device).float() actions = actions.to(ptu.device).float() high_level_actions = high_level_actions.to( ptu.device).float() action_preds = world_model( obs, (high_level_actions, actions), primitive_model, use_network_action=False, )[-1] loss = criterion(action_preds, actions) total_loss += loss.item() total_test_steps += 1 eval_statistics[ "test/primitive_loss"] = total_loss / total_test_steps if (total_loss / total_test_steps) <= best_test_loss: best_test_loss = total_loss / total_test_steps os.makedirs(logdir + "/models/", exist_ok=True) torch.save( primitive_model.state_dict(), logdir + "/models/primitive_model.pt", ) logger.record_dict(eval_statistics, prefix="") logger.dump_tabular(with_prefix=False, with_timestamp=False) else: world_model = WorldModel(**world_model_kwargs).to(ptu.device) optimizer = optim.Adam( world_model.parameters(), **optimizer_kwargs, ) for i in tqdm(range(num_epochs)): if i % variant["plotting_period"] == 0: visualize_wm( env, world_model, train_dataset.inputs, train_dataset.outputs, test_dataset.inputs, test_dataset.outputs, logdir, max_path_length, low_level_primitives, num_low_level_actions_per_primitive, ) eval_statistics = OrderedDict() print("Epoch: ", i) total_wm_loss = 0 total_div_loss = 0 total_image_pred_loss = 0 total_transition_loss = 0 total_entropy_loss = 0 total_train_steps = 0 for data in train_dataloader: with torch.cuda.amp.autocast(): actions, obs = data obs = obs.to(ptu.device).float() actions = actions.to(ptu.device).float() post, prior, post_dist, prior_dist, image_dist = world_model( obs, actions)[:5] obs = world_model.flatten_obs(obs.permute( 1, 0, 2), (int(np.prod(image_shape)), )) ( world_model_loss, div, image_pred_loss, transition_loss, entropy_loss, ) = compute_world_model_loss( world_model, image_shape, image_dist, prior, post, prior_dist, post_dist, obs, **world_model_loss_kwargs, ) total_wm_loss += world_model_loss.item() total_div_loss += div.item() total_image_pred_loss += image_pred_loss.item() total_transition_loss += transition_loss.item() total_entropy_loss += entropy_loss.item() total_train_steps += 1 update_network(world_model, optimizer, world_model_loss, gradient_clip, scaler) scaler.update() eval_statistics[ "train/wm_loss"] = total_wm_loss / total_train_steps eval_statistics[ "train/div_loss"] = total_div_loss / total_train_steps eval_statistics["train/image_pred_loss"] = (total_image_pred_loss / total_train_steps) eval_statistics["train/transition_loss"] = (total_transition_loss / total_train_steps) eval_statistics["train/entropy_loss"] = (total_entropy_loss / total_train_steps) best_test_loss = np.inf with torch.no_grad(): total_wm_loss = 0 total_div_loss = 0 total_image_pred_loss = 0 total_transition_loss = 0 total_entropy_loss = 0 total_train_steps = 0 total_test_steps = 0 for data in test_dataloader: with torch.cuda.amp.autocast(): actions, obs = data obs = obs.to(ptu.device).float() actions = actions.to(ptu.device).float() post, prior, post_dist, prior_dist, image_dist = world_model( obs, actions)[:5] obs = world_model.flatten_obs(obs.permute( 1, 0, 2), (int(np.prod(image_shape)), )) ( world_model_loss, div, image_pred_loss, transition_loss, entropy_loss, ) = compute_world_model_loss( world_model, image_shape, image_dist, prior, post, prior_dist, post_dist, obs, **world_model_loss_kwargs, ) total_wm_loss += world_model_loss.item() total_div_loss += div.item() total_image_pred_loss += image_pred_loss.item() total_transition_loss += transition_loss.item() total_entropy_loss += entropy_loss.item() total_test_steps += 1 eval_statistics[ "test/wm_loss"] = total_wm_loss / total_test_steps eval_statistics[ "test/div_loss"] = total_div_loss / total_test_steps eval_statistics["test/image_pred_loss"] = ( total_image_pred_loss / total_test_steps) eval_statistics["test/transition_loss"] = ( total_transition_loss / total_test_steps) eval_statistics["test/entropy_loss"] = (total_entropy_loss / total_test_steps) if (total_wm_loss / total_test_steps) <= best_test_loss: best_test_loss = total_wm_loss / total_test_steps os.makedirs(logdir + "/models/", exist_ok=True) torch.save( world_model.state_dict(), logdir + "/models/world_model.pt", ) logger.record_dict(eval_statistics, prefix="") logger.dump_tabular(with_prefix=False, with_timestamp=False) world_model.load_state_dict( torch.load(logdir + "/models/world_model.pt")) visualize_wm( env, world_model, train_dataset, test_dataset, logdir, max_path_length, low_level_primitives, num_low_level_actions_per_primitive, )
def _log_vae_stats(self): logger.record_dict( self.vae_trainer.get_diagnostics(), prefix='vae_trainer/', )
def _log_stats(self, epoch): logger.log("Epoch {} finished".format(epoch), with_timestamp=True) """ dump video of policy """ if epoch % self.dump_video_interval == 0: imsize = self.expl_env.imsize env = self.eval_env dump_path = logger.get_snapshot_dir() rollout_func = rollout video_name = "{}_sample_goal.gif".format(epoch) latent_distance_name = "{}_latent_sample_goal.png".format(epoch) dump_video(env, self.eval_data_collector._policy, osp.join(dump_path, video_name), rollout_function=rollout_func, imsize=imsize, horizon=self.max_path_length, rows=1, columns=8) plot_latent_dist(env, self.eval_data_collector._policy, save_name=osp.join(dump_path, latent_distance_name), rollout_function=rollout_func, horizon=self.max_path_length) old_goal_sampling_mode, old_decode_goals = env._goal_sampling_mode, env.decode_goals env._goal_sampling_mode = 'reset_of_env' env.decode_goals = False video_name = "{}_reset.gif".format(epoch) latent_distance_name = "{}_latent_reset.png".format(epoch) dump_video(env, self.eval_data_collector._policy, osp.join(dump_path, video_name), rollout_function=rollout_func, imsize=imsize, horizon=self.max_path_length, rows=1, columns=8) plot_latent_dist(env, self.eval_data_collector._policy, save_name=osp.join(dump_path, latent_distance_name), rollout_function=rollout_func, horizon=self.max_path_length) self.eval_env._goal_sampling_mode = old_goal_sampling_mode self.eval_env.decode_goals = old_decode_goals """ Replay Buffer """ logger.record_dict(self.replay_buffer.get_diagnostics(), prefix='replay_buffer/') """ Trainer """ logger.record_dict(self.trainer.get_diagnostics(), prefix='trainer/') """ Exploration """ logger.record_dict(self.expl_data_collector.get_diagnostics(), prefix='exploration/') expl_paths = self.expl_data_collector.get_epoch_paths() if hasattr(self.expl_env, 'get_diagnostics'): logger.record_dict( self.expl_env.get_diagnostics(expl_paths), prefix='exploration/', ) logger.record_dict( eval_util.get_generic_path_information(expl_paths), prefix="exploration/", ) """ Evaluation """ logger.record_dict( self.eval_data_collector.get_diagnostics(), prefix='evaluation/', ) eval_paths = self.eval_data_collector.get_epoch_paths() if hasattr(self.eval_env, 'get_diagnostics'): logger.record_dict( self.eval_env.get_diagnostics(eval_paths), prefix='evaluation/', ) logger.record_dict( eval_util.get_generic_path_information(eval_paths), prefix="evaluation/", ) """ Misc """ gt.stamp('logging') logger.record_dict(_get_epoch_timings()) logger.record_tabular('Epoch', epoch) logger.dump_tabular(with_prefix=False, with_timestamp=False)