Exemplo n.º 1
0
 def backtrack_fn(s):
     assign(initial_parameters + s.data, net.parameters())
     test_pds = net(all_states)
     test_action_log_probs = net.get_loglikelihood(test_pds, actions)
     new_reward = surrogate_reward(advs, new=test_action_log_probs, old=old_log_ps).mean()
     if new_reward <= surr_rew or net.calc_kl(pds, test_pds).mean() > params.MAX_KL:
         return -float('inf')
     return new_reward - surr_rew
Exemplo n.º 2
0
 def backtrack_fn(s):
     assign(initial_parameters + s.data, net.parameters())
     test_pds = net(all_states)
     test_action_log_probs = net.get_loglikelihood(test_pds, actions)
     new_reward = surrogate_reward(advs, new=test_action_log_probs, old=old_log_ps).mean()
     # surr_new is the surrogate before optimization.
     # We need to make sure the loss is improving, and KL between old probabilites are not too large.
     if params.TRPO_KL_REDUCE_FUNC == 'mean':
         kl_metric = net.calc_kl(pds, test_pds).mean()
     elif params.TRPO_KL_REDUCE_FUNC == 'max':
         kl_metric = net.calc_kl(pds, test_pds).max()
     else:
         raise ValueError("unknown reduce function " + params.TRPO_KL_REDUCE_FUNC)
     if new_reward <= surr_rew or kl_metric > params.MAX_KL:
         return -float('inf')
     return new_reward - surr_rew
Exemplo n.º 3
0
 def backtrack_fn(s):
     assign(initial_parameters + s.data, net.parameters())
     test_pds = net(all_states)
     test_action_log_probs = net.get_loglikelihood(test_pds, actions)
     new_reward = surrogate_reward(advs,
                                   new=test_action_log_probs,
                                   old=old_log_ps).mean()
     if params.USE_CONS == 'all':
         if new_reward <= surr_rew or net.calc_kl(
                 pds, test_pds).mean() > params.MAX_KL:
             return -float('inf')
     elif params.USE_CONS == 'kl':
         if net.calc_kl(pds, test_pds).mean() > params.MAX_KL:
             return -float('inf')
     elif params.USE_CONS == 'rew':
         if new_reward <= surr_rew:
             return -float('inf')
     elif params.USE_CONS == 'none':
         pass
     else:
         raise NotImplementedError("No such constraints")
     return new_reward - surr_rew
