def compute_log_p_log_q_log_d(model, batch, decoder_distribution='bernoulli', num_latents_to_sample=1, sampling_method='importance_sampling'): x_0 = ptu.from_numpy(batch["x_0"]) data = batch["x_t"] imgs = ptu.from_numpy(data) latent_distribution_params = model.encode(imgs, x_0) r1 = model.latent_sizes[0] batch_size = data.shape[0] log_p, log_q, log_d = ptu.zeros( (batch_size, num_latents_to_sample)), ptu.zeros( (batch_size, num_latents_to_sample)), ptu.zeros( (batch_size, num_latents_to_sample)) true_prior = Normal(ptu.zeros((batch_size, r1)), ptu.ones( (batch_size, r1))) mus, logvars = latent_distribution_params[:2] for i in range(num_latents_to_sample): if sampling_method == 'importance_sampling': latents = model.rsample(latent_distribution_params[:2]) elif sampling_method == 'biased_sampling': latents = model.rsample(latent_distribution_params[:2]) elif sampling_method == 'true_prior_sampling': latents = true_prior.rsample() else: raise EnvironmentError('Invalid Sampling Method Provided') stds = logvars.exp().pow(.5) vae_dist = Normal(mus, stds) log_p_z = true_prior.log_prob(latents).sum(dim=1) log_q_z_given_x = vae_dist.log_prob(latents).sum(dim=1) if len(latent_distribution_params) == 3: # add conditioning for CVAEs latents = torch.cat((latents, latent_distribution_params[2]), dim=1) if decoder_distribution == 'bernoulli': decoded = model.decode(latents)[0] log_d_x_given_z = torch.log(imgs * decoded + (1 - imgs) * (1 - decoded) + 1e-8).sum(dim=1) elif decoder_distribution == 'gaussian_identity_variance': _, obs_distribution_params = model.decode(latents) dec_mu, dec_logvar = obs_distribution_params dec_var = dec_logvar.exp() decoder_dist = Normal(dec_mu, dec_var.pow(.5)) log_d_x_given_z = decoder_dist.log_prob(imgs).sum(dim=1) else: raise EnvironmentError('Invalid Decoder Distribution Provided') log_p[:, i] = log_p_z log_q[:, i] = log_q_z_given_x log_d[:, i] = log_d_x_given_z return log_p, log_q, log_d
def compute_log_p_log_q_log_d( model, data, decoder_distribution='bernoulli', num_latents_to_sample=1, sampling_method='importance_sampling' ): assert data.dtype != np.uint8, 'images should be normalized' imgs = ptu.from_numpy(data) latent_distribution_params = model.encode(imgs) representation_size = model.representation_size batch_size = data.shape[0] log_p, log_q, log_d = ptu.zeros((batch_size, num_latents_to_sample)), ptu.zeros( (batch_size, num_latents_to_sample)), ptu.zeros((batch_size, num_latents_to_sample)) true_prior = Normal(ptu.zeros((batch_size, representation_size)), ptu.ones((batch_size, representation_size))) mus, logvars = latent_distribution_params for i in range(num_latents_to_sample): if sampling_method == 'importance_sampling': latents = model.rsample(latent_distribution_params) elif sampling_method == 'biased_sampling': latents = model.rsample(latent_distribution_params) elif sampling_method == 'true_prior_sampling': latents = true_prior.rsample() else: raise EnvironmentError('Invalid Sampling Method Provided') stds = logvars.exp().pow(.5) vae_dist = Normal(mus, stds) log_p_z = true_prior.log_prob(latents).sum(dim=1) log_q_z_given_x = vae_dist.log_prob(latents).sum(dim=1) if decoder_distribution == 'bernoulli': decoded = model.decode(latents)[0] log_d_x_given_z = torch.log(imgs * decoded + (1 - imgs) * (1 - decoded) + 1e-8).sum(dim=1) elif decoder_distribution == 'gaussian_identity_variance': _, obs_distribution_params = model.decode(latents) dec_mu, dec_logvar = obs_distribution_params dec_var = dec_logvar.exp() decoder_dist = Normal(dec_mu, dec_var.pow(.5)) log_d_x_given_z = decoder_dist.log_prob(imgs).sum(dim=1) else: raise EnvironmentError('Invalid Decoder Distribution Provided') log_p[:, i] = log_p_z log_q[:, i] = log_q_z_given_x log_d[:, i] = log_d_x_given_z return log_p, log_q, log_d
def rsample(self, ): z = (self.normal_means + self.normal_stds * Normal(ptu.zeros(self.normal_means.size()), ptu.ones(self.normal_stds.size())).sample()) z.requires_grad_() c = self.categorical.sample()[:, :, None] s = torch.matmul(z, c) return torch.squeeze(s, 2)
def mean(self, ): """Misleading function name; this actually now samples the mean of the most likely component. c ~ argmax(C), returns mu_c This often computes the mode of the distribution, but not always. """ c = ptu.zeros(self.weights.shape[:2]) ind = torch.argmax(self.weights, dim=1) # [:, 0] c.scatter_(1, ind, 1) s = torch.matmul(self.normal_means, c[:, :, None]) return torch.squeeze(s, 2)
def compute_density(self, data): orig_data_length = len(data) data = np.vstack([ data for _ in range(self.n_average) ]) data = ptu.from_numpy(data) if self.mode == 'biased': latents, means, log_vars, stds = ( self.encoder.get_encoding_and_suff_stats(data) ) importance_weights = ptu.ones(data.shape[0]) elif self.mode == 'prior': latents = ptu.randn(len(data), self.z_dim) importance_weights = ptu.ones(data.shape[0]) elif self.mode == 'importance_sampling': latents, means, log_vars, stds = ( self.encoder.get_encoding_and_suff_stats(data) ) prior = Normal(ptu.zeros(1), ptu.ones(1)) prior_log_prob = prior.log_prob(latents).sum(dim=1) encoder_distrib = Normal(means, stds) encoder_log_prob = encoder_distrib.log_prob(latents).sum(dim=1) importance_weights = (prior_log_prob - encoder_log_prob).exp() else: raise NotImplementedError() unweighted_data_log_prob = self.compute_log_prob( data, self.decoder, latents ).squeeze(1) unweighted_data_prob = unweighted_data_log_prob.exp() unnormalized_data_prob = unweighted_data_prob * importance_weights """ Average over `n_average` """ dp_split = torch.split(unnormalized_data_prob, orig_data_length, dim=0) # pre_avg.shape = ORIG_LEN x N_AVERAGE dp_stacked = torch.stack(dp_split, dim=1) # final.shape = ORIG_LEN unnormalized_dp = torch.sum(dp_stacked, dim=1, keepdim=False) """ Compute the importance weight denomintors. This requires summing across the `n_average` dimension. """ iw_split = torch.split(importance_weights, orig_data_length, dim=0) iw_stacked = torch.stack(iw_split, dim=1) iw_denominators = iw_stacked.sum(dim=1, keepdim=False) final = unnormalized_dp / iw_denominators return ptu.get_numpy(final)
def rsample(self, return_pretanh_value=False): """ Sampling in the reparameterization case. """ z = (self.normal_mean + self.normal_std * Normal(ptu.zeros(self.normal_mean.size()), ptu.ones(self.normal_std.size())).sample()) z.requires_grad_() if return_pretanh_value: return torch.tanh(z), z else: return torch.tanh(z)
def __init__( self, env, dsp, policy, classifier, search_buffer, policy_lr=1e-3, classifier_lr=1e-3, optimizer_class=optim.Adam, use_automatic_entropy_tuning=True, target_entropy=None, ): super().__init__() self.env = env self.dsp = dsp self.policy = policy self.classifier = classifier self.search_buffer = search_buffer self.use_automatic_entropy_tuning = use_automatic_entropy_tuning if self.use_automatic_entropy_tuning: if target_entropy: self.target_entropy = target_entropy else: self.target_entropy = -np.prod(self.env.action_space.shape).item() self.log_alpha = ptu.zeros(1, requires_grad=True) self.alpha_optimizer = optimizer_class( [self.log_alpha], lr=policy_lr, ) self.classifier_criterion = nn.MSELoss() self.dsp_optimizer = optimizer_class( self.dsp.parameters(), lr=policy_lr, ) self.policy_optimizer = optimizer_class( self.policy.parameters(), lr=policy_lr, ) self.classisier_optimizer = optimizer_class( self.classifier.parameters(), lr=classifier_lr, ) self.eval_statistics = OrderedDict() self._n_train_steps_total = 0 self._need_to_update_eval_statistics = True
def __init__(self, hidden_sizes, obs_dim, action_dim, std=None, init_w=1e-3, min_log_std=None, max_log_std=None, num_gaussians=1, std_architecture="shared", **kwargs): super().__init__( hidden_sizes, input_size=obs_dim, output_size=action_dim * num_gaussians, init_w=init_w, # output_activation=torch.tanh, **kwargs) self.action_dim = action_dim self.num_gaussians = num_gaussians self.min_log_std = min_log_std self.max_log_std = max_log_std self.log_std = None self.std = std self.std_architecture = std_architecture if std is None: last_hidden_size = obs_dim if len(hidden_sizes) > 0: last_hidden_size = hidden_sizes[-1] if self.std_architecture == "shared": self.last_fc_log_std = nn.Linear(last_hidden_size, action_dim * num_gaussians) self.last_fc_log_std.weight.data.uniform_(-init_w, init_w) self.last_fc_log_std.bias.data.uniform_(-init_w, init_w) elif self.std_architecture == "values": self.log_std_logits = nn.Parameter( ptu.zeros(action_dim * num_gaussians, requires_grad=True)) else: error else: self.log_std = np.log(std) assert LOG_SIG_MIN <= self.log_std <= LOG_SIG_MAX self.last_fc_weights = nn.Linear(last_hidden_size, num_gaussians) self.last_fc_weights.weight.data.uniform_(-init_w, init_w) self.last_fc_weights.bias.data.uniform_(-init_w, init_w)
def __init__( self, downsample_size, input_channels, num_feat_points, temperature=1.0, init_w=1e-3, input_size=32, hidden_init=ptu.fanin_init, output_activation=identity, ): super().__init__() self.downsample_size = downsample_size self.temperature = temperature self.num_feat_points = num_feat_points self.hidden_init = hidden_init self.output_activation = output_activation self.input_channels = input_channels self.input_size = input_size # self.bn1 = nn.BatchNorm2d(1) self.conv1 = nn.Conv2d(input_channels, 48, kernel_size=5, stride=2) # self.bn1 = nn.BatchNorm2d(16) self.conv2 = nn.Conv2d(48, 48, kernel_size=5, stride=1) self.conv3 = nn.Conv2d(48, self.num_feat_points, kernel_size=5, stride=1) test_mat = ptu.zeros(1, self.input_channels, self.input_size, self.input_size) test_mat = self.conv1(test_mat) test_mat = self.conv2(test_mat) test_mat = self.conv3(test_mat) self.out_size = int(np.prod(test_mat.shape)) self.fc1 = nn.Linear(2 * self.num_feat_points, 400) self.fc2 = nn.Linear(400, 300) self.last_fc = nn.Linear( 300, self.input_channels * self.downsample_size * self.downsample_size) self.init_weights(init_w) self.i = 0
def __init__( self, env, policy, qf1, qf2, target_qf1, target_qf2, discount=0.99, reward_scale=1.0, policy_lr=1e-3, qf_lr=1e-3, optimizer_class=optim.Adam, soft_target_tau=1e-2, target_update_period=1, plotter=None, render_eval_paths=False, use_automatic_entropy_tuning=True, target_entropy=None, ): super().__init__() self.env = env self.policy = policy self.qf1 = qf1 self.qf2 = qf2 self.target_qf1 = target_qf1 self.target_qf2 = target_qf2 self.soft_target_tau = soft_target_tau self.target_update_period = target_update_period self.use_automatic_entropy_tuning = use_automatic_entropy_tuning if self.use_automatic_entropy_tuning: if target_entropy: self.target_entropy = target_entropy else: self.target_entropy = -np.prod(self.env.action_space.shape).item() # heuristic value from Tuomas self.log_alpha = ptu.zeros(1, requires_grad=True) self.alpha_optimizer = optimizer_class( [self.log_alpha], lr=policy_lr, ) self.plotter = plotter self.render_eval_paths = render_eval_paths self.qf_criterion = nn.MSELoss() self.vf_criterion = nn.MSELoss() self.policy_optimizer = optimizer_class( self.policy.parameters(), lr=policy_lr, ) self.qf1_optimizer = optimizer_class( self.qf1.parameters(), lr=qf_lr, ) self.qf2_optimizer = optimizer_class( self.qf2.parameters(), lr=qf_lr, ) self.discount = discount self.reward_scale = reward_scale self.eval_statistics = OrderedDict() self._n_train_steps_total = 0 self._need_to_update_eval_statistics = True
def __init__( self, env, policy, qf1, qf2, target_qf1, target_qf2, behavior_policy=None, dim_mult=1, discount=0.99, reward_scale=1.0, policy_lr=1e-3, qf_lr=1e-3, optimizer_class=optim.Adam, soft_target_tau=1e-2, target_update_period=1, plotter=None, render_eval_paths=False, use_automatic_entropy_tuning=True, target_entropy=None, use_target_nets=True, policy_eval_start=0, num_qs=2, ## For min_Q runs with_min_q=False, new_min_q=False, min_q_version=0, temp=1.0, hinge_bellman=False, use_projected_grad=False, normalize_magnitudes=False, regress_constant=False, min_q_weight=1.0, data_subtract=True, ## sort of backup max_q_backup=False, deterministic_backup=False, num_random=4, ## handle lagrange with_lagrange=False, lagrange_thresh=10.0, ## Handling discrete actions discrete=False, *args, **kwargs): super().__init__() self.env = env self.policy = policy self.qf1 = qf1 self.qf2 = qf2 self.target_qf1 = target_qf1 self.target_qf2 = target_qf2 self.soft_target_tau = soft_target_tau self.target_update_period = target_update_period self.use_automatic_entropy_tuning = use_automatic_entropy_tuning if self.use_automatic_entropy_tuning: if target_entropy: self.target_entropy = target_entropy else: if self.env is None: self.target_entropy = -2 else: self.target_entropy = -np.prod( self.env.action_space.shape).item( ) # heuristic value from Tuomas self.log_alpha = ptu.zeros(1, requires_grad=True) self.alpha_optimizer = optimizer_class( [self.log_alpha], lr=policy_lr, ) self.with_lagrange = with_lagrange if self.with_lagrange: self.target_action_gap = lagrange_thresh self.log_alpha_prime = ptu.zeros(1, requires_grad=True) self.alpha_prime_optimizer = optimizer_class( [ self.log_alpha_prime, ], lr=qf_lr, ) self.plotter = plotter self.render_eval_paths = render_eval_paths self.qf_criterion = nn.MSELoss() self.vf_criterion = nn.MSELoss() self.policy_optimizer = optimizer_class( self.policy.parameters(), lr=policy_lr, ) self.qf1_optimizer = optimizer_class( self.qf1.parameters(), lr=qf_lr, ) self.qf2_optimizer = optimizer_class( self.qf2.parameters(), lr=qf_lr, ) self.discount = discount self.reward_scale = reward_scale self.eval_statistics = OrderedDict() self._n_train_steps_total = 0 self._need_to_update_eval_statistics = True self._use_target_nets = use_target_nets self.policy_eval_start = policy_eval_start self._current_epoch = 0 self._policy_update_ctr = 0 self._num_q_update_steps = 0 self._num_policy_update_steps = 0 self.policy_eval_start = policy_eval_start self._num_policy_steps = 1 if not self._use_target_nets: self.target_qf1 = qf1 self.target_qf2 = qf2 self.softmax = torch.nn.Softmax(dim=-1) self.num_qs = num_qs ## min Q self.with_min_q = with_min_q self.new_min_q = new_min_q self.temp = temp self.min_q_version = min_q_version self.use_projected_grad = use_projected_grad self.normalize_magnitudes = normalize_magnitudes self.regress_constant = regress_constant self.min_q_weight = min_q_weight self.softmax = torch.nn.Softmax(dim=1) self.hinge_bellman = hinge_bellman self.softplus = torch.nn.Softplus(beta=self.temp, threshold=20) self.data_subtract = data_subtract self.max_q_backup = max_q_backup self.deterministic_backup = deterministic_backup self.num_random = num_random self.discrete = discrete
def __init__( self, env, policy, qf1, qf2, target_qf1, target_qf2, buffer_policy=None, discount=0.99, reward_scale=1.0, beta=1.0, beta_schedule_kwargs=None, policy_lr=1e-3, qf_lr=1e-3, policy_weight_decay=0, q_weight_decay=0, optimizer_class=optim.Adam, soft_target_tau=1e-2, target_update_period=1, plotter=None, render_eval_paths=False, use_automatic_entropy_tuning=True, target_entropy=None, bc_num_pretrain_steps=0, q_num_pretrain1_steps=0, q_num_pretrain2_steps=0, bc_batch_size=128, bc_loss_type="mle", awr_loss_type="mle", save_bc_policies=0, alpha=1.0, policy_update_period=1, q_update_period=1, weight_loss=True, compute_bc=True, bc_weight=0.0, rl_weight=1.0, use_awr_update=True, use_reparam_update=False, reparam_weight=1.0, awr_weight=1.0, post_pretrain_hyperparams=None, post_bc_pretrain_hyperparams=None, awr_use_mle_for_vf=False, awr_sample_actions=False, awr_min_q=False, reward_transform_class=None, reward_transform_kwargs=None, terminal_transform_class=None, terminal_transform_kwargs=None, pretraining_env_logging_period=100000, pretraining_logging_period=1000, do_pretrain_rollouts=False, train_bc_on_rl_buffer=False, use_automatic_beta_tuning=False, beta_epsilon=1e-10, ): super().__init__() self.env = env self.policy = policy self.qf1 = qf1 self.qf2 = qf2 self.target_qf1 = target_qf1 self.target_qf2 = target_qf2 self.buffer_policy = buffer_policy self.soft_target_tau = soft_target_tau self.target_update_period = target_update_period self.use_awr_update = use_awr_update self.use_automatic_entropy_tuning = use_automatic_entropy_tuning if self.use_automatic_entropy_tuning: if target_entropy: self.target_entropy = target_entropy else: self.target_entropy = -np.prod(self.env.action_space.shape).item() # heuristic value from Tuomas self.log_alpha = ptu.zeros(1, requires_grad=True) self.alpha_optimizer = optimizer_class( [self.log_alpha], lr=policy_lr, ) self.awr_use_mle_for_vf = awr_use_mle_for_vf self.awr_sample_actions = awr_sample_actions self.awr_min_q = awr_min_q self.plotter = plotter self.render_eval_paths = render_eval_paths self.qf_criterion = nn.MSELoss() self.vf_criterion = nn.MSELoss() self.policy_optimizer = optimizer_class( self.policy.parameters(), weight_decay=policy_weight_decay, lr=policy_lr, ) self.qf1_optimizer = optimizer_class( self.qf1.parameters(), weight_decay=q_weight_decay, lr=qf_lr, ) self.qf2_optimizer = optimizer_class( self.qf2.parameters(), weight_decay=q_weight_decay, lr=qf_lr, ) if buffer_policy and train_bc_on_rl_buffer: self.buffer_policy_optimizer = optimizer_class( self.buffer_policy.parameters(), weight_decay=policy_weight_decay, lr=policy_lr, ) self.use_automatic_beta_tuning = use_automatic_beta_tuning and buffer_policy and train_bc_on_rl_buffer self.beta_epsilon=beta_epsilon if self.use_automatic_beta_tuning: self.log_beta = ptu.zeros(1, requires_grad=True) self.beta_optimizer = optimizer_class( [self.log_beta], lr=policy_lr, ) else: self.beta = beta self.beta_schedule_kwargs = beta_schedule_kwargs if beta_schedule_kwargs is None: self.beta_schedule = ConstantSchedule(beta) else: schedule_class = beta_schedule_kwargs.pop("schedule_class", PiecewiseLinearSchedule) self.beta_schedule = schedule_class(**beta_schedule_kwargs) self.discount = discount self.reward_scale = reward_scale self.eval_statistics = OrderedDict() self._n_train_steps_total = 0 self._need_to_update_eval_statistics = True self.bc_num_pretrain_steps = bc_num_pretrain_steps self.q_num_pretrain1_steps = q_num_pretrain1_steps self.q_num_pretrain2_steps = q_num_pretrain2_steps self.bc_batch_size = bc_batch_size self.bc_loss_type = bc_loss_type self.awr_loss_type = awr_loss_type self.rl_weight = rl_weight self.bc_weight = bc_weight self.save_bc_policies = save_bc_policies self.eval_policy = MakeDeterministic(self.policy) self.compute_bc = compute_bc self.alpha = alpha self.q_update_period = q_update_period self.policy_update_period = policy_update_period self.weight_loss = weight_loss self.reparam_weight = reparam_weight self.awr_weight = awr_weight self.post_pretrain_hyperparams = post_pretrain_hyperparams self.post_bc_pretrain_hyperparams = post_bc_pretrain_hyperparams self.update_policy = True self.pretraining_env_logging_period = pretraining_env_logging_period self.pretraining_logging_period = pretraining_logging_period self.do_pretrain_rollouts = do_pretrain_rollouts self.reward_transform_class = reward_transform_class or LinearTransform self.reward_transform_kwargs = reward_transform_kwargs or dict(m=1, b=0) self.terminal_transform_class = terminal_transform_class or LinearTransform self.terminal_transform_kwargs = terminal_transform_kwargs or dict(m=1, b=0) self.reward_transform = self.reward_transform_class(**self.reward_transform_kwargs) self.terminal_transform = self.terminal_transform_class(**self.terminal_transform_kwargs) self.use_reparam_update = use_reparam_update self.train_bc_on_rl_buffer = train_bc_on_rl_buffer and buffer_policy
def __init__(self, env, policy, qf1, qf2, vf, policy_lr=1e-3, qf_lr=1e-3, vf_lr=1e-3, policy_mean_reg_weight=1e-3, policy_std_reg_weight=1e-3, policy_pre_activation_weight=0., optimizer_class=optim.Adam, train_policy_with_reparameterization=True, soft_target_tau=1e-3, policy_update_period=1, target_update_period=1, plotter=None, render_eval_paths=False, eval_deterministic=True, eval_policy=None, exploration_policy=None, use_automatic_entropy_tuning=True, target_entropy=None, **kwargs): if eval_policy is None: if eval_deterministic: eval_policy = MakeDeterministic(policy) else: eval_policy = policy super().__init__(env=env, exploration_policy=exploration_policy or policy, eval_policy=eval_policy, **kwargs) self.policy = policy self.qf1 = qf1 self.qf2 = qf2 self.vf = vf self.soft_target_tau = soft_target_tau self.policy_update_period = policy_update_period self.target_update_period = target_update_period self.policy_mean_reg_weight = policy_mean_reg_weight self.policy_std_reg_weight = policy_std_reg_weight self.policy_pre_activation_weight = policy_pre_activation_weight self.train_policy_with_reparameterization = ( train_policy_with_reparameterization) self.use_automatic_entropy_tuning = use_automatic_entropy_tuning if self.use_automatic_entropy_tuning: if target_entropy: self.target_entropy = target_entropy else: self.target_entropy = -np.prod( self.env.action_space.shape).item( ) # heuristic value from Tuomas self.log_alpha = ptu.Variable(ptu.zeros(1), requires_grad=True) self.alpha_optimizer = optimizer_class( [self.log_alpha], lr=policy_lr, ) self.plotter = plotter self.render_eval_paths = render_eval_paths self.target_vf = vf.copy() self.qf_criterion = nn.MSELoss() self.vf_criterion = nn.MSELoss() self.policy_optimizer = optimizer_class( self.policy.parameters(), lr=policy_lr, ) self.qf1_optimizer = optimizer_class( self.qf1.parameters(), lr=qf_lr, ) self.qf2_optimizer = optimizer_class( self.qf2.parameters(), lr=qf_lr, ) self.vf_optimizer = optimizer_class( self.vf.parameters(), lr=vf_lr, )
def __init__( self, env, policy, qf1, qf2, target_qf1, target_qf2, vae, discount=0.99, reward_scale=1.0, policy_lr=1e-3, qf_lr=1e-3, optimizer_class=optim.Adam, soft_target_tau=1e-2, target_update_period=1, plotter=None, render_eval_paths=False, # BEAR specific params mode='auto', kernel_choice='laplacian', policy_update_style=0, mmd_sigma=10.0, target_mmd_thresh=0.05, num_samples_mmd_match=4, with_grad_penalty_v1=False, with_grad_penalty_v2=False, grad_coefficient_policy=0.0, grad_coefficient_q=0.0, use_target_nets=False, policy_update_delay=100, start_epoch_grad_penalty=0, num_steps_policy_update_only=50, bc_pretrain_steps=20000, target_update_method='default', use_adv_weighting=False, positive_reward=False, ): super().__init__() self.env = env self.policy = policy self.qf1 = qf1 self.qf2 = qf2 self.target_qf1 = target_qf1 self.target_qf2 = target_qf2 self.vae = vae self.soft_target_tau = soft_target_tau self.target_update_period = target_update_period self.plotter = plotter self.render_eval_paths = render_eval_paths self.qf_criterion = nn.MSELoss() self.vf_criterion = nn.MSELoss() self.policy_optimizer = optimizer_class( self.policy.parameters(), lr=policy_lr, ) self.qf1_optimizer = optimizer_class( self.qf1.parameters(), lr=qf_lr, ) self.qf2_optimizer = optimizer_class( self.qf2.parameters(), lr=qf_lr, ) self.vae_optimizer = optimizer_class( self.vae.parameters(), lr=3e-4, ) self.mode = mode if self.mode == 'auto': self.log_alpha = ptu.zeros(1, requires_grad=True) self.alpha_optimizer = optimizer_class( [self.log_alpha], lr=1e-3, ) self.mmd_sigma = mmd_sigma self.kernel_choice = kernel_choice self.num_samples_mmd_match = num_samples_mmd_match self.policy_update_style = policy_update_style self.target_mmd_thresh = target_mmd_thresh self.discount = discount self.reward_scale = reward_scale self.eval_statistics = OrderedDict() self._n_train_steps_total = 0 self._need_to_update_eval_statistics = True self._with_gradient_penalty_v1 = with_grad_penalty_v1 self._with_gradient_penalty_v2 = with_grad_penalty_v2 self._grad_coefficient_q = grad_coefficient_q self._grad_coefficient_policy = grad_coefficient_policy self._use_target_nets = use_target_nets self._policy_delay_update = policy_update_delay self._num_policy_steps = num_steps_policy_update_only self._start_epoch_grad_penalty = start_epoch_grad_penalty self._bc_pretrain_steps = bc_pretrain_steps self._target_update_method = target_update_method self._use_adv_weighting = use_adv_weighting self._positive_reward = positive_reward if self._target_update_method == 'distillation': self.target_qf1_opt = optimizer_class(self.target_qf1.parameters(), lr=qf_lr) self.target_qf2_opt = optimizer_class(self.target_qf2.parameters(), lr=qf_lr) self._current_epoch = 0 self._policy_update_ctr = 0 self._num_q_update_steps = 0 self._num_policy_update_steps = 0 if not self._use_target_nets: self.target_qf1 = qf1 self.target_qf2 = qf2