Exemplo n.º 1
0
def learn(
        env,
        policy_func,
        disc,
        *,
        timesteps_per_batch,  # timesteps per actor per update
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        optim_epochs,
        optim_stepsize,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,  # advantage estimation
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        max_seconds=0,  # time constraint
        callback=None,  # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
        logdir=".",
        agentName="PPO-Agent",
        resume=0,
        num_parallel=0,
        num_cpu=1,
        num_extra=0,
        gan_batch_size=128,
        gan_num_epochs=5,
        gan_display_step=40,
        resume_disc=0,
        resume_non_disc=0,
        mocap_path="",
        gan_replay_buffer_size=1000000,
        gan_prob_to_put_in_replay=0.01,
        gan_reward_to_retrain_discriminator=5,
        use_distance=0,
        use_blend=0):
    # Deal with GAN
    if not use_distance:
        replay_buf = MyReplayBuffer(gan_replay_buffer_size)
    data = np.loadtxt(
        mocap_path + ".dat"
    )  #"D:/p4sw/devrel/libdev/flex/dev/rbd/data/bvh/motion_simple.dat");
    label = np.concatenate((np.ones(
        (data.shape[0], 1)), np.zeros((data.shape[0], 1))),
                           axis=1)

    print("Real data label = " + str(label))

    mocap_set = Dataset(dict(data=data, label=label), shuffle=True)

    # Setup losses and stuff
    # ----------------------------------------
    rank = MPI.COMM_WORLD.Get_rank()
    ob_space = env.observation_space
    ac_space = env.action_space

    ob_size = ob_space.shape[0]
    ac_size = ac_space.shape[0]

    #print("rank = " + str(rank) + " ob_space = "+str(ob_space.shape) + " ac_space = "+str(ac_space.shape))
    #exit(0)
    pi = policy_func("pi", ob_space,
                     ac_space)  # Construct network for new policy
    oldpi = policy_func("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = U.mean(kloldnew)
    meanent = U.mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = U.clip(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg  #
    pol_surr = -U.mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vfloss1 = tf.square(pi.vpred - ret)
    vpredclipped = oldpi.vpred + tf.clip_by_value(pi.vpred - oldpi.vpred,
                                                  -clip_param, clip_param)
    vfloss2 = tf.square(vpredclipped - ret)
    vf_loss = .5 * U.mean(
        tf.maximum(vfloss1, vfloss2)
    )  # we do the same clipping-based trust region for the value function
    #vf_loss = U.mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    adam.sync()

    # Prepare for rollouts
    # ----------------------------------------
    sess = tf.get_default_session()

    avars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    non_disc_vars = [
        a for a in avars
        if not a.name.split("/")[0].startswith("discriminator")
    ]
    disc_vars = [
        a for a in avars if a.name.split("/")[0].startswith("discriminator")
    ]
    #print(str(non_disc_names))
    #print(str(disc_names))
    #exit(0)
    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    disc_saver = tf.train.Saver(disc_vars, max_to_keep=None)
    non_disc_saver = tf.train.Saver(non_disc_vars, max_to_keep=None)
    saver = tf.train.Saver(max_to_keep=None)
    if resume > 0:
        saver.restore(
            tf.get_default_session(),
            os.path.join(os.path.abspath(logdir),
                         "{}-{}".format(agentName, resume)))
        if not use_distance:
            if os.path.exists(logdir + "\\" + 'replay_buf_' +
                              str(int(resume / 100) * 100) + '.pkl'):
                print("Load replay buf")
                with open(
                        logdir + "\\" + 'replay_buf_' +
                        str(int(resume / 100) * 100) + '.pkl', 'rb') as f:
                    replay_buf = pickle.load(f)
            else:
                print("Can't load replay buf " + logdir + "\\" +
                      'replay_buf_' + str(int(resume / 100) * 100) + '.pkl')
    iters_so_far = resume

    if resume_non_disc > 0:
        non_disc_saver.restore(
            tf.get_default_session(),
            os.path.join(
                os.path.abspath(logdir),
                "{}-{}".format(agentName + "_non_disc", resume_non_disc)))
        iters_so_far = resume_non_disc

    if use_distance:
        print("Use distance")
        nn = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(data)
    else:
        nn = None
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     disc,
                                     timesteps_per_batch,
                                     stochastic=True,
                                     num_parallel=num_parallel,
                                     num_cpu=num_cpu,
                                     rank=rank,
                                     ob_size=ob_size,
                                     ac_size=ac_size,
                                     com=MPI.COMM_WORLD,
                                     num_extra=num_extra,
                                     iters_so_far=iters_so_far,
                                     use_distance=use_distance,
                                     nn=nn)

    if resume_disc > 0:
        disc_saver.restore(
            tf.get_default_session(),
            os.path.join(os.path.abspath(logdir),
                         "{}-{}".format(agentName + "_disc", resume_disc)))

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"
    logF = open(logdir + "\\" + 'log.txt', 'a')
    logR = open(logdir + "\\" + 'log_rew.txt', 'a')
    logStats = open(logdir + "\\" + 'log_stats.txt', 'a')
    if os.path.exists(logdir + "\\" + 'ob_list_' + str(rank) + '.pkl'):
        with open(logdir + "\\" + 'ob_list_' + str(rank) + '.pkl', 'rb') as f:
            ob_list = pickle.load(f)
    else:
        ob_list = []

    dump_training = 0
    learn_from_training = 0
    if dump_training:
        # , "mean": pi.ob_rms.mean, "std": pi.ob_rms.std
        saverRMS = tf.train.Saver({
            "_sum": pi.ob_rms._sum,
            "_sumsq": pi.ob_rms._sumsq,
            "_count": pi.ob_rms._count
        })
        saverRMS.save(tf.get_default_session(),
                      os.path.join(os.path.abspath(logdir), "rms.tf"))

        ob_np_a = np.asarray(ob_list)
        ob_np = np.reshape(ob_np_a, (-1, ob_size))
        [vpred, pdparam] = pi._vpred_pdparam(ob_np)

        print("vpred = " + str(vpred))
        print("pd_param = " + str(pdparam))
        with open('training.pkl', 'wb') as f:
            pickle.dump(ob_np, f)
            pickle.dump(vpred, f)
            pickle.dump(pdparam, f)
        exit(0)
    if learn_from_training:
        # , "mean": pi.ob_rms.mean, "std": pi.ob_rms.std

        with open('training.pkl', 'rb') as f:
            ob_np = pickle.load(f)
            vpred = pickle.load(f)
            pdparam = pickle.load(f)
        num = ob_np.shape[0]
        for i in range(num):
            xp = ob_np[i][1]
            ob_np[i][1] = 0.0
            ob_np[i][18] -= xp
            ob_np[i][22] -= xp
            ob_np[i][24] -= xp
            ob_np[i][26] -= xp
            ob_np[i][28] -= xp
            ob_np[i][30] -= xp
            ob_np[i][32] -= xp
            ob_np[i][34] -= xp
        print("ob_np = " + str(ob_np))
        print("vpred = " + str(vpred))
        print("pdparam = " + str(pdparam))
        batch_size = 128

        y_vpred = tf.placeholder(tf.float32, [
            batch_size,
        ])
        y_pdparam = tf.placeholder(tf.float32, [batch_size, pdparam.shape[1]])

        vpred_loss = U.mean(tf.square(pi.vpred - y_vpred))
        vpdparam_loss = U.mean(tf.square(pi.pdparam - y_pdparam))

        total_train_loss = vpred_loss + vpdparam_loss
        #total_train_loss = vpdparam_loss
        #total_train_loss = vpred_loss
        #coef = 0.01
        #dense_all = U.dense_all
        #for a in dense_all:
        #   total_train_loss += coef * tf.nn.l2_loss(a)
        #total_train_loss = vpdparam_loss
        optimizer = tf.train.AdamOptimizer(
            learning_rate=0.001).minimize(total_train_loss)
        d = Dataset(dict(ob=ob_np, vpred=vpred, pdparam=pdparam),
                    shuffle=not pi.recurrent)
        sess = tf.get_default_session()
        sess.run(tf.global_variables_initializer())
        saverRMS = tf.train.Saver({
            "_sum": pi.ob_rms._sum,
            "_sumsq": pi.ob_rms._sumsq,
            "_count": pi.ob_rms._count
        })
        saverRMS.restore(tf.get_default_session(),
                         os.path.join(os.path.abspath(logdir), "rms.tf"))
        if resume > 0:
            saver.restore(
                tf.get_default_session(),
                os.path.join(os.path.abspath(logdir),
                             "{}-{}".format(agentName, resume)))

        for q in range(100):
            sumLoss = 0
            for batch in d.iterate_once(batch_size):
                tl, _ = sess.run(
                    [total_train_loss, optimizer],
                    feed_dict={
                        pi.ob: batch["ob"],
                        y_vpred: batch["vpred"],
                        y_pdparam: batch["pdparam"]
                    })
                sumLoss += tl
            print("Iteration " + str(q) + " Loss = " + str(sumLoss))
        assign_old_eq_new()  # set old parameter values to new parameter values

        # Save as frame 1
        try:
            saver.save(tf.get_default_session(),
                       os.path.join(logdir, agentName),
                       global_step=1)
        except:
            pass
        #exit(0)
    if resume > 0:
        firstTime = False
    else:
        firstTime = True

    # Check accuracy
    #amocap = sess.run([disc.accuracy],
    #                feed_dict={disc.input: data,
    #                           disc.label: label})
    #print("Mocap accuracy = " + str(amocap))
    #print("Mocap label is " + str(label))

    #adata = np.array(replay_buf._storage)
    #print("adata shape = " + str(adata.shape))
    #alabel = np.concatenate((np.zeros((adata.shape[0], 1)), np.ones((adata.shape[0], 1))), axis=1)

    #areplay = sess.run([disc.accuracy],
    #                feed_dict={disc.input: adata,
    #                           disc.label: alabel})
    #print("Replay accuracy = " + str(areplay))
    #print("Replay label is " + str(alabel))
    #exit(0)
    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()

        add_vtarg_and_adv(seg, gamma, lam, timesteps_per_batch, num_parallel,
                          num_cpu)
        #print(" ob= " + str(seg["ob"])+ " rew= " + str(seg["rew"])+ " vpred= " + str(seg["vpred"])+ " new= " + str(seg["new"])+ " ac= " + str(seg["ac"])+ " prevac= " + str(seg["prevac"])+ " nextvpred= " + str(seg["nextvpred"])+ " ep_rets= " + str(seg["ep_rets"])+ " ep_lens= " + str(seg["ep_lens"]))

        #exit(0)
        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret, extra = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"], seg["extra"]

        #ob_list.append(ob.tolist())
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            #print(str(losses))
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        rewmean = np.mean(rewbuffer)
        logger.record_tabular("EpRewMean", rewmean)
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)

        # Train discriminator
        if not use_distance:
            print("Put in replay buf " +
                  str((int)(gan_prob_to_put_in_replay * extra.shape[0] + 1)))
            replay_buf.add(extra[np.random.choice(
                extra.shape[0],
                (int)(gan_prob_to_put_in_replay * extra.shape[0] + 1),
                replace=True)])
            #if iters_so_far == 1:
            if not use_blend:
                if firstTime:
                    firstTime = False
                    # Train with everything we got
                    lb = np.concatenate((np.zeros(
                        (extra.shape[0], 1)), np.ones((extra.shape[0], 1))),
                                        axis=1)
                    extra_set = Dataset(dict(data=extra, label=lb),
                                        shuffle=True)
                    for e in range(10):
                        i = 0
                        for mbatch in mocap_set.iterate_once(gan_batch_size):
                            batch = extra_set.next_batch(gan_batch_size)
                            _, l = sess.run(
                                [disc.optimizer_first, disc.loss],
                                feed_dict={
                                    disc.input:
                                    np.concatenate(
                                        (mbatch['data'], batch['data'])),
                                    disc.label:
                                    np.concatenate(
                                        (mbatch['label'], batch['label']))
                                })
                            i = i + 1
                            # Display logs per step
                            if i % gan_display_step == 0 or i == 1:
                                print(
                                    'discriminator epoch %i Step %i: Minibatch Loss: %f'
                                    % (e, i, l))
                        print(
                            'discriminator epoch %i Step %i: Minibatch Loss: %f'
                            % (e, i, l))
                if seg['mean_ext_rew'] > gan_reward_to_retrain_discriminator:
                    for e in range(gan_num_epochs):
                        i = 0
                        for mbatch in mocap_set.iterate_once(gan_batch_size):
                            data = replay_buf.sample(mbatch['data'].shape[0])
                            lb = np.concatenate((np.zeros(
                                (data.shape[0], 1)), np.ones(
                                    (data.shape[0], 1))),
                                                axis=1)
                            _, l = sess.run(
                                [disc.optimizer, disc.loss],
                                feed_dict={
                                    disc.input:
                                    np.concatenate((mbatch['data'], data)),
                                    disc.label:
                                    np.concatenate((mbatch['label'], lb))
                                })
                            i = i + 1
                            # Display logs per step
                            if i % gan_display_step == 0 or i == 1:
                                print(
                                    'discriminator epoch %i Step %i: Minibatch Loss: %f'
                                    % (e, i, l))
                        print(
                            'discriminator epoch %i Step %i: Minibatch Loss: %f'
                            % (e, i, l))
            else:
                if firstTime:
                    firstTime = False
                    # Train with everything we got
                    extra_set = Dataset(dict(data=extra), shuffle=True)
                    for e in range(10):
                        i = 0
                        for mbatch in mocap_set.iterate_once(gan_batch_size):
                            batch = extra_set.next_batch(gan_batch_size)
                            bf = np.random.uniform(0, 1, (gan_batch_size, 1))
                            onembf = 1 - bf
                            my_label = np.concatenate((bf, onembf), axis=1)
                            my_data = np.multiply(mbatch['data'],
                                                  bf) + np.multiply(
                                                      batch['data'], onembf)
                            _, l = sess.run([disc.optimizer_first, disc.loss],
                                            feed_dict={
                                                disc.input: my_data,
                                                disc.label: my_label
                                            })
                            i = i + 1
                            # Display logs per step
                            if i % gan_display_step == 0 or i == 1:
                                print(
                                    'discriminator epoch %i Step %i: Minibatch Loss: %f'
                                    % (e, i, l))
                        print(
                            'discriminator epoch %i Step %i: Minibatch Loss: %f'
                            % (e, i, l))
                if seg['mean_ext_rew'] > gan_reward_to_retrain_discriminator:
                    for e in range(gan_num_epochs):
                        i = 0
                        for mbatch in mocap_set.iterate_once(gan_batch_size):
                            data = replay_buf.sample(mbatch['data'].shape[0])

                            bf = np.random.uniform(0, 1, (gan_batch_size, 1))
                            onembf = 1 - bf
                            my_label = np.concatenate((bf, onembf), axis=1)
                            my_data = np.multiply(mbatch['data'],
                                                  bf) + np.multiply(
                                                      data, onembf)

                            _, l = sess.run([disc.optimizer_first, disc.loss],
                                            feed_dict={
                                                disc.input: my_data,
                                                disc.label: my_label
                                            })
                            i = i + 1
                            # Display logs per step
                            if i % gan_display_step == 0 or i == 1:
                                print(
                                    'discriminator epoch %i Step %i: Minibatch Loss: %f'
                                    % (e, i, l))
                        print(
                            'discriminator epoch %i Step %i: Minibatch Loss: %f'
                            % (e, i, l))

        # if True:
        #     lb = np.concatenate((np.zeros((extra.shape[0],1)),np.ones((extra.shape[0],1))),axis=1)
        #     extra_set = Dataset(dict(data=extra,label=lb), shuffle=True)
        #     num_r = 1
        #     if iters_so_far == 1:
        #         num_r = gan_num_epochs
        #     for e in range(num_r):
        #         i = 0
        #         for batch in extra_set.iterate_once(gan_batch_size):
        #             mbatch = mocap_set.next_batch(gan_batch_size)
        #             _, l = sess.run([disc.optimizer, disc.loss], feed_dict={disc.input: np.concatenate((mbatch['data'],batch['data'])), disc.label: np.concatenate((mbatch['label'],batch['label']))})
        #             i = i + 1
        #             # Display logs per step
        #             if i % gan_display_step == 0 or i == 1:
        #                 print('discriminator epoch %i Step %i: Minibatch Loss: %f' % (e, i, l))
        #         print('discriminator epoch %i Step %i: Minibatch Loss: %f' % (e, i, l))

        if not use_distance:
            if iters_so_far % 100 == 0:
                with open(
                        logdir + "\\" + 'replay_buf_' + str(iters_so_far) +
                        '.pkl', 'wb') as f:
                    pickle.dump(replay_buf, f)

        with open(logdir + "\\" + 'ob_list_' + str(rank) + '.pkl', 'wb') as f:
            pickle.dump(ob_list, f)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logF.write(str(rewmean) + "\n")
            logR.write(str(seg['mean_ext_rew']) + "\n")
            logStats.write(logger.get_str() + "\n")
            logF.flush()
            logStats.flush()
            logR.flush()

            logger.dump_tabular()

            try:
                os.remove(logdir + "/checkpoint")
            except OSError:
                pass
            try:
                saver.save(tf.get_default_session(),
                           os.path.join(logdir, agentName),
                           global_step=iters_so_far)
            except:
                pass
            try:
                non_disc_saver.save(tf.get_default_session(),
                                    os.path.join(logdir,
                                                 agentName + "_non_disc"),
                                    global_step=iters_so_far)
            except:
                pass
            try:
                disc_saver.save(tf.get_default_session(),
                                os.path.join(logdir, agentName + "_disc"),
                                global_step=iters_so_far)
            except:
                pass
Exemplo n.º 2
0
def learn(
    env,
    policy_func,
    *,
    timesteps_per_actorbatch,  # timesteps per actor per update
    clip_param,
    entcoeff,  # clipping parameter epsilon, entropy coeff
    optim_epochs,
    optim_stepsize,
    optim_batchsize,  # optimization hypers
    gamma,
    lam,  # advantage estimation
    max_timesteps=0,
    max_episodes=0,
    max_iters=0,
    max_seconds=0,  # time constraint
    callback=None,  # you can do anything in the callback, since it takes locals(), globals()
    adam_epsilon=1e-5,
    schedule='constant'  # annealing for stepsize parameters (epsilon and adam)
):
    print(
        "----------------------------------------------------------------------------Learning is here ..."
    )
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_func("pi", ob_space,
                     ac_space)  # Construct network for new policy
    oldpi = policy_func("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param,
                             1.0 + clip_param) * atarg  #
    pol_surr = -tf.reduce_mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    adam.sync()

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_actorbatch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    # my
    #sess = tf.get_default_session()
    #writer = tf.summary.FileWriter("/home/chris/openai_logdir/tb2")
    #placeholder_rews = tf.placeholder(tf.float32)
    #placeholder_vpreds = tf.placeholder(tf.float32)
    #placeholder_advs = tf.placeholder(tf.float32)
    #placeholder_news = tf.placeholder(tf.float32)
    #placeholder_ep_rews = tf.placeholder(tf.float32)

    #tf.summary.histogram("rews", placeholder_rews)
    #tf.summary.histogram("vpreds", placeholder_vpreds)
    #tf.summary.histogram("advs", placeholder_advs)
    #tf.summary.histogram("news", placeholder_news)
    #tf.summary.scalar("ep_rews", placeholder_ep_rews)

    #writer = SummaryWriter("/home/chris/openai_logdir/tb2")

    #placeholder_ep_rews = tf.placeholder(tf.float32)
    #placeholder_ep_vpred = tf.placeholder(tf.float32)
    #placeholder_ep_atarg = tf.placeholder(tf.float32)

    #sess = tf.get_default_session()
    #writer = tf.summary.FileWriter("/home/chris/openai_logdir/x_new/tb2")
    #tf.summary.scalar("EpRews", placeholder_ep_rews)
    #tf.summary.scalar("EpVpred", placeholder_ep_vpred)

    #tf.summary.scalar("EpRews", placeholder_ep_rews)
    #tf.summary.scalar("EpVpred", placeholder_ep_vpred)
    #tf.summary.scalar("EpAtarg", placeholder_ep_atarg)
    #summ = tf.summary.merge_all()

    time_step_idx = 0

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()

        # Compute target value using TD(lambda) estimator, and advantage with GAE(lambda)
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate

        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate

        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))

        # ep_lens and ep_rets
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        #TODO make MPI compatible
        path_lens = seg["path_lens"]
        path_rews = seg["path_rets"]

        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        timesteps_so_far += sum(path_lens)  # my add path lens
        print("timesteps_so_far: %i" % timesteps_so_far)

        iters_so_far += 1
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)

        #my
        logger.record_tabular("PathLenMean", np.mean(path_lens))
        logger.record_tabular("PathRewMean", np.mean(path_rews))

        # MY
        dir = "/home/chris/openai_logdir/"
        env_name = env.unwrapped.spec.id

        #plots

        plt.figure(figsize=(10, 4))
        # plt.plot(vpredbefore, animated=True, label="vpredbefore")
        #new_with_nones = seg["new"]
        #np.place(new_with_nones, new_with_nones == 0, [6])
        #plt.plot(new_with_nones, 'r.', animated=True, label="new")

        for dd in np.where(seg["new"] == 1)[0]:
            plt.axvline(x=dd, color='green', linewidth=1)
            #plt.annotate('asd', xy=(2, 1), xytext=(3, 1.5),)

        #np.place(new_episode_with_nones, new_episode_with_nones == 0, [90])
        #aaaa = np.where(seg["news_episode"] > 0)
        #for dd in aaaa[0]:
        #    plt.annotate('ne'+str(dd), xy=(0, 6), xytext=(3, 1.5),)
        #plt.axvline(x=dd, color='green', linewidth=1)

        #plt.plot(ac, animated=True, label="ac")

        plt.plot(vpredbefore, 'g', label="vpredbefore", antialiased=True)

        # plot advantage of episode time steps
        #plt.plot(seg["adv"], 'b', animated=True, label="adv")
        #plt.plot(atarg, 'r', animated=True, label="atarg")
        plt.plot(tdlamret, 'y', animated=True, label="tdlamret")

        plt.legend()
        plt.title('iters_so_far: ' + str(iters_so_far))
        plt.savefig(dir + env_name + '_plot.png', dpi=300)

        if iters_so_far % 2 == 0 or iters_so_far == 0:
            #plt.ylim(ymin=-10, ymax=100)
            #plt.ylim(ymin=-15, ymax=15)

            plt.savefig(dir + '/plotiters/' + env_name + '_plot' + '_iter' +
                        str(iters_so_far).zfill(3) + '.png',
                        dpi=300)

        plt.clf()
        plt.close()

        # 3d V obs
        #ac, vpred = pi.act(True, ob)
        # check ob dim
        freq = 1  # 5
        # <= 3?
        if env.observation_space.shape[0] <= 3 and (iters_so_far % freq == 0
                                                    or iters_so_far == 1):

            figV, axV = plt.subplots()

            # surface
            #figV = plt.figure()
            #axV = Axes3D(figV)

            obs = env.observation_space

            X = np.arange(obs.low[0], obs.high[0],
                          (obs.high[0] - obs.low[0]) / 30)
            Y = np.arange(obs.low[1], obs.high[1],
                          (obs.high[1] - obs.low[1]) / 30)
            X, Y = np.meshgrid(X, Y)

            Z = np.zeros((len(X), len(Y)))

            for x in range(len(X)):
                for y in range(len(Y)):
                    #strange datatype needed??
                    myob = np.copy(ob[0])
                    myob[0] = X[0][x]
                    myob[1] = Y[y][0]
                    stochastic = True
                    ac, vpred = pi.act(stochastic, myob)
                    Z[x][y] = vpred

            plt.xlabel('First D')
            plt.ylabel('Second D')
            #plt.clabel('Value-Function')
            plt.title(env_name + ' iteration: ' + str(iters_so_far))

            # surface
            #axV.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=cm.coolwarm)

            # heat map
            imgplot = plt.imshow(Z, interpolation='nearest')
            imgplot.set_cmap('hot')
            plt.colorbar()

            figV.savefig(dir + env_name + '_V.png', dpi=300)

            if iters_so_far % 2 == 0 or iters_so_far == 0:
                figV.savefig(dir + '/plotiters/' + env_name + '_plot' +
                             '_iter' + str(iters_so_far).zfill(3) + '_V' +
                             '.png',
                             dpi=300)

            figV.clf()
            plt.close(figV)
        """
        # transfer timesteps of iterations into timesteps of episodes
        idx_seg = 0
        for ep in range(len(lens)) :
            all_ep = episodes_so_far - len(lens) + ep

            if all_ep%100==0:
                break

            ep_rews = seg["rew"][idx_seg:idx_seg+lens[ep]]
            ep_vpred = seg["vpred"][idx_seg:idx_seg+lens[ep]]
            ep_atarg = atarg[idx_seg:idx_seg+lens[ep]]

            idx_seg += lens[ep]
            #writer.add_histogram("ep_vpred", data, iters_so_far)
            #hist_dict[placeholder_ep_rews] = ep_rews[ep]
            #sess2 = sess.run(summ, feed_dict=hist_dict)

            #summary = tf.Summary()

            #if test_ep:
            #    break
            for a in range(len(ep_rews)):
                d = ep_rews[a]
                d2 = ep_vpred[a]
                d3 = ep_atarg[a]

                #summary.value.add(tag="EpRews/"+str(all_ep), simple_value=d)
                #writer.add_summary(summary, a)

                sess2 = sess.run(summ, feed_dict={placeholder_ep_rews: d, placeholder_ep_vpred: d2, placeholder_ep_atarg: d3})
            #writer.add_summary(sess2, global_step=a)
                writer.add_summary(sess2, global_step=time_step_idx)
                time_step_idx += 1

            writer.flush()
            #writer.close()
            #logger.record_tabular("vpred_e" + str(all_ep), ep_rews[ep])

        """
        """
        ###
        sess2 = sess.run(summ, feed_dict={placeholder_rews: seg["rew"],
                                          placeholder_vpreds: seg["vpred"],
                                          placeholder_advs: seg["adv"],
                                          placeholder_news: seg["new"]})
        writer.add_summary(sess2, global_step=episodes_so_far)
        """

        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()
