def train_ae(ae_trainer, training_distrib, num_epochs=100, num_batches_per_epoch=500, batch_size=512, goal_key='image_desired_goal', rl_csv_fname='progress.csv'): from rlkit.core import logger logger.remove_tabular_output(rl_csv_fname, relative_to_snapshot_dir=True) logger.add_tabular_output('ae_progress.csv', relative_to_snapshot_dir=True) for epoch in range(num_epochs): for batch_num in range(num_batches_per_epoch): goals = ptu.from_numpy( training_distrib.sample(batch_size)[goal_key]) batch = dict(raw_next_observations=goals, ) ae_trainer.train_from_torch(batch) log = OrderedDict() log['epoch'] = epoch append_log(log, ae_trainer.eval_statistics, prefix='ae/') logger.record_dict(log) logger.dump_tabular(with_prefix=True, with_timestamp=False) ae_trainer.end_epoch(epoch) logger.add_tabular_output(rl_csv_fname, relative_to_snapshot_dir=True) logger.remove_tabular_output('ae_progress.csv', relative_to_snapshot_dir=True)
def train_vae_and_update_variant(variant): from rlkit.core import logger skewfit_variant = variant["skewfit_variant"] train_vae_variant = variant["train_vae_variant"] if skewfit_variant.get("vae_path", None) is None: logger.remove_tabular_output("progress.csv", relative_to_snapshot_dir=True) logger.add_tabular_output("vae_progress.csv", relative_to_snapshot_dir=True) vae, vae_train_data, vae_test_data = train_vae(train_vae_variant, return_data=True) if skewfit_variant.get("save_vae_data", False): skewfit_variant["vae_train_data"] = vae_train_data skewfit_variant["vae_test_data"] = vae_test_data logger.save_extra_data(vae, "vae.pkl", mode="pickle") logger.remove_tabular_output("vae_progress.csv", relative_to_snapshot_dir=True) logger.add_tabular_output("progress.csv", relative_to_snapshot_dir=True) skewfit_variant["vae_path"] = vae # just pass the VAE directly else: if skewfit_variant.get("save_vae_data", False): vae_train_data, vae_test_data, info = generate_vae_dataset( train_vae_variant["generate_vae_dataset_kwargs"]) skewfit_variant["vae_train_data"] = vae_train_data skewfit_variant["vae_test_data"] = vae_test_data
def train_vae_and_update_variant(variant): from rlkit.core import logger skewfit_variant = variant['skewfit_variant'] train_vae_variant = variant['train_vae_variant'] if skewfit_variant.get('vae_path', None) is None: logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('vae_progress.csv', relative_to_snapshot_dir=True) vae, vae_train_data, vae_test_data = train_vae( train_vae_variant, variant['other_variant'], return_data=True) if skewfit_variant.get('save_vae_data', False): skewfit_variant['vae_train_data'] = vae_train_data skewfit_variant['vae_test_data'] = vae_test_data logger.save_extra_data(vae, 'vae.pkl', mode='pickle') logger.remove_tabular_output( 'vae_progress.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) skewfit_variant['vae_path'] = vae # just pass the VAE directly else: if skewfit_variant.get('save_vae_data', False): vae_train_data, vae_test_data, info = generate_vae_dataset( train_vae_variant['generate_vae_dataset_kwargs']) skewfit_variant['vae_train_data'] = vae_train_data skewfit_variant['vae_test_data'] = vae_test_data
def train_vae_and_update_config(cfgs): from rlkit.core import logger if cfgs.VAE_TRAINER.get('vae_path', None) is None: logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('vae_progress.csv', relative_to_snapshot_dir=True) vae, vae_train_data, vae_test_data = train_vae(cfgs, return_data=True) cfgs.VAE_TRAINER.vae_path = vae # just pass the VAE directly if cfgs.VAE_TRAINER.save_vae_data: cfgs.VAE_TRAINER.train_data = vae_train_data cfgs.VAE_TRAINER.test_data = vae_test_data logger.remove_tabular_output( 'vae_progress.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) else: vae_train_data, vae_test_data, _ = generate_vae_dataset(cfgs) if cfgs.VAE_TRAINER.save_vae_data: cfgs.VAE_TRAINER.train_data = vae_train_data cfgs.VAE_TRAINER.test_data = vae_test_data
def run(self): if self.progress_csv_file_name != 'progress.csv': logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output( self.progress_csv_file_name, relative_to_snapshot_dir=True, ) timer.return_global_times = True for _ in range(self.num_iters): self._begin_epoch() timer.start_timer('saving') logger.save_itr_params(self.epoch, self._get_snapshot()) timer.stop_timer('saving') log_dict, _ = self._train() logger.record_dict(log_dict) logger.dump_tabular(with_prefix=True, with_timestamp=False) self._end_epoch() logger.save_itr_params(self.epoch, self._get_snapshot()) if self.progress_csv_file_name != 'progress.csv': logger.remove_tabular_output( self.progress_csv_file_name, relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, )
def continue_experiment(args): logger.add_text_output('./d_text.txt') logger.add_tabular_output('./d_tabular.txt') logger.set_snapshot_dir('./snaps') extra = joblib.load(args.extra) algorithm = extra['algorithm'] algorithm.farmlist_base = [('0.0.0.0', 15)] algorithm.refarm() if ptu.gpu_enabled(): algorithm.cuda() algorithm.train()
def pretrain_q_with_bc_data(self, batch_size): logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('pretrain_q.csv', relative_to_snapshot_dir=True) self.update_policy = True # then train policy and Q function together prev_time = time.time() for i in range(self.num_pretrain_steps): self.eval_statistics = dict() if i % self.pretraining_logging_period == 0: self._need_to_update_eval_statistics = True train_data = self.replay_buffer.random_batch(batch_size) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] # goals = train_data['resampled_goals'] train_data['observations'] = obs # torch.cat((obs, goals), dim=1) train_data[ 'next_observations'] = next_obs # torch.cat((next_obs, goals), dim=1) self.train_from_torch(train_data) if self.do_pretrain_rollouts and i % self.pretraining_env_logging_period == 0: total_ret = self.do_rollouts() print("Return at step {} : {}".format(i, total_ret / 20)) if i % self.pretraining_logging_period == 0: if self.do_pretrain_rollouts: self.eval_statistics[ "pretrain_bc/avg_return"] = total_ret / 20 self.eval_statistics["batch"] = i self.eval_statistics["epoch_time"] = time.time() - prev_time logger.record_dict(self.eval_statistics) logger.dump_tabular(with_prefix=True, with_timestamp=False) prev_time = time.time() logger.remove_tabular_output( 'pretrain_q.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) self._need_to_update_eval_statistics = True self.eval_statistics = dict()
def experiment(variant): logger.add_text_output('./d_text.txt') logger.add_tabular_output('./d_tabular.txt') logger.set_snapshot_dir('./snaps') farmer = Farmer([('0.0.0.0', 1)]) remote_env = farmer.force_acq_env() remote_env.set_spaces() env = NormalizedBoxEnv(remote_env) es = GaussianStrategy( action_space=env.action_space, max_sigma=0.1, min_sigma=0.1, # Constant sigma ) obs_dim = env.observation_space.low.size action_dim = env.action_space.low.size qf1 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[256, 256], ) qf2 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[256, 256], ) policy = TanhMlpPolicy( input_size=obs_dim, output_size=action_dim, hidden_sizes=[256, 256], ) exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) algorithm = TD3(env, qf1=qf1, qf2=qf2, policy=policy, exploration_policy=exploration_policy, **variant['algo_kwargs']) if ptu.gpu_enabled(): algorithm.cuda() algorithm.train()
def transfer_encoder_goal_conditioned_sac_experiment(train_variant, transfer_variant=None): rl_csv_fname = 'progress.csv' retrain_rl_csv_fname = 'retrain_progress.csv' train_results = encoder_goal_conditioned_sac_experiment( **train_variant, rl_csv_fname=rl_csv_fname) logger.remove_tabular_output(rl_csv_fname, relative_to_snapshot_dir=True) logger.add_tabular_output(retrain_rl_csv_fname, relative_to_snapshot_dir=True) if not transfer_variant: transfer_variant = {} transfer_variant = recursive_dict_update(train_variant, transfer_variant) print('Starting transfer exp') encoder_goal_conditioned_sac_experiment(**transfer_variant, is_retraining_from_scratch=True, train_results=train_results, rl_csv_fname=retrain_rl_csv_fname)
def pretrain_q_with_bc_data(self, batch_size): logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('pretrain_q.csv', relative_to_snapshot_dir=True) prev_time = time.time() for i in range(self.num_pretrain_steps): self.eval_statistics = dict() if i % self.pretraining_logging_period == 0: self._need_to_update_eval_statistics = True train_data = self.replay_buffer.random_batch(self.bc_batch_size) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] # goals = train_data['resampled_goals'] train_data['observations'] = obs # torch.cat((obs, goals), dim=1) train_data[ 'next_observations'] = next_obs # torch.cat((next_obs, goals), dim=1) self.train_from_torch(train_data, pretrain=True) if i % self.pretraining_logging_period == 0: self.eval_statistics["batch"] = i self.eval_statistics["epoch_time"] = time.time() - prev_time stats_with_prefix = add_prefix(self.eval_statistics, prefix="trainer/") logger.record_dict(stats_with_prefix) logger.dump_tabular(with_prefix=True, with_timestamp=False) prev_time = time.time() logger.remove_tabular_output( 'pretrain_q.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) self._need_to_update_eval_statistics = True self.eval_statistics = dict()
def grill_her_full_experiment(variant, mode='td3'): train_vae_variant = variant['train_vae_variant'] grill_variant = variant['grill_variant'] env_class = variant['env_class'] env_kwargs = variant['env_kwargs'] init_camera = variant['init_camera'] train_vae_variant['generate_vae_dataset_kwargs']['env_class'] = env_class train_vae_variant['generate_vae_dataset_kwargs']['env_kwargs'] = env_kwargs train_vae_variant['generate_vae_dataset_kwargs'][ 'init_camera'] = init_camera grill_variant['env_class'] = env_class grill_variant['env_kwargs'] = env_kwargs grill_variant['init_camera'] = init_camera if 'vae_paths' not in grill_variant: logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('vae_progress.csv', relative_to_snapshot_dir=True) vae = train_vae(train_vae_variant) rdim = train_vae_variant['representation_size'] vae_file = logger.save_extra_data(vae, 'vae.pkl', mode='pickle') logger.remove_tabular_output( 'vae_progress.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) grill_variant['vae_paths'] = { str(rdim): vae_file, } grill_variant['rdim'] = str(rdim) if mode == 'td3': grill_her_td3_experiment(variant['grill_variant']) elif mode == 'twin-sac': grill_her_twin_sac_experiment(variant['grill_variant']) elif mode == 'sac': grill_her_sac_experiment(variant['grill_variant'])
def experiment(variant): logger.add_text_output('./d_text.txt') logger.add_tabular_output('./d_tabular.txt') logger.set_snapshot_dir('./snaps') farmer = Farmer([('0.0.0.0', 1)]) remote_env = farmer.force_acq_env() remote_env.set_spaces() env = NormalizedBoxEnv(remote_env) es = OUStrategy(action_space=env.action_space) obs_dim = env.observation_space.low.size action_dim = env.action_space.low.size net_size = variant['net_size'] qf = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[net_size, net_size], ) policy = TanhMlpPolicy( input_size=obs_dim, output_size=action_dim, hidden_sizes=[net_size, net_size], ) exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) algorithm = DDPG(env, qf=qf, policy=policy, exploration_policy=exploration_policy, **variant['algo_params']) if ptu.gpu_enabled(): algorithm.cuda() algorithm.train()
def experiment(variant): logger.add_text_output('./d_text.txt') logger.add_tabular_output('./d_tabular.txt') logger.set_snapshot_dir('./snaps') farmer = Farmer([('0.0.0.0', 1)]) remote_env = farmer.force_acq_env() remote_env.set_spaces() env = NormalizedBoxEnv(remote_env) obs_dim = int(np.prod(env.observation_space.shape)) action_dim = int(np.prod(env.action_space.shape)) net_size = variant['net_size'] qf = FlattenMlp( hidden_sizes=[net_size, net_size], input_size=obs_dim + action_dim, output_size=1, ) vf = FlattenMlp( hidden_sizes=[net_size, net_size], input_size=obs_dim, output_size=1, ) policy = TanhGaussianPolicy( hidden_sizes=[net_size, net_size], obs_dim=obs_dim, action_dim=action_dim, ) algorithm = SoftActorCritic(env=env, training_env=env, policy=policy, qf=qf, vf=vf, **variant['algo_params']) if ptu.gpu_enabled(): algorithm.cuda() algorithm.train()
def train_vae( variant, env_kwargs, env_id, env_class, imsize, init_camera, return_data=False ): from rlkit.misc.ml_util import PiecewiseLinearSchedule, ConstantSchedule from rlkit.torch.vae.conv_vae import ( ConvVAE, SpatialAutoEncoder, AutoEncoder, ) import rlkit.torch.vae.conv_vae as conv_vae from rlkit.torch.vae.vae_trainer import ConvVAETrainer from rlkit.core import logger import rlkit.torch.pytorch_util as ptu from rlkit.pythonplusplus import identity import torch logger.remove_tabular_output( 'progress.csv', relative_to_snapshot_dir=True ) logger.add_tabular_output( 'model_progress.csv', relative_to_snapshot_dir=True ) beta = variant["beta"] representation_size = variant.get("representation_size", variant.get("latent_sizes", variant.get("embedding_dim", None))) use_linear_dynamics = variant.get('use_linear_dynamics', False) generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn', generate_vae_dataset) variant['generate_vae_dataset_kwargs']['use_linear_dynamics'] = use_linear_dynamics variant['generate_vae_dataset_kwargs']['batch_size'] = variant['algo_kwargs']['batch_size'] train_dataset, test_dataset, info = generate_vae_dataset_fctn( env_kwargs, env_id, env_class, imsize, init_camera, **variant['generate_vae_dataset_kwargs'] ) if use_linear_dynamics: action_dim = train_dataset.data['actions'].shape[2] logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None if 'context_schedule' in variant: schedule = variant['context_schedule'] if type(schedule) is dict: context_schedule = PiecewiseLinearSchedule(**schedule) else: context_schedule = ConstantSchedule(schedule) variant['algo_kwargs']['context_schedule'] = context_schedule if variant.get('decoder_activation', None) == 'sigmoid': decoder_activation = torch.nn.Sigmoid() else: decoder_activation = identity architecture = variant['vae_kwargs'].get('architecture', None) if not architecture and imsize == 84: architecture = conv_vae.imsize84_default_architecture elif not architecture and imsize == 48: architecture = conv_vae.imsize48_default_architecture variant['vae_kwargs']['architecture'] = architecture variant['vae_kwargs']['imsize'] = imsize if variant['algo_kwargs'].get('is_auto_encoder', False): model = AutoEncoder(representation_size, decoder_output_activation=decoder_activation,**variant['vae_kwargs']) elif variant.get('use_spatial_auto_encoder', False): model = SpatialAutoEncoder(representation_size, decoder_output_activation=decoder_activation,**variant['vae_kwargs']) else: vae_class = variant.get('vae_class', ConvVAE) if use_linear_dynamics: model = vae_class(representation_size, decoder_output_activation=decoder_activation, action_dim=action_dim,**variant['vae_kwargs']) else: model = vae_class(representation_size, decoder_output_activation=decoder_activation,**variant['vae_kwargs']) model.to(ptu.device) vae_trainer_class = variant.get('vae_trainer_class', ConvVAETrainer) trainer = vae_trainer_class(model, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False) for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) trainer.train_epoch(epoch, train_dataset) trainer.test_epoch(epoch, test_dataset) if should_save_imgs: trainer.dump_reconstructions(epoch) trainer.dump_samples(epoch) if dump_skew_debug_plots: trainer.dump_best_reconstruction(epoch) trainer.dump_worst_reconstruction(epoch) trainer.dump_sampling_histogram(epoch) stats = trainer.get_diagnostics() for k, v in stats.items(): logger.record_tabular(k, v) logger.dump_tabular() trainer.end_epoch(epoch) if epoch % 50 == 0: logger.save_itr_params(epoch, model) logger.save_extra_data(model, 'vae.pkl', mode='pickle') logger.remove_tabular_output( 'model_progress.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) if return_data: return model, train_dataset, test_dataset return model
def setup_logger( exp_prefix="default", variant=None, text_log_file="debug.log", variant_log_file="variant.json", tabular_log_file="progress.csv", snapshot_mode="last", snapshot_gap=1, log_tabular_only=False, log_dir=None, git_infos=None, script_name=None, **create_log_dir_kwargs ): """ Set up logger to have some reasonable default settings. Will save log output to based_log_dir/exp_prefix/exp_name. exp_name will be auto-generated to be unique. If log_dir is specified, then that directory is used as the output dir. :param exp_prefix: The sub-directory for this specific experiment. :param variant: :param text_log_file: :param variant_log_file: :param tabular_log_file: :param snapshot_mode: :param log_tabular_only: :param snapshot_gap: :param log_dir: :param git_infos: :param script_name: If set, save the script name to this. :return: """ if git_infos is None: git_infos = get_git_infos(conf.CODE_DIRS_TO_MOUNT) first_time = log_dir is None if first_time: log_dir = create_log_dir(exp_prefix, **create_log_dir_kwargs) if variant is not None: logger.log("Variant:") logger.log(json.dumps(dict_to_safe_json(variant), indent=2)) variant_log_path = osp.join(log_dir, variant_log_file) logger.log_variant(variant_log_path, variant) tabular_log_path = osp.join(log_dir, tabular_log_file) text_log_path = osp.join(log_dir, text_log_file) logger.add_text_output(text_log_path) if first_time: logger.add_tabular_output(tabular_log_path) else: logger._add_output(tabular_log_path, logger._tabular_outputs, logger._tabular_fds, mode='a') for tabular_fd in logger._tabular_fds: logger._tabular_header_written.add(tabular_fd) logger.set_snapshot_dir(log_dir) logger.set_snapshot_mode(snapshot_mode) logger.set_snapshot_gap(snapshot_gap) logger.set_log_tabular_only(log_tabular_only) exp_name = log_dir.split("/")[-1] logger.push_prefix("[%s] " % exp_name) if git_infos is not None: for ( directory, code_diff, code_diff_staged, commit_hash, branch_name ) in git_infos: if directory[-1] == '/': directory = directory[:-1] diff_file_name = directory[1:].replace("/", "-") + ".patch" diff_staged_file_name = ( directory[1:].replace("/", "-") + "_staged.patch" ) if code_diff is not None and len(code_diff) > 0: with open(osp.join(log_dir, diff_file_name), "w") as f: f.write(code_diff + '\n') if code_diff_staged is not None and len(code_diff_staged) > 0: with open(osp.join(log_dir, diff_staged_file_name), "w") as f: f.write(code_diff_staged + '\n') with open(osp.join(log_dir, "git_infos.txt"), "a") as f: f.write("directory: {}\n".format(directory)) f.write("git hash: {}\n".format(commit_hash)) f.write("git branch name: {}\n\n".format(branch_name)) if script_name is not None: with open(osp.join(log_dir, "script_name.txt"), "w") as f: f.write(script_name) return log_dir
def setup_logger( exp_prefix="default", exp_id=0, seed=0, variant=None, base_log_dir=None, text_log_file="debug.log", variant_log_file="variant.json", tabular_log_file="progress.csv", snapshot_mode="last", snapshot_gap=1, log_tabular_only=False, log_dir=None, git_info=None, script_name=None, ): """ Set up logger to have some reasonable default settings. Will save log output to based_log_dir/exp_prefix/exp_name. exp_name will be auto-generated to be unique. If log_dir is specified, then that directory is used as the output dir. :param exp_prefix: The sub-directory for this specific experiment. :param exp_id: The number of the specific experiment run within this experiment. :param variant: :param base_log_dir: The directory where all log should be saved. :param text_log_file: :param variant_log_file: :param tabular_log_file: :param snapshot_mode: :param log_tabular_only: :param snapshot_gap: :param log_dir: :param git_info: :param script_name: If set, save the script name to this. :return: """ first_time = log_dir is None if first_time: log_dir = create_log_dir(exp_prefix, exp_id=exp_id, seed=seed, base_log_dir=base_log_dir) if variant is not None: logger.log("Variant:") logger.log(json.dumps(dict_to_safe_json(variant), indent=2)) variant_log_path = osp.join(log_dir, variant_log_file) logger.log_variant(variant_log_path, variant) tabular_log_path = osp.join(log_dir, tabular_log_file) text_log_path = osp.join(log_dir, text_log_file) logger.add_text_output(text_log_path) if first_time: logger.add_tabular_output(tabular_log_path) else: logger._add_output(tabular_log_path, logger._tabular_outputs, logger._tabular_fds, mode='a') for tabular_fd in logger._tabular_fds: logger._tabular_header_written.add(tabular_fd) logger.set_snapshot_dir(log_dir) logger.set_snapshot_mode(snapshot_mode) logger.set_snapshot_gap(snapshot_gap) logger.set_log_tabular_only(log_tabular_only) exp_name = log_dir.split("/")[-1] logger.push_prefix("[%s] " % exp_name) if git_info is not None: code_diff, commit_hash, branch_name = git_info if code_diff is not None: with open(osp.join(log_dir, "code.diff"), "w") as f: f.write(code_diff) with open(osp.join(log_dir, "git_info.txt"), "w") as f: f.write("git hash: {}".format(commit_hash)) f.write('\n') f.write("git branch name: {}".format(branch_name)) if script_name is not None: with open(osp.join(log_dir, "script_name.txt"), "w") as f: f.write(script_name) return log_dir
def pretrain_q_with_bc_data(self): """ :return: """ logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('pretrain_q.csv', relative_to_snapshot_dir=True) self.update_policy = False # first train only the Q function for i in range(self.q_num_pretrain1_steps): self.eval_statistics = dict() train_data = self.replay_buffer.random_batch(self.bc_batch_size) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] # goals = train_data['resampled_goals'] train_data['observations'] = obs # torch.cat((obs, goals), dim=1) train_data[ 'next_observations'] = next_obs # torch.cat((next_obs, goals), dim=1) self.train_from_torch(train_data, pretrain=True) if i % self.pretraining_logging_period == 0: stats_with_prefix = add_prefix(self.eval_statistics, prefix="trainer/") logger.record_dict(stats_with_prefix) logger.dump_tabular(with_prefix=True, with_timestamp=False) self.update_policy = True # then train policy and Q function together prev_time = time.time() for i in range(self.q_num_pretrain2_steps): self.eval_statistics = dict() if i % self.pretraining_logging_period == 0: self._need_to_update_eval_statistics = True train_data = self.replay_buffer.random_batch(self.bc_batch_size) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] # goals = train_data['resampled_goals'] train_data['observations'] = obs # torch.cat((obs, goals), dim=1) train_data[ 'next_observations'] = next_obs # torch.cat((next_obs, goals), dim=1) self.train_from_torch(train_data, pretrain=True) if i % self.pretraining_logging_period == 0: self.eval_statistics["batch"] = i self.eval_statistics["epoch_time"] = time.time() - prev_time stats_with_prefix = add_prefix(self.eval_statistics, prefix="trainer/") logger.record_dict(stats_with_prefix) logger.dump_tabular(with_prefix=True, with_timestamp=False) prev_time = time.time() logger.remove_tabular_output( 'pretrain_q.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) self._need_to_update_eval_statistics = True self.eval_statistics = dict() if self.post_pretrain_hyperparams: self.set_algorithm_weights(**self.post_pretrain_hyperparams)
def pretrain_policy_with_bc( self, policy, train_buffer, test_buffer, steps, label="policy", ): """Given a policy, first get its optimizer, then run the policy on the train buffer, get the losses, and back propagate the loss. After training on a batch, test on the test buffer and get the statistics :param policy: :param train_buffer: :param test_buffer: :param steps: :param label: :return: """ logger.remove_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'pretrain_%s.csv' % label, relative_to_snapshot_dir=True, ) optimizer = self.optimizers[policy] prev_time = time.time() for i in range(steps): train_policy_loss, train_logp_loss, train_mse_loss, train_stats = self.run_bc_batch( train_buffer, policy) train_policy_loss = train_policy_loss * self.bc_weight optimizer.zero_grad() train_policy_loss.backward() optimizer.step() test_policy_loss, test_logp_loss, test_mse_loss, test_stats = self.run_bc_batch( test_buffer, policy) test_policy_loss = test_policy_loss * self.bc_weight if i % self.pretraining_logging_period == 0: stats = { "pretrain_bc/batch": i, "pretrain_bc/Train Logprob Loss": ptu.get_numpy(train_logp_loss), "pretrain_bc/Test Logprob Loss": ptu.get_numpy(test_logp_loss), "pretrain_bc/Train MSE": ptu.get_numpy(train_mse_loss), "pretrain_bc/Test MSE": ptu.get_numpy(test_mse_loss), "pretrain_bc/train_policy_loss": ptu.get_numpy(train_policy_loss), "pretrain_bc/test_policy_loss": ptu.get_numpy(test_policy_loss), "pretrain_bc/epoch_time": time.time() - prev_time, } logger.record_dict(stats) logger.dump_tabular(with_prefix=True, with_timestamp=False) pickle.dump( self.policy, open(logger.get_snapshot_dir() + '/bc_%s.pkl' % label, "wb")) prev_time = time.time() logger.remove_tabular_output( 'pretrain_%s.csv' % label, relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) if self.post_bc_pretrain_hyperparams: self.set_algorithm_weights(**self.post_bc_pretrain_hyperparams)
def train_vae( vae: VAE, dset: NpGroundTruthData, num_epochs: int = 0, # Do not pre-train by default save_period: int = 5, test_p: float = 0.1, # data proportion to use for test vae_trainer_kwargs: Dict = {}, ): logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('vae_progress.csv', relative_to_snapshot_dir=True) # Flatten images. n = dset.data.shape[0] data = dset.data.transpose(0, 3, 2, 1) data = data.reshape((n, -1)) # Un-normalize images. if data.dtype != np.uint8: assert np.min(data) >= 0.0 assert np.max(data) <= 1.0 data = (data * 255).astype(np.uint8) # Make sure factors are normalized. assert np.min(dset.factors) >= 0.0 assert np.max(dset.factors) <= 1.0 # Split into train and test set. test_size = int(n * test_p) test_data = data[:test_size] train_data = data[test_size:] train_factors = dset.factors[test_size:] test_factors = dset.factors[:test_size] logger.get_snapshot_dir() trainer = VAETrainer(vae, train_data, test_data, train_factors=train_factors, test_factors=test_factors, **vae_trainer_kwargs) for epoch in range(num_epochs): should_save_imgs = (epoch % save_period == 0) trainer.train_epoch(epoch) trainer.test_epoch(epoch, save_reconstruction=should_save_imgs) if should_save_imgs: trainer.dump_samples(epoch) trainer.update_train_weights() logger.save_extra_data(vae, 'vae.pkl', mode='pickle') logger.remove_tabular_output( 'vae_progress.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) return vae, trainer
def smac_experiment( trainer_kwargs=None, algo_kwargs=None, qf_kwargs=None, policy_kwargs=None, context_encoder_kwargs=None, context_decoder_kwargs=None, env_name=None, env_params=None, path_loader_kwargs=None, latent_dim=None, policy_class="TanhGaussianPolicy", # video/debug debug=False, use_dummy_encoder=False, networks_ignore_context=False, use_ground_truth_context=False, save_video=False, save_video_period=False, # Pre-train params pretrain_rl=False, pretrain_offline_algo_kwargs=None, pretrain_buffer_kwargs=None, load_buffer_kwargs=None, saved_tasks_path=None, macaw_format_base_path=None, # overrides saved_tasks_path and load_buffer_kwargs load_macaw_buffer_kwargs=None, train_task_idxs=None, eval_task_idxs=None, relabel_offline_dataset=False, skip_initial_data_collection_if_pretrained=False, relabel_kwargs=None, # PEARL n_train_tasks=0, n_eval_tasks=0, use_next_obs_in_context=False, tags=None, online_trainer_kwargs=None, ): if not skip_initial_data_collection_if_pretrained: raise NotImplementedError("deprecated! make sure to skip it!") if relabel_kwargs is None: relabel_kwargs = {} del tags pretrain_buffer_kwargs = pretrain_buffer_kwargs or {} context_decoder_kwargs = context_decoder_kwargs or {} pretrain_offline_algo_kwargs = pretrain_offline_algo_kwargs or {} online_trainer_kwargs = online_trainer_kwargs or {} register_pearl_envs() env_params = env_params or {} context_encoder_kwargs = context_encoder_kwargs or {} trainer_kwargs = trainer_kwargs or {} path_loader_kwargs = path_loader_kwargs or {} load_macaw_buffer_kwargs = load_macaw_buffer_kwargs or {} base_env = ENVS[env_name](**env_params) if saved_tasks_path: task_data = load_local_or_remote_file(saved_tasks_path, file_type='joblib') tasks = task_data['tasks'] train_task_idxs = task_data['train_task_indices'] eval_task_idxs = task_data['eval_task_indices'] base_env.tasks = tasks elif macaw_format_base_path is not None: tasks = pickle.load( open('{}/tasks.pkl'.format(macaw_format_base_path), 'rb')) base_env.tasks = tasks else: tasks = base_env.tasks task_indices = base_env.get_all_task_idx() train_task_idxs = list(task_indices[:n_train_tasks]) eval_task_idxs = list(task_indices[-n_eval_tasks:]) if hasattr(base_env, 'task_to_vec'): train_tasks = [base_env.task_to_vec(tasks[i]) for i in train_task_idxs] eval_tasks = [base_env.task_to_vec(tasks[i]) for i in eval_task_idxs] else: train_tasks = [tasks[i] for i in train_task_idxs] eval_tasks = [tasks[i] for i in eval_task_idxs] if use_ground_truth_context: latent_dim = len(train_tasks[0]) expl_env = NormalizedBoxEnv(base_env) reward_dim = 1 if debug: algo_kwargs['max_path_length'] = 50 algo_kwargs['batch_size'] = 5 algo_kwargs['num_epochs'] = 5 algo_kwargs['num_eval_steps_per_epoch'] = 100 algo_kwargs['num_expl_steps_per_train_loop'] = 100 algo_kwargs['num_trains_per_train_loop'] = 10 algo_kwargs['min_num_steps_before_training'] = 100 obs_dim = expl_env.observation_space.low.size action_dim = expl_env.action_space.low.size if use_next_obs_in_context: context_encoder_input_dim = 2 * obs_dim + action_dim + reward_dim else: context_encoder_input_dim = obs_dim + action_dim + reward_dim context_encoder_output_dim = latent_dim * 2 def create_qf(): return ConcatMlp(input_size=obs_dim + action_dim + latent_dim, output_size=1, **qf_kwargs) qf1 = create_qf() qf2 = create_qf() target_qf1 = create_qf() target_qf2 = create_qf() if isinstance(policy_class, str): policy_class = policy_class_from_str(policy_class) policy = policy_class( obs_dim=obs_dim + latent_dim, action_dim=action_dim, **policy_kwargs, ) encoder_class = DummyMlpEncoder if use_dummy_encoder else MlpEncoder context_encoder = encoder_class( input_size=context_encoder_input_dim, output_size=context_encoder_output_dim, hidden_sizes=[200, 200, 200], use_ground_truth_context=use_ground_truth_context, **context_encoder_kwargs) context_decoder = MlpDecoder(input_size=obs_dim + action_dim + latent_dim, output_size=1, **context_decoder_kwargs) reward_predictor = context_decoder agent = SmacAgent( latent_dim, context_encoder, policy, reward_predictor, use_next_obs_in_context=use_next_obs_in_context, _debug_ignore_context=networks_ignore_context, _debug_use_ground_truth_context=use_ground_truth_context, ) trainer = SmacTrainer( agent=agent, env=expl_env, latent_dim=latent_dim, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, reward_predictor=reward_predictor, context_encoder=context_encoder, context_decoder=context_decoder, _debug_ignore_context=networks_ignore_context, _debug_use_ground_truth_context=use_ground_truth_context, **trainer_kwargs) algorithm = MetaRLAlgorithm( agent=agent, env=expl_env, trainer=trainer, train_task_indices=train_task_idxs, eval_task_indices=eval_task_idxs, train_tasks=train_tasks, eval_tasks=eval_tasks, use_next_obs_in_context=use_next_obs_in_context, use_ground_truth_context=use_ground_truth_context, env_info_sizes=get_env_info_sizes(expl_env), **algo_kwargs) if macaw_format_base_path: load_macaw_buffer_onto_algo(algo=algorithm, base_directory=macaw_format_base_path, train_task_idxs=train_task_idxs, **load_macaw_buffer_kwargs) elif load_buffer_kwargs: load_buffer_onto_algo(algorithm, **load_buffer_kwargs) if relabel_offline_dataset: relabel_offline_data(algorithm, tasks=tasks, env=expl_env.wrapped_env, **relabel_kwargs) if path_loader_kwargs: replay_buffer = algorithm.replay_buffer.task_buffers[0] enc_replay_buffer = algorithm.enc_replay_buffer.task_buffers[0] demo_test_buffer = EnvReplayBuffer(env=expl_env, **pretrain_buffer_kwargs) path_loader = MDPPathLoader(trainer, replay_buffer=replay_buffer, demo_train_buffer=enc_replay_buffer, demo_test_buffer=demo_test_buffer, **path_loader_kwargs) path_loader.load_demos() if pretrain_rl: eval_pearl_fn = EvalPearl(algorithm, train_task_idxs, eval_task_idxs) pretrain_algo = OfflineMetaRLAlgorithm( meta_replay_buffer=algorithm.meta_replay_buffer, replay_buffer=algorithm.replay_buffer, task_embedding_replay_buffer=algorithm.enc_replay_buffer, trainer=trainer, train_tasks=train_task_idxs, extra_eval_fns=[eval_pearl_fn], use_meta_learning_buffer=algorithm.use_meta_learning_buffer, **pretrain_offline_algo_kwargs) pretrain_algo.to(ptu.device) logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('pretrain.csv', relative_to_snapshot_dir=True) pretrain_algo.train() logger.remove_tabular_output('pretrain.csv', relative_to_snapshot_dir=True) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) if skip_initial_data_collection_if_pretrained: algorithm.num_initial_steps = 0 algorithm.trainer.configure(**online_trainer_kwargs) algorithm.to(ptu.device) algorithm.train()
def pretrain_policy_with_bc(self): if self.buffer_for_bc_training == "demos": self.bc_training_buffer = self.demo_train_buffer self.bc_test_buffer = self.demo_test_buffer elif self.buffer_for_bc_training == "replay_buffer": self.bc_training_buffer = self.replay_buffer.train_replay_buffer self.bc_test_buffer = self.replay_buffer.validation_replay_buffer else: self.bc_training_buffer = None self.bc_test_buffer = None if self.load_policy_path: self.policy = load_local_or_remote_file(self.load_policy_path) ptu.copy_model_params_from_to(self.policy, self.target_policy) return logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('pretrain_policy.csv', relative_to_snapshot_dir=True) if self.do_pretrain_rollouts: total_ret = self.do_rollouts() print("INITIAL RETURN", total_ret / 20) prev_time = time.time() for i in range(self.bc_num_pretrain_steps): train_policy_loss, train_logp_loss, train_mse_loss, train_log_std = self.run_bc_batch( self.demo_train_buffer, self.policy) train_policy_loss = train_policy_loss * self.bc_weight self.policy_optimizer.zero_grad() train_policy_loss.backward() self.policy_optimizer.step() test_policy_loss, test_logp_loss, test_mse_loss, test_log_std = self.run_bc_batch( self.demo_test_buffer, self.policy) test_policy_loss = test_policy_loss * self.bc_weight if self.do_pretrain_rollouts and i % self.pretraining_env_logging_period == 0: total_ret = self.do_rollouts() print("Return at step {} : {}".format(i, total_ret / 20)) if i % self.pretraining_logging_period == 0: stats = { "pretrain_bc/batch": i, "pretrain_bc/Train Logprob Loss": ptu.get_numpy(train_logp_loss), "pretrain_bc/Test Logprob Loss": ptu.get_numpy(test_logp_loss), "pretrain_bc/Train MSE": ptu.get_numpy(train_mse_loss), "pretrain_bc/Test MSE": ptu.get_numpy(test_mse_loss), "pretrain_bc/train_policy_loss": ptu.get_numpy(train_policy_loss), "pretrain_bc/test_policy_loss": ptu.get_numpy(test_policy_loss), "pretrain_bc/epoch_time": time.time() - prev_time, } if self.do_pretrain_rollouts: stats["pretrain_bc/avg_return"] = total_ret / 20 logger.record_dict(stats) logger.dump_tabular(with_prefix=True, with_timestamp=False) pickle.dump(self.policy, open(logger.get_snapshot_dir() + '/bc.pkl', "wb")) prev_time = time.time() logger.remove_tabular_output( 'pretrain_policy.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) ptu.copy_model_params_from_to(self.policy, self.target_policy) if self.post_bc_pretrain_hyperparams: self.set_algorithm_weights(**self.post_bc_pretrain_hyperparams)
def train_vae_and_update_variant(variant): # actually pretrain vae and ROLL. skewfit_variant = variant['skewfit_variant'] train_vae_variant = variant['train_vae_variant'] # prepare the background subtractor needed to perform segmentation if 'unet' in skewfit_variant['segmentation_method']: print("training opencv background model!") v = train_vae_variant['generate_lstm_dataset_kwargs'] env_id = v.get('env_id', None) env_id_invis = invisiable_env_id[env_id] import gym import multiworld multiworld.register_all_envs() obj_invisible_env = gym.make(env_id_invis) init_camera = v.get('init_camera', None) presampled_goals = None if skewfit_variant.get("presampled_goals_path") is not None: presampled_goals = load_local_or_remote_file( skewfit_variant['presampled_goals_path']).item() print("presampled goal path is: ", skewfit_variant['presampled_goals_path']) obj_invisible_env = ImageEnv( obj_invisible_env, v.get('imsize'), init_camera=init_camera, transpose=True, normalize=True, presampled_goals=presampled_goals, ) train_num = 2000 if 'Push' in env_id else 4000 train_bgsb(obj_invisible_env, train_num=train_num) if skewfit_variant.get('vae_path', None) is None: # train new vaes logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) vaes, vae_train_datas, vae_test_datas = train_vae( train_vae_variant, skewfit_variant=skewfit_variant, return_data=True) # one original vae, one segmented ROLL. if skewfit_variant.get('save_vae_data', False): skewfit_variant['vae_train_data'] = vae_train_datas skewfit_variant['vae_test_data'] = vae_test_datas logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) skewfit_variant['vae_path'] = vaes # just pass the VAE directly else: # load pre-trained vaes print("load pretrain scene-/objce-VAE from: {}".format( skewfit_variant['vae_path'])) data = torch.load(osp.join(skewfit_variant['vae_path'], 'params.pkl')) vae_original = data['vae_original'] vae_segmented = data['lstm_segmented'] skewfit_variant['vae_path'] = [vae_segmented, vae_original] generate_vae_dataset_fctn = train_vae_variant.get( 'generate_vae_data_fctn', generate_vae_dataset) generate_lstm_dataset_fctn = train_vae_variant.get( 'generate_lstm_data_fctn') assert generate_lstm_dataset_fctn is not None, "Must provide a custom generate lstm pretraining dataset function!" train_data_lstm, test_data_lstm, info_lstm = generate_lstm_dataset_fctn( train_vae_variant['generate_lstm_dataset_kwargs'], segmented=True, segmentation_method=skewfit_variant['segmentation_method']) train_data_ori, test_data_ori, info_ori = generate_vae_dataset_fctn( train_vae_variant['generate_vae_dataset_kwargs']) train_datas = [train_data_lstm, train_data_ori] test_datas = [test_data_lstm, test_data_ori] if skewfit_variant.get('save_vae_data', False): skewfit_variant['vae_train_data'] = train_datas skewfit_variant['vae_test_data'] = test_datas
def train_vae( variant, return_data=False, skewfit_variant=None): # acutally train both the vae and the lstm from rlkit.util.ml_util import PiecewiseLinearSchedule from rlkit.torch.vae.conv_vae import ( ConvVAE, ) import rlkit.torch.vae.conv_vae as conv_vae import ROLL.LSTM_model as LSTM_model from ROLL.LSTM_model import ConvLSTM2 from ROLL.LSTM_trainer import ConvLSTMTrainer from rlkit.torch.vae.vae_trainer import ConvVAETrainer import rlkit.torch.pytorch_util as ptu from rlkit.pythonplusplus import identity import torch seg_pretrain = variant['seg_pretrain'] ori_pretrain = variant['ori_pretrain'] generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn', generate_vae_dataset) generate_lstm_dataset_fctn = variant.get('generate_lstm_data_fctn') assert generate_lstm_dataset_fctn is not None, "Must provide a custom generate lstm pretraining dataset function!" train_data_lstm, test_data_lstm, info_lstm = generate_lstm_dataset_fctn( variant['generate_lstm_dataset_kwargs'], segmented=True, segmentation_method=skewfit_variant['segmentation_method']) train_data_ori, test_data_ori, info_ori = generate_vae_dataset_fctn( variant['generate_vae_dataset_kwargs']) if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None if variant.get('decoder_activation', None) == 'sigmoid': decoder_activation = torch.nn.Sigmoid() else: decoder_activation = identity architecture = variant['vae_kwargs'].get('architecture', None) if not architecture and variant.get('imsize') == 84: architecture = conv_vae.imsize84_default_architecture elif not architecture and variant.get('imsize') == 48: architecture = conv_vae.imsize48_default_architecture variant['vae_kwargs']['architecture'] = architecture variant['vae_kwargs']['imsize'] = variant.get('imsize') architecture = variant['lstm_kwargs'].get('architecture', None) if not architecture and variant.get('imsize') == 84: architecture = None # TODO LSTM: wrap a 84 lstm architecutre elif not architecture and variant.get('imsize') == 48: architecture = LSTM_model.imsize48_default_architecture variant['lstm_kwargs']['architecture'] = architecture variant['lstm_kwargs']['imsize'] = variant.get('imsize') train_datas = [ train_data_lstm, train_data_ori, ] test_datas = [ test_data_lstm, test_data_ori, ] names = [ 'lstm_seg_pretrain', 'vae_ori_pretrain', ] vaes = [] env_id = variant['generate_lstm_dataset_kwargs'].get('env_id') assert env_id is not None lstm_pretrain_vae_only = variant.get('lstm_pretrain_vae_only', False) for idx in range(2): train_data, test_data, name = train_datas[idx], test_datas[idx], names[ idx] logger.add_tabular_output('{}_progress.csv'.format(name), relative_to_snapshot_dir=True) if idx == 1: # train the original vae representation_size = variant.get( "vae_representation_size", variant.get('representation_size')) beta = variant.get('vae_beta', variant.get('beta')) m = ConvVAE(representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) else: # train the segmentation lstm lstm_version = variant.get('lstm_version', 2) if lstm_version == 2: lstm_class = ConvLSTM2 representation_size = variant.get( "lstm_representation_size", variant.get('representation_size')) beta = variant.get('lstm_beta', variant.get('beta')) m = lstm_class(representation_size, decoder_output_activation=decoder_activation, **variant['lstm_kwargs']) t = ConvLSTMTrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) m.to(ptu.device) vaes.append(m) print("test data len: ", len(test_data)) print("train data len: ", len(train_data)) save_period = variant['save_period'] pjhome = os.environ['PJHOME'] if env_id == 'SawyerPushHurdle-v0' and osp.exists( osp.join( pjhome, 'data/local/pre-train-lstm', '{}-{}-{}-0.3-0.5.npy'.format( 'SawyerPushHurdle-v0', 'seg-color', '500'))): data_file_path = osp.join( pjhome, 'data/local/pre-train-lstm', '{}-{}-{}-0.3-0.5.npy'.format(env_id, 'seg-color', 500)) puck_pos_path = osp.join( pjhome, 'data/local/pre-train-lstm', '{}-{}-{}-0.3-0.5-puck-pos.npy'.format(env_id, 'seg-color', 500)) all_data = np.load(data_file_path) puck_pos = np.load(puck_pos_path) all_data = normalize_image(all_data.copy()) obj_states = puck_pos else: all_data = np.concatenate([train_data_lstm, test_data_lstm], axis=0) all_data = normalize_image(all_data.copy()) obj_states = info_lstm['obj_state'] obj = 'door' if 'Door' in env_id else 'puck' num_epochs = variant['num_lstm_epochs'] if idx == 0 else variant[ 'num_vae_epochs'] if (idx == 0 and seg_pretrain) or (idx == 1 and ori_pretrain): for epoch in range(num_epochs): should_save_imgs = (epoch % save_period == 0) if idx == 0: # only LSTM trainer has 'only_train_vae' argument t.train_epoch(epoch, only_train_vae=lstm_pretrain_vae_only) t.test_epoch(epoch, save_reconstruction=should_save_imgs, save_prefix='r_' + name, only_train_vae=lstm_pretrain_vae_only) else: t.train_epoch(epoch) t.test_epoch( epoch, save_reconstruction=should_save_imgs, save_prefix='r_' + name, ) if should_save_imgs: t.dump_samples(epoch, save_prefix='s_' + name) if idx == 0: compare_latent_distance( m, all_data, obj_states, obj_name=obj, save_dir=logger.get_snapshot_dir(), save_name='lstm_latent_distance_{}.png'.format( epoch)) test_lstm_traj( env_id, m, save_path=logger.get_snapshot_dir(), save_name='lstm_test_traj_{}.png'.format(epoch)) test_masked_traj_lstm( env_id, m, save_dir=logger.get_snapshot_dir(), save_name='masked_test_{}.png'.format(epoch)) t.update_train_weights() logger.save_extra_data(m, '{}.pkl'.format(name), mode='pickle') logger.remove_tabular_output('{}_progress.csv'.format(name), relative_to_snapshot_dir=True) if idx == 0 and variant.get("only_train_lstm", False): exit() if return_data: return vaes, train_datas, test_datas return m
def get_n_train_vae(latent_dim, env, vae_train_epochs, num_image_examples, vae_kwargs, vae_trainer_kwargs, vae_architecture, vae_save_period=10, vae_test_p=.9, decoder_activation='sigmoid', vae_class='VAE', **kwargs): env.goal_sampling_mode = 'test' image_examples = unnormalize_image( env.sample_goals(num_image_examples)['desired_goal']) n = int(num_image_examples * vae_test_p) train_dataset = ImageObservationDataset(image_examples[:n, :]) test_dataset = ImageObservationDataset(image_examples[n:, :]) if decoder_activation == 'sigmoid': decoder_activation = torch.nn.Sigmoid() vae_class = vae_class.lower() if vae_class == 'VAE'.lower(): vae_class = ConvVAE elif vae_class == 'SpatialVAE'.lower(): vae_class = SpatialAutoEncoder else: raise RuntimeError("Invalid VAE Class: {}".format(vae_class)) vae = vae_class(latent_dim, architecture=vae_architecture, decoder_output_activation=decoder_activation, **vae_kwargs) trainer = ConvVAETrainer(vae, **vae_trainer_kwargs) logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('vae_progress.csv', relative_to_snapshot_dir=True) for epoch in range(vae_train_epochs): should_save_imgs = (epoch % vae_save_period == 0) trainer.train_epoch(epoch, train_dataset) trainer.test_epoch(epoch, test_dataset) if should_save_imgs: trainer.dump_reconstructions(epoch) trainer.dump_samples(epoch) stats = trainer.get_diagnostics() for k, v in stats.items(): logger.record_tabular(k, v) logger.dump_tabular() trainer.end_epoch(epoch) if epoch % 50 == 0: logger.save_itr_params(epoch, vae) logger.save_extra_data(vae, 'vae.pkl', mode='pickle') logger.remove_tabular_output('vae_progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('progress.csv', relative_to_snapshot_dir=True) return vae
def run_experiment(argv): default_log_dir = config.LOCAL_LOG_DIR now = datetime.datetime.now(dateutil.tz.tzlocal()) # avoid name clashes when running distributed jobs rand_id = str(uuid.uuid4())[:5] timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z') default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id) parser = argparse.ArgumentParser() parser.add_argument( '--n_parallel', type=int, default=1, help= 'Number of parallel workers to perform rollouts. 0 => don\'t start any workers' ) parser.add_argument('--exp_name', type=str, default=default_exp_name, help='Name of the experiment.') parser.add_argument('--log_dir', type=str, default=None, help='Path to save the log and iteration snapshot.') parser.add_argument('--snapshot_mode', type=str, default='all', help='Mode to save the snapshot. Can be either "all" ' '(all iterations will be saved), "last" (only ' 'the last iteration will be saved), "gap" (every' '`snapshot_gap` iterations are saved), or "none" ' '(do not save snapshots)') parser.add_argument('--snapshot_gap', type=int, default=1, help='Gap between snapshot iterations.') parser.add_argument('--tabular_log_file', type=str, default='progress.csv', help='Name of the tabular log file (in csv).') parser.add_argument('--text_log_file', type=str, default='debug.log', help='Name of the text log file (in pure text).') parser.add_argument('--params_log_file', type=str, default='params.json', help='Name of the parameter log file (in json).') parser.add_argument('--variant_log_file', type=str, default='variant.json', help='Name of the variant log file (in json).') parser.add_argument( '--resume_from', type=str, default=None, help='Name of the pickle file to resume experiment from.') parser.add_argument('--plot', type=ast.literal_eval, default=False, help='Whether to plot the iteration results') parser.add_argument( '--log_tabular_only', type=ast.literal_eval, default=False, help= 'Whether to only print the tabular log information (in a horizontal format)' ) parser.add_argument('--seed', type=int, help='Random seed for numpy') parser.add_argument('--args_data', type=str, help='Pickled data for stub objects') parser.add_argument('--variant_data', type=str, help='Pickled data for variant configuration') parser.add_argument('--use_cloudpickle', type=ast.literal_eval, default=False) parser.add_argument('--code_diff', type=str, help='A string of the code diff to save.') parser.add_argument('--commit_hash', type=str, help='A string of the commit hash') parser.add_argument('--script_name', type=str, help='Name of the launched script') args = parser.parse_args(argv[1:]) if args.seed is not None: set_seed(args.seed) if args.n_parallel > 0: from rllab.sampler import parallel_sampler parallel_sampler.initialize(n_parallel=args.n_parallel) if args.seed is not None: parallel_sampler.set_seed(args.seed) if args.plot: from rllab.plotter import plotter plotter.init_worker() if args.log_dir is None: log_dir = osp.join(default_log_dir, args.exp_name) else: log_dir = args.log_dir tabular_log_file = osp.join(log_dir, args.tabular_log_file) text_log_file = osp.join(log_dir, args.text_log_file) params_log_file = osp.join(log_dir, args.params_log_file) if args.variant_data is not None: variant_data = pickle.loads(base64.b64decode(args.variant_data)) variant_log_file = osp.join(log_dir, args.variant_log_file) logger.log_variant(variant_log_file, variant_data) else: variant_data = None if not args.use_cloudpickle: raise NotImplementedError("Not supporting non-cloud-pickle") logger.add_text_output(text_log_file) logger.add_tabular_output(tabular_log_file) prev_snapshot_dir = logger.get_snapshot_dir() prev_mode = logger.get_snapshot_mode() logger.set_snapshot_dir(log_dir) logger.set_snapshot_mode(args.snapshot_mode) logger.set_snapshot_gap(args.snapshot_gap) logger.set_log_tabular_only(args.log_tabular_only) logger.push_prefix("[%s] " % args.exp_name) """ Save information for code reproducibility. """ if args.code_diff is not None: code_diff_str = cloudpickle.loads(base64.b64decode(args.code_diff)) with open(osp.join(log_dir, "code.diff"), "w") as f: f.write(code_diff_str) if args.commit_hash is not None: with open(osp.join(log_dir, "commit_hash.txt"), "w") as f: f.write(args.commit_hash) if args.script_name is not None: with open(osp.join(log_dir, "script_name.txt"), "w") as f: f.write(args.script_name) if args.resume_from is not None: data = joblib.load(args.resume_from) assert 'algo' in data algo = data['algo'] algo.train() else: # read from stdin if args.use_cloudpickle: method_call = cloudpickle.loads(base64.b64decode(args.args_data)) method_call(variant_data) else: data = pickle.loads(base64.b64decode(args.args_data)) maybe_iter = concretize(data) if is_iterable(maybe_iter): for _ in maybe_iter: pass logger.set_snapshot_mode(prev_mode) logger.set_snapshot_dir(prev_snapshot_dir) logger.remove_tabular_output(tabular_log_file) logger.remove_text_output(text_log_file) logger.pop_prefix()