def _compute_sup_loss(self, obs, actions, labels, valid_mask): obs = torch_ify(obs) actions = torch_ify(actions) valid_mask = torch_ify(valid_mask).bool() labels = torch_ify(labels).clone() valids = ~torch.isnan(labels) labels[~valids] = 0 if self._recurrent: pre_actions = actions[:, :-1, :] policy_input = (obs, pre_actions) else: policy_input = obs # lls = self.policy.sup_log_prob(policy_input, labels) valid_num = (valid_mask.unsqueeze(-1) * valids).float().sum() dists = self.policy.get_sup_distribution(policy_input) lls = dists.log_prob(labels) lls[~valids] = 0 lls[~valid_mask] = 0 loss = -lls.sum() / valid_num accuracy = (torch.argmax(dists.probs, -1) == labels).float() accuracy[~valids] = 0. accuracy[~valid_mask] = 0. accuracy = accuracy.sum() / valid_num # return -lls[valid_mask].mean() return loss, accuracy
def _start_new_rollout(self): self.exploration_policy.reset() # Note: we assume we're using a silent env. o = self.training_env.reset() rgp = self.rollout_goal_params if rgp is None: self._rollout_goal = o[self.desired_goal_key] elif rgp["strategy"] == "ensemble_qs": exploration_temperature = rgp["exploration_temperature"] assert len(self.ensemble_qs) > 0 N = 128 obs = np.tile(o[self.observation_key], (N, 1)) proposed_goals = self.training_env.sample_goals(N)[ self.desired_goal_key] new_obs = np.hstack((obs, proposed_goals)) actions = torch_ify(self.policy.get_action(new_obs)[0]) q_values = np.zeros((len(self.ensemble_qs), N)) for i, q in enumerate(self.ensemble_qs): q_values[i, :] = np_ify(q(torch_ify(new_obs), actions)).flatten() q_std = q_values.std(axis=0) p = softmax(q_std / exploration_temperature) ind = np.random.choice(np.arange(N), p=p) self._rollout_goal = {} self._rollout_goal[self.desired_goal_key] = proposed_goals[ind, :] elif rgp["strategy"] == "vae_q": pass else: assert False, "bad rollout goal strategy" return o
def _add_exploration_bonus(self, paths): paths = copy.deepcopy(paths) entropy_decreases = [] with torch.no_grad(): for path in paths: for i in range(len(path['observations']) - 1): obs1 = path['observations'][i] labels1 = torch.tensor(path['env_infos']['sup_labels'][i]) valid_mask1 = ~torch.isnan(labels1)[None, :] entropy_1 = self.policy.get_sup_distribution( torch_ify(obs1)[None, :]).entropy() entropy_1 = torch.mean(entropy_1[valid_mask1]) obs2 = path['observations'][i + 1] labels2 = torch.tensor(path['env_infos']['sup_labels'][i + 1]) valid_mask2 = ~torch.isnan(labels2)[None, :] entropy_2 = self.policy.get_sup_distribution( torch_ify(obs2)[None, :]).entropy() entropy_2 = torch.mean(entropy_2[valid_mask2]) entropy_decrease = (entropy_1 - entropy_2).item() entropy_decreases.append(entropy_decrease) path['rewards'][ i] += self.exploration_bonus * entropy_decrease if self._need_to_update_eval_statistics: self.eval_statistics.update( create_stats_ordered_dict( 'Entropy Decrease', entropy_decreases, )) return paths
def get_action(self, obs, labels=None, deterministic=False): assert len(obs.shape) == 1 assert (self.policy.a_p == self.sup_learner.a_p).all() with torch.no_grad(): obs_action = (torch_ify(obs)[None, None, :], self.policy.a_p[None, None, :]) if labels is not None: labels = torch_ify(labels)[None, None, :] pis, info = self.forward(obs_action, labels=labels, latent=self.policy.latent_p, sup_latent=self.sup_learner.latent_p, return_info=True) sup_probs = Categorical(logits=info['sup_preactivation']).probs pis = np_ify(pis[0, 0, :]) sup_probs = np_ify(sup_probs[0, 0, :, :]) if deterministic: action = np.argmax(pis) else: action = np.random.choice(np.arange(pis.shape[0]), p=pis) self.policy.a_p = torch_ify(np.array([action])) self.policy.latent_p = info['latent'] self.sup_learner.a_p = torch_ify(np.array([action])) self.sup_learner.latent_p = info['sup_latent'] return action, {'intentions': sup_probs}
def forward(self, obs, action, network_idx=None, return_net_outputs=False): # TODO: is this the usage I want? obs = torch_ify(obs) action = torch_ify(action) if network_idx is None: network = random.choice(self._nets) else: network = self._nets[network_idx] output = network(obs, action) # TODO: possibly wrap this in a Probabilistic Network class mean = output[:, :self.obs_dim] logvar = output[:, self.obs_dim:2 * self.obs_dim] if self.predict_reward: reward = output[:, -1:] # do variance pinning logvar = self.max_logvar - F.softplus(self.max_logvar - logvar) logvar = self.min_logvar + F.softplus(logvar - self.min_logvar) var = torch.exp(logvar) # sampling trick next_obs = mean + torch.randn_like(mean) * var.sqrt() if not self.predict_reward: if self.env: action = self.denormalize_action(action) reward = self.rew_function(obs, action, next_obs) if return_net_outputs: return mean, logvar, reward return next_obs, reward
def _compute_sup_loss(self, obs, labels): obs = torch_ify(obs) labels = torch_ify(labels) valid_mask = ~torch.isnan(labels) # replay buffer! labels[~valid_mask] = 0 lls = self.policy.sup_log_prob(obs, labels) return -lls[valid_mask]
def _train_sup_learner(self, observations, labels): observations = torch_ify(observations) labels = torch_ify(labels) self._sup_optimizer.zero_grad() sup_loss = self._compute_sup_loss(observations, labels) sup_loss.backward() self._sup_optimizer.step() return sup_loss
def beta_eval(goals): # goals = np.array([[ # *goal # ]]) N = len(goals) observations = np.tile(obs, (N, 1)) new_obs = np.hstack((observations, goals)) actions = torch_ify(policy.get_action(new_obs)[0]) return np_ify(q(torch_ify(new_obs), actions)).flatten()
def _compute_sup_loss(self, obs, labels): obs = torch_ify(obs) labels = torch_ify(labels).clone() valid_mask = ~torch.isnan(labels) labels[~valid_mask] = 0 lls = self.policy.sup_log_prob(obs, labels) lls[~valid_mask] = 0 # return -lls[valid_mask].mean() return -lls.mean()
def __init__( self, env, env_name, rew_function=None, ): self.env = env self.env_name = env_name self.rew_function = rew_function self.lb = torch_ify(self.env._wrapped_env.action_space.low) self.ub = torch_ify(self.env._wrapped_env.action_space.high) self.trained_at_all = True
def get_action(self, observation): self.is_update_on_last_action = False # print(type(self.policy_base)) action = self.policy_base.get_action(observation)[0] torch_observation = torch_ify(np.expand_dims(observation, axis=0)) torch_action = torch_ify(np.expand_dims(action, axis=0)) # print(torch_observation, torch_action) if self.qf_danger_probability(torch_observation, torch_action) >= self.threshold: self.is_update_on_last_action = True action = self.policy_danger.get_action(observation)[0] return action, {}
def __init__(self, hidden_sizes, obs_dim, action_dim, num_bootstrap, rew_function=None, env=None): ''' Usage: model = Model(...) next_obs = model(obs, action) trajectory = model.unroll(obs, action_sequence) Note: only pass in env if you want to denormalize state/actions before reward function in rollouts TODO: handle the different PETS sampling strategies. ''' super().__init__() self.rew_function = rew_function self.predict_reward = self.rew_function is None self.env = env if self.env is not None: self.lb = torch_ify(self.env._wrapped_env.action_space.low) self.ub = torch_ify(self.env._wrapped_env.action_space.high) self.obs_dim = obs_dim self.action_dim = action_dim self.input_size = obs_dim + action_dim if self.predict_reward: self.output_dim = obs_dim * 2 + 1 else: self.output_dim = obs_dim * 2 self.num_bootstrap = num_bootstrap self._nets = nn.ModuleList() for i in range(num_bootstrap): # TODO: figure out what the network architecture should be self._nets.append( FlattenMlp(hidden_sizes, self.output_dim, self.input_size, hidden_activation=swish)) self.max_logvar = nn.Parameter( torch.ones(1, self.obs_dim, dtype=torch.float32) / 2.0) self.min_logvar = nn.Parameter( -torch.ones(1, self.obs_dim, dtype=torch.float32) * 10.0) self.trained_at_all = False
def get_action(self, obs, deterministic=False): assert len(obs.shape) == 1 with torch.no_grad(): obs_action = (torch_ify(obs)[None,None,:], self.a_p[None,None,:]) pis, info = self.forward(obs_action, latent=self.latent_p, return_info=True) sup_probs = self.sup_prob(obs_action, latent=self.latent_p) pis = np_ify(pis[0,0,:]) sup_probs = np_ify(sup_probs[0,0,:,:]) if deterministic: action = np.argmax(pis) else: action = np.random.choice(np.arange(pis.shape[0]),p=pis) self.a_p = torch_ify(np.array([action])) self.latent_p = info['latent'] return action, {'intentions': sup_probs}
def __init__( self, input_dim, node_num, ego_init=np.array([0., 1.]), other_init=np.array([1., 0.]), ): super(TrafficGraphBuilder, self).__init__() self.input_dim = input_dim self.node_num = node_num self.ego_init = torch_ify(ego_init) self.other_init = torch_ify(other_init) self.output_dim = input_dim + self.ego_init.shape[0]
def forward(self, obs, valid_musk=None): # x: (batch*num_node) x output_dim # edge_index: 2 x node_edge # messages from nodes in edge_index[0] are sent to nodes in edge_index[1] batch_size, node_num, obs_dim = obs.shape x = torch.zeros(batch_size, self.node_num, self.output_dim).to(ptu.device) x[:, :, :self.input_dim] = obs x[:, 0, self.input_dim:] = self.ego_init[None, :] x[:, 1:, self.input_dim:] = self.other_init[None, None, :] x = x.reshape(int(batch_size * self.node_num), self.output_dim) # xs = obs[:,:,0] # ys = obs[:,:,1] # upper_indices = torch.where(ys > 4.) # lower_indices = torch.where((ys > 0.) and (ys <= 4.)) obs = np_ify(obs) edge_index = get_edge_index(obs) #batch x 2 x max_edge_num edge_index = np.swapaxes(edge_index, 0, 1).reshape(2, -1) edge_index = np.unique(edge_index, axis=1) edge_index = torch_ify(edge_index).long() edge_index = pyg_utils.remove_self_loops(edge_index)[0] return x, edge_index
def unroll(self, obs, action_sequence, sampling_strategy): ''' obs: batch_size * obs_dim (Tensor) action_sequence: batch_size * timesteps * action_dim (Tensor) sampling_strategy: one of "TS1" or "TSinf" return observations: batch_size * timesteps * obs_dim rewards: batch_size * timesteps ''' obs = torch_ify(obs) action_sequence = torch_ify(action_sequence) batch_size = action_sequence.shape[0] n_timesteps = action_sequence.shape[1] obs_output = [] rew_output = [] # sampling the initial bootstrap assignments for every particle bootstrap_assignments = np.random.randint(self.num_bootstrap, size=batch_size) for i in range(n_timesteps): bs_obs = [] bs_rew = [] # we just run all networks on all inputs and then do the sampling after for network_idx in range(self.num_bootstrap): # rely on the fact that self.forward() is probabilistic and stuff with torch.no_grad(): next_obs, reward = self.forward(obs, action_sequence[:, i, :], network_idx) bs_obs.append(next_obs) bs_rew.append(reward) bs_obs = torch.stack(bs_obs) bs_rew = torch.stack(bs_rew) # do the sampling for the bootstrap next_obs = bs_obs[bootstrap_assignments, range(batch_size), :] next_rew = bs_rew[bootstrap_assignments, range(batch_size)] # move observation forward obs = next_obs obs_output.append(next_obs) rew_output.append(next_rew) # resample if needed if sampling_strategy == 'TS1': bootstrap_assignments = np.random.randint(self.num_bootstrap, size=batch_size) observations = torch.stack(obs_output, dim=1) rewards = torch.stack(rew_output, dim=1) return observations, rewards
def _compute_sup_loss(self, obs, actions, labels, valids): obs = torch_ify(obs) actions = torch_ify(actions) valids = torch_ify(valids).bool() labels = torch_ify(labels).clone() valid_mask = ~torch.isnan(labels) labels[~valid_mask] = 0 if self._recurrent: pre_actions = actions[:, :-1, :] policy_input = (obs, pre_actions) else: policy_input = obs lls = self.policy.sup_log_prob(policy_input, labels) lls[~valids] = 0 lls[~valid_mask] = 0 # return -lls[valid_mask].mean() return -lls.sum() / (valids[:, :, None] * valid_mask).float().sum()
def _train_sup_learners(self, observations, n_labels): sup_losses = [] observations = torch_ify(observations) for learner, labels, optimizer in zip(self.sup_learners, n_labels, self._sup_optimizers): labels = torch_ify(labels) valid_mask = ~torch.isnan(labels).squeeze(-1) if torch.sum(valid_mask) > 0.: optimizer.zero_grad() loss = self._compute_sup_loss(learner, observations, labels, valid_mask) loss.backward() optimizer.step() sup_losses.append(loss.item()) else: sup_losses.append(0.) return sup_losses
def keyDownCb(keyName): if keyName == 'BACKSPACE': resetEnv() return if keyName == 'ESCAPE': sys.exit(0) action = 0 if keyName == 'LEFT': action = env.actions.west elif keyName == 'RIGHT': action = env.actions.east elif keyName == 'UP': action = env.actions.north elif keyName == 'DOWN': action = env.actions.south elif keyName == 'SPACE': action = env.actions.mine elif keyName == 'PAGE_UP': if hasattr(env.actions, 'eat'): action = env.actions.eat else: action = env.actions.dispense elif keyName == 'PAGE_DOWN': action = env.actions.place elif keyName == '0': action = env.actions.place0 elif keyName == '1': action = env.actions.place1 elif keyName == '2': action = env.actions.place2 elif keyName == '3': action = env.actions.place3 elif keyName == '4': action = env.actions.place4 elif keyName == 'RETURN': action = env.actions.done else: print("unknown key %s" % keyName) return obs, reward, done, info = env.step(action) if pkl is not None: qs = qf(torch_ify(obs)).data.numpy()[0] print(qs) print(qs.argmax()) if hasattr(env, 'health'): print('step=%s, reward=%.2f, health=%d' % (env.step_count, reward, env.health)) else: print('step=%s, reward=%.2f' % (env.step_count, reward)) if done: print('done!') resetEnv()
def __init__( self, a_0, latent_0, obs_dim, action_dim, lstm_net, post_net, ): super().__init__() self.a_0 = torch_ify(a_0).clone().detach() self.latent_0 = tuple( [torch_ify(h).clone().detach() for h in latent_0]) self.a_p = self.a_0.clone().detach() self.latent_p = tuple([h.clone().detach() for h in self.latent_0]) self.obs_dim = obs_dim self.action_dim = action_dim self.lstm_net = lstm_net self.post_net = post_net
def __init__( self, a_0, latent_0, obs_dim, action_dim, lstm_net, decoder, sup_learner, ): super().__init__() self.a_0 = torch_ify(a_0).clone().detach() self.latent_0 = tuple([torch_ify(h).clone().detach() for h in latent_0]) self.a_p = self.a_0.clone().detach() self.latent_p = tuple([h.clone().detach() for h in self.latent_0]) self.obs_dim = obs_dim self.action_dim = action_dim self.lstm_net = lstm_net self.decoder = decoder self.sup_learner = sup_learner
def add_advantages(self, path, path_len, flag): if flag: next_vf = self.vf(torch_ify(path["next_observations"])) cur_vf = self.vf(torch_ify(path["observations"])) rewards = torch_ify(path["rewards"]) term = (1 - torch_ify(path["terminals"].astype(np.float32))) delta = rewards + term * self.discount * next_vf - cur_vf advantages = torch.zeros((path_len)) returns = torch.zeros((path_len)) gae = 0 R = 0 for i in reversed(range(path_len)): advantages[i] = delta[i] + term[i] * (self.discount * self.gae_lambda) * gae gae = advantages[i] returns[i] = rewards[i] + term[i] * self.discount * R R = returns[i] advantages = np_ify(advantages) if advantages.std() != 0.0: advantages = (advantages - advantages.mean()) / advantages.std() else: advantages = (advantages - advantages.mean()) returns = np_ify(returns) else: advantages = np.zeros(path_len) returns = np.zeros(path_len) return dict(observations=path["observations"], actions=path["actions"], rewards=path["rewards"], next_observations=path["next_observations"], terminals=path["terminals"], agent_infos=path["agent_infos"], env_infos=path["env_infos"], advantages=advantages, returns=returns)
def forward(self, obs): if len(obs.shape) < 2: obs = torch_ify(obs).unsqueeze(0) cumsum = 0 arrs = [] for size in self.sizes: arrs.append(obs.narrow(dim=1, start=cumsum, length=size)) cumsum += size assert cumsum == obs.shape[1], 'not all of obs used' full_img, health, pos, pantry, shelf = arrs x_full_img = self.full_img_network(full_img) x_inventory = self.inventory_network(pantry, shelf, health, pos) out = self.final_network(torch.cat((x_full_img, x_inventory), dim=1)) return out
def forward(self, obs): if len(obs.shape) < 2: obs = torch_ify(obs).unsqueeze(0) cumsum = 0 arrs = [] for size in self.sizes: arrs.append(obs.narrow(dim=1, start=cumsum, length=size)) cumsum += size assert cumsum == obs.shape[1], 'not all of obs used' img, shelf, health = arrs x_img = self.img_network(img) x_inventory = self.inventory_network(shelf) out = self.final_network(x_img, x_inventory, health) return out
def forward(self, obs): # import pdb; pdb.set_trace() if len(obs.shape) < 2: obs = torch_ify(obs).unsqueeze(0) cumsum = 0 arrs = [] # import pdb; # pdb.set_trace() # for size in self.sizes: # arrs.append(obs.narrow(dim=1, start=cumsum, length=size)) # cumsum += size # assert cumsum == obs.shape[1], 'not all of obs used' # # import pdb; pdb.set_trace() x = self.img_network(obs.contiguous().view((obs.shape[0], -1))) #import pdb; pdb.set_trace() out = self.final_network(x) return out
def _get_dist_from_np(self, *args, **kwargs): torch_args = tuple(torch_ify(x) for x in args) torch_kwargs = {k: torch_ify(v) for k, v in kwargs.items()} dist = self(*torch_args, **torch_kwargs) return dist
def step(self, action): self.just_made_obj_type = None self.just_eaten_type = None self.just_placed_on = None obs, reward, done, info = super().step(action, incl_health=self.include_health) shelf_obs = self.gen_shelf_obs() """ Generate obs """ extra_obs_count_string = shelf_obs.sum(axis=0).tostring() extra_obs = shelf_obs.flatten() # magic number repeating shelf 8 times to fill up more of the obs extra_obs = np.repeat(extra_obs, 8) num_objs = np.repeat(self.info_last['pickup_%s' % self.task[1]], 8) obs = np.concatenate((obs, extra_obs, num_objs)) """ Generate reward """ solved = self.solved_task() if 'make' in self.task[0]: reward = self.get_make_reward() if self.task[0] == 'make': info.update({ 'progress': (self.max_make_idx + 1) / len(self.make_sequence) }) else: reward = int(solved) """ Generate info """ info.update({'health': self.health}) info.update(self.info_last) if solved: if self.end_on_task_completion: done = True info.update({'solved': True}) if self.lifelong: # remove obj so can keep making self.carrying = None else: info.update({'solved': False}) if self.time_horizon and self.step_count % self.time_horizon == 0: done = True """ Exploration bonuses """ self.obs_count[extra_obs_count_string] = self.obs_count.get( extra_obs_count_string, 0) + 1 if self.cbe: reward += 1 / np.sqrt(self.obs_count[extra_obs_count_string]) elif self.rnd: torch_obs = torch_ify(extra_obs) true_rnd = self.rnd_network(torch_obs) pred_rnd = self.rnd_target_network(torch_obs) loss = self.rnd_loss(true_rnd, pred_rnd) self.rnd_optimizer.zero_grad() loss.backward() self.rnd_optimizer.step() # RND exploration bonus self.sum_rnd += loss self.sum_square_rnd += loss**2 stdev = (self.sum_square_rnd / self.step_count) - (self.sum_rnd / self.step_count)**2 reward += loss / (stdev * self.health_cap) # funny ordering because otherwise we'd get the transpose due to how the grid indices work self.visit_count[self.agent_pos[1], self.agent_pos[0]] += 1 return obs, reward, done, info
def step(self, action): self.env_shaping_step_count += 1 self.just_made_obj_type = None self.just_eaten_type = None self.just_placed_on = None self.just_mined_type = None obs, reward, done, info = super().step(action, incl_health=self.include_health) shelf_obs = self.gen_shelf_obs() """ Generate obs """ obs_grid_string = obs.tostring() extra_obs = shelf_obs.flatten() # magic number repeating shelf 8 times to fill up more of the obs extra_obs = np.repeat(extra_obs, 8) num_objs = np.repeat(self.info_last['pickup_%s' % self.task[1]], 8) obs = np.concatenate( (obs, extra_obs, num_objs)) if self.include_num_objs else np.concatenate( (obs, extra_obs)) """ Generate reward """ solved = self.solved_task() if 'make' in self.task[0]: reward = self.get_make_reward() if self.task[0] == 'make': info.update({ 'progress': (self.max_make_idx + 1) / len(self.make_sequence) }) else: reward = int(solved) """ Generate info """ info.update({'health': self.health}) info.update(self.info_last) if solved: if self.end_on_task_completion: done = True info.update({'solved': True}) if self.lifelong: # remove obj so can keep making self.carrying = None else: info.update({'solved': False}) if self.time_horizon and self.step_count % self.time_horizon == 0: done = True """ Exploration bonuses """ self.obs_count[obs_grid_string] = self.obs_count.get( obs_grid_string, 0) + 1 if self.cbe: reward += 1 / np.sqrt(self.obs_count[obs_grid_string]) elif self.rnd: self.sum_rnd_obs += obs torch_obs = torch_ify(obs) true_rnd = self.rnd_network(torch_obs) pred_rnd = self.rnd_target_network(torch_obs) loss = self.rnd_loss(true_rnd, pred_rnd) self.rnd_optimizer.zero_grad() loss.backward() self.rnd_optimizer.step() # RND exploration bonus self.sum_rnd_loss += loss self.sum_square_rnd_loss += loss**2 mean = self.sum_rnd_loss / self.step_count stdev = (self.sum_square_rnd_loss / self.step_count) - mean**2 try: bonus = np.clip((loss / stdev).detach().numpy(), -1, 1) except ZeroDivisionError: # stdev is 0, which should occur only in the first timestep bonus = 1 reward += bonus if self.hitting_time == 0 and reward > 0: self.hitting_time = self.step_count # funny ordering because otherwise we'd get the transpose due to how the grid indices work self.visit_count[self.agent_pos[1], self.agent_pos[0]] += 1 return obs, reward, done, info
def get_action(self, obs_np): dist_vec = eval_np(self, obs_np) return Categorical(torch_ify(dist_vec)).sample().item(), {}
def train_from_torch(self, batch): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] """ Policy and Alpha Loss """ dist = Categorical(self.policy(obs)) new_obs_actions = dist.sample() log_pi = dist.log_prob(new_obs_actions) log_pis = torch.stack([dist.log_prob(torch.tensor(ac, device=('cuda' if ptu.gpu_enabled() else 'cpu'))) for ac in range(self.policy.action_dim)]).permute(1, 0) if self.use_automatic_entropy_tuning: alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() alpha = self.log_alpha.exp() else: alpha_loss = 0 alpha = 1 q_new_actions = torch.min( self.qf1(obs), self.qf2(obs), ) policy_loss = (torch.exp(log_pis) * (alpha * log_pis - q_new_actions)).sum(dim=1).mean() """ QF Loss """ action_idx = actions.argmax(dim=1).unsqueeze(1) q1_pred = self.qf1(obs).gather(1, action_idx) q2_pred = self.qf2(obs).gather(1, action_idx) # Make sure policy accounts for squashing functions like tanh correctly! new_dist = Categorical(torch_ify(self.policy(next_obs))) new_next_actions = new_dist.sample() new_log_pi = new_dist.log_prob(new_next_actions).unsqueeze(1) target_q_values = torch.min( self.target_qf1(next_obs).gather(1, new_next_actions.unsqueeze(1)), self.target_qf2(next_obs).gather(1, new_next_actions.unsqueeze(1)), ) - alpha * new_log_pi q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values qf1_loss = self.qf_criterion(q1_pred, q_target.detach()) qf2_loss = self.qf_criterion(q2_pred, q_target.detach()) """ Update networks """ self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() """ Soft Updates """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to( self.qf1, self.target_qf1, self.soft_target_tau ) ptu.soft_update_from_to( self.qf2, self.target_qf2, self.soft_target_tau ) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ policy_loss = (torch.exp(log_pis) * (log_pis - q_new_actions)).sum(dim=1).mean() self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy( policy_loss )) self.eval_statistics.update(create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update(create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) # self.eval_statistics.update(create_stats_ordered_dict( # 'Policy mu', # ptu.get_numpy(policy_mean), # )) # self.eval_statistics.update(create_stats_ordered_dict( # 'Policy log std', # ptu.get_numpy(policy_log_std), # )) if self.use_automatic_entropy_tuning: self.eval_statistics['Alpha'] = alpha.item() self.eval_statistics['Alpha Loss'] = alpha_loss.item() self._n_train_steps_total += 1