Exemplo n.º 3
0
def enjoy(
        env,
        policy_func,
        *,
        timesteps_per_actorbatch,  # timesteps per actor per update
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        optim_epochs,
        optim_stepsize,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,  # advantage estimation
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        max_seconds=0,  # time constraint
        callback=None,  # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
        save_name=None,
        save_per_acts=3,
        reload_name=None):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_func("pi", ob_space,
                     ac_space)  # Construct network for new policy
    oldpi = policy_func("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param,
                             1.0 + clip_param) * atarg  #
    pol_surr = -tf.reduce_mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    adam.sync()

    if reload_name:
        saver = tf.train.Saver()
        saver.restore(tf.get_default_session(), reload_name)
        print("Loaded model successfully.")

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_actorbatch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
Exemplo n.º 4
0
    def replay(self, seg_list, batch_size):
        print(self.scope + " training")

        if self.schedule == 'constant':
            cur_lrmult = 1.0
        elif self.schedule == 'linear':
            cur_lrmult = max(
                1.0 - float(self.timesteps_so_far) / self.max_timesteps, 0)

        # Here we do a bunch of optimization epochs over the data
        # 批量计算的思路是,每次将所有战斗的g值得到,然后求平均,优化。循环多次
        newlosses_list = []
        logger.log("Optimizing...")
        loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]
        logger.log(fmt_row(13, loss_names))
        for _ in range(self.optim_epochs):
            g_list = []
            for seg in seg_list:

                self.add_vtarg_and_adv(seg, self.gamma, self.lam)

                # print(seg)

                # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
                ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg[
                    "adv"], seg["tdlamret"]
                vpredbefore = seg[
                    "vpred"]  # predicted value function before udpate
                atarg = (atarg - atarg.mean()) / atarg.std(
                )  # standardized advantage function estimate
                d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                            shuffle=not self.pi.recurrent)

                if hasattr(self.pi, "ob_rms"):
                    self.pi.ob_rms.update(
                        ob)  # update running mean/std for policy

                self.assign_old_eq_new(
                )  # set old parameter values to new parameter values

                # 完整的拿所有行为
                batch = d.next_batch(d.n)
                # print("ob", batch["ob"], "ac", batch["ac"], "atarg", batch["atarg"], "vtarg", batch["vtarg"])
                *newlosses, debug_atarg, pi_ac, opi_ac, vpred, pi_pd, opi_pd, kl_oldnew, total_loss, var_list, grads, g = \
                    self.lossandgrad(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult)
                # print("debug_atarg", debug_atarg, "pi_ac", pi_ac, "opi_ac", opi_ac, "vpred", vpred, "pi_pd", pi_pd,
                #       "opi_pd", opi_pd, "kl_oldnew", kl_oldnew, "var_mean", np.mean(g), "total_loss", total_loss)
                if np.isnan(np.mean(g)):
                    print('output nan, ignore it!')
                else:
                    g_list.append(g)
                    newlosses_list.append(newlosses)

            # 批量计算之后求平均在优化模型
            if len(g_list) > 0:
                avg_g = np.mean(g_list, axis=0)
                self.adam.update(avg_g, self.optim_stepsize * cur_lrmult)
                logger.log(fmt_row(13, np.mean(newlosses_list, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for seg in seg_list:
            self.add_vtarg_and_adv(seg, self.gamma, self.lam)

            # print(seg)

            # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
            ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
                "tdlamret"]
            vpredbefore = seg[
                "vpred"]  # predicted value function before udpate
            atarg = (atarg - atarg.mean()) / atarg.std(
            )  # standardized advantage function estimate
            d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                        shuffle=not self.pi.recurrent)
            # 完整的拿所有行为
            batch = d.next_batch(d.n)
            newlosses = self.compute_losses(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
            losses.append(newlosses)
        print(losses)

        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            if np.isinf(lossval):
                debug = True
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(self.flatten_lists, zip(*listoflrpairs))
        self.lenbuffer.extend(lens)
        self.rewbuffer.extend(rews)
        last_rew = self.rewbuffer[-1] if len(self.rewbuffer) > 0 else 0
        logger.record_tabular("LastRew", last_rew)
        logger.record_tabular(
            "LastLen", 0 if len(self.lenbuffer) <= 0 else self.lenbuffer[-1])
        logger.record_tabular("EpLenMean", np.mean(self.lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(self.rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        self.episodes_so_far += len(lens)
        self.timesteps_so_far += sum(lens)
        self.iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", self.episodes_so_far)
        logger.record_tabular("TimestepsSoFar", self.timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - self.tstart)
        logger.record_tabular("IterSoFar", self.iters_so_far)
        logger.record_tabular("CalulateActions", self.act_times)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()
Exemplo n.º 5
0
def learn(
        env,
        policy_fn,
        timesteps_per_actorbatch,  # timesteps per actor per update
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        optim_epochs,
        optim_stepsize,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,  # advantage estimation
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        max_seconds=0,  # time constraint
        callback=None,  # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
        save_per_iter=50,
        max_sample_traj=10,
        ckpt_dir=None,
        log_dir=None,
        task_name="origin",
        sample_stochastic=True,
        load_model_path=None,
        task="train"):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_fn("pi", ob_space,
                   ac_space)  # Construct network for new policy
    oldpi = policy_fn("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed clipping parameter epsilon

    ob_p = U.get_placeholder_cached(name="ob_physics")
    ob_f = U.get_placeholder_cached(name="ob_frames")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param,
                             1.0 + clip_param) * atarg  #
    pol_surr = -tf.reduce_mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    if pi.recurrent:
        state_c = U.get_placeholder_cached(name="state_c")
        state_h = U.get_placeholder_cached(name="state_h")
        lossandgrad = U.function(
            [ob_p, ob_f, state_c, state_h, ac, atarg, ret, lrmult],
            losses + [U.flatgrad(total_loss, var_list)])
        compute_losses = U.function(
            [ob_p, ob_f, state_c, state_h, ac, atarg, ret, lrmult], losses)
    else:
        lossandgrad = U.function([ob_p, ob_f, ac, atarg, ret, lrmult],
                                 losses + [U.flatgrad(total_loss, var_list)])
        compute_losses = U.function([ob_p, ob_f, ac, atarg, ret, lrmult],
                                    losses)

    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])

    writer = U.FileWriter(log_dir)
    U.initialize()
    adam.sync()

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_actorbatch,
                                     stochastic=True)
    traj_gen = traj_episode_generator(pi,
                                      env,
                                      timesteps_per_actorbatch,
                                      stochastic=sample_stochastic)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    loss_stats = stats(loss_names)
    ep_stats = stats(["Reward", "Episode_Length", "Episode_This_Iter"])
    if task == "sample_trajectory":
        sample_trajectory(load_model_path, max_sample_traj, traj_gen,
                          task_name, sample_stochastic)
        sys.exit()

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        # Save model
        if iters_so_far % save_per_iter == 0 and ckpt_dir is not None:
            U.save_state(os.path.join(ckpt_dir, task_name),
                         counter=iters_so_far)

        logger.log2("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.next()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        ob_p, ob_f = zip(*ob)
        ob_p = np.array(ob_p)
        ob_f = np.array(ob_f)
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate

        if pi.recurrent:
            sc, sh = zip(*seg["state"])
            d = Dataset(dict(ob_p=ob_p,
                             ob_f=ob_f,
                             ac=ac,
                             atarg=atarg,
                             vtarg=tdlamret,
                             state_c=np.array(sc),
                             state_h=np.array(sh)),
                        shuffle=not pi.recurrent)
        else:
            d = Dataset(dict(ob_p=ob_p,
                             ob_f=ob_f,
                             ac=ac,
                             atarg=atarg,
                             vtarg=tdlamret),
                        shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob_p)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log2("Optimizing...")
        logger.log2(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                if pi.recurrent:
                    lg = lossandgrad(batch["ob_p"], batch["ob_f"],
                                     batch["state_c"], batch["state_h"],
                                     batch["ac"], batch["atarg"],
                                     batch["vtarg"], cur_lrmult)
                else:
                    lg = lossandgrad(batch["ob_p"], batch["ob_f"], batch["ac"],
                                     batch["atarg"], batch["vtarg"],
                                     cur_lrmult)
                newlosses = lg[:-1]
                g = lg[-1]
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            logger.log2(fmt_row(13, np.mean(losses, axis=0)))

        logger.log2("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            if pi.recurrent:
                newlosses = compute_losses(batch["ob_p"], batch["ob_f"],
                                           batch["state_c"], batch["state_h"],
                                           batch["ac"], batch["atarg"],
                                           batch["vtarg"], cur_lrmult)
            else:
                newlosses = compute_losses(batch["ob_p"], batch["ob_f"],
                                           batch["ac"], batch["atarg"],
                                           batch["vtarg"], cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log2(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()
            loss_stats.add_all_summary(writer, meanlosses, iters_so_far)
            ep_stats.add_all_summary(
                writer, [np.mean(rewbuffer),
                         np.mean(lenbuffer),
                         len(lens)], iters_so_far)

    return pi
Exemplo n.º 6
0
def learn(env, policy_func, reward_giver, expert_dataset, rank,
          pretrained, pretrained_weight, *,
          g_step, d_step, entcoeff, save_per_iter,
          ckpt_dir, log_dir, timesteps_per_batch, task_name,
          gamma, lam,
          max_kl, cg_iters, cg_damping=1e-2,
          vf_stepsize=3e-4, d_stepsize=3e-4, vf_iters=3,
          max_timesteps=0, max_episodes=0, max_iters=0,
          callback=None
          ):

    nworkers = MPI.COMM_WORLD.Get_size()
    rank = MPI.COMM_WORLD.Get_rank()
    np.set_printoptions(precision=3)
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_func("pi", ob_space, ac_space, reuse=(pretrained_weight != None))
    oldpi = policy_func("oldpi", ob_space, ac_space)
    atarg = tf.placeholder(dtype=tf.float32, shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    entbonus = entcoeff * meanent

    vferr = tf.reduce_mean(tf.square(pi.vpred - ret))

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # advantage * pnew / pold
    surrgain = tf.reduce_mean(ratio * atarg)

    optimgain = surrgain + entbonus
    losses = [optimgain, meankl, entbonus, surrgain, meanent]
    loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]

    dist = meankl

    all_var_list = pi.get_trainable_variables()
    var_list = [v for v in all_var_list if v.name.startswith("pi/pol") or v.name.startswith("pi/logstd")]
    vf_var_list = [v for v in all_var_list if v.name.startswith("pi/vff")]
    assert len(var_list) == len(vf_var_list) + 1
    d_adam = MpiAdam(reward_giver.get_trainable_variables())
    vfadam = MpiAdam(vf_var_list)

    get_flat = U.GetFlat(var_list)
    set_from_flat = U.SetFromFlat(var_list)
    klgrads = tf.gradients(dist, var_list)
    flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan")
    shapes = [var.get_shape().as_list() for var in var_list]
    start = 0
    tangents = []
    for shape in shapes:
        sz = U.intprod(shape)
        tangents.append(tf.reshape(flat_tangent[start:start+sz], shape))
        start += sz
    gvp = tf.add_n([tf.reduce_sum(g*tangent) for (g, tangent) in zipsame(klgrads, tangents)])  # pylint: disable=E1111
    fvp = U.flatgrad(gvp, var_list)

    assign_old_eq_new = U.function([], [], updates=[tf.assign(oldv, newv)
                                                    for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())])
    compute_losses = U.function([ob, ac, atarg], losses)
    compute_lossandgrad = U.function([ob, ac, atarg], losses + [U.flatgrad(optimgain, var_list)])
    compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp)
    compute_vflossandgrad = U.function([ob, ret], U.flatgrad(vferr, vf_var_list))

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(colorize("done in %.3f seconds" % (time.time() - tstart), color='magenta'))
        else:
            yield

    def allmean(x):
        assert isinstance(x, np.ndarray)
        out = np.empty_like(x)
        MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
        out /= nworkers
        return out

    U.initialize()
    th_init = get_flat()
    MPI.COMM_WORLD.Bcast(th_init, root=0)
    set_from_flat(th_init)
    d_adam.sync()
    vfadam.sync()
    if rank == 0:
        print("Init param sum", th_init.sum(), flush=True)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi, env, reward_giver, timesteps_per_batch, stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=40)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40)  # rolling buffer for episode rewards
    true_rewbuffer = deque(maxlen=40)

    assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1

    g_loss_stats = stats(loss_names)
    d_loss_stats = stats(reward_giver.loss_name)
    ep_stats = stats(["True_rewards", "Rewards", "Episode_length"])
    # if provide pretrained weight
    if pretrained_weight is not None:
        U.load_state(pretrained_weight, var_list=pi.get_variables())

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break

        # Save model
        if rank == 0 and iters_so_far % save_per_iter == 0 and ckpt_dir is not None:
            fname = os.path.join(ckpt_dir, task_name)
            os.makedirs(os.path.dirname(fname), exist_ok=True)
            saver = tf.train.Saver()
            saver.save(tf.get_default_session(), fname)

        logger.log("********** Iteration %i ************" % iters_so_far)

        def fisher_vector_product(p):
            return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p
        # ------------------ Update G ------------------
        logger.log("Optimizing Policy...")
        for _ in range(g_step):
            with timed("sampling"):
                seg = seg_gen.__next__()
            add_vtarg_and_adv(seg, gamma, lam)
            # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
            ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"]
            vpredbefore = seg["vpred"]  # predicted value function before udpate
            atarg = (atarg - atarg.mean()) / atarg.std()  # standardized advantage function estimate

            if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob)  # update running mean/std for policy

            args = seg["ob"], seg["ac"], atarg
            fvpargs = [arr[::5] for arr in args]

            assign_old_eq_new()  # set old parameter values to new parameter values
            with timed("computegrad"):
                *lossbefore, g = compute_lossandgrad(*args)
            lossbefore = allmean(np.array(lossbefore))
            g = allmean(g)
            if np.allclose(g, 0):
                logger.log("Got zero gradient. not updating")
            else:
                with timed("cg"):
                    stepdir = cg(fisher_vector_product, g, cg_iters=cg_iters, verbose=rank == 0)
                assert np.isfinite(stepdir).all()
                shs = .5*stepdir.dot(fisher_vector_product(stepdir))
                lm = np.sqrt(shs / max_kl)
                # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
                fullstep = stepdir / lm
                expectedimprove = g.dot(fullstep)
                surrbefore = lossbefore[0]
                stepsize = 1.0
                thbefore = get_flat()
                for _ in range(10):
                    thnew = thbefore + fullstep * stepsize
                    set_from_flat(thnew)
                    meanlosses = surr, kl, *_ = allmean(np.array(compute_losses(*args)))
                    improve = surr - surrbefore
                    logger.log("Expected: %.3f Actual: %.3f" % (expectedimprove, improve))
                    if not np.isfinite(meanlosses).all():
                        logger.log("Got non-finite value of losses -- bad!")
                    elif kl > max_kl * 1.5:
                        logger.log("violated KL constraint. shrinking step.")
                    elif improve < 0:
                        logger.log("surrogate didn't improve. shrinking step.")
                    else:
                        logger.log("Stepsize OK!")
                        break
                    stepsize *= .5
                else:
                    logger.log("couldn't compute a good step")
                    set_from_flat(thbefore)
                if nworkers > 1 and iters_so_far % 20 == 0:
                    paramsums = MPI.COMM_WORLD.allgather((thnew.sum(), vfadam.getflat().sum()))  # list of tuples
                    assert all(np.allclose(ps, paramsums[0]) for ps in paramsums[1:])
            with timed("vf"):
                for _ in range(vf_iters):
                    for (mbob, mbret) in dataset.iterbatches((seg["ob"], seg["tdlamret"]),
                                                             include_final_partial_batch=False, batch_size=128):
                        if hasattr(pi, "ob_rms"):
                            pi.ob_rms.update(mbob)  # update running mean/std for policy
                        g = allmean(compute_vflossandgrad(mbob, mbret))
                        vfadam.update(g, vf_stepsize)

        g_losses = meanlosses
        for (lossname, lossval) in zip(loss_names, meanlosses):
            logger.record_tabular(lossname, lossval)
        logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))
        # ------------------ Update D ------------------
        logger.log("Optimizing Discriminator...")
        logger.log(fmt_row(13, reward_giver.loss_name))
        ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob))
        batch_size = len(ob) // d_step
        d_losses = []  # list of tuples, each of which gives the loss for a minibatch
        for ob_batch, ac_batch in dataset.iterbatches((ob, ac),
                                                      include_final_partial_batch=False,
                                                      batch_size=batch_size):
            ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob_batch))
            # update running mean/std for reward_giver
            if hasattr(reward_giver, "obs_rms"): reward_giver.obs_rms.update(np.concatenate((ob_batch, ob_expert), 0))
            *newlosses, g = reward_giver.lossandgrad(ob_batch, ac_batch, ob_expert, ac_expert)
            d_adam.update(allmean(g), d_stepsize)
            d_losses.append(newlosses)
        logger.log(fmt_row(13, np.mean(d_losses, axis=0)))

        lrlocal = (seg["ep_lens"], seg["ep_rets"], seg["ep_true_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews, true_rets = map(flatten_lists, zip(*listoflrpairs))
        true_rewbuffer.extend(true_rets)
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpTrueRewMean", np.mean(true_rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)

        if rank == 0:
            logger.dump_tabular()
Exemplo n.º 7
0
def learn(
    env,
    policy_func,
    reward_giver,
    semi_dataset,
    rank,
    pretrained_weight,
    *,
    g_step,
    d_step,
    entcoeff,
    save_per_iter,
    ckpt_dir,
    log_dir,
    timesteps_per_batch,
    task_name,
    gamma,
    lam,
    max_kl,
    cg_iters,
    cg_damping=1e-2,
    vf_stepsize=3e-4,
    d_stepsize=3e-4,
    vf_iters=3,
    max_timesteps=0,
    max_episodes=0,
    max_iters=0,
    vf_batchsize=128,
    callback=None,
    freeze_g=False,
    freeze_d=False,
    pretrained_il=None,
    pretrained_semi=None,
    semi_loss=False,
    expert_reward_threshold=None,  # filter experts based on reward
    expert_label=get_semi_prefix(),
    sparse_reward=False  # filter experts based on success flag (sparse reward)
):

    semi_loss = semi_loss and semi_dataset is not None
    l2_w = 0.1

    nworkers = MPI.COMM_WORLD.Get_size()
    rank = MPI.COMM_WORLD.Get_rank()
    np.set_printoptions(precision=3)

    if rank == 0:
        writer = U.file_writer(log_dir)

        # print all the hyperparameters in the log...
        log_dict = {
            # "expert trajectories": expert_dataset.num_traj,
            "expert model": pretrained_semi,
            "algo": "trpo",
            "threads": nworkers,
            "timesteps_per_batch": timesteps_per_batch,
            "timesteps_per_thread": -(-timesteps_per_batch // nworkers),
            "entcoeff": entcoeff,
            "vf_iters": vf_iters,
            "vf_batchsize": vf_batchsize,
            "vf_stepsize": vf_stepsize,
            "d_stepsize": d_stepsize,
            "g_step": g_step,
            "d_step": d_step,
            "max_kl": max_kl,
            "gamma": gamma,
            "lam": lam,
        }

        if semi_dataset is not None:
            log_dict["semi trajectories"] = semi_dataset.num_traj
        if hasattr(semi_dataset, 'info'):
            log_dict["semi_dataset_info"] = semi_dataset.info
        if expert_reward_threshold is not None:
            log_dict["expert reward threshold"] = expert_reward_threshold
        log_dict["sparse reward"] = sparse_reward

        # print them all together for csv
        logger.log(",".join([str(elem) for elem in log_dict]))
        logger.log(",".join([str(elem) for elem in log_dict.values()]))

        # also print them separately for easy reading:
        for elem in log_dict:
            logger.log(str(elem) + ": " + str(log_dict[elem]))

    # divide the timesteps to the threads
    timesteps_per_batch = -(-timesteps_per_batch // nworkers
                            )  # get ceil of division

    # Setup losses and stuff
    # ----------------------------------------
    ob_space = OrderedDict([(label, env[label].observation_space)
                            for label in env])

    if semi_dataset and get_semi_prefix() in env:  # semi ob space is different
        semi_obs_space = semi_ob_space(env[get_semi_prefix()],
                                       semi_size=semi_dataset.semi_size)
        ob_space[get_semi_prefix()] = semi_obs_space
    else:
        print("no semi dataset")
        # raise RuntimeError

    vf_stepsize = {label: vf_stepsize for label in env}

    ac_space = {label: env[label].action_space for label in ob_space}
    pi = {
        label: policy_func("pi",
                           ob_space=ob_space[label],
                           ac_space=ac_space[label],
                           prefix=label)
        for label in ob_space
    }
    oldpi = {
        label: policy_func("oldpi",
                           ob_space=ob_space[label],
                           ac_space=ac_space[label],
                           prefix=label)
        for label in ob_space
    }
    atarg = {
        label: tf.placeholder(dtype=tf.float32, shape=[None])
        for label in ob_space
    }  # Target advantage function (if applicable)
    ret = {
        label: tf.placeholder(dtype=tf.float32, shape=[None])
        for label in ob_space
    }  # Empirical return

    ob = {
        label: U.get_placeholder_cached(name=label + "ob")
        for label in ob_space
    }
    ac = {
        label: pi[label].pdtype.sample_placeholder([None])
        for label in ob_space
    }

    kloldnew = {label: oldpi[label].pd.kl(pi[label].pd) for label in ob_space}
    ent = {label: pi[label].pd.entropy() for label in ob_space}
    meankl = {label: tf.reduce_mean(kloldnew[label]) for label in ob_space}
    meanent = {label: tf.reduce_mean(ent[label]) for label in ob_space}
    entbonus = {label: entcoeff * meanent[label] for label in ob_space}

    vferr = {
        label: tf.reduce_mean(tf.square(pi[label].vpred - ret[label]))
        for label in ob_space
    }

    ratio = {
        label:
        tf.exp(pi[label].pd.logp(ac[label]) - oldpi[label].pd.logp(ac[label]))
        for label in ob_space
    }  # advantage * pnew / pold
    surrgain = {
        label: tf.reduce_mean(ratio[label] * atarg[label])
        for label in ob_space
    }

    optimgain = {
        label: surrgain[label] + entbonus[label]
        for label in ob_space
    }
    losses = {
        label: [
            optimgain[label], meankl[label], entbonus[label], surrgain[label],
            meanent[label]
        ]
        for label in ob_space
    }
    loss_names = {
        label: [
            label + name for name in
            ["optimgain", "meankl", "entloss", "surrgain", "entropy"]
        ]
        for label in ob_space
    }

    vf_losses = {label: [vferr[label]] for label in ob_space}
    vf_loss_names = {label: [label + "vf_loss"] for label in ob_space}

    dist = {label: meankl[label] for label in ob_space}

    all_var_list = {
        label: pi[label].get_trainable_variables()
        for label in ob_space
    }
    var_list = {
        label: [
            v for v in all_var_list[label]
            if "pol" in v.name or "logstd" in v.name
        ]
        for label in ob_space
    }
    vf_var_list = {
        label: [v for v in all_var_list[label] if "vf" in v.name]
        for label in ob_space
    }
    for label in ob_space:
        assert len(var_list[label]) == len(vf_var_list[label]) + 1

    get_flat = {label: U.GetFlat(var_list[label]) for label in ob_space}
    set_from_flat = {
        label: U.SetFromFlat(var_list[label])
        for label in ob_space
    }
    klgrads = {
        label: tf.gradients(dist[label], var_list[label])
        for label in ob_space
    }
    flat_tangent = {
        label: tf.placeholder(dtype=tf.float32,
                              shape=[None],
                              name=label + "flat_tan")
        for label in ob_space
    }
    fvp = {}
    for label in ob_space:
        shapes = [var.get_shape().as_list() for var in var_list[label]]
        start = 0
        tangents = []
        for shape in shapes:
            sz = U.intprod(shape)
            tangents.append(
                tf.reshape(flat_tangent[label][start:start + sz], shape))
            start += sz
        gvp = tf.add_n([
            tf.reduce_sum(g * tangent)
            for (g, tangent) in zipsame(klgrads[label], tangents)
        ])  # pylint: disable=E1111
        fvp[label] = U.flatgrad(gvp, var_list[label])

    assign_old_eq_new = {
        label:
        U.function([], [],
                   updates=[
                       tf.assign(oldv, newv)
                       for (oldv,
                            newv) in zipsame(oldpi[label].get_variables(),
                                             pi[label].get_variables())
                   ])
        for label in ob_space
    }
    compute_losses = {
        label: U.function([ob[label], ac[label], atarg[label]], losses[label])
        for label in ob_space
    }

    compute_vf_losses = {
        label: U.function([ob[label], ac[label], atarg[label], ret[label]],
                          losses[label] + vf_losses[label])
        for label in ob_space
    }

    compute_lossandgrad = {
        label: U.function([ob[label], ac[label], atarg[label]], losses[label] +
                          [U.flatgrad(optimgain[label], var_list[label])])
        for label in ob_space
    }
    compute_fvp = {
        label:
        U.function([flat_tangent[label], ob[label], ac[label], atarg[label]],
                   fvp[label])
        for label in ob_space
    }

    compute_vflossandgrad = {
        label: U.function([ob[label], ret[label]], vf_losses[label] +
                          [U.flatgrad(vferr[label], vf_var_list[label])])
        for label in ob_space
    }

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(
                colorize("done in %.3f seconds" % (time.time() - tstart),
                         color='magenta'))
        else:
            yield

    def allmean(x):
        assert isinstance(x, np.ndarray)
        out = np.empty_like(x)
        MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
        out /= nworkers
        return out

    episodes_so_far = {label: 0 for label in ob_space}
    timesteps_so_far = {label: 0 for label in ob_space}
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = {label: deque(maxlen=40)
                 for label in ob_space}  # rolling buffer for episode lengths
    rewbuffer = {label: deque(maxlen=40)
                 for label in ob_space}  # rolling buffer for episode rewards
    true_rewbuffer = {label: deque(maxlen=40) for label in ob_space}
    success_buffer = {label: deque(maxlen=40) for label in ob_space}
    # L2 only for semi network
    l2_rewbuffer = deque(
        maxlen=40) if semi_loss and semi_dataset is not None else None
    total_rewbuffer = deque(
        maxlen=40) if semi_loss and semi_dataset is not None else None

    not_update = 1 if not freeze_d else 0  # do not update G before D the first time
    loaded = False
    # if provide pretrained weight
    if not U.load_checkpoint_variables(pretrained_weight,
                                       include_no_prefix_vars=True):
        # if no general checkpoint available, check sub-checkpoints for both networks
        if U.load_checkpoint_variables(pretrained_il,
                                       prefix=get_il_prefix(),
                                       include_no_prefix_vars=False):
            if rank == 0:
                logger.log("loaded checkpoint variables from " + pretrained_il)
            loaded = True
        elif expert_label == get_il_prefix():
            logger.log("ERROR no available cat_dauggi expert model in ",
                       pretrained_il)
            exit(1)

        if U.load_checkpoint_variables(pretrained_semi,
                                       prefix=get_semi_prefix(),
                                       include_no_prefix_vars=False):
            if rank == 0:
                logger.log("loaded checkpoint variables from " +
                           pretrained_semi)
            loaded = True
        elif expert_label == get_semi_prefix():
            if rank == 0:
                logger.log("ERROR no available semi expert model in ",
                           pretrained_semi)
            exit(1)
    else:
        loaded = True
        if rank == 0:
            logger.log("loaded checkpoint variables from " + pretrained_weight)

    if loaded:
        not_update = 0 if any(
            [x.op.name.find("adversary") != -1
             for x in U.ALREADY_INITIALIZED]) else 1
        if pretrained_weight and pretrained_weight.rfind("iter_") and \
                pretrained_weight[pretrained_weight.rfind("iter_") + len("iter_"):].isdigit():
            curr_iter = int(
                pretrained_weight[pretrained_weight.rfind("iter_") +
                                  len("iter_"):]) + 1

            if rank == 0:
                print("loaded checkpoint at iteration: " + str(curr_iter))
            iters_so_far = curr_iter
            for label in timesteps_so_far:
                timesteps_so_far[label] = iters_so_far * timesteps_per_batch

    d_adam = MpiAdam(reward_giver.get_trainable_variables())
    vfadam = {label: MpiAdam(vf_var_list[label]) for label in ob_space}

    U.initialize()
    d_adam.sync()

    for label in ob_space:
        th_init = get_flat[label]()
        MPI.COMM_WORLD.Bcast(th_init, root=0)
        set_from_flat[label](th_init)
        vfadam[label].sync()
        if rank == 0:
            print(label + "Init param sum", th_init.sum(), flush=True)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = {
        label: traj_segment_generator(
            pi[label],
            env[label],
            reward_giver,
            timesteps_per_batch,
            stochastic=True,
            semi_dataset=semi_dataset if label == get_semi_prefix() else None,
            semi_loss=semi_loss,
            reward_threshold=expert_reward_threshold
            if label == expert_label else None,
            sparse_reward=sparse_reward if label == expert_label else False)
        for label in ob_space
    }

    g_losses = {}

    assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1

    g_loss_stats = {
        label: stats(loss_names[label] + vf_loss_names[label])
        for label in ob_space if label != expert_label
    }
    d_loss_stats = stats(reward_giver.loss_name)
    ep_names = ["True_rewards", "Rewards", "Episode_length", "Success"]

    ep_stats = {label: None for label in ob_space}
    # cat_dauggi network stats
    if get_il_prefix() in ep_stats:
        ep_stats[get_il_prefix()] = stats([name for name in ep_names])

    # semi network stats
    if get_semi_prefix() in ep_stats:
        if semi_loss and semi_dataset is not None:
            ep_names.append("L2_loss")
            ep_names.append("total_rewards")
        ep_stats[get_semi_prefix()] = stats(
            [get_semi_prefix() + name for name in ep_names])

    if rank == 0:
        start_time = time.time()
        ch_count = 0
        env_type = env[expert_label].env.env.__class__.__name__

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and any(
            [timesteps_so_far[label] >= max_timesteps for label in ob_space]):
            break
        elif max_episodes and any(
            [episodes_so_far[label] >= max_episodes for label in ob_space]):
            break
        elif max_iters and iters_so_far >= max_iters:
            break

        # Save model
        if rank == 0 and iters_so_far % save_per_iter == 0 and ckpt_dir is not None:
            fname = os.path.join(ckpt_dir, task_name)
            if env_type.find("Pendulum") != -1 or save_per_iter != 1:
                fname = os.path.join(ckpt_dir, 'iter_' + str(iters_so_far),
                                     'iter_' + str(iters_so_far))
            os.makedirs(os.path.dirname(fname), exist_ok=True)
            saver = tf.train.Saver()
            saver.save(tf.get_default_session(), fname, write_meta_graph=False)

        if rank == 0 and time.time(
        ) - start_time >= 3600 * ch_count:  # save a different checkpoint every hour
            fname = os.path.join(ckpt_dir, 'hour' + str(ch_count).zfill(3))
            fname = os.path.join(fname, 'iter_' + str(iters_so_far))
            os.makedirs(os.path.dirname(fname), exist_ok=True)
            saver = tf.train.Saver()
            saver.save(tf.get_default_session(), fname, write_meta_graph=False)
            ch_count += 1

        logger.log("********** Iteration %i ************" % iters_so_far)

        def fisher_func_builder(label):
            def fisher_vector_product(p):
                return allmean(compute_fvp[label](p, *
                                                  fvpargs)) + cg_damping * p

            return fisher_vector_product

        # ------------------ Update G ------------------
        d = {label: None for label in ob_space}
        segs = {label: None for label in ob_space}
        logger.log("Optimizing Policy...")
        for curr_step in range(g_step):
            for label in ob_space:

                if curr_step and label == expert_label:  # get expert trajectories only for one g_step which is same as d_step
                    continue

                logger.log("Optimizing Policy " + label + "...")
                with timed("sampling"):
                    segs[label] = seg = seg_gen[label].__next__()

                seg["rew"] = seg["rew"] - seg["l2_loss"] * l2_w

                add_vtarg_and_adv(seg, gamma, lam)
                ob, ac, atarg, tdlamret, full_ob = seg["ob"], seg["ac"], seg[
                    "adv"], seg["tdlamret"], seg["full_ob"]
                vpredbefore = seg[
                    "vpred"]  # predicted value function before udpate
                atarg = (atarg - atarg.mean()) / atarg.std(
                )  # standardized advantage function estimate
                d[label] = Dataset(dict(ob=ob,
                                        ac=ac,
                                        atarg=atarg,
                                        vtarg=tdlamret),
                                   shuffle=True)

                if not_update or label == expert_label:
                    continue  # stop G from updating

                if hasattr(pi[label], "ob_rms"):
                    pi[label].ob_rms.update(
                        full_ob)  # update running mean/std for policy

                args = seg["full_ob"], seg["ac"], atarg
                fvpargs = [arr[::5] for arr in args]

                assign_old_eq_new[label](
                )  # set old parameter values to new parameter values
                with timed("computegrad"):
                    *lossbefore, g = compute_lossandgrad[label](*args)
                lossbefore = allmean(np.array(lossbefore))
                g = allmean(g)
                if np.allclose(g, 0):
                    logger.log("Got zero gradient. not updating")
                else:
                    with timed("cg"):
                        stepdir = cg(fisher_func_builder(label),
                                     g,
                                     cg_iters=cg_iters,
                                     verbose=rank == 0)
                    assert np.isfinite(stepdir).all()
                    shs = .5 * stepdir.dot(fisher_func_builder(label)(stepdir))
                    lm = np.sqrt(shs / max_kl)
                    # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
                    fullstep = stepdir / lm
                    expectedimprove = g.dot(fullstep)
                    surrbefore = lossbefore[0]
                    stepsize = 1.0
                    thbefore = get_flat[label]()
                    for _ in range(10):
                        thnew = thbefore + fullstep * stepsize
                        set_from_flat[label](thnew)
                        meanlosses = surr, kl, *_ = allmean(
                            np.array(compute_losses[label](*args)))
                        if rank == 0:
                            print("Generator entropy " + str(meanlosses[4]) +
                                  ", loss " + str(meanlosses[2]))
                        improve = surr - surrbefore
                        logger.log("Expected: %.3f Actual: %.3f" %
                                   (expectedimprove, improve))
                        if not np.isfinite(meanlosses).all():
                            logger.log(
                                "Got non-finite value of losses -- bad!")
                        elif kl > max_kl * 1.5:
                            logger.log(
                                "violated KL constraint. shrinking step.")
                        elif improve < 0:
                            logger.log(
                                "surrogate didn't improve. shrinking step.")
                        else:
                            logger.log("Stepsize OK!")
                            break
                        stepsize *= .5
                    else:
                        logger.log("couldn't compute a good step")
                        set_from_flat[label](thbefore)
                    if nworkers > 1 and iters_so_far % 20 == 0:
                        paramsums = MPI.COMM_WORLD.allgather(
                            (thnew.sum(),
                             vfadam[label].getflat().sum()))  # list of tuples
                        assert all(
                            np.allclose(ps, paramsums[0])
                            for ps in paramsums[1:])

            expert_dataset = d[expert_label]

            if not_update:
                break

            for label in ob_space:
                if label == expert_label:
                    continue

                with timed("vf"):
                    logger.log(fmt_row(13, vf_loss_names[label]))
                    for _ in range(vf_iters):
                        vf_b_losses = []
                        for batch in d[label].iterate_once(vf_batchsize):
                            mbob = batch["ob"]
                            mbret = batch["vtarg"]
                            *newlosses, g = compute_vflossandgrad[label](mbob,
                                                                         mbret)
                            g = allmean(g)
                            newlosses = allmean(np.array(newlosses))

                            vfadam[label].update(g, vf_stepsize[label])
                            vf_b_losses.append(newlosses)
                        logger.log(fmt_row(13, np.mean(vf_b_losses, axis=0)))

                    logger.log("Evaluating losses...")
                    losses = []
                    for batch in d[label].iterate_once(vf_batchsize):
                        newlosses = compute_vf_losses[label](batch["ob"],
                                                             batch["ac"],
                                                             batch["atarg"],
                                                             batch["vtarg"])
                        losses.append(newlosses)
                    g_losses[label], _, _ = mpi_moments(losses, axis=0)

                #########################
                for ob_batch, ac_batch, full_ob_batch in dataset.iterbatches(
                    (segs[label]["ob"], segs[label]["ac"],
                     segs[label]["full_ob"]),
                        include_final_partial_batch=False,
                        batch_size=len(ob)):
                    expert_batch = expert_dataset.next_batch(len(ob))

                    ob_expert, ac_expert = expert_batch["ob"], expert_batch[
                        "ac"]

                    exp_rew = 0
                    exp_rews = None
                    for obs, acs in zip(ob_expert, ac_expert):
                        curr_rew = reward_giver.get_reward(obs, acs)[0][0] \
                                   if not hasattr(reward_giver, '_labels') else \
                                   reward_giver.get_reward(obs, acs, label)
                        if isinstance(curr_rew, tuple):
                            curr_rew, curr_rews = curr_rew
                            exp_rews = 1 - np.exp(
                                -curr_rews
                            ) if exp_rews is None else exp_rews + 1 - np.exp(
                                -curr_rews)
                        exp_rew += 1 - np.exp(-curr_rew)
                    mean_exp_rew = exp_rew / len(ob_expert)
                    mean_exp_rews = exp_rews / len(
                        ob_expert) if exp_rews is not None else None

                    gen_rew = 0
                    gen_rews = None
                    for obs, acs, full_obs in zip(ob_batch, ac_batch,
                                                  full_ob_batch):
                        curr_rew = reward_giver.get_reward(obs, acs)[0][0] \
                                   if not hasattr(reward_giver, '_labels') else \
                                   reward_giver.get_reward(obs, acs, label)
                        if isinstance(curr_rew, tuple):
                            curr_rew, curr_rews = curr_rew
                            gen_rews = 1 - np.exp(
                                -curr_rews
                            ) if gen_rews is None else gen_rews + 1 - np.exp(
                                -curr_rews)
                        gen_rew += 1 - np.exp(-curr_rew)
                    mean_gen_rew = gen_rew / len(ob_batch)
                    mean_gen_rews = gen_rews / len(
                        ob_batch) if gen_rews is not None else None
                    if rank == 0:
                        msg = "Network " + label + \
                            " Generator step " + str(curr_step) + ": Dicriminator reward of expert traj " \
                            + str(mean_exp_rew) + " vs gen traj " + str(mean_gen_rew)
                        if mean_exp_rews is not None and mean_gen_rews is not None:
                            msg += "\nDiscriminator multi rewards of expert " + str(mean_exp_rews) + " vs gen " \
                                    + str(mean_gen_rews)
                        logger.log(msg)
                #########################

        if not not_update:
            for label in g_losses:
                for (lossname,
                     lossval) in zip(loss_names[label] + vf_loss_names[label],
                                     g_losses[label]):
                    logger.record_tabular(lossname, lossval)
                logger.record_tabular(
                    label + "ev_tdlam_before",
                    explained_variance(segs[label]["vpred"],
                                       segs[label]["tdlamret"]))

        # ------------------ Update D ------------------
        if not freeze_d:
            logger.log("Optimizing Discriminator...")
            batch_size = len(list(segs.values())[0]['ob']) // d_step
            expert_dataset = d[expert_label]
            batch_gen = {
                label: dataset.iterbatches(
                    (segs[label]["ob"], segs[label]["ac"]),
                    include_final_partial_batch=False,
                    batch_size=batch_size)
                for label in segs if label != expert_label
            }

            d_losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for step in range(d_step):
                g_ob = {}
                g_ac = {}
                for label in batch_gen:  # get batches for different gens
                    g_ob[label], g_ac[label] = batch_gen[label].__next__()

                expert_batch = expert_dataset.next_batch(batch_size)

                ob_expert, ac_expert = expert_batch["ob"], expert_batch["ac"]

                for label in g_ob:
                    #########################
                    exp_rew = 0
                    exp_rews = None
                    for obs, acs in zip(ob_expert, ac_expert):
                        curr_rew = reward_giver.get_reward(obs, acs)[0][0] \
                            if not hasattr(reward_giver, '_labels') else \
                            reward_giver.get_reward(obs, acs, label)
                        if isinstance(curr_rew, tuple):
                            curr_rew, curr_rews = curr_rew
                            exp_rews = 1 - np.exp(
                                -curr_rews
                            ) if exp_rews is None else exp_rews + 1 - np.exp(
                                -curr_rews)
                        exp_rew += 1 - np.exp(-curr_rew)
                    mean_exp_rew = exp_rew / len(ob_expert)
                    mean_exp_rews = exp_rews / len(
                        ob_expert) if exp_rews is not None else None

                    gen_rew = 0
                    gen_rews = None
                    for obs, acs in zip(g_ob[label], g_ac[label]):
                        curr_rew = reward_giver.get_reward(obs, acs)[0][0] \
                            if not hasattr(reward_giver, '_labels') else \
                            reward_giver.get_reward(obs, acs, label)
                        if isinstance(curr_rew, tuple):
                            curr_rew, curr_rews = curr_rew
                            gen_rews = 1 - np.exp(
                                -curr_rews
                            ) if gen_rews is None else gen_rews + 1 - np.exp(
                                -curr_rews)
                        gen_rew += 1 - np.exp(-curr_rew)
                    mean_gen_rew = gen_rew / len(g_ob[label])
                    mean_gen_rews = gen_rews / len(
                        g_ob[label]) if gen_rews is not None else None
                    if rank == 0:
                        msg = "Dicriminator reward of expert traj " + str(mean_exp_rew) + " vs " + label + \
                            "gen traj " + str(mean_gen_rew)
                        if mean_exp_rews is not None and mean_gen_rews is not None:
                            msg += "\nDiscriminator multi expert rewards " + str(mean_exp_rews) + " vs " + label + \
                                   "gen " + str(mean_gen_rews)
                        logger.log(msg)
                        #########################

                # update running mean/std for reward_giver
                if hasattr(reward_giver, "obs_rms"):
                    reward_giver.obs_rms.update(
                        np.concatenate(list(g_ob.values()) + [ob_expert], 0))
                *newlosses, g = reward_giver.lossandgrad(
                    *(list(g_ob.values()) + list(g_ac.values()) + [ob_expert] +
                      [ac_expert]))
                d_adam.update(allmean(g), d_stepsize)
                d_losses.append(newlosses)
                logger.log(fmt_row(13, reward_giver.loss_name))
                logger.log(fmt_row(13, np.mean(d_losses, axis=0)))

        for label in ob_space:
            lrlocal = (segs[label]["ep_lens"], segs[label]["ep_rets"],
                       segs[label]["ep_true_rets"], segs[label]["ep_success"],
                       segs[label]["ep_semi_loss"])  # local values

            listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
            lens, rews, true_rets, success, semi_losses = map(
                flatten_lists, zip(*listoflrpairs))

            # success
            success = [
                float(elem) for elem in success
                if isinstance(elem, (int, float, bool))
            ]  # remove potential None types
            if not success:
                success = [-1]  # set success to -1 if env has no success flag
            success_buffer[label].extend(success)

            true_rewbuffer[label].extend(true_rets)
            lenbuffer[label].extend(lens)
            rewbuffer[label].extend(rews)

            if semi_loss and semi_dataset is not None and label == get_semi_prefix(
            ):
                semi_losses = [elem * l2_w for elem in semi_losses]
                total_rewards = rews
                total_rewards = [
                    re_elem - l2_elem
                    for re_elem, l2_elem in zip(total_rewards, semi_losses)
                ]
                l2_rewbuffer.extend(semi_losses)
                total_rewbuffer.extend(total_rewards)

            logger.record_tabular(label + "EpLenMean",
                                  np.mean(lenbuffer[label]))
            logger.record_tabular(label + "EpRewMean",
                                  np.mean(rewbuffer[label]))
            logger.record_tabular(label + "EpTrueRewMean",
                                  np.mean(true_rewbuffer[label]))
            logger.record_tabular(label + "EpSuccess",
                                  np.mean(success_buffer[label]))

            if semi_loss and semi_dataset is not None and label == get_semi_prefix(
            ):
                logger.record_tabular(label + "EpSemiLoss",
                                      np.mean(l2_rewbuffer))
                logger.record_tabular(label + "EpTotalLoss",
                                      np.mean(total_rewbuffer))
            logger.record_tabular(label + "EpThisIter", len(lens))
            episodes_so_far[label] += len(lens)
            timesteps_so_far[label] += sum(lens)

            logger.record_tabular(label + "EpisodesSoFar",
                                  episodes_so_far[label])
            logger.record_tabular(label + "TimestepsSoFar",
                                  timesteps_so_far[label])
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        iters_so_far += 1
        logger.record_tabular("ItersSoFar", iters_so_far)

        if rank == 0:
            logger.dump_tabular()
            if not not_update:
                for label in g_loss_stats:
                    g_loss_stats[label].add_all_summary(
                        writer, g_losses[label], iters_so_far)
            if not freeze_d:
                d_loss_stats.add_all_summary(writer, np.mean(d_losses, axis=0),
                                             iters_so_far)

            for label in ob_space:
                # default buffers
                ep_buffers = [
                    np.mean(true_rewbuffer[label]),
                    np.mean(rewbuffer[label]),
                    np.mean(lenbuffer[label]),
                    np.mean(success_buffer[label])
                ]

                if semi_loss and semi_dataset is not None and label == get_semi_prefix(
                ):
                    ep_buffers.append(np.mean(l2_rewbuffer))
                    ep_buffers.append(np.mean(total_rewbuffer))

                ep_stats[label].add_all_summary(writer, ep_buffers,
                                                iters_so_far)

        if not_update and not freeze_g:
            not_update -= 1
Exemplo n.º 8
0
def learn(
        env,
        policy_func,
        *,
        timesteps_per_actorbatch,  # timesteps per actor per update
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        optim_epochs,
        optim_stepsize,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,  # advantage estimation
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        max_seconds=0,  # time constraint
        callback=None,  # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant',
        seed=1  # annealing for stepsize parameters (epsilon and adam)
):

    # We want to log:
    num_options = 1  # Hacky solution -> enables to use same logging!
    epoch = -1
    saves = True
    wsaves = True

    ### Book-keeping
    gamename = env.spec.id[:-3].lower()
    gamename += 'seed' + str(seed)
    #gamename += app
    version_name = 'officialPPO'

    dirname = '{}_{}_{}opts_saves/'.format(version_name, gamename, num_options)

    # retrieve everything using relative paths. Create a train_results folder where the repo has been cloned
    dirname_rel = os.path.dirname(__file__)
    splitted = dirname_rel.split("/")
    dirname_rel = ("/".join(dirname_rel.split("/")[:len(splitted) - 3]) + "/")
    dirname = dirname_rel + "train_results/" + dirname

    # Specify the paths where results shall be written to:
    src_code_path = dirname_rel + 'baselines/baselines/ppo1/'
    results_path = dirname_rel
    envs_path = dirname_rel + 'nfunk/envs_nf/'

    print(dirname)
    #input ("wait here after dirname")

    if wsaves:
        first = True
        if not os.path.exists(dirname):
            os.makedirs(dirname)
            first = False
        # while os.path.exists(dirname) and first:
        #     dirname += '0'

        files = ['pposgd_simple.py', 'mlp_policy.py', 'run_mujoco.py']
        first = True
        for i in range(len(files)):
            src = os.path.join(src_code_path) + files[i]
            print(src)
            #dest = os.path.join('/home/nfunk/results_NEW/ppo1/') + dirname
            dest = dirname + "src_code/"
            if (first):
                os.makedirs(dest)
                first = False
            print(dest)
            shutil.copy2(src, dest)
        # brute force copy normal env file at end of copying process:
        env_files = ['pendulum_nf.py']
        for i in range(len(env_files)):
            src = os.path.join(envs_path + env_files[i])
            shutil.copy2(src, dest)
        os.makedirs(dest + "assets/")
        src = os.path.join(envs_path + "assets/clockwise.png")
        shutil.copy2(src, dest + "assets/")

    np.random.seed(seed)
    tf.set_random_seed(seed)
    env.seed(seed)

    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_func("pi", ob_space,
                     ac_space)  # Construct network for new policy
    oldpi = policy_func("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = U.mean(kloldnew)
    meanent = U.mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = U.clip(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg  #
    pol_surr = -U.mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = U.mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    adam.sync()

    saver = tf.train.Saver(max_to_keep=10000)
    saver_best = tf.train.Saver(max_to_keep=1)

    ### More book-kepping
    results = []
    if saves:
        results = open(
            dirname + version_name + '_' + gamename + '_' + str(num_options) +
            'opts_' + '_results.csv', 'w')
        results_best_model = open(
            dirname + version_name + '_' + gamename + '_' + str(num_options) +
            'opts_' + '_bestmodel.csv', 'w')

        out = 'epoch,avg_reward'

        for opt in range(num_options):
            out += ',option {} dur'.format(opt)
        for opt in range(num_options):
            out += ',option {} std'.format(opt)
        for opt in range(num_options):
            out += ',option {} term'.format(opt)
        for opt in range(num_options):
            out += ',option {} adv'.format(opt)
        out += '\n'
        results.write(out)
        # results.write('epoch,avg_reward,option 1 dur, option 2 dur, option 1 term, option 2 term\n')
        results.flush()

    if epoch >= 0:

        print("Loading weights from iteration: " + str(epoch))

        filename = dirname + '{}_epoch_{}.ckpt'.format(gamename, epoch)
        saver.restore(U.get_session(), filename)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_actorbatch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    max_val = -100000
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values

        if iters_so_far % 50 == 0 and wsaves:
            print("weights are saved...")
            filename = dirname + '{}_epoch_{}.ckpt'.format(
                gamename, iters_so_far)
            save_path = saver.save(U.get_session(), filename)

        if (np.mean(rewbuffer) > max_val) and wsaves:
            max_val = np.mean(rewbuffer)
            results_best_model.write('epoch: ' + str(iters_so_far) + 'rew: ' +
                                     str(np.mean(rewbuffer)) + '\n')
            results_best_model.flush()
            filename = dirname + 'best.ckpt'.format(gamename, iters_so_far)
            save_path = saver_best.save(U.get_session(), filename)

        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()

        ### Book keeping
        if saves:
            out = "{},{}"
            #for _ in range(num_options): #out+=",{},{},{},{}"
            out += "\n"

            info = [iters_so_far, np.mean(rewbuffer)]

            results.write(out.format(*info))
            results.flush()
Exemplo n.º 9
0
def core_train_def(
        env,
        pi,
        oldpi,
        env_att,
        pi_att,
        loss_names,
        lossandgrad,
        adam,
        assign_old_eq_new,
        compute_losses,
        timesteps_per_batch,
        optim_epochs,
        optim_stepsize,
        optim_batchsize,
        gamma,
        lam,
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        max_seconds=0,  # time constraint
        callback=None,
        schedule='constant',
        test_envs=[]):
    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator_def(pi_att,
                                         pi,
                                         env_att,
                                         env,
                                         timesteps_per_batch,
                                         stochastic=True)

    if test_envs:
        test_gens = [
            pposgd_simple.traj_segment_generator(pi,
                                                 attenv,
                                                 timesteps_per_batch,
                                                 stochastic=True)
            for attenv in test_envs
        ]
    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=50)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=50)  # rolling buffer for episode rewards
    testbuffers = [deque(maxlen=50) for test_env in test_envs]

    # Maithra edits: add lists to return logs
    ep_lengths = []
    ep_rewards = []
    ep_labels = []
    ep_actions = []
    ep_correct_actions = []
    ep_obs = []
    ep_att_obs = []
    ep_att_actions = []
    ep_test_rewards = [[] for test_env in test_envs]

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Defender Iteration %i ************" %
                   iters_so_far)

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        if test_envs:
            test_segs = [test_gen.__next__() for test_gen in test_gens]

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret, label, att_ob, att_ac  = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"], seg["label"], \
                                          seg["att_ob"], seg["att_ac"]

        if test_envs:
            for i, test_seg in enumerate(test_segs):
                test_rews = test_seg["ep_rets"]
                testbuffers[i].extend(test_rews)
                ep_test_rewards[i].append(np.mean(testbuffers[i]))

        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)

            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))

        # Maithra edit: append intermediate results onto returned logs
        ep_lengths.append(np.mean(lenbuffer))
        ep_rewards.append(np.mean(rewbuffer))
        ep_labels.append(deepcopy(label))
        ep_actions.append(deepcopy(ac))
        ep_obs.append(deepcopy(ob))
        ep_att_obs.append(deepcopy(att_ob))
        ep_att_actions.append(deepcopy(att_ac))
        # compute mean of correct actions and append, ignoring actions
        # where either choice could be right
        count = 0
        idxs = np.all((label == [1, 1]), axis=1)
        # removing for now: count += np.sum(idxs)
        new_label = label[np.invert(idxs)]
        new_ac = ac[np.invert(idxs)]
        count += np.sum((new_ac == np.argmax(new_label, axis=1)))
        # changing ep_correct_actions.append(count/len(label))
        ep_correct_actions.append(count / (len(label) - np.sum(idxs)))

        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()

    info_dict = {
        "lengths": ep_lengths,
        "rewards": ep_rewards,
        "labels": ep_labels,
        "actions": ep_actions,
        "correct_actions": ep_correct_actions,
        "obs": ep_obs,
        "att_actions": ep_att_actions,
        "att_obs": ep_att_obs,
        "test_rews": ep_test_rewards
    }
    #Maithra edit
    return pi, oldpi, lossandgrad, adam, assign_old_eq_new, compute_losses, info_dict
