Ejemplo n.º 1
0
def learn(
        env,
        model_path,
        data_path,
        policy_fn,
        *,
        horizon=150,  # timesteps per actor per update
        rolloutSize=50,
        clip_param=0.2,
        entcoeff=0.02,  # clipping parameter epsilon, entropy coeff
        optim_epochs=10,
        optim_stepsize=3e-4,
        optim_batchsize=32,  # optimization hypers
        gamma=0.99,
        lam=0.95,  # advantage estimation
        max_iters=0,  # time constraint
        adam_epsilon=1e-4,
        schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
        retrain=False):

    # Setup losses and policy
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_fn("pi", ob_space,
                   ac_space)  # Construct network for new policy
    oldpi = policy_fn("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return
    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param,
                             1.0 + clip_param) * atarg  #
    pol_surr = -tf.reduce_mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    adam.sync()

    # Prepare for rollouts
    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=5)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=5)  # rolling buffer for episode rewards

    p = []  # for saving the rollouts

    if retrain == True:
        print("Retraining the policy from saved path")
        time.sleep(2)
        U.load_state(model_path)
    max_timesteps = int(horizon * rolloutSize * max_iters)

    while True:
        if max_iters and iters_so_far >= max_iters:
            break
        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)
        print("Collecting samples for policy optimization !! ")
        if iters_so_far > 70:
            render = True
        else:
            render = False
        rollouts = sample_trajectory(pi,
                                     env,
                                     horizon=horizon,
                                     rolloutSize=rolloutSize,
                                     stochastic=True,
                                     render=render)
        # Save rollouts
        data = {'rollouts': rollouts}
        p.append(data)
        del data
        data_file_name = data_path + 'rollout_data.pkl'
        pickle.dump(p, open(data_file_name, "wb"))

        add_vtarg_and_adv(rollouts, gamma, lam)

        ob, ac, atarg, tdlamret = rollouts["ob"], rollouts["ac"], rollouts[
            "adv"], rollouts["tdlamret"]
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    deterministic=pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)

        lrlocal = (rollouts["ep_lens"], rollouts["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        logger.record_tabular("Success", rollouts["success"])
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()

    return pi
Ejemplo n.º 2
0
def learn(
        env,
        agent,
        optimizer,
        scheduler,
        comm,
        timesteps_per_actorbatch,  # timesteps per actor per update
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        optim_epochs,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,  # advantage estimation
        checkpoint_dir,
        model_name,
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        max_seconds=0,
        schedule='linear'):

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(agent, env, timesteps_per_actorbatch)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    gradient_steps_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "ent"]

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    while True:
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        logger.log("********** Iteration %i ************" % iters_so_far)

        epsilon_mult_dict = {
            'constant': 1.0,
            'linear': max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        }
        current_clip_param = epsilon_mult_dict[schedule] * clip_param

        seg = next(seg_gen)
        add_vtarg_and_adv(seg, gamma, lam)

        ob, ac, logprobs, adv, tdlamret = seg["ob"], seg["ac"], seg[
            "logprobs"], seg["adv"], seg["tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        adv = (adv - adv.mean()
               ) / adv.std()  # standardized advantage function estimate
        d = Dataset(dict(ob=ob,
                         ac=ac,
                         logprobs=logprobs,
                         adv=adv,
                         vtarg=tdlamret),
                    deterministic=False)  # nonrecurrent

        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        agent.train()
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                pol_surr, pol_entpen, vf_loss, ent = compute_losses(
                    batch, agent, entcoeff, current_clip_param)
                total_loss = pol_surr + pol_entpen + vf_loss

                optimizer.zero_grad()
                total_loss.backward()
                with tc.no_grad():
                    for p in agent.parameters():
                        g_old = p.grad.numpy()
                        g_new = np.zeros_like(g_old)
                        comm.Allreduce(sendbuf=g_old,
                                       recvbuf=g_new,
                                       op=MPI.SUM)
                        p.grad.copy_(
                            tc.tensor(g_new).float() / comm.Get_size())

                optimizer.step()
                scheduler.step()
                gradient_steps_so_far += 1

                # sync agent parameters from process with rank zero. should stay synced automatically,
                # this is just a failsafe
                if gradient_steps_so_far > 0 and gradient_steps_so_far % 100 == 0:
                    with tc.no_grad():
                        for p in agent.parameters():
                            p_data = p.data.numpy()
                            comm.Bcast(p_data, root=0)
                            p.data.copy_(tc.tensor(p_data).float())

                newlosses = (pol_surr.detach().numpy(),
                             pol_entpen.detach().numpy(),
                             vf_loss.detach().numpy(), ent.detach().numpy())
                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch, agent, entcoeff,
                                       current_clip_param)
            losses.append(
                tuple(
                    list(
                        map(lambda loss: loss.detach().numpy(),
                            list(newlosses)))))
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        if comm.Get_rank() == 0:
            logger.dump_tabular()
            if iters_so_far > 0 and iters_so_far % 10 == 0:
                print("Saving checkpoint...")
                os.makedirs(os.path.join(checkpoint_dir, model_name),
                            exist_ok=True)
                tc.save(agent.state_dict(),
                        os.path.join(checkpoint_dir, model_name, 'model.pth'))