def __getitem__(self, idx): traj_i = idx // self.traj_length trans_i = idx % self.traj_length x = Image.fromarray(self.data['observations'][traj_i, trans_i].reshape( 48, 48, 3), mode='RGB') c = Image.fromarray(self.data['env'][traj_i].reshape(48, 48, 3), mode='RGB') # upsampling gives bad images so random resizing params set to 1 for now # crop = self.crop.get_params(c, (0.9, 0.9), (1, 1)) crop = self.crop.get_params(c, (1, 1), (1, 1)) # jitter = self.jitter.get_params((0.5,1.5), (0.9,1.1), (0.9,1.1), (-0.1,0.1)) jitter = self.jitter.get_params(0.5, 0.1, 0.1, 0.1) x = jitter( F.resized_crop(x, crop[0], crop[1], crop[2], crop[3], (48, 48), Image.BICUBIC)) c = jitter( F.resized_crop(c, crop[0], crop[1], crop[2], crop[3], (48, 48), Image.BICUBIC)) x_t = normalize_image(np.array(x).flatten()).squeeze() env = normalize_image(np.array(c).flatten()).squeeze() data_dict = { 'x_t': x_t, 'env': env, } return data_dict
def random_batch(self, batch_size): traj_i = np.random.choice(self.size, batch_size) trans_i = np.random.choice(self.traj_length - 1, batch_size) try: env = normalize_image(self.data['env'][traj_i, :]) except: env = normalize_image(self.data['observations'][traj_i, 0, :]) x_t = normalize_image(self.data['observations'][traj_i, trans_i, :]) x_next = normalize_image(self.data['observations'][traj_i, trans_i + 1, :]) episode_num = np.random.randint(0, self.size) episode_obs = normalize_image( self.data['observations'][episode_num, :8, :]) data_dict = { 'x_t': x_t, 'x_next': x_next, 'env': env, 'actions': self.data['actions'][traj_i, trans_i, :], 'episode_obs': episode_obs, 'episode_acts': self.data['actions'][episode_num, :7, :], } return np_to_pytorch_batch(data_dict)
def add_heatmap_imgs_to_o_dict(env, agent, observation_key, full_o, v_function, vectorized=False): o = full_o[observation_key] goal_grid = env.get_mesh_grid(observation_key) o_grid = np.c_[np.tile(o, (len(goal_grid), 1)), goal_grid] v_vals, indiv_v_vals = v_function(o_grid) v_vals = ptu.get_numpy(v_vals) indiv_v_vals = [ptu.get_numpy(indiv_v_val) for indiv_v_val in indiv_v_vals] vmin = np.array(indiv_v_vals).min() vmax = np.array(indiv_v_vals).max() # Assuming square observation space, how many points on x axis vary_len = int(len(goal_grid)**(1 / 2)) if not vectorized: vmin = min(np.array(indiv_v_vals).min(), v_vals.min()) vmax = max(np.array(indiv_v_vals).min(), v_vals.max()) full_o['v_vals'] = normalize_image( env.get_image_plt(v_vals.reshape((vary_len, vary_len)), imsize=env.imsize, vmin=vmin, vmax=vmax)) for goal_dim in range(len(indiv_v_vals)): full_o['v_vals_dim_{}'.format(goal_dim)] = normalize_image( env.get_image_plt(indiv_v_vals[goal_dim].reshape( (vary_len, vary_len)), imsize=env.imsize, vmin=vmin, vmax=vmax))
def __getitem__(self, idx): traj_i = idx // self.traj_length trans_i = idx % self.traj_length env = normalize_image(self.data['env'][traj_i, :]) x_t = normalize_image(self.data['observations'][traj_i, trans_i, :]) data_dict = { 'x_t': x_t, 'env': env, } return data_dict
def get_dataset_stats(self, data): torch_input = ptu.from_numpy(normalize_image(data)) mus, log_vars = self.model.encode(torch_input) mus = ptu.get_numpy(mus) mean = np.mean(mus, axis=0) std = np.std(mus, axis=0) return mus, mean, std
def get_batch(self, test_data=False, epoch=None): if self.use_parallel_dataloading: if test_data: dataloader = self.test_dataloader else: dataloader = self.train_dataloader samples = next(dataloader).to(ptu.device) return samples dataset = self.test_dataset if test_data else self.train_dataset skew = False if epoch is not None: skew = (self.start_skew_epoch < epoch) if not test_data and self.skew_dataset and skew: probs = self._train_weights / np.sum(self._train_weights) ind = np.random.choice( len(probs), self.batch_size, p=probs, ) else: ind = np.random.randint(0, len(dataset), self.batch_size) samples = normalize_image(dataset[ind, :]) if self.normalize: samples = ((samples - self.train_data_mean) + 1) / 2 if self.background_subtract: samples = samples - self.train_data_mean return ptu.from_numpy(samples)
def vis(args): imgs = np.load(args.ds) vae = joblib.load(args.file) losses = [] for i, image_obs in enumerate(imgs): img = normalize_image(image_obs) recon, *_ = eval_np(vae, img) error = ((recon - img)**2).sum() losses.append((i, error)) losses.sort(key=lambda x: -x[1]) for rank, (i, error) in enumerate(losses[:NUM_SHOWN]): image_obs = imgs[i] recon, *_ = eval_np(vae, normalize_image(image_obs)) img = image_obs.reshape(3, 48, 48).transpose() rimg = recon.reshape(3, 48, 48).transpose() cv2.imshow( "image, rank {}, loss {}".format(rank, error), img ) cv2.imshow( "recon, rank {}, loss {}".format(rank, error), rimg ) print("rank {}\terror {}".format(rank, error)) for j, (i, error) in enumerate(losses[-NUM_SHOWN:]): rank = len(losses) - j - 1 image_obs = imgs[i] recon, *_ = eval_np(vae, normalize_image(image_obs)) img = image_obs.reshape(3, 48, 48).transpose() rimg = recon.reshape(3, 48, 48).transpose() cv2.imshow( "image, rank {}, loss {}".format(rank, error), img ) cv2.imshow( "recon, rank {}, loss {}".format(rank, error), rimg ) print("rank {}\terror {}".format(rank, error)) cv2.waitKey(0) cv2.destroyAllWindows()
def random_batch(self, batch_size): i = np.random.choice(self.size, batch_size, replace=(self.size < batch_size)) obs = self.data[i, :] if self.normalize: obs = normalize_image(obs) data_dict = { 'observations': obs, } return np_to_pytorch_batch(data_dict)
def sample_goals(self, batch_size): if self.goal_buffer._size == 0: return None goal_idxs = self.goal_buffer._sample_indices(batch_size) goals = { 'latent_desired_goal': self.goal_buffer._next_obs['latent_observation'][goal_idxs], 'image_desired_goal': normalize_image( self.goal_buffer._next_obs['image_observation'][goal_idxs]) } return goals
def random_batch(self, batch_size): traj_i = np.random.choice(self.size, batch_size) trans_i = np.random.choice(self.traj_length, batch_size) # conditioning = np.random.choice(self.traj_length, batch_size) # env = normalize_image(self.data['observations'][traj_i, conditioning, :]) try: env = normalize_image(self.data['env'][traj_i, :]) except: env = normalize_image(self.data['observations'][traj_i, 0, :]) x_t = normalize_image(self.data['observations'][traj_i, trans_i, :]) episode_num = np.random.randint(0, self.size) episode_obs = normalize_image( self.data['observations'][episode_num, :8, :]) data_dict = { 'x_t': x_t, 'env': env, 'episode_obs': episode_obs, } return np_to_pytorch_batch(data_dict)
def log_loss_under_uniform(self, model, data, priority_function_kwargs): import torch.nn.functional as F log_probs_prior = [] log_probs_biased = [] log_probs_importance = [] kles = [] mses = [] for i in range(0, data.shape[0], self.batch_size): img = normalize_image(data[i:min(data.shape[0], i + self.batch_size), :]) torch_img = ptu.from_numpy(img) reconstructions, obs_distribution_params, latent_distribution_params = self.model( torch_img) priority_function_kwargs['sampling_method'] = 'true_prior_sampling' log_p, log_q, log_d = compute_log_p_log_q_log_d( model, img, **priority_function_kwargs) log_prob_prior = log_d.mean() priority_function_kwargs['sampling_method'] = 'biased_sampling' log_p, log_q, log_d = compute_log_p_log_q_log_d( model, img, **priority_function_kwargs) log_prob_biased = log_d.mean() priority_function_kwargs['sampling_method'] = 'importance_sampling' log_p, log_q, log_d = compute_log_p_log_q_log_d( model, img, **priority_function_kwargs) log_prob_importance = (log_p - log_q + log_d).mean() kle = model.kl_divergence(latent_distribution_params) mse = F.mse_loss(torch_img, reconstructions, reduction='elementwise_mean') mses.append(mse.item()) kles.append(kle.item()) log_probs_prior.append(log_prob_prior.item()) log_probs_biased.append(log_prob_biased.item()) log_probs_importance.append(log_prob_importance.item()) logger.record_tabular("Uniform Data Log Prob (True Prior)", np.mean(log_probs_prior)) logger.record_tabular("Uniform Data Log Prob (Biased)", np.mean(log_probs_biased)) logger.record_tabular("Uniform Data Log Prob (Importance)", np.mean(log_probs_importance)) logger.record_tabular("Uniform Data KL", np.mean(kles)) logger.record_tabular("Uniform Data MSE", np.mean(mses))
def add_heatmap_img_to_o_dict(env, agent, observation_key, full_o, v_function): o = full_o[observation_key] goal_grid = env.get_mesh_grid(observation_key) o_grid = np.c_[np.tile(o, (len(goal_grid), 1)), goal_grid] v_vals = v_function(o_grid) v_vals = ptu.get_numpy(v_vals) # Assuming square observation space, how many points on x axis vary_len = int(len(goal_grid)**(1 / 2)) vmin = v_vals.min() vmax = v_vals.max() full_o['v_vals'] = normalize_image( env.get_image_plt(v_vals.reshape((vary_len, vary_len)), imsize=env.imsize, vmin=vmin, vmax=vmax))
def viz_rewards(model, id, savefile=None): clip_data = load_clip(id) t_clip = transform_batch(clip_data) batch = ptu.from_numpy(normalize_image(t_clip)) batch = batch.to("cuda") z = ptu.get_numpy(model.encoder(batch).cpu()) z_goal = z[-1, :] distances = [] for t in range(len(z)): d = np.linalg.norm(z[t, :] - z_goal) distances.append(d) plt.figure() plt.plot(distances) if savefile: plt.savefig(savefile) return np.array(distances)
def _dump_imgs_and_reconstructions(self, idxs, filename): imgs = [] recons = [] for i in idxs: img_np = self.train_dataset[i] img_torch = ptu.from_numpy(normalize_image(img_np)) recon, *_ = self.model(img_torch.view(1, -1)) img = img_torch.view(self.input_channels, self.imsize, self.imsize).transpose(1, 2) rimg = recon.view(self.input_channels, self.imsize, self.imsize).transpose(1, 2) imgs.append(img) recons.append(rimg) all_imgs = torch.stack(imgs + recons) save_file = osp.join(self.log_dir, filename) save_image( all_imgs.data, save_file, nrow=len(idxs), )
def dump_uniform_imgs_and_reconstructions(self, dataset, epoch): idxs = np.random.choice(range(dataset.shape[0]), 4) filename = 'uniform{}.png'.format(epoch) imgs = [] recons = [] for i in idxs: img_np = dataset[i] img_torch = ptu.from_numpy(normalize_image(img_np)) recon, *_ = self.model(img_torch.view(1, -1)) img = img_torch.view(self.input_channels, self.imsize, self.imsize).transpose(1, 2) rimg = recon.view(self.input_channels, self.imsize, self.imsize).transpose(1, 2) imgs.append(img) recons.append(rimg) all_imgs = torch.stack(imgs + recons) save_file = osp.join(self.log_dir, filename) save_image( all_imgs.data, save_file, nrow=4, )
def __getitem__(self, idxs): samples = self.dataset[idxs, :] if self.should_normalize: samples = normalize_image(samples) return np.float32(samples)
def _kl_np_to_np(self, np_imgs): torch_input = ptu.from_numpy(normalize_image(np_imgs)) mu, log_var = self.model.encode(torch_input) return ptu.get_numpy( -torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1))
def get_clip_as_batch(id, max_frames=-1): clip_data = load_clip(id, max_frames) t_clip = transform_batch(clip_data) batch = ptu.from_numpy(normalize_image(t_clip)) # batch = batch.to("cuda") return batch
def __init__( self, model, batch_size=128, log_interval=0, beta=0.5, beta_schedule=None, lr=None, do_scatterplot=False, normalize=False, mse_weight=0.1, is_auto_encoder=False, background_subtract=False, linearity_weight=0.0, distance_weight=0.0, loss_weights=None, use_linear_dynamics=False, use_parallel_dataloading=False, train_data_workers=2, skew_dataset=False, skew_config=None, priority_function_kwargs=None, start_skew_epoch=0, weight_decay=0, key_to_reconstruct="observations", ): #TODO:steven fix pickling assert not use_parallel_dataloading, "Have to fix pickling the dataloaders first" if skew_config is None: skew_config = {} self.log_interval = log_interval self.batch_size = batch_size self.beta = beta if is_auto_encoder: self.beta = 0 if lr is None: if is_auto_encoder: lr = 1e-2 else: lr = 1e-3 self.beta_schedule = beta_schedule if self.beta_schedule is None or is_auto_encoder: self.beta_schedule = ConstantSchedule(self.beta) self.imsize = model.imsize self.do_scatterplot = do_scatterplot model.to(ptu.device) self.model = model self.representation_size = model.representation_size self.input_channels = model.input_channels self.imlength = model.imlength self.lr = lr params = list(self.model.parameters()) self.optimizer = optim.Adam( params, lr=self.lr, weight_decay=weight_decay, ) self.key_to_reconstruct = key_to_reconstruct self.batch_size = batch_size self.use_parallel_dataloading = use_parallel_dataloading self.train_data_workers = train_data_workers self.skew_dataset = skew_dataset self.skew_config = skew_config self.start_skew_epoch = start_skew_epoch if priority_function_kwargs is None: self.priority_function_kwargs = dict() else: self.priority_function_kwargs = priority_function_kwargs if use_parallel_dataloading: self.train_dataset_pt = ImageDataset(train_dataset, should_normalize=True) self.test_dataset_pt = ImageDataset(test_dataset, should_normalize=True) if self.skew_dataset: base_sampler = InfiniteWeightedRandomSampler( self.train_dataset, self._train_weights) else: base_sampler = InfiniteRandomSampler(self.train_dataset) self.train_dataloader = DataLoader( self.train_dataset_pt, sampler=InfiniteRandomSampler(self.train_dataset), batch_size=batch_size, drop_last=False, num_workers=train_data_workers, pin_memory=True, ) self.test_dataloader = DataLoader( self.test_dataset_pt, sampler=InfiniteRandomSampler(self.test_dataset), batch_size=batch_size, drop_last=False, num_workers=0, pin_memory=True, ) self.train_dataloader = iter(self.train_dataloader) self.test_dataloader = iter(self.test_dataloader) self.normalize = normalize self.mse_weight = mse_weight self.background_subtract = background_subtract if self.normalize or self.background_subtract: self.train_data_mean = np.mean(self.train_dataset, axis=0) self.train_data_mean = normalize_image( np.uint8(self.train_data_mean)) self.linearity_weight = linearity_weight self.distance_weight = distance_weight self.loss_weights = loss_weights self.use_linear_dynamics = use_linear_dynamics self._extra_stats_to_log = None # stateful tracking variables, reset every epoch self.eval_statistics = collections.defaultdict(list) self.eval_data = collections.defaultdict(list)
def _reconstruction_squared_error_np_to_np(self, np_imgs): torch_input = ptu.from_numpy(normalize_image(np_imgs)) recons, *_ = self.model(torch_input) error = torch_input - recons return ptu.get_numpy((error**2).sum(dim=1))