Exemplo n.º 4
0
def trpo_step(all_states, actions, old_log_ps, rewards, returns, not_dones, advs, net, params, store, opt_step):
    '''
    Trust Region Policy Optimization
    Runs K epochs of TRPO as in https://arxiv.org/abs/1502.05477
    Inputs:
    - all_states, the historical value of all the states
    - actions, the actions that the policy sampled
    - old_log_ps, the probability of the actions that the policy sampled
    - advs, advantages as estimated by GAE
    - net, policy network to train [WILL BE MUTATED]
    - params, additional placeholder for parameters like EPS
    Returns:
    - The TRPO loss; main job is to mutate the net
    '''    
    # Initial setup
    initial_parameters = flatten(net.parameters()).clone()
    # all_states is in shape (experience_size, observation_size). Usually 2048 experiences.
    # Get mean and std of action distribution for all experiences.
    pds = net(all_states)
    # And compute the log probabilities for the actions chosen at rollout time.
    action_log_probs = net.get_loglikelihood(pds, actions)

    # Calculate losses
    surr_rew = surrogate_reward(advs, new=action_log_probs, old=old_log_ps).mean()
    grad = ch.autograd.grad(surr_rew, net.parameters(), retain_graph=True)
    # This represents the computation of gradient, and will be used to obtain 2nd order.
    flat_grad = flatten(grad)

    # Make fisher product estimator. Only use a fraction of examples.
    num_samples = int(all_states.shape[0] * params.FISHER_FRAC_SAMPLES)
    selected = np.random.choice(range(all_states.shape[0]), num_samples, replace=False)
    
    detached_selected_pds = select_prob_dists(pds, selected, detach=True)
    selected_pds = select_prob_dists(pds, selected, detach=False)
    
    # Construct the KL divergence which we will optimize on. This is essentially 0, but what we care about is the Hessian.
    # We want to know when the network parameter changes, how the K-L divergence of network output changes.
    kl = net.calc_kl(detached_selected_pds, selected_pds).mean()
    # g is the gradient of the KL divergence w.r.t to parameters. It is 0 at the starting point.
    g = flatten(ch.autograd.grad(kl, net.parameters(), create_graph=True))
    '''
    Fisher matrix to vector x product. Essentially, a Hessian-vector product of K-L divergence w.r.t network parameter.
    '''
    def fisher_product(x, damp_coef=1.):
        contig_flat = lambda q: ch.cat([y.contiguous().view(-1) for y in q])
        # z is the gradient-vector product. Take the derivation of it to get Hessian vector product.
        z = g @ x
        hv = ch.autograd.grad(z, net.parameters(), retain_graph=True)
        return contig_flat(hv).detach() + x*params.DAMPING * damp_coef

    # Find KL constrained gradient step
    # The Fisher matrix A is unknown, but we can compute the product.
    # flat_grad is the right-hand side value b. Want to solve x in Ax = b
    step = cg_solve(fisher_product, flat_grad, params.CG_STEPS)
    # Return the solution. "step" has size of network parameters.

    max_step_coeff = (2 * params.MAX_KL / (step @ fisher_product(step)))**(0.5)
    max_trpo_step = max_step_coeff * step

    if store and params.SHOULD_LOG_KL:
        kl_approximation_logging(all_states, pds, flat_grad, step, net, store)
        kl_vs_second_order_approx(all_states, pds, net, max_trpo_step, params, store, opt_step)

    # Backtracking line search
    with ch.no_grad():
        # Backtracking function, which gives the improvement on objective given an update direction s.
        def backtrack_fn(s):
            assign(initial_parameters + s.data, net.parameters())
            test_pds = net(all_states)
            test_action_log_probs = net.get_loglikelihood(test_pds, actions)
            new_reward = surrogate_reward(advs, new=test_action_log_probs, old=old_log_ps).mean()
            # surr_new is the surrogate before optimization.
            # We need to make sure the loss is improving, and KL between old probabilites are not too large.
            if params.TRPO_KL_REDUCE_FUNC == 'mean':
                kl_metric = net.calc_kl(pds, test_pds).mean()
            elif params.TRPO_KL_REDUCE_FUNC == 'max':
                kl_metric = net.calc_kl(pds, test_pds).max()
            else:
                raise ValueError("unknown reduce function " + params.TRPO_KL_REDUCE_FUNC)
            if new_reward <= surr_rew or kl_metric > params.MAX_KL:
                return -float('inf')
            return new_reward - surr_rew
        expected_improve = flat_grad @ max_trpo_step
        # max_trpo_step is the search direction. Backtracking line search will find a scaler for it.
        # expected_improve is the expected decrease in loss estimated by gradient.
        # backtracking_line_search will try a scaler 0.5, 0.25, 0.125, etc to achieve expected improvement.
        final_step = backtracking_line_search(backtrack_fn, max_trpo_step,
                                              expected_improve,
                                              num_tries=params.MAX_BACKTRACK)

        assign(initial_parameters + final_step, net.parameters())

    # entropy regularization not used for TRPO so return 0.
    return surr_rew.item(), 0.0, 0.0
