def get_action(self, state, deterministic, epsilon=1e-6): mean, log_std = self.forward(state) normal = Normal(torch.zeros(mean.shape), torch.ones(log_std.shape)) z = normal.sample() if self.args.stochastic_actor: std = log_std.exp() action_0 = mean + torch.mul(z, std) action_1 = torch.tanh(action_0) action = torch.mul(self.action_range, action_1) + self.action_bias log_prob = Normal(mean, std).log_prob(action_0) - torch.log( 1. - action_1.pow(2) + epsilon) - torch.log(self.action_range) log_prob = log_prob.sum(dim=-1, keepdim=True) action_mean = torch.mul(self.action_range, torch.tanh(mean)) + self.action_bias action = action_mean.detach().cpu().numpy( ) if deterministic else action.detach().cpu().numpy() return action, log_prob.detach().item() else: action_mean = torch.mul(self.action_range, torch.tanh(mean)) + self.action_bias action = action_mean + 0.1 * torch.mul(self.action_range, z) action = torch.min(action, self.action_high) action = torch.max(action, self.action_low) action = action_mean.detach().cpu().numpy( ) if deterministic else action.detach().cpu().numpy() return action, 0
def forward(self, ss: List, phase_use_mode: bool = False) -> Tuple: p_pres_logits, p_where_mean, p_where_std, p_depth_mean, \ p_depth_std, p_what_mean, p_what_std = ss if phase_use_mode: z_pres = (p_pres_logits > 0).float() else: z_pres = RelaxedBernoulli(logits=p_pres_logits, temperature=self.args.train.tau_pres).rsample() # z_where_scale, z_where_shift: (bs, dim, num_cell, num_cell) if phase_use_mode: z_where_scale, z_where_shift = p_where_mean.chunk(2, 1) else: z_where_scale, z_where_shift = \ Normal(p_where_mean, p_where_std).rsample().chunk(2, 1) # z_where_origin: (bs, dim, num_cell, num_cell) z_where_origin = \ torch.cat([z_where_scale.detach(), z_where_shift.detach()], dim=1) z_where_shift = \ (2. / self.args.arch.num_cell) * \ (self.offset + 0.5 + torch.tanh(z_where_shift)) - 1. scale, ratio = z_where_scale.chunk(2, 1) scale = scale.sigmoid() ratio = torch.exp(ratio) ratio_sqrt = ratio.sqrt() z_where_scale = torch.cat([scale / ratio_sqrt, scale * ratio_sqrt], dim=1) # z_where: (bs, dim, num_cell, num_cell) z_where = torch.cat([z_where_scale, z_where_shift], dim=1) if phase_use_mode: z_depth = p_depth_mean z_what = p_what_mean else: z_depth = Normal(p_depth_mean, p_depth_std).rsample() z_what = Normal(p_what_mean, p_what_std).rsample() z_what_reshape = z_what.permute(0, 2, 3, 1).reshape(-1, self.args.z.z_what_dim). \ view(-1, self.args.z.z_what_dim, 1, 1) if self.args.data.inp_channel == 1 or not self.args.arch.phase_overlap: o = self.z_what_decoder_net(z_what_reshape) o = o.sigmoid() a = o.new_ones(o.size()) elif self.args.arch.phase_overlap: o, a = self.z_what_decoder_net(z_what_reshape).split([self.args.data.inp_channel, 1], dim=1) o, a = o.sigmoid(), a.sigmoid() else: raise NotImplemented lv = [z_pres, z_where, z_depth, z_what, z_where_origin] pa = [o, a] return pa, lv
def optimize_model(policy, q1_net, q2_net, v_net, v_target_net, memory, actor_optimizer, q_net_optimizer, v_net_optimizer): if len(memory) < train_batch_size: return 0, 0, 0 # dummy losses for consistency in presenting results st_b, ac_b, rew_b, nst_b, dn_b = memory.sample(train_batch_size) states_th = torch.tensor(st_b).float().to(device) actions_th = torch.tensor(ac_b).to(device) rewards_th = torch.tensor(rew_b).unsqueeze(1).to(device) next_states_th = torch.tensor(nst_b).float().to(device) dones_th = torch.tensor(dn_b).float().unsqueeze(1).to(device) Q1_vals = q1_net(states_th, actions_th) Q2_vals = q2_net(states_th, actions_th) V_vals = v_net(states_th) V_next_state_vals = v_target_net(next_states_th) pi_action_means, pi_action_logstd = policy(states_th) pi_action_stds = torch.exp(pi_action_logstd) z = Normal(torch.zeros_like(pi_action_means), torch.ones_like(pi_action_stds)).sample() newly_sampled_actions = pi_action_means + z * pi_action_stds newly_sampled_action_log_probs = Normal( pi_action_means, pi_action_stds).log_prob(newly_sampled_actions) newly_sampled_Q1_vals = q1_net(states_th, newly_sampled_actions) newly_sampled_Q2_vals = q2_net(states_th, newly_sampled_actions) newly_sampled_Q_minvals = torch.min(newly_sampled_Q1_vals, newly_sampled_Q2_vals) J_v = torch.mean( (V_vals - (newly_sampled_Q_minvals.detach() - entropy_coeff * newly_sampled_action_log_probs.detach()))**2) v_net_optimizer.zero_grad() J_v.backward() v_net_optimizer.step() J_q1 = torch.mean((Q1_vals - (rewards_th + gamma * V_next_state_vals * (1 - dones_th)))**2) J_q2 = torch.mean((Q2_vals - (rewards_th + gamma * V_next_state_vals * (1 - dones_th)))**2) J_q = J_q1 + J_q2 q_net_optimizer.zero_grad() J_q.backward() q_net_optimizer.step() J_pi = torch.mean(entropy_coeff * newly_sampled_action_log_probs - newly_sampled_Q_minvals) actor_optimizer.zero_grad() J_pi.backward() actor_optimizer.step() return J_v, J_q, J_pi
(epoch, i + 1, running_loss / 20)) running_loss = 0 optimizer.step() print('Done!') ### EVALUATE POLICY max_steps = env.spec.timestep_limit pi.cpu() returns = [] for i in range(10): print('iter', i) obs = env.reset() done = False totalr = 0. steps = 0 while not done: a_mu, a_sigma = pi(torch.from_numpy(obs).float()) a = Normal(loc=a_mu, scale=a_sigma).sample() obs, r, done, _ = env.step(a.detach().numpy()) if RENDER: env.render() totalr += r steps += 1 if steps % 100 == 0: print("%i/%i" % (steps, max_steps)) if steps >= max_steps: break returns.append(totalr) print('returns', returns) print('mean return', np.mean(returns)) print('std of return', np.std(returns))