def get_action(self, state, deterministic, epsilon=1e-6):
     mean, log_std = self.forward(state)
     normal = Normal(torch.zeros(mean.shape), torch.ones(log_std.shape))
     z = normal.sample()
     if self.args.stochastic_actor:
         std = log_std.exp()
         action_0 = mean + torch.mul(z, std)
         action_1 = torch.tanh(action_0)
         action = torch.mul(self.action_range, action_1) + self.action_bias
         log_prob = Normal(mean, std).log_prob(action_0) - torch.log(
             1. - action_1.pow(2) + epsilon) - torch.log(self.action_range)
         log_prob = log_prob.sum(dim=-1, keepdim=True)
         action_mean = torch.mul(self.action_range,
                                 torch.tanh(mean)) + self.action_bias
         action = action_mean.detach().cpu().numpy(
         ) if deterministic else action.detach().cpu().numpy()
         return action, log_prob.detach().item()
     else:
         action_mean = torch.mul(self.action_range,
                                 torch.tanh(mean)) + self.action_bias
         action = action_mean + 0.1 * torch.mul(self.action_range, z)
         action = torch.min(action, self.action_high)
         action = torch.max(action, self.action_low)
         action = action_mean.detach().cpu().numpy(
         ) if deterministic else action.detach().cpu().numpy()
         return action, 0
Ejemplo n.º 2
0
    def forward(self, ss: List, phase_use_mode: bool = False) -> Tuple:

        p_pres_logits, p_where_mean, p_where_std, p_depth_mean, \
        p_depth_std, p_what_mean, p_what_std = ss

        if phase_use_mode:
            z_pres = (p_pres_logits > 0).float()
        else:
            z_pres = RelaxedBernoulli(logits=p_pres_logits, temperature=self.args.train.tau_pres).rsample()

        # z_where_scale, z_where_shift: (bs, dim, num_cell, num_cell)
        if phase_use_mode:
            z_where_scale, z_where_shift = p_where_mean.chunk(2, 1)
        else:
            z_where_scale, z_where_shift = \
                Normal(p_where_mean, p_where_std).rsample().chunk(2, 1)

        # z_where_origin: (bs, dim, num_cell, num_cell)
        z_where_origin = \
            torch.cat([z_where_scale.detach(), z_where_shift.detach()], dim=1)

        z_where_shift = \
            (2. / self.args.arch.num_cell) * \
            (self.offset + 0.5 + torch.tanh(z_where_shift)) - 1.

        scale, ratio = z_where_scale.chunk(2, 1)
        scale = scale.sigmoid()
        ratio = torch.exp(ratio)
        ratio_sqrt = ratio.sqrt()
        z_where_scale = torch.cat([scale / ratio_sqrt, scale * ratio_sqrt], dim=1)
        # z_where: (bs, dim, num_cell, num_cell)
        z_where = torch.cat([z_where_scale, z_where_shift], dim=1)

        if phase_use_mode:
            z_depth = p_depth_mean
            z_what = p_what_mean
        else:
            z_depth = Normal(p_depth_mean, p_depth_std).rsample()
            z_what = Normal(p_what_mean, p_what_std).rsample()

        z_what_reshape = z_what.permute(0, 2, 3, 1).reshape(-1, self.args.z.z_what_dim). \
            view(-1, self.args.z.z_what_dim, 1, 1)

        if self.args.data.inp_channel == 1 or not self.args.arch.phase_overlap:
            o = self.z_what_decoder_net(z_what_reshape)
            o = o.sigmoid()
            a = o.new_ones(o.size())
        elif self.args.arch.phase_overlap:
            o, a = self.z_what_decoder_net(z_what_reshape).split([self.args.data.inp_channel, 1], dim=1)
            o, a = o.sigmoid(), a.sigmoid()
        else:
            raise NotImplemented

        lv = [z_pres, z_where, z_depth, z_what, z_where_origin]
        pa = [o, a]

        return pa, lv
Ejemplo n.º 3
0
def optimize_model(policy, q1_net, q2_net, v_net, v_target_net, memory,
                   actor_optimizer, q_net_optimizer, v_net_optimizer):
    if len(memory) < train_batch_size:
        return 0, 0, 0  # dummy losses for consistency in presenting results

    st_b, ac_b, rew_b, nst_b, dn_b = memory.sample(train_batch_size)
    states_th = torch.tensor(st_b).float().to(device)
    actions_th = torch.tensor(ac_b).to(device)
    rewards_th = torch.tensor(rew_b).unsqueeze(1).to(device)
    next_states_th = torch.tensor(nst_b).float().to(device)
    dones_th = torch.tensor(dn_b).float().unsqueeze(1).to(device)

    Q1_vals = q1_net(states_th, actions_th)
    Q2_vals = q2_net(states_th, actions_th)
    V_vals = v_net(states_th)
    V_next_state_vals = v_target_net(next_states_th)
    pi_action_means, pi_action_logstd = policy(states_th)
    pi_action_stds = torch.exp(pi_action_logstd)

    z = Normal(torch.zeros_like(pi_action_means),
               torch.ones_like(pi_action_stds)).sample()
    newly_sampled_actions = pi_action_means + z * pi_action_stds
    newly_sampled_action_log_probs = Normal(
        pi_action_means, pi_action_stds).log_prob(newly_sampled_actions)
    newly_sampled_Q1_vals = q1_net(states_th, newly_sampled_actions)
    newly_sampled_Q2_vals = q2_net(states_th, newly_sampled_actions)
    newly_sampled_Q_minvals = torch.min(newly_sampled_Q1_vals,
                                        newly_sampled_Q2_vals)

    J_v = torch.mean(
        (V_vals -
         (newly_sampled_Q_minvals.detach() -
          entropy_coeff * newly_sampled_action_log_probs.detach()))**2)
    v_net_optimizer.zero_grad()
    J_v.backward()
    v_net_optimizer.step()

    J_q1 = torch.mean((Q1_vals - (rewards_th + gamma * V_next_state_vals *
                                  (1 - dones_th)))**2)
    J_q2 = torch.mean((Q2_vals - (rewards_th + gamma * V_next_state_vals *
                                  (1 - dones_th)))**2)
    J_q = J_q1 + J_q2
    q_net_optimizer.zero_grad()
    J_q.backward()
    q_net_optimizer.step()

    J_pi = torch.mean(entropy_coeff * newly_sampled_action_log_probs -
                      newly_sampled_Q_minvals)
    actor_optimizer.zero_grad()
    J_pi.backward()
    actor_optimizer.step()
    return J_v, J_q, J_pi
Ejemplo n.º 4
0
                  (epoch, i + 1, running_loss / 20))
            running_loss = 0
        optimizer.step()
print('Done!')

### EVALUATE POLICY
max_steps = env.spec.timestep_limit
pi.cpu()
returns = []
for i in range(10):
    print('iter', i)
    obs = env.reset()
    done = False
    totalr = 0.
    steps = 0
    while not done:
        a_mu, a_sigma = pi(torch.from_numpy(obs).float())
        a = Normal(loc=a_mu, scale=a_sigma).sample()
        obs, r, done, _ = env.step(a.detach().numpy())
        if RENDER:
            env.render()
        totalr += r
        steps += 1
        if steps % 100 == 0: print("%i/%i" % (steps, max_steps))
        if steps >= max_steps:
            break
    returns.append(totalr)

print('returns', returns)
print('mean return', np.mean(returns))
print('std of return', np.std(returns))