Exemplo n.º 5
0
def robust_ppo_step(all_states, actions, old_log_ps, rewards, returns, not_dones, 
                advs, net, params, store, opt_step, relaxed_net, eps_scheduler, beta_scheduler):
    '''
    Proximal Policy Optimization with robustness regularizer
    Runs K epochs of PPO as in https://arxiv.org/abs/1707.06347
    Inputs:
    - all_states, the historical value of all the states
    - actions, the actions that the policy sampled
    - old_log_ps, the log probability of the actions that the policy sampled
    - advs, advantages as estimated by GAE
    - net, policy network to train [WILL BE MUTATED]
    - params, additional placeholder for parameters like EPS
    Returns:
    - The PPO loss; main job is to mutate the net
    '''
    # Storing batches of stuff
    # if store is not None:
    #     orig_dists = net(all_states)

    ### ACTUAL PPO OPTIMIZATION START
    if params.SHARE_WEIGHTS:
        orig_vs = net.get_value(all_states).squeeze(-1).view([params.NUM_ACTORS, -1])
        old_vs = orig_vs.detach()

    # We treat all PPO epochs as one epoch.
    eps_scheduler.set_epoch_length(params.PPO_EPOCHS * params.NUM_MINIBATCHES)
    beta_scheduler.set_epoch_length(params.PPO_EPOCHS * params.NUM_MINIBATCHES)
    # We count from 1.
    eps_scheduler.step_epoch()
    beta_scheduler.step_epoch()

    if params.HISTORY_LENGTH > 0:
        # LSTM policy. Need to go over all episodes instead of states.
        # We normalize all advantages at once instead of batch by batch, since each batch may contain different number of samples.
        normalized_advs = adv_normalize(advs)
        batches, alive_masks, time_masks, lengths = pack_history([all_states, actions, old_log_ps, normalized_advs], not_dones, max_length=params.HISTORY_LENGTH)


    for _ in range(params.PPO_EPOCHS):
        if params.HISTORY_LENGTH > 0:
            # LSTM policy. Need to go over all episodes instead of states.
            params.POLICY_ADAM.zero_grad()
            hidden = None
            surrogate = 0.0
            for i, batch in enumerate(batches):
                # Now we get chunks of time sequences, each of them with a maximum length of params.HISTORY_LENGTH.
                # select log probabilities, advantages of this minibatch.
                batch_states, batch_actions, batch_old_log_ps, batch_advs = batch
                mask = time_masks[i]
                """
                print('batch states', batch_states.size())
                print('batch actions', batch_actions.size())
                print('batch old_log_ps', batch_old_log_ps.size())
                print('batch advs', batch_advs.size())
                print('alive mask', alive_masks[i].size(), alive_masks[i].sum())
                print('mask', mask.size())
                """
                # keep only the alive hidden states.
                if hidden is not None:
                    # print('hidden[0]', hidden[0].size())
                    hidden = [h[:, alive_masks[i], :].detach() for h in hidden]
                    # print('hidden[0]', hidden[0].size())
                # dist contains mean and variance of Gaussian.
                mean, std, hidden = net.multi_forward(batch_states, hidden=hidden)
                dist = mean, std
                # Convert state distribution to log likelyhood.
                new_log_ps = net.get_loglikelihood(dist, batch_actions)
                # print('batch new_log_ps', new_log_ps.size())
                """
                print('old')
                print(batch_old_log_ps)
                print('new')
                print(new_log_ps * mask)
                print('diff')
                print((batch_old_log_ps - new_log_ps * mask).pow(2).sum().item())
                """

                shape_equal_cmp(new_log_ps, batch_old_log_ps)

                # Calculate rewards
                # the surrogate rewards is basically exp(new_log_ps - old_log_ps) * advantage
                # dimension is the same as minibatch size.
                # We already normalized advs before. No need to normalize here.
                unclp_rew = surrogate_reward(batch_advs, new=new_log_ps, old=batch_old_log_ps, mask=mask, normalize=False)
                clp_rew = surrogate_reward(batch_advs, new=new_log_ps, old=batch_old_log_ps,
                                           clip_eps=params.CLIP_EPS, mask=mask, normalize=False)


                # Total loss, is the min of clipped and unclipped reward for each state, averaged.
                surrogate_batch = (-ch.min(unclp_rew, clp_rew) * mask).sum()
                # We sum the batch loss here because each batch contains uneven number of trajactories.
                surrogate = surrogate + surrogate_batch

            # Divide surrogate loss by number of samples in this batch.
            surrogate = surrogate / all_states.size(0)
            # Calculate entropy bonus
            # So far, the entropy only depends on std and does not depend on time. No need to mask.
            entropy_bonus = net.entropies(dist)
            # Calculate regularizer under state perturbation.
            eps_scheduler.step_batch()
            beta_scheduler.step_batch()
            batch_action_means = None
            current_eps = eps_scheduler.get_eps()
            stdev = ch.exp(net.log_stdev)
            if params.ROBUST_PPO_DETACH_STDEV:
                # Detach stdev so that it won't be too large.
                stdev = stdev.detach()
            if params.ROBUST_PPO_METHOD == "sgld":
                kl_upper_bound = get_state_kl_bound_sgld(net, all_states, None,
                        eps=current_eps, steps=params.ROBUST_PPO_PGD_STEPS,
                        stdev=stdev, not_dones=not_dones).mean()
            else:
                raise ValueError(f"Unsupported robust PPO method {params.ROBUST_PPO_METHOD}")
            entropy = -params.ENTROPY_COEFF * entropy_bonus
            loss = surrogate + entropy + params.ROBUST_PPO_REG * kl_upper_bound
            # optimizer (only ADAM)
            loss.backward()
            if params.CLIP_GRAD_NORM != -1:
                ch.nn.utils.clip_grad_norm(net.parameters(), params.CLIP_GRAD_NORM)
            params.POLICY_ADAM.step()
        else:
            # Memoryless policy.
            # State is in shape (experience_size, observation_size). Usually 2048.
            state_indices = np.arange(all_states.shape[0])
            np.random.shuffle(state_indices)
            # We use a minibatch of states to do optimization, and each epoch contains several iterations.
            splits = np.array_split(state_indices, params.NUM_MINIBATCHES)
            # A typical mini-batch size is 2048/32=64
            for selected in splits:
                def sel(*args):
                    return [v[selected] for v in args]

                # old_log_ps: log probabilities of actions sampled based in experience buffer.
                # advs: advantages of these states.
                # both old_log_ps and advs are in shape (experience_size,) = 2048.
                tup = sel(all_states, actions, old_log_ps, advs)
                # select log probabilities, advantages of this minibatch.
                batch_states, batch_actions, batch_old_log_ps, batch_advs = tup

                # Forward propagation on current parameters (being constantly updated), to get distribution of these states
                # dist contains mean and variance of Gaussian.
                dist = net(batch_states)
                # Convert state distribution to log likelyhood.
                new_log_ps = net.get_loglikelihood(dist, batch_actions)

                shape_equal_cmp(new_log_ps, batch_old_log_ps)

                # Calculate rewards
                # the surrogate rewards is basically exp(new_log_ps - old_log_ps) * advantage
                # dimension is the same as minibatch size.
                unclp_rew = surrogate_reward(batch_advs, new=new_log_ps, old=batch_old_log_ps)
                clp_rew = surrogate_reward(batch_advs, new=new_log_ps, old=batch_old_log_ps,
                                           clip_eps=params.CLIP_EPS)

                # Calculate entropy bonus
                entropy_bonus = net.entropies(dist).mean()

                # Calculate regularizer under state perturbation.
                eps_scheduler.step_batch()
                beta_scheduler.step_batch()
                batch_action_means = dist[0]
                current_eps = eps_scheduler.get_eps()
                stdev = ch.exp(net.log_stdev)
                if params.ROBUST_PPO_DETACH_STDEV:
                    # Detach stdev so that it won't be too large.
                    stdev = stdev.detach()
                if params.ROBUST_PPO_METHOD == "convex-relax":
                    kl_upper_bound = get_state_kl_bound(relaxed_net, batch_states, batch_action_means,
                            eps=current_eps, beta=beta_scheduler.get_eps(),
                            stdev=stdev).mean()
                elif params.ROBUST_PPO_METHOD == "sgld":
                    kl_upper_bound = get_state_kl_bound_sgld(net, batch_states, batch_action_means,
                            eps=current_eps, steps=params.ROBUST_PPO_PGD_STEPS,
                            stdev=stdev).mean()
                else:
                    raise ValueError(f"Unsupported robust PPO method {params.ROBUST_PPO_METHOD}")

                # Total loss, is the min of clipped and unclipped reward for each state, averaged.
                surrogate = -ch.min(unclp_rew, clp_rew).mean()
                entropy = -params.ENTROPY_COEFF * entropy_bonus
                loss = surrogate + entropy + params.ROBUST_PPO_REG * kl_upper_bound
                
                # If we are sharing weights, take the value step simultaneously 
                # (since the policy and value networks depend on the same weights)
                if params.SHARE_WEIGHTS:
                    tup = sel(returns, not_dones, old_vs)
                    batch_returns, batch_not_dones, batch_old_vs = tup
                    val_loss = value_step(batch_states, batch_returns, batch_advs,
                                          batch_not_dones, net.get_value, None, params,
                                          store, old_vs=batch_old_vs, opt_step=opt_step)
                    loss += params.VALUE_MULTIPLIER * val_loss

                # Optimizer step (Adam or SGD)
                if params.POLICY_ADAM is None:
                    grad = ch.autograd.grad(loss, net.parameters())
                    flat_grad = flatten(grad)
                    if params.CLIP_GRAD_NORM != -1:
                        norm_grad = ch.norm(flat_grad)
                        flat_grad = flat_grad if norm_grad <= params.CLIP_GRAD_NORM else \
                                    flat_grad / norm_grad * params.CLIP_GRAD_NORM

                    assign(flatten(net.parameters()) - params.PPO_LR * flat_grad, net.parameters())
                else:
                    params.POLICY_ADAM.zero_grad()
                    loss.backward()
                    if params.CLIP_GRAD_NORM != -1:
                        ch.nn.utils.clip_grad_norm(net.parameters(), params.CLIP_GRAD_NORM)
                    params.POLICY_ADAM.step()
        # Logging.
        kl_upper_bound = kl_upper_bound.item()
        surrogate = surrogate.item()
        entropy_bonus = entropy_bonus.item()
        print(f'eps={eps_scheduler.get_eps():8.6f}, beta={beta_scheduler.get_eps():8.6f}, kl={kl_upper_bound:10.5g}, '
              f'surrogate={surrogate:8.5f}, entropy={entropy_bonus:8.5f}, loss={loss.item():8.5f}')
    std = ch.exp(net.log_stdev)
    print(f'std_min={std.min().item():8.5f}, std_max={std.max().item():8.5f}, std_mean={std.mean().item():8.5f}')

    if store is not None:
        # TODO: ADV: add row name suffix
        row ={
            'eps': eps_scheduler.get_eps(),
            'beta': beta_scheduler.get_eps(),
            'kl': kl_upper_bound,
            'surrogate': surrogate,
            'entropy': entropy_bonus,
            'loss': loss.item(),
        }
        store.log_table_and_tb('robust_ppo_data', row)

    return loss.item(), surrogate, entropy_bonus
