def log_diagnostics(self, paths, **kwargs): list_of_rewards, terminals, obs, actions, next_obs = split_paths(paths) returns = [] for rewards in list_of_rewards: returns.append(np.sum(rewards)) statistics = OrderedDict() statistics.update( create_stats_ordered_dict( 'Undiscounted Returns', returns, )) statistics.update( create_stats_ordered_dict( 'Rewards', list_of_rewards, )) statistics.update(create_stats_ordered_dict( 'Actions', actions, )) fraction_of_time_on_platform = [o[1] for o in obs] statistics['Fraction of time on platform'] = np.mean( fraction_of_time_on_platform) for key, value in statistics.items(): logger.record_tabular(key, value) return returns
def evaluate(self, epoch, eval_paths=None): statistics = OrderedDict() statistics.update(self.eval_statistics) logger.log("Collecting samples for evaluation") if eval_paths: test_paths = eval_paths else: test_paths = self.get_eval_paths() statistics.update( eval_util.get_generic_path_information( test_paths, stat_prefix="Test", )) # if len(self._exploration_paths) > 0: # statistics.update(eval_util.get_generic_path_information( # self._exploration_paths, stat_prefix="Exploration", # )) if hasattr(self.env, "log_diagnostics"): self.env.log_diagnostics(test_paths, logger=logger) if hasattr(self.env, "get_diagnostics"): statistics.update(self.env.get_diagnostics(test_paths)) average_returns = eval_util.get_average_returns(test_paths) statistics['AverageReturn'] = average_returns for key, value in statistics.items(): logger.record_tabular(key, value) self.need_to_update_eval_statistics = True
def log_diagnostics(self, paths, **kwargs): list_of_rewards, terminals, obs, actions, next_obs = split_paths(paths) returns = [] for rewards in list_of_rewards: returns.append(np.sum(rewards)) last_statistics = OrderedDict() last_statistics.update( create_stats_ordered_dict( 'UndiscountedReturns', returns, )) last_statistics.update( create_stats_ordered_dict( 'Rewards', list_of_rewards, )) last_statistics.update(create_stats_ordered_dict( 'Actions', actions, )) for key, value in last_statistics.items(): logger.record_tabular(key, value) return returns
def log_diagnostics(self, paths): final_values = [] final_unclipped_rewards = [] final_rewards = [] for path in paths: final_value = path["actions"][-1][0] final_values.append(final_value) score = path["observations"][0][0] * final_value final_unclipped_rewards.append(score) final_rewards.append(clip_magnitude(score, 1)) last_statistics = OrderedDict() last_statistics.update( create_stats_ordered_dict( 'Final Value', final_values, )) last_statistics.update( create_stats_ordered_dict( 'Unclipped Final Rewards', final_unclipped_rewards, )) last_statistics.update( create_stats_ordered_dict( 'Final Rewards', final_rewards, )) for key, value in last_statistics.items(): logger.record_tabular(key, value) return final_unclipped_rewards
def evaluate(self, epoch): """ Perform evaluation for this algorithm. :param epoch: The epoch number. """ statistics = OrderedDict() train_batch = self.get_batch() statistics.update(self._statistics_from_batch(train_batch, "Train")) logger.log("Collecting samples for evaluation") test_paths = self._sample_eval_paths() statistics.update( get_generic_path_information( test_paths, stat_prefix="Test", )) statistics.update(self._statistics_from_paths(test_paths, "Test")) average_returns = get_average_returns(test_paths) statistics['AverageReturn'] = average_returns statistics['Epoch'] = epoch for key, value in statistics.items(): logger.record_tabular(key, value) self.env.log_diagnostics(test_paths) logger.dump_tabular(with_prefix=False, with_timestamp=False)
def train_vae(variant, return_data=False): from railrl.misc.ml_util import PiecewiseLinearSchedule from railrl.torch.vae.vae_trainer import ConvVAETrainer from railrl.core import logger beta = variant["beta"] use_linear_dynamics = variant.get('use_linear_dynamics', False) generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn', generate_vae_dataset) variant['generate_vae_dataset_kwargs'][ 'use_linear_dynamics'] = use_linear_dynamics train_dataset, test_dataset, info = generate_vae_dataset_fctn( variant['generate_vae_dataset_kwargs']) if use_linear_dynamics: action_dim = train_dataset.data['actions'].shape[2] else: action_dim = 0 model = get_vae(variant, action_dim) logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None vae_trainer_class = variant.get('vae_trainer_class', ConvVAETrainer) trainer = vae_trainer_class(model, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False) for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) trainer.train_epoch(epoch, train_dataset) trainer.test_epoch(epoch, test_dataset) if should_save_imgs: trainer.dump_reconstructions(epoch) trainer.dump_samples(epoch) if dump_skew_debug_plots: trainer.dump_best_reconstruction(epoch) trainer.dump_worst_reconstruction(epoch) trainer.dump_sampling_histogram(epoch) stats = trainer.get_diagnostics() for k, v in stats.items(): logger.record_tabular(k, v) logger.dump_tabular() trainer.end_epoch(epoch) if epoch % 50 == 0: logger.save_itr_params(epoch, model) logger.save_extra_data(model, 'vae.pkl', mode='pickle') if return_data: return model, train_dataset, test_dataset return model
def run_task(variant): from railrl.core import logger print(variant) logger.log("Hello from script") logger.log("variant: " + str(variant)) logger.record_tabular("value", 1) logger.dump_tabular() logger.log("snapshot_dir:", logger.get_snapshot_dir())
def simulate_policy(args): data = pickle.load(open(args.file, "rb")) policy_key = args.policy_type + '/policy' if policy_key in data: policy = data[policy_key] else: raise Exception("No policy found in loaded dict. Keys: {}".format( data.keys())) env_key = args.env_type + '/env' if env_key in data: env = data[env_key] else: raise Exception("No environment found in loaded dict. Keys: {}".format( data.keys())) #robosuite env specific things env._wrapped_env.has_renderer = True env.reset() env.viewer.set_camera(camera_id=0) if isinstance(env, RemoteRolloutEnv): env = env._wrapped_env print("Policy loaded") if args.enable_render: # some environments need to be reconfigured for visualization env.enable_render() if args.gpu: ptu.set_gpu_mode(True) if hasattr(policy, "to"): policy.to(ptu.device) if hasattr(env, "vae"): env.vae.to(ptu.device) if args.pause: import ipdb ipdb.set_trace() if isinstance(policy, PyTorchModule): policy.train(False) paths = [] while True: paths.append( deprecated_rollout( env, policy, max_path_length=args.H, render=not args.hide, )) if args.log_diagnostics: if hasattr(env, "log_diagnostics"): env.log_diagnostics(paths, logger) for k, v in eval_util.get_generic_path_information(paths).items(): logger.record_tabular(k, v) logger.dump_tabular()
def log_diagnostics(self, paths): n_goal = len(self.goal_positions) goal_reached = [False] * n_goal for path in paths: last_obs = path["observations"][-1] for i, goal in enumerate(self.goal_positions): if np.linalg.norm(last_obs - goal) < self.goal_threshold: goal_reached[i] = True logger.record_tabular('env:goals_reached', goal_reached.count(True))
def log_diagnostics(self, paths): statistics = OrderedDict() for stat_name in [ 'arm to object distance', 'object to goal distance', 'arm to goal distance', ]: stat = get_stat_in_paths(paths, 'env_infos', stat_name) statistics.update(create_stats_ordered_dict(stat_name, stat)) for key, value in statistics.items(): logger.record_tabular(key, value)
def plot_performance(policy, env, nrolls): print("max_tau, distance") # fixed_goals = [-40, -30, 30, 40] fixed_goals = [-5, -3, 3, 5] taus = np.arange(10) * 10 for row, fix_tau in enumerate([True, False]): for col, horizon_fixed in enumerate([True, False]): plot_num = row + 2 * col + 1 plt.subplot(2, 2, plot_num) for fixed_goal in fixed_goals: distances = [] for max_tau in taus: paths = [] for _ in range(nrolls): goal = env.sample_goal_for_rollout() goal[0] = fixed_goal path = multitask_rollout( env, policy, goal, init_tau=max_tau, max_path_length=100 if horizon_fixed else max_tau + 1, animated=False, cycle_tau=True, decrement_tau=not fix_tau, ) paths.append(path) env.log_diagnostics(paths) for key, value in get_generic_path_information( paths).items(): logger.record_tabular(key, value) distance = float( dict(logger._tabular)['Final Distance to goal Mean']) distances.append(distance) plt.plot(taus, distances) print("line done") plt.legend([str(goal) for goal in fixed_goals]) if fix_tau: plt.xlabel("Tau (Horizon-1)") else: plt.xlabel("Initial tau (=Horizon-1)") plt.xlabel("Max tau") plt.ylabel("Final distance to goal") plt.title("Fix Tau = {}, Horizon Fixed to 100 = {}".format( fix_tau, horizon_fixed, )) plt.show() plt.savefig('results/iclr2018/cheetah-sweep-tau-eval-5-3.jpg')
def train_epoch(self, epoch, batches=100): self.model.train() losses = [] kles = [] mses = [] beta = self.beta_schedule.get_value(epoch) for batch_idx in range(batches): data = self.get_batch() obs = data['obs'] next_obs = data['next_obs'] actions = data['actions'] self.optimizer.zero_grad() recon_batch, mu, logvar = self.model(next_obs) mse = self.logprob(recon_batch, next_obs) kle = self.kl_divergence(mu, logvar) if self.recon_loss_type == 'mse': loss = mse + beta * kle elif self.recon_loss_type == 'wse': wse = self.logprob(recon_batch, next_obs, unorm_weights=self.recon_weights) loss = wse + beta * kle loss.backward() losses.append(loss.data[0]) mses.append(mse.data[0]) kles.append(kle.data[0]) self.optimizer.step() logger.record_tabular("train/epoch", epoch) logger.record_tabular("train/MSE", np.mean(mses)) logger.record_tabular("train/KL", np.mean(kles)) logger.record_tabular("train/loss", np.mean(losses))
def evaluate(self, epoch, eval_paths=None): statistics = OrderedDict() if isinstance(self.epoch_discount_schedule, StatConditionalSchedule): table_dict = logger.get_table_dict() # rllab converts things to strings for some reason value = float( table_dict[self.epoch_discount_schedule.statistic_name]) self.epoch_discount_schedule.update(value) if not isinstance(self.epoch_discount_schedule, ConstantSchedule): statistics['Discount Factor'] = self.discount for key, value in statistics.items(): logger.record_tabular(key, value) super().evaluate(epoch, eval_paths=eval_paths)
def simulate_policy(args): if args.pause: import ipdb; ipdb.set_trace() data = pickle.load(open(args.file, "rb")) # joblib.load(args.file) if 'policy' in data: policy = data['policy'] elif 'evaluation/policy' in data: policy = data['evaluation/policy'] if 'env' in data: env = data['env'] elif 'evaluation/env' in data: env = data['evaluation/env'] if isinstance(env, RemoteRolloutEnv): env = env._wrapped_env print("Policy loaded") if args.gpu: ptu.set_gpu_mode(True) policy.to(ptu.device) else: ptu.set_gpu_mode(False) policy.to(ptu.device) if isinstance(env, VAEWrappedEnv): env.mode(args.mode) if args.enable_render or hasattr(env, 'enable_render'): # some environments need to be reconfigured for visualization env.enable_render() if args.multitaskpause: env.pause_on_goal = True if isinstance(policy, PyTorchModule): policy.train(False) paths = [] while True: paths.append(multitask_rollout( env, policy, max_path_length=args.H, render=not args.hide, observation_key=data.get('evaluation/observation_key', 'observation'), desired_goal_key=data.get('evaluation/desired_goal_key', 'desired_goal'), )) if hasattr(env, "log_diagnostics"): env.log_diagnostics(paths) if hasattr(env, "get_diagnostics"): for k, v in env.get_diagnostics(paths).items(): logger.record_tabular(k, v) logger.dump_tabular()
def log_loss_under_uniform(self, model, data, priority_function_kwargs): import torch.nn.functional as F log_probs_prior = [] log_probs_biased = [] log_probs_importance = [] kles = [] mses = [] for i in range(0, data.shape[0], self.batch_size): img = normalize_image(data[i:min(data.shape[0], i + self.batch_size), :]) torch_img = ptu.from_numpy(img) reconstructions, obs_distribution_params, latent_distribution_params = self.model( torch_img) priority_function_kwargs['sampling_method'] = 'true_prior_sampling' log_p, log_q, log_d = compute_log_p_log_q_log_d( model, img, **priority_function_kwargs) log_prob_prior = log_d.mean() priority_function_kwargs['sampling_method'] = 'biased_sampling' log_p, log_q, log_d = compute_log_p_log_q_log_d( model, img, **priority_function_kwargs) log_prob_biased = log_d.mean() priority_function_kwargs['sampling_method'] = 'importance_sampling' log_p, log_q, log_d = compute_log_p_log_q_log_d( model, img, **priority_function_kwargs) log_prob_importance = (log_p - log_q + log_d).mean() kle = model.kl_divergence(latent_distribution_params) mse = F.mse_loss(torch_img, reconstructions, reduction='elementwise_mean') mses.append(mse.item()) kles.append(kle.item()) log_probs_prior.append(log_prob_prior.item()) log_probs_biased.append(log_prob_biased.item()) log_probs_importance.append(log_prob_importance.item()) logger.record_tabular("Uniform Data Log Prob (True Prior)", np.mean(log_probs_prior)) logger.record_tabular("Uniform Data Log Prob (Biased)", np.mean(log_probs_biased)) logger.record_tabular("Uniform Data Log Prob (Importance)", np.mean(log_probs_importance)) logger.record_tabular("Uniform Data KL", np.mean(kles)) logger.record_tabular("Uniform Data MSE", np.mean(mses))
def log_diagnostics(self, paths): Ntraj = len(paths) acts = np.array([traj['actions'] for traj in paths]) obs = np.array([traj['observations'] for traj in paths]) state_count = np.sum(obs, axis=1) states_visited = np.sum(state_count>0, axis=-1) #log states visited logger.record_tabular('AvgStatesVisited', np.mean(states_visited)) #log action block lengths traj_idx, _, acts_idx = np.where(acts==1) acts_idx = np.array([acts_idx[traj_idx==i] for i in range(Ntraj)]) if self.zero_reward: task_reward = np.array([traj['env_infos']['task_reward'] for traj in paths]) logger.record_tabular('ZeroedTaskReward', np.mean(np.sum(task_reward, axis=1)))
def test_epoch( self, epoch, ): self.model.eval() val_losses = [] per_dim_losses = np.zeros((self.num_batches, self.y_train.shape[1])) for batch in range(self.num_batches): inputs_np, labels_np = self.random_batch( self.X_test, self.y_test, batch_size=self.batch_size) inputs, labels = ptu.Variable( ptu.from_numpy(inputs_np)), ptu.Variable( ptu.from_numpy(labels_np)) outputs = self.model(inputs) loss = self.criterion(outputs, labels) val_losses.append(loss.data[0]) per_dim_loss = np.mean(np.power(ptu.get_numpy(outputs - labels), 2), axis=0) per_dim_losses[batch] = per_dim_loss logger.record_tabular("test/epoch", epoch) logger.record_tabular("test/loss", np.mean(np.array(val_losses))) for i in range(self.y_train.shape[1]): logger.record_tabular("test/dim " + str(i) + " loss", np.mean(per_dim_losses[:, i])) logger.dump_tabular()
def evaluate(self, epoch, eval_paths=None): statistics = OrderedDict() statistics.update(self.eval_statistics) logger.log("Collecting samples for evaluation") if eval_paths: test_paths = eval_paths else: test_paths = self.get_eval_paths() statistics.update( eval_util.get_generic_path_information( test_paths, stat_prefix="Test", )) if len(self._exploration_paths) > 0: statistics.update( eval_util.get_generic_path_information( self._exploration_paths, stat_prefix="Exploration", )) if hasattr(self.env, "log_diagnostics"): self.env.log_diagnostics(test_paths, logger=logger) if hasattr(self.env, "get_diagnostics"): statistics.update(self.env.get_diagnostics(test_paths)) if hasattr(self.eval_policy, "log_diagnostics"): self.eval_policy.log_diagnostics(test_paths, logger=logger) if hasattr(self.eval_policy, "get_diagnostics"): statistics.update(self.eval_policy.get_diagnostics(test_paths)) process = psutil.Process(os.getpid()) statistics['RAM Usage (Mb)'] = int(process.memory_info().rss / 1000000) statistics['Exploration Policy Noise'] = self._exploration_policy_noise average_returns = eval_util.get_average_returns(test_paths) statistics['AverageReturn'] = average_returns for key, value in statistics.items(): logger.record_tabular(key, value) self.need_to_update_eval_statistics = True
def test_epoch(self, epoch, save_network=True, batches=100): self.model.eval() mses = [] losses = [] for batch_idx in range(batches): data = self.get_batch(train=False) z = data["z"] z_proj = data['z_proj'] z_proj_hat = self.model(z) mse = self.mse_loss(z_proj_hat, z_proj) loss = mse mses.append(mse.data[0]) losses.append(loss.data[0]) logger.record_tabular("test/epoch", epoch) logger.record_tabular("test/MSE", np.mean(mses)) logger.record_tabular("test/loss", np.mean(losses)) logger.dump_tabular() if save_network: logger.save_itr_params(epoch, self.model, prefix='reproj', save_anyway=True)
def train_epoch(self, epoch): self.model.train() losses = [] per_dim_losses = np.zeros((self.num_batches, self.y_train.shape[1])) for batch in range(self.num_batches): inputs_np, labels_np = self.random_batch( self.X_train, self.y_train, batch_size=self.batch_size) inputs, labels = ptu.Variable( ptu.from_numpy(inputs_np)), ptu.Variable( ptu.from_numpy(labels_np)) self.optimizer.zero_grad() outputs = self.model(inputs) loss = self.criterion(outputs, labels) loss.backward() self.optimizer.step() losses.append(loss.data[0]) per_dim_loss = np.mean(np.power(ptu.get_numpy(outputs - labels), 2), axis=0) per_dim_losses[batch] = per_dim_loss logger.record_tabular("train/epoch", epoch) logger.record_tabular("train/loss", np.mean(np.array(losses))) for i in range(self.y_train.shape[1]): logger.record_tabular("train/dim " + str(i) + " loss", np.mean(per_dim_losses[:, i]))
def log_diagnostics(self, paths): target_onehots = [] for path in paths: first_observation = path["observations"][0][:self.n + 1] target_onehots.append(first_observation) final_predictions = [] # each element has shape (dim) nonfinal_predictions = [] # each element has shape (seq_length-1, dim) for path in paths: actions = path["actions"] if self._softmax_action: actions = softmax(actions, axis=-1) final_predictions.append(actions[-1]) nonfinal_predictions.append(actions[:-1]) nonfinal_predictions_sequence_dimension_flattened = np.vstack( nonfinal_predictions) # shape = N X dim nonfinal_prob_zero = [ softmax[0] for softmax in nonfinal_predictions_sequence_dimension_flattened ] final_probs_correct = [] for final_prediction, target_onehot in zip(final_predictions, target_onehots): correct_pred_idx = np.argmax(target_onehot) final_probs_correct.append(final_prediction[correct_pred_idx]) final_prob_zero = [softmax[0] for softmax in final_predictions] last_statistics = OrderedDict() last_statistics.update( create_stats_ordered_dict('Final P(correct)', final_probs_correct)) last_statistics.update( create_stats_ordered_dict('Non-final P(zero)', nonfinal_prob_zero)) last_statistics.update( create_stats_ordered_dict('Final P(zero)', final_prob_zero)) for key, value in last_statistics.items(): logger.record_tabular(key, value) return final_probs_correct
def evaluate(self, epoch, exploration_paths): """ Perform evaluation for this algorithm. :param epoch: The epoch number. :param exploration_paths: List of dicts, each representing a path. """ logger.log("Collecting samples for evaluation") paths = self._sample_eval_paths(epoch) statistics = OrderedDict() statistics.update(self._statistics_from_paths(paths, "Test")) statistics.update(self._get_other_statistics()) statistics.update(self._statistics_from_paths(exploration_paths, "Exploration")) statistics['AverageReturn'] = get_average_returns(paths) statistics['Epoch'] = epoch for key, value in statistics.items(): logger.record_tabular(key, value) self.log_diagnostics(paths)
def evaluate(self, epoch): """ Perform evaluation for this algorithm. :param epoch: The epoch number. :param exploration_paths: List of dicts, each representing a path. """ statistics = OrderedDict() train_batch = self.get_batch(training=True) statistics.update(self._statistics_from_batch(train_batch, "Train")) validation_batch = self.get_batch(training=False) statistics.update( self._statistics_from_batch(validation_batch, "Validation") ) statistics['QF Loss Validation - Train Gap'] = ( statistics['Validation QF Loss Mean'] - statistics['Train QF Loss Mean'] ) statistics['Epoch'] = epoch for key, value in statistics.items(): logger.record_tabular(key, value) logger.dump_tabular(with_prefix=False, with_timestamp=False)
def train_epoch(self, epoch, batches=100): self.model.train() mses = [] losses = [] for batch_idx in range(batches): data = self.get_batch() z = data["z"] z_proj = data['z_proj'] self.optimizer.zero_grad() z_proj_hat = self.model(z) mse = self.mse_loss(z_proj_hat, z_proj) loss = mse loss.backward() mses.append(mse.data[0]) losses.append(loss.data[0]) self.optimizer.step() logger.record_tabular("train/epoch", epoch) logger.record_tabular("train/MSE", np.mean(mses)) logger.record_tabular("train/loss", np.mean(losses))
def get_n_train_vae(latent_dim, env, vae_train_epochs, num_image_examples, vae_kwargs, vae_trainer_kwargs, vae_architecture, vae_save_period=10, vae_test_p=.9, decoder_activation='sigmoid', vae_class='VAE', **kwargs): env.goal_sampling_mode = 'test' image_examples = unnormalize_image( env.sample_goals(num_image_examples)['desired_goal']) n = int(num_image_examples * vae_test_p) train_dataset = ImageObservationDataset(image_examples[:n, :]) test_dataset = ImageObservationDataset(image_examples[n:, :]) if decoder_activation == 'sigmoid': decoder_activation = torch.nn.Sigmoid() vae_class = vae_class.lower() if vae_class == 'VAE'.lower(): vae_class = ConvVAE elif vae_class == 'SpatialVAE'.lower(): vae_class = SpatialAutoEncoder else: raise RuntimeError("Invalid VAE Class: {}".format(vae_class)) vae = vae_class(latent_dim, architecture=vae_architecture, decoder_output_activation=decoder_activation, **vae_kwargs) trainer = ConvVAETrainer(vae, **vae_trainer_kwargs) logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('vae_progress.csv', relative_to_snapshot_dir=True) for epoch in range(vae_train_epochs): should_save_imgs = (epoch % vae_save_period == 0) trainer.train_epoch(epoch, train_dataset) trainer.test_epoch(epoch, test_dataset) if should_save_imgs: trainer.dump_reconstructions(epoch) trainer.dump_samples(epoch) stats = trainer.get_diagnostics() for k, v in stats.items(): logger.record_tabular(k, v) logger.dump_tabular() trainer.end_epoch(epoch) if epoch % 50 == 0: logger.save_itr_params(epoch, vae) logger.save_extra_data(vae, 'vae.pkl', mode='pickle') logger.remove_tabular_output('vae_progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('progress.csv', relative_to_snapshot_dir=True) return vae
def train_rfeatures_model(variant, return_data=False): from railrl.misc.ml_util import PiecewiseLinearSchedule # from railrl.torch.vae.conv_vae import ( # ConvVAE, ConvResnetVAE # ) import railrl.torch.vae.conv_vae as conv_vae # from railrl.torch.vae.vae_trainer import ConvVAETrainer from railrl.launchers.experiments.ashvin.rfeatures.rfeatures_model import TimestepPredictionModel from railrl.launchers.experiments.ashvin.rfeatures.rfeatures_trainer import TimePredictionTrainer from railrl.core import logger import railrl.torch.pytorch_util as ptu from railrl.pythonplusplus import identity import torch output_classes = variant["output_classes"] representation_size = variant["representation_size"] batch_size = variant["batch_size"] variant['dataset_kwargs']["output_classes"] = output_classes train_dataset, test_dataset, info = get_data(variant['dataset_kwargs']) num_train_workers = variant.get("num_train_workers", 0) # 0 uses main process (good for pdb) train_dataset_loader = InfiniteBatchLoader( train_dataset, batch_size=batch_size, num_workers=num_train_workers, ) test_dataset_loader = InfiniteBatchLoader( test_dataset, batch_size=batch_size, ) logger.save_extra_data(info) logger.get_snapshot_dir() if variant.get('decoder_activation', None) == 'sigmoid': decoder_activation = torch.nn.Sigmoid() else: decoder_activation = identity architecture = variant['model_kwargs'].get('architecture', None) if not architecture and variant.get('imsize') == 84: architecture = conv_vae.imsize84_default_architecture elif not architecture and variant.get('imsize') == 48: architecture = conv_vae.imsize48_default_architecture variant['model_kwargs']['architecture'] = architecture model_class = variant.get('model_class', TimestepPredictionModel) model = model_class( representation_size, decoder_output_activation=decoder_activation, output_classes=output_classes, **variant['model_kwargs'], ) # model = torch.nn.DataParallel(model) model.to(ptu.device) variant['trainer_kwargs']['batch_size'] = batch_size trainer_class = variant.get('trainer_class', TimePredictionTrainer) trainer = trainer_class( model, **variant['trainer_kwargs'], ) save_period = variant['save_period'] trainer.dump_trajectory_rewards( "initial", dict(train=train_dataset.dataset, test=test_dataset.dataset)) dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False) for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) trainer.train_epoch(epoch, train_dataset_loader, batches=10) trainer.test_epoch(epoch, test_dataset_loader, batches=1) if should_save_imgs: trainer.dump_reconstructions(epoch) trainer.dump_trajectory_rewards( epoch, dict(train=train_dataset.dataset, test=test_dataset.dataset), should_save_imgs) stats = trainer.get_diagnostics() for k, v in stats.items(): logger.record_tabular(k, v) logger.dump_tabular() trainer.end_epoch(epoch) if epoch % 50 == 0: logger.save_itr_params(epoch, model) logger.save_extra_data(model, 'vae.pkl', mode='pickle') if return_data: return model, train_dataset, test_dataset return model
def log_diagnostics(self, paths): reward_fwd = get_stat_in_paths(paths, 'env_infos', 'reward_fwd') reward_ctrl = get_stat_in_paths(paths, 'env_infos', 'reward_ctrl') logger.record_tabular('AvgRewardDist', np.mean(reward_fwd)) logger.record_tabular('AvgRewardCtrl', np.mean(reward_ctrl)) if len(paths) > 0: progs = [ path["observations"][-1][-3] - path["observations"][0][-3] for path in paths ] logger.record_tabular('AverageForwardProgress', np.mean(progs)) logger.record_tabular('MaxForwardProgress', np.max(progs)) logger.record_tabular('MinForwardProgress', np.min(progs)) logger.record_tabular('StdForwardProgress', np.std(progs)) else: logger.record_tabular('AverageForwardProgress', np.nan) logger.record_tabular('MaxForwardProgress', np.nan) logger.record_tabular('MinForwardProgress', np.nan) logger.record_tabular('StdForwardProgress', np.nan)
def test_epoch(self, epoch, save_vae=True, **kwargs): self.model.eval() losses = [] kles = [] zs = [] recon_logging_dict = { 'MSE': [], 'WSE': [], } for k in self.extra_recon_logging: recon_logging_dict[k] = [] beta = self.beta_schedule.get_value(epoch) for batch_idx in range(100): data = self.get_batch(train=False) obs = data['obs'] next_obs = data['next_obs'] actions = data['actions'] recon_batch, mu, logvar = self.model(next_obs) mse = self.logprob(recon_batch, next_obs) wse = self.logprob(recon_batch, next_obs, unorm_weights=self.recon_weights) for k, idx in self.extra_recon_logging.items(): recon_loss = self.logprob(recon_batch, next_obs, idx=idx) recon_logging_dict[k].append(recon_loss.data[0]) kle = self.kl_divergence(mu, logvar) if self.recon_loss_type == 'mse': loss = mse + beta * kle elif self.recon_loss_type == 'wse': loss = wse + beta * kle z_data = ptu.get_numpy(mu.cpu()) for i in range(len(z_data)): zs.append(z_data[i, :]) losses.append(loss.data[0]) recon_logging_dict['WSE'].append(wse.data[0]) recon_logging_dict['MSE'].append(mse.data[0]) kles.append(kle.data[0]) zs = np.array(zs) self.model.dist_mu = zs.mean(axis=0) self.model.dist_std = zs.std(axis=0) for k in recon_logging_dict: logger.record_tabular("/".join(["test", k]), np.mean(recon_logging_dict[k])) logger.record_tabular("test/KL", np.mean(kles)) logger.record_tabular("test/loss", np.mean(losses)) logger.record_tabular("beta", beta) process = psutil.Process(os.getpid()) logger.record_tabular("RAM Usage (Mb)", int(process.memory_info().rss / 1000000)) num_active_dims = 0 for std in self.model.dist_std: if std > 0.15: num_active_dims += 1 logger.record_tabular("num_active_dims", num_active_dims) logger.dump_tabular() if save_vae: logger.save_itr_params(epoch, self.model, prefix='vae', save_anyway=True) # slow...
def offline_evaluate(self, epoch): for key, value in self.eval_statistics.items(): logger.record_tabular(key, value) self.need_to_update_eval_statistics = True
def evaluate(self, epoch): statistics = OrderedDict() for key, value in statistics.items(): logger.record_tabular(key, value) super().evaluate(epoch)