Exemplo n.º 10
0
def learn(
        env,
        policy_func,
        *,
        timesteps_per_actorbatch,  # timesteps per actor per update
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        vfcoeff,
        optim_epochs,
        optim_stepsize,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,  # advantage estimation
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        max_seconds=0,  # time constraint
        callback=None,  # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
        sensor=False,
        save_name=None,
        save_per_acts=3,
        reload_name=None):
    # Setup losses and stuff
    # ----------------------------------------
    if sensor:
        ob_space = env.sensor_space
    else:
        ob_space = env.observation_space
    ac_space = env.action_space

    pi = policy_func("pi", ob_space,
                     ac_space)  # Construct network for new policy
    oldpi = policy_func("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return
    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed cliping parameter epsilon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param,
                             1.0 + clip_param) * atarg  #
    pol_surr = -tf.reduce_mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = vfcoeff * tf.reduce_mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    adam.sync()

    if reload_name:
        saver = tf.train.Saver()
        saver.restore(tf.get_default_session(), reload_name)
        print("Loaded model successfully.")

    # from IPython import embed; embed()
    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_actorbatch,
                                     stochastic=True,
                                     sensor=sensor)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % (iters_so_far + 1))

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        elapse = time.time() - tstart
        logger.record_tabular("TimeElapsed", elapse)

        #Iteration Recording
        record = 1
        if record:
            file_path = os.path.join(
                os.path.expanduser("~"),
                "PycharmProjects/Gibson_Exercise/gibson/utils/models/iterations"
            )
            try:
                os.mkdir(file_path)
            except OSError:
                pass

            if iters_so_far == 1:
                with open(os.path.join(
                        os.path.expanduser("~"),
                        'PycharmProjects/Gibson_Exercise/gibson/utils/models/iterations/values.csv'
                ),
                          'w',
                          newline='') as csvfile:
                    fieldnames = [
                        'Iteration', 'TimeSteps', 'Reward', 'LossEnt',
                        'LossVF', 'PolSur'
                    ]
                    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

                    writer.writeheader()
                    writer.writerow({
                        'Iteration': iters_so_far,
                        'TimeSteps': timesteps_so_far,
                        'Reward': np.mean(rews),
                        'LossEnt': meanlosses[4],
                        'LossVF': meanlosses[2],
                        'PolSur': meanlosses[1]
                    })
            else:
                with open(os.path.join(
                        os.path.expanduser("~"),
                        'PycharmProjects/Gibson_Exercise/gibson/utils/models/iterations/values.csv'
                ),
                          'a',
                          newline='') as csvfile:
                    fieldnames = [
                        'Iteration', 'TimeSteps', 'Reward', 'LossEnt',
                        'LossVF', 'PolSur'
                    ]
                    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

                    writer.writerow({
                        'Iteration': iters_so_far,
                        'TimeSteps': timesteps_so_far,
                        'Reward': np.mean(rews),
                        'LossEnt': meanlosses[4],
                        'LossVF': meanlosses[2],
                        'PolSur': meanlosses[1]
                    })

        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()

        load_number = 0
        if not reload_name == None:
            load_number = int(str(reload_name.split('_')[7]).split('.')[0])

        if save_name and (iters_so_far % save_per_acts == 0):
            base_path = os.path.dirname(os.path.abspath(__file__))
            print(base_path)
            out_name = os.path.join(
                base_path, 'models',
                save_name + '_' + str(iters_so_far + load_number) + ".model")
            U.save_state(out_name)
            print("Saved model successfully.")