Exemplo n.º 6
0
def robust_ppo_step(all_states, actions, old_log_ps, rewards, returns,
                    not_dones, advs, net, params, store, opt_step, relaxed_net,
                    eps_scheduler, beta_scheduler):
    '''
    Proximal Policy Optimization with robustness regularizer
    Runs K epochs of PPO as in https://arxiv.org/abs/1707.06347
    Inputs:
    - all_states, the historical value of all the states
    - actions, the actions that the policy sampled
    - old_log_ps, the log probability of the actions that the policy sampled
    - advs, advantages as estimated by GAE
    - net, policy network to train [WILL BE MUTATED]
    - params, additional placeholder for parameters like EPS
    Returns:
    - The PPO loss; main job is to mutate the net
    '''
    # Storing batches of stuff
    # if store is not None:
    #     orig_dists = net(all_states)

    ### ACTUAL PPO OPTIMIZATION START
    if params.SHARE_WEIGHTS:
        orig_vs = net.get_value(all_states).squeeze(-1).view(
            [params.NUM_ACTORS, -1])
        old_vs = orig_vs.detach()

    # We treat all PPO epochs as one epoch.
    eps_scheduler.set_epoch_length(params.PPO_EPOCHS * params.NUM_MINIBATCHES)
    beta_scheduler.set_epoch_length(params.PPO_EPOCHS * params.NUM_MINIBATCHES)
    # We count from 1.
    eps_scheduler.step_epoch()
    beta_scheduler.step_epoch()

    for _ in range(params.PPO_EPOCHS):
        # State is in shape (experience_size, observation_size). Usually 2048.
        state_indices = np.arange(all_states.shape[0])
        np.random.shuffle(state_indices)
        # We use a minibatch of states to do optimization, and each epoch contains several iterations.
        splits = np.array_split(state_indices, params.NUM_MINIBATCHES)
        # A typical mini-batch size is 2048/32=64
        for selected in splits:

            def sel(*args):
                return [v[selected] for v in args]

            # old_log_ps: log probabilities of actions sampled based in experience buffer.
            # advs: advantages of these states.
            # both old_log_ps and advs are in shape (experience_size,) = 2048.
            tup = sel(all_states, actions, old_log_ps, advs)
            # select log probabilities, advantages of this minibatch.
            batch_states, batch_actions, batch_old_log_ps, batch_advs = tup

            # Forward propagation on current parameters (being constantly updated), to get distribution of these states
            # dist contains mean and variance of Gaussian.
            dist = net(batch_states)
            # Convert state distribution to log likelyhood.
            new_log_ps = net.get_loglikelihood(dist, batch_actions)

            shape_equal_cmp(new_log_ps, batch_old_log_ps)

            # Calculate rewards
            # the surrogate rewards is basically exp(new_log_ps - old_log_ps) * advantage
            # dimension is the same as minibatch size.
            unclp_rew = surrogate_reward(batch_advs,
                                         new=new_log_ps,
                                         old=batch_old_log_ps)
            clp_rew = surrogate_reward(batch_advs,
                                       new=new_log_ps,
                                       old=batch_old_log_ps,
                                       clip_eps=params.CLIP_EPS)

            # Calculate entropy bonus
            entropy_bonus = net.entropies(dist).mean()

            # Calculate regularizer under state perturbation.
            eps_scheduler.step_batch()
            beta_scheduler.step_batch()
            batch_action_means = dist[0]
            current_eps = eps_scheduler.get_eps()
            stdev = ch.exp(net.log_stdev)
            if params.ROBUST_PPO_DETACH_STDEV:
                # Detach stdev so that it won't be too large.
                stdev = stdev.detach()
            if params.ROBUST_PPO_METHOD == "convex-relax":
                kl_upper_bound = get_state_kl_bound(
                    relaxed_net,
                    batch_states,
                    batch_action_means,
                    eps=current_eps,
                    beta=beta_scheduler.get_eps(),
                    stdev=stdev).mean()
            elif params.ROBUST_PPO_METHOD == "sgld":
                kl_upper_bound = get_state_kl_bound_sgld(
                    net,
                    batch_states,
                    batch_action_means,
                    eps=current_eps,
                    steps=params.ROBUST_PPO_PGD_STEPS,
                    stdev=stdev).mean()
            else:
                raise ValueError(
                    f"Unsupported robust PPO method {params.ROBUST_PPO_METHOD}"
                )

            # Total loss, is the min of clipped and unclipped reward for each state, averaged.
            surrogate = -ch.min(unclp_rew, clp_rew).mean()
            entropy = -params.ENTROPY_COEFF * entropy_bonus
            loss = surrogate + entropy + params.ROBUST_PPO_REG * kl_upper_bound

            # If we are sharing weights, take the value step simultaneously
            # (since the policy and value networks depend on the same weights)
            if params.SHARE_WEIGHTS:
                tup = sel(returns, not_dones, old_vs)
                batch_returns, batch_not_dones, batch_old_vs = tup
                val_loss = value_step(batch_states,
                                      batch_returns,
                                      batch_advs,
                                      batch_not_dones,
                                      net.get_value,
                                      None,
                                      params,
                                      store,
                                      old_vs=batch_old_vs,
                                      opt_step=opt_step)
                loss += params.VALUE_MULTIPLIER * val_loss

            # Optimizer step (Adam or SGD)
            if params.POLICY_ADAM is None:
                grad = ch.autograd.grad(loss, net.parameters())
                flat_grad = flatten(grad)
                if params.CLIP_GRAD_NORM != -1:
                    norm_grad = ch.norm(flat_grad)
                    flat_grad = flat_grad if norm_grad <= params.CLIP_GRAD_NORM else \
                                flat_grad / norm_grad * params.CLIP_GRAD_NORM

                assign(
                    flatten(net.parameters()) - params.PPO_LR * flat_grad,
                    net.parameters())
            else:
                params.POLICY_ADAM.zero_grad()
                loss.backward()
                if params.CLIP_GRAD_NORM != -1:
                    ch.nn.utils.clip_grad_norm(net.parameters(),
                                               params.CLIP_GRAD_NORM)
                params.POLICY_ADAM.step()
        # Logging.
        kl_upper_bound = kl_upper_bound.item()
        surrogate = surrogate.item()
        entropy_bonus = entropy_bonus.item()
        print(
            f'eps={eps_scheduler.get_eps():8.6f}, beta={beta_scheduler.get_eps():8.6f}, kl={kl_upper_bound:8.5f}, '
            f'surrogate={surrogate:8.5f}, entropy={entropy_bonus:8.5f}, loss={loss.item():8.5f}'
        )
    std = ch.exp(net.log_stdev)
    print(
        f'std_min={std.min().item():8.5f}, std_max={std.max().item():8.5f}, std_mean={std.mean().item():8.5f}'
    )

    if store is not None:
        row = {
            'eps': eps_scheduler.get_eps(),
            'beta': beta_scheduler.get_eps(),
            'kl': kl_upper_bound,
            'surrogate': surrogate,
            'entropy': entropy_bonus,
            'loss': loss.item(),
        }
        store.log_table_and_tb('robust_ppo_data', row)

    return loss
