def test_epoch( self, epoch, ): self.model.eval() val_losses = [] per_dim_losses = np.zeros((self.num_batches, self.y_train.shape[1])) for batch in range(self.num_batches): inputs_np, labels_np = self.random_batch( self.X_test, self.y_test, batch_size=self.batch_size) inputs, labels = ptu.Variable( ptu.from_numpy(inputs_np)), ptu.Variable( ptu.from_numpy(labels_np)) outputs = self.model(inputs) loss = self.criterion(outputs, labels) val_losses.append(loss.data[0]) per_dim_loss = np.mean(np.power(ptu.get_numpy(outputs - labels), 2), axis=0) per_dim_losses[batch] = per_dim_loss logger.record_tabular("test/epoch", epoch) logger.record_tabular("test/loss", np.mean(np.array(val_losses))) for i in range(self.y_train.shape[1]): logger.record_tabular("test/dim " + str(i) + " loss", np.mean(per_dim_losses[:, i])) logger.dump_tabular()
def simulate_policy(args): data = joblib.load(args.file) model = data['model'] env = data['env'] orig_policy = data['mpc_controller'] print("Policy loaded") if args.pause: import ipdb ipdb.set_trace() policy = GradientBasedMPCController( env, model, mpc_horizon=1, num_grad_steps=10, learning_rate=1e-1, warm_start=False, ) while True: path = rollout( env, policy, orig_policy, max_path_length=args.H, animated=True, ) if hasattr(env, "log_diagnostics"): env.log_diagnostics([path]) logger.dump_tabular()
def experiment(variant): num_rollouts = variant['num_rollouts'] H = variant['H'] render = variant['render'] data = joblib.load(variant['qf_path']) qf = data['qf'] env = data['env'] qf_policy = data['policy'] if ptu.gpu_enabled(): qf.to(ptu.device) qf_policy.to(ptu.device) policy_class = variant['policy_class'] if policy_class == StateOnlySdqBasedSqpOcPolicy: policy = policy_class(qf, env, qf_policy, **variant['policy_params']) else: policy = policy_class(qf, env, **variant['policy_params']) paths = [] for _ in range(num_rollouts): goal = env.sample_goal_for_rollout() path = multitask_rollout( env, policy, goal, discount=variant['discount'], max_path_length=H, animated=render, ) paths.append(path) env.log_diagnostics(paths) logger.dump_tabular(with_timestamp=False)
def test_epoch(self, epoch, save_network=True, batches=100): self.model.eval() mses = [] losses = [] for batch_idx in range(batches): data = self.get_batch(train=False) z = data["z"] z_proj = data['z_proj'] z_proj_hat = self.model(z) mse = self.mse_loss(z_proj_hat, z_proj) loss = mse mses.append(mse.data[0]) losses.append(loss.data[0]) logger.record_tabular("test/epoch", epoch) logger.record_tabular("test/MSE", np.mean(mses)) logger.record_tabular("test/loss", np.mean(losses)) logger.dump_tabular() if save_network: logger.save_itr_params(epoch, self.model, prefix='reproj', save_anyway=True)
def experiment(variant): num_rollouts = variant['num_rollouts'] H = variant['H'] render = variant['render'] data = joblib.load(variant['qf_path']) policy_params = variant['policy_params'] if 'model' in data: model = data['model'] else: qf = data['qf'] model = ModelExtractor(qf) policy_params['model_learns_deltas'] = False env = data['env'] if ptu.gpu_enabled(): model.to(ptu.device) policy = variant['policy_class']( model, env, **policy_params ) paths = [] for _ in range(num_rollouts): goal = env.sample_goal_for_rollout() path = multitask_rollout( env, policy, goal, discount=0, max_path_length=H, animated=render, ) paths.append(path) env.log_diagnostics(paths) logger.dump_tabular(with_timestamp=False)
def evaluate(self, epoch): """ Perform evaluation for this algorithm. :param epoch: The epoch number. """ statistics = OrderedDict() train_batch = self.get_batch() statistics.update(self._statistics_from_batch(train_batch, "Train")) logger.log("Collecting samples for evaluation") test_paths = self._sample_eval_paths() statistics.update( get_generic_path_information( test_paths, stat_prefix="Test", )) statistics.update(self._statistics_from_paths(test_paths, "Test")) average_returns = get_average_returns(test_paths) statistics['AverageReturn'] = average_returns statistics['Epoch'] = epoch for key, value in statistics.items(): logger.record_tabular(key, value) self.env.log_diagnostics(test_paths) logger.dump_tabular(with_prefix=False, with_timestamp=False)
def train_vae(variant, return_data=False): from railrl.misc.ml_util import PiecewiseLinearSchedule from railrl.torch.vae.vae_trainer import ConvVAETrainer from railrl.core import logger beta = variant["beta"] use_linear_dynamics = variant.get('use_linear_dynamics', False) generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn', generate_vae_dataset) variant['generate_vae_dataset_kwargs'][ 'use_linear_dynamics'] = use_linear_dynamics train_dataset, test_dataset, info = generate_vae_dataset_fctn( variant['generate_vae_dataset_kwargs']) if use_linear_dynamics: action_dim = train_dataset.data['actions'].shape[2] else: action_dim = 0 model = get_vae(variant, action_dim) logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None vae_trainer_class = variant.get('vae_trainer_class', ConvVAETrainer) trainer = vae_trainer_class(model, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False) for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) trainer.train_epoch(epoch, train_dataset) trainer.test_epoch(epoch, test_dataset) if should_save_imgs: trainer.dump_reconstructions(epoch) trainer.dump_samples(epoch) if dump_skew_debug_plots: trainer.dump_best_reconstruction(epoch) trainer.dump_worst_reconstruction(epoch) trainer.dump_sampling_histogram(epoch) stats = trainer.get_diagnostics() for k, v in stats.items(): logger.record_tabular(k, v) logger.dump_tabular() trainer.end_epoch(epoch) if epoch % 50 == 0: logger.save_itr_params(epoch, model) logger.save_extra_data(model, 'vae.pkl', mode='pickle') if return_data: return model, train_dataset, test_dataset return model
def test_epoch(self, epoch, save_vae=True, **kwargs): self.model.eval() losses = [] kles = [] zs = [] recon_logging_dict = { 'MSE': [], 'WSE': [], } for k in self.extra_recon_logging: recon_logging_dict[k] = [] beta = self.beta_schedule.get_value(epoch) for batch_idx in range(100): data = self.get_batch(train=False) obs = data['obs'] next_obs = data['next_obs'] actions = data['actions'] recon_batch, mu, logvar = self.model(next_obs) mse = self.logprob(recon_batch, next_obs) wse = self.logprob(recon_batch, next_obs, unorm_weights=self.recon_weights) for k, idx in self.extra_recon_logging.items(): recon_loss = self.logprob(recon_batch, next_obs, idx=idx) recon_logging_dict[k].append(recon_loss.data[0]) kle = self.kl_divergence(mu, logvar) if self.recon_loss_type == 'mse': loss = mse + beta * kle elif self.recon_loss_type == 'wse': loss = wse + beta * kle z_data = ptu.get_numpy(mu.cpu()) for i in range(len(z_data)): zs.append(z_data[i, :]) losses.append(loss.data[0]) recon_logging_dict['WSE'].append(wse.data[0]) recon_logging_dict['MSE'].append(mse.data[0]) kles.append(kle.data[0]) zs = np.array(zs) self.model.dist_mu = zs.mean(axis=0) self.model.dist_std = zs.std(axis=0) for k in recon_logging_dict: logger.record_tabular("/".join(["test", k]), np.mean(recon_logging_dict[k])) logger.record_tabular("test/KL", np.mean(kles)) logger.record_tabular("test/loss", np.mean(losses)) logger.record_tabular("beta", beta) process = psutil.Process(os.getpid()) logger.record_tabular("RAM Usage (Mb)", int(process.memory_info().rss / 1000000)) num_active_dims = 0 for std in self.model.dist_std: if std > 0.15: num_active_dims += 1 logger.record_tabular("num_active_dims", num_active_dims) logger.dump_tabular() if save_vae: logger.save_itr_params(epoch, self.model, prefix='vae', save_anyway=True) # slow...
def run_task(variant): from railrl.core import logger print(variant) logger.log("Hello from script") logger.log("variant: " + str(variant)) logger.record_tabular("value", 1) logger.dump_tabular() logger.log("snapshot_dir:", logger.get_snapshot_dir())
def pretrain_policy_with_bc(self): logger.remove_tabular_output( 'progress.csv', relative_to_snapshot_dir=True ) logger.add_tabular_output( 'pretrain_policy.csv', relative_to_snapshot_dir=True ) if self.do_pretrain_rollouts: total_ret = self.do_rollouts() print("INITIAL RETURN", total_ret/20) prev_time = time.time() for i in range(self.bc_num_pretrain_steps): train_policy_loss, train_logp_loss, train_mse_loss, train_log_std = self.run_bc_batch(self.demo_train_buffer, self.policy) train_policy_loss = train_policy_loss * self.bc_weight self.policy_optimizer.zero_grad() train_policy_loss.backward() self.policy_optimizer.step() test_policy_loss, test_logp_loss, test_mse_loss, test_log_std = self.run_bc_batch(self.demo_test_buffer, self.policy) test_policy_loss = test_policy_loss * self.bc_weight if self.do_pretrain_rollouts and i % self.pretraining_env_logging_period == 0: total_ret = self.do_rollouts() print("Return at step {} : {}".format(i, total_ret/20)) if i % self.pretraining_logging_period==0: stats = { "pretrain_bc/batch": i, "pretrain_bc/Train Logprob Loss": ptu.get_numpy(train_logp_loss), "pretrain_bc/Test Logprob Loss": ptu.get_numpy(test_logp_loss), "pretrain_bc/Train MSE": ptu.get_numpy(train_mse_loss), "pretrain_bc/Test MSE": ptu.get_numpy(test_mse_loss), "pretrain_bc/train_policy_loss": ptu.get_numpy(train_policy_loss), "pretrain_bc/test_policy_loss": ptu.get_numpy(test_policy_loss), "pretrain_bc/epoch_time":time.time()-prev_time, } if self.do_pretrain_rollouts: stats["pretrain_bc/avg_return"] = total_ret / 20 logger.record_dict(stats) logger.dump_tabular(with_prefix=True, with_timestamp=False) pickle.dump(self.policy, open(logger.get_snapshot_dir() + '/bc.pkl', "wb")) prev_time = time.time() logger.remove_tabular_output( 'pretrain_policy.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) if self.post_bc_pretrain_hyperparams: self.set_algorithm_weights(**self.post_bc_pretrain_hyperparams)
def simulate_policy(args): data = pickle.load(open(args.file, "rb")) policy_key = args.policy_type + '/policy' if policy_key in data: policy = data[policy_key] else: raise Exception("No policy found in loaded dict. Keys: {}".format( data.keys())) env_key = args.env_type + '/env' if env_key in data: env = data[env_key] else: raise Exception("No environment found in loaded dict. Keys: {}".format( data.keys())) #robosuite env specific things env._wrapped_env.has_renderer = True env.reset() env.viewer.set_camera(camera_id=0) if isinstance(env, RemoteRolloutEnv): env = env._wrapped_env print("Policy loaded") if args.enable_render: # some environments need to be reconfigured for visualization env.enable_render() if args.gpu: ptu.set_gpu_mode(True) if hasattr(policy, "to"): policy.to(ptu.device) if hasattr(env, "vae"): env.vae.to(ptu.device) if args.pause: import ipdb ipdb.set_trace() if isinstance(policy, PyTorchModule): policy.train(False) paths = [] while True: paths.append( deprecated_rollout( env, policy, max_path_length=args.H, render=not args.hide, )) if args.log_diagnostics: if hasattr(env, "log_diagnostics"): env.log_diagnostics(paths, logger) for k, v in eval_util.get_generic_path_information(paths).items(): logger.record_tabular(k, v) logger.dump_tabular()
def train(self): timer.return_global_times = True for _ in range(self.num_epochs): self._begin_epoch() # logger.save_itr_params(self.epoch, self._get_snapshot()) # timer.stamp('saving') log_dict, _ = self._train() logger.record_dict(log_dict) logger.dump_tabular(with_prefix=True, with_timestamp=False) logger.save_itr_params(self.epoch, self._get_snapshot()) self._end_epoch()
def _try_to_eval(self, epoch, eval_paths=None): logger.save_extra_data(self.get_extra_data_to_save(epoch)) params = self.get_epoch_snapshot(epoch) logger.save_itr_params(epoch, params) if self._can_evaluate(): self.evaluate(epoch, eval_paths=eval_paths) # params = self.get_epoch_snapshot(epoch) # logger.save_itr_params(epoch, params) table_keys = logger.get_table_key_set() if self._old_table_keys is not None: assert table_keys == self._old_table_keys, ( "Table keys cannot change from iteration to iteration." ) self._old_table_keys = table_keys logger.record_tabular( "Number of train steps total", self._n_train_steps_total, ) logger.record_tabular( "Number of env steps total", self._n_env_steps_total, ) logger.record_tabular( "Number of rollouts total", self._n_rollouts_total, ) if self.collection_mode != 'online-parallel': times_itrs = gt.get_times().stamps.itrs train_time = times_itrs['train'][-1] sample_time = times_itrs['sample'][-1] if 'eval' in times_itrs: eval_time = times_itrs['eval'][-1] if epoch > 0 else -1 else: eval_time = -1 epoch_time = train_time + sample_time + eval_time total_time = gt.get_times().total logger.record_tabular('Train Time (s)', train_time) logger.record_tabular('(Previous) Eval Time (s)', eval_time) logger.record_tabular('Sample Time (s)', sample_time) logger.record_tabular('Epoch Time (s)', epoch_time) logger.record_tabular('Total Train Time (s)', total_time) else: logger.record_tabular('Epoch Time (s)', time.time() - self._epoch_start_time) logger.record_tabular("Epoch", epoch) logger.dump_tabular(with_prefix=False, with_timestamp=False) else: logger.log("Skipping eval for now.")
def pretrain_q_with_bc_data(self): logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('pretrain_q.csv', relative_to_snapshot_dir=True) self.update_policy = False # first train only the Q function for i in range(self.q_num_pretrain_steps): self.eval_statistics = dict() self._need_to_update_eval_statistics = True train_data = self.replay_buffer.random_batch(128) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] if self.goal_conditioned: goals = train_data['resampled_goals'] train_data['observations'] = torch.cat((obs, goals), dim=1) train_data['next_observations'] = torch.cat((next_obs, goals), dim=1) self.train_from_torch(train_data) logger.record_dict(self.eval_statistics) logger.dump_tabular(with_prefix=True, with_timestamp=False) self.update_policy = True # then train policy and Q function together for i in range(self.q_num_pretrain_steps): self.eval_statistics = dict() self._need_to_update_eval_statistics = True train_data = self.replay_buffer.random_batch(128) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] if self.goal_conditioned: goals = train_data['resampled_goals'] train_data['observations'] = torch.cat((obs, goals), dim=1) train_data['next_observations'] = torch.cat((next_obs, goals), dim=1) self.train_from_torch(train_data) logger.record_dict(self.eval_statistics) logger.dump_tabular(with_prefix=True, with_timestamp=False) logger.remove_tabular_output( 'pretrain_q.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, )
def _try_to_offline_eval(self, epoch): start_time = time.time() logger.save_extra_data(self.get_extra_data_to_save(epoch)) self.offline_evaluate(epoch) params = self.get_epoch_snapshot(epoch) logger.save_itr_params(epoch, params) table_keys = logger.get_table_key_set() if self._old_table_keys is not None: assert table_keys == self._old_table_keys, ( "Table keys cannot change from iteration to iteration.") self._old_table_keys = table_keys logger.dump_tabular(with_prefix=False, with_timestamp=False) logger.log("Eval Time: {0}".format(time.time() - start_time))
def simulate_policy(args): if args.pause: import ipdb; ipdb.set_trace() data = pickle.load(open(args.file, "rb")) # joblib.load(args.file) if 'policy' in data: policy = data['policy'] elif 'evaluation/policy' in data: policy = data['evaluation/policy'] if 'env' in data: env = data['env'] elif 'evaluation/env' in data: env = data['evaluation/env'] if isinstance(env, RemoteRolloutEnv): env = env._wrapped_env print("Policy loaded") if args.gpu: ptu.set_gpu_mode(True) policy.to(ptu.device) else: ptu.set_gpu_mode(False) policy.to(ptu.device) if isinstance(env, VAEWrappedEnv): env.mode(args.mode) if args.enable_render or hasattr(env, 'enable_render'): # some environments need to be reconfigured for visualization env.enable_render() if args.multitaskpause: env.pause_on_goal = True if isinstance(policy, PyTorchModule): policy.train(False) paths = [] while True: paths.append(multitask_rollout( env, policy, max_path_length=args.H, render=not args.hide, observation_key=data.get('evaluation/observation_key', 'observation'), desired_goal_key=data.get('evaluation/desired_goal_key', 'desired_goal'), )) if hasattr(env, "log_diagnostics"): env.log_diagnostics(paths) if hasattr(env, "get_diagnostics"): for k, v in env.get_diagnostics(paths).items(): logger.record_tabular(k, v) logger.dump_tabular()
def simulate_policy(args): dir = args.path data = joblib.load("{}/params.pkl".format(dir)) env = data['env'] model_params = data['model_params'] mpc_params = data['mpc_params'] # dyn_model = NNDynamicsModel(env=env, **model_params) # mpc_controller = MPCcontroller(env=env, # dyn_model=dyn_model, # **mpc_params) tf_path_meta = "{}/tf_out-0.meta".format(dir) tf_path = "{}/tf_out-0".format(dir) with tf.Session() as sess: new_saver = tf.train.import_meta_graph(tf_path_meta) new_saver.restore(sess, tf_path) env = data['env'] if isinstance(env, RemoteRolloutEnv): env = env._wrapped_env print("Policy loaded") if args.gpu: set_gpu_mode(True) policy.to(ptu.device) if args.pause: import ipdb ipdb.set_trace() if isinstance(policy, PyTorchModule): policy.train(False) while True: try: path = rollout( env, policy, max_path_length=args.H, animated=True, ) env.log_diagnostics([path]) policy.log_diagnostics([path]) logger.dump_tabular() # Hack for now. Not sure why rollout assumes that close is an # keyword argument except TypeError as e: if (str(e) != "render() got an unexpected keyword " "argument 'close'"): raise e
def create_policy(variant): bottom_snapshot = joblib.load(variant['bottom_path']) column_snapshot = joblib.load(variant['column_path']) policy = variant['combiner_class']( policy1=bottom_snapshot['naf_policy'], policy2=column_snapshot['naf_policy'], ) env = bottom_snapshot['env'] logger.save_itr_params(0, dict( policy=policy, env=env, )) path = rollout( env, policy, max_path_length=variant['max_path_length'], animated=variant['render'], ) env.log_diagnostics([path]) logger.dump_tabular()
def simulate_policy(args): data = joblib.load(args.file) qfs = data['qfs'] env = data['env'] print("Data loaded") if args.pause: import ipdb; ipdb.set_trace() for qf in qfs: qf.train(False) paths = [] while True: paths.append(finite_horizon_rollout( env, qfs, max_path_length=args.H, max_T=args.mt, )) if hasattr(env, "log_diagnostics"): env.log_diagnostics(paths) logger.dump_tabular()
def experiment(variant): num_rollouts = variant['num_rollouts'] H = variant['H'] render = variant['render'] env = MultitaskPoint2DEnv() qf = PerfectPoint2DQF() policy = variant['policy_class'](qf, env, **variant['policy_params']) paths = [] for _ in range(num_rollouts): goal = env.sample_goal_state_for_rollout() path = multitask_rollout( env, policy, goal, discount=0, max_path_length=H, animated=render, ) paths.append(path) env.log_diagnostics(paths) logger.dump_tabular(with_timestamp=False)
def evaluate(self, epoch): """ Perform evaluation for this algorithm. :param epoch: The epoch number. :param exploration_paths: List of dicts, each representing a path. """ statistics = OrderedDict() train_batch = self.get_batch(training=True) statistics.update(self._statistics_from_batch(train_batch, "Train")) validation_batch = self.get_batch(training=False) statistics.update( self._statistics_from_batch(validation_batch, "Validation") ) statistics['QF Loss Validation - Train Gap'] = ( statistics['Validation QF Loss Mean'] - statistics['Train QF Loss Mean'] ) statistics['Epoch'] = epoch for key, value in statistics.items(): logger.record_tabular(key, value) logger.dump_tabular(with_prefix=False, with_timestamp=False)
def experiment(variant): path = variant['path'] policy_class = variant['policy_class'] policy_params = variant['policy_params'] horizon = variant['horizon'] num_rollouts = variant['num_rollouts'] discount = variant['discount'] stat_name = variant['stat_name'] data = joblib.load(path) env = data['env'] qf = data['qf'] qf_argmax_policy = data['policy'] policy = policy_class( qf, env, qf_argmax_policy, **policy_params ) paths = [] for _ in range(num_rollouts): goal = env.sample_goal_for_rollout() path = multitask_rollout( env, policy, goal, discount, max_path_length=horizon, animated=False, decrement_discount=False, ) paths.append(path) env.log_diagnostics(paths) results = logger.get_table_dict() logger.dump_tabular() return results[stat_name]
def simulate_policy(args): data = joblib.load(args.file) policy = data['mpc_controller'] env = data['env'] print("Policy loaded") if args.pause: import ipdb ipdb.set_trace() policy.cost_fn = env.cost_fn policy.env = env if args.T: policy.mpc_horizon = args.T paths = [] while True: paths.append( rollout( env, policy, max_path_length=args.H, animated=True, )) if hasattr(env, "log_diagnostics"): env.log_diagnostics(paths) logger.dump_tabular()
def experiment(variant): num_rollouts = variant['num_rollouts'] data = joblib.load(variant['qf_path']) qf = data['qf'] env = data['env'] qf_policy = data['policy'] if ptu.gpu_enabled(): qf.to(ptu.device) qf_policy.to(ptu.device) if isinstance(qf, VectorizedGoalStructuredUniversalQfunction): policy = UnconstrainedOcWithImplicitModel(qf, env, qf_policy, **variant['policy_params']) else: policy = UnconstrainedOcWithGoalConditionedModel( qf, env, qf_policy, **variant['policy_params']) paths = [] for _ in range(num_rollouts): goal = env.sample_goal_for_rollout() print("goal", goal) path = multitask_rollout(env, policy, goal, **variant['rollout_params']) paths.append(path) env.log_diagnostics(paths) logger.dump_tabular(with_timestamp=False)
def train_rfeatures_model(variant, return_data=False): from railrl.misc.ml_util import PiecewiseLinearSchedule # from railrl.torch.vae.conv_vae import ( # ConvVAE, ConvResnetVAE # ) import railrl.torch.vae.conv_vae as conv_vae # from railrl.torch.vae.vae_trainer import ConvVAETrainer from railrl.launchers.experiments.ashvin.rfeatures.rfeatures_model import TimestepPredictionModel from railrl.launchers.experiments.ashvin.rfeatures.rfeatures_trainer import TimePredictionTrainer from railrl.core import logger import railrl.torch.pytorch_util as ptu from railrl.pythonplusplus import identity import torch output_classes = variant["output_classes"] representation_size = variant["representation_size"] batch_size = variant["batch_size"] variant['dataset_kwargs']["output_classes"] = output_classes train_dataset, test_dataset, info = get_data(variant['dataset_kwargs']) num_train_workers = variant.get("num_train_workers", 0) # 0 uses main process (good for pdb) train_dataset_loader = InfiniteBatchLoader( train_dataset, batch_size=batch_size, num_workers=num_train_workers, ) test_dataset_loader = InfiniteBatchLoader( test_dataset, batch_size=batch_size, ) logger.save_extra_data(info) logger.get_snapshot_dir() if variant.get('decoder_activation', None) == 'sigmoid': decoder_activation = torch.nn.Sigmoid() else: decoder_activation = identity architecture = variant['model_kwargs'].get('architecture', None) if not architecture and variant.get('imsize') == 84: architecture = conv_vae.imsize84_default_architecture elif not architecture and variant.get('imsize') == 48: architecture = conv_vae.imsize48_default_architecture variant['model_kwargs']['architecture'] = architecture model_class = variant.get('model_class', TimestepPredictionModel) model = model_class( representation_size, decoder_output_activation=decoder_activation, output_classes=output_classes, **variant['model_kwargs'], ) # model = torch.nn.DataParallel(model) model.to(ptu.device) variant['trainer_kwargs']['batch_size'] = batch_size trainer_class = variant.get('trainer_class', TimePredictionTrainer) trainer = trainer_class( model, **variant['trainer_kwargs'], ) save_period = variant['save_period'] trainer.dump_trajectory_rewards( "initial", dict(train=train_dataset.dataset, test=test_dataset.dataset)) dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False) for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) trainer.train_epoch(epoch, train_dataset_loader, batches=10) trainer.test_epoch(epoch, test_dataset_loader, batches=1) if should_save_imgs: trainer.dump_reconstructions(epoch) trainer.dump_trajectory_rewards( epoch, dict(train=train_dataset.dataset, test=test_dataset.dataset), should_save_imgs) stats = trainer.get_diagnostics() for k, v in stats.items(): logger.record_tabular(k, v) logger.dump_tabular() trainer.end_epoch(epoch) if epoch % 50 == 0: logger.save_itr_params(epoch, model) logger.save_extra_data(model, 'vae.pkl', mode='pickle') if return_data: return model, train_dataset, test_dataset return model
def get_n_train_vae(latent_dim, env, vae_train_epochs, num_image_examples, vae_kwargs, vae_trainer_kwargs, vae_architecture, vae_save_period=10, vae_test_p=.9, decoder_activation='sigmoid', vae_class='VAE', **kwargs): env.goal_sampling_mode = 'test' image_examples = unnormalize_image( env.sample_goals(num_image_examples)['desired_goal']) n = int(num_image_examples * vae_test_p) train_dataset = ImageObservationDataset(image_examples[:n, :]) test_dataset = ImageObservationDataset(image_examples[n:, :]) if decoder_activation == 'sigmoid': decoder_activation = torch.nn.Sigmoid() vae_class = vae_class.lower() if vae_class == 'VAE'.lower(): vae_class = ConvVAE elif vae_class == 'SpatialVAE'.lower(): vae_class = SpatialAutoEncoder else: raise RuntimeError("Invalid VAE Class: {}".format(vae_class)) vae = vae_class(latent_dim, architecture=vae_architecture, decoder_output_activation=decoder_activation, **vae_kwargs) trainer = ConvVAETrainer(vae, **vae_trainer_kwargs) logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('vae_progress.csv', relative_to_snapshot_dir=True) for epoch in range(vae_train_epochs): should_save_imgs = (epoch % vae_save_period == 0) trainer.train_epoch(epoch, train_dataset) trainer.test_epoch(epoch, test_dataset) if should_save_imgs: trainer.dump_reconstructions(epoch) trainer.dump_samples(epoch) stats = trainer.get_diagnostics() for k, v in stats.items(): logger.record_tabular(k, v) logger.dump_tabular() trainer.end_epoch(epoch) if epoch % 50 == 0: logger.save_itr_params(epoch, vae) logger.save_extra_data(vae, 'vae.pkl', mode='pickle') logger.remove_tabular_output('vae_progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('progress.csv', relative_to_snapshot_dir=True) return vae
goal = env.sample_goal_for_rollout() goal[7:14] = 0 path = multitask_rollout( env, original_policy, # env.multitask_goal, goal, init_tau=10, max_path_length=args.H, animated=not args.hide, cycle_tau=True, decrement_tau=False, ) if hasattr(env, "log_diagnostics"): env.log_diagnostics([path]) logger.dump_tabular() else: for weight in [1]: for num_simulated_paths in [args.npath]: print("") print("weight", weight) print("num_simulated_paths", num_simulated_paths) policy = CollocationMpcController( env, implicit_model, original_policy, num_simulated_paths=num_simulated_paths, feasibility_weight=weight, ) policy.train(False) paths = []
def train(dataset_generator, n_start_samples, projection=project_samples_square_np, n_samples_to_add_per_epoch=1000, n_epochs=100, save_period=10, append_all_data=True, full_variant=None, dynamics_noise=0, num_bins=5, weight_type='sqrt_inv_p', **kwargs): report = HTMLReport( logger.get_snapshot_dir() + '/report.html', images_per_row=3, ) dynamics = Dynamics(projection, dynamics_noise) if full_variant: report.add_header("Variant") report.add_text( json.dumps( ppp.dict_to_safe_json(full_variant, sort=True), indent=2, )) orig_train_data = dataset_generator(n_start_samples) train_data = orig_train_data heatmap_imgs = [] sample_imgs = [] entropies = [] tvs_to_uniform = [] """ p_theta = previous iteration's model p_new = this iteration's distribution """ p_theta = Histogram(num_bins, weight_type=weight_type) for epoch in range(n_epochs): logger.record_tabular('Epoch', epoch) logger.record_tabular('Entropy ', p_theta.entropy()) logger.record_tabular('KL from uniform', p_theta.kl_from_uniform()) logger.record_tabular('TV to uniform', p_theta.tv_to_uniform()) entropies.append(p_theta.entropy()) tvs_to_uniform.append(p_theta.tv_to_uniform()) samples = p_theta.sample(n_samples_to_add_per_epoch) empirical_samples = dynamics(samples) if append_all_data: train_data = np.vstack((train_data, empirical_samples)) else: train_data = np.vstack((orig_train_data, empirical_samples)) if epoch == 0 or (epoch + 1) % save_period == 0: report.add_text("Epoch {}".format(epoch)) heatmap_img = visualize_histogram(epoch, p_theta, report) sample_img = visualize_samples(epoch, train_data, p_theta, report, dynamics) heatmap_imgs.append(heatmap_img) sample_imgs.append(sample_img) report.save() from PIL import Image Image.fromarray(heatmap_img).save(logger.get_snapshot_dir() + '/heatmap{}.png'.format(epoch)) Image.fromarray(sample_img).save(logger.get_snapshot_dir() + '/samples{}.png'.format(epoch)) weights = p_theta.compute_per_elem_weights(train_data) p_new = Histogram(num_bins, weight_type=weight_type) p_new.fit( train_data, weights=weights, ) p_theta = p_new logger.dump_tabular() plot_curves([ ("Entropy", entropies), ("TVs to Uniform", tvs_to_uniform), ], report) report.add_text("Max entropy: {}".format(p_theta.max_entropy())) report.save() heatmap_video = np.stack(heatmap_imgs) sample_video = np.stack(sample_imgs) vwrite( logger.get_snapshot_dir() + '/heatmaps.mp4', heatmap_video, ) vwrite( logger.get_snapshot_dir() + '/samples.mp4', sample_video, ) try: gif( logger.get_snapshot_dir() + '/samples.gif', sample_video, ) gif( logger.get_snapshot_dir() + '/heatmaps.gif', heatmap_video, ) report.add_image( logger.get_snapshot_dir() + '/samples.gif', "Samples GIF", is_url=True, ) report.add_image( logger.get_snapshot_dir() + '/heatmaps.gif', "Heatmaps GIF", is_url=True, ) report.save() except ImportError as e: print(e)
def main(): parser = argparse.ArgumentParser() parser.add_argument('file', type=str, help='path to the snapshot file') parser.add_argument('--H', type=int, default=300, help='Max length of rollout') parser.add_argument('--nrolls', type=int, default=1, help='Number of rollout per eval') parser.add_argument('--verbose', action='store_true') parser.add_argument('--mtau', type=float, help='Max tau value') parser.add_argument('--grid', action='store_true') parser.add_argument('--gpu', action='store_true') parser.add_argument('--load', action='store_true') parser.add_argument('--hide', action='store_true') parser.add_argument('--pause', action='store_true') parser.add_argument('--cycle', help='cycle tau', action='store_true') args = parser.parse_args() data = joblib.load(args.file) env = data['env'] if 'policy' in data: policy = data['policy'] else: policy = data['exploration_policy'] qf = data['qf'] policy.train(False) qf.train(False) if args.pause: import ipdb ipdb.set_trace() if args.gpu: ptu.set_gpu_mode(True) policy.to(ptu.device) if args.mtau is None: print("Defaulting max tau to 10.") max_tau = 10 else: max_tau = args.mtau while True: paths = [] for _ in range(args.nrolls): goal = env.sample_goal_for_rollout() print("goal", goal) env.set_goal(goal) policy.set_goal(goal) policy.set_tau(max_tau) path = rollout( env, policy, qf, init_tau=max_tau, max_path_length=args.H, animated=not args.hide, cycle_tau=args.cycle, ) paths.append(path) env.log_diagnostics(paths) for key, value in get_generic_path_information(paths).items(): logger.record_tabular(key, value) logger.dump_tabular()
def simulate_policy(args): data = joblib.load(args.file) if 'eval_policy' in data: policy = data['eval_policy'] elif 'policy' in data: policy = data['policy'] elif 'exploration_policy' in data: policy = data['exploration_policy'] else: raise Exception("No policy found in loaded dict. Keys: {}".format( data.keys())) env = data['env'] env.mode("video_env") env.decode_goals = True if hasattr(env, 'enable_render'): # some environments need to be reconfigured for visualization env.enable_render() if args.gpu: set_gpu_mode(True) policy.to(ptu.device) if hasattr(env, "vae"): env.vae.to(ptu.device) else: # make sure everything is on the CPU set_gpu_mode(False) policy.cpu() if hasattr(env, "vae"): env.vae.cpu() if args.pause: import ipdb ipdb.set_trace() if isinstance(policy, PyTorchModule): policy.train(False) ROWS = 3 COLUMNS = 6 dirname = osp.dirname(args.file) input_file_name = os.path.splitext(os.path.basename(args.file))[0] filename = osp.join(dirname, "video_{}.mp4".format(input_file_name)) rollout_function = create_rollout_function( multitask_rollout, observation_key='observation', desired_goal_key='desired_goal', ) paths = dump_video( env, policy, filename, rollout_function, ROWS=ROWS, COLUMNS=COLUMNS, horizon=args.H, dirname_to_save_images=dirname, subdirname="rollouts_" + input_file_name, ) if hasattr(env, "log_diagnostics"): env.log_diagnostics(paths) logger.dump_tabular()