def learn(env,
          policy_func,
          reward_giver,
          expert_dataset,
          rank,
          pretrained,
          pretrained_weight,
          *,
          g_step,
          d_step,
          entcoeff,
          save_per_iter,
          ckpt_dir,
          log_dir,
          timesteps_per_batch,
          task_name,
          gamma,
          lam,
          max_kl,
          cg_iters,
          cg_damping=1e-2,
          vf_stepsize=3e-4,
          d_stepsize=3e-4,
          vf_iters=3,
          max_timesteps=0,
          max_episodes=0,
          max_iters=0,
          vf_batchsize=128,
          callback=None,
          freeze_g=False,
          freeze_d=False,
          semi_dataset=None,
          semi_loss=False):

    semi_loss = semi_loss and semi_dataset is not None
    l2_w = 0.1

    nworkers = MPI.COMM_WORLD.Get_size()
    rank = MPI.COMM_WORLD.Get_rank()
    np.set_printoptions(precision=3)

    if rank == 0:
        writer = U.file_writer(log_dir)

        # print all the hyperparameters in the log...
        log_dict = {
            "expert trajectories": expert_dataset.num_traj,
            "algo": "trpo",
            "threads": nworkers,
            "timesteps_per_batch": timesteps_per_batch,
            "timesteps_per_thread": -(-timesteps_per_batch // nworkers),
            "entcoeff": entcoeff,
            "vf_iters": vf_iters,
            "vf_batchsize": vf_batchsize,
            "vf_stepsize": vf_stepsize,
            "d_stepsize": d_stepsize,
            "g_step": g_step,
            "d_step": d_step,
            "max_kl": max_kl,
            "gamma": gamma,
            "lam": lam,
            "l2_weight": l2_w
        }

        if semi_dataset is not None:
            log_dict["semi trajectories"] = semi_dataset.num_traj
        if hasattr(semi_dataset, 'info'):
            log_dict["semi_dataset_info"] = semi_dataset.info

        # print them all together for csv
        logger.log(",".join([str(elem) for elem in log_dict]))
        logger.log(",".join([str(elem) for elem in log_dict.values()]))

        # also print them separately for easy reading:
        for elem in log_dict:
            logger.log(str(elem) + ": " + str(log_dict[elem]))

    # divide the timesteps to the threads
    timesteps_per_batch = -(-timesteps_per_batch // nworkers
                            )  # get ceil of division

    # Setup losses and stuff
    # ----------------------------------------
    if semi_dataset:
        ob_space = semi_ob_space(env, semi_size=semi_dataset.semi_size)
    else:
        ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_func("pi",
                     ob_space=ob_space,
                     ac_space=ac_space,
                     reuse=(pretrained_weight is not None))
    oldpi = policy_func("oldpi", ob_space=ob_space, ac_space=ac_space)
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    entbonus = entcoeff * meanent

    vferr = tf.reduce_mean(tf.square(pi.vpred - ret))

    ratio = tf.exp(pi.pd.logp(ac) -
                   oldpi.pd.logp(ac))  # advantage * pnew / pold
    surrgain = tf.reduce_mean(ratio * atarg)

    optimgain = surrgain + entbonus
    losses = [optimgain, meankl, entbonus, surrgain, meanent]
    loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]

    vf_losses = [vferr]
    vf_loss_names = ["vf_loss"]

    dist = meankl

    all_var_list = pi.get_trainable_variables()
    var_list = [
        v for v in all_var_list
        if v.name.startswith("pi/pol") or v.name.startswith("pi/logstd")
    ]
    vf_var_list = [v for v in all_var_list if v.name.startswith("pi/vf")]
    assert len(var_list) == len(vf_var_list) + 1

    get_flat = U.GetFlat(var_list)
    set_from_flat = U.SetFromFlat(var_list)
    klgrads = tf.gradients(dist, var_list)
    flat_tangent = tf.placeholder(dtype=tf.float32,
                                  shape=[None],
                                  name="flat_tan")
    shapes = [var.get_shape().as_list() for var in var_list]
    start = 0
    tangents = []
    for shape in shapes:
        sz = U.intprod(shape)
        tangents.append(tf.reshape(flat_tangent[start:start + sz], shape))
        start += sz
    gvp = tf.add_n([
        tf.reduce_sum(g * tangent)
        for (g, tangent) in zipsame(klgrads, tangents)
    ])  # pylint: disable=E1111
    fvp = U.flatgrad(gvp, var_list)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg], losses)
    compute_vf_losses = U.function([ob, ac, atarg, ret], losses + vf_losses)
    compute_lossandgrad = U.function([ob, ac, atarg], losses +
                                     [U.flatgrad(optimgain, var_list)])
    compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp)
    compute_vflossandgrad = U.function([ob, ret], vf_losses +
                                       [U.flatgrad(vferr, vf_var_list)])

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(
                colorize("done in %.3f seconds" % (time.time() - tstart),
                         color='magenta'))
        else:
            yield

    def allmean(x):
        assert isinstance(x, np.ndarray)
        out = np.empty_like(x)
        MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
        out /= nworkers
        return out

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=40)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40)  # rolling buffer for episode rewards
    true_rewbuffer = deque(maxlen=40)
    success_buffer = deque(maxlen=40)
    l2_rewbuffer = deque(
        maxlen=40) if semi_loss and semi_dataset is not None else None
    total_rewbuffer = deque(
        maxlen=40) if semi_loss and semi_dataset is not None else None

    assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1

    not_update = 1 if not freeze_d else 0  # do not update G before D the first time
    # if provide pretrained weight
    loaded = False
    if not U.load_checkpoint_variables(pretrained_weight):
        if U.load_checkpoint_variables(pretrained_weight,
                                       check_prefix=get_il_prefix()):
            if rank == 0:
                logger.log("loaded checkpoint variables from " +
                           pretrained_weight)
            loaded = True
    else:
        loaded = True

    if loaded:
        not_update = 0 if any(
            [x.op.name.find("adversary") != -1
             for x in U.ALREADY_INITIALIZED]) else 1
        if pretrained_weight and pretrained_weight.rfind("iter_") and \
                pretrained_weight[pretrained_weight.rfind("iter_") + len("iter_"):].isdigit():
            curr_iter = int(
                pretrained_weight[pretrained_weight.rfind("iter_") +
                                  len("iter_"):]) + 1
            print("loaded checkpoint at iteration: " + str(curr_iter))
            iters_so_far = curr_iter
            timesteps_so_far = iters_so_far * timesteps_per_batch

    d_adam = MpiAdam(reward_giver.get_trainable_variables())
    vfadam = MpiAdam(vf_var_list)

    U.initialize()

    th_init = get_flat()
    MPI.COMM_WORLD.Bcast(th_init, root=0)
    set_from_flat(th_init)
    d_adam.sync()
    vfadam.sync()
    if rank == 0:
        print("Init param sum", th_init.sum(), flush=True)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(
        pi,
        env,
        reward_giver,
        timesteps_per_batch,
        stochastic=True,
        semi_dataset=semi_dataset,
        semi_loss=semi_loss)  # ADD L2 loss to semi trajectories

    g_loss_stats = stats(loss_names + vf_loss_names)
    d_loss_stats = stats(reward_giver.loss_name)
    ep_names = ["True_rewards", "Rewards", "Episode_length", "Success"]
    if semi_loss and semi_dataset is not None:
        ep_names.append("L2_loss")
        ep_names.append("total_rewards")
    ep_stats = stats(ep_names)

    if rank == 0:
        start_time = time.time()
        ch_count = 0
        env_type = env.env.env.__class__.__name__

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break

        # Save model
        if rank == 0 and iters_so_far % save_per_iter == 0 and ckpt_dir is not None:
            fname = os.path.join(ckpt_dir, task_name)
            if env_type.find(
                    "Pendulum"
            ) != -1 or save_per_iter != 1:  # allow pendulum to save all iterations
                fname = os.path.join(ckpt_dir, 'iter_' + str(iters_so_far),
                                     'iter_' + str(iters_so_far))
            os.makedirs(os.path.dirname(fname), exist_ok=True)
            saver = tf.train.Saver()
            saver.save(tf.get_default_session(), fname, write_meta_graph=False)

        if rank == 0 and time.time(
        ) - start_time >= 3600 * ch_count:  # save a different checkpoint every hour
            fname = os.path.join(ckpt_dir, 'hour' + str(ch_count).zfill(3))
            fname = os.path.join(fname, 'iter_' + str(iters_so_far))
            os.makedirs(os.path.dirname(fname), exist_ok=True)
            saver = tf.train.Saver()
            saver.save(tf.get_default_session(), fname, write_meta_graph=False)
            ch_count += 1

        logger.log("********** Iteration %i ************" % iters_so_far)

        def fisher_vector_product(p):
            return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p

        # ------------------ Update G ------------------
        logger.log("Optimizing Policy...")
        for curr_step in range(g_step):
            with timed("sampling"):
                seg = seg_gen.__next__()

            seg["rew"] = seg["rew"] - seg["l2_loss"] * l2_w

            add_vtarg_and_adv(seg, gamma, lam)
            # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
            ob, ac, atarg, tdlamret, full_ob = seg["ob"], seg["ac"], seg[
                "adv"], seg["tdlamret"], seg["full_ob"]
            vpredbefore = seg[
                "vpred"]  # predicted value function before udpate
            atarg = (atarg - atarg.mean()) / atarg.std(
            )  # standardized advantage function estimate
            d = Dataset(dict(ob=full_ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                        shuffle=True)

            if not_update:
                break  # stop G from updating

            if hasattr(pi, "ob_rms"):
                pi.ob_rms.update(full_ob)  # update running mean/std for policy

            args = seg["full_ob"], seg["ac"], atarg
            fvpargs = [arr[::5] for arr in args]

            assign_old_eq_new(
            )  # set old parameter values to new parameter values
            with timed("computegrad"):
                *lossbefore, g = compute_lossandgrad(*args)
            lossbefore = allmean(np.array(lossbefore))
            g = allmean(g)
            if np.allclose(g, 0):
                logger.log("Got zero gradient. not updating")
            else:
                with timed("cg"):
                    stepdir = cg(fisher_vector_product,
                                 g,
                                 cg_iters=cg_iters,
                                 verbose=rank == 0)
                assert np.isfinite(stepdir).all()
                shs = .5 * stepdir.dot(fisher_vector_product(stepdir))
                lm = np.sqrt(shs / max_kl)
                # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
                fullstep = stepdir / lm
                expectedimprove = g.dot(fullstep)
                surrbefore = lossbefore[0]
                stepsize = 1.0
                thbefore = get_flat()
                for _ in range(10):
                    thnew = thbefore + fullstep * stepsize
                    set_from_flat(thnew)
                    meanlosses = surr, kl, *_ = allmean(
                        np.array(compute_losses(*args)))
                    if rank == 0:
                        print("Generator entropy " + str(meanlosses[4]) +
                              ", loss " + str(meanlosses[2]))
                    improve = surr - surrbefore
                    logger.log("Expected: %.3f Actual: %.3f" %
                               (expectedimprove, improve))
                    if not np.isfinite(meanlosses).all():
                        logger.log("Got non-finite value of losses -- bad!")
                    elif kl > max_kl * 1.5:
                        logger.log("violated KL constraint. shrinking step.")
                    elif improve < 0:
                        logger.log("surrogate didn't improve. shrinking step.")
                    else:
                        logger.log("Stepsize OK!")
                        break
                    stepsize *= .5
                else:
                    logger.log("couldn't compute a good step")
                    set_from_flat(thbefore)
                if nworkers > 1 and iters_so_far % 20 == 0:
                    paramsums = MPI.COMM_WORLD.allgather(
                        (thnew.sum(),
                         vfadam.getflat().sum()))  # list of tuples
                    assert all(
                        np.allclose(ps, paramsums[0]) for ps in paramsums[1:])
            with timed("vf"):
                logger.log(fmt_row(13, vf_loss_names))
                for _ in range(vf_iters):
                    vf_b_losses = []
                    for batch in d.iterate_once(vf_batchsize):
                        mbob = batch["ob"]
                        mbret = batch["vtarg"]

                        if hasattr(pi, "ob_rms"):
                            pi.ob_rms.update(
                                mbob)  # update running mean/std for policy
                        *newlosses, g = compute_vflossandgrad(mbob, mbret)
                        g = allmean(g)
                        newlosses = allmean(np.array(newlosses))

                        vfadam.update(g, vf_stepsize)
                        vf_b_losses.append(newlosses)
                    logger.log(fmt_row(13, np.mean(vf_b_losses, axis=0)))

            logger.log("Evaluating losses...")
            losses = []
            for batch in d.iterate_once(vf_batchsize):
                newlosses = compute_vf_losses(batch["ob"], batch["ac"],
                                              batch["atarg"], batch["vtarg"])
                losses.append(newlosses)
            meanlosses, _, _ = mpi_moments(losses, axis=0)

            #########################
            '''
            For evaluation during training.
            Can be commented out for faster training...
            '''
            for ob_batch, ac_batch, full_ob_batch in dataset.iterbatches(
                (ob, ac, full_ob),
                    include_final_partial_batch=False,
                    batch_size=len(ob)):
                ob_expert, ac_expert = expert_dataset.get_next_batch(
                    len(ob_batch))
                exp_rew = 0
                for obs, acs in zip(ob_expert, ac_expert):
                    exp_rew += 1 - np.exp(
                        -reward_giver.get_reward(obs, acs)[0][0])
                mean_exp_rew = exp_rew / len(ob_expert)

                gen_rew = 0
                for obs, acs, full_obs in zip(ob_batch, ac_batch,
                                              full_ob_batch):
                    gen_rew += 1 - np.exp(
                        -reward_giver.get_reward(obs, acs)[0][0])
                mean_gen_rew = gen_rew / len(ob_batch)
                if rank == 0:
                    logger.log("Generator step " + str(curr_step) +
                               ": Dicriminator reward of expert traj " +
                               str(mean_exp_rew) + " vs gen traj " +
                               str(mean_gen_rew))
            #########################

        if not not_update:
            g_losses = meanlosses
            for (lossname, lossval) in zip(loss_names + vf_loss_names,
                                           meanlosses):
                logger.record_tabular(lossname, lossval)
            logger.record_tabular("ev_tdlam_before",
                                  explained_variance(vpredbefore, tdlamret))

        # ------------------ Update D ------------------
        if not freeze_d:
            logger.log("Optimizing Discriminator...")
            batch_size = len(ob) // d_step
            d_losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for ob_batch, ac_batch, full_ob_batch in dataset.iterbatches(
                (ob, ac, full_ob),
                    include_final_partial_batch=False,
                    batch_size=batch_size):
                ob_expert, ac_expert = expert_dataset.get_next_batch(
                    len(ob_batch))
                #########################
                '''
                For evaluation during training.
                Can be commented out for faster training...
                '''
                exp_rew = 0
                for obs, acs in zip(ob_expert, ac_expert):
                    exp_rew += 1 - np.exp(
                        -reward_giver.get_reward(obs, acs)[0][0])
                mean_exp_rew = exp_rew / len(ob_expert)

                gen_rew = 0

                for obs, acs in zip(ob_batch, ac_batch):
                    gen_rew += 1 - np.exp(
                        -reward_giver.get_reward(obs, acs)[0][0])

                mean_gen_rew = gen_rew / len(ob_batch)
                if rank == 0:
                    logger.log("Dicriminator reward of expert traj " +
                               str(mean_exp_rew) + " vs gen traj " +
                               str(mean_gen_rew))
                #########################
                # update running mean/std for reward_giver
                if hasattr(reward_giver, "obs_rms"):
                    reward_giver.obs_rms.update(
                        np.concatenate((ob_batch, ob_expert), 0))
                loss_input = (ob_batch, ac_batch, ob_expert, ac_expert)
                *newlosses, g = reward_giver.lossandgrad(*loss_input)
                d_adam.update(allmean(g), d_stepsize)
                d_losses.append(newlosses)
            logger.log(fmt_row(13, reward_giver.loss_name))
            logger.log(fmt_row(13, np.mean(d_losses, axis=0)))

        lrlocal = (seg["ep_lens"], seg["ep_rets"], seg["ep_true_rets"],
                   seg["ep_success"], seg["ep_semi_loss"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews, true_rets, success, semi_losses = map(
            flatten_lists, zip(*listoflrpairs))

        # success
        success = [
            float(elem) for elem in success
            if isinstance(elem, (int, float, bool))
        ]  # remove potential None types
        if not success:
            success = [-1]  # set success to -1 if env has no success flag
        success_buffer.extend(success)

        true_rewbuffer.extend(true_rets)
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        if semi_loss and semi_dataset is not None:
            semi_losses = [elem * l2_w for elem in semi_losses]
            total_rewards = rews
            total_rewards = [
                re_elem - l2_elem
                for re_elem, l2_elem in zip(total_rewards, semi_losses)
            ]
            l2_rewbuffer.extend(semi_losses)
            total_rewbuffer.extend(total_rewards)

        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpTrueRewMean", np.mean(true_rewbuffer))
        logger.record_tabular("EpSuccess", np.mean(success_buffer))

        if semi_loss and semi_dataset is not None:
            logger.record_tabular("EpSemiLoss", np.mean(l2_rewbuffer))
            logger.record_tabular("EpTotalReward", np.mean(total_rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        logger.record_tabular("ItersSoFar", iters_so_far)

        if rank == 0:
            logger.dump_tabular()
            if not not_update:
                g_loss_stats.add_all_summary(writer, g_losses, iters_so_far)
            if not freeze_d:
                d_loss_stats.add_all_summary(writer, np.mean(d_losses, axis=0),
                                             iters_so_far)

            # default buffers
            ep_buffers = [
                np.mean(true_rewbuffer),
                np.mean(rewbuffer),
                np.mean(lenbuffer),
                np.mean(success_buffer)
            ]

            if semi_loss and semi_dataset is not None:
                ep_buffers.append(np.mean(l2_rewbuffer))
                ep_buffers.append(np.mean(total_rewbuffer))

            ep_stats.add_all_summary(writer, ep_buffers, iters_so_far)

        if not_update and not freeze_g:
            not_update -= 1
Exemplo n.º 12
0
def learn(
    env,
    policy_func,
    *,
    timesteps_per_batch,  # timesteps per actor per update
    clip_param,
    entcoeff,  # clipping parameter epsilon, entropy coeff
    optim_epochs,
    optim_stepsize,
    optim_batchsize,  # optimization hypers
    gamma,
    lam,  # advantage estimation
    max_timesteps=0,
    max_episodes=0,
    max_iters=0,
    max_seconds=0,  # time constraint
    callback=None,  # you can do anything in the callback, since it takes locals(), globals()
    adam_epsilon=1e-5,
    schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
    sym_loss_weight=0.0,
    return_threshold=None,  # termiante learning if reaches return_threshold
    op_after_init=None,
    init_policy_params=None,
    policy_scope=None,
    max_threshold=None,
    positive_rew_enforce=False,
    reward_drop_bound=True,
    min_iters=0,
    ref_policy_params=None,
    discrete_learning=None  # [obs_disc, act_disc, state_filter_fn, state_unfilter_fn, weight]
):

    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    if policy_scope is None:
        pi = policy_func("pi", ob_space,
                         ac_space)  # Construct network for new policy
        oldpi = policy_func("oldpi", ob_space,
                            ac_space)  # Network for old policy
    else:
        pi = policy_func(policy_scope, ob_space,
                         ac_space)  # Construct network for new policy
        oldpi = policy_func("old" + policy_scope, ob_space,
                            ac_space)  # Network for old policy

    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = U.mean(kloldnew)
    meanent = U.mean(ent)
    pol_entpen = (-entcoeff) * meanent

    sym_loss = sym_loss_weight * U.mean(
        tf.square(pi.mean - pi.mirrored_mean))  # mirror symmetric loss
    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = U.clip(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg  #
    pol_surr = -U.mean(tf.minimum(
        surr1, surr2)) + sym_loss  # PPO's pessimistic surrogate (L^CLIP)

    vf_loss = U.mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent, sym_loss]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent", "sym_loss"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()

    if init_policy_params is not None:
        cur_scope = pi.get_variables()[0].name[0:pi.get_variables()[0].name.
                                               find('/')]
        orig_scope = list(init_policy_params.keys()
                          )[0][0:list(init_policy_params.keys())[0].find('/')]
        for i in range(len(pi.get_variables())):
            assign_op = pi.get_variables()[i].assign(
                init_policy_params[pi.get_variables()[i].name.replace(
                    cur_scope, orig_scope, 1)])
            U.get_session().run(assign_op)
            assign_op = oldpi.get_variables()[i].assign(
                init_policy_params[pi.get_variables()[i].name.replace(
                    cur_scope, orig_scope, 1)])
            U.get_session().run(assign_op)

    if ref_policy_params is not None:
        ref_pi = policy_func("ref_pi", ob_space, ac_space)
        cur_scope = ref_pi.get_variables()[0].name[0:ref_pi.get_variables()[0].
                                                   name.find('/')]
        orig_scope = list(ref_policy_params.keys()
                          )[0][0:list(ref_policy_params.keys())[0].find('/')]
        for i in range(len(ref_pi.get_variables())):
            assign_op = ref_pi.get_variables()[i].assign(
                ref_policy_params[ref_pi.get_variables()[i].name.replace(
                    cur_scope, orig_scope, 1)])
            U.get_session().run(assign_op)
        #env.env.env.ref_policy = ref_pi

    adam.sync()

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_batch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    max_thres_satisfied = max_threshold is None
    adjust_ratio = 0.0
    prev_avg_rew = -1000000
    revert_parameters = {}
    variables = pi.get_variables()
    for i in range(len(variables)):
        cur_val = variables[i].eval()
        revert_parameters[variables[i].name] = cur_val
    revert_data = [0, 0, 0]
    all_collected_transition_data = []
    Vfunc = {}

    # temp
    import joblib
    path = 'data/value_iter_cartpole_discrete'
    [Vfunc, obs_disc, act_disc, state_filter_fn,
     state_unfilter_fn] = joblib.load(path + '/ref_policy_funcs.pkl')

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()

        if reward_drop_bound is not None:
            lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
            listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
            lens, rews = map(flatten_lists, zip(*listoflrpairs))
            lenbuffer.extend(lens)
            rewbuffer.extend(rews)
            revert_iteration = False
            if np.mean(
                    rewbuffer
            ) < prev_avg_rew - 50:  # detect significant drop in performance, revert to previous iteration
                print("Revert Iteration!!!!!")
                revert_iteration = True
            else:
                prev_avg_rew = np.mean(rewbuffer)
            logger.record_tabular("Revert Rew", prev_avg_rew)
            if revert_iteration:  # revert iteration
                for i in range(len(pi.get_variables())):
                    assign_op = pi.get_variables()[i].assign(
                        revert_parameters[pi.get_variables()[i].name])
                    U.get_session().run(assign_op)
                episodes_so_far = revert_data[0]
                timesteps_so_far = revert_data[1]
                iters_so_far = revert_data[2]
                continue
            else:
                variables = pi.get_variables()
                for i in range(len(variables)):
                    cur_val = variables[i].eval()
                    revert_parameters[variables[i].name] = np.copy(cur_val)
                revert_data[0] = episodes_so_far
                revert_data[1] = timesteps_so_far
                revert_data[2] = iters_so_far

        if positive_rew_enforce:
            rewlocal = (seg["pos_rews"], seg["neg_pens"], seg["rew"]
                        )  # local values
            listofrews = MPI.COMM_WORLD.allgather(rewlocal)  # list of tuples
            pos_rews, neg_pens, rews = map(flatten_lists, zip(*listofrews))
            if np.mean(rews) < 0.0:
                #min_id = np.argmin(rews)
                #adjust_ratio = pos_rews[min_id]/np.abs(neg_pens[min_id])
                adjust_ratio = np.max([
                    adjust_ratio,
                    np.mean(pos_rews) / np.abs(np.mean(neg_pens))
                ])
                for i in range(len(seg["rew"])):
                    if np.abs(seg["rew"][i] - seg["pos_rews"][i] -
                              seg["neg_pens"][i]) > 1e-5:
                        print(seg["rew"][i], seg["pos_rews"][i],
                              seg["neg_pens"][i])
                        print('Reward wrong!')
                        abc
                    seg["rew"][i] = seg["pos_rews"][
                        i] + seg["neg_pens"][i] * adjust_ratio
        if ref_policy_params is not None:
            rewed = 0
            for i in range(len(seg["rew"])):
                #pred_nexvf = np.max([ref_pi.act(False, seg["collected_transitions"][i][5])[1], pi.act(False, seg["collected_transitions"][i][5])[1]])
                #pred_curvf = np.max([ref_pi.act(False, seg["collected_transitions"][i][4])[1], pi.act(False, seg["collected_transitions"][i][4])[1]])

                if obs_disc(state_filter_fn(seg["collected_transitions"][i][2])) in Vfunc and \
                        obs_disc(state_filter_fn(seg["collected_transitions"][i][0])) in Vfunc:
                    pred_nexvf = Vfunc[obs_disc(
                        state_filter_fn(seg["collected_transitions"][i][2]))]
                    pred_curvf = Vfunc[obs_disc(
                        state_filter_fn(seg["collected_transitions"][i][0]))]
                    rewed += 1
                else:
                    pred_nexvf = 0
                    pred_curvf = 0

                vf_diff = 0.99 * pred_nexvf - pred_curvf
                seg["rew"][i] += vf_diff * 0.1
            print('rewarded for : ', rewed / len(seg["rew"]))
        if discrete_learning is not None:
            rewlocal = (seg["collected_transitions"], seg["rew"]
                        )  # local values
            listofrews = MPI.COMM_WORLD.allgather(rewlocal)  # list of tuples
            collected_transitions, rews = map(flatten_lists, zip(*listofrews))
            processed_transitions = []
            for trans in collected_transitions:
                processed_transitions.append([
                    discrete_learning[2](trans[0]), trans[1],
                    discrete_learning[2](trans[2]), trans[3]
                ])
            all_collected_transition_data += processed_transitions
            if len(all_collected_transition_data) > 500000:
                all_collected_transition_data = all_collected_transition_data[
                    len(all_collected_transition_data) - 500000:]
            if len(all_collected_transition_data) > 50000:
                logger.log("Fitting discrete dynamic model...")
                dyn_model, obs_disc = fit_dyn_model(
                    discrete_learning[0], discrete_learning[1],
                    all_collected_transition_data)
                logger.log(
                    "Perform value iteration on the discrete dynamic model...")
                Vfunc, policy = optimize_policy(dyn_model, 0.99)
                discrete_learning[0] = obs_disc
                rewarded = 0
                for i in range(len(seg["rew"])):
                    vf_diff = 0.99*Vfunc[discrete_learning[0](discrete_learning[2](seg["collected_transitions"][i][2]))] - \
                        Vfunc[discrete_learning[0](discrete_learning[2](seg["collected_transitions"][i][0]))]
                    seg["rew"][i] += vf_diff * discrete_learning[4]
                    #if policy[discrete_learning[0](discrete_learning[2](seg["collected_transitions"][i][0]))] == discrete_learning[1](seg["collected_transitions"][i][1]):
                    #    seg["rew"][i] += 2.0
                    #    rewarded += 1
                #logger.log(str(rewarded*1.0/len(seg["rew"])) + ' rewarded')

        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))
        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        if reward_drop_bound is None:
            lenbuffer.extend(lens)
            rewbuffer.extend(rews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        logger.record_tabular("Iter", iters_so_far)
        if positive_rew_enforce:
            if adjust_ratio is not None:
                logger.record_tabular("RewardAdjustRatio", adjust_ratio)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()
        if return_threshold is not None and max_thres_satisfied:
            if np.mean(
                    rewbuffer) > return_threshold and iters_so_far > min_iters:
                break
        if max_threshold is not None:
            print('Current max return: ', np.max(rewbuffer))
            if np.max(rewbuffer) > max_threshold:
                max_thres_satisfied = True
            else:
                max_thres_satisfied = False
    return pi
Exemplo n.º 13
0
def learn(
    env,
    policy_fn,
    *,
    timesteps_per_actorbatch,  # timesteps per actor per update
    clip_param,
    entcoeff,  # clipping parameter epsilon, entropy coeff
    optim_epochs,
    optim_stepsize,
    optim_batchsize,  # optimization hypers
    gamma,
    lam,  # advantage estimation
    max_timesteps=0,
    max_episodes=0,
    max_iters=0,
    max_seconds=0,  # time constraint
    callback=None,  # you can do anything in the callback, since it takes locals(), globals()
    adam_epsilon=1e-5,
    schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
    sym_loss_weight=0.0,
    **kwargs,
):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space

    pi = policy_fn("pi", ob_space,
                   ac_space)  # Construct network for new policy

    oldpi = policy_fn("oldpi", ob_space, ac_space)  # Network for old policy

    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    atarg_novel = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function for the novelty reward term
    ret_novel = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Empirical return for the novelty reward term

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule

    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()

    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    sym_loss = sym_loss_weight * tf.reduce_mean(
        tf.square(pi.mean - pi.mirr_mean))
    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold

    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param,
                             1.0 + clip_param) * atarg  #

    surr1_novel = ratio * atarg_novel  # surrogate loss of the novelty term
    surr2_novel = tf.clip_by_value(
        ratio, 1.0 - clip_param,
        1.0 + clip_param) * atarg_novel  # surrogate loss of the novelty term

    pol_surr = -tf.reduce_mean(tf.minimum(
        surr1, surr2)) + sym_loss  # PPO's pessimistic surrogate (L^CLIP)
    pol_surr_novel = -tf.reduce_mean(tf.minimum(
        surr1_novel,
        surr2_novel)) + sym_loss  # PPO's surrogate for the novelty part

    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    vf_loss_novel = tf.reduce_mean(tf.square(pi.vpred_novel - ret_novel))

    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent, sym_loss]

    total_loss_novel = pol_surr_novel + pol_entpen + vf_loss_novel
    losses_novel = [
        pol_surr_novel, pol_entpen, vf_loss_novel, meankl, meanent, sym_loss
    ]

    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent", 'symm']

    policy_var_list = pi.get_trainable_variables(scope='pi/pol')

    policy_var_count = 0
    for vars in policy_var_list:
        count_in_var = 1
        for dim in vars.shape._dims:
            count_in_var *= dim
        policy_var_count += count_in_var

    var_list = pi.get_trainable_variables(
        scope='pi/pol') + pi.get_trainable_variables(scope='pi/vf/')
    var_list_novel = pi.get_trainable_variables(
        scope='pi/pol') + pi.get_trainable_variables(scope='pi/vf_novel/')

    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])

    lossandgrad_novel = U.function(
        [ob, ac, atarg_novel, ret_novel, lrmult],
        losses_novel + [U.flatgrad(total_loss_novel, var_list_novel)])

    adam = MpiAdam(var_list, epsilon=adam_epsilon)
    adam_novel = MpiAdam(var_list_novel, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])

    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)
    compute_losses_novel = U.function([ob, ac, atarg_novel, ret_novel, lrmult],
                                      losses_novel)
    comp_sym_loss = U.function([], sym_loss)
    U.initialize()
    adam.sync()
    adam_novel.sync()

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_actorbatch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0

    novelty_update_iter_cycle = 10
    novelty_start_iter = 50
    novelty_update = True

    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards
    rewnovelbuffer = deque(
        maxlen=100)  # rolling buffer for episode novelty rewards

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, atarg_novel, tdlamret, tdlamret_novel = seg["ob"], seg[
            "ac"], seg["adv"], seg["adv_novel"], seg["tdlamret"], seg[
                "tdlamret_novel"]

        vpredbefore = seg["vpred"]  # predicted value function before udpate
        vprednovelbefore = seg[
            'vpred_novel']  # predicted novelty value function before update

        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        atarg_novel = (atarg_novel - atarg_novel.mean()) / atarg_novel.std(
        )  # standartized novelty advantage function estimate

        d = Dataset(dict(ob=ob,
                         ac=ac,
                         atarg=atarg,
                         vtarg=tdlamret,
                         atarg_novel=atarg_novel,
                         vtarg_novel=tdlamret_novel),
                    shuffle=not pi.recurrent)

        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)

                *newlosses_novel, g_novel = lossandgrad_novel(
                    batch["ob"], batch["ac"], batch["atarg_novel"],
                    batch["vtarg_novel"], cur_lrmult)

                # pol_g = g[0:policy_var_count]
                # pol_g_novel = g_novel[0:policy_var_count]
                #
                # if (np.dot(pol_g, pol_g_novel) > 0):
                #     adam_novel.update(g_novel, optim_stepsize * cur_lrmult)
                #
                # else:
                #
                #     parallel_g = (np.dot(pol_g, pol_g_novel) / np.linalg.norm(pol_g_novel)) * pol_g_novel
                #     final_pol_gradient = pol_g - parallel_g
                #
                #     final_gradient = np.zeros(len(g))
                #     final_gradient[0:policy_var_count] = final_pol_gradient
                #     final_gradient[policy_var_count::] = g[policy_var_count::]
                #
                #     adam.update(final_gradient, optim_stepsize * cur_lrmult)

                # zigzag_update(novelty_update, adam, adam_novel, g, g_novel, step)

                adam.update(g, optim_stepsize * cur_lrmult)

                # adam_novel.update(g_novel, optim_stepsize * cur_lrmult)

                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult)
            # newlosses_novel = compute_losses_novel(batch["ob"], batch["ac"], batch["atarg_novel"], batch["vtarg_novel"],
            #                                        cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"], seg['ep_rets_novel']
                   )  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews, rews_novel = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        rewnovelbuffer.extend(rews_novel)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpRNoveltyRewMean", np.mean(rewnovelbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        if iters_so_far >= novelty_start_iter and iters_so_far % novelty_update_iter_cycle == 0:
            novelty_update = not novelty_update

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        logger.record_tabular("NoveltyUpdate", novelty_update)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()

    return pi