Exemplo n.º 7
0
def ppo_step(all_states, actions, old_log_ps, rewards, returns, not_dones,
             advs, net, params, store, opt_step):
    '''
    Proximal Policy Optimization
    Runs K epochs of PPO as in https://arxiv.org/abs/1707.06347
    Inputs:
    - all_states, the historical value of all the states
    - actions, the actions that the policy sampled
    - old_log_ps, the log probability of the actions that the policy sampled
    - advs, advantages as estimated by GAE
    - net, policy network to train [WILL BE MUTATED]
    - params, additional placeholder for parameters like EPS
    Returns:
    - The PPO loss; main job is to mutate the net
    '''
    # Storing batches of stuff
    # if store is not None:
    #     orig_dists = net(all_states)

    ### ACTUAL PPO OPTIMIZATION START
    if params.SHARE_WEIGHTS:
        orig_vs = net.get_value(all_states).squeeze(-1).view(
            [params.NUM_ACTORS, -1])
        old_vs = orig_vs.detach()

    for _ in range(params.PPO_EPOCHS):
        # State is in shape (experience_size, observation_size). Usually 2048.
        state_indices = np.arange(all_states.shape[0])
        np.random.shuffle(state_indices)
        # We use a minibatch of states to do optimization, and each epoch contains several iterations.
        splits = np.array_split(state_indices, params.NUM_MINIBATCHES)
        # A typical mini-batch size is 2048/32=64
        for selected in splits:

            def sel(*args):
                return [v[selected] for v in args]

            # old_log_ps: log probabilities of actions sampled based in experience buffer.
            # advs: advantages of these states.
            # both old_log_ps and advs are in shape (experience_size,) = 2048.
            tup = sel(all_states, actions, old_log_ps, advs)
            # select log probabilities, advantages of this minibatch.
            batch_states, batch_actions, batch_old_log_ps, batch_advs = tup

            # Forward propagation on current parameters (being constantly updated), to get distribution of these states
            # dist contains mean and variance of Gaussian.
            dist = net(batch_states)
            # Convert state distribution to log likelyhood.
            new_log_ps = net.get_loglikelihood(dist, batch_actions)

            shape_equal_cmp(new_log_ps, batch_old_log_ps)

            # Calculate rewards
            # the surrogate rewards is basically exp(new_log_ps - old_log_ps) * advantage
            # dimension is the same as minibatch size.
            unclp_rew = surrogate_reward(batch_advs,
                                         new=new_log_ps,
                                         old=batch_old_log_ps)
            clp_rew = surrogate_reward(batch_advs,
                                       new=new_log_ps,
                                       old=batch_old_log_ps,
                                       clip_eps=params.CLIP_EPS)

            # Calculate entropy bonus
            entropy_bonus = net.entropies(dist).mean()

            # Total loss, is the min of clipped and unclipped reward for each state, averaged.
            surrogate = -ch.min(unclp_rew, clp_rew).mean()
            entropy = -params.ENTROPY_COEFF * entropy_bonus
            loss = surrogate + entropy

            # If we are sharing weights, take the value step simultaneously
            # (since the policy and value networks depend on the same weights)
            if params.SHARE_WEIGHTS:
                tup = sel(returns, not_dones, old_vs)
                batch_returns, batch_not_dones, batch_old_vs = tup
                val_loss = value_step(batch_states,
                                      batch_returns,
                                      batch_advs,
                                      batch_not_dones,
                                      net.get_value,
                                      None,
                                      params,
                                      store,
                                      old_vs=batch_old_vs,
                                      opt_step=opt_step)
                loss += params.VALUE_MULTIPLIER * val_loss

            # Optimizer step (Adam or SGD)
            if params.POLICY_ADAM is None:
                grad = ch.autograd.grad(loss, net.parameters())
                flat_grad = flatten(grad)
                if params.CLIP_GRAD_NORM != -1:
                    norm_grad = ch.norm(flat_grad)
                    flat_grad = flat_grad if norm_grad <= params.CLIP_GRAD_NORM else \
                                flat_grad / norm_grad * params.CLIP_GRAD_NORM

                assign(
                    flatten(net.parameters()) - params.PPO_LR * flat_grad,
                    net.parameters())
            else:
                params.POLICY_ADAM.zero_grad()
                loss.backward()
                if params.CLIP_GRAD_NORM != -1:
                    ch.nn.utils.clip_grad_norm(net.parameters(),
                                               params.CLIP_GRAD_NORM)
                params.POLICY_ADAM.step()
        print(
            f'surrogate={surrogate.item():8.5f}, entropy={entropy_bonus.item():8.5f}, loss={loss.item():8.5f}'
        )

    std = ch.exp(net.log_stdev)
    print(
        f'std_min={std.min().item():8.5f}, std_max={std.max().item():8.5f}, std_mean={std.mean().item():8.5f}'
    )

    return loss
