def train_vae_and_update_variant(variant): from railrl.core import logger grill_variant = variant['grill_variant'] train_vae_variant = variant['train_vae_variant'] if grill_variant.get('vae_path', None) is None: logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('vae_progress.csv', relative_to_snapshot_dir=True) vae, vae_train_data, vae_test_data = train_vae(train_vae_variant, return_data=True) if grill_variant.get('save_vae_data', False): grill_variant['vae_train_data'] = vae_train_data grill_variant['vae_test_data'] = vae_test_data logger.save_extra_data(vae, 'vae.pkl', mode='pickle') logger.remove_tabular_output( 'vae_progress.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) grill_variant['vae_path'] = vae # just pass the VAE directly else: if grill_variant.get('save_vae_data', False): vae_train_data, vae_test_data, info = generate_vae_dataset( train_vae_variant['generate_vae_dataset_kwargs']) grill_variant['vae_train_data'] = vae_train_data grill_variant['vae_test_data'] = vae_test_data
def pretrain_policy_with_bc(self): logger.remove_tabular_output( 'progress.csv', relative_to_snapshot_dir=True ) logger.add_tabular_output( 'pretrain_policy.csv', relative_to_snapshot_dir=True ) if self.do_pretrain_rollouts: total_ret = self.do_rollouts() print("INITIAL RETURN", total_ret/20) prev_time = time.time() for i in range(self.bc_num_pretrain_steps): train_policy_loss, train_logp_loss, train_mse_loss, train_log_std = self.run_bc_batch(self.demo_train_buffer, self.policy) train_policy_loss = train_policy_loss * self.bc_weight self.policy_optimizer.zero_grad() train_policy_loss.backward() self.policy_optimizer.step() test_policy_loss, test_logp_loss, test_mse_loss, test_log_std = self.run_bc_batch(self.demo_test_buffer, self.policy) test_policy_loss = test_policy_loss * self.bc_weight if self.do_pretrain_rollouts and i % self.pretraining_env_logging_period == 0: total_ret = self.do_rollouts() print("Return at step {} : {}".format(i, total_ret/20)) if i % self.pretraining_logging_period==0: stats = { "pretrain_bc/batch": i, "pretrain_bc/Train Logprob Loss": ptu.get_numpy(train_logp_loss), "pretrain_bc/Test Logprob Loss": ptu.get_numpy(test_logp_loss), "pretrain_bc/Train MSE": ptu.get_numpy(train_mse_loss), "pretrain_bc/Test MSE": ptu.get_numpy(test_mse_loss), "pretrain_bc/train_policy_loss": ptu.get_numpy(train_policy_loss), "pretrain_bc/test_policy_loss": ptu.get_numpy(test_policy_loss), "pretrain_bc/epoch_time":time.time()-prev_time, } if self.do_pretrain_rollouts: stats["pretrain_bc/avg_return"] = total_ret / 20 logger.record_dict(stats) logger.dump_tabular(with_prefix=True, with_timestamp=False) pickle.dump(self.policy, open(logger.get_snapshot_dir() + '/bc.pkl', "wb")) prev_time = time.time() logger.remove_tabular_output( 'pretrain_policy.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) if self.post_bc_pretrain_hyperparams: self.set_algorithm_weights(**self.post_bc_pretrain_hyperparams)
def pretrain_q_with_bc_data(self): logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('pretrain_q.csv', relative_to_snapshot_dir=True) self.update_policy = False # first train only the Q function for i in range(self.q_num_pretrain_steps): self.eval_statistics = dict() self._need_to_update_eval_statistics = True train_data = self.replay_buffer.random_batch(128) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] if self.goal_conditioned: goals = train_data['resampled_goals'] train_data['observations'] = torch.cat((obs, goals), dim=1) train_data['next_observations'] = torch.cat((next_obs, goals), dim=1) self.train_from_torch(train_data) logger.record_dict(self.eval_statistics) logger.dump_tabular(with_prefix=True, with_timestamp=False) self.update_policy = True # then train policy and Q function together for i in range(self.q_num_pretrain_steps): self.eval_statistics = dict() self._need_to_update_eval_statistics = True train_data = self.replay_buffer.random_batch(128) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] if self.goal_conditioned: goals = train_data['resampled_goals'] train_data['observations'] = torch.cat((obs, goals), dim=1) train_data['next_observations'] = torch.cat((next_obs, goals), dim=1) self.train_from_torch(train_data) logger.record_dict(self.eval_statistics) logger.dump_tabular(with_prefix=True, with_timestamp=False) logger.remove_tabular_output( 'pretrain_q.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, )
def grill_her_full_experiment(variant, mode='td3'): train_vae_variant = variant['train_vae_variant'] grill_variant = variant['grill_variant'] env_class = variant['env_class'] env_kwargs = variant['env_kwargs'] init_camera = variant['init_camera'] train_vae_variant['generate_vae_dataset_kwargs']['env_class'] = env_class train_vae_variant['generate_vae_dataset_kwargs']['env_kwargs'] = env_kwargs train_vae_variant['generate_vae_dataset_kwargs']['init_camera'] = init_camera grill_variant['env_class'] = env_class grill_variant['env_kwargs'] = env_kwargs grill_variant['init_camera'] = init_camera if 'vae_paths' not in grill_variant: logger.remove_tabular_output( 'progress.csv', relative_to_snapshot_dir=True ) logger.add_tabular_output( 'vae_progress.csv', relative_to_snapshot_dir=True ) vae = train_vae(train_vae_variant) rdim = train_vae_variant['representation_size'] vae_file = logger.save_extra_data(vae, 'vae.pkl', mode='pickle') logger.remove_tabular_output( 'vae_progress.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) grill_variant['vae_paths'] = { str(rdim): vae_file, } grill_variant['rdim'] = str(rdim) if mode == 'td3': grill_her_td3_experiment(variant['grill_variant']) elif mode == 'twin-sac': grill_her_twin_sac_experiment(variant['grill_variant']) elif mode == 'sac': grill_her_sac_experiment(variant['grill_variant'])
def get_n_train_vae(latent_dim, env, vae_train_epochs, num_image_examples, vae_kwargs, vae_trainer_kwargs, vae_architecture, vae_save_period=10, vae_test_p=.9, decoder_activation='sigmoid', vae_class='VAE', **kwargs): env.goal_sampling_mode = 'test' image_examples = unnormalize_image( env.sample_goals(num_image_examples)['desired_goal']) n = int(num_image_examples * vae_test_p) train_dataset = ImageObservationDataset(image_examples[:n, :]) test_dataset = ImageObservationDataset(image_examples[n:, :]) if decoder_activation == 'sigmoid': decoder_activation = torch.nn.Sigmoid() vae_class = vae_class.lower() if vae_class == 'VAE'.lower(): vae_class = ConvVAE elif vae_class == 'SpatialVAE'.lower(): vae_class = SpatialAutoEncoder else: raise RuntimeError("Invalid VAE Class: {}".format(vae_class)) vae = vae_class(latent_dim, architecture=vae_architecture, decoder_output_activation=decoder_activation, **vae_kwargs) trainer = ConvVAETrainer(vae, **vae_trainer_kwargs) logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('vae_progress.csv', relative_to_snapshot_dir=True) for epoch in range(vae_train_epochs): should_save_imgs = (epoch % vae_save_period == 0) trainer.train_epoch(epoch, train_dataset) trainer.test_epoch(epoch, test_dataset) if should_save_imgs: trainer.dump_reconstructions(epoch) trainer.dump_samples(epoch) stats = trainer.get_diagnostics() for k, v in stats.items(): logger.record_tabular(k, v) logger.dump_tabular() trainer.end_epoch(epoch) if epoch % 50 == 0: logger.save_itr_params(epoch, vae) logger.save_extra_data(vae, 'vae.pkl', mode='pickle') logger.remove_tabular_output('vae_progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('progress.csv', relative_to_snapshot_dir=True) return vae
def pretrain_q_with_bc_data(self): logger.remove_tabular_output( 'progress.csv', relative_to_snapshot_dir=True ) logger.add_tabular_output( 'pretrain_q.csv', relative_to_snapshot_dir=True ) self.update_policy = False # first train only the Q function for i in range(self.q_num_pretrain1_steps): self.eval_statistics = dict() train_data = self.replay_buffer.random_batch(self.bc_batch_size) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] # goals = train_data['resampled_goals'] train_data['observations'] = obs # torch.cat((obs, goals), dim=1) train_data['next_observations'] = next_obs # torch.cat((next_obs, goals), dim=1) self.train_from_torch(train_data) if i%self.pretraining_logging_period == 0: logger.record_dict(self.eval_statistics) logger.dump_tabular(with_prefix=True, with_timestamp=False) self.update_policy = True # then train policy and Q function together prev_time = time.time() for i in range(self.q_num_pretrain2_steps): self.eval_statistics = dict() if i % self.pretraining_logging_period == 0: self._need_to_update_eval_statistics=True train_data = self.replay_buffer.random_batch(self.bc_batch_size) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] # goals = train_data['resampled_goals'] train_data['observations'] = obs # torch.cat((obs, goals), dim=1) train_data['next_observations'] = next_obs # torch.cat((next_obs, goals), dim=1) self.train_from_torch(train_data) if self.do_pretrain_rollouts and i % self.pretraining_env_logging_period == 0: total_ret = self.do_rollouts() print("Return at step {} : {}".format(i, total_ret/20)) if i%self.pretraining_logging_period==0: if self.do_pretrain_rollouts: self.eval_statistics["pretrain_bc/avg_return"] = total_ret / 20 self.eval_statistics["batch"] = i self.eval_statistics["epoch_time"] = time.time()-prev_time logger.record_dict(self.eval_statistics) logger.dump_tabular(with_prefix=True, with_timestamp=False) prev_time = time.time() logger.remove_tabular_output( 'pretrain_q.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) self._need_to_update_eval_statistics = True self.eval_statistics = dict() if self.post_pretrain_hyperparams: self.set_algorithm_weights(**self.post_pretrain_hyperparams)
def pretrain_policy_with_bc(self): logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('pretrain_policy.csv', relative_to_snapshot_dir=True) for i in range(self.bc_num_pretrain_steps): train_batch = self.get_batch_from_buffer(self.demo_train_buffer) train_o = train_batch["observations"] train_u = train_batch["actions"] if self.goal_conditioned: train_g = train_batch["resampled_goals"] train_o = torch.cat((train_o, train_g), dim=1) train_pred_u = self.policy(train_o) train_error = (train_pred_u - train_u)**2 train_bc_loss = train_error.mean() policy_loss = self.bc_weight * train_bc_loss.mean() self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() test_batch = self.get_batch_from_buffer(self.demo_test_buffer) test_o = test_batch["observations"] test_u = test_batch["actions"] if self.goal_conditioned: test_g = test_batch["resampled_goals"] test_o = torch.cat((test_o, test_g), dim=1) test_pred_u = self.policy(test_o) test_error = (test_pred_u - test_u)**2 test_bc_loss = test_error.mean() train_loss_mean = np.mean(ptu.get_numpy(train_bc_loss)) test_loss_mean = np.mean(ptu.get_numpy(test_bc_loss)) stats = { "Train BC Loss": train_loss_mean, "Test BC Loss": test_loss_mean, "policy_loss": ptu.get_numpy(policy_loss), "batch": i, } logger.record_dict(stats) logger.dump_tabular(with_prefix=True, with_timestamp=False) if i % 1000 == 0: logger.save_itr_params( i, { "evaluation/policy": self.policy, "evaluation/env": self.env.wrapped_env, }) logger.remove_tabular_output( 'pretrain_policy.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, )
def train_vae(variant): from railrl.misc.ml_util import PiecewiseLinearSchedule from railrl.torch.vae.conv_vae import ConvVAE from railrl.torch.vae.conv_vae_trainer import ConvVAETrainer from railrl.core import logger import railrl.torch.pytorch_util as ptu from multiworld.core.image_env import ImageEnv from railrl.envs.vae_wrappers import VAEWrappedEnv from railrl.misc.asset_loader import local_path_from_s3_or_local_path logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('vae_progress.csv', relative_to_snapshot_dir=True) env_id = variant['generate_vae_dataset_kwargs'].get('env_id', None) if env_id is not None: import gym env = gym.make(env_id) else: env_class = variant['generate_vae_dataset_kwargs']['env_class'] env_kwargs = variant['generate_vae_dataset_kwargs']['env_kwargs'] env = env_class(**env_kwargs) representation_size = variant["representation_size"] beta = variant["beta"] if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None # obtain training and testing data dataset_path = variant['generate_vae_dataset_kwargs'].get( 'dataset_path', None) test_p = variant['generate_vae_dataset_kwargs'].get('test_p', 0.9) filename = local_path_from_s3_or_local_path(dataset_path) dataset = np.load(filename, allow_pickle=True).item() N = dataset['obs'].shape[0] n = int(N * test_p) train_data = {} test_data = {} for k in dataset.keys(): train_data[k] = dataset[k][:n, :] test_data[k] = dataset[k][n:, :] # setup vae variant['vae_kwargs']['action_dim'] = train_data['actions'].shape[1] if variant.get('vae_type', None) == "VAE-state": from railrl.torch.vae.vae import VAE input_size = train_data['obs'].shape[1] variant['vae_kwargs']['input_size'] = input_size m = VAE(representation_size, **variant['vae_kwargs']) elif variant.get('vae_type', None) == "VAE2": from railrl.torch.vae.conv_vae2 import ConvVAE2 variant['vae_kwargs']['imsize'] = variant['imsize'] m = ConvVAE2(representation_size, **variant['vae_kwargs']) else: variant['vae_kwargs']['imsize'] = variant['imsize'] m = ConvVAE(representation_size, **variant['vae_kwargs']) if ptu.gpu_enabled(): m.cuda() # setup vae trainer if variant.get('vae_type', None) == "VAE-state": from railrl.torch.vae.vae_trainer import VAETrainer t = VAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) else: t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) # visualization vis_variant = variant.get('vis_kwargs', {}) save_video = vis_variant.get('save_video', False) if isinstance(env, ImageEnv): image_env = env else: image_env = ImageEnv( env, variant['generate_vae_dataset_kwargs'].get('imsize'), init_camera=variant['generate_vae_dataset_kwargs'].get( 'init_camera'), transpose=True, normalize=True, ) render = variant.get('render', False) reward_params = variant.get("reward_params", dict()) vae_env = VAEWrappedEnv(image_env, m, imsize=image_env.imsize, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, **variant.get('vae_wrapped_env_kwargs', {})) vae_env.reset() vae_env.add_mode("video_env", 'video_env') vae_env.add_mode("video_vae", 'video_vae') if save_video: import railrl.samplers.rollout_functions as rf from railrl.policies.simple import RandomPolicy random_policy = RandomPolicy(vae_env.action_space) rollout_function = rf.create_rollout_function( rf.multitask_rollout, max_path_length=100, observation_key='latent_observation', desired_goal_key='latent_desired_goal', vis_list=vis_variant.get('vis_list', []), dont_terminate=True, ) dump_video_kwargs = variant.get("dump_video_kwargs", dict()) dump_video_kwargs['imsize'] = vae_env.imsize dump_video_kwargs['vis_list'] = [ 'image_observation', 'reconstr_image_observation', 'image_latent_histogram_2d', 'image_latent_histogram_mu_2d', 'image_plt', 'image_rew', 'image_rew_euclidean', 'image_rew_mahalanobis', 'image_rew_logp', 'image_rew_kl', 'image_rew_kl_rev', ] def visualization_post_processing(save_vis, save_video, epoch): vis_list = vis_variant.get('vis_list', []) if save_vis: if vae_env.vae_input_key_prefix == 'state': vae_env.dump_reconstructions(epoch, n_recon=vis_variant.get( 'n_recon', 16)) vae_env.dump_samples(epoch, n_samples=vis_variant.get('n_samples', 64)) if 'latent_representation' in vis_list: vae_env.dump_latent_plots(epoch) if any(elem in vis_list for elem in [ 'latent_histogram', 'latent_histogram_mu', 'latent_histogram_2d', 'latent_histogram_mu_2d' ]): vae_env.compute_latent_histogram() if not save_video and ('latent_histogram' in vis_list): vae_env.dump_latent_histogram(epoch=epoch, noisy=True, use_true_prior=True) if not save_video and ('latent_histogram_mu' in vis_list): vae_env.dump_latent_histogram(epoch=epoch, noisy=False, use_true_prior=True) if save_video and save_vis: from railrl.envs.vae_wrappers import temporary_mode from railrl.misc.video_gen import dump_video from railrl.core import logger vae_env.compute_goal_encodings() logdir = logger.get_snapshot_dir() filename = osp.join(logdir, 'video_{epoch}.mp4'.format(epoch=epoch)) variant['dump_video_kwargs']['epoch'] = epoch temporary_mode(vae_env, mode='video_env', func=dump_video, args=(vae_env, random_policy, filename, rollout_function), kwargs=variant['dump_video_kwargs']) if not vis_variant.get('save_video_env_only', True): filename = osp.join( logdir, 'video_{epoch}_vae.mp4'.format(epoch=epoch)) temporary_mode(vae_env, mode='video_vae', func=dump_video, args=(vae_env, random_policy, filename, rollout_function), kwargs=variant['dump_video_kwargs']) # train vae for epoch in range(variant['num_epochs']): #for epoch in range(2000): save_vis = (epoch % vis_variant['save_period'] == 0 or epoch == variant['num_epochs'] - 1) save_vae = (epoch % variant['snapshot_gap'] == 0 or epoch == variant['num_epochs'] - 1) t.train_epoch(epoch) '''if epoch % 500 == 0 or epoch == variant['num_epochs']-1: t.test_epoch( epoch, save_reconstruction=save_vis, save_interpolation=save_vis, save_vae=save_vae, ) if epoch % 200 == 0 or epoch == variant['num_epochs']-1: visualization_post_processing(save_video, save_video, epoch)''' t.test_epoch( epoch, save_reconstruction=save_vis, save_interpolation=save_vis, save_vae=save_vae, ) if epoch % 300 == 0 or epoch == variant['num_epochs'] - 1: visualization_post_processing(save_vis, save_video, epoch) logger.save_extra_data(m, 'vae.pkl', mode='pickle') logger.remove_tabular_output( 'vae_progress.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) print("finished --------------------!!!!!!!!!!!!!!!") return m
def train_reprojection_network_and_update_variant(variant): from railrl.core import logger from railrl.misc.asset_loader import load_local_or_remote_file import railrl.torch.pytorch_util as ptu rl_variant = variant.get("rl_variant", {}) vae_wrapped_env_kwargs = rl_variant.get('vae_wrapped_env_kwargs', {}) if vae_wrapped_env_kwargs.get("use_reprojection_network", False): train_reprojection_network_variant = variant.get( "train_reprojection_network_variant", {}) if train_reprojection_network_variant.get("use_cached_network", False): vae_path = rl_variant.get("vae_path", None) reprojection_network = load_local_or_remote_file( osp.join(vae_path, 'reproj_network.pkl')) from railrl.core import logger logger.save_extra_data(reprojection_network, 'reproj_network.pkl', mode='pickle') if ptu.gpu_enabled(): reprojection_network.cuda() vae_wrapped_env_kwargs[ 'reprojection_network'] = reprojection_network else: logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('reproj_progress.csv', relative_to_snapshot_dir=True) vae_path = rl_variant.get("vae_path", None) ckpt = rl_variant.get("ckpt", None) if type(ckpt) is str: vae = load_local_or_remote_file(osp.join(ckpt, 'vae.pkl')) from railrl.core import logger logger.save_extra_data(vae, 'vae.pkl', mode='pickle') elif type(vae_path) is str: vae = load_local_or_remote_file( osp.join(vae_path, 'vae_params.pkl')) from railrl.core import logger logger.save_extra_data(vae, 'vae.pkl', mode='pickle') else: vae = vae_path if type(vae) is str: vae = load_local_or_remote_file(vae) else: vae = vae if ptu.gpu_enabled(): vae.cuda() train_reprojection_network_variant['vae'] = vae reprojection_network = train_reprojection_network( train_reprojection_network_variant) vae_wrapped_env_kwargs[ 'reprojection_network'] = reprojection_network
def run_experiment(argv): default_log_dir = config.LOCAL_LOG_DIR now = datetime.datetime.now(dateutil.tz.tzlocal()) # avoid name clashes when running distributed jobs rand_id = str(uuid.uuid4())[:5] timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z') default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id) parser = argparse.ArgumentParser() parser.add_argument( '--n_parallel', type=int, default=1, help= 'Number of parallel workers to perform rollouts. 0 => don\'t start any workers' ) parser.add_argument('--exp_name', type=str, default=default_exp_name, help='Name of the experiment.') parser.add_argument('--log_dir', type=str, default=None, help='Path to save the log and iteration snapshot.') parser.add_argument('--snapshot_mode', type=str, default='all', help='Mode to save the snapshot. Can be either "all" ' '(all iterations will be saved), "last" (only ' 'the last iteration will be saved), "gap" (every' '`snapshot_gap` iterations are saved), or "none" ' '(do not save snapshots)') parser.add_argument('--snapshot_gap', type=int, default=1, help='Gap between snapshot iterations.') parser.add_argument('--tabular_log_file', type=str, default='progress.csv', help='Name of the tabular log file (in csv).') parser.add_argument('--text_log_file', type=str, default='debug.log', help='Name of the text log file (in pure text).') parser.add_argument('--params_log_file', type=str, default='params.json', help='Name of the parameter log file (in json).') parser.add_argument('--variant_log_file', type=str, default='variant.json', help='Name of the variant log file (in json).') parser.add_argument( '--resume_from', type=str, default=None, help='Name of the pickle file to resume experiment from.') parser.add_argument('--plot', type=ast.literal_eval, default=False, help='Whether to plot the iteration results') parser.add_argument( '--log_tabular_only', type=ast.literal_eval, default=False, help= 'Whether to only print the tabular log information (in a horizontal format)' ) parser.add_argument('--seed', type=int, help='Random seed for numpy') parser.add_argument('--args_data', type=str, help='Pickled data for stub objects') parser.add_argument('--variant_data', type=str, help='Pickled data for variant configuration') parser.add_argument('--use_cloudpickle', type=ast.literal_eval, default=False) parser.add_argument('--code_diff', type=str, help='A string of the code diff to save.') parser.add_argument('--commit_hash', type=str, help='A string of the commit hash') parser.add_argument('--script_name', type=str, help='Name of the launched script') args = parser.parse_args(argv[1:]) if args.seed is not None: set_seed(args.seed) if args.n_parallel > 0: from rllab.sampler import parallel_sampler parallel_sampler.initialize(n_parallel=args.n_parallel) if args.seed is not None: parallel_sampler.set_seed(args.seed) if args.plot: from rllab.plotter import plotter plotter.init_worker() if args.log_dir is None: log_dir = osp.join(default_log_dir, args.exp_name) else: log_dir = args.log_dir tabular_log_file = osp.join(log_dir, args.tabular_log_file) text_log_file = osp.join(log_dir, args.text_log_file) params_log_file = osp.join(log_dir, args.params_log_file) if args.variant_data is not None: variant_data = pickle.loads(base64.b64decode(args.variant_data)) variant_log_file = osp.join(log_dir, args.variant_log_file) logger.log_variant(variant_log_file, variant_data) else: variant_data = None if not args.use_cloudpickle: raise NotImplementedError("Not supporting non-cloud-pickle") logger.add_text_output(text_log_file) logger.add_tabular_output(tabular_log_file) prev_snapshot_dir = logger.get_snapshot_dir() prev_mode = logger.get_snapshot_mode() logger.set_snapshot_dir(log_dir) logger.set_snapshot_mode(args.snapshot_mode) logger.set_snapshot_gap(args.snapshot_gap) logger.set_log_tabular_only(args.log_tabular_only) logger.push_prefix("[%s] " % args.exp_name) """ Save information for code reproducibility. """ if args.code_diff is not None: code_diff_str = cloudpickle.loads(base64.b64decode(args.code_diff)) with open(osp.join(log_dir, "code.diff"), "w") as f: f.write(code_diff_str) if args.commit_hash is not None: with open(osp.join(log_dir, "commit_hash.txt"), "w") as f: f.write(args.commit_hash) if args.script_name is not None: with open(osp.join(log_dir, "script_name.txt"), "w") as f: f.write(args.script_name) if args.resume_from is not None: data = joblib.load(args.resume_from) assert 'algo' in data algo = data['algo'] algo.train() else: # read from stdin if args.use_cloudpickle: method_call = cloudpickle.loads(base64.b64decode(args.args_data)) method_call(variant_data) else: data = pickle.loads(base64.b64decode(args.args_data)) maybe_iter = concretize(data) if is_iterable(maybe_iter): for _ in maybe_iter: pass logger.set_snapshot_mode(prev_mode) logger.set_snapshot_dir(prev_snapshot_dir) logger.remove_tabular_output(tabular_log_file) logger.remove_text_output(text_log_file) logger.pop_prefix()