def _annotate_image(self, image, text, color=(0, 0, 255)): from multiworld.core.image_env import normalize_image from multiworld.core.image_env import unormalize_image if self.disable_annotated_images: return image img = unormalize_image(image).reshape(3, self.imsize, self.imsize).transpose((1, 2, 0)) img = img[::, :, ::-1] img = img.copy() if self.imsize == 84: fontScale = 0.30 elif self.imsize == 48: fontScale = 0.25 else: fontScale = 0.50 org = (0, self.imsize - 3) fontFace = 0 cv2.putText(img=img, text=text, org=org, fontFace=fontFace, fontScale=fontScale, color=color, thickness=1) img = img[::, :, ::-1] return normalize_image( img.transpose((2, 0, 1)).reshape(self.imsize * self.imsize * 3))
def get_batch(self, train=True, epoch=None): if self.use_parallel_dataloading: if not train: dataloader = self.test_dataloader else: dataloader = self.train_dataloader samples = next(dataloader).to(ptu.device) return samples dataset = self.train_dataset if train else self.test_dataset skew = False if epoch is not None: skew = self.start_skew_epoch < epoch if train and self.skew_dataset and skew: probs = self._train_weights / np.sum(self._train_weights) ind = np.random.choice( len(probs), self.batch_size, p=probs, ) else: ind = np.random.randint(0, len(dataset), self.batch_size) samples = normalize_image(dataset[ind, :]) if self.normalize: samples = ((samples - self.train_data_mean) + 1) / 2 if self.background_subtract: samples = samples - self.train_data_mean return ptu.from_numpy(samples)
def get_batch(self, train=True, epoch=None, sample_factors: bool = False): dataset = self.train_dataset if train else self.test_dataset factors = self.train_factors if train else self.test_factors skew = False if epoch is not None: skew = (self.start_skew_epoch < epoch) if train and self.skew_dataset and skew: probs = self._train_weights / np.sum(self._train_weights) ind = np.random.choice( len(probs), self.batch_size, p=probs, ) else: ind = np.random.randint(0, len(dataset), self.batch_size) sample_data = normalize_image(dataset[ind, :]) if self.normalize: sample_data = ((sample_data - self.train_data_mean) + 1) / 2 if self.background_subtract: sample_data = sample_data - self.train_data_mean sample_data = ptu.from_numpy(sample_data) if sample_factors: return sample_data, ptu.from_numpy(factors[ind, :]) else: return sample_data
def sample_buffer_goals(self, batch_size, skew=True, key='image_observation_segmented'): """ Samples goals from weighted replay buffer for relabeling or exploration. Returns None if replay buffer is empty. Example of what might be returned: dict( image_desired_goals: image_achieved_goals[weighted_indices], latent_desired_goals: latent_desired_goals[weighted_indices], ) """ if self._size == 0: return None weighted_idxs = self.sample_weighted_indices( batch_size, skew=skew ) # NOTE yufei: this is the original RLkit code, I think it does not make sense in the segmentation case, # because self.decoded_obs_key is just 'image_observation', which can not serve as the 'image_desired_goal' # here. # next_image_obs = normalize_image( # self._next_obs[self.decoded_obs_key][weighted_idxs] # ) next_latent_obs = self._next_obs[self.achieved_goal_key][weighted_idxs] next_img_obs = normalize_image( self._next_obs[key][weighted_idxs] ) # we should use the segmented images as the image_desired_goal # NOTE LSTM: if we ever want to change the key, remember to pass a key in! return { self.decoded_desired_goal_key: next_img_obs, self.desired_goal_key: next_latent_obs }
def get_dataset_stats(self, data): torch_input = ptu.from_numpy(normalize_image(data)) mus, log_vars = self.model.encode(torch_input) mus = ptu.get_numpy(mus) mean = np.mean(mus, axis=0) std = np.std(mus, axis=0) return mus, mean, std
def _compute_train_weights(self): method = self.skew_config.get("method", "squared_error") power = self.skew_config.get("power", 1) batch_size = 512 size = self.train_dataset.shape[0] next_idx = min(batch_size, size) cur_idx = 0 weights = np.zeros(size) while cur_idx < self.train_dataset.shape[0]: idxs = np.arange(cur_idx, next_idx) data = self.train_dataset[idxs, :] if method == "vae_prob": data = normalize_image(data) weights[idxs] = compute_p_x_np_to_np( self.model, data, power=power, **self.priority_function_kwargs ) else: raise NotImplementedError("Method {} not supported".format(method)) cur_idx = next_idx next_idx += batch_size next_idx = min(next_idx, size) if method == "vae_prob": weights = relative_probs_from_log_probs(weights) return weights
def __init__( self, train_dataset, test_dataset, model, train_seedsteps, test_seedsteps, train_actions=None, test_actions=None, batch_size=128, log_interval=0, gamma=0.5, lr=1e-3, do_scatterplot=False, normalize=False, mse_weight=0.1, is_auto_encoder=False, background_subtract=False, ): self.quick_init(locals()) self.log_interval = log_interval self.batch_size = batch_size self.beta = model.beta self.gamma = gamma self.imsize = model.imsize self.do_scatterplot = do_scatterplot model.to(ptu.device) self.model = model self.representation_size = model.representation_size self.input_channels = model.input_channels self.imlength = model.imlength self.train_seedsteps = train_seedsteps self.test_seedsteps = test_seedsteps self.lr = lr params = list(self.model.parameters()) self.optimizer = optim.Adam(params, lr=self.lr) self.train_dataset, self.test_dataset = train_dataset, test_dataset assert self.train_dataset.dtype == np.uint8 assert self.test_dataset.dtype == np.uint8 self.train_dataset = train_dataset self.test_dataset = test_dataset self.train_actions = train_actions self.test_actions = test_actions self.batch_size = batch_size self.normalize = normalize self.mse_weight = mse_weight self.background_subtract = background_subtract if self.normalize or self.background_subtract: self.train_data_mean = np.mean(self.train_dataset, axis=0) self.train_data_mean = normalize_image( np.uint8(self.train_data_mean)) self.vae_logger_stats_for_rl = {} self._extra_stats_to_log = None
def random_vae_training_data(self, batch_size, epoch): # epoch no longer needed. Using self.skew in sample_weighted_indices # instead. weighted_idxs = self.sample_weighted_indices(batch_size, ) next_image_obs = normalize_image( self._next_obs[self.decoded_obs_key][weighted_idxs]) return dict(next_obs=ptu.from_numpy(next_image_obs))
def postprocess_obs_dict(obs_dict): """ Undo internal replay buffer representation changes: save images as bytes """ for obs_key, obs in obs_dict.items(): if 'image' in obs_key and obs is not None: obs_dict[obs_key] = normalize_image(obs) return obs_dict
def get_batch(self, train=True): dataset = self.train_dataset if train else self.test_dataset ind = np.random.randint(0, len(dataset), self.batch_size) samples = normalize_image(dataset[ind, :]) if self.normalize: samples = ((samples - self.train_data_mean) + 1) / 2 if self.background_subtract: samples = samples - self.train_data_mean return ptu.from_numpy(samples)
def compute_sampled_latents(vae_env): vae_env.num_active_dims = 0 for std in vae_env.vae.dist_std: if std > 0.15: vae_env.num_active_dims += 1 vae_env.active_dims = vae_env.vae.dist_std.argsort()[-vae_env.num_active_dims:][::-1] vae_env.inactive_dims = vae_env.vae.dist_std.argsort()[:-vae_env.num_active_dims][::-1] if vae_env.use_vae_dataset and vae_env.vae_dataset_path is not None: from multiworld.core.image_env import normalize_image from railrl.misc.asset_loader import local_path_from_s3_or_local_path filename = local_path_from_s3_or_local_path(vae_env.vae_dataset_path) dataset = np.load(filename).item() vae_env.num_samples_for_latent_histogram = min(dataset['next_obs'].shape[0], vae_env.num_samples_for_latent_histogram) sampled_idx = np.random.choice(dataset['next_obs'].shape[0], vae_env.num_samples_for_latent_histogram) if vae_env.vae_input_key_prefix == 'state': vae_dataset_samples = dataset['next_obs'][sampled_idx] else: vae_dataset_samples = normalize_image(dataset['next_obs'][sampled_idx]) del dataset else: vae_dataset_samples = None n = vae_env.num_samples_for_latent_histogram if vae_dataset_samples is not None: imgs = vae_dataset_samples else: if vae_env.vae_input_key_prefix == 'state': imgs = vae_env.wrapped_env.wrapped_env.sample_goals(n)['state_desired_goal'] else: imgs = vae_env.wrapped_env.sample_goals(n)['image_desired_goal'] batch_size = 2500 latents, latents_noisy, latents_reproj = None, None, None for i in range(0, n, batch_size): batch_latents_mean, batch_latents_logvar = vae_env.encode_imgs(imgs[i:i + batch_size], clip_std=False) batch_latents_noisy = vae_env.reparameterize(batch_latents_mean, batch_latents_logvar, noisy=True) if vae_env.use_reprojection_network: batch_latents_reproj = ptu.get_numpy(vae_env.reproject_encoding(ptu.np_to_var(batch_latents_noisy))) if latents is None: latents = batch_latents_mean latents_noisy = batch_latents_noisy if vae_env.use_reprojection_network: latents_reproj = batch_latents_reproj else: latents = np.concatenate((latents, batch_latents_mean), axis=0) latents_noisy = np.concatenate((latents_noisy, batch_latents_noisy), axis=0) if vae_env.use_reprojection_network: latents_reproj = np.concatenate((latents_reproj, batch_latents_reproj), axis=0) vae_env.sampled_latents = latents vae_env.sampled_latents_noisy = latents_noisy vae_env.sampled_latents_reproj = latents_reproj
def log_loss_under_uniform(self, model, data, priority_function_kwargs): import torch.nn.functional as F log_probs_prior = [] log_probs_biased = [] log_probs_importance = [] kles = [] mses = [] for i in range(0, data.shape[0], self.batch_size): img = normalize_image(data[i : min(data.shape[0], i + self.batch_size), :]) torch_img = ptu.from_numpy(img) reconstructions, obs_distribution_params, latent_distribution_params = self.model( torch_img ) priority_function_kwargs["sampling_method"] = "true_prior_sampling" log_p, log_q, log_d = compute_log_p_log_q_log_d( model, img, **priority_function_kwargs ) log_prob_prior = log_d.mean() priority_function_kwargs["sampling_method"] = "biased_sampling" log_p, log_q, log_d = compute_log_p_log_q_log_d( model, img, **priority_function_kwargs ) log_prob_biased = log_d.mean() priority_function_kwargs["sampling_method"] = "importance_sampling" log_p, log_q, log_d = compute_log_p_log_q_log_d( model, img, **priority_function_kwargs ) log_prob_importance = (log_p - log_q + log_d).mean() kle = model.kl_divergence(latent_distribution_params) mse = F.mse_loss(torch_img, reconstructions, reduction="elementwise_mean") mses.append(mse.item()) kles.append(kle.item()) log_probs_prior.append(log_prob_prior.item()) log_probs_biased.append(log_prob_biased.item()) log_probs_importance.append(log_prob_importance.item()) logger.record_tabular( "Uniform Data Log Prob (True Prior)", np.mean(log_probs_prior) ) logger.record_tabular( "Uniform Data Log Prob (Biased)", np.mean(log_probs_biased) ) logger.record_tabular( "Uniform Data Log Prob (Importance)", np.mean(log_probs_importance) ) logger.record_tabular("Uniform Data KL", np.mean(kles)) logger.record_tabular("Uniform Data MSE", np.mean(mses))
def random_vae_training_data( self, batch_size, epoch, key=None): # NOTE yufei: pass in a chosen key. # epoch no longer needed. Using self.skew in sample_weighted_indices # instead. # print("random_vae_training_data with key: ", key) weighted_idxs = self.sample_weighted_indices(batch_size, key) if key is None: key = self.decoded_obs_key next_image_obs = normalize_image(self._next_obs[key][weighted_idxs]) return dict(next_obs=ptu.from_numpy(next_image_obs))
def dump_reconstructions(vae_env, epoch, n_recon=16): from railrl.core import logger import os.path as osp from torchvision.utils import save_image if vae_env.use_vae_dataset and vae_env.vae_dataset_path is not None: from multiworld.core.image_env import normalize_image from railrl.misc.asset_loader import local_path_from_s3_or_local_path filename = local_path_from_s3_or_local_path(vae_env.vae_dataset_path) dataset = np.load(filename).item() sampled_idx = np.random.choice(dataset['next_obs'].shape[0], n_recon) if vae_env.vae_input_key_prefix == 'state': states = dataset['next_obs'][sampled_idx] imgs = ptu.np_to_var( vae_env.wrapped_env.states_to_images(states) ) recon_samples, _, _ = vae_env.vae(ptu.np_to_var(states)) recon_imgs = ptu.np_to_var( vae_env.wrapped_env.states_to_images(ptu.get_numpy(recon_samples)) ) else: imgs = ptu.np_to_var( normalize_image(dataset['next_obs'][sampled_idx]) ) recon_imgs, _, _, _ = vae_env.vae(imgs) del dataset else: return comparison = torch.cat([ imgs.narrow(start=0, length=vae_env.wrapped_env.image_length, dimension=1).contiguous().view( -1, vae_env.wrapped_env.channels, vae_env.wrapped_env.imsize, vae_env.wrapped_env.imsize ), recon_imgs.contiguous().view( n_recon, vae_env.wrapped_env.channels, vae_env.wrapped_env.imsize, vae_env.wrapped_env.imsize )[:n_recon] ]) if epoch is not None: save_dir = osp.join(logger.get_snapshot_dir(), 'r_%d.png' % epoch) else: save_dir = osp.join(logger.get_snapshot_dir(), 'r.png') save_image(comparison.data.cpu(), save_dir, nrow=n_recon)
def get_batch(self, train=True): if self.use_parallel_dataloading: if not train: dataloader = self.test_dataloader else: dataloader = self.train_dataloader samples = next(dataloader) return { 'obs': ptu.Variable(samples[0][0]), 'actions': ptu.Variable(samples[1][0]), 'next_obs': ptu.Variable(samples[2][0]), } dataset = self.train_dataset if train else self.test_dataset ind = np.random.randint(0, len(dataset), self.batch_size) samples = normalize_image(dataset[ind, :]) return ptu.np_to_var(samples)
def get_batch(self, train=True, epoch=None): if self.use_parallel_dataloading: if not train: dataloader = self.test_dataloader else: dataloader = self.train_dataloader samples = next(dataloader).to(ptu.device) return samples dataset = self.train_dataset if train else self.test_dataset ind = np.random.randint(0, len(dataset), self.batch_size) samples = normalize_image(dataset[ind, :]) if self.normalize: samples = ((samples - self.train_data_mean) + 1) / 2 if self.background_subtract: samples = samples - self.train_data_mean return ptu.from_numpy(samples)
def random_lstm_training_data(self, batch_size, key=None): if key is None: key = self.decoded_obs_key traj_idxes = np.random.randint(0, self._traj_num, batch_size) imlen = self._next_obs[key].shape[-1] data = np.zeros((batch_size, self.max_path_length, imlen), dtype=np.uint8) for i in range(batch_size): data[i] = self._obs[key][traj_idxes[i] * self.max_path_length: (traj_idxes[i] + 1) * self.max_path_length] data = normalize_image(data) data = np.swapaxes(data, 0, 1) # traj_len x batch_size x imlen return dict( next_obs=ptu.from_numpy(data) )
def _dump_imgs_and_reconstructions(self, idxs, filename): imgs = [] recons = [] for i in idxs: img_np = self.train_dataset[i] img_torch = ptu.from_numpy(normalize_image(img_np)) recon, *_ = self.model(img_torch.view(1, -1)) img = img_torch.view( self.input_channels, self.imsize, self.imsize ).transpose(1, 2) rimg = recon.view(self.input_channels, self.imsize, self.imsize).transpose( 1, 2 ) imgs.append(img) recons.append(rimg) all_imgs = torch.stack(imgs + recons) save_file = osp.join(logger.get_snapshot_dir(), filename) save_image(all_imgs.data, save_file, nrow=len(idxs))
def dump_uniform_imgs_and_reconstructions(self, dataset, epoch): idxs = np.random.choice(range(dataset.shape[0]), 4) filename = "uniform{}.png".format(epoch) imgs = [] recons = [] for i in idxs: img_np = dataset[i] img_torch = ptu.from_numpy(normalize_image(img_np)) recon, *_ = self.model(img_torch.view(1, -1)) img = img_torch.view( self.input_channels, self.imsize, self.imsize ).transpose(1, 2) rimg = recon.view(self.input_channels, self.imsize, self.imsize).transpose( 1, 2 ) imgs.append(img) recons.append(rimg) all_imgs = torch.stack(imgs + recons) save_file = osp.join(logger.get_snapshot_dir(), filename) save_image(all_imgs.data, save_file, nrow=4)
def sample_buffer_goals(self, batch_size): """ Samples goals from weighted replay buffer for relabeling or exploration. Returns None if replay buffer is empty. Example of what might be returned: dict( image_desired_goals: image_achieved_goals[weighted_indices], latent_desired_goals: latent_desired_goals[weighted_indices], ) """ if self._size == 0: return None weighted_idxs = self.sample_weighted_indices(batch_size, ) next_image_obs = normalize_image( self._next_obs[self.decoded_obs_key][weighted_idxs]) next_latent_obs = self._next_obs[self.achieved_goal_key][weighted_idxs] return { self.decoded_desired_goal_key: next_image_obs, self.desired_goal_key: next_latent_obs, }
def drop_puck(img, imsize=48): tempImg = copy.deepcopy(img) tempImg = tempImg.reshape(3, imsize, imsize).transpose() tempImg = tempImg * 255 tempImg = tempImg.astype(np.uint8) tempImgCopy = copy.deepcopy(tempImg) # this segments out puck lowerColorBounds = (0, 60, 150) # RGB upperColorBounds = (10, 140, 255) # RGB mask_puck = cv2.inRange(tempImgCopy, lowerColorBounds, upperColorBounds) kernel = np.ones((2,2),np.uint8) mask_puck = cv2.dilate(mask_puck, kernel) mask_puck = mask_puck>200 tmp = np.zeros_like(tempImgCopy[mask_puck]) tmp[:, 0]= 120 tmp[:, 1]= 120 tmp[:, 2]= 100 tempImgCopy[mask_puck] = tmp tempImgCopy = tempImgCopy.transpose() tempImgCopy = tempImgCopy.reshape([1,-1]) tempImgCopy = normalize_image(tempImgCopy) # show_image = copy.deepcopy(tempImgCopy) # img = show_image.reshape(3, imsize, imsize).transpose() # img = img[::-1, :, ::-1] # cv2.imshow("segmented image", img) # cv2.waitKey() return tempImgCopy.flatten()
def _dump_imgs_and_reconstructions(self, idxs, filename): imgs = [] recons = [] for i in idxs: img_np = self.train_dataset[i] img_torch = ptu.from_numpy(normalize_image(img_np)) recon, *_ = self.model(img_torch.view(1, -1)) img = img_torch.view(self.input_channels, self.imsize, self.imsize).transpose(1, 2) rimg = recon.view(self.input_channels, self.imsize, self.imsize).transpose(1, 2) imgs.append(img) recons.append(rimg) all_imgs = torch.stack(imgs + recons) # save_file = osp.join(logger.get_snapshot_dir(), filename) #save_file = osp.join('/mnt/manh/project/visual_RL_imaged_goal', filename) project_path = osp.abspath(os.curdir) save_dir = osp.join(project_path + str('/result_image/'), filename) save_image( all_imgs.data, save_file, nrow=len(idxs), )
def _reconstruction_squared_error_np_to_np(self, np_imgs): torch_input = ptu.from_numpy(normalize_image(np_imgs)) recons, *_ = self.model(torch_input) error = torch_input - recons return ptu.get_numpy((error**2).sum(dim=1))
def refresh_latents(self, epoch): self.epoch = epoch self.skew = self.epoch > self.start_skew_epoch batch_size = 512 next_idx = min(batch_size, self._size) if self.exploration_rewards_type == "hash_count": # you have to count everything then compute exploration rewards cur_idx = 0 next_idx = min(batch_size, self._size) while cur_idx < self._size: idxs = np.arange(cur_idx, next_idx) normalized_imgs = normalize_image( self._next_obs[self.decoded_obs_key][idxs]) cur_idx = next_idx next_idx += batch_size next_idx = min(next_idx, self._size) cur_idx = 0 obs_sum = np.zeros(self.vae.representation_size) obs_square_sum = np.zeros(self.vae.representation_size) while cur_idx < self._size: idxs = np.arange(cur_idx, next_idx) self._obs[self.observation_key][idxs] = self.env._encode( normalize_image(self._obs[self.decoded_obs_key][idxs])) self._next_obs[self.observation_key][idxs] = self.env._encode( normalize_image(self._next_obs[self.decoded_obs_key][idxs])) # WARNING: we only refresh the desired/achieved latents for # "next_obs". This means that obs[desired/achieve] will be invalid, # so make sure there's no code that references this. # TODO: enforce this with code and not a comment self._next_obs[self.desired_goal_key][idxs] = self.env._encode( normalize_image( self._next_obs[self.decoded_desired_goal_key][idxs])) self._next_obs[self.achieved_goal_key][idxs] = self.env._encode( normalize_image( self._next_obs[self.decoded_achieved_goal_key][idxs])) normalized_imgs = normalize_image( self._next_obs[self.decoded_obs_key][idxs]) if self._give_explr_reward_bonus: rewards = self.exploration_reward_func( normalized_imgs, idxs, **self.priority_function_kwargs) self._exploration_rewards[idxs] = rewards.reshape(-1, 1) if self._prioritize_vae_samples: if (self.exploration_rewards_type == self.vae_priority_type and self._give_explr_reward_bonus): self._vae_sample_priorities[ idxs] = self._exploration_rewards[idxs] else: self._vae_sample_priorities[ idxs] = self.vae_prioritization_func( normalized_imgs, idxs, **self.priority_function_kwargs).reshape(-1, 1) obs_sum += self._obs[self.observation_key][idxs].sum(axis=0) obs_square_sum += np.power(self._obs[self.observation_key][idxs], 2).sum(axis=0) cur_idx = next_idx next_idx += batch_size next_idx = min(next_idx, self._size) self.vae.dist_mu = obs_sum / self._size self.vae.dist_std = np.sqrt(obs_square_sum / self._size - np.power(self.vae.dist_mu, 2)) if self._prioritize_vae_samples: """ priority^power is calculated in the priority function for image_bernoulli_prob or image_gaussian_inv_prob and directly here if not. """ if self.vae_priority_type == "vae_prob": self._vae_sample_priorities[:self. _size] = relative_probs_from_log_probs( self. _vae_sample_priorities[:self. _size]) self._vae_sample_probs = self._vae_sample_priorities[:self. _size] else: self._vae_sample_probs = ( self._vae_sample_priorities[:self._size]**self.power) p_sum = np.sum(self._vae_sample_probs) assert p_sum > 0, "Unnormalized p sum is {}".format(p_sum) self._vae_sample_probs /= np.sum(self._vae_sample_probs) self._vae_sample_probs = self._vae_sample_probs.flatten()
def __init__( self, train_dataset, test_dataset, model, positive_range=2, negative_range=10, triplet_sample_num=8, triplet_loss_margin=0.5, batch_size=128, log_interval=0, recon_loss_coef=1, triplet_loss_coef=[], triplet_loss_type=[], ae_loss_coef=1, matching_loss_coef=1, vae_matching_loss_coef=1, matching_loss_one_side=False, contrastive_loss_coef=0, lstm_kl_loss_coef=0, adaptive_margin=0, beta=0.5, beta_schedule=None, lr=None, do_scatterplot=False, normalize=False, mse_weight=0.1, is_auto_encoder=False, background_subtract=False, use_parallel_dataloading=False, train_data_workers=2, skew_dataset=False, skew_config=None, priority_function_kwargs=None, start_skew_epoch=0, weight_decay=0, ): print("In LSTM trainer, ae_loss_coef is: ", ae_loss_coef) print("In LSTM trainer, matching_loss_coef is: ", matching_loss_coef) print("In LSTM trainer, vae_matching_loss_coef is: ", vae_matching_loss_coef) if skew_config is None: skew_config = {} self.log_interval = log_interval self.batch_size = batch_size self.beta = beta if is_auto_encoder: self.beta = 0 if lr is None: if is_auto_encoder: lr = 1e-2 else: lr = 1e-3 self.beta_schedule = beta_schedule if self.beta_schedule is None or is_auto_encoder: self.beta_schedule = ConstantSchedule(self.beta) self.imsize = model.imsize self.do_scatterplot = do_scatterplot self.recon_loss_coef = recon_loss_coef self.triplet_loss_coef = triplet_loss_coef self.ae_loss_coef = ae_loss_coef self.matching_loss_coef = matching_loss_coef self.vae_matching_loss_coef = vae_matching_loss_coef self.contrastive_loss_coef = contrastive_loss_coef self.lstm_kl_loss_coef = lstm_kl_loss_coef self.matching_loss_one_side = matching_loss_one_side # triplet loss range self.positve_range = positive_range self.negative_range = negative_range self.triplet_sample_num = triplet_sample_num self.triplet_loss_margin = triplet_loss_margin self.triplet_loss_type = triplet_loss_type self.adaptive_margin = adaptive_margin model.to(ptu.device) self.model = model self.representation_size = model.representation_size self.input_channels = model.input_channels self.imlength = model.imlength self.lr = lr params = list(self.model.parameters()) self.optimizer = optim.Adam( params, lr=self.lr, weight_decay=weight_decay, ) self.train_dataset, self.test_dataset = train_dataset, test_dataset assert self.train_dataset.dtype == np.uint8 assert self.test_dataset.dtype == np.uint8 self.batch_size = batch_size self.use_parallel_dataloading = use_parallel_dataloading self.train_data_workers = train_data_workers self.skew_dataset = skew_dataset self.skew_config = skew_config self.start_skew_epoch = start_skew_epoch if priority_function_kwargs is None: self.priority_function_kwargs = dict() else: self.priority_function_kwargs = priority_function_kwargs if self.skew_dataset: self._train_weights = self._compute_train_weights() else: self._train_weights = None if use_parallel_dataloading: self.train_dataset_pt = ImageDataset(train_dataset, should_normalize=True) self.test_dataset_pt = ImageDataset(test_dataset, should_normalize=True) if self.skew_dataset: base_sampler = InfiniteWeightedRandomSampler( self.train_dataset, self._train_weights) else: base_sampler = InfiniteRandomSampler(self.train_dataset) self.train_dataloader = DataLoader( self.train_dataset_pt, sampler=InfiniteRandomSampler(self.train_dataset), batch_size=batch_size, drop_last=False, num_workers=train_data_workers, pin_memory=True, ) self.test_dataloader = DataLoader( self.test_dataset_pt, sampler=InfiniteRandomSampler(self.test_dataset), batch_size=batch_size, drop_last=False, num_workers=0, pin_memory=True, ) self.train_dataloader = iter(self.train_dataloader) self.test_dataloader = iter(self.test_dataloader) self.normalize = normalize self.mse_weight = mse_weight self.background_subtract = background_subtract if self.normalize or self.background_subtract: self.train_data_mean = np.mean(self.train_dataset, axis=0) self.train_data_mean = normalize_image( np.uint8(self.train_data_mean)) self.eval_statistics = OrderedDict() self._extra_stats_to_log = None
def __init__( self, train_dataset, test_dataset, model, batch_size=128, log_interval=0, beta=0.5, beta_schedule=None, lr=None, do_scatterplot=False, normalize=False, mse_weight=0.1, is_auto_encoder=False, background_subtract=False, use_parallel_dataloading=True, train_data_workers=2, skew_dataset=False, skew_config=None, priority_function_kwargs=None, start_skew_epoch=0, weight_decay=0, ): if skew_config is None: skew_config = {} self.log_interval = log_interval self.batch_size = batch_size self.beta = beta if is_auto_encoder: self.beta = 0 if lr is None: if is_auto_encoder: lr = 1e-2 else: lr = 1e-3 self.beta_schedule = beta_schedule if self.beta_schedule is None or is_auto_encoder: self.beta_schedule = ConstantSchedule(self.beta) self.imsize = model.imsize self.do_scatterplot = do_scatterplot model.to(ptu.device) self.model = model self.representation_size = model.representation_size self.input_channels = model.input_channels self.imlength = model.imlength self.lr = lr params = list(self.model.parameters()) self.optimizer = optim.Adam( params, lr=self.lr, weight_decay=weight_decay, ) self.train_dataset, self.test_dataset = train_dataset, test_dataset assert self.train_dataset.dtype == np.uint8 assert self.test_dataset.dtype == np.uint8 self.train_dataset = train_dataset self.test_dataset = test_dataset self.batch_size = batch_size self.use_parallel_dataloading = use_parallel_dataloading self.train_data_workers = train_data_workers self.skew_dataset = skew_dataset self.skew_config = skew_config self.start_skew_epoch = start_skew_epoch if priority_function_kwargs is None: self.priority_function_kwargs = dict() else: self.priority_function_kwargs = priority_function_kwargs if self.skew_dataset: self._train_weights = self._compute_train_weights() else: self._train_weights = None if use_parallel_dataloading: self.train_dataset_pt = ImageDataset(train_dataset, should_normalize=True) self.test_dataset_pt = ImageDataset(test_dataset, should_normalize=True) if self.skew_dataset: base_sampler = InfiniteWeightedRandomSampler( self.train_dataset, self._train_weights) else: base_sampler = InfiniteRandomSampler(self.train_dataset) self.train_dataloader = DataLoader( self.train_dataset_pt, sampler=InfiniteRandomSampler(self.train_dataset), batch_size=batch_size, drop_last=False, num_workers=train_data_workers, pin_memory=True, ) self.test_dataloader = DataLoader( self.test_dataset_pt, sampler=InfiniteRandomSampler(self.test_dataset), batch_size=batch_size, drop_last=False, num_workers=0, pin_memory=True, ) self.train_dataloader = iter(self.train_dataloader) self.test_dataloader = iter(self.test_dataloader) self.normalize = normalize self.mse_weight = mse_weight self.background_subtract = background_subtract if self.normalize or self.background_subtract: self.train_data_mean = np.mean(self.train_dataset, axis=0) self.train_data_mean = normalize_image( np.uint8(self.train_data_mean)) self.eval_statistics = OrderedDict() self._extra_stats_to_log = None
def _test_lstm(lstm_trainer, epoch, replay_buffer, env_id, lstm_save_period=1, uniform_dataset=None, save_prefix='r', lstm_segmentation_method='color', lstm_test_N=500, key='image_observation_segmented'): batch_sampler = replay_buffer.random_lstm_training_data save_imgs = epoch % lstm_save_period == 0 log_fit_skew_stats = replay_buffer._prioritize_vae_samples and uniform_dataset is not None if uniform_dataset is not None: replay_buffer.log_loss_under_uniform( uniform_dataset, lstm_trainer.batch_size, rl_logger=lstm_trainer.vae_logger_stats_for_rl) lstm_trainer.test_epoch(epoch, from_rl=True, key=key, sample_batch=batch_sampler, save_reconstruction=save_imgs, save_prefix=save_prefix) if save_imgs: sample_save_prefix = save_prefix.replace('r', 's') lstm_trainer.dump_samples(epoch, save_prefix=sample_save_prefix) if log_fit_skew_stats: replay_buffer.dump_best_reconstruction(epoch) replay_buffer.dump_worst_reconstruction(epoch) replay_buffer.dump_sampling_histogram( epoch, batch_size=lstm_trainer.batch_size) if uniform_dataset is not None: replay_buffer.dump_uniform_imgs_and_reconstructions( dataset=uniform_dataset, epoch=epoch) m = lstm_trainer.model pjhome = os.environ['PJHOME'] seg_name = 'seg-' + 'color' if env_id in [ 'SawyerPushNIPSEasy-v0', 'SawyerPushHurdle-v0', 'SawyerPushHurdleMiddle-v0' ]: N = 500 data_file_path = osp.join( pjhome, 'data/local/pre-train-lstm', '{}-{}-{}-0.3-0.5.npy'.format(env_id, seg_name, N)) puck_pos_path = osp.join( pjhome, 'data/local/pre-train-lstm', '{}-{}-{}-0.3-0.5-puck-pos.npy'.format(env_id, seg_name, N)) if osp.exists(data_file_path): all_data = np.load(data_file_path) puck_pos = np.load(puck_pos_path) all_data = normalize_image(all_data.copy()) compare_latent_distance( m, all_data, puck_pos, save_dir=logger.get_snapshot_dir(), obj_name='puck', save_name='online_lstm_latent_distance_{}.png'.format( epoch)) elif env_id == 'SawyerDoorHookResetFreeEnv-v1': N = 1000 seg_name = 'seg-' + 'unet' data_file_path = osp.join( pjhome, 'data/local/pre-train-lstm', 'vae-only-{}-{}-{}-0-0.npy'.format(env_id, seg_name, N)) door_angle_path = osp.join( pjhome, 'data/local/pre-train-lstm', 'vae-only-{}-{}-{}-0-0-door-angle.npy'.format( env_id, seg_name, N)) if osp.exists(data_file_path): all_data = np.load(data_file_path) door_angle = np.load(door_angle_path) all_data = normalize_image(all_data.copy()) compare_latent_distance( m, all_data, door_angle, save_dir=logger.get_snapshot_dir(), obj_name='door', save_name='online_lstm_latent_distance_{}.png'.format( epoch)) elif env_id == 'SawyerPushHurdleResetFreeEnv-v0': N = 2000 data_file_path = osp.join( pjhome, 'data/local/pre-train-lstm', 'vae-only-{}-{}-{}-0.3-0.5.npy'.format(env_id, seg_name, N)) puck_pos_path = osp.join( pjhome, 'data/local/pre-train-lstm', 'vae-only-{}-{}-{}-0.3-0.5-puck-pos.npy'.format( env_id, seg_name, N)) if osp.exists(data_file_path): all_data = np.load(data_file_path) puck_pos = np.load(puck_pos_path) all_data = normalize_image(all_data.copy()) compare_latent_distance( m, all_data, puck_pos, save_dir=logger.get_snapshot_dir(), obj_name='puck', save_name='online_lstm_latent_distance_{}.png'.format( epoch)) test_lstm_traj(env_id, m, save_path=logger.get_snapshot_dir(), save_name='online_lstm_test_traj_{}.png'.format(epoch)) test_masked_traj_lstm( env_id, m, save_dir=logger.get_snapshot_dir(), save_name='online_masked_test_{}.png'.format(epoch))
def refresh_latents(self, epoch): self.epoch = epoch self.skew = (self.epoch > self.start_skew_epoch) batch_size = 512 next_idx = min(batch_size, self._size) if self.exploration_rewards_type == 'hash_count': # you have to count everything then compute exploration rewards cur_idx = 0 next_idx = min(batch_size, self._size) while cur_idx < self._size: idxs = np.arange(cur_idx, next_idx) normalized_imgs = (normalize_image( self._next_obs[self.decoded_obs_key][idxs])) cur_idx = next_idx next_idx += batch_size next_idx = min(next_idx, self._size) cur_idx = 0 obs_sum = np.zeros(self.vae.representation_size) obs_square_sum = np.zeros(self.vae.representation_size) while cur_idx < self._size: idxs = np.arange(cur_idx, next_idx) # NOTE yufei: observation should use env.vae_original (non-segmented images) self._obs[self.observation_key][idxs] = \ self.env._encode( normalize_image(self._obs[self.decoded_obs_key][idxs]), self.env.vae_original ) self._next_obs[self.observation_key][idxs] = \ self.env._encode( normalize_image(self._next_obs[self.decoded_obs_key][idxs]), self.env.vae_original ) # WARNING: we only refresh the desired/achieved latents for # "next_obs". This means that obs[desired/achieve] will be invalid, # so make sure there's no code that references this. # TODO: enforce this with code and not a comment # print("in online vae replay buffer, observation key is: ", self.observation_key) # print("in online vae replay buffer, decoded_obs key is: ", self.decoded_obs_key) # self.show_obs(normalize_image(self._next_obs[self.decoded_obs_key][0]), name='next obs') # self.show_obs(normalize_image(self._obs[self.decoded_obs_key][0]), name='obs') # NOTE yufei: for desired_goal_key we use env.vae_segmented self._next_obs[self.desired_goal_key][idxs] = \ self.env._encode( normalize_image(self._next_obs[self.decoded_desired_goal_key][idxs]), self.env.vae_segmented ) self._next_obs[self.achieved_goal_key][idxs] = \ self.env._encode( normalize_image(self._next_obs[self.decoded_achieved_goal_key][idxs]), self.env.vae_segmented ) # self._obs[self.desired_goal_key][idxs] = \ # self.env._encode( # normalize_image(self._obs[self.decoded_desired_goal_key][idxs]), self.env.vae_segmented # ) # self._obs[self.achieved_goal_key][idxs] = \ # self.env._encode( # normalize_image(self._obs[self.decoded_achieved_goal_key][idxs]), self.env.vae_segmented # ) # print("in online vae replay buffer, desired goal key is: ", self.desired_goal_key) # print("in online vae replay buffer, achieved goal key is: ", self.achieved_goal_key) # print("in online vae replay buffer, decoded_desired_goal_key is: ", self.decoded_desired_goal_key) # print("in online vae replay buffer, decoded_achieved_goal_key is: ", self.decoded_achieved_goal_key) # self.show_obs(normalize_image(self._next_obs[self.decoded_desired_goal_key][0]), name='decoded desired goal') # self.show_obs(normalize_image(self._next_obs[self.decoded_achieved_goal_key][0]), name='decoded achieved goal') normalized_imgs = (normalize_image( self._next_obs[self.decoded_obs_key][idxs])) normalized_imgs_seg = (normalize_image( self._next_obs[self.decoded_obs_key_seg][idxs])) if self._give_explr_reward_bonus: rewards = self.exploration_reward_func( normalized_imgs, idxs, **self.priority_function_kwargs) self._exploration_rewards[idxs] = rewards.reshape(-1, 1) if self._prioritize_vae_samples: if (self.exploration_rewards_type == self.vae_priority_type and self._give_explr_reward_bonus): self._vae_sample_priorities[idxs] = ( self._exploration_rewards[idxs]) else: # NOTE yufei: this is what skewfit actually uses. So I only updated this. self._vae_sample_priorities[idxs] = ( self.vae_prioritization_func( self.vae, normalized_imgs, idxs, **self.priority_function_kwargs).reshape(-1, 1)) self._vae_sample_priorities_seg[idxs] = ( self.vae_prioritization_func( self.vae_seg, normalized_imgs_seg, idxs, **self.priority_function_kwargs).reshape(-1, 1)) obs_sum += self._obs[self.observation_key][idxs].sum(axis=0) obs_square_sum += np.power(self._obs[self.observation_key][idxs], 2).sum(axis=0) cur_idx = next_idx next_idx += batch_size next_idx = min(next_idx, self._size) self.vae.dist_mu = obs_sum / self._size self.vae.dist_std = np.sqrt(obs_square_sum / self._size - np.power(self.vae.dist_mu, 2)) if self._prioritize_vae_samples: """ priority^power is calculated in the priority function for image_bernoulli_prob or image_gaussian_inv_prob and directly here if not. """ if self.vae_priority_type == 'vae_prob': self._vae_sample_priorities[:self. _size] = relative_probs_from_log_probs( self. _vae_sample_priorities[:self. _size]) self._vae_sample_probs = self._vae_sample_priorities[:self. _size] self._vae_sample_priorities_seg[:self._size] = relative_probs_from_log_probs( self._vae_sample_priorities_seg[:self._size]) self._vae_sample_probs_seg = self._vae_sample_priorities_seg[: self . _size] else: self._vae_sample_probs = self._vae_sample_priorities[:self. _size]**self.power self._vae_sample_probs_seg = self._vae_sample_priorities_seg[: self . _size]**self.power p_sum = np.sum(self._vae_sample_probs) assert p_sum > 0, "Unnormalized p sum is {}".format(p_sum) self._vae_sample_probs /= np.sum(self._vae_sample_probs) self._vae_sample_probs = self._vae_sample_probs.flatten() p_sum = np.sum(self._vae_sample_probs_seg) assert p_sum > 0, "Unnormalized p sum is {}".format(p_sum) self._vae_sample_probs_seg /= np.sum(self._vae_sample_probs_seg) self._vae_sample_probs_seg = self._vae_sample_probs_seg.flatten()
def __init__( self, train_dataset, test_dataset, model, batch_size=128, log_interval=0, # beta=0.5, # beta_schedule=None, lr=None, do_scatterplot=False, normalize=False, mse_weight=0.1, is_auto_encoder=False, background_subtract=False, use_parallel_dataloading=True, train_data_workers=2, # skew_dataset=False, # skew_config=None, priority_function_kwargs=None, start_skew_epoch=0, weight_decay=0, **kwargs): self.log_interval = log_interval self.batch_size = batch_size assert lr is not None self.imsize = model.imsize self.do_scatterplot = do_scatterplot self.representation_size = model.representation_size model.to(ptu.device) self.model = model self.pixelcnn = model.pixelcnn self.input_channels = model.input_channels self.imlength = model.imlength self.skew_dataset = False self._train_weights = None self.lr = lr # keep VQVAE params only params = [] pixel_params = [] for p in model.named_parameters(): if 'pixelcnn' in p[0]: pixel_params.append(p[1]) else: params.append(p[1]) self.optimizer = optim.Adam( params, lr=self.lr, weight_decay=weight_decay, ) self.pixelcnn_opt = optim.Adam(pixel_params, lr=self.lr, weight_decay=weight_decay, amsgrad=True) self.pixelcnn_criterion = nn.CrossEntropyLoss().to(ptu.device) self.train_dataset, self.test_dataset = train_dataset, test_dataset assert self.train_dataset.dtype == np.uint8 assert self.test_dataset.dtype == np.uint8 self.train_dataset = train_dataset self.test_dataset = test_dataset self.batch_size = batch_size self.use_parallel_dataloading = use_parallel_dataloading self.train_data_workers = train_data_workers if priority_function_kwargs is None: self.priority_function_kwargs = dict() else: self.priority_function_kwargs = priority_function_kwargs self._train_weights = None if use_parallel_dataloading: self.train_dataset_pt = ImageDataset(train_dataset, should_normalize=True) self.test_dataset_pt = ImageDataset(test_dataset, should_normalize=True) #base_sampler = InfiniteRandomSampler(self.train_dataset) self.train_dataloader = DataLoader( self.train_dataset_pt, sampler=InfiniteRandomSampler(self.train_dataset), batch_size=batch_size, drop_last=False, num_workers=train_data_workers, pin_memory=True, ) self.test_dataloader = DataLoader( self.test_dataset_pt, sampler=InfiniteRandomSampler(self.test_dataset), batch_size=batch_size, drop_last=False, num_workers=0, pin_memory=True, ) self.train_dataloader = iter(self.train_dataloader) self.test_dataloader = iter(self.test_dataloader) self.normalize = normalize self.mse_weight = mse_weight self.background_subtract = background_subtract if self.normalize or self.background_subtract: self.train_data_mean = np.mean(self.train_dataset, axis=0) self.train_data_mean = normalize_image( np.uint8(self.train_data_mean)) self.eval_statistics = OrderedDict() self._extra_stats_to_log = None
def _kl_np_to_np(self, np_imgs): torch_input = ptu.from_numpy(normalize_image(np_imgs)) mu, log_var = self.model.encode(torch_input) return ptu.get_numpy( -torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1))