Exemplo n.º 8
0
def trpo_step(all_states, actions, old_log_ps, rewards, returns, not_dones,
              advs, net, params, store, opt_step):
    '''
    Trust Region Policy Optimization
    Runs K epochs of TRPO as in https://arxiv.org/abs/1502.05477
    Inputs:
    - all_states, the historical value of all the states
    - actions, the actions that the policy sampled
    - old_log_ps, the probability of the actions that the policy sampled
    - advs, advantages as estimated by GAE
    - net, policy network to train [WILL BE MUTATED]
    - params, additional placeholder for parameters like EPS
    Returns:
    - The TRPO loss; main job is to mutate the net
    '''
    # Initial setup
    initial_parameters = flatten(net.parameters()).clone()
    pds = net(all_states)
    action_log_probs = net.get_loglikelihood(pds, actions)

    # Calculate losses
    surr_rew = surrogate_reward(advs, new=action_log_probs,
                                old=old_log_ps).mean()
    grad = ch.autograd.grad(surr_rew, net.parameters(), retain_graph=True)
    flat_grad = flatten(grad)

    if params.USE_CONJ:
        # Make fisher product estimator
        num_samples = int(all_states.shape[0] * params.FISHER_FRAC_SAMPLES)
        selected = np.random.choice(range(all_states.shape[0]),
                                    num_samples,
                                    replace=False)

        detached_selected_pds = select_prob_dists(pds, selected, detach=True)
        selected_pds = select_prob_dists(pds, selected, detach=False)

        kl = net.calc_kl(detached_selected_pds, selected_pds).mean()
        g = flatten(ch.autograd.grad(kl, net.parameters(), create_graph=True))

        def fisher_product(x, damp_coef=1.):
            contig_flat = lambda q: ch.cat(
                [y.contiguous().view(-1) for y in q])
            z = g @ x
            hv = ch.autograd.grad(z, net.parameters(), retain_graph=True)
            return contig_flat(hv).detach() + x * params.DAMPING * damp_coef

        # Find KL constrained gradient step
        step = cg_solve(fisher_product, flat_grad, params.CG_STEPS)

        max_step_coeff = (2 * params.MAX_KL /
                          (step @ fisher_product(step)))**(0.5)
        max_trpo_step = max_step_coeff * step

        if store and params.SHOULD_LOG_KL:
            kl_approximation_logging(all_states, pds, flat_grad, step, net,
                                     store)
            kl_vs_second_order_approx(all_states, pds, net, max_trpo_step,
                                      params, store, opt_step)
    else:
        max_trpo_step = flat_grad.clone() * params.PPO_LR_ADAM

    # Backtracking line search
    with ch.no_grad():
        # Backtracking function
        def backtrack_fn(s):
            assign(initial_parameters + s.data, net.parameters())
            test_pds = net(all_states)
            test_action_log_probs = net.get_loglikelihood(test_pds, actions)
            new_reward = surrogate_reward(advs,
                                          new=test_action_log_probs,
                                          old=old_log_ps).mean()
            if params.USE_CONS == 'all':
                if new_reward <= surr_rew or net.calc_kl(
                        pds, test_pds).mean() > params.MAX_KL:
                    return -float('inf')
            elif params.USE_CONS == 'kl':
                if net.calc_kl(pds, test_pds).mean() > params.MAX_KL:
                    return -float('inf')
            elif params.USE_CONS == 'rew':
                if new_reward <= surr_rew:
                    return -float('inf')
            elif params.USE_CONS == 'none':
                pass
            else:
                raise NotImplementedError("No such constraints")
            return new_reward - surr_rew

        expected_improve = flat_grad @ max_trpo_step
        final_step = backtracking_line_search(backtrack_fn,
                                              max_trpo_step,
                                              expected_improve,
                                              num_tries=params.MAX_BACKTRACK)

        assign(initial_parameters + final_step, net.parameters())

    return surr_rew
