def update_params(batch):
    states = torch.from_numpy(np.stack(batch.state)).to(dtype).to(device)
    actions = torch.from_numpy(np.stack(batch.action)).to(dtype).to(device)
    rewards = torch.from_numpy(np.stack(batch.reward)).to(dtype).to(device)
    masks = torch.from_numpy(np.stack(batch.mask)).to(dtype).to(device)
    with torch.no_grad():
        values = value_net(states)
    """get advantage estimation from the trajectories"""
    advantages, returns = estimate_advantages(rewards, masks, values,
                                              args.gamma, args.tau, device)
    """perform TRPO update"""
    trpo_step(policy_net, value_net, states, actions, returns, advantages,
              args.max_kl, args.damping, args.l2_reg)
Example #2
0
def update_params_trpo(batch, i_iter):
    # (3)
    value_np.training = False
    num_transition = 0
    idx = []
    for b in batch:
        l = len(b)
        idx.append(l + num_transition - 1)
        num_transition += l
    masks = ones(num_transition)
    masks[idx] = 0
    states = zeros(1, num_transition, state_dim)
    actions = zeros(1, num_transition, action_dim)
    disc_rewards = zeros(1, num_transition, 1)
    rewards = zeros(1, num_transition, 1)
    i = 0
    for e, ep in enumerate(batch):
        for t, tr in enumerate(ep):
            states[:, i, :] = torch.from_numpy(tr.state).to(dtype).to(device)
            actions[:, i, :] = torch.from_numpy(tr.action).to(dtype).to(device)
            disc_rewards[:,
                         i, :] = torch.tensor(tr.disc_rew).to(dtype).to(device)
            rewards[:, i, :] = torch.tensor(tr.reward).to(dtype).to(device)
            i += 1

    with torch.no_grad():
        values_distr = value_np(
            states, disc_rewards,
            states)  # estimate value function of each state with NN
        values = values_distr.mean
    """get advantage estimation from the trajectories"""
    advantages, returns = estimate_advantages(rewards.squeeze_(0), masks,
                                              values.squeeze_(0), args.gamma,
                                              args.tau, device)
    """plot"""
    plot_values(value_np, value_net, states, values, advantages, disc_rewards,
                env, args, i_iter)
    """perform TRPO update"""
    trpo_step(policy_net, value_net, states.squeeze(0), actions.squeeze(0),
              returns.squeeze(0), advantages, args.max_kl, args.damping,
              args.l2_reg)