Exemplo n.º 14
0
def learn(
    env,
    policy_fn,
    *,
    timesteps_per_actorbatch,  # timesteps per actor per update
    clip_param,
    entcoeff,  # clipping parameter epsilon, entropy coeff
    optim_epochs,
    optim_stepsize,
    optim_batchsize,  # optimization hypers
    gamma,
    lam,  # advantage estimation
    max_timesteps=0,
    max_episodes=0,
    max_iters=0,
    max_seconds=0,  # time constraint
    callback=None,  # you can do anything in the callback, since it takes locals(), globals()
    adam_epsilon=1e-5,
    schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
    restore_model_from_file=None,
    save_model_with_prefix,  # this is the naming of the saved model file. Usually here we set indication of the target goal:
    # for example 3dof_ppo1_H.
    # That way we can only select which networks we can execute to the real robot. We do not have to send all files or folder.
    # Naming of the model file should be self explanatory.
    job_id=None,  # this variable is used for indentifing Spearmint iteration number. It is usually set by the Spearmint iterator
    outdir="/tmp/rosrl/experiments/continuous/ppo1/"):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_fn("pi", ob_space,
                   ac_space)  # Construct network for new policy
    oldpi = policy_fn("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param,
                             1.0 + clip_param) * atarg  #
    pol_surr = -tf.reduce_mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    adam.sync()
    """
    Here we add a possibility to resume from a previously saved model if a model file is provided
    """
    if restore_model_from_file:
        # saver = tf.train.Saver(tf.all_variables())
        saver = tf.train.import_meta_graph(restore_model_from_file)
        saver.restore(
            tf.get_default_session(),
            tf.train.latest_checkpoint('./'))  #restore_model_from_file)
        logger.log("Loaded model from {}".format(restore_model_from_file))

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_actorbatch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    if save_model_with_prefix:
        if job_id is not None:
            basePath = '/tmp/rosrl/' + str(
                env.__class__.__name__) + '/ppo1/' + job_id
        else:
            basePath = '/tmp/rosrl/' + str(env.__class__.__name__) + '/ppo1/'

    # Create the writer for TensorBoard logs
    summary_writer = tf.summary.FileWriter(outdir,
                                           graph=tf.get_default_graph())

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpRewSEM", np.std(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        """
        Save the model at every itteration
        """

        if save_model_with_prefix:
            #if np.mean(rewbuffer) > -50.0:
            if iters_so_far % 10 == 0:
                basePath = outdir + "/models/"

                if not os.path.exists(basePath):
                    os.makedirs(basePath)
                modelF = basePath + save_model_with_prefix + "_afterIter_" + str(
                    iters_so_far) + ".model"
                U.save_state(modelF)
                logger.log("Saved model to file :{}".format(modelF))

        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()

        summary = tf.Summary(value=[
            tf.Summary.Value(tag="EpRewMean", simple_value=np.mean(rewbuffer))
        ])
        summary_writer.add_summary(summary, timesteps_so_far)
Exemplo n.º 15
0
def learn(env, policy_fn, *, timesteps_per_actorbatch, clip_param, entcoeff, optim_epochs, optim_stepsize,
          optim_batchsize, gamma, lam, max_timesteps=0, max_episodes=0, max_iters=0, max_seconds=0, callback=None,
          adam_epsilon=1e-5, schedule='constant'):
    """
    Learning PPO with Stochastic Gradient Descent

    :param env: (Gym Environment) environment to train on
    :param policy_fn: (function (str, Gym Spaces, Gym Spaces): TensorFlow Tensor) creates the policy
    :param timesteps_per_actorbatch: (int) timesteps per actor per update
    :param clip_param: (float) clipping parameter epsilon
    :param entcoeff: (float) the entropy loss weight
    :param optim_epochs: (float) the optimizer's number of epochs
    :param optim_stepsize: (float) the optimizer's stepsize
    :param optim_batchsize: (int) the optimizer's the batch size
    :param gamma: (float) discount factor
    :param lam: (float) advantage estimation
    :param max_timesteps: (int) number of env steps to optimizer for
    :param max_episodes: (int) the maximum number of epochs
    :param max_iters: (int) the maximum number of iterations
    :param max_seconds: (int) the maximal duration
    :param callback: (function (dict, dict)) function called at every steps with state of the algorithm.
        It takes the local and global variables.
    :param adam_epsilon: (float) the epsilon value for the adam optimizer
    :param schedule: (str) The type of scheduler for the learning rate update ('linear', 'constant',
                                 'double_linear_con', 'middle_drop' or 'double_middle_drop')
    """

    # Setup losses and stuff
    ob_space = env.observation_space
    ac_space = env.action_space
    sess = tf_util.single_threaded_session()

    # Construct network for new policy
    policy = policy_fn("pi", ob_space, ac_space, sess=sess)

    # Network for old policy
    oldpi = policy_fn("oldpi", ob_space, ac_space, sess=sess,
                      placeholders={"obs": policy.obs_ph, "stochastic": policy.stochastic_ph})

    # Target advantage function (if applicable)
    atarg = tf.placeholder(dtype=tf.float32, shape=[None])

    # Empirical return
    ret = tf.placeholder(dtype=tf.float32, shape=[None])

    # learning rate multiplier, updated with schedule
    lrmult = tf.placeholder(name='lrmult', dtype=tf.float32, shape=[])

    # Annealed cliping parameter epislon
    clip_param = clip_param * lrmult

    obs_ph = policy.obs_ph
    action_ph = policy.pdtype.sample_placeholder([None])

    kloldnew = oldpi.proba_distribution.kl(policy.proba_distribution)
    ent = policy.proba_distribution.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    # pnew / pold
    ratio = tf.exp(policy.proba_distribution.logp(action_ph) - oldpi.proba_distribution.logp(action_ph))

    # surrogate from conservative policy iteration
    surr1 = ratio * atarg
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg

    # PPO's pessimistic surrogate (L^CLIP)
    pol_surr = - tf.reduce_mean(tf.minimum(surr1, surr2))
    vf_loss = tf.reduce_mean(tf.square(policy.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = policy.get_trainable_variables()
    lossandgrad = tf_util.function([obs_ph, action_ph, atarg, ret, lrmult],
                                   losses + [tf_util.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon, sess=sess)

    assign_old_eq_new = tf_util.function([], [], updates=[tf.assign(oldv, newv)
                                                          for (oldv, newv) in
                                                          zipsame(oldpi.get_variables(), policy.get_variables())])
    compute_losses = tf_util.function([obs_ph, action_ph, atarg, ret, lrmult], losses)

    tf_util.initialize(sess=sess)
    adam.sync()

    # Prepare for rollouts
    seg_gen = traj_segment_generator(policy, env, timesteps_per_actorbatch, stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    t_start = time.time()

    # rolling buffer for episode lengths
    lenbuffer = deque(maxlen=100)
    # rolling buffer for episode rewards
    rewbuffer = deque(maxlen=100)

    assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0,
                max_seconds > 0]) == 1, "Only one time constraint permitted"

    while True:
        if callback:
            callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - t_start >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        obs_ph, action_ph, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"]

        # predicted value function before udpate
        vpredbefore = seg["vpred"]

        # standardized advantage function estimate
        atarg = (atarg - atarg.mean()) / atarg.std()
        dataset = Dataset(dict(ob=obs_ph, ac=action_ph, atarg=atarg, vtarg=tdlamret),
                          shuffle=not policy.recurrent)
        optim_batchsize = optim_batchsize or obs_ph.shape[0]

        if hasattr(policy, "ob_rms"):
            # update running mean/std for policy
            policy.ob_rms.update(obs_ph)

        # set old parameter values to new parameter values
        assign_old_eq_new(sess=sess)
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))

        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            # list of tuples, each of which gives the loss for a minibatch
            losses = []
            for batch in dataset.iterate_once(optim_batchsize):
                *newlosses, grad = lossandgrad(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult,
                                               sess=sess)
                adam.update(grad, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for batch in dataset.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult, sess=sess)
            losses.append(newlosses)
        mean_losses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, mean_losses))
        for (loss_val, name) in zipsame(mean_losses, loss_names):
            logger.record_tabular("loss_" + name, loss_val)
        logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))

        # local values
        lrlocal = (seg["ep_lens"], seg["ep_rets"])

        # list of tuples
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - t_start)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()

    return policy
Exemplo n.º 16
0
def learn(
        env,
        policy_func_pro,
        policy_func_adv,
        *,
        timesteps_per_batch,  # timesteps per actor per update
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        optim_epochs,
        optim_stepsize,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,
        lr_l,
        lr_a,
        max_steps_episode,  # advantage estimation
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        max_seconds=0,  # time constraint
        callback=None,  # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
        clip_action=False,
        restore_dir=None,
        ckpt_dir=None,
        save_timestep_period=1000):
    # Setup losses and stuff
    # ----------------------------------------
    rew_mean = []

    ob_space = env.observation_space
    pro_ac_space = env.action_space
    adv_ac_space = env.adv_action_space

    # env.render()
    pro_pi = policy_func_pro("pro_pi", ob_space,
                             pro_ac_space)  # Construct network for new policy
    pro_oldpi = policy_func_pro("pro_oldpi", ob_space,
                                pro_ac_space)  # Network for old policy

    adv_pi = policy_func_adv(
        "adv_pi", ob_space,
        adv_ac_space)  # Construct network for new adv policy
    adv_oldpi = policy_func_adv("adv_oldpi", ob_space,
                                adv_ac_space)  # Network for old adv policy

    pro_atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    adv_atarg = tf.placeholder(dtype=tf.float32, shape=[None])

    ret_pro = tf.placeholder(dtype=tf.float32,
                             shape=[None])  # Empirical return
    ret_adv = tf.placeholder(dtype=tf.float32,
                             shape=[None])  # Empirical return

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ob_ = U.get_placeholder_cached(name="ob_")

    adv_ob = U.get_placeholder_cached(name="ob_adv")
    adv_ob_ = U.get_placeholder_cached(name="adv_ob_")

    pro_ac = pro_pi.pdtype.sample_placeholder([None])
    adv_ac = adv_pi.pdtype.sample_placeholder([None])

    # define Lyapunov net
    s_dim = ob_space.shape[0]
    a_dim = pro_ac_space.shape[0]
    d_dim = adv_ac_space.shape[0]

    use_lyapunov = True
    approx_value = True
    finite_horizon = True
    labda_init = 1.  # ΔL的权重,拉格朗日乘子
    alpha3 = 0.5  # L2性能,后期我们可以修改这个指标来看性能
    tau = 5e-3
    ita = 1.
    lr_labda = 0.033

    LN_R = tf.placeholder(tf.float32, [None, 1], 'r')  # 回报
    LN_V = tf.placeholder(tf.float32, [None, 1], 'v')  # 回报
    LR_L = tf.placeholder(tf.float32, None, 'LR_L')  # Lyapunov网络学习率
    LR_A = tf.placeholder(tf.float32, None, 'LR_A')  # Actor网络学习率
    L_terminal = tf.placeholder(tf.float32, [None, 1], 'terminal')
    # LN_S = tf.placeholder(tf.float32, [None, s_dim], 's')  # 状态
    # LN_a_input = tf.placeholder(tf.float32, [None, a_dim], 'a_input')  # batch中输入的动作
    # ob_ = tf.placeholder(tf.float32, [None, s_dim], 's_')  # 后继状态
    LN_a_input_ = tf.placeholder(tf.float32, [None, a_dim], 'a_')  # 后继状态
    LN_d_input = tf.placeholder(tf.float32, [None, d_dim],
                                'd_input')  # batch中输入的干扰
    labda = tf.placeholder(tf.float32, None, 'LR_lambda')

    # log_labda = tf.get_variable('lambda', None, tf.float32, initializer=tf.log(labda_init))  # log(λ),用于自适应系数
    # labda = tf.clip_by_value(tf.exp(log_labda), *SCALE_lambda_MIN_MAX)

    l = build_l(ob, pro_ac, s_dim, a_dim)  # lyapunov 网络
    l_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                 scope='Lyapunov')

    ema = tf.train.ExponentialMovingAverage(decay=1 -
                                            tau)  # soft replacement  网络软更新

    def ema_getter(getter, name, *args, **kwargs):
        return ema.average(getter(name, *args, **kwargs))

    target_update = [ema.apply(l_params)]

    a_ = pro_pi.ac_
    a_old_ = pro_oldpi.ac_

    # 这里的下一个动作的采样方式有好几种
    l_ = build_l(ob_, a_, s_dim, a_dim, reuse=True)

    l_d = build_l(adv_ob, pro_ac, s_dim, a_dim, reuse=True)
    l_d_ = build_l(adv_ob_, LN_a_input_, s_dim, a_dim, reuse=True)

    l_old_ = build_l(ob_,
                     a_old_,
                     s_dim,
                     a_dim,
                     reuse=True,
                     custom_getter=ema_getter)  # 这里是否可以用a_代替a_old呢
    l_derta = tf.reduce_mean(
        l_ - l + (alpha3 + 1) * LN_R -
        ita * tf.expand_dims(tf.norm(LN_d_input, axis=1), axis=1))
    # d 里的l这一项可能也需要改,因为这里主策略已经优化过了,所以需要重新采样,但是重新采样好像也不太适合,因为R变了,要不还有一个方式,就是轮流更新
    # 还有一种方式就是这里的ac_也采用样本里的,来保证一致性,但是为啥之前SAC那个就没问题呢
    l_d_derta = tf.reduce_mean(
        l_d_ - l_d + (alpha3 + 1) * LN_R -
        ita * tf.expand_dims(tf.norm(adv_pi.ac, axis=1), axis=1)
    )  # 可能是这里震荡了, adv_pi.ac  或许这里我真的该同步两个lyapunov
    # labda_loss = -tf.reduce_mean(log_labda * l_derta)  # lambda的更新loss
    # lambda_train = tf.train.AdamOptimizer(LR_A).minimize(labda_loss, var_list=log_labda)  # alpha的优化器

    with tf.control_dependencies(target_update):
        if approx_value:
            if finite_horizon:
                l_target = LN_V  # 这里自己近似会不会好一点
            else:
                l_target = LN_R + gamma * (1 - L_terminal) * tf.stop_gradient(
                    l_old_)  # Lyapunov critic - self.alpha * next_log_pis
        else:
            l_target = LN_R
        l_error = tf.losses.mean_squared_error(labels=l_target, predictions=l)
        ltrain = tf.train.AdamOptimizer(LR_L).minimize(l_error,
                                                       var_list=l_params)
        Lyapunov_train = [ltrain]
        Lyapunov_opt_input = [
            LN_R, LN_V, L_terminal, ob, ob_, pro_ac, LN_d_input, LR_A, LR_L
        ]

        Lyapunov_opt = U.function(Lyapunov_opt_input, Lyapunov_train)
        Lyapunov_opt_loss = U.function(Lyapunov_opt_input, [l_error])
    # Lyapunov函数

    pro_kloldnew = pro_oldpi.pd.kl(pro_pi.pd)  # compute kl difference
    adv_kloldnew = adv_oldpi.pd.kl(adv_pi.pd)
    pro_ent = pro_pi.pd.entropy()
    adv_ent = adv_pi.pd.entropy()
    pro_meankl = tf.reduce_mean(pro_kloldnew)
    adv_meankl = tf.reduce_mean(adv_kloldnew)
    pro_meanent = tf.reduce_mean(pro_ent)
    adv_meanent = tf.reduce_mean(adv_ent)
    pro_pol_entpen = (-entcoeff) * pro_meanent
    adv_pol_entpen = (-entcoeff) * adv_meanent

    pro_ratio = tf.exp(pro_pi.pd.logp(pro_ac) -
                       pro_oldpi.pd.logp(pro_ac))  # pnew / pold
    adv_ratio = tf.exp(adv_pi.pd.logp(adv_ac) - adv_oldpi.pd.logp(adv_ac))
    pro_surr1 = pro_ratio * pro_atarg  # surrogate from conservative policy iteration
    adv_surr1 = adv_ratio * adv_atarg
    pro_surr2 = tf.clip_by_value(pro_ratio, 1.0 - clip_param,
                                 1.0 + clip_param) * pro_atarg  #
    adv_surr2 = tf.clip_by_value(adv_ratio, 1.0 - clip_param,
                                 1.0 + clip_param) * adv_atarg
    pro_pol_surr = -tf.reduce_mean(tf.minimum(
        pro_surr1, pro_surr2))  # PPO's pessimistic surrogate (L^CLIP)
    adv_pol_surr = -tf.reduce_mean(tf.minimum(adv_surr1, adv_surr2))
    pro_vf_loss = tf.reduce_mean(tf.square(pro_pi.vpred - ret_pro))
    adv_vf_loss = tf.reduce_mean(tf.square(adv_pi.vpred - ret_adv))
    pro_lyapunov_loss = tf.reduce_mean(-l_derta * labda)
    pro_total_loss = pro_pol_surr + pro_pol_entpen + pro_vf_loss + pro_lyapunov_loss

    adv_lyapunov_loss = tf.reduce_mean(-l_d_derta * labda)
    adv_total_loss = adv_pol_surr + adv_pol_entpen + adv_vf_loss + adv_lyapunov_loss

    pro_losses = [
        pro_pol_surr, pro_pol_entpen, pro_vf_loss, pro_meankl, pro_meanent,
        pro_lyapunov_loss
    ]
    pro_loss_names = [
        "pro_pol_surr", "pro_pol_entpen", "pro_vf_loss", "pro_kl", "pro_ent",
        "pro_lyapunov_loss"
    ]
    adv_losses = [
        adv_pol_surr, adv_pol_entpen, adv_vf_loss, adv_meankl, adv_meanent,
        adv_lyapunov_loss
    ]
    adv_loss_names = [
        "adv_pol_surr", "adv_pol_entpen", "adv_vf_loss", "adv_kl", "adv_ent",
        "adv_lyapunov_loss"
    ]

    pro_var_list = pro_pi.get_trainable_variables()
    adv_var_list = adv_pi.get_trainable_variables()
    pro_opt_input = [
        ob, pro_ac, pro_atarg, ret_pro, lrmult, LN_R, LN_V, ob_, LN_d_input,
        LR_A, LR_L, labda
    ]
    # pro_lossandgrad = U.function([ob, pro_ac, pro_atarg, ret, lrmult], pro_losses + [U.flatgrad(pro_total_loss, pro_var_list)])
    pro_lossandgrad = U.function(
        pro_opt_input, pro_losses + [U.flatgrad(pro_total_loss, pro_var_list)])

    # Lyapunov_grad = U.function(Lyapunov_opt_input, U.flatgrad(l_derta * labda, pro_var_list))

    adv_opt_input = [
        adv_ob, adv_ac, adv_atarg, ret_adv, lrmult, LN_R, adv_ob_, pro_ac,
        LN_a_input_, LR_A, LR_L, labda
    ]
    adv_lossandgrad = U.function(
        adv_opt_input, adv_losses + [U.flatgrad(adv_total_loss, adv_var_list)])
    pro_adam = MpiAdam(pro_var_list, epsilon=adam_epsilon)
    adv_adam = MpiAdam(adv_var_list, epsilon=adam_epsilon)

    pro_assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv) for (oldv, newv) in zipsame(
                pro_oldpi.get_variables(), pro_pi.get_variables())
        ])
    adv_assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv) for (oldv, newv) in zipsame(
                adv_oldpi.get_variables(), adv_pi.get_variables())
        ])
    pro_compute_losses = U.function(pro_opt_input, pro_losses)
    adv_compute_losses = U.function(adv_opt_input, adv_losses)

    # zp gymfc new
    saver = None
    if ckpt_dir:  # save model
        # Store for each one
        keep = int(max_timesteps /
                   float(save_timestep_period))  # number of model want to save
        print("[INFO] Keeping ", keep, " checkpoints")
        saver = tf.train.Saver(save_relative_paths=True, max_to_keep=keep)

    print('version: use discarl v2-v5')
    print('info:', 'alpha3', alpha3, 'SCALE_lambda_MIN_MAX', lr_labda,
          'Finite_horizon', finite_horizon, 'adv_mag',
          env.adv_action_space.high, 'timesteps', max_timesteps)

    U.initialize()
    pro_adam.sync()
    adv_adam.sync()

    if restore_dir:  # restore model
        ckpt = tf.train.get_checkpoint_state(restore_dir)
        if ckpt:
            # If there is one that already exists then restore it
            print("Restoring model from ", ckpt.model_checkpoint_path)
            saver.restore(tf.get_default_session(), ckpt.model_checkpoint_path)
        else:
            print("Trying to restore model from ", restore_dir,
                  " but doesn't exist")

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pro_pi,
                                     adv_pi,
                                     env,
                                     timesteps_per_batch,
                                     max_steps_episode,
                                     stochastic=True,
                                     clip_action=clip_action)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    stop_buffer = deque(maxlen=30)
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards
    costbuffer = deque(maxlen=100)
    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    next_ckpt_timestep = save_timestep_period

    while True:
        if callback: callback(locals(), globals())

        end = False
        if max_timesteps and timesteps_so_far >= max_timesteps:
            end = True
        elif max_episodes and episodes_so_far >= max_episodes:
            end = True
        elif max_iters and iters_so_far >= max_iters:
            end = True
        elif max_seconds and time.time() - tstart >= max_seconds:
            end = True

        if saver and ((timesteps_so_far >= next_ckpt_timestep) or end):
            task_name = "ppo-{}-{}.ckpt".format(env.spec.id, timesteps_so_far)
            fname = os.path.join(ckpt_dir, task_name)
            os.makedirs(os.path.dirname(fname), exist_ok=True)
            saver.save(tf.get_default_session(), fname)
            next_ckpt_timestep += save_timestep_period

        if end:  #and np.mean(stop_buffer) > zp_max_step:
            break

        if end and max_timesteps < 100:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        # 这里是对应的Lyapunov函数的学习率
        Lya_epsilon = 1.0 - (timesteps_so_far - 1.0) / max_timesteps
        if end:
            Lya_epsilon = 0.0001
        lr_a_this = lr_a * Lya_epsilon
        lr_l_this = lr_l * Lya_epsilon
        lr_labda_this = lr_labda * Lya_epsilon

        Percentage = min(timesteps_so_far / max_timesteps, 1) * 100
        logger.log("**********Iteration %i **Percentage %.2f **********" %
                   (iters_so_far, Percentage))

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, pro_ac, adv_ac, pro_atarg, adv_atarg, pro_tdlamret, adv_tdlamret = seg[
            "ob"], seg["pro_ac"], seg["adv_ac"], seg["pro_adv"], seg[
                "adv_adv"], seg["pro_tdlamret"], seg["adv_tdlamret"]
        rew = seg["rew"]
        ob_ = seg["ob_"]
        new = seg["new"]

        pro_vpredbefore = seg[
            "pro_vpred"]  # predicted value function before udpate
        adv_vpredbefore = seg["adv_vpred"]
        pro_atarg = (pro_atarg - pro_atarg.mean()) / pro_atarg.std(
        )  # standardized advantage function estimate
        adv_atarg = (adv_atarg - adv_atarg.mean()) / adv_atarg.std()

        # TODO
        # d = Dataset(dict(ob=ob, ac=pro_ac, atarg=pro_atarg, vtarg=pro_tdlamret), shuffle=not pro_pi.recurrent)
        d = Dataset(dict(ob=ob,
                         ob_=ob_,
                         rew=rew,
                         new=new,
                         ac=pro_ac,
                         adv=adv_ac,
                         atarg=pro_atarg,
                         vtarg=pro_tdlamret),
                    shuffle=not pro_pi.recurrent)  # 放入经验回放寄存器
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pro_pi, "ob_rms"):
            pro_pi.ob_rms.update(ob)  # update running mean/std for policy

        pro_assign_old_eq_new(
        )  # set old parameter values to new parameter values

        logger.log("Pro Optimizing...")
        logger.log(fmt_row(13, pro_loss_names))

        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            pro_losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                l_value = deepcopy(batch["atarg"])
                zp_fuk = l_value.reshape(-1, 1)

                # [LN_R, LN_V, ob, ob_, pro_ac, LN_d_input, LR_A, LR_L]
                Lyapunov_opt(batch["rew"].reshape(-1,
                                                  1), l_value.reshape(-1, 1),
                             batch["new"], batch["ob"], batch["ob_"],
                             batch["ac"], batch["adv"], lr_a_this, lr_l_this)
                *newlosses, g = pro_lossandgrad(
                    batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"],
                    cur_lrmult, batch["rew"].reshape(-1, 1),
                    l_value.reshape(-1, 1), batch["ob_"], batch["adv"],
                    lr_a_this, lr_l_this, lr_labda_this)

                pro_adam.update(g, optim_stepsize * cur_lrmult)
                pro_losses.append(newlosses)
            # logger.log(fmt_row(13, np.mean(pro_losses, axis=0)))

        # logger.log("Pro Evaluating losses...")
        pro_losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = pro_compute_losses(batch["ob"], batch["ac"],
                                           batch["atarg"], batch["vtarg"],
                                           cur_lrmult,
                                           batch["rew"].reshape(-1, 1),
                                           l_value.reshape(-1, 1),
                                           batch["ob_"], batch["adv"],
                                           lr_a_this, lr_l_this, lr_labda_this)
            pro_losses.append(newlosses)
        pro_meanlosses, _, _ = mpi_moments(pro_losses, axis=0)

        logger.log(fmt_row(13, pro_meanlosses))

        ac_ = sample_next_act(ob_, a_dim, pro_pi,
                              stochastic=True)  # ob, a_dim, policy, stochastic

        d = Dataset(dict(ob=ob,
                         adv_ac=adv_ac,
                         atarg=adv_atarg,
                         vtarg=adv_tdlamret,
                         ob_=ob_,
                         rew=rew,
                         ac_=ac_,
                         new=new,
                         pro_ac=pro_ac),
                    shuffle=not adv_pi.recurrent)

        if hasattr(adv_pi, "ob_rms"): adv_pi.ob_rms.update(ob)
        adv_assign_old_eq_new()

        # logger.log(fmt_row(13, adv_loss_names))
        logger.log("Adv Optimizing...")
        logger.log(fmt_row(13, adv_loss_names))
        for _ in range(optim_epochs):
            adv_losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                # adv_opt_input = [ob, adv_ac, adv_atarg, ret, lrmult, LN_R, ob_, LN_a_input_, LR_A, LR_L]
                # [ob, adv_ac, adv_atarg, ret, lrmult, LN_R, ob_, pro_ac, LN_a_input_, LR_A, LR_L]
                # ac_ = sample_next_act(batch["ob_"], a_dim, pro_pi, stochastic=True) # ob, a_dim, policy, stochastic
                *newlosses, g = adv_lossandgrad(
                    batch["ob"], batch["adv_ac"], batch["atarg"],
                    batch["vtarg"], cur_lrmult, batch["rew"].reshape(-1, 1),
                    batch["ob_"], batch["pro_ac"], batch["ac_"], lr_a_this,
                    lr_l_this, lr_labda_this)
                # *newlosses, g = adv_lossandgrad(batch["ob"], batch["adv_ac"], batch["atarg"], batch["vtarg"], cur_lrmult,
                #                                 batch["rew"].reshape(-1,1), batch["ob_"],
                #                                 batch["pro_ac"], batch["pro_ac"], lr_a_this, lr_l_this)
                adv_adam.update(g, optim_stepsize * cur_lrmult)
                adv_losses.append(newlosses)
            # logger.log(fmt_row(13, np.mean(adv_losses, axis=0)))
        # logger.log("Adv Evaluating losses...")
        adv_losses = []

        for batch in d.iterate_once(optim_batchsize):
            newlosses = adv_compute_losses(
                batch["ob"], batch["adv_ac"], batch["atarg"], batch["vtarg"],
                cur_lrmult, batch["rew"].reshape(-1, 1), batch["ob_"],
                batch["pro_ac"], batch["ac_"], lr_a_this, lr_l_this,
                lr_labda_this)
            adv_losses.append(newlosses)
        adv_meanlosses, _, _ = mpi_moments(adv_losses, axis=0)
        logger.log(fmt_row(13, adv_meanlosses))

        # curr_rew = evaluate(pro_pi, test_env)
        # rew_mean.append(curr_rew)
        # print(curr_rew)

        # logger.record_tabular("ev_tdlam_before", explained_variance(pro_vpredbefore, pro_tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))

        # print(rews)

        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        cost = seg["ep_cost"]
        costbuffer.extend(cost)

        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpCostMean", np.mean(costbuffer))
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)

        stop_buffer.extend(lens)

        logger.record_tabular("stop_flag", np.mean(stop_buffer))
        logger.dump_tabular()

        # print(stop_buffer)
        print(lr_labda_this)

    print('version: use discarl v2-v5')
    print('info:', 'alpha3', alpha3, 'SCALE_lambda_MIN_MAX', lr_labda,
          'Finite_horizon', finite_horizon, 'adv_mag',
          env.adv_action_space.high, 'timesteps', max_timesteps)

    return pro_pi, np.mean(rewbuffer), timesteps_so_far, np.mean(lenbuffer)
Exemplo n.º 17
0
    def update(self):
        global GLOBAL_UPDATE_COUNTER
        # seg_gen = traj_segment_generator(pi, env, horizon=timesteps_per_batch, stochastic=True)
        while not COORD.should_stop():
            if self.max_iters and self.iters_so_far >= self.max_iters:
                COORD.request_stop()
                break
            if
            UPDATE_EVENT.wait()
            logger.log("********** Iteration %i ************" % iters_so_far)
        # saver.restore()


        seg = seg_gen.__next__()

        #TODO: I got to fix this function to let it return the right seg["adv"] and seg["lamret"]
        add_vtarg_and_adv(seg, gamma, lam)

        losses = [[] for i in range(len1)]
        meanlosses = [[] for i in range(len1)]
        for i in range(len1):
            ob[i], ac[i], atarg[i], tdlamret[i]= seg["ob"][i], seg["ac"][i], seg["adv"][i], seg["tdlamret"][i]
            # ob_extend = np.expand_dims(ob[i],axis=0)
            # ob[i] = ob_extend
            vpredbefore = seg["vpred"][i] # predicted value function before udpate
            atarg[i] = (atarg[i] - atarg[i].mean()) / atarg[i].std() # standardized advantage function estimate
            d= Dataset(dict(ob=ob[i], ac=ac[i], atarg=atarg[i], vtarg=tdlamret[i]), shuffle=not pi[i].recurrent)
            optim_batchsize = optim_batchsize or ob[i].shape[0]

            if hasattr(pi[i], "ob_rms"): pi[i].ob_rms.update(ob[i]) # update running mean/std for policy

            #TODO: I have to make suer how assign_old_ea_new works and whether to assign it for each agent
            #Yes I can assure it will work now

            # save network parameters using tf.train.Saver
            #     saver_name = "saver" + str(iters_so_far)

            U.function([], [], updates=[tf.assign(oldv, newv) for (oldv, newv) in zipsame(oldpi[i].get_variables(), pi[i].get_variables())])()
             # set old parameter values to new parameter values
        # Here we do a bunch of optimization epochs over the data
            logger.log("Optimizing the agent{}...".format(i))
            logger.log(fmt_row(13, loss_names))
            for _ in range(optim_epochs):
                losses[i] = [] # list of tuples, each of which gives the loss for a minibatch
                for batch in d.iterate_once(optim_batchsize):
                    # batch["ob"] = np.expand_dims(batch["ob"], axis=0)
                    *newlosses, g = lossandgrad[i](batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult)
                    adam[i].update(g, optim_stepsize * cur_lrmult)
                    losses[i].append(newlosses)
                    logger.log(fmt_row(13, np.mean(losses[i], axis=0)))

            logger.log("Evaluating losses of agent{}...".format(i))
            losses[i] = []
            for batch in d.iterate_once(optim_batchsize):
                newlosses = compute_losses[i](batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult)
                losses[i].append(newlosses)
            meanlosses[i], _, _ = mpi_moments(losses[i], axis=0)
            logger.log(fmt_row(13, meanlosses[i]))
            for (lossval, name) in zipsame(meanlosses[i], loss_names):
                logger.record_tabular("loss_" + name, lossval)
            logger.record_tabular("ev_tdlam_before{}".format(i), explained_variance(vpredbefore, tdlamret[i]))

            lrlocal = (seg["ep_lens"][i], seg["ep_rets"][i]) # local values
            listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
            lens, rews = map(flatten_lists, zip(*listoflrpairs))
            lenbuffer[i].extend(lens)
            rewbuffer[i].extend(rews)
            logger.record_tabular("EpLenMean {}".format(i), np.mean(lenbuffer[i]))
            logger.record_tabular("EpRewMean {}".format(i), np.mean(rewbuffer[i]))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        temp_pi = policy_func("temp_pi"+str(iters_so_far) , ob_space[0], ac_space[0], placeholder_name="temp_pi_observation" + str(iters_so_far))
        U.function([], [], updates=[tf.assign(oldv, newv) for (oldv, newv) in
                                    zipsame(temp_pi.get_variables(), pi[0].get_variables())])()
        parameters_savers.append(temp_pi)
        if iters_so_far % 3 == 0:
            sample_iteration = int(np.random.uniform(iters_so_far / 2, iters_so_far))
            print("now assign the {}th parameter of agent0 to agent1".format(sample_iteration))
            pi_restore = parameters_savers[sample_iteration]
            U.function([], [], updates=[tf.assign(oldv, newv) for (oldv, newv) in
                                    zipsame(pi[1].get_variables(), pi_restore.get_variables())])()

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()


    # # now when the training iteration ends, save the trained model and test the win rate of the two.
    # pi0_variables = slim.get_variables(scope = "pi0")
    # pi1_variables = slim.get_variables(scope = "pi1")
    # parameters_to_save_list0 = [v for v in pi0_variables]
    # parameters_to_save_list1 = [v for v in pi1_variables]
    # parameters_to_save_list = parameters_to_save_list0 + parameters_to_save_list1
    # saver = tf.train.Saver(parameters_to_save_list)
    # parameters_path = 'parameter/'
    # tf.train.Saver()
    save_path = saver.save(U.get_session(), "saveparameter/35/35.pkl")
Exemplo n.º 18
0
def learn(env, policy_fn, *,
          timesteps_per_actorbatch,  # timesteps per actor per update
          clip_param, entcoeff,  # clipping parameter epsilon, entropy coeff
          optim_epochs, optim_stepsize, optim_batchsize,  # optimization hypers
          gamma, lam,  # advantage estimation
          max_timesteps=0, max_episodes=0, max_iters=0, max_seconds=0,  # time constraint
          callback=None,  # you can do anything in the callback, since it takes locals(), globals()
          adam_epsilon=1e-5,
          schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
          gae_kstep=None,
          env_eval=None,
          saved_model=None,
          eval_at=50,
          save_at=50,
          normalize_atarg=True,
          experiment_spec=None,  # dict with: experiment_name, experiment_folder
          **extra_args
          ):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_fn("pi", ob_space, ac_space)  # Construct network for new policy
    oldpi = policy_fn("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(dtype=tf.float32, shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(name='lrmult', dtype=tf.float32, shape=[])  # learning rate multiplier, updated with schedule
    entromult = tf.placeholder(name='entromult', dtype=tf.float32, shape=[])  # entropy penalty multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    MPI_n_ranks = MPI.COMM_WORLD.Get_size()

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg  #
    pol_surr = - tf.reduce_mean(tf.minimum(surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult, entromult], losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function([], [], updates=[tf.assign(oldv, newv)
                                                    for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult, entromult], losses)

    acts = list()
    stats1 = list()
    stats2 = list()
    stats3 = list()
    for p in range(ac_space.shape[0]):
        acts.append(tf.placeholder(tf.float32, name="act_{}".format(p + 1)))
        stats1.append(tf.placeholder(tf.float32, name="stats1_{}".format(p + 1)))
        stats2.append(tf.placeholder(tf.float32, name="stats2_{}".format(p + 1)))
        stats3.append(tf.placeholder(tf.float32, name="stats3_{}".format(p + 1)))
        tf.summary.histogram("act_{}".format(p), acts[p])
        if pi.dist == 'gaussian':
            tf.summary.histogram("pd_mean_{}".format(p), stats1[p])
            tf.summary.histogram("pd_std_{}".format(p), stats2[p])
            tf.summary.histogram("pd_logstd_{}".format(p), stats3[p])
        else:
            tf.summary.histogram("pd_beta_{}".format(p), stats1[p])
            tf.summary.histogram("pd_alpha_{}".format(p), stats2[p])
            tf.summary.histogram("pd_alpha_beta_{}".format(p), stats3[p])

    rew = tf.placeholder(tf.float32, name="rew")
    tf.summary.histogram("rew", rew)
    summaries = tf.summary.merge_all()
    gather_summaries = U.function([ob, *acts, *stats1, *stats2, *stats3, rew], summaries)

    U.initialize()
    adam.sync()
    if saved_model is not None:
        U.load_state(saved_model)

    if (MPI.COMM_WORLD.Get_rank() == 0) & (experiment_spec is not None):
        # TensorBoard & Saver
        # ----------------------------------------
        if experiment_spec['experiment_folder'] is not None:
            path_tb = os.path.join(experiment_spec['experiment_folder'], 'tensorboard')
            path_logs = os.path.join(experiment_spec['experiment_folder'], 'logs')
            exp_name = '' if experiment_spec['experiment_name'] is not None else experiment_spec['experiment_name']
            summary_file = tf.summary.FileWriter(os.path.join(path_tb, exp_name), U.get_session().graph)
            saver = tf.train.Saver(max_to_keep=None)
            logger.configure(dir=os.path.join(path_logs, exp_name))
    else:
        logger.configure(format_strs=[])

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi, env, timesteps_per_actorbatch, stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0, max_seconds > 0]) == 1, "Only one time constraint permitted"

    while True:
        if callback:
            callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(iters_so_far) / max_iters, 0)
        elif 'exp' in schedule:
            current_lr = schedule.strip()
            _, d = current_lr.split('__')
            cur_lrmult = float(d) ** (float(iters_so_far) / max_iters)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()
        if gae_kstep is None:
            add_vtarg_and_adv(seg, gamma, lam)
            T = len(seg["rew"])
        else:
            calculate_advantage_and_vtarg(seg, gamma, lam, k_step=gae_kstep)
            T = len(seg["rew"]) - gae_kstep

        ob, ac, atarg, tdlamret = seg["ob"][:T], seg["ac"][:T], seg["adv"][:T], seg["tdlamret"][:T]
        vpredbefore = seg["vpred"][:T]  # predicted value function before udpate

        if normalize_atarg:
            eps = 1e-9
            atarg = (atarg - atarg.mean()) / (atarg.std() + eps)  # standardized advantage function estimate

        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret), shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        g_max = -np.Inf
        g_min = np.Inf
        g_mean = []
        for _ in range(optim_epochs):
            losses = []  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult, cur_lrmult)
                g_max = g.max() if g.max() > g_max else g_max
                g_min = g.min() if g.min() < g_min else g_min
                g_mean.append(g.mean())
                if np.isnan(np.sum(g)):
                    print('NaN in Gradient, skipping this update')
                    continue
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
                # logger.log(fmt_row(13, np.mean(losses, axis=0)))

        summary = tf.Summary()
        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult, cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.log(fmt_row(13, meanlosses))
            for (lossval, name) in zipsame(meanlosses, loss_names):
                logger.record_tabular("loss_" + name, lossval)
                summary.value.add(tag="loss_" + name, simple_value=lossval)

        logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("ItersSoFar (%)", iters_so_far / max_iters * 100)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        logger.record_tabular("TimePerIter", (time.time() - tstart) / (iters_so_far + 1))

        if MPI.COMM_WORLD.Get_rank() == 0:
            # Saves model
            if ((iters_so_far % save_at) == 0) & (iters_so_far != 0):
                if experiment_spec['experiment_folder'] is not None:
                    path_models = os.path.join(experiment_spec['experiment_folder'], 'models')
                    dir_path = os.path.join(path_models, exp_name)
                    if not os.path.exists(dir_path):
                        os.makedirs(dir_path)
                    saver.save(U.get_session(), os.path.join(dir_path, 'model'), global_step=iters_so_far)

        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()

            summ = gather_summaries(ob,
                                    *np.split(ac, ac_space.shape[0], axis=1),
                                    *np.split(seg['stat1'], ac_space.shape[0], axis=1),
                                    *np.split(seg['stat2'], ac_space.shape[0], axis=1),
                                    *np.split(seg['stat3'], ac_space.shape[0], axis=1),
                                    seg['rew'])

            summary.value.add(tag="total_loss", simple_value=meanlosses[:3].sum())
            summary.value.add(tag="explained_variance", simple_value=explained_variance(vpredbefore, tdlamret))
            summary.value.add(tag='EpRewMean', simple_value=np.mean(rewbuffer))
            summary.value.add(tag='EpLenMean', simple_value=np.mean(lenbuffer))
            summary.value.add(tag='EpThisIter', simple_value=len(lens))
            summary.value.add(tag='atarg_max', simple_value=atarg.max())
            summary.value.add(tag='atarg_min', simple_value=atarg.min())
            summary.value.add(tag='atarg_mean', simple_value=atarg.mean())
            summary.value.add(tag='GMean', simple_value=np.mean(g_mean))
            summary.value.add(tag='GMax', simple_value=g_max)
            summary.value.add(tag='GMin', simple_value=g_min)
            summary.value.add(tag='learning_rate', simple_value=cur_lrmult * optim_stepsize)
            summary.value.add(tag='AcMAX', simple_value=np.mean(seg["ac"].max()))
            summary.value.add(tag='AcMIN', simple_value=np.mean(seg["ac"].min()))
            summary_file.add_summary(summary, iters_so_far)
            summary_file.add_summary(summ, iters_so_far)

        iters_so_far += 1
Exemplo n.º 19
0
def learn(
    env,
    policy_fn,
    *,
    timesteps_per_actorbatch,  # timesteps per actor per update
    clip_param,
    entcoeff,  # clipping parameter epsilon, entropy coeff
    optim_epochs,
    optim_stepsize,
    optim_batchsize,  # optimization hypers
    gamma,
    lam,  # advantage estimation
    max_timesteps=0,
    max_episodes=0,
    max_iters=0,
    max_seconds=0,  # time constraint
    callback=None,  # you can do anything in the callback, since it takes locals(), globals()
    adam_epsilon=1e-5,
    schedule='constant'  # annealing for stepsize parameters (epsilon and adam)
):
    # to store (timestep, min_reward, max_reward, avg_reward) tuples from each
    # iteration
    graph_data = []

    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_fn("pi", ob_space,
                   ac_space)  # Construct network for new policy
    oldpi = policy_fn("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    # learning rate multiplier, updated according to the schedule
    lrmult = tf.placeholder(name='lrmult', dtype=tf.float32, shape=[])
    #clip_param = clip_param * lrmult # Annealed cliping parameter epislon

    # retrieve ob placeholder
    ob = U.get_placeholder_cached(name="ob")
    # get a 2d placeholder for a list of 1d action
    ac = pi.pdtype.sample_placeholder([None])

    # get tensor for the KL-divergence between the old and new action gaussians
    kloldnew = oldpi.pd.kl(pi.pd)
    # get tensor for the entropy of the new action gaussian
    ent = pi.pd.entropy()
    # take the mean of all kl divergences and entropies
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg
    # take the average to get the expected value of the current batch
    pol_surr = -tf.reduce_mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    # TODO remove experimenting -
    name_correspondences = \
    {
        "fc1/kernel":"fc.0.weight",
        "fc1/bias":"fc.0.bias",
        "fc2/kernel":"fc.1.weight",
        "fc2/bias":"fc.1.bias",
        "final/kernel":"fc.2.weight",
        "final/bias":"fc.2.bias",
    }
    for tf_name, torch_name, out_var in [('pi/vf', 'value_net', pi.vpred),\
            ('pi/pol', 'pol_net', pi.mean)]:
        value_dict = {}

        prefix = tf_name

        for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, prefix):
            name = var.name[(len(prefix) + 1):-2]
            if "logstd" in name:
                continue
            value = tf.get_default_session().run(var)
            kernel = "kernel" in name
            if kernel:
                value = value.T
            value_dict[name_correspondences[name]] = value
        with open(torch_name + '_state_dict', \
            'wb+') as file:
            pickle.dump(value_dict, file)

        print(tf.get_default_session().run(out_var,\
            feed_dict={ob:np.array([[1.0, 2.0, 3.0, 4.0]])}))

    # - TODO remove experimenting

    adam.sync()

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_actorbatch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    result = 0

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)
        # TODO remove
        with open('acs', 'wb+') as file:
            pickle.dump(seg['ac'], file)
        # remove TODO

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]

        assign_old_eq_new()  # set old parameter values to new parameter values
        # TODO remove
        *losses, g = lossandgrad(seg["ob"], seg["ac"], seg["adv"],
                                 seg["tdlamret"], cur_lrmult)
        ratio_res = tf.get_default_session().run(ratio, feed_dict=\
        {
            lossandgrad.inputs[0]: seg["ob"],
            lossandgrad.inputs[1]: seg["ac"],
            lossandgrad.inputs[2]: seg["adv"],
            lossandgrad.inputs[3]: seg["tdlamret"]
        }
        )
        surr_res = tf.get_default_session().run(surr1, feed_dict=\
        {
            lossandgrad.inputs[0]: seg["ob"],
            lossandgrad.inputs[1]: seg["ac"],
            lossandgrad.inputs[2]: seg["adv"],
            lossandgrad.inputs[3]: seg["tdlamret"]
        }
        )
        pred_vals = tf.get_default_session().run(pi.vpred, feed_dict=\
        {
            lossandgrad.inputs[0]: seg["ob"],
            lossandgrad.inputs[1]: seg["ac"],
            lossandgrad.inputs[2]: seg["adv"],
            lossandgrad.inputs[3]: seg["tdlamret"]
        })
        print(losses[0], losses[2])
        print(seg["tdlamret"])
        print(pred_vals)
        import sys
        sys.exit()
        # remove TODO

        vpredbefore = seg["vpred"]  # predicted value function before udpate
        #atarg = (atarg - atarg.mean()) / atarg.std() # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                #logger.log(fmt_row(13, newlosses))
                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        rewmean = np.mean(rewbuffer)
        rewmin = np.min(rewbuffer)
        rewmax = np.max(rewbuffer)
        timesteps_so_far += sum(lens)
        graph_data.append((timesteps_so_far, rewmin, rewmax, rewmean))
        result = rewmean
        logger.record_tabular("EpRewMean", rewmean)
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()

        # TODO remove
        import sys
        sys.exit()

    return pi, result, graph_data
Exemplo n.º 20
0
def learn(
        env,
        policy_fn,
        *,
        timesteps_per_actorbatch,  # timesteps per actor per update
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        optim_epochs,
        optim_stepsize,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,  # advantage estimation
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        max_seconds=0,  # time constraint
        callback=None,  # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
        flight_log=None,
        restore_dir=None,
        ckpt_dir=None,
        save_timestep_period=1000):

    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_fn("pi", ob_space,
                   ac_space)  # Construct network for new policy
    oldpi = policy_fn("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param,
                             1.0 + clip_param) * atarg  #
    pol_surr = -tf.reduce_mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    saver = None
    if ckpt_dir:
        # Store for each one
        keep = int(max_timesteps / float(save_timestep_period))
        print("[INFO] Keeping ", keep, " checkpoints")
        saver = tf.train.Saver(save_relative_paths=True, max_to_keep=keep)

    U.initialize()
    adam.sync()

    if restore_dir:
        ckpt = tf.train.get_checkpoint_state(restore_dir)
        if ckpt:
            # If there is one that already exists then restore it
            print("Restoring model from ", ckpt.model_checkpoint_path)
            saver.restore(tf.get_default_session(), ckpt.model_checkpoint_path)
        else:
            print("Trying to restore model from ", restore_dir,
                  " but doesn't exist")

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_actorbatch,
                                     stochastic=True,
                                     flight_log=flight_log)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"
    next_ckpt_timestep = save_timestep_period
    while True:
        if callback: callback(locals(), globals())

        end = False
        if max_timesteps and timesteps_so_far >= max_timesteps:
            end = True
        elif max_episodes and episodes_so_far >= max_episodes:
            end = True
        elif max_iters and iters_so_far >= max_iters:
            end = True
        elif max_seconds and time.time() - tstart >= max_seconds:
            end = True

        # How often should we create checkpoints
        # Because of the iterations deployed in batches this might not happen exactly
        if saver and ((timesteps_so_far >= next_ckpt_timestep) or end):
            task_name = "ppo1-{}-{}.ckpt".format(env.spec.id, timesteps_so_far)
            fname = os.path.join(ckpt_dir, task_name)
            os.makedirs(os.path.dirname(fname), exist_ok=True)
            saver.save(tf.get_default_session(), fname)
            next_ckpt_timestep += save_timestep_period

        if end:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()
Exemplo n.º 21
0
def learn(
        env,
        test_env,
        policy_fn,
        *,
        timesteps_per_actorbatch,  # timesteps per actor per update
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        optim_epochs,
        optim_stepsize,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,  # advantage estimation
        # CMAES
    max_fitness,  # has to be negative, as cmaes consider minization
        popsize,
        gensize,
        bounds,
        sigma,
        eval_iters,
        max_v_train_iter,
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        max_seconds=0,
        # time constraint
        callback=None,
        # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant',
        # annealing for stepsize parameters (epsilon and adam)
        seed,
        env_id):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_fn("pi", ob_space,
                   ac_space)  # Construct network for new policy
    oldpi = policy_fn("oldpi", ob_space, ac_space)  # Network for old policy
    backup_pi = policy_fn(
        "backup_pi", ob_space, ac_space
    )  # Construct a network for every individual to adapt during the es evolution

    reward = tf.placeholder(dtype=tf.float32, shape=[None])  # step rewards
    pi_params = tf.placeholder(dtype=tf.float32, shape=[None])
    old_pi_params = tf.placeholder(dtype=tf.float32, shape=[None])
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule

    bound_coeff = tf.placeholder(
        name='bound_coeff', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule

    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    next_ob = U.get_placeholder_cached(
        name="next_ob")  # next step observation for updating q function
    ac = U.get_placeholder_cached(
        name="act")  # action placeholder for computing q function
    mean_ac = U.get_placeholder_cached(
        name="mean_act")  # action placeholder for computing q function

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    pi_adv = pi.qpred - pi.vpred
    adv_mean, adv_var = tf.nn.moments(pi_adv, axes=[0])
    normalized_pi_adv = (pi_adv - adv_mean) / tf.sqrt(adv_var)

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * normalized_pi_adv  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param,
                             1.0 + clip_param) * normalized_pi_adv  #
    pol_surr = -tf.reduce_mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)

    qf_loss = tf.reduce_mean(
        tf.square(reward + gamma * pi.mean_qpred - pi.qpred))

    # qf_loss = tf.reduce_mean(U.huber_loss(pi.qpred - tf.stop_gradient(reward + gamma * pi.mean_qpred)))
    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    qf_losses = [qf_loss]
    vf_losses = [vf_loss]

    # pol_loss = -tf.reduce_mean(pi_adv)
    pol_loss = pol_surr + pol_entpen
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    vf_var_list = [
        v for v in var_list if v.name.split("/")[1].startswith("vf")
    ]
    pol_var_list = [
        v for v in var_list if v.name.split("/")[1].startswith("pol")
    ]
    qf_var_list = [
        v for v in var_list if v.name.split("/")[1].startswith("qf")
    ]
    mean_qf_var_list = [
        v for v in var_list if v.name.split("/")[1].startswith("meanqf")
    ]

    vf_lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                                losses + [U.flatgrad(vf_loss, vf_var_list)])
    qf_lossandgrad = U.function(
        [ob, ac, next_ob, mean_ac, lrmult, reward, atarg],
        qf_losses + [U.flatgrad(qf_loss, qf_var_list)])

    qf_adam = MpiAdam(qf_var_list, epsilon=adam_epsilon)

    vf_adam = MpiAdam(vf_var_list, epsilon=adam_epsilon)

    assign_target_q_eq_eval_q = U.function(
        [], [],
        updates=[
            tf.assign(target_q, eval_q)
            for (target_q, eval_q) in zipsame(mean_qf_var_list, qf_var_list)
        ])

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])

    assign_backup_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(backup_v, newv) for (
                backup_v,
                newv) in zipsame(backup_pi.get_variables(), pi.get_variables())
        ])
    assign_new_eq_backup = U.function(
        [], [],
        updates=[
            tf.assign(newv, backup_v)
            for (newv, backup_v
                 ) in zipsame(pi.get_variables(), backup_pi.get_variables())
        ])
    # Compute all losses

    mean_pi_actions = U.function([ob], [pi.pd.mode()])
    compute_pol_losses = U.function([ob, ac, atarg, ret, lrmult],
                                    [pol_loss, pol_surr, pol_entpen, meankl])

    U.initialize()

    get_pi_flat_params = U.GetFlat(pol_var_list)
    set_pi_flat_params = U.SetFromFlat(pol_var_list)

    vf_adam.sync()

    global timesteps_so_far, episodes_so_far, iters_so_far, \
        tstart, lenbuffer, rewbuffer, tstart, ppo_timesteps_so_far, best_fitness

    episodes_so_far = 0
    timesteps_so_far = 0
    ppo_timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    eval_gen = traj_segment_generator_eval(pi,
                                           test_env,
                                           timesteps_per_actorbatch,
                                           stochastic=True)  # For evaluation
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_actorbatch,
                                     stochastic=True,
                                     eval_gen=eval_gen)  # For train V Func

    # Build generator for all solutions
    actors = []
    best_fitness = 0
    for i in range(popsize):
        newActor = traj_segment_generator(pi,
                                          env,
                                          timesteps_per_actorbatch,
                                          stochastic=True,
                                          eval_gen=eval_gen)
        actors.append(newActor)

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    while True:
        if max_timesteps and timesteps_so_far >= max_timesteps:
            print("Max time steps")
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            print("Max episodes")
            break
        elif max_iters and iters_so_far >= max_iters:
            print("Max iterations")
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            print("Max time")
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)

        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        # Generate new samples
        # Train V func
        for i in range(max_v_train_iter):
            logger.log("Iteration:" + str(iters_so_far) +
                       " - sub-train iter for V func:" + str(i))
            logger.log("Generate New Samples")
            seg = seg_gen.__next__()
            add_vtarg_and_adv(seg, gamma, lam)

            ob, ac, next_ob, atarg, reward, tdlamret, traj_idx = seg["ob"], seg["ac"], seg["next_ob"], seg["adv"], seg[
                "rew"], seg["tdlamret"], \
                                                                 seg["traj_index"]
            atarg = (atarg - atarg.mean()) / atarg.std(
            )  # standardized advantage function estimate
            d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                        shuffle=not pi.recurrent)
            optim_batchsize = optim_batchsize or ob.shape[0]

            if hasattr(pi, "ob_rms"):
                pi.ob_rms.update(
                    ob)  # update running mean/std for normalization

            assign_old_eq_new(
            )  # set old parameter values to new parameter values
            # Train V function
            logger.log("Training V Func and Evaluating V Func Losses")
            for _ in range(optim_epochs):
                losses = [
                ]  # list of tuples, each of which gives the loss for a minibatch
                for batch in d.iterate_once(optim_batchsize):
                    *vf_losses, g = vf_lossandgrad(batch["ob"], batch["ac"],
                                                   batch["atarg"],
                                                   batch["vtarg"], cur_lrmult)
                    vf_adam.update(g, optim_stepsize * cur_lrmult)
                    losses.append(vf_losses)
                logger.log(fmt_row(13, np.mean(losses, axis=0)))

            d_q = Dataset(dict(ob=ob,
                               ac=ac,
                               next_ob=next_ob,
                               reward=reward,
                               atarg=atarg,
                               vtarg=tdlamret),
                          shuffle=not pi.recurrent)

            # Re-train q function
            logger.log("Training Q Func Evaluating Q Func Losses")
            for _ in range(optim_epochs):
                losses = [
                ]  # list of tuples, each of which gives the loss for a minibatch
                for batch in d_q.iterate_once(optim_batchsize):
                    *qf_losses, g = qf_lossandgrad(
                        batch["ob"], batch["ac"], batch["next_ob"],
                        mean_pi_actions(batch["ob"])[0], cur_lrmult,
                        batch["reward"], batch["atarg"])
                    qf_adam.update(g, optim_stepsize * cur_lrmult)
                    losses.append(qf_losses)
                logger.log(fmt_row(13, np.mean(losses, axis=0)))

            assign_target_q_eq_eval_q()

        # CMAES Train Policy
        assign_old_eq_new()  # set old parameter values to new parameter values
        assign_backup_eq_new()  # backup current policy
        flatten_weights = get_pi_flat_params()
        opt = cma.CMAOptions()
        opt['tolfun'] = max_fitness
        opt['popsize'] = popsize
        opt['maxiter'] = gensize
        opt['verb_disp'] = 0
        opt['verb_log'] = 0
        opt['seed'] = seed
        opt['AdaptSigma'] = True
        es = cma.CMAEvolutionStrategy(flatten_weights, sigma, opt)
        while True:
            if es.countiter >= gensize:
                logger.log("Max generations for current layer")
                break
            logger.log("Iteration:" + str(iters_so_far) +
                       " - sub-train Generation for Policy:" +
                       str(es.countiter))
            logger.log("Sigma=" + str(es.sigma))
            solutions = es.ask()
            costs = []
            lens = []

            assign_backup_eq_new()  # backup current policy

            for id, solution in enumerate(solutions):
                set_pi_flat_params(solution)
                losses = []
                cost = compute_pol_losses(ob, ac, atarg, tdlamret, cur_lrmult)
                costs.append(cost[0])
                assign_new_eq_backup()
            # Weights decay
            l2_decay = compute_weight_decay(0.99, solutions)
            costs += l2_decay
            costs, real_costs = fitness_rank(costs)
            es.tell_real_seg(solutions=solutions,
                             function_values=costs,
                             real_f=real_costs,
                             segs=None)
            best_solution = es.result[0]
            best_fitness = es.result[1]
            logger.log("Best Solution Fitness:" + str(best_fitness))
            set_pi_flat_params(best_solution)

        iters_so_far += 1
        episodes_so_far += sum(lens)
Exemplo n.º 22
0
def learn(
        env,
        policy_func,
        *,
        timesteps_per_batch,  # timesteps per actor per update
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        optim_epochs,
        optim_stepsize,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,  # advantage estimation
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        max_seconds=0,  # time constraint
        callback=None,  # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
        init_policy_params=None,
        policy_scope='pi'):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space

    pi = policy_func(policy_scope, ob_space,
                     ac_space)  # Construct network for new policy
    oldpi = policy_func("old" + policy_scope, ob_space,
                        ac_space)  # Network for old policy

    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return
    clip_tf = tf.placeholder(dtype=tf.float32)

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)

    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_tf,
                             1.0 + clip_tf * lrmult) * atarg
    pol_surr = -tf.reduce_mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)

    if hasattr(pi, 'additional_loss'):
        pol_surr += pi.additional_loss

    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))

    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    pol_var_list = [
        v for v in pi.get_trainable_variables()
        if 'placehold' not in v.name and 'offset' not in v.name and 'secondary'
        not in v.name and 'vf' not in v.name and 'pol' in v.name
    ]
    pol_var_size = np.sum([np.prod(v.shape) for v in pol_var_list])

    get_pol_flat = U.GetFlat(pol_var_list)
    set_pol_from_flat = U.SetFromFlat(pol_var_list)

    total_loss = pol_surr + pol_entpen + vf_loss

    lossandgrad = U.function([ob, ac, atarg, ret, lrmult, clip_tf],
                             losses + [U.flatgrad(total_loss, var_list)])
    pol_lossandgrad = U.function([ob, ac, atarg, ret, lrmult, clip_tf],
                                 losses +
                                 [U.flatgrad(total_loss, pol_var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult, clip_tf], losses)
    compute_losses_cpo = U.function(
        [ob, ac, atarg, ret, lrmult, clip_tf],
        [tf.reduce_mean(surr1), pol_entpen, vf_loss, meankl, meanent])
    compute_ratios = U.function([ob, ac], ratio)
    compute_kls = U.function([ob, ac], kloldnew)

    compute_rollout_old_prob = U.function([ob, ac],
                                          tf.reduce_mean(oldpi.pd.logp(ac)))
    compute_rollout_new_prob = U.function([ob, ac],
                                          tf.reduce_mean(pi.pd.logp(ac)))
    compute_rollout_new_prob_min = U.function([ob, ac],
                                              tf.reduce_min(pi.pd.logp(ac)))

    update_ops = {}
    update_placeholders = {}
    for v in pi.get_trainable_variables():
        update_placeholders[v.name] = tf.placeholder(v.dtype,
                                                     shape=v.get_shape())
        update_ops[v.name] = v.assign(update_placeholders[v.name])

    # compute fisher information matrix
    dims = [int(np.prod(p.shape)) for p in pol_var_list]
    logprob_grad = U.flatgrad(tf.reduce_mean(pi.pd.logp(ac)), pol_var_list)
    compute_logprob_grad = U.function([ob, ac], logprob_grad)

    U.initialize()

    adam.sync()

    if init_policy_params is not None:
        cur_scope = pi.get_variables()[0].name[0:pi.get_variables()[0].name.
                                               find('/')]
        orig_scope = list(init_policy_params.keys()
                          )[0][0:list(init_policy_params.keys())[0].find('/')]
        print(cur_scope, orig_scope)
        for i in range(len(pi.get_variables())):
            if pi.get_variables()[i].name.replace(cur_scope, orig_scope,
                                                  1) in init_policy_params:
                assign_op = pi.get_variables()[i].assign(
                    init_policy_params[pi.get_variables()[i].name.replace(
                        cur_scope, orig_scope, 1)])
                tf.get_default_session().run(assign_op)
                assign_op = oldpi.get_variables()[i].assign(
                    init_policy_params[pi.get_variables()[i].name.replace(
                        cur_scope, orig_scope, 1)])
                tf.get_default_session().run(assign_op)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_batch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    prev_params = {}
    for v in var_list:
        if 'pol' in v.name:
            prev_params[v.name] = v.eval()

    optim_seg = None

    grad_scale = 1.0

    while True:
        if MPI.COMM_WORLD.Get_rank() == 0:
            print('begin')
            memory()
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()

        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        unstandardized_adv = np.copy(atarg)
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate

        args = seg["ob"], seg["ac"], atarg
        fvpargs = [arr for arr in args]

        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))

        cur_clip_val = clip_param

        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)

        for epoch in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult, cur_clip_val)
                adam.update(g * grad_scale, optim_stepsize * cur_lrmult)
                losses.append(newlosses)

            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult, clip_param)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)

        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.log(fmt_row(13, meanlosses))
            for (lossval, name) in zipsame(meanlosses, loss_names):
                logger.record_tabular("loss_" + name, lossval)

        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.record_tabular("EpLenMean", np.mean(lenbuffer))
            logger.record_tabular("EpRewMean", np.mean(rewbuffer))
            logger.record_tabular("EpThisIter", len(lens))
            logger.record_tabular(
                "PolVariance",
                repr(adam.getflat()[-env.action_space.shape[0]:]))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.record_tabular("EpisodesSoFar", episodes_so_far)
            logger.record_tabular("TimestepsSoFar", timesteps_so_far)
            logger.record_tabular("TimeElapsed", time.time() - tstart)

        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()
        if MPI.COMM_WORLD.Get_rank() == 0:
            print('end')
            memory()

    return pi, np.mean(rewbuffer)
def learn(
    env,
    policy_fn,
    timesteps_per_actorbatch,  # timesteps per actor per update
    clip_param,
    entcoeff,  # clipping parameter epsilon, entropy coeff
    optim_epochs,
    optim_stepsize,
    optim_batchsize,  # optimization hypers
    gamma,
    lam,  # advantage estimation
    max_timesteps=0,
    max_episodes=0,
    max_iters=0,
    max_seconds=0,  # time constraint
    callback=None,  # you can do anything in the callback, since it takes locals(), globals()
    adam_epsilon=1e-5,
    schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
    resume_training=False,
):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_fn("pi", ob_space,
                   ac_space)  # Construct network for new policy
    oldpi = policy_fn("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(
        name="atarg", dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(name="ret", dtype=tf.float32,
                         shape=[None])  # Empirical return

    summ_writer = tf.summary.FileWriter("/tmp/tensorboard",
                                        U.get_session().graph)
    U.launch_tensorboard_in_background("/tmp/tensorboard")

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule

    ob = U.get_placeholder_cached(name="ob")
    ob_2 = U.get_placeholder_cached(name="ob_2")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param,
                             1.0 + clip_param) * atarg  #
    pol_surr = -tf.reduce_mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    check_ops = tf.add_check_numerics_ops()

    var_list = pi.get_trainable_variables()
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    lossandgrad = U.function(
        [ob, ob_2, ac, atarg, ret, lrmult],
        losses + [U.flatgrad(total_loss, var_list)],
        updates=[check_ops],
    )
    debugnan = U.function([ob, ob_2, ac, atarg, ret, lrmult],
                          losses + [ratio, surr1, surr2])
    dbgnames = loss_names + ["ratio", "surr1", "surr2"]

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ob_2, ac, atarg, ret, lrmult], losses)

    U.initialize()
    adam.sync()

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_actorbatch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    if resume_training:
        pi.load_variables("/tmp/rlnav_model")
        oldpi.load_variables("/tmp/rlnav_model")
    else:
        # clear reward history log
        with open(rew_hist_filepath, 'w') as f:
            f.write('')

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ob_2, ac, atarg, tdlamret = seg["ob"], seg["ob_2"], seg["ac"], seg[
            "adv"], seg["tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        # flatten first two dimensions and put into batch maker
        dataset_dict = dict(ob=ob,
                            ob_2=ob_2,
                            ac=ac,
                            atarg=atarg,
                            vtarg=tdlamret)

        def flatten_horizon_and_agent_dims(array):
            """ using F order because we assume that the shape is (horizon, n_agents)
            and we want the new flattened first dimension to first run through
            horizon, then n_agents in order to not cut up the sequentiality of
            the experiences """
            new_shape = (array.shape[0] * array.shape[1], ) + array.shape[2:]
            return array.reshape(new_shape, order='F')

        for key in dataset_dict:
            dataset_dict[key] = flatten_horizon_and_agent_dims(
                dataset_dict[key])
        d = Dataset(dataset_dict, shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]
        n_agents = ob.shape[1]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        #export rewards to log file
        rewplot = np.array(seg["rew"])
        with open(rew_hist_filepath, 'ab') as f:
            np.savetxt(f, rewplot, delimiter=',')

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                landg = lossandgrad(batch["ob"], batch["ob_2"], batch["ac"],
                                    batch["atarg"], batch["vtarg"], cur_lrmult)
                newlosses, g = landg[:-1], landg[-1]
                # debug nans
                if np.any(np.isnan(newlosses)):
                    dbglosses = debugnan(batch["ob"], batch["ob_2"],
                                         batch["ac"], batch["atarg"],
                                         batch["vtarg"], cur_lrmult)
                    raise ValueError("Nan detected in losses: {} {}".format(
                        dbgnames, dbglosses))
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        pi.save_variables("/tmp/rlnav_model")
        pi.load_variables("/tmp/rlnav_model")

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ob_2"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular(
            "ev_tdlam_before",
            np.average([
                explained_variance(vpredbefore[:, i], tdlamret[:, i])
                for i in range(n_agents)
            ]))  # average of explained variance for each agent
        #         lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values
        #         listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
        #         lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lens, rews = (seg["ep_lens"], seg["ep_rets"])
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        tf.summary.scalar("EpLenMean", np.mean(lenbuffer))
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()

    return pi
Exemplo n.º 24
0
def learn(
        env,
        agent,
        reward_giver,
        expert_dataset,
        g_step,
        d_step,
        d_stepsize=3e-4,
        timesteps_per_batch=1024,
        nb_train_steps=50,
        max_timesteps=0,
        max_iters=0,  # TODO: max_episodes
        callback=None,
        d_adam=None,
        sess=None,
        saver=None):

    nworkers = MPI.COMM_WORLD.Get_size()
    rank = MPI.COMM_WORLD.Get_rank()
    np.set_printoptions(precision=3)

    # Prepare for rollouts
    # ----------------------------------------
    timesteps_so_far = 0
    iters_so_far = 0

    assert sum([max_iters > 0, max_timesteps > 0]) == 1

    # TODO: implicit policy does not admit pretraining?

    # set up record
    policy_losses_record = {}
    discriminator_losses_record = {}

    while True:
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        #elif max_episodes and episodes_so_far >= max_episodes:
        #    break
        elif max_iters and iters_so_far >= max_iters:
            break

        logger.log("********** Iteration %i ************" % iters_so_far)
        logger.log("********** Steps %i ************" % timesteps_so_far)

        # ------------------ Update G ------------------
        logger.log("Optimizing Policy...")
        ob_policy, ac_policy, losses_record = train_one_batch(
            env, agent, reward_giver, timesteps_per_batch, nb_train_steps,
            g_step)
        assert len(ob_policy) == len(ac_policy) == timesteps_per_batch * g_step

        # ------------------ Update D ------------------
        logger.log("Optimizing Discriminator...")
        logger.log(fmt_row(13, reward_giver.loss_name))
        ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob_policy))
        batch_size = len(ob_policy) // d_step
        d_losses = [
        ]  # list of tuples, each of which gives the loss for a minibatch
        for ob_batch, ac_batch in dataset.iterbatches(
            (ob_policy, ac_policy),
                include_final_partial_batch=False,
                batch_size=batch_size):
            ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob_batch))
            # update running mean/std for reward_giver
            if hasattr(reward_giver, "obs_rms"):
                reward_giver.obs_rms.update(
                    np.concatenate((ob_batch, ob_expert), 0))
            *newlosses, g = reward_giver.lossandgrad(ob_batch, ac_batch,
                                                     ob_expert, ac_expert)
            d_adam.update(allmean(g, nworkers), d_stepsize)
            d_losses.append(newlosses)

        logger.log(fmt_row(13, np.mean(d_losses, axis=0)))
        timesteps_so_far += timesteps_per_batch * g_step
        iters_so_far += 1

        # record
        for k, v in losses_record.items():
            if k in policy_losses_record.keys():
                policy_losses_record[k] += v
            else:
                policy_losses_record[k] = v
        for idx, k in enumerate(reward_giver.loss_name):
            if k in discriminator_losses_record.keys():
                discriminator_losses_record[k] += [
                    np.mean(d_losses, axis=0)[idx]
                ]
            else:
                discriminator_losses_record[k] = [
                    np.mean(d_losses, axis=0)[idx]
                ]

        # logging
        logger.record_tabular("Epoch Actor Losses",
                              np.mean(losses_record['actor_loss']))
        logger.record_tabular("Epoch Critic Losses",
                              np.mean(losses_record['critic_loss']))
        logger.record_tabular("Epoch Classifier Losses",
                              np.mean(losses_record['classifier_loss']))
        logger.record_tabular("Epoch Entropy",
                              np.mean(losses_record['entropy']))
        if rank == 0:
            logger.dump_tabular()

        # Call callback
        if callback is not None:
            callback(locals(), globals())
Exemplo n.º 25
0
def compete_learn(
        env,
        policy_func,
        *,
        timesteps_per_batch,  # timesteps per actor per update
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        optim_epochs,
        optim_stepsize,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,  # advantage estimation
        max_timesteps=0,
        max_episodes=0,
        max_iters=20,
        max_seconds=0,  # time constraint
        callback=None,  # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-3,
        schedule='linear'  # annealing for stepsize parameters (epsilon and adam)
):
    # Setup losses and stuff
    # ----------------------------------------
    # at this stage, the ob_space and reward_space is
    #TODO: all of this tuples are not right ? becuase items in tuple is not mutable
    #TODO: another way to store the two agents' states is to use with with tf.variable_scope(scope, reuse=reuse):
    len1 = 2
    ob_space = env.observation_space.spaces
    ac_space = env.action_space.spaces
    pi = [
        policy_func("pi" + str(i),
                    ob_space[i],
                    ac_space[i],
                    placeholder_name="observation" + str(i))
        for i in range(len1)
    ]
    oldpi = [
        policy_func("oldpi" + str(i),
                    ob_space[i],
                    ac_space[i],
                    placeholder_name="observation" + str(i))
        for i in range(len1)
    ]
    atarg = [
        tf.placeholder(dtype=tf.float32, shape=[None]) for i in range(len1)
    ]
    ret = [tf.placeholder(dtype=tf.float32, shape=[None]) for i in range(len1)]
    tdlamret = [[] for i in range(len1)]
    # TODO: here I should revise lrmult to as it was before
    # lrmult = 1.0 # here for simple I only use constant learning rate multiplier
    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult

    #TODO: this point I cannot finally understand it, originally it is
    # ob=U.get_placeholder_cached(name="ob")
    #TODO: here it is a bug to fix, I think the get_placeholder_cached is global, you can only cache observation once and the next time if it finds the name placeholder it will return the previous placeholder, I don't know whether different namescope have  effect on this.
    # ob1 = U.get_placeholder_cached(name="observation1") # Note: I am not sure about this point
    # # ob2 = U.get_placeholder_cached(name="observation2")
    # ob1 = U.get_placeholder_cached(name="observation0")  # Note: I am not sure about this point
    # ob2 = U.get_placeholder_cached(name="observation1")
    #TODO: the only one question now is that pi network and oldpi networ both have the ob_ph named "observation", even in the original baseline implementation, does pi and oldpi share the observation placeholder, I think it is not

    ob = [
        U.get_placeholder_cached(name="observation" + str(i))
        for i in range(len1)
    ]
    # ac = tuple([pi[i].act(stochastic=True, observation=env.observation_space[i])[0]
    #      for i in range(len1)])
    # TODO: here for the policy to work I changed the observation parameter passed into the pi function to s which comes from env.reset()
    # s = env.reset()
    # ac = tuple([pi[i].act(stochastic=True, observation=s[i])[0]
    #             for i in range(len1)])

    ac = [pi[i].pdtype.sample_placeholder([None]) for i in range(len1)]
    kloldnew = [oldpi[i].pd.kl(pi[i].pd) for i in range(len1)]
    ent = [pi[i].pd.entropy() for i in range(len1)]
    print("ent1 and ent2 are {} and {}".format(ent[0], ent[1]))
    meankl = [U.mean(kloldnew[i]) for i in range(len1)]
    meanent = [U.mean(ent[i]) for i in range(len1)]

    pol_entpen = [(-entcoeff) * meanent[i] for i in range(len1)]
    ratio = [
        tf.exp(pi[i].pd.logp(ac[i]) - oldpi[i].pd.logp(ac[i]))
        for i in range(len1)
    ]
    # ratio = [tf.exp(pi[i].pd.logp(ac) - oldpi[i].pd.logp(ac[i])) for i in range(len1)] #pnew / pold
    surr1 = [ratio[i] * atarg[i] for i in range(len1)]
    # U.clip = tf.clip_by_value(t, clip_value_min, clip_value_max,name=None):
    # # among which t is A 'Tensor' so
    surr2 = [
        U.clip(ratio[i], 1.0 - clip_param, 1.0 + clip_param)
        for i in range(len1)
    ]
    pol_surr = [-U.mean(tf.minimum(surr1[i], surr2[i])) for i in range(len1)]
    vf_loss = [U.mean(tf.square(pi[i].vpred - ret[i])) for i in range(len1)]
    total_loss = [
        pol_surr[i] + pol_entpen[i] + vf_loss[i] for i in range(len1)
    ]
    # here I ccome to realize that the following miscelleous losses are just operations not tensors so they should be
    # # be made to a list to contain the info of the two agents
    # surr2 = U.clip(ratio[i], 1.0 - clip_param, 1.0 + clip_param)
    # pol_surr = -U.mean(tf.minimum(surr1[i], surr2[i]))
    # vf_loss = U.mean(tf.square(pi[i].vpred - ret[i]))
    # total_loss = pol_surr + pol_entpen + vf_loss

    #TODO: in another way I choose to revise losses to following:
    losses = [[pol_surr[i], pol_entpen[i], vf_loss[i], meankl[i], meanent[i]]
              for i in range(len1)]
    loss_names = ["pol_sur", "pol_entpen", "vf_loss", "kl", "ent"]
    var_list = [pi[i].get_trainable_variables() for i in range(len1)]

    lossandgrad = [
        U.function([ob[i], ac[i], atarg[i], ret[i], lrmult],
                   losses[i] + [U.flatgrad(total_loss[i], var_list[i])])
        for i in range(len1)
    ]
    adam = [MpiAdam(var_list[i], epsilon=adam_epsilon) for i in range(2)]

    #TODO: I wonder this cannot function as expected because the result is a list of functions, not will not execute automatically
    # assign_old_eq_new = [U.function([],[], updates=[tf.assign(oldv, newv)
    #     for (oldv, newv) in zipsame(oldpi[i].get_variables(), pi[i].get_variables())]) for i in range(len1)]

    # compute_losses is a function, so it should not be copied to copies, nevertheless the parameters should be
    # passed into it as the two agents
    compute_losses = [
        U.function([ob[i], ac[i], atarg[i], ret[i], lrmult], losses[i])
        for i in range(len1)
    ]
    # sess = U.get_session()
    # writer = tf.summary.FileWriter(logdir='log-mlp',graph=sess.graph)
    # now when the training iteration ends, save the trained model and test the win rate of the two.
    pi0_variables = slim.get_variables(scope="pi0")
    pi1_variables = slim.get_variables(scope="pi1")
    parameters_to_save_list0 = [v for v in pi0_variables]
    parameters_to_save_list1 = [v for v in pi1_variables]
    parameters_to_save_list = parameters_to_save_list0 + parameters_to_save_list1
    saver = tf.train.Saver(parameters_to_save_list)
    restore = tf.train.Saver(parameters_to_save_list)
    U.initialize()
    restore.restore(U.get_session(), "parameter/500/500.pkl")
    U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(pi[1].get_variables(), pi[0].get_variables())
        ])()
    U.get_session().run
    # [adam[i].sync() for i in range(2)]
    adam[0].sync()
    adam[1].sync()
    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     horizon=timesteps_per_batch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()

    lenbuffer = [deque(maxlen=100)
                 for i in range(len1)]  # rolling buffer for episode lengths
    rewbuffer = [deque(maxlen=100)
                 for i in range(len1)]  # rolling buffer for episode rewards

    parameters_savers = []
    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        # saver.restore()

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()
        #TODO: I got to fix this function to let it return the right seg["adv"] and seg["lamret"]
        add_vtarg_and_adv(seg, gamma, lam)

        losses = [[] for i in range(len1)]
        meanlosses = [[] for i in range(len1)]
        for i in range(len1):
            ob[i], ac[i], atarg[i], tdlamret[i] = seg["ob"][i], seg["ac"][
                i], seg["adv"][i], seg["tdlamret"][i]
            # ob_extend = np.expand_dims(ob[i],axis=0)
            # ob[i] = ob_extend
            vpredbefore = seg["vpred"][
                i]  # predicted value function before udpate
            atarg[i] = (atarg[i] - atarg[i].mean()) / atarg[i].std(
            )  # standardized advantage function estimate
            d = Dataset(dict(ob=ob[i],
                             ac=ac[i],
                             atarg=atarg[i],
                             vtarg=tdlamret[i]),
                        shuffle=not pi[i].recurrent)
            optim_batchsize = optim_batchsize or ob[i].shape[0]

            if hasattr(pi[i], "ob_rms"):
                pi[i].ob_rms.update(
                    ob[i])  # update running mean/std for policy

            #TODO: I have to make suer how assign_old_ea_new works and whether to assign it for each agent
            #Yes I can assure it will work now

            # save network parameters using tf.train.Saver
            #     saver_name = "saver" + str(iters_so_far)

            U.function([], [],
                       updates=[
                           tf.assign(oldv, newv) for (oldv, newv) in zipsame(
                               oldpi[i].get_variables(), pi[i].get_variables())
                       ])()
            # set old parameter values to new parameter values
            # Here we do a bunch of optimization epochs over the data
            logger.log("Optimizing the agent{}...".format(i))
            logger.log(fmt_row(13, loss_names))
            for _ in range(optim_epochs):
                losses[i] = [
                ]  # list of tuples, each of which gives the loss for a minibatch
                for batch in d.iterate_once(optim_batchsize):
                    # batch["ob"] = np.expand_dims(batch["ob"], axis=0)
                    *newlosses, g = lossandgrad[i](batch["ob"], batch["ac"],
                                                   batch["atarg"],
                                                   batch["vtarg"], cur_lrmult)
                    adam[i].update(g, optim_stepsize * cur_lrmult)
                    losses[i].append(newlosses)
                    logger.log(fmt_row(13, np.mean(losses[i], axis=0)))

            logger.log("Evaluating losses of agent{}...".format(i))
            losses[i] = []
            for batch in d.iterate_once(optim_batchsize):
                newlosses = compute_losses[i](batch["ob"], batch["ac"],
                                              batch["atarg"], batch["vtarg"],
                                              cur_lrmult)
                losses[i].append(newlosses)
            meanlosses[i], _, _ = mpi_moments(losses[i], axis=0)
            logger.log(fmt_row(13, meanlosses[i]))
            for (lossval, name) in zipsame(meanlosses[i], loss_names):
                logger.record_tabular("loss_" + name, lossval)
            logger.record_tabular("ev_tdlam_before{}".format(i),
                                  explained_variance(vpredbefore, tdlamret[i]))

            lrlocal = (seg["ep_lens"][i], seg["ep_rets"][i])  # local values
            listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
            lens, rews = map(flatten_lists, zip(*listoflrpairs))
            lenbuffer[i].extend(lens)
            rewbuffer[i].extend(rews)
            logger.record_tabular("EpLenMean {}".format(i),
                                  np.mean(lenbuffer[i]))
            logger.record_tabular("EpRewMean {}".format(i),
                                  np.mean(rewbuffer[i]))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        temp_pi = policy_func("temp_pi" + str(iters_so_far),
                              ob_space[0],
                              ac_space[0],
                              placeholder_name="temp_pi_observation" +
                              str(iters_so_far))
        U.function([], [],
                   updates=[
                       tf.assign(oldv, newv) for (oldv, newv) in zipsame(
                           temp_pi.get_variables(), pi[0].get_variables())
                   ])()
        parameters_savers.append(temp_pi)

        # now I think when the
        if iters_so_far % 3 == 0:
            sample_iteration = int(
                np.random.uniform(iters_so_far / 2, iters_so_far))
            print("now assign the {}th parameter of agent0 to agent1".format(
                sample_iteration))
            pi_restore = parameters_savers[sample_iteration]
            U.function([], [],
                       updates=[
                           tf.assign(oldv, newv)
                           for (oldv,
                                newv) in zipsame(pi[1].get_variables(),
                                                 pi_restore.get_variables())
                       ])()

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()

    # # now when the training iteration ends, save the trained model and test the win rate of the two.
    # pi0_variables = slim.get_variables(scope = "pi0")
    # pi1_variables = slim.get_variables(scope = "pi1")
    # parameters_to_save_list0 = [v for v in pi0_variables]
    # parameters_to_save_list1 = [v for v in pi1_variables]
    # parameters_to_save_list = parameters_to_save_list0 + parameters_to_save_list1
    # saver = tf.train.Saver(parameters_to_save_list)
    # parameters_path = 'parameter/'
    # tf.train.Saver()
    save_path = saver.save(U.get_session(), "parameter/800/800.pkl")
Exemplo n.º 26
0
def learn(*,
          network,
          env,
          reward_giver,
          expert_dataset,
          g_step,
          d_step,
          d_stepsize=3e-4,
          total_timesteps,
          eval_env=None,
          seed=None,
          nsteps=2048,
          ent_coef=0.0,
          lr=3e-4,
          vf_coef=0.5,
          max_grad_norm=0.5,
          gamma=0.99,
          lam=0.95,
          log_interval=10,
          nminibatches=4,
          noptepochs=4,
          cliprange=0.2,
          save_interval=0,
          load_path=None,
          model_fn=None,
          update_fn=None,
          init_fn=None,
          mpi_rank_weight=1,
          comm=None,
          **network_kwargs):

    # from PPO learn
    set_global_seeds(seed)

    if isinstance(lr, float): lr = constfn(lr)
    else: assert callable(lr)
    if isinstance(cliprange, float): cliprange = constfn(cliprange)
    else: assert callable(cliprange)
    total_timesteps = int(total_timesteps)

    policy = build_policy(env, network, **network_kwargs)

    # nenvs = env.num_envs
    nenvs = 1

    ob_space = env.observation_space
    ac_space = env.action_space

    nbatch = nenvs * nsteps
    nbatch_train = nbatch // nminibatches
    is_mpi_root = (MPI is None or MPI.COMM_WORLD.Get_rank() == 0)

    if model_fn is None:
        from baselines.ppo2.model import Model
        model_fn = Model

    model = model_fn(policy=policy,
                     ob_space=ob_space,
                     ac_space=ac_space,
                     nbatch_act=nenvs,
                     nbatch_train=nbatch_train,
                     nsteps=nsteps,
                     ent_coef=ent_coef,
                     vf_coef=vf_coef,
                     max_grad_norm=max_grad_norm,
                     comm=comm,
                     mpi_rank_weight=mpi_rank_weight)

    if load_path is not None:
        model.load(load_path)

    runner = Runner(env=env,
                    model=model,
                    nsteps=nsteps,
                    gamma=gamma,
                    lam=lam,
                    reward_giver=reward_giver)
    if eval_env is not None:
        eval_runner = Runner(env=eval_env,
                             model=model,
                             nsteps=nsteps,
                             gamma=gamma,
                             lam=lam)

    epinfobuf = deque(maxlen=100)
    if eval_env is not None:
        eval_epinfobuf = deque(maxlen=100)

    if init_fn is not None:
        init_fn()

    tfirststart = time.perf_counter()

    nupdates = total_timesteps // nbatch

    # from TRPO MPI
    nworkers = MPI.COMM_WORLD.Get_size()

    ob = model.act_model.X
    ac = model.A

    d_adam = MpiAdam(reward_giver.get_trainable_variables())

    def allmean(x):
        assert isinstance(x, np.ndarray)
        out = np.empty_like(x)
        MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
        out /= nworkers
        return out

    # from PPO
    for update in range(1, nupdates + 1):
        assert nbatch % nminibatches == 0
        tstart = time.perf_counter()
        frac = 1.0 - (update - 1.0) / nupdates
        lrnow = lr(frac)
        cliprangenow = cliprange(frac)

        logger.log("Optimizing Policy...")
        for _ in range(g_step):
            if update % log_interval == 0 and is_mpi_root:
                logger.info('Stepping environment...')

            obs, returns, masks, actions, values, neglogpacs, states, epinfos = runner.run(
            )
            if eval_env is not None:
                eval_obs, eval_returns, eval_masks, eval_actions, eval_values, eval_neglogpacs, eval_states, eval_epinfos = eval_runner.run(
                )

            if update % log_interval == 0 and is_mpi_root: logger.info('Done.')

            epinfobuf.extend(epinfos)
            if eval_env is not None:
                eval_epinfobuf.extend(eval_epinfos)

            mblossvals = []
            if states is None:
                inds = np.arange(nbatch)
                for _ in range(noptepochs):
                    np.random.shuffle(inds)
                    for start in range(0, nbatch, nbatch_train):
                        end = start + nbatch_train
                        mbinds = inds[start:end]
                        slices = (arr[mbinds]
                                  for arr in (obs, returns, masks, actions,
                                              values, neglogpacs))
                        mblossvals.append(
                            model.train(lrnow, cliprangenow, *slices))
            else:
                assert False  # make sure we're not going here, so any bugs aren't from here
                assert nenvs % nminibatches == 0
                envsperbatch = nenvs // nminibatches
                envinds = np.arange(nenvs)
                flatinds = np.arange(nenvs * nsteps).reshape(nenvs, nsteps)
                for _ in range(noptepochs):
                    np.random.shuffle(envinds)
                    for start in range(0, nenvs, envsperbatch):
                        end = start + envsperbatch
                        mbenvinds = envinds[start:end]
                        mbflatinds = flatinds[mbenvinds].ravel()
                        slices = (arr[mbflatinds]
                                  for arr in (obs, returns, masks, actions,
                                              values, neglogpacs))
                        mbstates = states[mbenvinds]
                        mblossvals.append(
                            model.train(lrnow, cliprangenow, *slices,
                                        mbstates))

        lossvals = np.mean(mblossvals, axis=0)
        tnow = time.perf_counter()
        fps = int(nbatch / (tnow - tstart))

        # TRPO MPI
        logger.log("Optimizing Disciminator...")
        logger.log(fmt_row(13, reward_giver.loss_name))
        ob_expert, ac_expert = expert_dataset.get_next_batch(len(obs))
        batch_size = len(obs) // d_step
        d_losses = []
        for ob_batch, ac_batch in dataset.iterbatches(
            (obs, actions),
                include_final_partial_batch=False,
                batch_size=batch_size):
            ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob_batch))
            if hasattr(reward_giver, "obs_rms"):
                reward_giver.obs_rms.update(
                    np.concatenate((ob_batch, ob_expert), 0))
            *newlosses, g = reward_giver.lossandgrad(ob_batch, ac_batch,
                                                     ob_expert, ac_expert)
            d_adam.update(allmean(g), d_stepsize)
            d_losses.append(newlosses)

        logger.log(fmt_row(13, np.mean(d_losses, axis=0)))

        if update_fn is not None:
            update_fn(update)

        if update % log_interval == 0 or update == 1:
            ev = explained_variance(values, returns)
            logger.logkv("misc/serial_timesteps", update * nsteps)
            logger.logkv("misc/nupdates", update)
            logger.logkv("misc/total_timesteps", update * nbatch)
            logger.logkv("fps", fps)
            logger.logkv("misc/explained_variance", float(ev))
            logger.logkv("eprewmean",
                         safemean([epinfo['r'] for epinfo in epinfobuf]))
            logger.logkv("eplenmean",
                         safemean([epinfo['l'] for epinfo in epinfobuf]))
            if eval_env is not None:
                logger.logkv(
                    "eval_eprewmean",
                    safemean([epinfo['r'] for epinfo in eval_epinfobuf]))
                logger.logkv(
                    "eval_eplenmean",
                    safemean([epinfo['l'] for epinfo in eval_epinfobuf]))
            logger.logkv("misc/time_elapsed", tnow - tfirststart)
            for (lossval, lossname) in zip(lossvals, model.loss_names):
                logger.logkv("loss/" + lossname, lossval)

            logger.dumpkvs()

        if save_interval and (update % save_interval == 0 or update
                              == 1) and logger.get_dir() and is_mpi_root:
            checkdir = osp.join(logger.get_dir(), 'checkpoints')
            os.makedirs(checkdir, exist_ok=True)
            savepath = osp.join(checkdir, '%.5i' % update)
            print("Saving to", savepath)
            model.save(savepath)

    return model
Exemplo n.º 27
0
def learn(env, policy_func,
        timesteps_per_batch, # timesteps per actor per update
        clip_param, entcoeff, # clipping parameter epsilon, entropy coeff
        optim_epochs, optim_stepsize, optim_batchsize,# optimization hypers
        gamma, lam, # advantage estimation
        max_timesteps=0, max_episodes=0, max_iters=0, max_seconds=0,  # time constraint
        callback=None, # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant', # annealing for stepsize parameters (epsilon and adam)
        desired_kl=None
        ):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    nenvs = env.num_envs
    pi = policy_func("pi", ob_space, ac_space) # Construct network for new policy
    oldpi = policy_func("oldpi", ob_space, ac_space) # Network for old policy
    atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return

    lrmult = tf.placeholder(name='lrmult', dtype=tf.float32, shape=[]) # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = U.mean(kloldnew)
    meanent = U.mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # pnew / pold
    surr1 = ratio * atarg # surrogate from conservative policy iteration
    surr2 = U.clip(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg #
    pol_surr = - U.mean(tf.minimum(surr1, surr2)) # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = U.mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult], losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function([],[], updates=[tf.assign(oldv, newv)
        for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    adam.sync()

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator_vecenv(pi, env, timesteps_per_batch, stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100) # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100) # rolling buffer for episode rewards

    assert sum([max_iters>0, max_timesteps>0, max_episodes>0, max_seconds>0])==1, "Only one time constraint permitted"

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult =  max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        elif schedule == 'adapt':
            cur_lrmult = 1.0
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************"%iters_so_far)

        seg = seg_gen.next()
        add_vtarg_and_adv(seg, gamma, lam, nenvs)


        # merge all of the envs
        for k in seg.keys():
            if k != "ep_rets" and k != "ep_lens" and k != "nextvpred":
                seg[k] = np.concatenate(np.asarray(seg[k]), axis=0)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"]
        vpredbefore = seg["vpred"] # predicted value function before udpate
        atarg = (atarg - atarg.mean()) / atarg.std() # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret), shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy

        assign_old_eq_new() # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [] # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                #*newlosses, g = lossandgrad(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult)
                lossandgradout = lossandgrad(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult)
                newlosses, g = lossandgradout[:-1], lossandgradout[-1]
                if desired_kl != None and schedule == 'adapt':
                    if newlosses[-2] > desired_kl * 2:
                        optim_stepsize = max(1e-8, optim_stepsize / 1.5)
                    elif newlosses[-2] < desired_kl / 2:
                        optim_stepsize = min(1e0, optim_stepsize * 1.5 )
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult)
            losses.append(newlosses)
        meanlosses,_,_ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_"+name, lossval)
        logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        if MPI.COMM_WORLD.Get_rank()==0:
            logger.dump_tabular()
def learn(
        env,
        seed,
        policy_fn,
        *,
        timesteps_per_actorbatch,  # timesteps per actor per update
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        optim_epochs,
        optim_stepsize,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,  # advantage estimation
        save_model_with_prefix,  # Save the model
        save_prefix,
        restore_model_from_file,  # Load the states/model from this file.
        load_after_iters,
        save_after,
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        max_seconds=0,  # time constraint
        callback=None,  # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
        stochastic=True):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space

    g = tf.get_default_graph()
    with g.as_default():
        tf.set_random_seed(seed)

    pi = policy_fn("pi", ob_space,
                   ac_space)  # Construct network for new policy
    oldpi = policy_fn("oldpi", ob_space, ac_space)  # Network for old policy

    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param,
                             1.0 + clip_param) * atarg  #
    pol_surr = -tf.reduce_mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)

    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    adam.sync()

    if restore_model_from_file:
        saver = tf.train.Saver()
        basePath = os.path.dirname(os.path.abspath(__file__))
        modelF = basePath + '/' + save_prefix + "_afterIter_" + str(
            load_after_iters) + '.model'
        saver.restore(tf.get_default_session(), modelF)
        logger.log("Loaded model from {}".format(modelF))

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_actorbatch,
                                     stochastic=stochastic)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards
    truerewbuffer = deque(maxlen=100)

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]

        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))

        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult)
            losses.append(newlosses)
        meanlosses = np.mean(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_" + name, lossval)

        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"], seg["ep_true_rets"]
                   )  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews, truerews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        truerewbuffer.extend(truerews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpTrueRewMean", np.mean(truerewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)

        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)

        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()

        if iters_so_far % save_after == 0:
            if save_model_with_prefix:
                basePath = os.path.dirname(os.path.abspath(__file__))
                modelF = basePath + '/' + save_prefix + "_afterIter_" + str(
                    iters_so_far) + ".model"
                U.save_state(modelF)
                logger.log("Saved model to file :{}".format(modelF))

    return pi
Exemplo n.º 29
0
def learn(env,
          policy_func,
          reward_giver,
          expert_dataset,
          rank,
          pretrained,
          pretrained_weight,
          *,
          g_step,
          d_step,
          entcoeff,
          save_per_iter,
          ckpt_dir,
          log_dir,
          timesteps_per_batch,
          task_name,
          gamma,
          lam,
          max_kl,
          cg_iters,
          cg_damping=1e-2,
          vf_stepsize=3e-4,
          d_stepsize=3e-4,
          vf_iters=3,
          max_timesteps=0,
          max_episodes=0,
          max_iters=0,
          callback=None):

    nworkers = MPI.COMM_WORLD.Get_size()
    rank = MPI.COMM_WORLD.Get_rank()
    np.set_printoptions(precision=3)
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_func("pi",
                     ob_space,
                     ac_space,
                     reuse=(pretrained_weight != None))
    oldpi = policy_func("oldpi", ob_space, ac_space)
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    entbonus = entcoeff * meanent

    vferr = tf.reduce_mean(tf.square(pi.vpred - ret))

    ratio = tf.exp(pi.pd.logp(ac) -
                   oldpi.pd.logp(ac))  # advantage * pnew / pold
    surrgain = tf.reduce_mean(ratio * atarg)

    optimgain = surrgain + entbonus
    losses = [optimgain, meankl, entbonus, surrgain, meanent]
    loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]

    dist = meankl

    all_var_list = pi.get_trainable_variables()
    var_list = [
        v for v in all_var_list
        if v.name.startswith("pi/pol") or v.name.startswith("pi/logstd")
    ]
    vf_var_list = [v for v in all_var_list if v.name.startswith("pi/vff")]
    assert len(var_list) == len(vf_var_list) + 1
    d_adam = MpiAdam(reward_giver.get_trainable_variables())
    vfadam = MpiAdam(vf_var_list)

    get_flat = U.GetFlat(var_list)
    set_from_flat = U.SetFromFlat(var_list)
    klgrads = tf.gradients(dist, var_list)
    flat_tangent = tf.placeholder(dtype=tf.float32,
                                  shape=[None],
                                  name="flat_tan")
    shapes = [var.get_shape().as_list() for var in var_list]
    start = 0
    tangents = []
    for shape in shapes:
        sz = U.intprod(shape)
        tangents.append(tf.reshape(flat_tangent[start:start + sz], shape))
        start += sz
    gvp = tf.add_n([
        tf.reduce_sum(g * tangent)
        for (g, tangent) in zipsame(klgrads, tangents)
    ])  # pylint: disable=E1111
    fvp = U.flatgrad(gvp, var_list)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg], losses)
    compute_lossandgrad = U.function([ob, ac, atarg], losses +
                                     [U.flatgrad(optimgain, var_list)])
    compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp)
    compute_vflossandgrad = U.function([ob, ret],
                                       U.flatgrad(vferr, vf_var_list))

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(
                colorize("done in %.3f seconds" % (time.time() - tstart),
                         color='magenta'))
        else:
            yield

    def allmean(x):
        assert isinstance(x, np.ndarray)
        out = np.empty_like(x)
        MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
        out /= nworkers
        return out

    U.initialize()
    th_init = get_flat()
    MPI.COMM_WORLD.Bcast(th_init, root=0)
    set_from_flat(th_init)
    d_adam.sync()
    vfadam.sync()
    if rank == 0:
        print("Init param sum", th_init.sum(), flush=True)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     reward_giver,
                                     timesteps_per_batch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=40)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40)  # rolling buffer for episode rewards
    true_rewbuffer = deque(maxlen=40)

    assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1

    g_loss_stats = stats(loss_names)
    d_loss_stats = stats(reward_giver.loss_name)
    ep_stats = stats(["True_rewards", "Rewards", "Episode_length"])
    # if provide pretrained weight
    if pretrained_weight is not None:
        U.load_state(pretrained_weight, var_list=pi.get_variables())

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break

        # Save model
        if rank == 0 and iters_so_far % save_per_iter == 0 and ckpt_dir is not None:
            fname = os.path.join(ckpt_dir, task_name)
            os.makedirs(os.path.dirname(fname), exist_ok=True)
            saver = tf.train.Saver()
            saver.save(tf.get_default_session(), fname)

        logger.log("********** Iteration %i ************" % iters_so_far)

        def fisher_vector_product(p):
            return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p

        # ------------------ Update G ------------------
        logger.log("Optimizing Policy...")
        for _ in range(g_step):
            with timed("sampling"):
                seg = seg_gen.__next__()
            add_vtarg_and_adv(seg, gamma, lam)
            # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
            ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
                "tdlamret"]
            vpredbefore = seg[
                "vpred"]  # predicted value function before udpate
            atarg = (atarg - atarg.mean()) / atarg.std(
            )  # standardized advantage function estimate

            if hasattr(pi, "ob_rms"):
                pi.ob_rms.update(ob)  # update running mean/std for policy

            args = seg["ob"], seg["ac"], atarg
            fvpargs = [arr[::5] for arr in args]

            assign_old_eq_new(
            )  # set old parameter values to new parameter values
            with timed("computegrad"):
                *lossbefore, g = compute_lossandgrad(*args)
            lossbefore = allmean(np.array(lossbefore))
            g = allmean(g)
            if np.allclose(g, 0):
                logger.log("Got zero gradient. not updating")
            else:
                with timed("cg"):
                    stepdir = cg(fisher_vector_product,
                                 g,
                                 cg_iters=cg_iters,
                                 verbose=rank == 0)
                assert np.isfinite(stepdir).all()
                shs = .5 * stepdir.dot(fisher_vector_product(stepdir))
                lm = np.sqrt(shs / max_kl)
                # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
                fullstep = stepdir / lm
                expectedimprove = g.dot(fullstep)
                surrbefore = lossbefore[0]
                stepsize = 1.0
                thbefore = get_flat()
                for _ in range(10):
                    thnew = thbefore + fullstep * stepsize
                    set_from_flat(thnew)
                    meanlosses = surr, kl, *_ = allmean(
                        np.array(compute_losses(*args)))
                    improve = surr - surrbefore
                    logger.log("Expected: %.3f Actual: %.3f" %
                               (expectedimprove, improve))
                    if not np.isfinite(meanlosses).all():
                        logger.log("Got non-finite value of losses -- bad!")
                    elif kl > max_kl * 1.5:
                        logger.log("violated KL constraint. shrinking step.")
                    elif improve < 0:
                        logger.log("surrogate didn't improve. shrinking step.")
                    else:
                        logger.log("Stepsize OK!")
                        break
                    stepsize *= .5
                else:
                    logger.log("couldn't compute a good step")
                    set_from_flat(thbefore)
                if nworkers > 1 and iters_so_far % 20 == 0:
                    paramsums = MPI.COMM_WORLD.allgather(
                        (thnew.sum(),
                         vfadam.getflat().sum()))  # list of tuples
                    assert all(
                        np.allclose(ps, paramsums[0]) for ps in paramsums[1:])
            with timed("vf"):
                for _ in range(vf_iters):
                    for (mbob, mbret) in dataset.iterbatches(
                        (seg["ob"], seg["tdlamret"]),
                            include_final_partial_batch=False,
                            batch_size=128):
                        if hasattr(pi, "ob_rms"):
                            pi.ob_rms.update(
                                mbob)  # update running mean/std for policy
                        g = allmean(compute_vflossandgrad(mbob, mbret))
                        vfadam.update(g, vf_stepsize)

        g_losses = meanlosses
        for (lossname, lossval) in zip(loss_names, meanlosses):
            logger.record_tabular(lossname, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        # ------------------ Update D ------------------
        logger.log("Optimizing Discriminator...")
        logger.log(fmt_row(13, reward_giver.loss_name))
        ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob))
        batch_size = len(ob) // d_step
        d_losses = [
        ]  # list of tuples, each of which gives the loss for a minibatch
        for ob_batch, ac_batch in dataset.iterbatches(
            (ob, ac), include_final_partial_batch=False,
                batch_size=batch_size):
            ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob_batch))
            # update running mean/std for reward_giver
            if hasattr(reward_giver, "obs_rms"):
                reward_giver.obs_rms.update(
                    np.concatenate((ob_batch, ob_expert), 0))
            *newlosses, g = reward_giver.lossandgrad(ob_batch, ac_batch,
                                                     ob_expert, ac_expert)
            d_adam.update(allmean(g), d_stepsize)
            d_losses.append(newlosses)
        logger.log(fmt_row(13, np.mean(d_losses, axis=0)))

        lrlocal = (seg["ep_lens"], seg["ep_rets"], seg["ep_true_rets"]
                   )  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews, true_rets = map(flatten_lists, zip(*listoflrpairs))
        true_rewbuffer.extend(true_rets)
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpTrueRewMean", np.mean(true_rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)

        if rank == 0:
            logger.dump_tabular()
Exemplo n.º 30
0
def learn(
        env,
        policy_func,
        *,
        timesteps_per_batch,  # timesteps per actor per update
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        optim_epochs,
        optim_stepsize,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,  # advantage estimation
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        max_seconds=0,  # time constraint
        callback=None,  # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        dropout_on_V,
        dropout_tau_V=0.05,
        lengthscale_V=0.0015,
        schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
        override_reg=None):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_func("pi", ob_space,
                     ac_space)  # Construct network for new policy
    oldpi = policy_func("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32,
                         shape=[None])  # Empirical returLAMBDAn

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = U.mean(kloldnew)
    meanent = U.mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = U.clip(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg  #
    pol_surr = -U.mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)

    ### MAIN CHANGES
    ### Fitting V
    if dropout_on_V:

        ## TUNING PARAMETERS
        alpha = 0.5

        x = ret
        flat = pi.vpred_dropout_networks
        flat_stacked = tf.stack(flat)  # K x M x outsize
        # M x B X outsize
        sumsq = U.sum(tf.square(x - flat_stacked), -1)
        sumsq *= (-.5 * alpha * dropout_tau_V)
        vf_loss = (-1.0 * alpha**-1.0) * logsumexp(sumsq, 0)

        if override_reg is not None:
            critic_l2_reg = override_reg
        else:
            critic_l2_reg = lengthscale_V**2.0 * (pi.V_keep_prob) / (
                2.0 * float(np.prod(ob_space.shape[0]) * dropout_tau_V))
        critic_reg_vars = [
            x for x in pi.get_trainable_variables()
            if 'value_function' in x.name
        ]

        critic_reg = tc.layers.apply_regularization(
            tc.layers.l2_regularizer(pi.V_keep_prob),
            weights_list=critic_reg_vars)
        vf_loss += critic_reg
    else:
        vf_loss = U.mean(tf.square(pi.vpred - ret))
        if override_reg is not None:
            critic_l2_reg = override_reg
            critic_reg_vars = [
                x for x in pi.get_trainable_variables()
                if 'value_function' in x.name
            ]

            critic_reg = tc.layers.apply_regularization(
                tc.layers.l2_regularizer(pi.V_keep_prob),
                weights_list=critic_reg_vars)
            vf_loss += critic_reg

    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    adam.sync()

    # Prepare for rollouts
    # ----------------------------------------

    timesteps_so_far = 0

    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_batch,
                                     stochastic=True)

    episodes_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()

        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        update_dropout_masks(
            [x for x in pi.get_variables() if 'dropout' in x.name],
            pi.V_keep_prob)
        assign_old_eq_new()

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()
Exemplo n.º 31
0
def learn(env, policy_fn, *,
        timesteps_per_actorbatch, # timesteps per actor per update
        clip_param, entcoeff, # clipping parameter epsilon, entropy coeff
        optim_epochs, optim_stepsize, optim_batchsize,# optimization hypers
        gamma, lam, # advantage estimation
        max_timesteps=0, max_episodes=0, max_iters=0, max_seconds=0,  # time constraint
        callback=None, # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant' # annealing for stepsize parameters (epsilon and adam)
        ):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_fn("pi", ob_space, ac_space) # Construct network for new policy
    oldpi = policy_fn("oldpi", ob_space, ac_space) # Network for old policy
    atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return

    lrmult = tf.placeholder(name='lrmult', dtype=tf.float32, shape=[]) # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # pnew / pold
    surr1 = ratio * atarg # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg #
    pol_surr = - tf.reduce_mean(tf.minimum(surr1, surr2)) # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult], losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function([],[], updates=[tf.assign(oldv, newv)
        for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    adam.sync()

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi, env, timesteps_per_actorbatch, stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100) # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100) # rolling buffer for episode rewards

    assert sum([max_iters>0, max_timesteps>0, max_episodes>0, max_seconds>0])==1, "Only one time constraint permitted"

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult =  max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************"%iters_so_far)

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"]
        vpredbefore = seg["vpred"] # predicted value function before udpate
        atarg = (atarg - atarg.mean()) / atarg.std() # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret), shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy

        assign_old_eq_new() # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [] # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult)
            losses.append(newlosses)
        meanlosses,_,_ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_"+name, lossval)
        logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        if MPI.COMM_WORLD.Get_rank()==0:
            logger.dump_tabular()
Exemplo n.º 32
0
def learn(
        env,
        policy_func,
        *,
        timesteps_per_batch,  # timesteps per actor per update
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        optim_epochs,
        optim_stepsize,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,  # advantage estimation
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        max_seconds=0,  # time constraint
        callback=None,  # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
        save_per_iter=100,
        ckpt_dir=None,
        task="train",
        sample_stochastic=True,
        load_model_path=None,
        task_name=None,
        max_sample_traj=1500):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_func("pi", ob_space,
                     ac_space)  # Construct network for new policy
    oldpi = policy_func("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])
    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = U.mean(kloldnew)
    meanent = U.mean(ent)
    pol_entpen = (-entcoeff) * meanent
    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = U.clip(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg  #
    pol_surr = -U.mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = U.mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    adam.sync()

    # Prepare for rollouts
    # ----------------------------------------
    mode = 'stochastic' if sample_stochastic else 'deterministic'
    logger.log("Using %s policy" % mode)
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_batch,
                                     stochastic=sample_stochastic)
    traj_gen = traj_episode_generator(pi,
                                      env,
                                      timesteps_per_batch,
                                      stochastic=sample_stochastic)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    sample_trajs = []
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards
    # if provieded model path
    if load_model_path is not None:
        U.load_state(load_model_path)

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        if task == "sample_trajectory":
            logger.log("********** Iteration %i ************" % iters_so_far)
            traj = traj_gen.__next__()
            ob, new, ep_ret, ac, rew, ep_len = traj['ob'], traj['new'], traj[
                'ep_ret'], traj['ac'], traj['rew'], traj['ep_len']
            logger.record_tabular("ep_ret", ep_ret)
            logger.record_tabular("ep_len", ep_len)
            logger.record_tabular("immediate reward", np.mean(rew))
            if MPI.COMM_WORLD.Get_rank() == 0:
                logger.dump_tabular()
            traj_data = {"ob": ob, "ac": ac, "rew": rew, "ep_ret": ep_ret}
            sample_trajs.append(traj_data)
            if iters_so_far > max_sample_traj:
                sample_ep_rets = [traj["ep_ret"] for traj in sample_trajs]
                logger.log("Average total return: %f" %
                           (sum(sample_ep_rets) / len(sample_ep_rets)))
                if sample_stochastic:
                    task_name = 'stochastic.' + task_name
                else:
                    task_name = 'deterministic.' + task_name
                pkl.dump(sample_trajs, open(task_name + ".pkl", "wb"))
                break
            iters_so_far += 1
        elif task == "train":

            # Save model
            if iters_so_far % save_per_iter == 0 and ckpt_dir is not None:
                U.save_state(os.path.join(ckpt_dir, task_name),
                             counter=iters_so_far)

            logger.log("********** Iteration %i ************" % iters_so_far)

            seg = seg_gen.__next__()
            add_vtarg_and_adv(seg, gamma, lam)

            # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
            ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
                "tdlamret"]
            vpredbefore = seg[
                "vpred"]  # predicted value function before udpate
            atarg = (atarg - atarg.mean()) / atarg.std(
            )  # standardized advantage function estimate

            d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                        shuffle=not pi.recurrent)
            optim_batchsize = optim_batchsize or ob.shape[0]

            if hasattr(pi, "ob_rms"):
                pi.ob_rms.update(ob)  # update running mean/std for policy

            assign_old_eq_new(
            )  # set old parameter values to new parameter values
            logger.log("Optimizing...")
            logger.log(fmt_row(13, loss_names))
            # Here we do a bunch of optimization epochs over the data
            for _ in range(optim_epochs):
                losses = [
                ]  # list of tuples, each of which gives the loss for a minibatch
                for batch in d.iterate_once(optim_batchsize):
                    *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                                batch["atarg"], batch["vtarg"],
                                                cur_lrmult)
                    adam.update(g, optim_stepsize * cur_lrmult)
                    losses.append(newlosses)
                logger.log(fmt_row(13, np.mean(losses, axis=0)))

            logger.log("Evaluating losses...")
            losses = []
            for batch in d.iterate_once(optim_batchsize):
                newlosses = compute_losses(batch["ob"], batch["ac"],
                                           batch["atarg"], batch["vtarg"],
                                           cur_lrmult)
                losses.append(newlosses)
            meanlosses, _, _ = mpi_moments(losses, axis=0)
            logger.log(fmt_row(13, meanlosses))
            for (lossval, name) in zipsame(meanlosses, loss_names):
                logger.record_tabular("loss_" + name, lossval)
            logger.record_tabular("ev_tdlam_before",
                                  explained_variance(vpredbefore, tdlamret))
            lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
            listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
            lens, rews = map(flatten_lists, zip(*listoflrpairs))
            lenbuffer.extend(lens)
            rewbuffer.extend(rews)
            logger.record_tabular("EpLenMean", np.mean(lenbuffer))
            logger.record_tabular("EpRewMean", np.mean(rewbuffer))
            logger.record_tabular("EpThisIter", len(lens))
            episodes_so_far += len(lens)
            timesteps_so_far += sum(lens)
            iters_so_far += 1
            logger.record_tabular("EpisodesSoFar", episodes_so_far)
            logger.record_tabular("TimestepsSoFar", timesteps_so_far)
            logger.record_tabular("TimeElapsed", time.time() - tstart)
            if MPI.COMM_WORLD.Get_rank() == 0:
                logger.dump_tabular()