Exemplo n.º 9
0
def ppo_step(all_states, actions, old_log_ps, rewards, returns, not_dones,
             advs, net, params, store, opt_step):
    '''
    Proximal Policy Optimization
    Runs K epochs of PPO as in https://arxiv.org/abs/1707.06347
    Inputs:
    - all_states, the historical value of all the states
    - actions, the actions that the policy sampled
    - old_log_ps, the log probability of the actions that the policy sampled
    - advs, advantages as estimated by GAE
    - net, policy network to train [WILL BE MUTATED]
    - params, additional placeholder for parameters like EPS
    Returns:
    - The PPO loss; main job is to mutate the net
    '''
    # Storing batches of stuff
    if store is not None:
        orig_dists = net(all_states)

    ### ACTUAL PPO OPTIMIZATION START
    if params.SHARE_WEIGHTS:
        orig_vs = net.get_value(all_states).squeeze(-1).view(
            [params.NUM_ACTORS, -1])
        old_vs = orig_vs.detach()

    for _ in range(params.PPO_EPOCHS):
        state_indices = np.arange(all_states.shape[0])
        np.random.shuffle(state_indices)
        splits = np.array_split(state_indices, params.NUM_MINIBATCHES)
        for selected in splits:

            def sel(*args):
                return [v[selected] for v in args]

            tup = sel(all_states, actions, old_log_ps, advs)
            batch_states, batch_actions, batch_old_log_ps, batch_advs = tup

            dist = net(batch_states)
            new_log_ps = net.get_loglikelihood(dist, batch_actions)

            shape_equal_cmp(new_log_ps, batch_old_log_ps)

            # Calculate rewards
            unclp_rew = surrogate_reward(batch_advs,
                                         new=new_log_ps,
                                         old=batch_old_log_ps)
            clp_rew = surrogate_reward(batch_advs,
                                       new=new_log_ps,
                                       old=batch_old_log_ps,
                                       clip_eps=params.CLIP_EPS)

            # Calculate entropy bonus
            entropy_bonus = net.entropies(dist).mean()

            # Total loss
            surrogate = -ch.min(unclp_rew, clp_rew).mean()
            entropy = -params.ENTROPY_COEFF * entropy_bonus
            loss = surrogate + entropy

            # If we are sharing weights, take the value step simultaneously
            # (since the policy and value networks depend on the same weights)
            if params.SHARE_WEIGHTS:
                tup = sel(returns, not_dones, old_vs)
                batch_returns, batch_not_dones, batch_old_vs = tup
                val_loss = value_step(batch_states,
                                      batch_returns,
                                      batch_advs,
                                      batch_not_dones,
                                      net.get_value,
                                      None,
                                      params,
                                      store,
                                      old_vs=batch_old_vs,
                                      opt_step=opt_step)
                loss += params.VALUE_MULTIPLIER * val_loss

            # Optimizer step (Adam or SGD)
            if params.POLICY_ADAM is None:
                grad = ch.autograd.grad(loss, net.parameters())
                flat_grad = flatten(grad)
                if params.CLIP_GRAD_NORM != -1:
                    norm_grad = ch.norm(flat_grad)
                    flat_grad = flat_grad if norm_grad <= params.CLIP_GRAD_NORM else \
                                flat_grad / norm_grad * params.CLIP_GRAD_NORM

                assign(
                    flatten(net.parameters()) - params.PPO_LR * flat_grad,
                    net.parameters())
            else:
                params.POLICY_ADAM.zero_grad()
                loss.backward()
                if params.CLIP_GRAD_NORM != -1:
                    ch.nn.utils.clip_grad_norm(net.parameters(),
                                               params.CLIP_GRAD_NORM)
                params.POLICY_ADAM.step()

    return loss