def set_flat_params(model, flat_params, trainable_only=True): idx = 0 # import ipdb; ipdb.set_trace() for p in model.parameters(): flat_shape = int(np.prod(list(p.data.shape))) flat_params_to_assign = flat_params[idx:idx + flat_shape] if len(p.data.shape): p.data = ptu.tensor(flat_params_to_assign.reshape(*p.data.shape)) else: p.data = ptu.tensor(flat_params_to_assign[0]) idx += flat_shape return model
def compute_world_model_loss( world_model, image_shape, image_dist, prior, post, prior_dist, post_dist, obs, forward_kl, free_nats, transition_loss_scale, kl_loss_scale, image_loss_scale, ): preprocessed_obs = world_model.flatten_obs(world_model.preprocess(obs), image_shape) image_pred_loss = -1 * image_dist.log_prob(preprocessed_obs).mean() post_detached_dist = world_model.get_detached_dist(post) prior_detached_dist = world_model.get_detached_dist(prior) if forward_kl: div = kld(post_dist, prior_dist).mean() div = torch.max(div, ptu.tensor(free_nats)) prior_kld = kld(post_detached_dist, prior_dist).mean() post_kld = kld(post_dist, prior_detached_dist).mean() else: div = kld(prior_dist, post_dist).mean() div = torch.max(div, ptu.tensor(free_nats)) prior_kld = kld(prior_dist, post_detached_dist).mean() post_kld = kld(prior_detached_dist, post_dist).mean() transition_loss = torch.max(prior_kld, ptu.tensor(free_nats)) entropy_loss = torch.max(post_kld, ptu.tensor(free_nats)) entropy_loss_scale = 1 - transition_loss_scale entropy_loss_scale = (1 - kl_loss_scale) * entropy_loss_scale transition_loss_scale = (1 - kl_loss_scale) * transition_loss_scale world_model_loss = (kl_loss_scale * div + image_loss_scale * image_pred_loss + transition_loss_scale * transition_loss + entropy_loss_scale * entropy_loss) return world_model_loss, div, image_pred_loss, transition_loss, entropy_loss
def __init__( self, hidden_size, obs_dim, num_layers=4, discrete_continuous_dist=False, discrete_action_dim=0, continuous_action_dim=0, hidden_activation=F.elu, min_std=0.1, init_std=0.0, mean_scale=5.0, use_tanh_normal=True, dist="trunc_normal", **kwargs, ): self.discrete_continuous_dist = discrete_continuous_dist self.discrete_action_dim = discrete_action_dim self.continuous_action_dim = continuous_action_dim if self.discrete_continuous_dist: self.output_size = self.discrete_action_dim + self.continuous_action_dim * 2 else: self.output_size = self.continuous_action_dim * 2 super().__init__( [hidden_size] * num_layers, input_size=obs_dim, output_size=self.output_size, hidden_activation=hidden_activation, hidden_init=torch.nn.init.xavier_uniform_, **kwargs, ) self._min_std = min_std self._mean_scale = mean_scale self.use_tanh_normal = use_tanh_normal self._dist = dist self.raw_init_std = torch.log(torch.exp(ptu.tensor(init_std)) - 1)
def __init__( self, env, context_graph, qf1, target_qf1, policy_n, cactor, qf2, target_qf2, deterministic_cactor_in_graph=True, deterministic_next_action=False, use_entropy_loss=True, use_entropy_reward=True, sum_n_loss=False, # use sum instead of mean for n agent losses use_cactor_entropy_loss=True, use_automatic_entropy_tuning=True, state_dependent_alpha=False, target_entropy=None, negative_sampling=False, discount=0.99, reward_scale=1.0, policy_learning_rate=1e-4, context_graph_learning_rate=1e-3, qf_learning_rate=1e-3, # not used qf_weight_decay=0., init_alpha=1., cactor_learning_rate=1e-4, target_hard_update_period=1000, tau=1e-2, use_soft_update=False, qf_criterion=None, pre_activation_weight=0., optimizer_class=optim.Adam, min_q_value=-np.inf, max_q_value=np.inf, context_graph_optimizer=None, cactor_optimizer=None, policy_optimizer_n=None, alpha_optimizer_n=None, calpha_optimizer_n=None, log_alpha_n = None, log_calpha_n = None, ): super().__init__() self.env = env if qf_criterion is None: qf_criterion = nn.MSELoss() self.context_graph = context_graph self.qf1 = qf1 self.target_qf1 = target_qf1 self.qf2 = qf2 self.target_qf2 = target_qf2 self.policy_n = policy_n self.cactor = cactor self.deterministic_cactor_in_graph = deterministic_cactor_in_graph self.deterministic_next_action = deterministic_next_action self.sum_n_loss = sum_n_loss self.negative_sampling = negative_sampling self.discount = discount self.reward_scale = reward_scale self.policy_learning_rate = policy_learning_rate self.context_graph_learning_rate = context_graph_learning_rate self.qf_learning_rate = qf_learning_rate self.qf_weight_decay = qf_weight_decay self.cactor_learning_rate = cactor_learning_rate self.target_hard_update_period = target_hard_update_period self.tau = tau self.use_soft_update = use_soft_update self.qf_criterion = qf_criterion self.pre_activation_weight = pre_activation_weight self.min_q_value = min_q_value self.max_q_value = max_q_value if context_graph_optimizer: self.context_graph_optimizer = context_graph_optimizer else: self.context_graph_optimizer = optimizer_class( list(self.context_graph.parameters())\ +list(self.qf1.parameters())\ +list(self.qf2.parameters()), lr=self.context_graph_learning_rate, ) if policy_optimizer_n: self.policy_optimizer_n = policy_optimizer_n else: self.policy_optimizer_n = [ optimizer_class( self.policy_n[i].parameters(), lr=self.policy_learning_rate, ) for i in range(len(self.policy_n))] if cactor_optimizer: self.cactor_optimizer = cactor_optimizer else: self.cactor_optimizer = optimizer_class( self.cactor.parameters(), lr=self.cactor_learning_rate, ) self.init_alpha = init_alpha self.use_entropy_loss = use_entropy_loss self.use_entropy_reward = use_entropy_reward self.use_cactor_entropy_loss = use_cactor_entropy_loss self.use_automatic_entropy_tuning = use_automatic_entropy_tuning self.state_dependent_alpha = state_dependent_alpha if self.use_automatic_entropy_tuning: if target_entropy: self.target_entropy = target_entropy else: self.target_entropy = -np.prod(self.env.action_space.shape).item() # heuristic value from Tuomas if self.use_entropy_loss: if log_alpha_n: self.log_alpha_n = log_alpha_n else: self.log_alpha_n = [ptu.tensor([np.log(self.init_alpha)], requires_grad=True, dtype=torch.float32) for i in range(len(self.policy_n))] if alpha_optimizer_n: self.alpha_optimizer_n = alpha_optimizer_n else: if self.state_dependent_alpha: self.alpha_optimizer_n = [ optimizer_class( self.log_alpha_n[i].parameters(), lr=self.policy_learning_rate, ) for i in range(len(self.log_alpha_n))] else: self.alpha_optimizer_n = [ optimizer_class( [self.log_alpha_n[i]], lr=self.policy_learning_rate, ) for i in range(len(self.log_alpha_n))] if self.use_cactor_entropy_loss: if log_calpha_n: self.log_calpha_n = log_calpha_n else: self.log_calpha_n = [ptu.tensor([np.log(self.init_alpha)], requires_grad=True, dtype=torch.float32) for i in range(len(self.policy_n))] if calpha_optimizer_n: self.calpha_optimizer_n = calpha_optimizer_n else: if self.state_dependent_alpha: self.calpha_optimizer_n = [ optimizer_class( self.log_calpha_n[i].parameters(), lr=self.policy_learning_rate, ) for i in range(len(self.log_calpha_n))] else: self.calpha_optimizer_n = [ optimizer_class( [self.log_calpha_n[i]], lr=self.policy_learning_rate, ) for i in range(len(self.log_calpha_n))] self.eval_statistics = OrderedDict() self._n_train_steps_total = 0 self._need_to_update_eval_statistics = True
def __init__( self, env, cg1, target_cg1, qf1_n, target_qf1_n, cg2, target_cg2, qf2_n, target_qf2_n, cgca, cactor_n, policy_n, deterministic_cactor_in_graph=True, deterministic_next_action=False, use_entropy_loss=True, use_entropy_reward=True, use_cactor_entropy_loss=True, use_automatic_entropy_tuning=True, target_entropy=None, discount=0.99, reward_scale=1.0, policy_learning_rate=1e-4, qf_learning_rate=1e-3, qf_weight_decay=0., init_alpha=1., cactor_learning_rate=1e-4, target_hard_update_period=1000, tau=1e-2, use_soft_update=False, qf_criterion=None, pre_activation_weight=0., optimizer_class=optim.Adam, min_q_value=-np.inf, max_q_value=np.inf, qf1_optimizer=None, qf2_optimizer=None, cactor_optimizer=None, policy_optimizer_n=None, alpha_optimizer_n=None, calpha_optimizer=None, log_alpha_n = None, log_calpha_n = None, ): super().__init__() self.env = env if qf_criterion is None: qf_criterion = nn.MSELoss() self.cg1 = cg1 self.target_cg1 = target_cg1 self.qf1_n = qf1_n self.target_qf1_n = target_qf1_n self.cg2 = cg2 self.target_cg2 = target_cg2 self.qf2_n = qf2_n self.target_qf2_n = target_qf2_n self.cgca = cgca self.cactor_n = cactor_n self.policy_n = policy_n self.deterministic_cactor_in_graph = deterministic_cactor_in_graph self.deterministic_next_action = deterministic_next_action self.discount = discount self.reward_scale = reward_scale self.policy_learning_rate = policy_learning_rate self.qf_learning_rate = qf_learning_rate self.qf_weight_decay = qf_weight_decay self.cactor_learning_rate = cactor_learning_rate self.target_hard_update_period = target_hard_update_period self.tau = tau self.use_soft_update = use_soft_update self.qf_criterion = qf_criterion self.pre_activation_weight = pre_activation_weight self.min_q_value = min_q_value self.max_q_value = max_q_value if qf1_optimizer: self.qf1_optimizer = qf1_optimizer else: qf1_parameters = list(self.cg1.parameters()) for qf1 in qf1_n: qf1_parameters += list(qf1.parameters()) self.qf1_optimizer = optimizer_class( qf1_parameters, lr=self.qf_learning_rate, ) if qf2_optimizer: self.qf2_optimizer = qf2_optimizer else: qf2_parameters = list(self.cg2.parameters()) for qf2 in qf2_n: qf2_parameters += list(qf2.parameters()) self.qf2_optimizer = optimizer_class( qf2_parameters, lr=self.qf_learning_rate, ) if policy_optimizer_n: self.policy_optimizer_n = policy_optimizer_n else: self.policy_optimizer_n = [ optimizer_class( self.policy_n[i].parameters(), lr=self.policy_learning_rate, ) for i in range(len(self.policy_n))] if cactor_optimizer: self.cactor_optimizer = cactor_optimizer else: cactor_parameters = list(self.cgca.parameters()) for cactor in cactor_n: cactor_parameters += list(cactor.parameters()) self.cactor_optimizer = optimizer_class( cactor_parameters, lr=self.cactor_learning_rate, ) self.init_alpha = init_alpha self.use_entropy_loss = use_entropy_loss self.use_entropy_reward = use_entropy_reward self.use_cactor_entropy_loss = use_cactor_entropy_loss self.use_automatic_entropy_tuning = use_automatic_entropy_tuning if self.use_automatic_entropy_tuning: if target_entropy: self.target_entropy = target_entropy else: self.target_entropy = -np.prod(self.env.action_space.shape).item() # heuristic value from Tuomas if self.use_entropy_loss: if log_alpha_n: self.log_alpha_n = log_alpha_n else: self.log_alpha_n = [ptu.tensor([np.log(self.init_alpha)], requires_grad=True, dtype=torch.float32) for i in range(len(self.policy_n))] if alpha_optimizer_n: self.alpha_optimizer_n = alpha_optimizer_n else: self.alpha_optimizer_n = [ optimizer_class( [self.log_alpha_n[i]], lr=self.policy_learning_rate, ) for i in range(len(self.log_alpha_n))] if self.use_cactor_entropy_loss: if log_calpha_n: self.log_calpha_n = log_calpha_n else: self.log_calpha_n = [ptu.tensor([np.log(self.init_alpha)], requires_grad=True, dtype=torch.float32) for i in range(len(self.policy_n))] if calpha_optimizer: self.calpha_optimizer = calpha_optimizer else: self.calpha_optimizer = \ optimizer_class( self.log_calpha_n, lr=self.policy_learning_rate, ) self.eval_statistics = OrderedDict() self._n_train_steps_total = 0 self._need_to_update_eval_statistics = True
def train_from_torch( self, batch, train=True, pretrain=False, ): """ :param batch: :param train: :param pretrain: :return: """ rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] weights = batch.get('weights', None) if self.reward_transform: rewards = self.reward_transform(rewards) if self.terminal_transform: terminals = self.terminal_transform(terminals) """ Policy and Alpha Loss """ dist = self.policy(obs) new_obs_actions, log_pi = dist.rsample_and_logprob() policy_mle = dist.mle_estimate() if self.brac: buf_dist = self.buffer_policy(obs) buf_log_pi = buf_dist.log_prob(actions) rewards = rewards + buf_log_pi if self.use_automatic_entropy_tuning: alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() alpha = self.log_alpha.exp() else: alpha_loss = 0 alpha = self.alpha """ QF Loss """ q1_pred = self.qf1(obs, actions) q2_pred = self.qf2(obs, actions) # Make sure policy accounts for squashing functions like tanh correctly! next_dist = self.policy(next_obs) new_next_actions, new_log_pi = next_dist.rsample_and_logprob() target_q_values = torch.min( self.target_qf1(next_obs, new_next_actions), self.target_qf2(next_obs, new_next_actions), ) - alpha * new_log_pi q_target = self.reward_scale * rewards + ( 1. - terminals) * self.discount * target_q_values qf1_loss = self.qf_criterion(q1_pred, q_target.detach()) qf2_loss = self.qf_criterion(q2_pred, q_target.detach()) """ Policy Loss """ qf1_new_actions = self.qf1(obs, new_obs_actions) qf2_new_actions = self.qf2(obs, new_obs_actions) q_new_actions = torch.min( qf1_new_actions, qf2_new_actions, ) # Advantage-weighted regression if self.awr_use_mle_for_vf: v1_pi = self.qf1(obs, policy_mle) v2_pi = self.qf2(obs, policy_mle) v_pi = torch.min(v1_pi, v2_pi) else: if self.vf_K > 1: vs = [] for i in range(self.vf_K): u = dist.sample() q1 = self.qf1(obs, u) q2 = self.qf2(obs, u) v = torch.min(q1, q2) # v = q1 vs.append(v) v_pi = torch.cat(vs, 1).mean(dim=1) else: # v_pi = self.qf1(obs, new_obs_actions) v1_pi = self.qf1(obs, new_obs_actions) v2_pi = self.qf2(obs, new_obs_actions) v_pi = torch.min(v1_pi, v2_pi) if self.awr_sample_actions: u = new_obs_actions if self.awr_min_q: q_adv = q_new_actions else: q_adv = qf1_new_actions elif self.buffer_policy_sample_actions: buf_dist = self.buffer_policy(obs) u, _ = buf_dist.rsample_and_logprob() qf1_buffer_actions = self.qf1(obs, u) qf2_buffer_actions = self.qf2(obs, u) q_buffer_actions = torch.min( qf1_buffer_actions, qf2_buffer_actions, ) if self.awr_min_q: q_adv = q_buffer_actions else: q_adv = qf1_buffer_actions else: u = actions if self.awr_min_q: q_adv = torch.min(q1_pred, q2_pred) else: q_adv = q1_pred policy_logpp = dist.log_prob(u) if self.use_automatic_beta_tuning: buffer_dist = self.buffer_policy(obs) beta = self.log_beta.exp() kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist) beta_loss = -1 * (beta * (kldiv - self.beta_epsilon).detach()).mean() self.beta_optimizer.zero_grad() beta_loss.backward() self.beta_optimizer.step() else: beta = self.beta_schedule.get_value(self._n_train_steps_total) if self.normalize_over_state == "advantage": score = q_adv - v_pi if self.mask_positive_advantage: score = torch.sign(score) elif self.normalize_over_state == "Z": buffer_dist = self.buffer_policy(obs) K = self.Z_K buffer_obs = [] buffer_actions = [] log_bs = [] log_pis = [] for i in range(K): u = buffer_dist.sample() log_b = buffer_dist.log_prob(u) log_pi = dist.log_prob(u) buffer_obs.append(obs) buffer_actions.append(u) log_bs.append(log_b) log_pis.append(log_pi) buffer_obs = torch.cat(buffer_obs, 0) buffer_actions = torch.cat(buffer_actions, 0) p_buffer = torch.exp(torch.cat(log_bs, 0).sum(dim=1, )) log_pi = torch.cat(log_pis, 0) log_pi = log_pi.sum(dim=1, ) q1_b = self.qf1(buffer_obs, buffer_actions) q2_b = self.qf2(buffer_obs, buffer_actions) q_b = torch.min(q1_b, q2_b) q_b = torch.reshape(q_b, (-1, K)) adv_b = q_b - v_pi # if self._n_train_steps_total % 100 == 0: # import ipdb; ipdb.set_trace() # Z = torch.exp(adv_b / beta).mean(dim=1, keepdim=True) # score = torch.exp((q_adv - v_pi) / beta) / Z # score = score / sum(score) logK = torch.log(ptu.tensor(float(K))) logZ = torch.logsumexp(adv_b / beta - logK, dim=1, keepdim=True) logS = (q_adv - v_pi) / beta - logZ # logZ = torch.logsumexp(q_b/beta - logK, dim=1, keepdim=True) # logS = q_adv/beta - logZ score = F.softmax(logS, dim=0) # score / sum(score) else: error if self.clip_score is not None: score = torch.clamp(score, max=self.clip_score) if self.weight_loss and weights is None: if self.normalize_over_batch: weights = F.softmax(score / beta, dim=0) elif self.normalize_over_batch == "whiten": adv_mean = torch.mean(score) adv_std = torch.std(score) + 1e-5 normalized_score = (score - adv_mean) / adv_std weights = torch.exp(normalized_score / beta) elif self.normalize_over_batch == "exp": weights = torch.exp(score / beta) elif self.normalize_over_batch == "step_fn": weights = (score > 0).float() elif not self.normalize_over_batch: weights = score else: error weights = weights[:, 0] policy_loss = alpha * log_pi.mean() if self.use_awr_update and self.weight_loss: policy_loss = policy_loss + self.awr_weight * ( -policy_logpp * len(weights) * weights.detach()).mean() elif self.use_awr_update: policy_loss = policy_loss + self.awr_weight * ( -policy_logpp).mean() if self.use_reparam_update: policy_loss = policy_loss + self.reparam_weight * ( -q_new_actions).mean() policy_loss = self.rl_weight * policy_loss if self.compute_bc: train_policy_loss, train_logp_loss, train_mse_loss, _ = self.run_bc_batch( self.demo_train_buffer, self.policy) policy_loss = policy_loss + self.bc_weight * train_policy_loss if not pretrain and self.buffer_policy_reset_period > 0 and self._n_train_steps_total % self.buffer_policy_reset_period == 0: del self.buffer_policy_optimizer self.buffer_policy_optimizer = self.optimizer_class( self.buffer_policy.parameters(), weight_decay=self.policy_weight_decay, lr=self.policy_lr, ) self.optimizers[self.buffer_policy] = self.buffer_policy_optimizer for i in range(self.num_buffer_policy_train_steps_on_reset): if self.train_bc_on_rl_buffer: if self.advantage_weighted_buffer_loss: buffer_dist = self.buffer_policy(obs) buffer_u = actions buffer_new_obs_actions, _ = buffer_dist.rsample_and_logprob( ) buffer_policy_logpp = buffer_dist.log_prob(buffer_u) buffer_policy_logpp = buffer_policy_logpp[:, None] buffer_q1_pred = self.qf1(obs, buffer_u) buffer_q2_pred = self.qf2(obs, buffer_u) buffer_q_adv = torch.min(buffer_q1_pred, buffer_q2_pred) buffer_v1_pi = self.qf1(obs, buffer_new_obs_actions) buffer_v2_pi = self.qf2(obs, buffer_new_obs_actions) buffer_v_pi = torch.min(buffer_v1_pi, buffer_v2_pi) buffer_score = buffer_q_adv - buffer_v_pi buffer_weights = F.softmax(buffer_score / beta, dim=0) buffer_policy_loss = self.awr_weight * ( -buffer_policy_logpp * len(buffer_weights) * buffer_weights.detach()).mean() else: buffer_policy_loss, buffer_train_logp_loss, buffer_train_mse_loss, _ = self.run_bc_batch( self.replay_buffer.train_replay_buffer, self.buffer_policy) self.buffer_policy_optimizer.zero_grad() buffer_policy_loss.backward(retain_graph=True) self.buffer_policy_optimizer.step() if self.train_bc_on_rl_buffer: if self.advantage_weighted_buffer_loss: buffer_dist = self.buffer_policy(obs) buffer_u = actions buffer_new_obs_actions, _ = buffer_dist.rsample_and_logprob() buffer_policy_logpp = buffer_dist.log_prob(buffer_u) buffer_policy_logpp = buffer_policy_logpp[:, None] buffer_q1_pred = self.qf1(obs, buffer_u) buffer_q2_pred = self.qf2(obs, buffer_u) buffer_q_adv = torch.min(buffer_q1_pred, buffer_q2_pred) buffer_v1_pi = self.qf1(obs, buffer_new_obs_actions) buffer_v2_pi = self.qf2(obs, buffer_new_obs_actions) buffer_v_pi = torch.min(buffer_v1_pi, buffer_v2_pi) buffer_score = buffer_q_adv - buffer_v_pi buffer_weights = F.softmax(buffer_score / beta, dim=0) buffer_policy_loss = self.awr_weight * ( -buffer_policy_logpp * len(buffer_weights) * buffer_weights.detach()).mean() else: buffer_policy_loss, buffer_train_logp_loss, buffer_train_mse_loss, _ = self.run_bc_batch( self.replay_buffer.train_replay_buffer, self.buffer_policy) """ Update networks """ if self._n_train_steps_total % self.q_update_period == 0: self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() if self._n_train_steps_total % self.policy_update_period == 0 and self.update_policy: self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() if self.train_bc_on_rl_buffer and self._n_train_steps_total % self.policy_update_period == 0: self.buffer_policy_optimizer.zero_grad() buffer_policy_loss.backward() self.buffer_policy_optimizer.step() """ Soft Updates """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to(self.qf1, self.target_qf1, self.soft_target_tau) ptu.soft_update_from_to(self.qf2, self.target_qf2, self.soft_target_tau) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ policy_loss = (log_pi - q_new_actions).mean() self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update( create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) self.eval_statistics.update( create_stats_ordered_dict( 'rewards', ptu.get_numpy(rewards), )) self.eval_statistics.update( create_stats_ordered_dict( 'terminals', ptu.get_numpy(terminals), )) policy_statistics = add_prefix(dist.get_diagnostics(), "policy/") self.eval_statistics.update(policy_statistics) self.eval_statistics.update( create_stats_ordered_dict( 'Advantage Weights', ptu.get_numpy(weights), )) self.eval_statistics.update( create_stats_ordered_dict( 'Advantage Score', ptu.get_numpy(score), )) if self.normalize_over_state == "Z": self.eval_statistics.update( create_stats_ordered_dict( 'logZ', ptu.get_numpy(logZ), )) if self.use_automatic_entropy_tuning: self.eval_statistics['Alpha'] = alpha.item() self.eval_statistics['Alpha Loss'] = alpha_loss.item() if self.compute_bc: test_policy_loss, test_logp_loss, test_mse_loss, _ = self.run_bc_batch( self.demo_test_buffer, self.policy) self.eval_statistics.update({ "bc/Train Logprob Loss": ptu.get_numpy(train_logp_loss), "bc/Test Logprob Loss": ptu.get_numpy(test_logp_loss), "bc/Train MSE": ptu.get_numpy(train_mse_loss), "bc/Test MSE": ptu.get_numpy(test_mse_loss), "bc/train_policy_loss": ptu.get_numpy(train_policy_loss), "bc/test_policy_loss": ptu.get_numpy(test_policy_loss), }) if self.train_bc_on_rl_buffer: _, buffer_train_logp_loss, _, _ = self.run_bc_batch( self.replay_buffer.train_replay_buffer, self.buffer_policy) _, buffer_test_logp_loss, _, _ = self.run_bc_batch( self.replay_buffer.validation_replay_buffer, self.buffer_policy) buffer_dist = self.buffer_policy(obs) kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist) _, train_offline_logp_loss, _, _ = self.run_bc_batch( self.demo_train_buffer, self.buffer_policy) _, test_offline_logp_loss, _, _ = self.run_bc_batch( self.demo_test_buffer, self.buffer_policy) self.eval_statistics.update({ "buffer_policy/Train Online Logprob": -1 * ptu.get_numpy(buffer_train_logp_loss), "buffer_policy/Test Online Logprob": -1 * ptu.get_numpy(buffer_test_logp_loss), "buffer_policy/Train Offline Logprob": -1 * ptu.get_numpy(train_offline_logp_loss), "buffer_policy/Test Offline Logprob": -1 * ptu.get_numpy(test_offline_logp_loss), "buffer_policy/train_policy_loss": ptu.get_numpy(buffer_policy_loss), # "buffer_policy/test_policy_loss": ptu.get_numpy(buffer_test_policy_loss), "buffer_policy/kl_div": ptu.get_numpy(kldiv.mean()), }) if self.use_automatic_beta_tuning: self.eval_statistics.update({ "adaptive_beta/beta": ptu.get_numpy(beta.mean()), "adaptive_beta/beta loss": ptu.get_numpy(beta_loss.mean()), }) if self.validation_qlearning: train_data = self.replay_buffer.validation_replay_buffer.random_batch( self.bc_batch_size) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] # goals = train_data['resampled_goals'] train_data[ 'observations'] = obs # torch.cat((obs, goals), dim=1) train_data[ 'next_observations'] = next_obs # torch.cat((next_obs, goals), dim=1) self.test_from_torch(train_data) self._n_train_steps_total += 1
def __init__( self, env, qf1_n, target_qf1_n, policy_n, cactor_n, online_action, qf2_n, target_qf2_n, deterministic_cactor_in_graph=True, deterministic_next_action=False, prg_next_action=True, use_entropy_loss=True, use_entropy_reward=True, use_cactor_entropy_loss=True, use_automatic_entropy_tuning=True, state_dependent_alpha=False, target_entropy=None, dec_cactor=True, # each cactor only gets its own observation logit_level=1, discount=0.99, reward_scale=1.0, policy_learning_rate=1e-4, qf_learning_rate=1e-3, qf_weight_decay=0., init_alpha=1., cactor_learning_rate=1e-4, target_hard_update_period=1000, tau=1e-2, use_soft_update=False, qf_criterion=None, pre_activation_weight=0., optimizer_class=optim.Adam, min_q_value=-np.inf, max_q_value=np.inf, qf1_optimizer_n=None, qf2_optimizer_n=None, policy_optimizer_n=None, cactor_optimizer_n=None, alpha_optimizer_n=None, calpha_optimizer_n=None, log_alpha_n=None, log_calpha_n=None, ): super().__init__() self.env = env if qf_criterion is None: qf_criterion = nn.MSELoss() self.qf1_n = qf1_n self.target_qf1_n = target_qf1_n self.qf2_n = qf2_n self.target_qf2_n = target_qf2_n self.policy_n = policy_n self.cactor_n = cactor_n self.online_action = online_action self.logit_level = logit_level self.deterministic_cactor_in_graph = deterministic_cactor_in_graph self.deterministic_next_action = deterministic_next_action self.prg_next_action = prg_next_action self.dec_cactor = dec_cactor self.discount = discount self.reward_scale = reward_scale self.policy_learning_rate = policy_learning_rate self.qf_learning_rate = qf_learning_rate self.qf_weight_decay = qf_weight_decay self.cactor_learning_rate = cactor_learning_rate self.target_hard_update_period = target_hard_update_period self.tau = tau self.use_soft_update = use_soft_update self.qf_criterion = qf_criterion self.pre_activation_weight = pre_activation_weight self.min_q_value = min_q_value self.max_q_value = max_q_value if qf1_optimizer_n: self.qf1_optimizer_n = qf1_optimizer_n else: self.qf1_optimizer_n = [ optimizer_class( self.qf1_n[i].parameters(), lr=self.qf_learning_rate, ) for i in range(len(self.qf1_n)) ] if qf2_optimizer_n: self.qf2_optimizer_n = qf2_optimizer_n else: self.qf2_optimizer_n = [ optimizer_class( self.qf2_n[i].parameters(), lr=self.qf_learning_rate, ) for i in range(len(self.qf2_n)) ] if policy_optimizer_n: self.policy_optimizer_n = policy_optimizer_n else: self.policy_optimizer_n = [ optimizer_class( self.policy_n[i].parameters(), lr=self.policy_learning_rate, ) for i in range(len(self.policy_n)) ] if cactor_optimizer_n: self.cactor_optimizer_n = cactor_optimizer_n else: self.cactor_optimizer_n = [ optimizer_class( self.cactor_n[i].parameters(), lr=self.cactor_learning_rate, ) for i in range(len(self.cactor_n)) ] self.init_alpha = init_alpha self.use_entropy_loss = use_entropy_loss self.use_entropy_reward = use_entropy_reward self.use_cactor_entropy_loss = use_cactor_entropy_loss self.use_automatic_entropy_tuning = use_automatic_entropy_tuning self.state_dependent_alpha = state_dependent_alpha if self.use_automatic_entropy_tuning: if target_entropy: self.target_entropy = target_entropy else: self.target_entropy = -np.prod( self.env.action_space.shape).item( ) # heuristic value from Tuomas if self.use_entropy_loss: if log_alpha_n: self.log_alpha_n = log_alpha_n else: self.log_alpha_n = [ ptu.tensor([np.log(self.init_alpha)], requires_grad=True, dtype=torch.float32) for i in range(len(self.policy_n)) ] if alpha_optimizer_n: self.alpha_optimizer_n = alpha_optimizer_n else: if self.state_dependent_alpha: self.alpha_optimizer_n = [ optimizer_class( self.log_alpha_n[i].parameters(), lr=self.policy_learning_rate, ) for i in range(len(self.log_alpha_n)) ] else: self.alpha_optimizer_n = [ optimizer_class( [self.log_alpha_n[i]], lr=self.policy_learning_rate, ) for i in range(len(self.log_alpha_n)) ] if self.use_cactor_entropy_loss: if log_calpha_n: self.log_calpha_n = log_calpha_n else: self.log_calpha_n = [ ptu.tensor([np.log(self.init_alpha)], requires_grad=True, dtype=torch.float32) for i in range(len(self.policy_n)) ] if calpha_optimizer_n: self.calpha_optimizer_n = calpha_optimizer_n else: if self.state_dependent_alpha: self.calpha_optimizer_n = [ optimizer_class( self.log_calpha_n[i].parameters(), lr=self.policy_learning_rate, ) for i in range(len(self.log_calpha_n)) ] else: self.calpha_optimizer_n = [ optimizer_class( [self.log_calpha_n[i]], lr=self.policy_learning_rate, ) for i in range(len(self.log_calpha_n)) ] self.eval_statistics = OrderedDict() self._n_train_steps_total = 0 self._need_to_update_eval_statistics = True
def log_abs_det_jacobian(self, x, y): return 2.0 * (torch.log(ptu.tensor(2.0)) - x - F.softplus(-2.0 * x))
def world_model_loss_rt( world_model, image_shape, image_dist, reward_dist, prior, post, prior_dist, post_dist, pred_discount_dist, obs, rewards, terminals, forward_kl, free_nats, transition_loss_scale, kl_loss_scale, image_loss_scale, reward_loss_scale, pred_discount_loss_scale, discount, ): preprocessed_obs = world_model.flatten_obs(world_model.preprocess(obs), image_shape) image_pred_loss = -1 * image_dist.log_prob(preprocessed_obs).mean() post_detached_dist = world_model.get_detached_dist(post) prior_detached_dist = world_model.get_detached_dist(prior) reward_pred_loss = -1 * reward_dist.log_prob(rewards).mean() pred_discount_target = discount * (1 - terminals.float()) pred_discount_loss = -1 * pred_discount_dist.log_prob( pred_discount_target).mean() if forward_kl: div = kld(post_dist, prior_dist).mean() div = torch.max(div, ptu.tensor(free_nats)) prior_kld = kld(post_detached_dist, prior_dist).mean() post_kld = kld(post_dist, prior_detached_dist).mean() else: div = kld(prior_dist, post_dist).mean() div = torch.max(div, ptu.tensor(free_nats)) prior_kld = kld(prior_dist, post_detached_dist).mean() post_kld = kld(prior_detached_dist, post_dist).mean() transition_loss = torch.max(prior_kld, ptu.tensor(free_nats)) entropy_loss = torch.max(post_kld, ptu.tensor(free_nats)) entropy_loss_scale = 1 - transition_loss_scale entropy_loss_scale = (1 - kl_loss_scale) * entropy_loss_scale transition_loss_scale = (1 - kl_loss_scale) * transition_loss_scale world_model_loss = (kl_loss_scale * div + image_loss_scale * image_pred_loss + transition_loss_scale * transition_loss + entropy_loss_scale * entropy_loss + reward_loss_scale * reward_pred_loss + pred_discount_loss_scale * pred_discount_loss) return ( world_model_loss, div, image_pred_loss, reward_pred_loss, transition_loss, entropy_loss, pred_discount_loss, )
def __init__( self, actor, vf, target_vf, world_model, image_shape, imagination_horizon=15, discount=0.99, actor_lr=8e-5, vf_lr=8e-5, world_model_lr=3e-4, world_model_gradient_clip=100.0, actor_gradient_clip=100.0, value_gradient_clip=100.0, adam_eps=1e-5, weight_decay=0.0, soft_target_tau=1, target_update_period=100, lam=0.95, free_nats=1.0, kl_loss_scale=0.0, pred_discount_loss_scale=10.0, image_loss_scale=1.0, reward_loss_scale=2.0, transition_loss_scale=0.8, detach_rewards=False, forward_kl=False, policy_gradient_loss_scale=0.0, actor_entropy_loss_schedule="1e-4", use_pred_discount=False, reward_scale=1, num_imagination_iterations=1, use_baseline=True, use_ppo_loss=False, ppo_clip_param=0.2, num_actor_value_updates=1, use_advantage_normalization=False, use_clipped_value_loss=False, actor_value_lr=8e-5, use_actor_value_optimizer=False, binarize_rewards=False, ): super().__init__() torch.backends.cudnn.benchmark = True self.scaler = torch.cuda.amp.GradScaler() self.use_pred_discount = use_pred_discount self.actor = actor.to(ptu.device) self.world_model = world_model.to(ptu.device) self.vf = vf.to(ptu.device) self.target_vf = target_vf.to(ptu.device) optimizer_class = optim.Adam self.actor_lr = actor_lr self.adam_eps = adam_eps self.weight_decay = weight_decay self.vf_lr = vf_lr self.world_model_lr = world_model_lr self.actor_optimizer = optimizer_class( self.actor.parameters(), lr=actor_lr, eps=adam_eps, weight_decay=weight_decay, ) self.vf_optimizer = optimizer_class( self.vf.parameters(), lr=vf_lr, eps=adam_eps, weight_decay=weight_decay, ) self.world_model_optimizer = optimizer_class( self.world_model.parameters(), lr=world_model_lr, eps=adam_eps, weight_decay=weight_decay, ) self.use_actor_value_optimizer = use_actor_value_optimizer self.actor_value_optimizer = optimizer_class( list(self.actor.parameters()) + list(self.vf.parameters()), lr=actor_value_lr, eps=adam_eps, weight_decay=weight_decay, ) self.discount = discount self.lam = lam self.imagination_horizon = imagination_horizon self.free_nats = ptu.tensor(free_nats) self.kl_loss_scale = kl_loss_scale self.pred_discount_loss_scale = pred_discount_loss_scale self.image_loss_scale = image_loss_scale self.reward_loss_scale = reward_loss_scale self.transition_loss_scale = transition_loss_scale self.policy_gradient_loss_scale = policy_gradient_loss_scale self.actor_entropy_loss_schedule = actor_entropy_loss_schedule self.actor_entropy_loss_scale = lambda x=actor_entropy_loss_schedule: schedule( x, self._n_train_steps_total) self.forward_kl = forward_kl self.soft_target_tau = soft_target_tau self.target_update_period = target_update_period self.image_shape = image_shape self.use_baseline = use_baseline self.use_ppo_loss = use_ppo_loss self.ppo_clip_param = ppo_clip_param self.num_actor_value_updates = num_actor_value_updates self.world_model_gradient_clip = world_model_gradient_clip self.actor_gradient_clip = actor_gradient_clip self.value_gradient_clip = value_gradient_clip self.use_advantage_normalization = use_advantage_normalization self.detach_rewards = detach_rewards self.num_imagination_iterations = num_imagination_iterations self.use_clipped_value_loss = use_clipped_value_loss self.reward_scale = reward_scale self._n_train_steps_total = 0 self._need_to_update_eval_statistics = True self.eval_statistics = OrderedDict() self.use_dynamics_backprop = self.policy_gradient_loss_scale < 1.0 self.binarize_rewards = binarize_rewards
def __init__( self, env, qf1, target_qf1, qf2, target_qf2, policy_n, shared_gnn=None, discount=0.99, reward_scale=1.0, policy_learning_rate=1e-4, qf_learning_rate=1e-3, qf_weight_decay=0., log_alpha_n=None, init_alpha=1., target_hard_update_period=1000, tau=1e-2, use_soft_update=False, qf_criterion=None, deterministic_next_action=False, use_entropy_reward=False, use_automatic_entropy_tuning=True, target_entropy=None, optimizer_class=optim.Adam, log_grad=False, shared_obs=False, min_q_value=-np.inf, max_q_value=np.inf, qf1_optimizer=None, qf2_optimizer=None, policy_optimizer_n=None, alpha_optimizer_n=None, shared_gnn_optimizer=None, ): super().__init__() if qf_criterion is None: qf_criterion = nn.MSELoss() self.env = env self.qf1 = qf1 self.target_qf1 = target_qf1 self.qf2 = qf2 self.target_qf2 = target_qf2 self.policy_n = policy_n self.shared_gnn = shared_gnn self.deterministic_next_action = deterministic_next_action self.use_entropy_reward = use_entropy_reward self.discount = discount self.reward_scale = reward_scale self.policy_learning_rate = policy_learning_rate self.qf_learning_rate = qf_learning_rate self.qf_weight_decay = qf_weight_decay self.target_hard_update_period = target_hard_update_period self.tau = tau self.use_soft_update = use_soft_update self.qf_criterion = qf_criterion self.min_q_value = min_q_value self.max_q_value = max_q_value self.log_grad = log_grad self.shared_obs = shared_obs self.init_alpha = init_alpha self.use_automatic_entropy_tuning = use_automatic_entropy_tuning if self.use_automatic_entropy_tuning: if target_entropy: self.target_entropy = target_entropy else: self.target_entropy = -np.prod(self.env.action_space.shape).item() # heuristic value from Tuomas if log_alpha_n: self.log_alpha_n = log_alpha_n else: self.log_alpha_n = [ptu.tensor([np.log(self.init_alpha)], requires_grad=True, dtype=torch.float32) for i in range(len(self.policy_n))] if alpha_optimizer_n: self.alpha_optimizer_n = alpha_optimizer_n else: self.alpha_optimizer_n = [ optimizer_class( [self.log_alpha_n[i]], lr=self.policy_learning_rate, ) for i in range(len(self.log_alpha_n))] if qf1_optimizer: self.qf1_optimizer = qf1_optimizer else: self.qf1_optimizer = optimizer_class( self.qf1.parameters(), lr=self.qf_learning_rate, ) if qf2_optimizer: self.qf2_optimizer = qf2_optimizer else: self.qf2_optimizer = optimizer_class( self.qf2.parameters(), lr=self.qf_learning_rate, ) if policy_optimizer_n: self.policy_optimizer_n = policy_optimizer_n else: self.policy_optimizer_n = [ optimizer_class( self.policy_n[i].parameters(), lr=self.policy_learning_rate, ) for i in range(len(self.policy_n))] if shared_gnn: if shared_gnn_optimizer: self.shared_gnn_optimizer = shared_gnn_optimizer else: self.shared_gnn_optimizer = optimizer_class( self.shared_gnn.parameters(), lr=self.policy_learning_rate/len(self.policy_n), ) self.eval_statistics = OrderedDict() self._n_train_steps_total = 0 self._need_to_update_eval_statistics = True