def initialize_dynamics_model(self): obs_dim = self._obs[self.observation_key].shape[1] self.dynamics_model = Mlp( hidden_sizes=[128, 128], output_size=obs_dim, input_size=obs_dim + self._action_dim, ) self.dynamics_model.to(ptu.device) self.dynamics_optimizer = Adam(self.dynamics_model.parameters()) self.dynamics_loss = MSELoss()
def __init__( self, vae, hidden_sizes=list([64, 128, 64]), init_w=1e-3, hidden_init=ptu.fanin_init, output_activation=identity, layer_norm=False, **kwargs ): self.save_init_params(locals()) super().__init__() self.vae = vae self.representation_size = self.vae.representation_size self.hidden_init = hidden_init self.output_activation = output_activation # self.dist_mu = np.zeros(self.representation_size) # self.dist_std = np.ones(self.representation_size) self.dist_mu = self.vae.dist_mu self.dist_std = self.vae.dist_std self.relu = nn.ReLU() self.init_w = init_w hidden_sizes = list(hidden_sizes) self.network=Mlp(hidden_sizes, self.representation_size, self.representation_size, layer_norm=layer_norm, hidden_init=hidden_init, output_activation=output_activation, init_w=init_w)
def __init__( self, representation_size, input_size, hidden_sizes, init_w=1e-3, hidden_init=ptu.fanin_init, output_activation=identity, output_scale=1, layer_norm=False, ): super().__init__() self.representation_size = representation_size self.hidden_init = hidden_init self.output_activation = output_activation self.dist_mu = np.zeros(self.representation_size) self.dist_std = np.ones(self.representation_size) self.relu = nn.ReLU() self.sigmoid = nn.Sigmoid() self.init_w = init_w hidden_sizes = list(hidden_sizes) self.encoder = TwoHeadMlp(hidden_sizes, representation_size, representation_size, input_size, layer_norm=layer_norm) hidden_sizes.reverse() self.decoder = Mlp(hidden_sizes, input_size, representation_size, layer_norm=layer_norm, output_activation=output_activation, output_bias=None) self.output_scale = output_scale
def experiment(variant): env = DiscreteSwimmerEnv(**variant['env_params']) qf = Mlp(input_size=int(np.prod(env.observation_space.shape)), output_size=env.action_space.n, **variant['qf_kwargs']) algorithm = DQN(env, qf=qf, **variant['algo_params']) algorithm.to(ptu.device) algorithm.train()
def experiment(variant): env = variant['env_class'](**variant['env_kwargs']) env = DiscretizeEnv(env, variant['num_bins']) # env = DiscreteReacherEnv(**variant['env_kwargs']) qf = Mlp(input_size=int(np.prod(env.observation_space.shape)), output_size=env.action_space.n, **variant['qf_kwargs']) algorithm = FiniteHorizonDQN(env, qf, **variant['algo_kwargs']) algorithm.to(ptu.device) algorithm.train()
def experiment(variant): env = DiscreteReacherEnv(**variant['env_params']) qf = Mlp( hidden_sizes=[32, 32], input_size=int(np.prod(env.observation_space.shape)), output_size=env.action_space.n, ) algorithm = DQN(env, qf=qf, **variant['algo_params']) if ptu.gpu_enabled(): algorithm.cuda() algorithm.train()
def __init__(self, representation_size, input_size, hidden_sizes=list([64, 128, 64]), init_w=1e-3, hidden_init=ptu.fanin_init, output_activation=identity, output_scale=1, layer_norm=False, normalize=True, train_data_mean=None, train_data_std=None, **kwargs): self.save_init_params(locals()) super().__init__() self.representation_size = representation_size self.hidden_init = hidden_init self.output_activation = output_activation self.dist_mu = np.zeros(self.representation_size) self.dist_std = np.ones(self.representation_size) self.relu = nn.ReLU() self.sigmoid = nn.Sigmoid() self.init_w = init_w hidden_sizes = list(hidden_sizes) self.input_size = input_size self.encoder = TwoHeadMlp(hidden_sizes, representation_size, representation_size, input_size, layer_norm=layer_norm, hidden_init=hidden_init, output_activation=output_activation, init_w=init_w) hidden_sizes.reverse() self.decoder = Mlp(hidden_sizes, input_size, representation_size, layer_norm=layer_norm, hidden_init=hidden_init, output_activation=output_activation, init_w=init_w) self.output_scale = output_scale self.normalize = normalize if train_data_mean is None: self.train_data_mean = ptu.np_to_var(np.zeros(input_size)) else: self.train_data_mean = train_data_mean if train_data_std is None: self.train_data_std = ptu.np_to_var(np.ones(input_size)) else: self.train_data_std = train_data_std
def experiment(variant): # register_grid_envs() # env = gym.make("GridMaze1-v0") env = gym.make("Pong-ram-v0") qf = Mlp( hidden_sizes=[100, 100], input_size=int(np.prod(env.observation_space.shape)), output_size=env.action_space.n, ) algorithm = DQN(env, qf=qf, **variant['algo_params']) algorithm.to(ptu.device) algorithm.train()
def __init__( self, obs_processor, obs_processor_output_dim, action_dim, hidden_sizes, ): super().__init__() self.obs_processor = obs_processor self.obs_processor_output_dim = obs_processor_output_dim self.mean_and_log_std_net = Mlp( hidden_sizes=hidden_sizes, output_size=action_dim * 2, input_size=obs_processor_output_dim, ) self.action_dim = action_dim
def experiment(variant): env = variant['env_class']() if variant['multitask']: env = MultitaskToFlatEnv(env) qf = Mlp( hidden_sizes=[32, 32], input_size=int(np.prod(env.observation_space.shape)), output_size=env.action_space.n, ) qf_criterion = variant['qf_criterion_class']() algorithm = variant['algo_class'](env, qf=qf, qf_criterion=qf_criterion, **variant['algo_params']) algorithm.to(ptu.device) algorithm.train()
def experiment(variant): env = gym.make(variant['env_id']) training_env = gym.make(variant['env_id']) qf = Mlp( hidden_sizes=[32, 32], input_size=int(np.prod(env.observation_space.shape)), output_size=env.action_space.n, ) qf_criterion = variant['qf_criterion_class']() algorithm = variant['algo_class'](env, training_env=training_env, qf=qf, qf_criterion=qf_criterion, **variant['algo_params']) algorithm.to(ptu.device) algorithm.train()
def experiment(variant): # env = gym.make('CartPole-v0') # training_env = gym.make('CartPole-v0') # env = DiscreteReacherEnv(**variant['env_kwargs']) # env = DiscreteSwimmerEnv() env = variant['env_class'](**variant['env_kwargs']) env = DiscretizeEnv(env, variant['num_bins']) qf = Mlp(input_size=int(np.prod(env.observation_space.shape)), output_size=env.action_space.n, **variant['qf_kwargs']) qf_criterion = nn.MSELoss() # Use this to switch to DoubleDQN # algorithm = DoubleDQN( algorithm = variant['algo_class'](env, qf=qf, qf_criterion=qf_criterion, **variant['algo_kwargs']) algorithm.to(ptu.device) algorithm.train()
def __init__( self, representation_size, architecture, normalize=True, output_classes=100, encoder_class=CNN, decoder_class=DCNN, decoder_output_activation=identity, decoder_distribution='bernoulli', input_channels=1, imsize=224, init_w=1e-3, min_variance=1e-3, hidden_init=ptu.fanin_init, delta_features=False, pretrained_features=False, ): """ :param representation_size: :param conv_args: must be a dictionary specifying the following: kernel_sizes n_channels strides :param conv_kwargs: a dictionary specifying the following: hidden_sizes batch_norm :param deconv_args: must be a dictionary specifying the following: hidden_sizes deconv_input_width deconv_input_height deconv_input_channels deconv_output_kernel_size deconv_output_strides deconv_output_channels kernel_sizes n_channels strides :param deconv_kwargs: batch_norm :param encoder_class: :param decoder_class: :param decoder_output_activation: :param decoder_distribution: :param input_channels: :param imsize: :param init_w: :param min_variance: :param hidden_init: """ super().__init__() # super().__init__(representation_size) if min_variance is None: self.log_min_variance = None else: self.log_min_variance = float(np.log(min_variance)) self.input_channels = input_channels self.imsize = imsize self.imlength = self.imsize * self.imsize * self.input_channels self.representation_size = representation_size self.output_classes = output_classes self.normalize = normalize self.img_mean = torch.tensor([0.485, 0.456, 0.406]) self.img_std = torch.tensor([0.229, 0.224, 0.225]) self.img_mean = self.img_mean.repeat(epic.CROP_WIDTH, epic.CROP_HEIGHT, 1).transpose(0, 2).to(ptu.device) self.img_std = self.img_std.repeat(epic.CROP_WIDTH, epic.CROP_HEIGHT, 1).transpose(0, 2).to(ptu.device) # self.img_normalizer = torchvision.transforms.Normalize(self.img_mean, self.img_std) self.encoder = torchvision.models.resnet.ResNet( torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], num_classes=representation_size, ) self.encoder.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # self.encoder = nn.DataParallel(self.encoder) if pretrained_features: exclude_names = ["fc"] state_dict = load_state_dict_from_url( "https://download.pytorch.org/models/resnet18-5c106cde.pth", progress=True, ) new_state_dict = state_dict.copy() for key in state_dict: for name in exclude_names: if name in key: del new_state_dict[key] break self.encoder.load_state_dict(new_state_dict, strict=False) self.delta_features = delta_features input_size = representation_size * 2 if delta_features else representation_size * 3 self.predictor = Mlp(output_size=output_classes, input_size=input_size, **architecture) self.predictor = self.predictor.to("cuda:0") self.epoch = 0
class TimestepPredictionModel(torch.nn.Module): def __init__( self, representation_size, architecture, normalize=True, output_classes=100, encoder_class=CNN, decoder_class=DCNN, decoder_output_activation=identity, decoder_distribution='bernoulli', input_channels=1, imsize=224, init_w=1e-3, min_variance=1e-3, hidden_init=ptu.fanin_init, delta_features=False, pretrained_features=False, ): """ :param representation_size: :param conv_args: must be a dictionary specifying the following: kernel_sizes n_channels strides :param conv_kwargs: a dictionary specifying the following: hidden_sizes batch_norm :param deconv_args: must be a dictionary specifying the following: hidden_sizes deconv_input_width deconv_input_height deconv_input_channels deconv_output_kernel_size deconv_output_strides deconv_output_channels kernel_sizes n_channels strides :param deconv_kwargs: batch_norm :param encoder_class: :param decoder_class: :param decoder_output_activation: :param decoder_distribution: :param input_channels: :param imsize: :param init_w: :param min_variance: :param hidden_init: """ super().__init__() # super().__init__(representation_size) if min_variance is None: self.log_min_variance = None else: self.log_min_variance = float(np.log(min_variance)) self.input_channels = input_channels self.imsize = imsize self.imlength = self.imsize * self.imsize * self.input_channels self.representation_size = representation_size self.output_classes = output_classes self.normalize = normalize self.img_mean = torch.tensor([0.485, 0.456, 0.406]) self.img_std = torch.tensor([0.229, 0.224, 0.225]) self.img_mean = self.img_mean.repeat(epic.CROP_WIDTH, epic.CROP_HEIGHT, 1).transpose(0, 2).to(ptu.device) self.img_std = self.img_std.repeat(epic.CROP_WIDTH, epic.CROP_HEIGHT, 1).transpose(0, 2).to(ptu.device) # self.img_normalizer = torchvision.transforms.Normalize(self.img_mean, self.img_std) self.encoder = torchvision.models.resnet.ResNet( torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], num_classes=representation_size, ) self.encoder.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # self.encoder = nn.DataParallel(self.encoder) if pretrained_features: exclude_names = ["fc"] state_dict = load_state_dict_from_url( "https://download.pytorch.org/models/resnet18-5c106cde.pth", progress=True, ) new_state_dict = state_dict.copy() for key in state_dict: for name in exclude_names: if name in key: del new_state_dict[key] break self.encoder.load_state_dict(new_state_dict, strict=False) self.delta_features = delta_features input_size = representation_size * 2 if delta_features else representation_size * 3 self.predictor = Mlp(output_size=output_classes, input_size=input_size, **architecture) self.predictor = self.predictor.to("cuda:0") self.epoch = 0 def get_latents(self, x0, xt, xT): bz = x0.shape[0] x = torch.cat([x0, xt, xT], dim=0).view( -1, 3, epic.CROP_HEIGHT, epic.CROP_WIDTH, ) z = self.encode(x) # # import pdb; pdb.set_trace() # if self.normalize: # x = x - self.img_mean # x = x / self.img_std # # x = self.img_normalizer(x) # zs = [] # for i in range(0, 3 * bz, MAX_BATCH_SIZE): # z = self.encoder(x[i:i+MAX_BATCH_SIZE, :, :, :]) # zs.append(z) # # z = self.encoder(x) # z = torch.cat(zs) # self.encoder(x) # .to("cuda:0") z0, zt, zT = z[:bz, :], z[bz:2 * bz, :], z[2 * bz:3 * bz, :] return z0, zt, zT def forward(self, x0, xt, xT): z0, zt, zT = self.get_latents(x0, xt, xT) # z0 = self.encoder(x0.view(-1, 3, 456, 256)).to("cuda:0") #.view((-1, 3, 240, 240))[:, :, :224, :224]) # zt = self.encoder(xt.view(-1, 3, 456, 256)).to("cuda:0") # .view((-1, 3, 240, 240))[:, :, :224, :224]) # zT = self.encoder(xT.view(-1, 3, 456, 256)).to("cuda:0") # .view((-1, 3, 240, 240))[:, :, :224, :224]) if self.delta_features: dt = zt - z0 dT = zT - z0 z = torch.cat([dt, dT], dim=1) else: z = torch.cat([z0, zt, zT], dim=1) out = self.predictor(z) return out def encode(self, x): bz = x.shape[0] if self.normalize: x = x - self.img_mean x = x / self.img_std zs = [] for i in range(0, bz, MAX_BATCH_SIZE): z = self.encoder(x[i:i + MAX_BATCH_SIZE, :, :, :]) zs.append(z) z = torch.cat(zs) return z
z = input for fc, gate in zip(self.fcs, self.gates): z = torch.sin(gate(z)) h = fc(h) * z return self.last_fc(h) def num_params(net): return nn.utils.parameters_to_vector(net.parameters()).shape[0] mean_y = np.mean(test_y_np, axis=0) print("Mean y", mean_y) print("Constant error", np.mean((test_y_np - mean_y)**2)) plt.figure() mlp = Mlp(hidden_sizes=[100, 100], output_size=out_dim, input_size=in_dim) mlp_n_params = num_params(mlp) h_size = 100 jac_net = JacobianNet(hidden_sizes=[h_size, h_size], output_size=out_dim, input_size=in_dim) # keep # paramers ~ same while num_params(jac_net) > mlp_n_params: h_size -= 5 jac_net = JacobianNet(hidden_sizes=[h_size, h_size], output_size=out_dim, input_size=in_dim) print("jac_net h_size:", h_size) linear_net = Mlp(hidden_sizes=[], output_size=out_dim, input_size=in_dim)
class OnlineVaeRelabelingBuffer(ObsDictRelabelingBuffer): def __init__( self, vae, *args, decoded_obs_key='image_observation', decoded_achieved_goal_key='image_achieved_goal', decoded_desired_goal_key='image_desired_goal', exploration_rewards_type='None', exploration_rewards_scale=1.0, vae_priority_type='None', start_skew_epoch=0, power=1.0, internal_keys=None, exploration_schedule_kwargs=None, priority_function_kwargs=None, exploration_counter_kwargs=None, relabeling_goal_sampling_mode='vae_prior', decode_vae_goals=False, **kwargs ): if internal_keys is None: internal_keys = [] for key in [ decoded_obs_key, decoded_achieved_goal_key, decoded_desired_goal_key ]: if key not in internal_keys: internal_keys.append(key) super().__init__(internal_keys=internal_keys, *args, **kwargs) # assert isinstance(self.env, VAEWrappedEnv) self.vae = vae self.decoded_obs_key = decoded_obs_key self.decoded_desired_goal_key = decoded_desired_goal_key self.decoded_achieved_goal_key = decoded_achieved_goal_key self.exploration_rewards_type = exploration_rewards_type self.exploration_rewards_scale = exploration_rewards_scale self.start_skew_epoch = start_skew_epoch self.vae_priority_type = vae_priority_type self.power = power self._relabeling_goal_sampling_mode = relabeling_goal_sampling_mode self.decode_vae_goals = decode_vae_goals if exploration_schedule_kwargs is None: self.explr_reward_scale_schedule = \ ConstantSchedule(self.exploration_rewards_scale) else: self.explr_reward_scale_schedule = \ PiecewiseLinearSchedule(**exploration_schedule_kwargs) self._give_explr_reward_bonus = ( exploration_rewards_type != 'None' and exploration_rewards_scale != 0. ) self._exploration_rewards = np.zeros((self.max_size, 1), dtype=np.float32) self._prioritize_vae_samples = ( vae_priority_type != 'None' and power != 0. ) self._vae_sample_priorities = np.zeros((self.max_size, 1), dtype=np.float32) self._vae_sample_probs = None self.use_dynamics_model = ( self.exploration_rewards_type == 'forward_model_error' ) if self.use_dynamics_model: self.initialize_dynamics_model() type_to_function = { 'reconstruction_error': self.reconstruction_mse, 'bce': self.binary_cross_entropy, 'latent_distance': self.latent_novelty, 'latent_distance_true_prior': self.latent_novelty_true_prior, 'forward_model_error': self.forward_model_error, 'gaussian_inv_prob': self.gaussian_inv_prob, 'bernoulli_inv_prob': self.bernoulli_inv_prob, 'vae_prob': self.vae_prob, 'hash_count': self.hash_count_reward, 'None': self.no_reward, } self.exploration_reward_func = ( type_to_function[self.exploration_rewards_type] ) self.vae_prioritization_func = ( type_to_function[self.vae_priority_type] ) if priority_function_kwargs is None: self.priority_function_kwargs = dict() else: self.priority_function_kwargs = priority_function_kwargs if self.exploration_rewards_type == 'hash_count': if exploration_counter_kwargs is None: exploration_counter_kwargs = dict() self.exploration_counter = CountExploration(env=self.env, **exploration_counter_kwargs) self.epoch = 0 def add_path(self, path): if self.decode_vae_goals: self.add_decoded_vae_goals_to_path(path) super().add_path(path) def add_decoded_vae_goals_to_path(self, path): # decoding the self-sampled vae images should be done in batch (here) # rather than in the env for efficiency desired_goals = flatten_dict( path['observations'], [self.desired_goal_key] )[self.desired_goal_key] desired_decoded_goals = self.env._decode(desired_goals) desired_decoded_goals = desired_decoded_goals.reshape( len(desired_decoded_goals), -1 ) for idx, next_obs in enumerate(path['observations']): path['observations'][idx][self.decoded_desired_goal_key] = \ desired_decoded_goals[idx] path['next_observations'][idx][self.decoded_desired_goal_key] = \ desired_decoded_goals[idx] def random_batch(self, batch_size): batch = super().random_batch(batch_size) exploration_rewards_scale = float(self.explr_reward_scale_schedule.get_value(self.epoch)) if self._give_explr_reward_bonus: batch_idxs = batch['indices'].flatten() batch['exploration_rewards'] = self._exploration_rewards[batch_idxs] batch['rewards'] += exploration_rewards_scale * batch['exploration_rewards'] return batch def get_diagnostics(self): if self._vae_sample_probs is None or self._vae_sample_priorities is None: stats = create_stats_ordered_dict( 'VAE Sample Weights', np.zeros(self._size), ) stats.update(create_stats_ordered_dict( 'VAE Sample Probs', np.zeros(self._size), )) else: vae_sample_priorities = self._vae_sample_priorities[:self._size] vae_sample_probs = self._vae_sample_probs[:self._size] stats = create_stats_ordered_dict( 'VAE Sample Weights', vae_sample_priorities, ) stats.update(create_stats_ordered_dict( 'VAE Sample Probs', vae_sample_probs, )) return stats def refresh_latents(self, epoch): self.epoch = epoch self.skew = (self.epoch > self.start_skew_epoch) batch_size = 512 next_idx = min(batch_size, self._size) if self.exploration_rewards_type == 'hash_count': # you have to count everything then compute exploration rewards cur_idx = 0 next_idx = min(batch_size, self._size) while cur_idx < self._size: idxs = np.arange(cur_idx, next_idx) normalized_imgs = self._next_obs[self.decoded_obs_key][idxs] self.update_hash_count(normalized_imgs) cur_idx = next_idx next_idx += batch_size next_idx = min(next_idx, self._size) cur_idx = 0 obs_sum = np.zeros(self.vae.representation_size) obs_square_sum = np.zeros(self.vae.representation_size) while cur_idx < self._size: idxs = np.arange(cur_idx, next_idx) self._obs[self.observation_key][idxs] = \ self.env._encode(self._obs[self.decoded_obs_key][idxs]) self._next_obs[self.observation_key][idxs] = \ self.env._encode(self._next_obs[self.decoded_obs_key][idxs]) # WARNING: we only refresh the desired/achieved latents for # "next_obs". This means that obs[desired/achieve] will be invalid, # so make sure there's no code that references this. # TODO: enforce this with code and not a comment self._next_obs[self.desired_goal_key][idxs] = \ self.env._encode(self._next_obs[self.decoded_desired_goal_key][idxs]) self._next_obs[self.achieved_goal_key][idxs] = \ self.env._encode(self._next_obs[self.decoded_achieved_goal_key][idxs]) normalized_imgs = self._next_obs[self.decoded_obs_key][idxs] if self._give_explr_reward_bonus: rewards = self.exploration_reward_func( normalized_imgs, idxs, **self.priority_function_kwargs ) self._exploration_rewards[idxs] = rewards.reshape(-1, 1) if self._prioritize_vae_samples: if ( self.exploration_rewards_type == self.vae_priority_type and self._give_explr_reward_bonus ): self._vae_sample_priorities[idxs] = ( self._exploration_rewards[idxs] ) else: self._vae_sample_priorities[idxs] = ( self.vae_prioritization_func( normalized_imgs, idxs, **self.priority_function_kwargs ).reshape(-1, 1) ) obs_sum+= self._obs[self.observation_key][idxs].sum(axis=0) obs_square_sum+= np.power(self._obs[self.observation_key][idxs], 2).sum(axis=0) cur_idx = next_idx next_idx += batch_size next_idx = min(next_idx, self._size) self.vae.dist_mu = obs_sum/self._size self.vae.dist_std = np.sqrt(obs_square_sum/self._size - np.power(self.vae.dist_mu, 2)) if self._prioritize_vae_samples: """ priority^power is calculated in the priority function for image_bernoulli_prob or image_gaussian_inv_prob and directly here if not. """ if self.vae_priority_type == 'vae_prob': self._vae_sample_priorities[:self._size] = relative_probs_from_log_probs( self._vae_sample_priorities[:self._size] ) self._vae_sample_probs = self._vae_sample_priorities[:self._size] else: self._vae_sample_probs = self._vae_sample_priorities[:self._size] ** self.power p_sum = np.sum(self._vae_sample_probs) assert p_sum > 0, "Unnormalized p sum is {}".format(p_sum) self._vae_sample_probs /= np.sum(self._vae_sample_probs) self._vae_sample_probs = self._vae_sample_probs.flatten() def sample_weighted_indices(self, batch_size): if ( self._prioritize_vae_samples and self._vae_sample_probs is not None and self.skew ): indices = np.random.choice( len(self._vae_sample_probs), batch_size, p=self._vae_sample_probs, ) assert ( np.max(self._vae_sample_probs) <= 1 and np.min(self._vae_sample_probs) >= 0 ) else: indices = self._sample_indices(batch_size) return indices def _sample_goals_from_env(self, batch_size): self.env.goal_sampling_mode = self._relabeling_goal_sampling_mode return self.env.sample_goals(batch_size) def sample_buffer_goals(self, batch_size): """ Samples goals from weighted replay buffer for relabeling or exploration. Returns None if replay buffer is empty. Example of what might be returned: dict( image_desired_goals: image_achieved_goals[weighted_indices], latent_desired_goals: latent_desired_goals[weighted_indices], ) """ if self._size == 0: return None weighted_idxs = self.sample_weighted_indices( batch_size, ) next_image_obs = self._next_obs[self.decoded_obs_key][weighted_idxs] next_latent_obs = self._next_obs[self.achieved_goal_key][weighted_idxs] return { self.decoded_desired_goal_key: next_image_obs, self.desired_goal_key: next_latent_obs } def random_vae_training_data(self, batch_size, epoch): # epoch no longer needed. Using self.skew in sample_weighted_indices # instead. weighted_idxs = self.sample_weighted_indices( batch_size, ) next_image_obs = self._next_obs[self.decoded_obs_key][weighted_idxs] observations = ptu.from_numpy(next_image_obs) return dict( observations=observations, ) def reconstruction_mse(self, next_vae_obs, indices): torch_input = ptu.from_numpy(next_vae_obs) recon_next_vae_obs, _, _ = self.vae(torch_input) error = torch_input - recon_next_vae_obs mse = torch.sum(error ** 2, dim=1) return ptu.get_numpy(mse) def gaussian_inv_prob(self, next_vae_obs, indices): return np.exp(self.reconstruction_mse(next_vae_obs, indices)) def binary_cross_entropy(self, next_vae_obs, indices): torch_input = ptu.from_numpy(next_vae_obs) recon_next_vae_obs, _, _ = self.vae(torch_input) error = - torch_input * torch.log( torch.clamp( recon_next_vae_obs, min=1e-30, # corresponds to about -70 ) ) bce = torch.sum(error, dim=1) return ptu.get_numpy(bce) def bernoulli_inv_prob(self, next_vae_obs, indices): torch_input = ptu.from_numpy(next_vae_obs) recon_next_vae_obs, _, _ = self.vae(torch_input) prob = ( torch_input * recon_next_vae_obs + (1 - torch_input) * (1 - recon_next_vae_obs) ).prod(dim=1) return ptu.get_numpy(1 / prob) def vae_prob(self, next_vae_obs, indices, **kwargs): return compute_p_x_np_to_np( self.vae, next_vae_obs, power=self.power, **kwargs ) def forward_model_error(self, next_vae_obs, indices): obs = self._obs[self.observation_key][indices] next_obs = self._next_obs[self.observation_key][indices] actions = self._actions[indices] state_action_pair = ptu.from_numpy(np.c_[obs, actions]) prediction = self.dynamics_model(state_action_pair) mse = self.dynamics_loss(prediction, ptu.from_numpy(next_obs)) return ptu.get_numpy(mse) def latent_novelty(self, next_vae_obs, indices): distances = ((self.env._encode(next_vae_obs) - self.vae.dist_mu) / self.vae.dist_std) ** 2 return distances.sum(axis=1) def latent_novelty_true_prior(self, next_vae_obs, indices): distances = self.env._encode(next_vae_obs) ** 2 return distances.sum(axis=1) def _kl_np_to_np(self, next_vae_obs, indices): torch_input = ptu.from_numpy(next_vae_obs) mu, log_var = self.vae.encode(torch_input) return ptu.get_numpy( - torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1) ) def update_hash_count(self, next_vae_obs): torch_input = ptu.from_numpy(next_vae_obs) mus, log_vars = self.vae.encode(torch_input) mus = ptu.get_numpy(mus) self.exploration_counter.increment_counts(mus) return None def hash_count_reward(self, next_vae_obs, indices): obs = self.env._encode(next_vae_obs) return self.exploration_counter.compute_count_based_reward(obs) def no_reward(self, next_vae_obs, indices): return np.zeros((len(next_vae_obs), 1)) def initialize_dynamics_model(self): obs_dim = self._obs[self.observation_key].shape[1] self.dynamics_model = Mlp( hidden_sizes=[128, 128], output_size=obs_dim, input_size=obs_dim + self._action_dim, ) self.dynamics_model.to(ptu.device) self.dynamics_optimizer = Adam(self.dynamics_model.parameters()) self.dynamics_loss = MSELoss() def train_dynamics_model(self, batches=50, batch_size=100): if not self.use_dynamics_model: return for _ in range(batches): indices = self._sample_indices(batch_size) self.dynamics_optimizer.zero_grad() obs = self._obs[self.observation_key][indices] next_obs = self._next_obs[self.observation_key][indices] actions = self._actions[indices] if self.exploration_rewards_type == 'inverse_model_error': obs, next_obs = next_obs, obs state_action_pair = ptu.from_numpy(np.c_[obs, actions]) prediction = self.dynamics_model(state_action_pair) mse = self.dynamics_loss(prediction, ptu.from_numpy(next_obs)) mse.backward() self.dynamics_optimizer.step() def log_loss_under_uniform(self, model, data, batch_size, rl_logger, priority_function_kwargs): import torch.nn.functional as F log_probs_prior = [] log_probs_biased = [] log_probs_importance = [] kles = [] mses = [] for i in range(0, data.shape[0], batch_size): img = data[i:min(data.shape[0], i + batch_size), :] torch_img = ptu.from_numpy(img) reconstructions, obs_distribution_params, latent_distribution_params = self.vae(torch_img) priority_function_kwargs['sampling_method'] = 'true_prior_sampling' log_p, log_q, log_d = compute_log_p_log_q_log_d(model, img, **priority_function_kwargs) log_prob_prior = log_d.mean() priority_function_kwargs['sampling_method'] = 'biased_sampling' log_p, log_q, log_d = compute_log_p_log_q_log_d(model, img, **priority_function_kwargs) log_prob_biased = log_d.mean() priority_function_kwargs['sampling_method'] = 'importance_sampling' log_p, log_q, log_d = compute_log_p_log_q_log_d(model, img, **priority_function_kwargs) log_prob_importance = (log_p - log_q + log_d).mean() kle = model.kl_divergence(latent_distribution_params) mse = F.mse_loss(torch_img, reconstructions, reduction='elementwise_mean') mses.append(mse.item()) kles.append(kle.item()) log_probs_prior.append(log_prob_prior.item()) log_probs_biased.append(log_prob_biased.item()) log_probs_importance.append(log_prob_importance.item()) rl_logger["Uniform Data Log Prob (Prior)"] = np.mean(log_probs_prior) rl_logger["Uniform Data Log Prob (Biased)"] = np.mean(log_probs_biased) rl_logger["Uniform Data Log Prob (Importance)"] = np.mean(log_probs_importance) rl_logger["Uniform Data KL"] = np.mean(kles) rl_logger["Uniform Data MSE"] = np.mean(mses) def _get_sorted_idx_and_train_weights(self): idx_and_weights = zip(range(len(self._vae_sample_probs)), self._vae_sample_probs) return sorted(idx_and_weights, key=lambda x: x[1])