def _try_to_eval(self, epoch, eval_paths=None): logger.save_extra_data(self.get_extra_data_to_save(epoch)) if self._can_evaluate(): self.evaluate(epoch, eval_paths=eval_paths) params = self.get_epoch_snapshot(epoch) logger.save_itr_params(epoch, params) table_keys = logger.get_table_key_set() if self._old_table_keys is not None: assert table_keys == self._old_table_keys, ( "Table keys cannot change from iteration to iteration.") self._old_table_keys = table_keys logger.record_tabular( "Number of train steps total", self._n_train_steps_total, ) logger.record_tabular( "Number of env steps total", self._n_env_steps_total, ) logger.record_tabular( "Number of rollouts total", self._n_rollouts_total, ) times_itrs = gt.get_times().stamps.itrs train_time = times_itrs['train'][-1] sample_time = times_itrs['sample'][-1] eval_time = times_itrs['eval'][-1] if epoch > 0 else 0 epoch_time = train_time + sample_time + eval_time total_time = gt.get_times().total logger.record_tabular('Train Time (s)', train_time) logger.record_tabular('(Previous) Eval Time (s)', eval_time) logger.record_tabular('Sample Time (s)', sample_time) logger.record_tabular('Epoch Time (s)', epoch_time) logger.record_tabular('Total Train Time (s)', total_time) logger.record_tabular("Epoch", epoch) logger.dump_tabular(with_prefix=False, with_timestamp=False) else: logger.log("Skipping eval for now.")
def simulate_policy(args): data = torch.load(str(args.file)) #data = joblib.load(str(args.file)) policy = data['evaluation/policy'] env = NormalizedBoxEnv(HalfCheetahEnv()) #env = data['evaluation/env'] print("Policy loaded") if args.gpu: set_gpu_mode(True) policy.cuda() while True: path = rollout( env, policy, max_path_length=args.H, render=True, ) if hasattr(env, "log_diagnostics"): env.log_diagnostics([path]) logger.dump_tabular()
def create_policy(variant): bottom_snapshot = joblib.load(variant['bottom_path']) column_snapshot = joblib.load(variant['column_path']) policy = variant['combiner_class']( policy1=bottom_snapshot['naf_policy'], policy2=column_snapshot['naf_policy'], ) env = bottom_snapshot['env'] logger.save_itr_params(0, dict( policy=policy, env=env, )) path = rollout( env, policy, max_path_length=variant['max_path_length'], animated=variant['render'], ) env.log_diagnostics([path]) logger.dump_tabular()
def simulate_policy(args): data = joblib.load(args.file) # Pickle is internally used using joblib policy = data['policy'] env = data['env'] print("Policy loaded") if args.gpu: set_gpu_mode(True) policy.cuda() if isinstance(policy, PyTorchModule): policy.train(False) while True: path = rollout( env, policy, max_path_length=args.H, animated=False, ) if hasattr(env, "log_diagnostics"): env.log_diagnostics([path]) logger.dump_tabular()
def simulate_policy(args): data = torch.load(args.file) policy = data['evaluation/policy'] env = data['evaluation/env'] print("Policy loaded") if args.gpu: set_gpu_mode(True) policy.cuda() while True: path = multitask_rollout( env, policy, max_path_length=args.H, render=True, observation_key='observation', desired_goal_key='desired_goal', ) if hasattr(env, "log_diagnostics"): env.log_diagnostics([path]) logger.dump_tabular()
def pretrain_q_with_bc_data(self, batch_size): logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('pretrain_q.csv', relative_to_snapshot_dir=True) prev_time = time.time() for i in range(self.num_pretrain_steps): self.eval_statistics = dict() if i % self.pretraining_logging_period == 0: self._need_to_update_eval_statistics = True train_data = self.replay_buffer.random_batch(self.bc_batch_size) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] # goals = train_data['resampled_goals'] train_data['observations'] = obs # torch.cat((obs, goals), dim=1) train_data[ 'next_observations'] = next_obs # torch.cat((next_obs, goals), dim=1) self.train_from_torch(train_data, pretrain=True) if i % self.pretraining_logging_period == 0: self.eval_statistics["batch"] = i self.eval_statistics["epoch_time"] = time.time() - prev_time stats_with_prefix = add_prefix(self.eval_statistics, prefix="trainer/") logger.record_dict(stats_with_prefix) logger.dump_tabular(with_prefix=True, with_timestamp=False) prev_time = time.time() logger.remove_tabular_output( 'pretrain_q.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) self._need_to_update_eval_statistics = True self.eval_statistics = dict()
def simulate_policy(fpath, env_name, seed, max_path_length, num_eval_steps, headless, max_eps, verbose=True, pause=False): data = torch.load(fpath, map_location=ptu.device) policy = data['evaluation/policy'] policy.to(ptu.device) # make new env, reloading with data['evaluation/env'] seems to make bug env = gym.make(env_name, **{"headless": headless, "verbose": False}) env.seed(seed) if pause: input("Waiting to start.") path_collector = MdpPathCollector(env, policy) paths = path_collector.collect_new_paths( max_path_length, num_eval_steps, discard_incomplete_paths=True, ) if max_eps: paths = paths[:max_eps] if verbose: completions = sum([ info["completed"] for path in paths for info in path["env_infos"] ]) print("Completed {} out of {}".format(completions, len(paths))) # plt.plot(paths[0]["actions"]) # plt.show() # plt.plot(paths[2]["observations"]) # plt.show() logger.record_dict( eval_util.get_generic_path_information(paths), prefix="evaluation/", ) logger.dump_tabular() return paths
def evaluate(self): eval_statistics = OrderedDict() self.mlp.eval() self.encoder.eval() # for i in range(self.min_context_size, self.max_context_size+1): for i in range(1, 12): # prep the batches context_batch, mask, obs_task_params, classification_inputs, classification_labels = self._get_eval_batch( self.num_tasks_per_eval, i) # print(len(context_batch)) # print(len(context_batch[0])) post_dist = self.encoder(context_batch, mask) z = post_dist.sample() # N_tasks x Dim # z = post_dist.mean obs_task_params = Variable(ptu.from_numpy(obs_task_params)) # print(obs_task_params) if self.training_regression: preds = self.mlp(z) loss = self.mse(preds, obs_task_params) eval_statistics['Loss for %d' % i] = np.mean( ptu.get_numpy(loss)) else: repeated_z = z.repeat( 1, self.classification_batch_size_per_task).view( -1, z.size(1)) mlp_input = torch.cat([classification_inputs, repeated_z], dim=-1) preds = self.mlp(mlp_input) # loss = self.bce(preds, classification_labels) class_preds = (preds > 0).type(preds.data.type()) accuracy = (class_preds == classification_labels).type( torch.FloatTensor).mean() eval_statistics['Acc for %d' % i] = np.mean( ptu.get_numpy(accuracy)) for key, value in eval_statistics.items(): logger.record_tabular(key, value) logger.dump_tabular(with_prefix=False, with_timestamp=False)
def _try_to_eval(self, epoch): if self._can_evaluate(): # save if it's time to save if epoch % self.freq_saving == 0: logger.save_extra_data(self.get_extra_data_to_save(epoch)) params = self.get_epoch_snapshot(epoch) logger.save_itr_params(epoch, params) self.evaluate(epoch) logger.record_tabular( "Number of train calls total", self._n_train_steps_total, ) logger.record_tabular( "Number of env steps total", self._n_env_steps_total, ) logger.record_tabular( "Number of rollouts total", self._n_rollouts_total, ) times_itrs = gt.get_times().stamps.itrs train_time = times_itrs['train'][-1] sample_time = times_itrs['sample'][-1] eval_time = times_itrs['eval'][-1] if epoch > 0 else 0 epoch_time = train_time + sample_time + eval_time total_time = gt.get_times().total logger.record_tabular('Train Time (s)', train_time) logger.record_tabular('(Previous) Eval Time (s)', eval_time) logger.record_tabular('Sample Time (s)', sample_time) logger.record_tabular('Epoch Time (s)', epoch_time) logger.record_tabular('Total Train Time (s)', total_time) logger.record_tabular("Epoch", epoch) logger.dump_tabular(with_prefix=False, with_timestamp=False) else: logger.log("Skipping eval for now.")
def simulate_policy(args): data = joblib.load(args.file) import ipdb; ipdb.set_trace() policy = data['exploration_policy'] # ? TODO, eval ? env = data['env'] print("Policy loaded") if args.gpu: set_gpu_mode("gpu") policy.cuda() if isinstance(policy, PyTorchModule): policy.train(False) while True: path = rollout( env, policy, max_path_length=args.H, animated=True, ) if hasattr(env, "log_diagnostics"): env.log_diagnostics([path]) logger.dump_tabular()
def test_epoch( self, epoch, ): self.model.eval() val_losses = [] per_dim_losses = np.zeros((self.num_batches, self.y_train.shape[1])) for batch in range(self.num_batches): inputs_np, labels_np = self.random_batch(self.X_test, self.y_test, batch_size=self.batch_size) inputs, labels = ptu.Variable(ptu.from_numpy(inputs_np)), ptu.Variable(ptu.from_numpy(labels_np)) outputs = self.model(inputs) loss = self.criterion(outputs, labels) val_losses.append(loss.data[0]) per_dim_loss = np.mean(np.power(ptu.get_numpy(outputs - labels), 2), axis=0) per_dim_losses[batch] = per_dim_loss logger.record_tabular("test/epoch", epoch) logger.record_tabular("test/loss", np.mean(np.array(val_losses))) for i in range(self.y_train.shape[1]): logger.record_tabular("test/dim "+str(i)+" loss", np.mean(per_dim_losses[:, i])) logger.dump_tabular()
def simulate_policy(args): data = torch.load(args.file) policy = data['evaluation/policy'] env = data['evaluation/env'] print("Policy loaded") if args.gpu: set_gpu_mode(True) policy.cuda() num_fail = 0 for _ in range(args.ep): path = rollout( env, policy, max_path_length=args.H, render=False, sleep=args.S, ) if np.any(path['rewards'] == -1): num_fail += 1 if args.de: last_obs = np.moveaxis( np.reshape(path['observations'][-1], (3, 33, 33)), 0, -1) last_next_obs = np.moveaxis( np.reshape(path['next_observations'][-1], (3, 33, 33)), 0, -1) last_obs = (last_obs * 33 + 128).astype(np.uint8) last_next_obs = (last_next_obs * 33 + 128).astype(np.uint8) fig = plt.figure(figsize=(10, 10)) fig.add_subplot(2, 1, 1) plt.imshow(last_obs) fig.add_subplot(2, 1, 2) plt.imshow(last_next_obs) plt.show() plt.close() if hasattr(env, "log_diagnostics"): env.log_diagnostics([path]) logger.dump_tabular() print('number of failures:', num_fail)
def simulate_policy(args): data = joblib.load(args.file) qfs = data['qfs'] env = data['env'] print("Data loaded") if args.pause: import ipdb ipdb.set_trace() for qf in qfs: qf.train(False) paths = [] while True: paths.append( finite_horizon_rollout( env, qfs, max_path_length=args.H, max_T=args.mt, )) if hasattr(env, "log_diagnostics"): env.log_diagnostics(paths) logger.dump_tabular()
def simulate_policy(args): data = joblib.load(args.file) policy = data['policy'] # env = data['env'] from rlkit.envs.mujoco_manip_env import MujocoManipEnv env = MujocoManipEnv("SawyerLiftEnv", render=True) print("Policy loaded") if args.gpu: set_gpu_mode(True) policy.cuda() if isinstance(policy, PyTorchModule): policy.train(False) while True: path = rollout( env, policy, max_path_length=args.H, animated=True, ) if hasattr(env, "log_diagnostics"): env.log_diagnostics([path]) logger.dump_tabular()
def evaluate(self, epoch): """ Perform evaluation for this algorithm. :param epoch: The epoch number. :param exploration_paths: List of dicts, each representing a path. """ statistics = OrderedDict() train_batch = self.get_batch(training=True) statistics.update(self._statistics_from_batch(train_batch, "Train")) validation_batch = self.get_batch(training=False) statistics.update( self._statistics_from_batch(validation_batch, "Validation") ) statistics['QF Loss Validation - Train Gap'] = ( statistics['Validation QF Loss Mean'] - statistics['Train QF Loss Mean'] ) statistics['Epoch'] = epoch for key, value in statistics.items(): logger.record_tabular(key, value) logger.dump_tabular(with_prefix=False, with_timestamp=False)
def simulate_policy(args): manager_data = torch.load(args.manager_file) worker_data = torch.load(args.worker_file) policy = manager_data['evaluation/policy'] worker = worker_data['evaluation/policy'] env = NormalizedBoxEnv(gym.make(str(args.env))) print("Policy loaded") if args.gpu: set_gpu_mode(True) policy.cuda() import cv2 video = cv2.VideoWriter('ppo_dirichlet_diayn_bipedal_walker_hardcore.avi', cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), 30, (1200, 800)) index = 0 path = rollout( env, policy, worker, continuous=True, max_path_length=args.H, render=True, ) if hasattr(env, "log_diagnostics"): env.log_diagnostics([path]) logger.dump_tabular() for i, img in enumerate(path['images']): print(i) video.write(img[:, :, ::-1].astype(np.uint8)) # cv2.imwrite("frames/ppo_dirichlet_diayn_policy_bipedal_walker_hardcore/%06d.png" % index, img[:,:,::-1]) index += 1 video.release() print("wrote video")
def simulate_policy(args): if args.pause: import ipdb ipdb.set_trace() data = pickle.load(open(args.file, "rb")) policy = data['policy'] env = data['env'] print("Policy and environment loaded") if args.gpu: ptu.set_gpu_mode(True) policy.to(ptu.device) # if isinstance(env, VAEWrappedEnv): # env.mode(args.mode) if args.enable_render or hasattr(env, 'enable_render'): # some environments need to be reconfigured for visualization env.enable_render() policy.train(False) paths = [] for i in range(1000): paths.append( multitask_rollout( env, policy, max_path_length=args.H, animated=not args.hide, observation_key='observation', desired_goal_key='context', )) if hasattr(env, "log_diagnostics"): env.log_diagnostics(paths) if hasattr(env, "get_diagnostics"): for k, v in env.get_diagnostics(paths).items(): logger.record_tabular(k, v) logger.dump_tabular() if 'point2d' in type(env.wrapped_env).__name__.lower(): point2d(paths, args)
def simulate_policy(args): data = joblib.load(args.file) policy = data['mpc_controller'] env = data['env'] print("Policy loaded") if args.pause: import ipdb ipdb.set_trace() policy.cost_fn = env.cost_fn policy.env = env if args.T: policy.mpc_horizon = args.T paths = [] while True: paths.append( rollout( env, policy, max_path_length=args.H, animated=True, )) if hasattr(env, "log_diagnostics"): env.log_diagnostics(paths) logger.dump_tabular()
def simulate_policy(args): data = torch.load(args.file) print(data) # policy = data['evaluation/policy'] ''' I don't know why but they did not save the policy for evalutaion. Instead of that, I used trainer/policy ''' policy = data['trainer/policy'] env = data['evaluation/env'] print("Policy loaded") if args.gpu: set_gpu_mode(True) policy.cuda() while True: path = rollout( env, policy, max_path_length=args.H, render=True, ) if hasattr(env, "log_diagnostics"): env.log_diagnostics([path]) logger.dump_tabular()
def simulate_policy(args): data = joblib.load(args.file) policy = data['policy'] env = data['env'] print("Policy loaded") farmer = Farmer([('0.0.0.0', 1)]) env_to_sim = farmer.force_acq_env() if args.gpu: set_gpu_mode(True) policy.cuda() if isinstance(policy, PyTorchModule): policy.train(False) while True: path = rollout( env_to_sim, policy, max_path_length=args.H, animated=False, ) if hasattr(env, "log_diagnostics"): env.log_diagnostics([path]) logger.dump_tabular()
def example(variant): import torch import rlkit.torch.pytorch_util as ptu print("Starting") logger.log(torch.__version__) date_format = "%m/%d/%Y %H:%M:%S %Z" date = datetime.now(tz=pytz.utc) logger.log("start") logger.log("Current date & time is: {}".format(date.strftime(date_format))) logger.log("Cuda available: {}".format(torch.cuda.is_available())) if torch.cuda.is_available(): x = torch.randn(3) logger.log(str(x.to(ptu.device))) date = date.astimezone(timezone("US/Pacific")) logger.log("Local date & time is: {}".format(date.strftime(date_format))) for i in range(variant["num_seconds"]): logger.log("Tick, {}".format(i)) time.sleep(1) logger.log("end") logger.log("Local date & time is: {}".format(date.strftime(date_format))) logger.log("start mujoco") from gym.envs.mujoco import HalfCheetahEnv e = HalfCheetahEnv() img = e.sim.render(32, 32) logger.log(str(sum(img))) logger.log("end mujoco") logger.record_tabular("Epoch", 1) logger.dump_tabular() logger.record_tabular("Epoch", 2) logger.dump_tabular() logger.record_tabular("Epoch", 3) logger.dump_tabular() print("Done")
def pretrain_q_with_bc_data(self): """ :return: """ logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('pretrain_q.csv', relative_to_snapshot_dir=True) self.update_policy = False # first train only the Q function for i in range(self.q_num_pretrain1_steps): self.eval_statistics = dict() train_data = self.replay_buffer.random_batch(self.bc_batch_size) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] # goals = train_data['resampled_goals'] train_data['observations'] = obs # torch.cat((obs, goals), dim=1) train_data[ 'next_observations'] = next_obs # torch.cat((next_obs, goals), dim=1) self.train_from_torch(train_data, pretrain=True) if i % self.pretraining_logging_period == 0: stats_with_prefix = add_prefix(self.eval_statistics, prefix="trainer/") logger.record_dict(stats_with_prefix) logger.dump_tabular(with_prefix=True, with_timestamp=False) self.update_policy = True # then train policy and Q function together prev_time = time.time() for i in range(self.q_num_pretrain2_steps): self.eval_statistics = dict() if i % self.pretraining_logging_period == 0: self._need_to_update_eval_statistics = True train_data = self.replay_buffer.random_batch(self.bc_batch_size) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] # goals = train_data['resampled_goals'] train_data['observations'] = obs # torch.cat((obs, goals), dim=1) train_data[ 'next_observations'] = next_obs # torch.cat((next_obs, goals), dim=1) self.train_from_torch(train_data, pretrain=True) if i % self.pretraining_logging_period == 0: self.eval_statistics["batch"] = i self.eval_statistics["epoch_time"] = time.time() - prev_time stats_with_prefix = add_prefix(self.eval_statistics, prefix="trainer/") logger.record_dict(stats_with_prefix) logger.dump_tabular(with_prefix=True, with_timestamp=False) prev_time = time.time() logger.remove_tabular_output( 'pretrain_q.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) self._need_to_update_eval_statistics = True self.eval_statistics = dict() if self.post_pretrain_hyperparams: self.set_algorithm_weights(**self.post_pretrain_hyperparams)
def pretrain_policy_with_bc( self, policy, train_buffer, test_buffer, steps, label="policy", ): """Given a policy, first get its optimizer, then run the policy on the train buffer, get the losses, and back propagate the loss. After training on a batch, test on the test buffer and get the statistics :param policy: :param train_buffer: :param test_buffer: :param steps: :param label: :return: """ logger.remove_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'pretrain_%s.csv' % label, relative_to_snapshot_dir=True, ) optimizer = self.optimizers[policy] prev_time = time.time() for i in range(steps): train_policy_loss, train_logp_loss, train_mse_loss, train_stats = self.run_bc_batch( train_buffer, policy) train_policy_loss = train_policy_loss * self.bc_weight optimizer.zero_grad() train_policy_loss.backward() optimizer.step() test_policy_loss, test_logp_loss, test_mse_loss, test_stats = self.run_bc_batch( test_buffer, policy) test_policy_loss = test_policy_loss * self.bc_weight if i % self.pretraining_logging_period == 0: stats = { "pretrain_bc/batch": i, "pretrain_bc/Train Logprob Loss": ptu.get_numpy(train_logp_loss), "pretrain_bc/Test Logprob Loss": ptu.get_numpy(test_logp_loss), "pretrain_bc/Train MSE": ptu.get_numpy(train_mse_loss), "pretrain_bc/Test MSE": ptu.get_numpy(test_mse_loss), "pretrain_bc/train_policy_loss": ptu.get_numpy(train_policy_loss), "pretrain_bc/test_policy_loss": ptu.get_numpy(test_policy_loss), "pretrain_bc/epoch_time": time.time() - prev_time, } logger.record_dict(stats) logger.dump_tabular(with_prefix=True, with_timestamp=False) pickle.dump( self.policy, open(logger.get_snapshot_dir() + '/bc_%s.pkl' % label, "wb")) prev_time = time.time() logger.remove_tabular_output( 'pretrain_%s.csv' % label, relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) if self.post_bc_pretrain_hyperparams: self.set_algorithm_weights(**self.post_bc_pretrain_hyperparams)
def offpolicy_inference(): import time from gym import wrappers filename = str(uuid.uuid4()) gpu = True env, _, _ = prepare_env(args.env_name, args.visionmodel_path, **env_kwargs) snapshot = torch.load(args.load_name) policy = snapshot['evaluation/policy'] if args.env_name.find('doorenv') > -1: policy.knob_noisy = args.knob_noisy policy.nn = env._wrapped_env.nn policy.visionnet_input = env_kwargs['visionnet_input'] epi_counter = 1 dooropen_counter = 0 total_time = 0 test_num = 100 if evaluation: render = False else: if not args.unity: render = True else: render = False start_time = int(time.mktime(time.localtime())) if gpu: set_gpu_mode(True) while True: if args.env_name.find('doorenv') > -1: path, door_opened, opening_time = rollout( env, policy, max_path_length=512, doorenv=True, render=render, evaluate=True, ) print("done first") if hasattr(env, "log_diagnostics"): env.log_diagnostics([path]) logger.dump_tabular() if evaluation: env, _, _ = prepare_env(args.env_name, args.visionmodel_path, **env_kwargs) if door_opened: dooropen_counter += 1 total_time += opening_time eval_print(dooropen_counter, epi_counter, start_time, total_time) else: path = rollout( env, policy, max_path_length=512, doorenv=False, render=render, ) if hasattr(env, "log_diagnostics"): env.log_diagnostics([path]) logger.dump_tabular() if evaluation: print("{} ep end >>>>>>>>>>>>>>>>>>>>>>>>".format(epi_counter)) epi_counter += 1 if args.env_name.find('door') > -1 and epi_counter > test_num: eval_print(dooropen_counter, epi_counter, start_time, total_time) break
else: max_tau = args.mtau env = data['env'] policy = data['policy'] policy.train(False) if args.gpu: ptu.set_gpu_mode(True) policy.cuda() while True: paths = [] for _ in range(args.nrolls): goal = env.sample_goal_for_rollout() path = multitask_rollout( env, policy, init_tau=max_tau, goal=goal, max_path_length=args.H, animated=not args.hide, cycle_tau=True, decrement_tau=True, ) paths.append(path) env.log_diagnostics(paths) for key, value in get_generic_path_information(paths).items(): logger.record_tabular(key, value) logger.dump_tabular()
def train_pixelcnn( vqvae=None, vqvae_path=None, num_epochs=100, batch_size=32, n_layers=15, dataset_path=None, save=True, save_period=10, cached_dataset_path=False, trainer_kwargs=None, model_kwargs=None, data_filter_fn=lambda x: x, debug=False, data_size=float('inf'), num_train_batches_per_epoch=None, num_test_batches_per_epoch=None, train_img_loader=None, test_img_loader=None, ): trainer_kwargs = {} if trainer_kwargs is None else trainer_kwargs model_kwargs = {} if model_kwargs is None else model_kwargs # Load VQVAE + Define Args if vqvae is None: vqvae = load_local_or_remote_file(vqvae_path) vqvae.to(ptu.device) vqvae.eval() root_len = vqvae.root_len num_embeddings = vqvae.num_embeddings embedding_dim = vqvae.embedding_dim cond_size = vqvae.num_embeddings imsize = vqvae.imsize discrete_size = root_len * root_len representation_size = embedding_dim * discrete_size input_channels = vqvae.input_channels imlength = imsize * imsize * input_channels log_dir = logger.get_snapshot_dir() # Define data loading info new_path = osp.join(log_dir, 'pixelcnn_data.npy') def prep_sample_data(cached_path): data = load_local_or_remote_file(cached_path).item() train_data = data['train'] test_data = data['test'] return train_data, test_data def encode_dataset(path, object_list): data = load_local_or_remote_file(path) data = data.item() data = data_filter_fn(data) all_data = [] n = min(data["observations"].shape[0], data_size) for i in tqdm(range(n)): obs = ptu.from_numpy(data["observations"][i] / 255.0) latent = vqvae.encode(obs, cont=False) all_data.append(latent) encodings = ptu.get_numpy(torch.stack(all_data, dim=0)) return encodings if train_img_loader: _, test_loader, test_batch_loader = create_conditional_data_loader( test_img_loader, 80, vqvae, "test2") # 80 _, train_loader, train_batch_loader = create_conditional_data_loader( train_img_loader, 2000, vqvae, "train2") # 2000 else: if cached_dataset_path: train_data, test_data = prep_sample_data(cached_dataset_path) else: train_data = encode_dataset(dataset_path['train'], None) # object_list) test_data = encode_dataset(dataset_path['test'], None) dataset = {'train': train_data, 'test': test_data} np.save(new_path, dataset) _, _, train_loader, test_loader, _ = \ rlkit.torch.vae.pixelcnn_utils.load_data_and_data_loaders(new_path, 'COND_LATENT_BLOCK', batch_size) #train_dataset = InfiniteBatchLoader(train_loader) #test_dataset = InfiniteBatchLoader(test_loader) print("Finished loading data") model = GatedPixelCNN(num_embeddings, root_len**2, n_classes=representation_size, **model_kwargs).to(ptu.device) trainer = PixelCNNTrainer( model, vqvae, batch_size=batch_size, **trainer_kwargs, ) print("Starting training") BEST_LOSS = 999 for epoch in range(num_epochs): should_save = (epoch % save_period == 0) and (epoch > 0) trainer.train_epoch(epoch, train_loader, num_train_batches_per_epoch) trainer.test_epoch(epoch, test_loader, num_test_batches_per_epoch) test_data = test_batch_loader.random_batch(bz)["x"] train_data = train_batch_loader.random_batch(bz)["x"] trainer.dump_samples(epoch, test_data, test=True) trainer.dump_samples(epoch, train_data, test=False) if should_save: logger.save_itr_params(epoch, model) stats = trainer.get_diagnostics() cur_loss = stats["test/loss"] if cur_loss < BEST_LOSS: BEST_LOSS = cur_loss vqvae.set_pixel_cnn(model) logger.save_extra_data(vqvae, 'best_vqvae', mode='torch') else: return vqvae for k, v in stats.items(): logger.record_tabular(k, v) logger.dump_tabular() trainer.end_epoch(epoch) return vqvae
def test_epoch( self, epoch, sample_batch=None, key=None, save_reconstruction=True, save_vae=True, from_rl=False, save_prefix='r', only_train_vae=False, ): self.model.eval() losses = [] log_probs = [] triplet_losses = [] matching_losses = [] vae_matching_losses = [] kles = [] lstm_kles = [] ae_losses = [] contrastive_losses = [] beta = float(self.beta_schedule.get_value(epoch)) for batch_idx in range(10): # print(batch_idx) if sample_batch is not None: data = sample_batch(self.batch_size, key=key) next_obs = data['next_obs'] else: next_obs = self.get_batch(epoch=epoch) reconstructions, obs_distribution_params, vae_latent_distribution_params, lstm_latent_encodings = self.model( next_obs) latent_encodings = lstm_latent_encodings vae_mu = vae_latent_distribution_params[0] # this is lstm inputs latent_distribution_params = vae_latent_distribution_params triplet_loss = ptu.zeros(1) for tri_idx, triplet_type in enumerate(self.triplet_loss_type): if triplet_type == 1 and not only_train_vae: triplet_loss += self.triplet_loss_coef[ tri_idx] * self.triplet_loss(latent_encodings) elif triplet_type == 2 and not only_train_vae: triplet_loss += self.triplet_loss_coef[ tri_idx] * self.triplet_loss_2(next_obs) elif triplet_type == 3 and not only_train_vae: triplet_loss += self.triplet_loss_coef[ tri_idx] * self.triplet_loss_3(next_obs) if self.matching_loss_coef > 0 and not only_train_vae: matching_loss = self.matching_loss(next_obs) else: matching_loss = ptu.zeros(1) if self.vae_matching_loss_coef > 0: matching_loss_vae = self.matching_loss_vae(next_obs) else: matching_loss_vae = ptu.zeros(1) if self.contrastive_loss_coef > 0 and not only_train_vae: contrastive_loss = self.contrastive_loss(next_obs) else: contrastive_loss = ptu.zeros(1) log_prob = self.model.logprob(next_obs, obs_distribution_params) kle = self.model.kl_divergence(latent_distribution_params) lstm_kle = ptu.zeros(1) ae_loss = F.mse_loss( latent_encodings.view((-1, self.model.representation_size)), vae_mu.detach()) ae_losses.append(ae_loss.item()) loss = -self.recon_loss_coef * log_prob + beta * kle + \ self.matching_loss_coef * matching_loss + self.ae_loss_coef * ae_loss + triplet_loss + \ self.vae_matching_loss_coef * matching_loss_vae + self.contrastive_loss_coef * contrastive_loss losses.append(loss.item()) log_probs.append(log_prob.item()) triplet_losses.append(triplet_loss.item()) matching_losses.append(matching_loss.item()) vae_matching_losses.append(matching_loss_vae.item()) kles.append(kle.item()) lstm_kles.append(lstm_kle.item()) contrastive_losses.append(contrastive_loss.item()) if batch_idx == 0 and save_reconstruction: seq_len, batch_size, feature_size = next_obs.shape show_obs = next_obs[0][:8] reconstructions = reconstructions.view( (seq_len, batch_size, feature_size))[0][:8] comparison = torch.cat([ show_obs.narrow(start=0, length=self.imlength, dim=1).contiguous().view( -1, self.input_channels, self.imsize, self.imsize).transpose(2, 3), reconstructions.view( -1, self.input_channels, self.imsize, self.imsize, ).transpose(2, 3) ]) save_dir = osp.join(logger.get_snapshot_dir(), '{}{}.png'.format(save_prefix, epoch)) save_image(comparison.data.cpu(), save_dir, nrow=8) self.eval_statistics['epoch'] = epoch self.eval_statistics['test/log prob'] = np.mean(log_probs) self.eval_statistics['test/triplet loss'] = np.mean(triplet_losses) self.eval_statistics['test/vae matching loss'] = np.mean( vae_matching_losses) self.eval_statistics['test/matching loss'] = np.mean(matching_losses) self.eval_statistics['test/KL'] = np.mean(kles) self.eval_statistics['test/lstm KL'] = np.mean(lstm_kles) self.eval_statistics['test/loss'] = np.mean(losses) self.eval_statistics['test/contrastive loss'] = np.mean( contrastive_losses) self.eval_statistics['beta'] = beta self.eval_statistics['test/ae loss'] = np.mean(ae_losses) if not from_rl: for k, v in self.eval_statistics.items(): logger.record_tabular(k, v) logger.dump_tabular() if save_vae: logger.save_itr_params(epoch, self.model) torch.cuda.empty_cache()
def train(dataset_generator, n_start_samples, projection=project_samples_square_np, n_samples_to_add_per_epoch=1000, n_epochs=100, z_dim=1, hidden_size=32, save_period=10, append_all_data=True, full_variant=None, dynamics_noise=0, decoder_output_var='learned', num_bins=5, skew_config=None, use_perfect_samples=False, use_perfect_density=False, vae_reset_period=0, vae_kwargs=None, use_dataset_generator_first_epoch=True, **kwargs): """ Sanitize Inputs """ assert skew_config is not None if not (use_perfect_density and use_perfect_samples): assert vae_kwargs is not None if vae_kwargs is None: vae_kwargs = {} report = HTMLReport( logger.get_snapshot_dir() + '/report.html', images_per_row=10, ) dynamics = Dynamics(projection, dynamics_noise) if full_variant: report.add_header("Variant") report.add_text( json.dumps( ppp.dict_to_safe_json(full_variant, sort=True), indent=2, )) vae, decoder, decoder_opt, encoder, encoder_opt = get_vae( decoder_output_var, hidden_size, z_dim, vae_kwargs, ) vae.to(ptu.device) epochs = [] losses = [] kls = [] log_probs = [] hist_heatmap_imgs = [] vae_heatmap_imgs = [] sample_imgs = [] entropies = [] tvs_to_uniform = [] entropy_gains_from_reweighting = [] p_theta = Histogram(num_bins) p_new = Histogram(num_bins) orig_train_data = dataset_generator(n_start_samples) train_data = orig_train_data start = time.time() for epoch in progressbar(range(n_epochs)): p_theta = Histogram(num_bins) if epoch == 0 and use_dataset_generator_first_epoch: vae_samples = dataset_generator(n_samples_to_add_per_epoch) else: if use_perfect_samples and epoch != 0: # Ideally the VAE = p_new, but in practice, it won't be... vae_samples = p_new.sample(n_samples_to_add_per_epoch) else: vae_samples = vae.sample(n_samples_to_add_per_epoch) projected_samples = dynamics(vae_samples) if append_all_data: train_data = np.vstack((train_data, projected_samples)) else: train_data = np.vstack((orig_train_data, projected_samples)) p_theta.fit(train_data) if use_perfect_density: prob = p_theta.compute_density(train_data) else: prob = vae.compute_density(train_data) all_weights = prob_to_weight(prob, skew_config) p_new.fit(train_data, weights=all_weights) if epoch == 0 or (epoch + 1) % save_period == 0: epochs.append(epoch) report.add_text("Epoch {}".format(epoch)) hist_heatmap_img = visualize_histogram(p_theta, skew_config, report) vae_heatmap_img = visualize_vae( vae, skew_config, report, resolution=num_bins, ) sample_img = visualize_vae_samples( epoch, train_data, vae, report, dynamics, ) visualize_samples( p_theta.sample(n_samples_to_add_per_epoch), report, title="P Theta/RB Samples", ) visualize_samples( p_new.sample(n_samples_to_add_per_epoch), report, title="P Adjusted Samples", ) hist_heatmap_imgs.append(hist_heatmap_img) vae_heatmap_imgs.append(vae_heatmap_img) sample_imgs.append(sample_img) report.save() Image.fromarray( hist_heatmap_img).save(logger.get_snapshot_dir() + '/hist_heatmap{}.png'.format(epoch)) Image.fromarray( vae_heatmap_img).save(logger.get_snapshot_dir() + '/hist_heatmap{}.png'.format(epoch)) Image.fromarray(sample_img).save(logger.get_snapshot_dir() + '/samples{}.png'.format(epoch)) """ train VAE to look like p_new """ if sum(all_weights) == 0: all_weights[:] = 1 if vae_reset_period > 0 and epoch % vae_reset_period == 0: vae, decoder, decoder_opt, encoder, encoder_opt = get_vae( decoder_output_var, hidden_size, z_dim, vae_kwargs, ) vae.to(ptu.device) vae.fit(train_data, weights=all_weights) epoch_stats = vae.get_epoch_stats() losses.append(np.mean(epoch_stats['losses'])) kls.append(np.mean(epoch_stats['kls'])) log_probs.append(np.mean(epoch_stats['log_probs'])) entropies.append(p_theta.entropy()) tvs_to_uniform.append(p_theta.tv_to_uniform()) entropy_gain = p_new.entropy() - p_theta.entropy() entropy_gains_from_reweighting.append(entropy_gain) for k in sorted(epoch_stats.keys()): logger.record_tabular(k, epoch_stats[k]) logger.record_tabular("Epoch", epoch) logger.record_tabular('Entropy ', p_theta.entropy()) logger.record_tabular('KL from uniform', p_theta.kl_from_uniform()) logger.record_tabular('TV to uniform', p_theta.tv_to_uniform()) logger.record_tabular('Entropy gain from reweight', entropy_gain) logger.record_tabular('Total Time (s)', time.time() - start) logger.dump_tabular() logger.save_itr_params( epoch, { 'vae': vae, 'train_data': train_data, 'vae_samples': vae_samples, 'dynamics': dynamics, }) report.add_header("Training Curves") plot_curves( [ ("Training Loss", losses), ("KL", kls), ("Log Probs", log_probs), ("Entropy Gain from Reweighting", entropy_gains_from_reweighting), ], report, ) plot_curves( [ ("Entropy", entropies), ("TV to Uniform", tvs_to_uniform), ], report, ) report.add_text("Max entropy: {}".format(p_theta.max_entropy())) report.save() for filename, imgs in [ ("hist_heatmaps", hist_heatmap_imgs), ("vae_heatmaps", vae_heatmap_imgs), ("samples", sample_imgs), ]: video = np.stack(imgs) vwrite( logger.get_snapshot_dir() + '/{}.mp4'.format(filename), video, ) local_gif_file_path = '{}.gif'.format(filename) gif_file_path = '{}/{}'.format(logger.get_snapshot_dir(), local_gif_file_path) gif(gif_file_path, video) report.add_image(local_gif_file_path, txt=filename, is_url=True) report.save()
def _log_stats(self, epoch): logger.log(f"Epoch {epoch} finished", with_timestamp=True) """ Replay Buffer """ logger.record_dict( self.replay_buffer.get_diagnostics(), prefix="replay_buffer/" ) """ Trainer """ logger.record_dict(self.trainer.get_diagnostics(), prefix="trainer/") """ Exploration """ logger.record_dict( self.expl_data_collector.get_diagnostics(), prefix="exploration/" ) expl_paths = self.expl_data_collector.get_epoch_paths() if len(expl_paths) > 0: if hasattr(self.expl_env, "get_diagnostics"): logger.record_dict( self.expl_env.get_diagnostics(expl_paths), prefix="exploration/", ) logger.record_dict( eval_util.get_generic_path_information(expl_paths), prefix="exploration/", ) """ Evaluation """ logger.record_dict( self.eval_data_collector.get_diagnostics(), prefix="evaluation/", ) eval_paths = self.eval_data_collector.get_epoch_paths() if hasattr(self.eval_env, "get_diagnostics"): logger.record_dict( self.eval_env.get_diagnostics(eval_paths), prefix="evaluation/", ) logger.record_dict( eval_util.get_generic_path_information(eval_paths), prefix="evaluation/", ) """ Misc """ gt.stamp("logging") timings = _get_epoch_timings() timings["time/training and exploration (s)"] = self.total_train_expl_time logger.record_dict(timings) logger.record_tabular("Epoch", epoch) logger.dump_tabular(with_prefix=False, with_timestamp=False)
def _log_stats(self, epoch): logger.log("Epoch {} finished".format(epoch), with_timestamp=True) """ Replay Buffer """ logger.record_dict( self.replay_buffer.get_diagnostics(), global_step=epoch, prefix="replay_buffer/", ) """ Trainer """ logger.record_dict(self.trainer.get_diagnostics(), global_step=epoch, prefix="trainer/") """ Exploration """ logger.record_dict( self.expl_data_collector.get_diagnostics(), global_step=epoch, prefix="exploration/", ) expl_paths = self.expl_data_collector.get_epoch_paths() if hasattr(self.expl_env, "get_diagnostics"): logger.record_dict( self.expl_env.get_diagnostics(expl_paths), global_step=epoch, prefix="exploration/", ) logger.record_dict( eval_util.get_generic_path_information(expl_paths), global_step=epoch, prefix="exploration/", ) """ Evaluation """ logger.record_dict( self.eval_data_collector.get_diagnostics(), global_step=epoch, prefix="evaluation/", ) eval_paths = self.eval_data_collector.get_epoch_paths() if hasattr(self.eval_env, "get_diagnostics"): logger.record_dict( self.eval_env.get_diagnostics(eval_paths), global_step=epoch, prefix="evaluation/", ) logger.record_dict( eval_util.get_generic_path_information(eval_paths), global_step=epoch, prefix="evaluation/", ) """ Misc """ gt.stamp("logging") logger.record_dict(_get_epoch_timings(), global_step=epoch) logger.record_tabular("Epoch", epoch) logger.dump_tabular(with_prefix=False, with_timestamp=False)