예제 #1
0
def learn(
        *,
        network,
        env,
        seed=None,
        beta,
        total_timesteps,
        sil_update,
        sil_loss,
        timesteps_per_batch=2048,  # what to train on
        epsilon=0.01,
        cg_iters=10,
        gamma=0.99,
        lam=0.98,  # advantage estimation
        entcoeff=0.0,
        lr=3e-4,
        cg_damping=0.1,
        vf_stepsize=1e-3,
        vf_iters=5,
        sil_value=0.01,
        sil_alpha=0.6,
        sil_beta=0.1,
        max_episodes=0,
        max_iters=0,  # time constraint
        callback=None,
        save_interval=0,
        load_path=None,
        model_fn=None,
        update_fn=None,
        init_fn=None,
        mpi_rank_weight=1,
        comm=None,
        vf_coef=0.5,
        max_grad_norm=0.5,
        log_interval=1,
        nminibatches=4,
        noptepochs=4,
        cliprange=0.2,
        TRPO=False,
        **network_kwargs):

    set_global_seeds(seed)
    if isinstance(lr, float): lr = constfn(lr)
    else: assert callable(lr)
    if isinstance(cliprange, float): cliprange = constfn(cliprange)
    else: assert callable(cliprange)
    nworkers = MPI.COMM_WORLD.Get_size()
    rank = MPI.COMM_WORLD.Get_rank()

    policy = build_policy(env,
                          network,
                          value_network='copy',
                          copos=True,
                          **network_kwargs)
    nenvs = env.num_envs
    np.set_printoptions(precision=3)

    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    nbatch = nenvs * timesteps_per_batch
    nbatch_train = nbatch // nminibatches
    is_mpi_root = (MPI is None or MPI.COMM_WORLD.Get_rank() == 0)
    if model_fn is None:
        model_fn = Model
    discrete_ac_space = isinstance(ac_space, gym.spaces.Discrete)

    ob = observation_placeholder(ob_space)
    with tf.variable_scope("pi", reuse=tf.AUTO_REUSE):
        pi = policy(observ_placeholder=ob)
        #sil_model=policy(None, None, sess=get_session)
        make_model = lambda: Model(
            policy=policy,
            ob_space=ob_space,
            ac_space=ac_space,
            nbatch_act=nenvs,
            nbatch_train=nbatch_train,
            nsteps=timesteps_per_batch,
            ent_coef=entcoeff,
            vf_coef=vf_coef,
            max_grad_norm=max_grad_norm,
            sil_update=sil_update,
            sil_value=sil_value,
            sil_alpha=sil_alpha,
            sil_beta=sil_beta,
            sil_loss=sil_loss,
            #                                    fn_reward=env.process_reward,
            fn_reward=None,
            #                                    fn_obs=env.process_obs,
            fn_obs=None,
            ppo=False,
            prev_pi='pi',
            silm=pi)
        model = make_model()
        if load_path is not None:
            model.load(load_path)
    with tf.variable_scope("oldpi", reuse=tf.AUTO_REUSE):
        oldpi = policy(observ_placeholder=ob)
        make_old_model = lambda: Model(
            policy=policy,
            ob_space=ob_space,
            ac_space=ac_space,
            nbatch_act=nenvs,
            nbatch_train=nbatch_train,
            nsteps=timesteps_per_batch,
            ent_coef=entcoeff,
            vf_coef=vf_coef,
            max_grad_norm=max_grad_norm,
            sil_update=sil_update,
            sil_value=sil_value,
            sil_alpha=sil_alpha,
            sil_beta=sil_beta,
            sil_loss=sil_loss,
            #                                    fn_reward=env.process_reward,
            fn_reward=None,
            #                                    fn_obs=env.process_obs,
            fn_obs=None,
            ppo=False,
            prev_pi='oldpi',
            silm=oldpi)
        old_model = make_old_model()

    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    old_entropy = oldpi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    entbonus = entcoeff * meanent

    vferr = tf.reduce_mean(tf.square(pi.vf - ret))

    ratio = tf.exp(pi.pd.logp(ac) -
                   oldpi.pd.logp(ac))  # advantage * pnew / pold
    surrgain = tf.reduce_mean(ratio * atarg)

    optimgain = surrgain + entbonus
    losses = [optimgain, meankl, entbonus, surrgain, meanent]
    loss_names = ["optimgain", "meankl", "entloss", "surrgain", "Entropy"]

    dist = meankl

    #all_var_list = pi.get_trainable_variables()
    #all_var_list = [v for v in all_var_list if v.name.split("/")[0].startswith("pi")]
    #var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("pol")]
    #vf_var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("vf")]

    all_var_list = get_trainable_variables("pi")
    var_list = get_pi_trainable_variables("pi")
    vf_var_list = get_vf_trainable_variables("pi")

    vfadam = MpiAdam(vf_var_list)

    get_flat = U.GetFlat(var_list)
    set_from_flat = U.SetFromFlat(var_list)
    klgrads = tf.gradients(dist, var_list)
    flat_tangent = tf.placeholder(dtype=tf.float32,
                                  shape=[None],
                                  name="flat_tan")
    shapes = [var.get_shape().as_list() for var in var_list]
    start = 0
    tangents = []
    for shape in shapes:
        sz = U.intprod(shape)
        tangents.append(tf.reshape(flat_tangent[start:start + sz], shape))
        start += sz
    #????gvp and fvp???
    gvp = tf.add_n([
        tf.reduce_sum(g * tangent)
        for (g, tangent) in zipsame(klgrads, tangents)
    ])  #pylint: disable=E1111
    fvp = U.flatgrad(gvp, var_list)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(get_variables("oldpi"), get_variables("pi"))
        ])
    compute_losses = U.function([ob, ac, atarg], losses)
    compute_lossandgrad = U.function([ob, ac, atarg], losses +
                                     [U.flatgrad(optimgain, var_list)])
    compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp)
    compute_vflossandgrad = U.function([ob, ret],
                                       U.flatgrad(vferr, vf_var_list))

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(
                colorize("done in %.3f seconds" % (time.time() - tstart),
                         color='magenta'))
        else:
            yield

    def allmean(x):
        assert isinstance(x, np.ndarray)
        out = np.empty_like(x)
        MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
        out /= nworkers
        return out

    U.initialize()
    if load_path is not None:
        pi.load(load_path)
    th_init = get_flat()
    if MPI is not None:
        MPI.COMM_WORLD.Bcast(th_init, root=0)
    set_from_flat(th_init)
    vfadam.sync()
    print("Init param sum", th_init.sum(), flush=True)

    # Initialize eta, omega optimizer
    if discrete_ac_space:
        init_eta = 1
        init_omega = 0.5
        eta_omega_optimizer = EtaOmegaOptimizerDiscrete(
            beta, epsilon, init_eta, init_omega)
    else:
        init_eta = 0.5
        init_omega = 2.0
        #????eta_omega_optimizer details?????
        eta_omega_optimizer = EtaOmegaOptimizer(beta, epsilon, init_eta,
                                                init_omega)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(env,
                                     timesteps_per_batch,
                                     model,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=40)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40)  # rolling buffer for episode rewards

    assert sum([max_iters > 0, total_timesteps > 0, max_episodes > 0]) == 1

    while True:
        if callback: callback(locals(), globals())
        if total_timesteps and timesteps_so_far >= total_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        logger.log("********** Iteration %i ************" % iters_so_far)

        with timed("sampling"):
            seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        model = seg["model"]
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        #print(ob[:20])
        #print(ac[:20])

        if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret)
        if hasattr(pi, "rms"):
            pi.rms.update(ob)  # update running mean/std for policy

        args = seg["ob"], seg["ac"], atarg
        fvpargs = [arr[::5] for arr in args]

        def fisher_vector_product(p):
            return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p

        assign_old_eq_new()  # set old parameter values to new parameter values
        with timed("computegrad"):
            *lossbefore, g = compute_lossandgrad(*args)
        lossbefore = allmean(np.array(lossbefore))
        g = allmean(g)
        if np.allclose(g, 0):
            logger.log("Got zero gradient. not updating")
        else:
            with timed("cg"):
                stepdir = cg(fisher_vector_product,
                             g,
                             cg_iters=cg_iters,
                             verbose=rank == 0)
            assert np.isfinite(stepdir).all()

            if TRPO:
                #
                # TRPO specific code.
                # Find correct step size using line search
                #
                shs = .5 * stepdir.dot(fisher_vector_product(stepdir))
                lm = np.sqrt(shs / epsilon)
                # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
                fullstep = stepdir / lm
                expectedimprove = g.dot(fullstep)
                surrbefore = lossbefore[0]
                stepsize = 1.0
                thbefore = get_flat()
                for _ in range(10):
                    thnew = thbefore + fullstep * stepsize
                    set_from_flat(thnew)
                    meanlosses = surr, kl, *_ = allmean(
                        np.array(compute_losses(*args)))
                    improve = surr - surrbefore
                    logger.log("Expected: %.3f Actual: %.3f" %
                               (expectedimprove, improve))
                    if not np.isfinite(meanlosses).all():
                        logger.log("Got non-finite value of losses -- bad!")
                    elif kl > epsilon * 1.5:
                        logger.log("violated KL constraint. shrinking step.")
                    elif improve < 0:
                        logger.log("surrogate didn't improve. shrinking step.")
                    else:
                        logger.log("Stepsize OK!")
                        break
                    stepsize *= .5
                else:
                    logger.log("couldn't compute a good step")
                    set_from_flat(thbefore)
            else:
                #
                # COPOS specific implementation.
                #
                copos_update_dir = stepdir

                # Split direction into log-linear 'w_theta' and non-linear 'w_beta' parts
                w_theta, w_beta = pi.split_w(copos_update_dir)

                tmp_ob = np.zeros(
                    (1, ) + env.observation_space.shape
                )  # We assume that entropy does not depend on the NN

                # Optimize eta and omega
                if discrete_ac_space:
                    entropy = lossbefore[4]
                    #entropy = - 1/timesteps_per_batch * np.sum(np.sum(pi.get_action_prob(ob) * pi.get_log_action_prob(ob), axis=1))
                    eta, omega = eta_omega_optimizer.optimize(
                        pi.compute_F_w(ob, copos_update_dir),
                        pi.get_log_action_prob(ob), timesteps_per_batch,
                        entropy)
                else:
                    Waa, Wsa = pi.w2W(w_theta)
                    wa = pi.get_wa(ob, w_beta)

                    varphis = pi.get_varphis(ob)

                    #old_ent = old_entropy.eval({oldpi.ob: tmp_ob})[0]
                    old_ent = lossbefore[4]
                    eta, omega = eta_omega_optimizer.optimize(
                        w_theta, Waa, Wsa, wa, varphis, pi.get_kt(),
                        pi.get_prec_matrix(), pi.is_new_policy_valid, old_ent)
                logger.log("Initial eta: " + str(eta) + " and omega: " +
                           str(omega))

                current_theta_beta = get_flat()
                prev_theta, prev_beta = pi.all_to_theta_beta(
                    current_theta_beta)

                if discrete_ac_space:
                    # Do a line search for both theta and beta parameters by adjusting only eta
                    eta = eta_search(w_theta, w_beta, eta, omega, allmean,
                                     compute_losses, get_flat, set_from_flat,
                                     pi, epsilon, args, discrete_ac_space)
                    logger.log("Updated eta, eta: " + str(eta))
                    set_from_flat(pi.theta_beta_to_all(prev_theta, prev_beta))
                    # Find proper omega for new eta. Use old policy parameters first.
                    eta, omega = eta_omega_optimizer.optimize(
                        pi.compute_F_w(ob, copos_update_dir),
                        pi.get_log_action_prob(ob), timesteps_per_batch,
                        entropy, eta)
                    logger.log("Updated omega, eta: " + str(eta) +
                               " and omega: " + str(omega))

                    # do line search for ratio for non-linear "beta" parameter values
                    #ratio = beta_ratio_line_search(w_theta, w_beta, eta, omega, allmean, compute_losses, get_flat, set_from_flat, pi,
                    #                     epsilon, beta, args)
                    # set ratio to 1 if we do not use beta ratio line search
                    ratio = 1
                    #print("ratio from line search: " + str(ratio))
                    cur_theta = (eta * prev_theta +
                                 w_theta.reshape(-1, )) / (eta + omega)
                    cur_beta = prev_beta + ratio * w_beta.reshape(-1, ) / eta
                else:
                    for i in range(2):
                        # Do a line search for both theta and beta parameters by adjusting only eta
                        eta = eta_search(w_theta, w_beta, eta, omega, allmean,
                                         compute_losses, get_flat,
                                         set_from_flat, pi, epsilon, args)
                        logger.log("Updated eta, eta: " + str(eta) +
                                   " and omega: " + str(omega))

                        # Find proper omega for new eta. Use old policy parameters first.
                        set_from_flat(
                            pi.theta_beta_to_all(prev_theta, prev_beta))
                        eta, omega = \
                            eta_omega_optimizer.optimize(w_theta, Waa, Wsa, wa, varphis, pi.get_kt(),
                                                         pi.get_prec_matrix(), pi.is_new_policy_valid, old_ent, eta)
                        logger.log("Updated omega, eta: " + str(eta) +
                                   " and omega: " + str(omega))

                    # Use final policy
                    logger.log("Final eta: " + str(eta) + " and omega: " +
                               str(omega))
                    cur_theta = (eta * prev_theta +
                                 w_theta.reshape(-1, )) / (eta + omega)
                    cur_beta = prev_beta + w_beta.reshape(-1, ) / eta

                set_from_flat(pi.theta_beta_to_all(cur_theta, cur_beta))

                meanlosses = surr, kl, *_ = allmean(
                    np.array(compute_losses(*args)))
##copos specific over
            if nworkers > 1 and iters_so_far % 20 == 0:
                paramsums = MPI.COMM_WORLD.allgather(
                    (thnew.sum(), vfadam.getflat().sum()))  # list of tuples
                assert all(
                    np.allclose(ps, paramsums[0]) for ps in paramsums[1:])
#cg over
        for (lossname, lossval) in zip(loss_names, meanlosses):
            logger.record_tabular(lossname, lossval)


#policy update over
        with timed("vf"):
            for _ in range(vf_iters):
                for (mbob, mbret) in dataset.iterbatches(
                    (seg["ob"], seg["tdlamret"]),
                        include_final_partial_batch=False,
                        batch_size=64):
                    g = allmean(compute_vflossandgrad(mbob, mbret))
                    vfadam.update(g, vf_stepsize)
        with timed('SIL'):
            lrnow = lr(1.0 - timesteps_so_far / total_timesteps)
            l_loss, sil_adv, sil_samples, sil_nlogp = model.sil_train(lrnow)

        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))

        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        print("Reward max: " + str(max(rewbuffer)))
        print("Reward min: " + str(min(rewbuffer)))

        logger.record_tabular(
            "EpLenMean",
            np.mean(lenbuffer) if np.sum(lenbuffer) != 0.0 else 0.0)
        logger.record_tabular(
            "EpRewMean",
            np.mean(rewbuffer) if np.sum(rewbuffer) != 0.0 else 0.0)
        logger.record_tabular(
            "AverageReturn",
            np.mean(rewbuffer) if np.sum(rewbuffer) != 0.0 else 0.0)
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        if sil_update > 0:
            logger.record_tabular("SilSamples", sil_samples)

        if rank == 0:
            logger.dump_tabular()
예제 #2
0
def learn(env, policy_func, *,
        timesteps_per_batch, # what to train on
        max_kl, cg_iters,
        gamma, lam, # advantage estimation
        entcoeff=0.0,
        cg_damping=1e-2,
        vf_stepsize=3e-4,
        vf_iters =3,
        max_timesteps=0, max_episodes=0, max_iters=0,  # time constraint
        callback=None
        ):
    nworkers = MPI.COMM_WORLD.Get_size()
    rank = MPI.COMM_WORLD.Get_rank()
    np.set_printoptions(precision=3)    
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_func("pi", ob_space, ac_space)
    oldpi = policy_func("oldpi", ob_space, ac_space)
    atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = U.mean(kloldnew)
    meanent = U.mean(ent)
    entbonus = entcoeff * meanent

    vferr = U.mean(tf.square(pi.vpred - ret))

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # advantage * pnew / pold
    surrgain = U.mean(ratio * atarg)

    optimgain = surrgain + entbonus
    losses = [optimgain, meankl, entbonus, surrgain, meanent]
    loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]

    dist = meankl

    all_var_list = pi.get_trainable_variables()
    var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("pol")]
    vf_var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("vf")]
    vfadam = MpiAdam(vf_var_list)

    get_flat = U.GetFlat(var_list)
    set_from_flat = U.SetFromFlat(var_list)
    klgrads = tf.gradients(dist, var_list)
    flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan")
    shapes = [var.get_shape().as_list() for var in var_list]
    start = 0
    tangents = []
    for shape in shapes:
        sz = U.intprod(shape)
        tangents.append(tf.reshape(flat_tangent[start:start+sz], shape))
        start += sz
    gvp = tf.add_n([U.sum(g*tangent) for (g, tangent) in zipsame(klgrads, tangents)]) #pylint: disable=E1111
    fvp = U.flatgrad(gvp, var_list)

    assign_old_eq_new = U.function([],[], updates=[tf.assign(oldv, newv)
        for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())])
    compute_losses = U.function([ob, ac, atarg], losses)
    compute_lossandgrad = U.function([ob, ac, atarg], losses + [U.flatgrad(optimgain, var_list)])
    compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp)
    compute_vflossandgrad = U.function([ob, ret], U.flatgrad(vferr, vf_var_list))

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(colorize("done in %.3f seconds"%(time.time() - tstart), color='magenta'))
        else:
            yield
    
    def allmean(x):
        assert isinstance(x, np.ndarray)
        out = np.empty_like(x)
        MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
        out /= nworkers
        return out

    U.initialize()
    th_init = get_flat()
    MPI.COMM_WORLD.Bcast(th_init, root=0)
    set_from_flat(th_init)
    vfadam.sync()
    print("Init param sum", th_init.sum(), flush=True)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi, env, timesteps_per_batch, stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards

    assert sum([max_iters>0, max_timesteps>0, max_episodes>0])==1

    while True:        
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        logger.log("********** Iteration %i ************"%iters_so_far)

        with timed("sampling"):
            seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"]
        vpredbefore = seg["vpred"] # predicted value function before udpate
        atarg = (atarg - atarg.mean()) / atarg.std() # standardized advantage function estimate

        if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret)
        if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy

        args = seg["ob"], seg["ac"], atarg
        fvpargs = [arr[::5] for arr in args]
        def fisher_vector_product(p):
            return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p

        assign_old_eq_new() # set old parameter values to new parameter values
        with timed("computegrad"):
            *lossbefore, g = compute_lossandgrad(*args)
        lossbefore = allmean(np.array(lossbefore))
        g = allmean(g)
        if np.allclose(g, 0):
            logger.log("Got zero gradient. not updating")
        else:
            with timed("cg"):
                stepdir = cg(fisher_vector_product, g, cg_iters=cg_iters, verbose=rank==0)
            assert np.isfinite(stepdir).all()
            shs = .5*stepdir.dot(fisher_vector_product(stepdir))
            lm = np.sqrt(shs / max_kl)
            # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
            fullstep = stepdir / lm
            expectedimprove = g.dot(fullstep)
            surrbefore = lossbefore[0]
            stepsize = 1.0
            thbefore = get_flat()
            for _ in range(10):
                thnew = thbefore + fullstep * stepsize
                set_from_flat(thnew)
                meanlosses = surr, kl, *_ = allmean(np.array(compute_losses(*args)))
                improve = surr - surrbefore
                logger.log("Expected: %.3f Actual: %.3f"%(expectedimprove, improve))
                if not np.isfinite(meanlosses).all():
                    logger.log("Got non-finite value of losses -- bad!")
                elif kl > max_kl * 1.5:
                    logger.log("violated KL constraint. shrinking step.")
                elif improve < 0:
                    logger.log("surrogate didn't improve. shrinking step.")
                else:
                    logger.log("Stepsize OK!")
                    break
                stepsize *= .5
            else:
                logger.log("couldn't compute a good step")
                set_from_flat(thbefore)
            if nworkers > 1 and iters_so_far % 20 == 0:
                paramsums = MPI.COMM_WORLD.allgather((thnew.sum(), vfadam.getflat().sum())) # list of tuples
                assert all(np.allclose(ps, paramsums[0]) for ps in paramsums[1:])

        for (lossname, lossval) in zip(loss_names, meanlosses):
            logger.record_tabular(lossname, lossval)

        with timed("vf"):

            for _ in range(vf_iters):
                for (mbob, mbret) in dataset.iterbatches((seg["ob"], seg["tdlamret"]), 
                include_final_partial_batch=False, batch_size=64):
                    g = allmean(compute_vflossandgrad(mbob, mbret))
                    vfadam.update(g, vf_stepsize)

        logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))

        lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)

        if rank==0:
            logger.dump_tabular()
예제 #3
0
def learn(
        env,
        policy_fn,
        *,
        timesteps_per_batch,  # what to train on
        epsilon,
        beta,
        cg_iters,
        gamma,
        lam,  # advantage estimation
        trial,
        method,
        entcoeff=0.0,
        cg_damping=1e-2,
        vf_stepsize=3e-4,
        vf_iters=3,
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,  # time constraint
        callback=None,
        TRPO=False):
    nworkers = MPI.COMM_WORLD.Get_size()
    rank = MPI.COMM_WORLD.Get_rank()
    np.set_printoptions(precision=3)
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_fn("pi", ob_space, ac_space)
    oldpi = policy_fn("oldpi", ob_space, ac_space)
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    old_entropy = oldpi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    entbonus = entcoeff * meanent

    vferr = tf.reduce_mean(tf.square(pi.vpred - ret))

    # ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # advantage * pnew / pold
    # surrgain = tf.reduce_mean(ratio * atarg)

    surrgain = tf.reduce_mean(pi.pd.logp(ac) * atarg)

    optimgain = surrgain + entbonus
    losses = [optimgain, meankl, entbonus, surrgain, meanent]
    loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]

    dist = meankl

    all_var_list = pi.get_trainable_variables()
    all_var_list = [
        v for v in all_var_list if v.name.split("/")[0].startswith("pi")
    ]
    var_list = [
        v for v in all_var_list if v.name.split("/")[1].startswith("pol")
    ]
    vf_var_list = [
        v for v in all_var_list if v.name.split("/")[1].startswith("vf")
    ]
    vfadam = MpiAdam(vf_var_list)

    get_flat = U.GetFlat(var_list)
    set_from_flat = U.SetFromFlat(var_list)
    klgrads = tf.gradients(dist, var_list)
    flat_tangent = tf.placeholder(dtype=tf.float32,
                                  shape=[None],
                                  name="flat_tan")
    shapes = [var.get_shape().as_list() for var in var_list]
    start = 0
    tangents = []
    for shape in shapes:
        sz = U.intprod(shape)
        tangents.append(tf.reshape(flat_tangent[start:start + sz], shape))
        start += sz
    gvp = tf.add_n([
        tf.reduce_sum(g * tangent)
        for (g, tangent) in zipsame(klgrads, tangents)
    ])  #pylint: disable=E1111
    fvp = U.flatgrad(gvp, var_list)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg], losses)
    compute_lossandgrad = U.function([ob, ac, atarg], losses +
                                     [U.flatgrad(optimgain, var_list)])
    compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp)
    compute_vflossandgrad = U.function([ob, ret],
                                       U.flatgrad(vferr, vf_var_list))

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(
                colorize("done in %.3f seconds" % (time.time() - tstart),
                         color='magenta'))
        else:
            yield

    def allmean(x):
        assert isinstance(x, np.ndarray)
        out = np.empty_like(x)
        MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
        out /= nworkers
        return out

    U.initialize()
    th_init = get_flat()
    MPI.COMM_WORLD.Bcast(th_init, root=0)
    set_from_flat(th_init)
    vfadam.sync()
    print("Init param sum", th_init.sum(), flush=True)

    # Initialize eta, omega optimizer
    init_eta = 0.5
    init_omega = 2.0
    eta_omega_optimizer = EtaOmegaOptimizer(beta, epsilon, init_eta,
                                            init_omega)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_batch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=40)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40)  # rolling buffer for episode rewards

    assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        logger.log("********** Iteration %i ************" % iters_so_far)

        with timed("sampling"):
            seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate

        if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret)
        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        args = seg["ob"], seg["ac"], atarg
        fvpargs = [arr[::5] for arr in args]

        def fisher_vector_product(p):
            return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p

        assign_old_eq_new()  # set old parameter values to new parameter values
        with timed("computegrad"):
            *lossbefore, g = compute_lossandgrad(*args)
        lossbefore = allmean(np.array(lossbefore))
        g = allmean(g)
        if np.allclose(g, 0):
            logger.log("Got zero gradient. not updating")
        else:
            with timed("cg"):
                stepdir = cg(fisher_vector_product,
                             g,
                             cg_iters=cg_iters,
                             verbose=rank == 0)
            assert np.isfinite(stepdir).all()

            if TRPO:
                #
                # TRPO specific code.
                # Find correct step size using line search
                #
                shs = .5 * stepdir.dot(fisher_vector_product(stepdir))
                lm = np.sqrt(shs / epsilon)
                # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
                fullstep = stepdir / lm
                expectedimprove = g.dot(fullstep)
                surrbefore = lossbefore[0]
                stepsize = 1.0
                thbefore = get_flat()
                for _ in range(10):
                    thnew = thbefore + fullstep * stepsize
                    set_from_flat(thnew)
                    meanlosses = surr, kl, *_ = allmean(
                        np.array(compute_losses(*args)))
                    improve = surr - surrbefore
                    logger.log("Expected: %.3f Actual: %.3f" %
                               (expectedimprove, improve))
                    if not np.isfinite(meanlosses).all():
                        logger.log("Got non-finite value of losses -- bad!")
                    elif kl > epsilon * 1.5:
                        logger.log("violated KL constraint. shrinking step.")
                    elif improve < 0:
                        logger.log("surrogate didn't improve. shrinking step.")
                    else:
                        logger.log("Stepsize OK!")
                        break
                    stepsize *= .5
                else:
                    logger.log("couldn't compute a good step")
                    set_from_flat(thbefore)
            else:
                #
                # COPOS specific implementation.
                #

                copos_update_dir = stepdir

                # Split direction into log-linear 'w_theta' and non-linear 'w_beta' parts
                w_theta, w_beta = pi.split_w(copos_update_dir)

                # q_beta(s,a) = \grad_beta \log \pi(a|s) * w_beta
                #             = features_beta(s) * K^T * Prec * a
                # q_beta = self.target.get_q_beta(features_beta, actions)

                Waa, Wsa = pi.w2W(w_theta)
                wa = pi.get_wa(ob, w_beta)

                varphis = pi.get_varphis(ob)

                # Optimize eta and omega
                tmp_ob = np.zeros(
                    (1, ) + env.observation_space.shape
                )  # We assume that entropy does not depend on the NN
                old_ent = old_entropy.eval({oldpi.ob: tmp_ob})[0]
                eta, omega = eta_omega_optimizer.optimize(
                    w_theta, Waa, Wsa, wa, varphis, pi.get_kt(),
                    pi.get_prec_matrix(), pi.is_new_policy_valid, old_ent)
                logger.log("Initial eta: " + str(eta) + " and omega: " +
                           str(omega))

                current_theta_beta = get_flat()
                prev_theta, prev_beta = pi.all_to_theta_beta(
                    current_theta_beta)

                for i in range(2):
                    # Do a line search for both theta and beta parameters by adjusting only eta
                    eta = eta_search(w_theta, w_beta, eta, omega, allmean,
                                     compute_losses, get_flat, set_from_flat,
                                     pi, epsilon, args)
                    logger.log("Updated eta, eta: " + str(eta) +
                               " and omega: " + str(omega))

                    # Find proper omega for new eta. Use old policy parameters first.
                    set_from_flat(pi.theta_beta_to_all(prev_theta, prev_beta))
                    eta, omega = \
                        eta_omega_optimizer.optimize(w_theta, Waa, Wsa, wa, varphis, pi.get_kt(),
                                                     pi.get_prec_matrix(), pi.is_new_policy_valid, old_ent, eta)
                    logger.log("Updated omega, eta: " + str(eta) +
                               " and omega: " + str(omega))

                # Use final policy
                logger.log("Final eta: " + str(eta) + " and omega: " +
                           str(omega))
                cur_theta = (eta * prev_theta +
                             w_theta.reshape(-1, )) / (eta + omega)
                cur_beta = prev_beta + w_beta.reshape(-1, ) / eta
                set_from_flat(pi.theta_beta_to_all(cur_theta, cur_beta))

                meanlosses = surr, kl, *_ = allmean(
                    np.array(compute_losses(*args)))

            if nworkers > 1 and iters_so_far % 20 == 0:
                paramsums = MPI.COMM_WORLD.allgather(
                    (thnew.sum(), vfadam.getflat().sum()))  # list of tuples
                assert all(
                    np.allclose(ps, paramsums[0]) for ps in paramsums[1:])

        for (lossname, lossval) in zip(loss_names, meanlosses):
            logger.record_tabular(lossname, lossval)

        with timed("vf"):

            for _ in range(vf_iters):
                for (mbob, mbret) in dataset.iterbatches(
                    (seg["ob"], seg["tdlamret"]),
                        include_final_partial_batch=False,
                        batch_size=64):
                    g = allmean(compute_vflossandgrad(mbob, mbret))
                    vfadam.update(g, vf_stepsize)

        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))

        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        logger.record_tabular("Name", method)
        logger.record_tabular("Iteration", iters_so_far)
        logger.record_tabular("trial", trial)

        if rank == 0:
            logger.dump_tabular()
예제 #4
0
class LocalTrainer(object):
    def __init__(self, id, env, runner, policy, old_policy, pis, global_policy,
                 config):
        self.id = id
        self._name = 'local_%d' % id
        self._env = env.unwrapped
        self._runner = runner
        self._config = config
        self._policy = policy
        self._old_policy = old_policy
        self._pis = pis

        self._ent_coeff = config.ent_coeff

        # set to the global network
        self._init_network = U.function(
            [],
            tf.group(*[
                tf.assign(v2, v1)
                for v1, v2 in zip(global_policy.var_list, policy.var_list)
            ]))

        # copy weights to the global network
        self._copy_network = U.function(
            [],
            tf.group(*[
                tf.assign(v1, v2)
                for v1, v2 in zip(global_policy.var_list, policy.var_list)
            ]))

        # tensorboard summary
        self._is_chef = (MPI.COMM_WORLD.Get_rank() == 0)
        self._num_workers = MPI.COMM_WORLD.Get_size()
        self.summary_name = ["reward", "length", "adv"]
        self.summary_name += env.unwrapped.reward_type

        # build loss/optimizers
        self._build_trpo()

        if self._config.is_train:
            self.summary_name = [
                '{}/{}'.format(self._name, key) for key in self.summary_name
            ]

    def init_network(self):
        self._init_network()

    def copy_network(self):
        self._copy_network()

    @contextmanager
    def timed(self, msg):
        if self._is_chef:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(
                colorize("done in %.3f seconds" % (time.time() - tstart),
                         color='magenta'))
        else:
            yield

    def _all_mean(self, x):
        assert isinstance(x, np.ndarray)
        out = np.empty_like(x)
        MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
        out /= self._num_workers
        return out

    def _build_trpo(self):
        pi = self._policy
        oldpi = self._old_policy
        other_pis = self._pis

        # input placeholders
        ob = pi.ob
        ac = pi.pdtype.sample_placeholder([None], name='action')
        atarg = tf.placeholder(
            dtype=tf.float32, shape=[None],
            name='advantage')  # Target advantage function (if applicable)
        ret = tf.placeholder(dtype=tf.float32, shape=[None],
                             name='return')  # Empirical return

        # policy
        all_var_list = pi.get_trainable_variables()
        pol_var_list = [
            v for v in all_var_list if v.name.split("/")[1].startswith("pol")
        ]
        vf_var_list = [
            v for v in all_var_list if v.name.split("/")[1].startswith("vf")
        ]
        self._vf_adam = MpiAdam(vf_var_list)

        kl_oldnew = oldpi.pd.kl(pi.pd)
        ent = pi.pd.entropy()
        mean_kl = tf.reduce_mean(kl_oldnew)
        mean_ent = tf.reduce_mean(ent)
        pol_entpen = -self._config.ent_coeff * mean_ent

        vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))

        ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))
        pol_surr = tf.reduce_mean(ratio * atarg)

        # divergence
        other_obs = []  # put id-th data
        for other_pi in other_pis:
            other_obs.extend(other_pi.obs[self.id])
        my_obs_for_other = flatten_lists(pi.obs)  # put i-th data
        other_obs_for_other = []  # put i-th data
        for i, other_pi in enumerate(other_pis):
            other_obs_for_other.extend(other_pi.obs[i])

        pairwise_divergence = [tf.constant(0.)]
        for i, other_pi in enumerate(other_pis):
            if i != self.id:
                pairwise_divergence.append(
                    tf.reduce_mean(pi.pds[self.id].kl(other_pi.pds[self.id])))
                pairwise_divergence.append(
                    tf.reduce_mean(other_pi.pds[i].kl(pi.pds[i])))
        pol_divergence = self._config.divergence_coeff * tf.reduce_mean(
            pairwise_divergence)

        pol_loss = pol_surr + pol_entpen + pol_divergence
        pol_losses = {
            'pol_loss': pol_loss,
            'pol_surr': pol_surr,
            'pol_entpen': pol_entpen,
            'pol_divergence': pol_divergence,
            'kl': mean_kl,
            'entropy': mean_ent
        }
        self.summary_name += ['vf_loss']
        self.summary_name += pol_losses.keys()

        self._get_flat = U.GetFlat(pol_var_list)
        self._set_from_flat = U.SetFromFlat(pol_var_list)
        klgrads = tf.gradients(mean_kl, pol_var_list)
        flat_tangent = tf.placeholder(dtype=tf.float32,
                                      shape=[None],
                                      name="flat_tan")
        shapes = [var.get_shape().as_list() for var in pol_var_list]
        start = 0
        tangents = []
        for shape in shapes:
            sz = U.intprod(shape)
            tangents.append(tf.reshape(flat_tangent[start:start + sz], shape))
            start += sz
        gvp = tf.add_n([
            tf.reduce_sum(g * tangent)
            for (g, tangent) in zipsame(klgrads, tangents)
        ])  #pylint: disable=E1111
        fvp = U.flatgrad(gvp, pol_var_list)

        self._update_oldpi = U.function(
            [], [],
            updates=[
                tf.assign(oldv, newv) for (
                    oldv,
                    newv) in zipsame(oldpi.get_variables(), pi.get_variables())
            ])
        obs_pairwise = other_obs + my_obs_for_other + other_obs_for_other + ob
        self._compute_losses = U.function(obs_pairwise + [ac, atarg],
                                          pol_losses)
        pol_losses = dict(pol_losses)
        pol_losses.update({'g': U.flatgrad(pol_loss, pol_var_list)})
        self._compute_lossandgrad = U.function(obs_pairwise + [ac, atarg],
                                               pol_losses)
        self._compute_fvp = U.function([flat_tangent] + obs_pairwise +
                                       [ac, atarg], fvp)
        self._compute_vflossandgrad = U.function(
            ob + [ret], U.flatgrad(vf_loss, vf_var_list))
        self._compute_vfloss = U.function(ob + [ret], vf_loss)

        # initialize and sync
        U.initialize()
        th_init = self._get_flat()
        MPI.COMM_WORLD.Bcast(th_init, root=0)
        self._set_from_flat(th_init)
        self._vf_adam.sync()
        rank = MPI.COMM_WORLD.Get_rank()

        if self._config.debug:
            logger.log(
                "[worker: {} local net: {}] Init pol param sum: {}".format(
                    rank, self.id, th_init.sum()))
            logger.log(
                "[worker: {} local net: {}] Init vf param sum: {}".format(
                    rank, self.id,
                    self._vf_adam.getflat().sum()))

    def generate_rollout(self, sess, context=None):
        with sess.as_default(), sess.graph.as_default():
            with self.timed("sampling"):
                rollout = self._runner.rollout(stochastic=True,
                                               context=context)
                self._runner.add_advantage(rollout, 0.99, 0.98)
            self.rollout = rollout

    def update(self, sess, rollouts, global_step):
        config = self._config

        with sess.as_default(), sess.graph.as_default():
            # train policy
            info = self._update_policy(rollouts, global_step)

            for key, value in rollouts[self.id].items():
                if key.startswith('ep_'):
                    info[key.split('ep_')[1]] = np.mean(value)

            if self._is_chef:
                logger.log(
                    '[worker {}] iter: {}, rewards: {}, length: {}'.format(
                        self.id, global_step, np.mean(info["reward"]),
                        np.mean(info["length"])))
            info = {
                '{}/{}'.format(self._name, key): np.mean(value)
                for key, value in info.items()
            }
            return info

    def evaluate(self, ckpt_num=None, record=False, context=None):
        config = self._config

        ep_lens = []
        ep_rets = []

        if record:
            record_dir = osp.join(config.log_dir, 'video')
            os.makedirs(record_dir, exist_ok=True)

        for _ in tqdm.trange(5):
            ep_traj = self._runner.rollout(True, True, context)
            ep_lens.append(ep_traj["ep_length"][0])
            ep_rets.append(ep_traj["ep_reward"][0])
            logger.log('[{}] Trial #{}: lengths {}, returns {}'.format(
                self._name, _, ep_traj["ep_length"][0],
                ep_traj["ep_reward"][0]))

            # Video recording
            if record:
                visual_obs = ep_traj["visual_ob"]
                video_name = '{}{}_{}{}.{}'.format(
                    config.video_prefix or '', self._name,
                    '' if ckpt_num is None else 'ckpt_{}_'.format(ckpt_num), _,
                    config.video_format)
                video_path = osp.join(record_dir, video_name)

                if config.video_format == 'mp4':
                    fps = 60.

                    def f(t):
                        frame_length = len(visual_obs)
                        new_fps = 1. / (1. / fps + 1. / frame_length)
                        idx = min(int(t * new_fps), frame_length - 1)
                        return visual_obs[idx]

                    video = mpy.VideoClip(f,
                                          duration=len(visual_obs) / fps + 2)
                    video.write_videofile(video_path, fps, verbose=False)
                elif config.video_format == 'gif':
                    imageio.mimsave(video_path, visual_obs, fps=100)

        logger.log('[{}] Episode Length: {}'.format(self._name,
                                                    np.mean(ep_lens)))
        logger.log('[{}] Episode Rewards: {}'.format(self._name,
                                                     np.mean(ep_rets)))

    def update_ob_rms(self, rollouts):
        assert self._config.obs_norm == 'learn'
        ob = np.concatenate([rollout['ob'] for rollout in rollouts])
        ob_dict = self._env.get_ob_dict(ob)
        for ob_name in self._policy.ob_type:
            self._policy.ob_rms[ob_name].update(ob_dict[ob_name])

    def _update_policy(self, rollouts, it):
        pi = self._policy
        seg = rollouts[self.id]
        info = defaultdict(list)

        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        atarg = (atarg - atarg.mean()) / atarg.std()
        info['adv'] = np.mean(atarg)

        other_ob_list = []
        for i, other_pi in enumerate(self._pis):
            other_ob_list.extend(other_pi.get_ob_list(rollouts[i]["ob"]))

        ob_list = pi.get_ob_list(ob)
        args = ob_list * self._config.num_contexts + \
            other_ob_list * 2 + ob_list + [ac, atarg]
        fvpargs = [arr[::5] for arr in args]

        def fisher_vector_product(p):
            return self._all_mean(self._compute_fvp(
                p, *fvpargs)) + self._config.cg_damping * p

        self._update_oldpi()

        with self.timed("compute gradient"):
            lossbefore = self._compute_lossandgrad(*args)
            lossbefore = {
                k: self._all_mean(np.array(lossbefore[k]))
                for k in sorted(lossbefore.keys())
            }
        g = lossbefore['g']

        if np.allclose(g, 0):
            logger.log("Got zero gradient. not updating")
        else:
            with self.timed("compute conjugate gradient"):
                stepdir = cg(fisher_vector_product,
                             g,
                             cg_iters=self._config.cg_iters,
                             verbose=False)
            assert np.isfinite(stepdir).all()
            shs = .5 * stepdir.dot(fisher_vector_product(stepdir))
            lm = np.sqrt(shs / self._config.max_kl)
            fullstep = stepdir / lm
            expectedimprove = g.dot(fullstep)
            surrbefore = lossbefore['pol_loss']
            stepsize = 1.0
            thbefore = self._get_flat()
            for _ in range(10):
                thnew = thbefore + fullstep * stepsize
                self._set_from_flat(thnew)
                meanlosses = self._compute_losses(*args)
                meanlosses = {
                    k: self._all_mean(np.array(meanlosses[k]))
                    for k in sorted(meanlosses.keys())
                }
                for key, value in meanlosses.items():
                    if key != 'g':
                        info[key].append(value)
                surr = meanlosses['pol_loss']
                kl = meanlosses['kl']
                meanlosses = np.array(list(meanlosses.values()))
                improve = surr - surrbefore
                logger.log("Expected: %.3f Actual: %.3f" %
                           (expectedimprove, improve))
                if not np.isfinite(meanlosses).all():
                    logger.log("Got non-finite value of losses -- bad!")
                elif kl > self._config.max_kl * 1.5:
                    logger.log("violated KL constraint. shrinking step.")
                elif improve < 0:
                    logger.log("surrogate didn't improve. shrinking step.")
                else:
                    logger.log("Stepsize OK!")
                    break
                stepsize *= .5
            else:
                logger.log("couldn't compute a good step")
                self._set_from_flat(thbefore)
            if self._num_workers > 1 and it % 20 == 0:
                paramsums = MPI.COMM_WORLD.allgather(
                    (thnew.sum(),
                     self._vf_adam.getflat().sum()))  # list of tuples
                assert all(
                    np.allclose(ps, paramsums[0])
                    for ps in paramsums[1:]), paramsums

        with self.timed("updating value function"):
            for _ in range(self._config.vf_iters):
                for (mbob, mbret) in dataset.iterbatches(
                    (ob, tdlamret),
                        include_final_partial_batch=False,
                        batch_size=self._config.vf_batch_size):
                    ob_list = pi.get_ob_list(mbob)
                    g = self._all_mean(
                        self._compute_vflossandgrad(*ob_list, mbret))
                    self._vf_adam.update(g, self._config.vf_stepsize)
                    vf_loss = self._all_mean(
                        np.array(self._compute_vfloss(*ob_list, mbret)))
                    info['vf_loss'].append(vf_loss)

        for key, value in info.items():
            info[key] = np.mean(value)
        return info
예제 #5
0
def learn(
        *,
        network,
        env,
        eval_env,
        total_timesteps,
        timesteps_per_batch=1024,  # what to train on
        max_kl=0.001,
        cg_iters=10,
        gamma=0.99,
        lam=1.0,  # advantage estimation
        seed=None,
        ent_coef=0.0,
        cg_damping=1e-2,
        vf_stepsize=3e-4,
        vf_iters=3,
        log_path=None,
        max_episodes=0,
        max_iters=0,  # time constraint
        callback=None,
        load_path=None,
        **network_kwargs):
    '''
    learn a policy function with TRPO algorithm

    Parameters:
    ----------

    network                 neural network to learn. Can be either string ('mlp', 'cnn', 'lstm', 'lnlstm' for basic types)
                            or function that takes input placeholder and returns tuple (output, None) for feedforward nets
                            or (output, (state_placeholder, state_output, mask_placeholder)) for recurrent nets

    env                     environment (one of the gym environments or wrapped via baselines.common.vec_env.VecEnv-type class

    timesteps_per_batch     timesteps per gradient estimation batch

    max_kl                  max KL divergence between old policy and new policy ( KL(pi_old || pi) )

    ent_coef                coefficient of policy entropy term in the optimization objective

    cg_iters                number of iterations of conjugate gradient algorithm

    cg_damping              conjugate gradient damping

    vf_stepsize             learning rate for adam optimizer used to optimie value function loss

    vf_iters                number of iterations of value function optimization iterations per each policy optimization step

    total_timesteps           max number of timesteps

    max_episodes            max number of episodes

    max_iters               maximum number of policy optimization iterations

    callback                function to be called with (locals(), globals()) each policy optimization step

    load_path               str, path to load the model from (default: None, i.e. no model is loaded)

    **network_kwargs        keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network

    Returns:
    -------

    learnt model

    '''

    if MPI is not None:
        nworkers = MPI.COMM_WORLD.Get_size()
        rank = MPI.COMM_WORLD.Get_rank()
    else:
        nworkers = 1
        rank = 0

    set_global_seeds(seed)

    np.set_printoptions(precision=3)
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space

    if isinstance(network, str):
        network = get_network_builder(network)(**network_kwargs)

    with tf.name_scope("pi"):
        pi_policy_network = network(ob_space.shape)
        pi_value_network = network(ob_space.shape)
        pi = PolicyWithValue(ac_space, pi_policy_network, pi_value_network)
    with tf.name_scope("oldpi"):
        old_pi_policy_network = network(ob_space.shape)
        old_pi_value_network = network(ob_space.shape)
        oldpi = PolicyWithValue(ac_space, old_pi_policy_network,
                                old_pi_value_network)

    pi_var_list = pi_policy_network.trainable_variables + list(
        pi.pdtype.trainable_variables)
    old_pi_var_list = old_pi_policy_network.trainable_variables + list(
        oldpi.pdtype.trainable_variables)
    vf_var_list = pi_value_network.trainable_variables + pi.value_fc.trainable_variables
    old_vf_var_list = old_pi_value_network.trainable_variables + oldpi.value_fc.trainable_variables

    if load_path is not None:
        load_path = osp.expanduser(load_path)
        ckpt = tf.train.Checkpoint(model=pi)
        manager = tf.train.CheckpointManager(ckpt, load_path, max_to_keep=None)
        ckpt.restore(manager.latest_checkpoint)

    vfadam = MpiAdam(vf_var_list)

    get_flat = U.GetFlat(pi_var_list)
    set_from_flat = U.SetFromFlat(pi_var_list)
    loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]
    shapes = [var.get_shape().as_list() for var in pi_var_list]

    def assign_old_eq_new():
        for pi_var, old_pi_var in zip(pi_var_list, old_pi_var_list):
            old_pi_var.assign(pi_var)
        for vf_var, old_vf_var in zip(vf_var_list, old_vf_var_list):
            old_vf_var.assign(vf_var)

    @tf.function
    def compute_lossandgrad(ob, ac, atarg):
        with tf.GradientTape() as tape:
            old_policy_latent = oldpi.policy_network(ob)
            old_pd, _ = oldpi.pdtype.pdfromlatent(old_policy_latent)
            policy_latent = pi.policy_network(ob)
            pd, _ = pi.pdtype.pdfromlatent(policy_latent)
            kloldnew = old_pd.kl(pd)
            ent = pd.entropy()
            meankl = tf.reduce_mean(kloldnew)
            meanent = tf.reduce_mean(ent)
            entbonus = ent_coef * meanent
            ratio = tf.exp(pd.logp(ac) - old_pd.logp(ac))
            surrgain = tf.reduce_mean(ratio * atarg)
            optimgain = surrgain + entbonus
            losses = [optimgain, meankl, entbonus, surrgain, meanent]
        gradients = tape.gradient(optimgain, pi_var_list)
        return losses + [U.flatgrad(gradients, pi_var_list)]

    @tf.function
    def compute_losses(ob, ac, atarg):
        old_policy_latent = oldpi.policy_network(ob)
        old_pd, _ = oldpi.pdtype.pdfromlatent(old_policy_latent)
        policy_latent = pi.policy_network(ob)
        pd, _ = pi.pdtype.pdfromlatent(policy_latent)
        kloldnew = old_pd.kl(pd)
        ent = pd.entropy()
        meankl = tf.reduce_mean(kloldnew)
        meanent = tf.reduce_mean(ent)
        entbonus = ent_coef * meanent
        ratio = tf.exp(pd.logp(ac) - old_pd.logp(ac))
        surrgain = tf.reduce_mean(ratio * atarg)
        optimgain = surrgain + entbonus
        losses = [optimgain, meankl, entbonus, surrgain, meanent]
        return losses

    #ob shape should be [batch_size, ob_dim], merged nenv
    #ret shape should be [batch_size]
    @tf.function
    def compute_vflossandgrad(ob, ret):
        with tf.GradientTape() as tape:
            pi_vf = pi.value(ob)
            vferr = tf.reduce_mean(tf.square(pi_vf - ret))
        return U.flatgrad(tape.gradient(vferr, vf_var_list), vf_var_list)

    @tf.function
    def compute_fvp(flat_tangent, ob, ac, atarg):
        with tf.GradientTape() as outter_tape:
            with tf.GradientTape() as inner_tape:
                old_policy_latent = oldpi.policy_network(ob)
                old_pd, _ = oldpi.pdtype.pdfromlatent(old_policy_latent)
                policy_latent = pi.policy_network(ob)
                pd, _ = pi.pdtype.pdfromlatent(policy_latent)
                kloldnew = old_pd.kl(pd)
                meankl = tf.reduce_mean(kloldnew)
            klgrads = inner_tape.gradient(meankl, pi_var_list)
            start = 0
            tangents = []
            for shape in shapes:
                sz = U.intprod(shape)
                tangents.append(
                    tf.reshape(flat_tangent[start:start + sz], shape))
                start += sz
            gvp = tf.add_n([
                tf.reduce_sum(g * tangent)
                for (g, tangent) in zipsame(klgrads, tangents)
            ])
        hessians_products = outter_tape.gradient(gvp, pi_var_list)
        fvp = U.flatgrad(hessians_products, pi_var_list)
        return fvp

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(
                colorize("done in %.3f seconds" % (time.time() - tstart),
                         color='magenta'))
        else:
            yield

    def allmean(x):
        assert isinstance(x, np.ndarray)
        if MPI is not None:
            out = np.empty_like(x)
            MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
            out /= nworkers
        else:
            out = np.copy(x)

        return out

    th_init = get_flat()
    if MPI is not None:
        MPI.COMM_WORLD.Bcast(th_init, root=0)

    set_from_flat(th_init)
    vfadam.sync()
    print("Init param sum", th_init.sum(), flush=True)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi, env, timesteps_per_batch)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=40)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40)  # rolling buffer for episode rewards

    logdir = log_path + '/evaluator'
    modeldir = log_path + '/models'
    if not os.path.exists(logdir):
        os.makedirs(logdir)
    if not os.path.exists(modeldir):
        os.makedirs(modeldir)
    evaluator = Evaluator(env=eval_env, model=pi, logdir=logdir)
    max_inner_iter = 500000 if env.spec.id == 'InvertedDoublePendulum-v2' else 3000000
    epoch = vf_iters
    batch_size = timesteps_per_batch
    mb_size = 256
    inner_iter_per_iter = epoch * int(batch_size / mb_size)
    max_iter = int(max_inner_iter / inner_iter_per_iter)
    eval_num = 150
    eval_interval = save_interval = int(
        int(max_inner_iter / eval_num) / inner_iter_per_iter)

    if sum([max_iters > 0, total_timesteps > 0, max_episodes > 0]) == 0:
        # noththing to be done
        return pi

    assert sum([max_iters>0, total_timesteps>0, max_episodes>0]) < 2, \
        'out of max_iters, total_timesteps, and max_episodes only one should be specified'

    for update in range(1, max_iter + 1):
        if callback: callback(locals(), globals())
        # if total_timesteps and timesteps_so_far >= total_timesteps:
        #     break
        # elif max_episodes and episodes_so_far >= max_episodes:
        #     break
        # elif max_iters and iters_so_far >= max_iters:
        #     break
        logger.log("********** Iteration %i ************" % iters_so_far)
        if (update - 1) % eval_interval == 0:
            evaluator.run_evaluation(update - 1)
        if (update - 1) % save_interval == 0:
            ckpt = tf.train.Checkpoint(model=pi)
            ckpt.save(modeldir + '/ckpt_ite' + str((update - 1)))

        with timed("sampling"):
            seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        ob = sf01(ob)
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate

        if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret)
        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        args = ob, ac, atarg
        fvpargs = [arr[::5] for arr in args]

        def fisher_vector_product(p):
            return allmean(compute_fvp(p, *fvpargs).numpy()) + cg_damping * p

        assign_old_eq_new()  # set old parameter values to new parameter values
        with timed("computegrad"):
            *lossbefore, g = compute_lossandgrad(*args)
        lossbefore = allmean(np.array(lossbefore))
        g = g.numpy()
        g = allmean(g)
        if np.allclose(g, 0):
            logger.log("Got zero gradient. not updating")
        else:
            with timed("cg"):
                stepdir = cg(fisher_vector_product,
                             g,
                             cg_iters=cg_iters,
                             verbose=rank == 0)
            assert np.isfinite(stepdir).all()
            shs = .5 * stepdir.dot(fisher_vector_product(stepdir))
            lm = np.sqrt(shs / max_kl)
            # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
            fullstep = stepdir / lm
            expectedimprove = g.dot(fullstep)
            surrbefore = lossbefore[0]
            stepsize = 1.0
            thbefore = get_flat()
            for _ in range(10):
                thnew = thbefore + fullstep * stepsize
                set_from_flat(thnew)
                meanlosses = surr, kl, *_ = allmean(
                    np.array(compute_losses(*args)))
                improve = surr - surrbefore
                logger.log("Expected: %.3f Actual: %.3f" %
                           (expectedimprove, improve))
                if not np.isfinite(meanlosses).all():
                    logger.log("Got non-finite value of losses -- bad!")
                elif kl > max_kl * 1.5:
                    logger.log("violated KL constraint. shrinking step.")
                elif improve < 0:
                    logger.log("surrogate didn't improve. shrinking step.")
                else:
                    logger.log("Stepsize OK!")
                    break
                stepsize *= .5
            else:
                logger.log("couldn't compute a good step")
                set_from_flat(thbefore)
            if nworkers > 1 and iters_so_far % 20 == 0:
                paramsums = MPI.COMM_WORLD.allgather(
                    (thnew.sum(), vfadam.getflat().sum()))  # list of tuples
                assert all(
                    np.allclose(ps, paramsums[0]) for ps in paramsums[1:])

        for (lossname, lossval) in zip(loss_names, meanlosses):
            logger.record_tabular(lossname, lossval)

        with timed("vf"):

            for _ in range(vf_iters):
                for (mbob, mbret) in dataset.iterbatches(
                    (seg["ob"], seg["tdlamret"]),
                        include_final_partial_batch=False,
                        batch_size=mb_size):
                    mbob = sf01(mbob)
                    g = allmean(compute_vflossandgrad(mbob, mbret).numpy())
                    vfadam.update(g, vf_stepsize)

        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))

        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        if MPI is not None:
            listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        else:
            listoflrpairs = [lrlocal]

        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)

        if rank == 0:
            logger.dump_tabular()

    return pi
예제 #6
0
def learn(
        env,
        policy_fn,
        *,
        timesteps_per_batch,  # what to train on
        max_kl,
        cg_iters,
        gamma,
        lam,  # advantage estimation
        entc=0.5,
        cg_damping=1e-2,
        vf_stepsize=3e-4,
        vf_iters=3,
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,  # time constraint
        callback=None,
        i_trial):
    nworkers = MPI.COMM_WORLD.Get_size()
    rank = MPI.COMM_WORLD.Get_rank()
    np.set_printoptions(precision=3)
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_fn("pi", ob_space, ac_space)
    oldpi = policy_fn("oldpi", ob_space, ac_space)
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])
    entp = tf.placeholder(dtype=tf.float32, shape=[])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)

    entbonus = entp * meanent

    vferr = tf.reduce_mean(tf.square(pi.vpred - ret))

    ratio = tf.exp(pi.pd.logp(ac) -
                   oldpi.pd.logp(ac))  # advantage * pnew / pold
    surrgain = tf.reduce_mean(ratio * atarg)

    optimgain = surrgain + entbonus
    losses = [optimgain, meankl, entbonus, surrgain, meanent]
    loss_names = ["optimgain", "meankl", "entloss", "surrgain", "loss_ent"]

    dist = meankl

    all_var_list = pi.get_trainable_variables()
    var_list = [
        v for v in all_var_list if v.name.split("/")[1].startswith("pol")
    ]
    vf_var_list = [
        v for v in all_var_list if v.name.split("/")[1].startswith("vf")
    ]
    vfadam = MpiAdam(vf_var_list)

    get_flat = U.GetFlat(var_list)
    set_from_flat = U.SetFromFlat(var_list)
    klgrads = tf.gradients(dist, var_list)
    flat_tangent = tf.placeholder(dtype=tf.float32,
                                  shape=[None],
                                  name="flat_tan")
    shapes = [var.get_shape().as_list() for var in var_list]
    start = 0
    tangents = []
    for shape in shapes:
        sz = U.intprod(shape)
        tangents.append(tf.reshape(flat_tangent[start:start + sz], shape))
        start += sz
    gvp = tf.add_n([
        tf.reduce_sum(g * tangent)
        for (g, tangent) in zipsame(klgrads, tangents)
    ])  #pylint: disable=E1111
    fvp = U.flatgrad(gvp, var_list)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, entp], losses)
    compute_lossandgrad = U.function([ob, ac, atarg, entp], losses +
                                     [U.flatgrad(optimgain, var_list)])
    compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp)
    compute_vflossandgrad = U.function([ob, ret],
                                       U.flatgrad(vferr, vf_var_list))

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(
                colorize("done in %.3f seconds" % (time.time() - tstart),
                         color='magenta'))
        else:
            yield

    def allmean(x):
        assert isinstance(x, np.ndarray)
        out = np.empty_like(x)
        MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
        out /= nworkers
        return out

    U.initialize()
    tf.global_variables_initializer()
    th_init = get_flat()
    MPI.COMM_WORLD.Bcast(th_init, root=0)
    set_from_flat(th_init)
    vfadam.sync()
    print("Init param sum", th_init.sum(), flush=True)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_batch,
                                     stochastic=True,
                                     gamma=gamma)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=40)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40)  # rolling buffer for episode rewards
    drwdsbuffer = deque(maxlen=40)

    assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        logger.log("********** Iteration %i ************" % iters_so_far)

        with timed("sampling"):
            seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # entcoeff = max(entc - float(iters_so_far) / float(max_iters), 0.01)
        entcoeff = 0.0

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate

        if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret)
        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        args = seg["ob"], seg["ac"], atarg
        fvpargs = [arr[::5] for arr in args]

        def fisher_vector_product(p):
            return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p

        assign_old_eq_new()  # set old parameter values to new parameter values
        with timed("computegrad"):
            *lossbefore, g = compute_lossandgrad(*args, entcoeff)
        lossbefore = allmean(np.array(lossbefore))
        g = allmean(g)
        if np.allclose(g, 0):
            print("Got zero gradient. not updating")
        else:
            with timed("cg"):
                stepdir = cg(fisher_vector_product,
                             g,
                             cg_iters=cg_iters,
                             verbose=rank == 0)
            assert np.isfinite(stepdir).all()
            shs = .5 * stepdir.dot(fisher_vector_product(stepdir))
            lm = np.sqrt(shs / max_kl)
            # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
            fullstep = stepdir / lm
            expectedimprove = g.dot(fullstep)
            surrbefore = lossbefore[0]
            stepsize = 1.0
            thbefore = get_flat()
            for _ in range(10):
                thnew = thbefore + fullstep * stepsize
                set_from_flat(thnew)
                meanlosses = surr, kl, *_ = allmean(
                    np.array(compute_losses(*args, entcoeff)))
                improve = surr - surrbefore
                print("Expected: %.3f Actual: %.3f" %
                      (expectedimprove, improve))
                if not np.isfinite(meanlosses).all():
                    print("Got non-finite value of losses -- bad!")
                elif kl > max_kl * 1.5:
                    print("violated KL constraint. shrinking step.")
                elif improve < 0:
                    print("surrogate didn't improve. shrinking step.")
                else:
                    print("Stepsize OK!")
                    break
                stepsize *= .5
            else:
                print("couldn't compute a good step")
                set_from_flat(thbefore)
            if nworkers > 1 and iters_so_far % 20 == 0:
                paramsums = MPI.COMM_WORLD.allgather(
                    (thnew.sum(), vfadam.getflat().sum()))  # list of tuples
                assert all(
                    np.allclose(ps, paramsums[0]) for ps in paramsums[1:])

        for (lossname, lossval) in zip(loss_names, meanlosses):
            logger.logkv(lossname, lossval)

        with timed("vf"):
            for _ in range(vf_iters):
                for (mbob, mbret) in dataset.iterbatches(
                    (seg["ob"], seg["tdlamret"]),
                        include_final_partial_batch=False,
                        batch_size=64):
                    g = allmean(compute_vflossandgrad(mbob, mbret))
                    vfadam.update(g, vf_stepsize)

        logger.logkv("ev_tdlam_before",
                     explained_variance(vpredbefore, tdlamret))

        lrlocal = (seg["ep_lens"], seg["ep_rets"], seg["ep_drwds"]
                   )  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews, drwds = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        drwdsbuffer.extend(drwds)

        logger.logkv("EpLenMean", np.mean(lenbuffer))
        logger.logkv("EpRewMean", np.mean(rewbuffer))
        logger.logkv("EpThisIter", len(lens))
        logger.logkv("EpDRewMean", np.mean(drwdsbuffer))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        logger.logkv("EpisodesSoFar", episodes_so_far)
        logger.logkv("TimestepsSoFar", timesteps_so_far)
        logger.logkv("TimeElapsed", time.time() - tstart)
        logger.logkv('trial', i_trial)
        logger.logkv("Iteration", iters_so_far)
        logger.logkv("Name", 'TRPO')

        if rank == 0:
            logger.dump_tabular()
예제 #7
0
def learn(
        env,
        policy_func,
        reward_giver,
        expert_dataset,
        rank,
        pretrained,
        pretrained_weight,
        *,
        #                   0
        g_step,
        d_step,
        entcoeff,
        save_per_iter,
        #                         1024
        ckpt_dir,
        log_dir,
        timesteps_per_batch,
        task_name,
        robot_name,
        gamma,
        lam,
        max_kl,
        cg_iters,
        cg_damping=1e-2,
        vf_stepsize=3e-4,
        d_stepsize=3e-4,
        vf_iters=3,
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        callback=None):

    nworkers = MPI.COMM_WORLD.Get_size()
    rank = MPI.COMM_WORLD.Get_rank()
    np.set_printoptions(precision=3)
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_func("pi",
                     ob_space,
                     ac_space,
                     reuse=(pretrained_weight != None))
    oldpi = policy_func("oldpi", ob_space, ac_space)
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    entbonus = entcoeff * meanent

    vferr = tf.reduce_mean(tf.square(pi.vpred - ret))

    ratio = tf.exp(pi.pd.logp(ac) -
                   oldpi.pd.logp(ac))  # advantage * pnew / pold
    surrgain = tf.reduce_mean(ratio * atarg)

    optimgain = surrgain + entbonus
    losses = [optimgain, meankl, entbonus, surrgain, meanent]
    loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]

    dist = meankl

    all_var_list = pi.get_trainable_variables()
    var_list = [
        v for v in all_var_list
        if v.name.startswith("pi/pol") or v.name.startswith("pi/logstd")
    ]
    vf_var_list = [v for v in all_var_list if v.name.startswith("pi/vff")]
    assert len(var_list) == len(vf_var_list) + 1
    d_adam = MpiAdam(reward_giver.get_trainable_variables())
    vfadam = MpiAdam(vf_var_list)

    get_flat = U.GetFlat(var_list)
    set_from_flat = U.SetFromFlat(var_list)
    klgrads = tf.gradients(dist, var_list)
    flat_tangent = tf.placeholder(dtype=tf.float32,
                                  shape=[None],
                                  name="flat_tan")
    shapes = [var.get_shape().as_list() for var in var_list]
    start = 0
    tangents = []
    for shape in shapes:
        sz = U.intprod(shape)
        tangents.append(tf.reshape(flat_tangent[start:start + sz], shape))
        start += sz
    gvp = tf.add_n([
        tf.reduce_sum(g * tangent)
        for (g, tangent) in zipsame(klgrads, tangents)
    ])  # pylint: disable=E1111
    fvp = U.flatgrad(gvp, var_list)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg], losses)
    compute_lossandgrad = U.function([ob, ac, atarg], losses +
                                     [U.flatgrad(optimgain, var_list)])
    compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp)
    compute_vflossandgrad = U.function([ob, ret],
                                       U.flatgrad(vferr, vf_var_list))

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(
                colorize("done in %.3f seconds" % (time.time() - tstart),
                         color='magenta'))
        else:
            yield

    def allmean(x):
        assert isinstance(x, np.ndarray)
        out = np.empty_like(x)
        MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
        out /= nworkers
        return out

    U.initialize()
    th_init = get_flat()
    MPI.COMM_WORLD.Bcast(th_init, root=0)
    set_from_flat(th_init)
    d_adam.sync()
    vfadam.sync()
    if rank == 0:
        print("Init param sum", th_init.sum(), flush=True)
    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     reward_giver,
                                     timesteps_per_batch,
                                     stochastic=True)
    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=40)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40)  # rolling buffer for episode rewards
    true_rewbuffer = deque(maxlen=40)
    assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1

    # if provide pretrained weight
    if pretrained_weight is not None:
        U.load_state(pretrained_weight, var_list=pi.get_variables())

    if robot_name == 'scara':
        summary_writer = tf.summary.FileWriter(
            '/home/yue/gym-gazebo/Tensorboard/scara',
            graph=tf.get_default_graph())
    elif robot_name == 'mara':
        # summary_writer=tf.summary.FileWriter('/home/yue/gym-gazebo/Tensorboard/mara/down-home_position',graph=tf.get_default_graph())
        summary_writer = tf.summary.FileWriter(
            '/home/yue/gym-gazebo/Tensorboard/mara/collisions_model/',
            graph=tf.get_default_graph())

    while True:
        if callback:
            callback(locals(), globals())

        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break

        logger.log("********** Iteration %i ************" % iters_so_far)

        def fisher_vector_product(p):
            return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p

        # ------------------ Update G ------------------
        logger.log("Optimizing Policy...")
        for _ in range(g_step):
            with timed("sampling"):
                seg = seg_gen.__next__()
            add_vtarg_and_adv(seg, gamma, lam)
            # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
            ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
                "tdlamret"]
            vpredbefore = seg[
                "vpred"]  # predicted value function before udpate
            atarg = (atarg - atarg.mean()) / atarg.std(
            )  # standardized advantage function estimate

            if hasattr(pi, "ob_rms"):
                pi.ob_rms.update(ob)  # update running mean/std for policy

            args = seg["ob"], seg["ac"], atarg
            fvpargs = [arr[::5] for arr in args]

            assign_old_eq_new(
            )  # set old parameter values to new parameter values
            with timed("computegrad"):
                *lossbefore, g = compute_lossandgrad(*args)
            lossbefore = allmean(np.array(lossbefore))
            g = allmean(g)
            if np.allclose(g, 0):
                logger.log("Got zero gradient. not updating")
            else:
                with timed("cg"):
                    stepdir = cg(fisher_vector_product,
                                 g,
                                 cg_iters=cg_iters,
                                 verbose=rank == 0)
                assert np.isfinite(stepdir).all()
                shs = .5 * stepdir.dot(fisher_vector_product(stepdir))
                lm = np.sqrt(shs / max_kl)
                # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
                fullstep = stepdir / lm
                expectedimprove = g.dot(fullstep)
                surrbefore = lossbefore[0]
                stepsize = 1.0
                thbefore = get_flat()
                for _ in range(10):
                    thnew = thbefore + fullstep * stepsize
                    set_from_flat(thnew)
                    meanlosses = surr, kl, *_ = allmean(
                        np.array(compute_losses(*args)))
                    improve = surr - surrbefore
                    logger.log("Expected: %.3f Actual: %.3f" %
                               (expectedimprove, improve))
                    if not np.isfinite(meanlosses).all():
                        logger.log("Got non-finite value of losses -- bad!")
                    elif kl > max_kl * 1.5:
                        logger.log("violated KL constraint. shrinking step.")
                    elif improve < 0:
                        logger.log("surrogate didn't improve. shrinking step.")
                    else:
                        logger.log("Stepsize OK!")
                        break
                    stepsize *= .5
                else:
                    logger.log("couldn't compute a good step")
                    set_from_flat(thbefore)
                if nworkers > 1 and iters_so_far % 20 == 0:
                    paramsums = MPI.COMM_WORLD.allgather(
                        (thnew.sum(),
                         vfadam.getflat().sum()))  # list of tuples
                    assert all(
                        np.allclose(ps, paramsums[0]) for ps in paramsums[1:])
            with timed("vf"):
                for _ in range(vf_iters):
                    for (mbob, mbret) in dataset.iterbatches(
                        (seg["ob"], seg["tdlamret"]),
                            include_final_partial_batch=False,
                            batch_size=128):

                        if hasattr(pi, "ob_rms"):
                            pi.ob_rms.update(
                                mbob)  # update running mean/std for policy
                        if nworkers != 1:
                            g = allmean(compute_vflossandgrad(mbob, mbret))
                        else:
                            g = compute_vflossandgrad(mbob, mbret)

                        vfadam.update(g, vf_stepsize)

        for (lossname, lossval) in zip(loss_names, meanlosses):
            logger.record_tabular(lossname, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        # ------------------ Update D ------------------
        logger.log("Optimizing Discriminator...")
        logger.log(fmt_row(13, reward_giver.loss_name))
        ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob))
        batch_size = len(ob) // d_step
        d_losses = [
        ]  # list of tuples, each of which gives the loss for a minibatch
        for ob_batch, ac_batch in dataset.iterbatches(
            (ob, ac), include_final_partial_batch=False,
                batch_size=batch_size):
            ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob_batch))
            # update running mean/std for reward_giver
            if hasattr(reward_giver, "obs_rms"):
                reward_giver.obs_rms.update(
                    np.concatenate((ob_batch, ob_expert), 0))
            *newlosses, g = reward_giver.lossandgrad(ob_batch, ac_batch,
                                                     ob_expert, ac_expert)
            if nworkers != 1:
                d_adam.update(allmean(g), d_stepsize)
            else:
                d_adam.update(g, d_stepsize)

            d_losses.append(newlosses)
        logger.log(fmt_row(13, np.mean(d_losses, axis=0)))
        g_loss_summary = tf.Summary(value=[
            tf.Summary.Value(tag="g_loss",
                             simple_value=np.mean(d_losses[0][0]))
        ])
        summary_writer.add_summary(g_loss_summary, timesteps_so_far)

        d_loss_summary = tf.Summary(value=[
            tf.Summary.Value(tag="d_loss",
                             simple_value=np.mean(d_losses[0][1]))
        ])
        summary_writer.add_summary(d_loss_summary, timesteps_so_far)

        entropy_summary = tf.Summary(value=[
            tf.Summary.Value(tag="entropy",
                             simple_value=np.mean(d_losses[0][2]))
        ])
        summary_writer.add_summary(entropy_summary, timesteps_so_far)

        entropy_loss_summary = tf.Summary(value=[
            tf.Summary.Value(tag="entropy_loss",
                             simple_value=np.mean(d_losses[0][3]))
        ])
        summary_writer.add_summary(entropy_loss_summary, timesteps_so_far)

        g_acc_summary = tf.Summary(value=[
            tf.Summary.Value(tag="g_acc", simple_value=np.mean(d_losses[0][4]))
        ])
        summary_writer.add_summary(g_acc_summary, timesteps_so_far)

        expert_acc_summary = tf.Summary(value=[
            tf.Summary.Value(tag="expert_acc",
                             simple_value=np.mean(d_losses[0][5]))
        ])
        summary_writer.add_summary(expert_acc_summary, timesteps_so_far)

        lrlocal = (seg["ep_lens"], seg["ep_rets"], seg["ep_true_rets"]
                   )  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews, true_rets = map(flatten_lists, zip(*listoflrpairs))
        true_rewbuffer.extend(true_rets)
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        summary = tf.Summary(value=[
            tf.Summary.Value(tag="MeanDiscriminator",
                             simple_value=np.mean(rewbuffer))
        ])
        summary_writer.add_summary(summary, timesteps_so_far)

        truesummary = tf.Summary(value=[
            tf.Summary.Value(tag="MeanGenerator",
                             simple_value=np.mean(true_rewbuffer))
        ])
        summary_writer.add_summary(truesummary, timesteps_so_far)

        true_rets_summary = tf.Summary(value=[
            tf.Summary.Value(tag="Generator", simple_value=np.mean(true_rets))
        ])
        summary_writer.add_summary(true_rets_summary, timesteps_so_far)

        len_summary = tf.Summary(value=[
            tf.Summary.Value(tag="Length", simple_value=np.mean(lenbuffer))
        ])
        summary_writer.add_summary(len_summary, timesteps_so_far)

        optimgain_summary = tf.Summary(value=[
            tf.Summary.Value(tag="Optimgain",
                             simple_value=np.mean(meanlosses[0]))
        ])
        summary_writer.add_summary(optimgain_summary, timesteps_so_far)

        meankl_summary = tf.Summary(value=[
            tf.Summary.Value(tag="Meankl", simple_value=np.mean(meanlosses[1]))
        ])
        summary_writer.add_summary(meankl_summary, timesteps_so_far)

        entloss_summary = tf.Summary(value=[
            tf.Summary.Value(tag="Entloss",
                             simple_value=np.mean(meanlosses[2]))
        ])
        summary_writer.add_summary(entloss_summary, timesteps_so_far)

        surrgain_summary = tf.Summary(value=[
            tf.Summary.Value(tag="Surrgain",
                             simple_value=np.mean(meanlosses[3]))
        ])
        summary_writer.add_summary(surrgain_summary, timesteps_so_far)

        entropy_summary = tf.Summary(value=[
            tf.Summary.Value(tag="Entropy",
                             simple_value=np.mean(meanlosses[4]))
        ])
        summary_writer.add_summary(entropy_summary, timesteps_so_far)

        epThisIter_summary = tf.Summary(value=[
            tf.Summary.Value(tag="EpThisIter", simple_value=np.mean(len(lens)))
        ])
        summary_writer.add_summary(epThisIter_summary, timesteps_so_far)

        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("MeanDiscriminator", np.mean(rewbuffer))

        # Save model
        if robot_name == 'scara':
            if iters_so_far % save_per_iter == 0:
                if np.mean(rewbuffer) <= 200 or np.mean(
                        true_rewbuffer) >= -100:
                    task_name = str(iters_so_far)
                    fname = os.path.join(ckpt_dir, task_name)
                    os.makedirs(os.path.dirname(fname), exist_ok=True)
                    saver = tf.train.Saver()
                    saver.save(tf.get_default_session(), fname)
                    if iters_so_far == 2000:
                        break

        elif robot_name == 'mara':
            if iters_so_far % save_per_iter == 0:
                # if np.mean(rewbuffer) <= 300 or np.mean(true_rewbuffer) >= -400:
                task_name = str(iters_so_far)
                fname = os.path.join(ckpt_dir, task_name)
                os.makedirs(os.path.dirname(fname), exist_ok=True)
                saver = tf.train.Saver()
                saver.save(tf.get_default_session(), fname)
                if iters_so_far == 5000:
                    break

        logger.record_tabular("MeanGenerator", np.mean(true_rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)

        if rank == 0:
            logger.dump_tabular()
예제 #8
0
def learn(env,
          policy_func,
          reward_giver,
          expert_dataset,
          rank,
          g_step,
          d_step,
          entcoeff,
          save_per_iter,
          timesteps_per_batch,
          ckpt_dir,
          log_dir,
          task_name,
          gamma,
          lam,
          max_kl,
          cg_iters,
          cg_damping=1e-2,
          vf_stepsize=3e-4,
          d_stepsize=3e-4,
          vf_iters=3,
          max_timesteps=0,
          max_episodes=0,
          max_iters=0,
          callback=None):

    nworkers = MPI.COMM_WORLD.Get_size()
    rank = MPI.COMM_WORLD.Get_rank()
    np.set_printoptions(precision=3)
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_func("pi", ob_space, ac_space)
    saver = tf.train.Saver(
        var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='pi'))
    saver.restore(tf.get_default_session(), U_.getPath() + '/model/bc.ckpt')

    oldpi = policy_func("oldpi", ob_space, ac_space)
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    entbonus = entcoeff * meanent

    vferr = tf.reduce_mean(tf.square(pi.vpred - ret))

    ratio = tf.exp(pi.pd.logp(ac) -
                   oldpi.pd.logp(ac))  # advantage * pnew / pold
    surrgain = tf.reduce_mean(ratio * atarg)

    optimgain = surrgain + entbonus
    losses = [optimgain, meankl, entbonus, surrgain, meanent]
    loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]

    dist = meankl

    all_var_list = pi.get_trainable_variables()
    var_list = [
        v for v in all_var_list
        if v.name.startswith("pi/pol") or v.name.startswith("pi/logstd")
    ]
    vf_var_list = [v for v in all_var_list if v.name.startswith("pi/vff")]
    assert len(var_list) == len(vf_var_list) + 1
    d_adam = MpiAdam(reward_giver.get_trainable_variables())
    vfadam = MpiAdam(vf_var_list)

    get_flat = U.GetFlat(var_list)
    set_from_flat = U.SetFromFlat(var_list)
    klgrads = tf.gradients(dist, var_list)
    flat_tangent = tf.placeholder(dtype=tf.float32,
                                  shape=[None],
                                  name="flat_tan")
    shapes = [var.get_shape().as_list() for var in var_list]
    start = 0
    tangents = []
    for shape in shapes:
        sz = U.intprod(shape)
        tangents.append(tf.reshape(flat_tangent[start:start + sz], shape))
        start += sz
    gvp = tf.add_n([
        tf.reduce_sum(g * tangent)
        for (g, tangent) in zipsame(klgrads, tangents)
    ])  # pylint: disable=E1111
    fvp = U.flatgrad(gvp, var_list)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg], losses)
    compute_lossandgrad = U.function([ob, ac, atarg], losses +
                                     [U.flatgrad(optimgain, var_list)])
    compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp)
    compute_vflossandgrad = U.function([ob, ret],
                                       U.flatgrad(vferr, vf_var_list))

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(
                colorize("done in %.3f seconds" % (time.time() - tstart),
                         color='magenta'))
        else:
            yield

    def allmean(x):
        assert isinstance(x, np.ndarray)
        out = np.empty_like(x)
        MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
        out /= nworkers
        return out

    U.initialize()
    th_init = get_flat()
    MPI.COMM_WORLD.Bcast(th_init, root=0)
    set_from_flat(th_init)
    d_adam.sync()
    vfadam.sync()
    if rank == 0:
        print("Init param sum", th_init.sum())

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     reward_giver,
                                     timesteps_per_batch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=40)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40)  # rolling buffer for episode rewards
    true_rewbuffer = deque(maxlen=40)

    assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1

    g_loss_stats = stats(loss_names)
    d_loss_stats = stats(reward_giver.loss_name)
    ep_stats = stats(["True_rewards", "Rewards", "Episode_length"])
    # if provide pretrained weight

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break

        # Save model
        if rank == 0 and iters_so_far % save_per_iter == 0 and ckpt_dir is not None:
            fname = os.path.join(ckpt_dir, task_name)
            print('save model as ', fname)
            try:
                os.makedirs(os.path.dirname(fname))
            except OSError:
                # folder already exists
                pass
            saver = tf.train.Saver()
            saver.save(tf.get_default_session(), fname)

        print("********** Iteration %i ************" % iters_so_far)

        def fisher_vector_product(p):
            return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p

        # ------------------ Update G ------------------
        print("Optimizing Policy...")
        for _ in range(g_step):
            with timed("sampling"):
                seg = seg_gen.next()
            add_vtarg_and_adv(seg, gamma, lam)
            # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
            ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
                "tdlamret"]
            vpredbefore = seg[
                "vpred"]  # predicted value function before udpate
            atarg = (atarg - atarg.mean()) / atarg.std(
            )  # standardized advantage function estimate

            if hasattr(pi, "ob_rms"):
                pi.ob_rms.update(ob)  # update running mean/std for policy

            args = seg["ob"], seg["ac"], atarg
            fvpargs = [arr[::5] for arr in args]

            assign_old_eq_new(
            )  # set old parameter values to new parameter values
            with timed("computegrad"):
                tmp_result = compute_lossandgrad(seg["ob"], seg["ac"], atarg)
                lossbefore = tmp_result[:-1]
                g = tmp_result[-1]
            lossbefore = allmean(np.array(lossbefore))
            g = allmean(g)
            if np.allclose(g, 0):
                print("Got zero gradient. not updating")
            else:
                with timed("cg"):
                    stepdir = cg(fisher_vector_product,
                                 g,
                                 cg_iters=cg_iters,
                                 verbose=rank == 0)
                assert np.isfinite(stepdir).all()
                shs = .5 * stepdir.dot(fisher_vector_product(stepdir))
                lm = np.sqrt(shs / max_kl)
                # print("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
                fullstep = stepdir / lm
                expectedimprove = g.dot(fullstep)
                surrbefore = lossbefore[0]
                stepsize = 1.0
                thbefore = get_flat()
                for _ in range(10):
                    thnew = thbefore + fullstep * stepsize
                    set_from_flat(thnew)
                    meanlosses = allmean(
                        np.array(compute_losses(seg["ob"], seg["ac"], atarg)))
                    surr = meanlosses[0]
                    kl = meanlosses[1]
                    improve = surr - surrbefore
                    print("Expected: %.3f Actual: %.3f" %
                          (expectedimprove, improve))
                    if not np.isfinite(meanlosses).all():
                        print("Got non-finite value of losses -- bad!")
                    elif kl > max_kl * 1.5:
                        print("violated KL constraint. shrinking step.")
                    elif improve < 0:
                        print("surrogate didn't improve. shrinking step.")
                    else:
                        print("Stepsize OK!")
                        break
                    stepsize *= .5
                else:
                    print("couldn't compute a good step")
                    set_from_flat(thbefore)
                if nworkers > 1 and iters_so_far % 20 == 0:
                    paramsums = MPI.COMM_WORLD.allgather(
                        (thnew.sum(),
                         vfadam.getflat().sum()))  # list of tuples
                    assert all(
                        np.allclose(ps, paramsums[0]) for ps in paramsums[1:])
            with timed("vf"):
                for _ in range(vf_iters):
                    for (mbob, mbret) in dataset.iterbatches(
                        (seg["ob"], seg["tdlamret"]),
                            include_final_partial_batch=False,
                            batch_size=128):
                        if hasattr(pi, "ob_rms"):
                            pi.ob_rms.update(
                                mbob)  # update running mean/std for policy
                        g = allmean(compute_vflossandgrad(mbob, mbret))
                        vfadam.update(g, vf_stepsize)

        print("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))

        # ------------------ Update D ------------------
        print("Optimizing Discriminator...")
        print(fmt_row(13, reward_giver.loss_name))
        ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob))
        batch_size = len(ob) // d_step
        d_losses = [
        ]  # list of tuples, each of which gives the loss for a minibatch
        for ob_batch, ac_batch in tqdm(
                dataset.iterbatches((ob, ac),
                                    include_final_partial_batch=False,
                                    batch_size=batch_size)):
            ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob_batch))
            # update running mean/std for reward_giver
            if hasattr(reward_giver, "obs_rms"):
                reward_giver.obs_rms.update(
                    np.concatenate((ob_batch, ob_expert), 0))
            tmp_result = reward_giver.lossandgrad(ob_batch, ac_batch,
                                                  ob_expert, ac_expert)
            newlosses = tmp_result[:-1]
            g = tmp_result[-1]
            d_adam.update(allmean(g), d_stepsize)
            d_losses.append(newlosses)
        print(fmt_row(13, np.mean(d_losses, axis=0)))

        timesteps_so_far += len(seg['ob'])
        iters_so_far += 1

        print("EpisodesSoFar", episodes_so_far)
        print("TimestepsSoFar", timesteps_so_far)
        print("TimeElapsed", time.time() - tstart)
예제 #9
0
def learn(*,
        network,
        env,
        total_timesteps,
        timesteps_per_batch=1024, # what to train on
        max_kl=0.001,
        cg_iters=10,
        gamma=0.99,
        lam=1.0, # advantage estimation
        seed=None,
        ent_coef=0.0,
        cg_damping=1e-2,
        vf_stepsize=3e-4,
        vf_iters =3,
        max_episodes=0, max_iters=0,  # time constraint
        callback=None,
        load_path=None,
        **network_kwargs
        ):
    '''
    learn a policy function with TRPO algorithm

    Parameters:
    ----------

    network                 neural network to learn. Can be either string ('mlp', 'cnn', 'lstm', 'lnlstm' for basic types)
                            or function that takes input placeholder and returns tuple (output, None) for feedforward nets
                            or (output, (state_placeholder, state_output, mask_placeholder)) for recurrent nets

    env                     environment (one of the gym environments or wrapped via baselines.common.vec_env.VecEnv-type class

    timesteps_per_batch     timesteps per gradient estimation batch

    max_kl                  max KL divergence between old policy and new policy ( KL(pi_old || pi) )

    ent_coef                coefficient of policy entropy term in the optimization objective

    cg_iters                number of iterations of conjugate gradient algorithm

    cg_damping              conjugate gradient damping

    vf_stepsize             learning rate for adam optimizer used to optimie value function loss

    vf_iters                number of iterations of value function optimization iterations per each policy optimization step

    total_timesteps           max number of timesteps

    max_episodes            max number of episodes

    max_iters               maximum number of policy optimization iterations

    callback                function to be called with (locals(), globals()) each policy optimization step

    load_path               str, path to load the model from (default: None, i.e. no model is loaded)

    **network_kwargs        keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network

    Returns:
    -------

    learnt model

    '''

    if MPI is not None:
        nworkers = MPI.COMM_WORLD.Get_size()
        rank = MPI.COMM_WORLD.Get_rank()
    else:
        nworkers = 1
        rank = 0

    cpus_per_worker = 1
    U.get_session(config=tf.ConfigProto(
            allow_soft_placement=True,
            inter_op_parallelism_threads=cpus_per_worker,
            intra_op_parallelism_threads=cpus_per_worker
    ))


    policy = build_policy(env, network, value_network='copy', **network_kwargs)
    set_global_seeds(seed)

    np.set_printoptions(precision=3)
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space

    ob = observation_placeholder(ob_space)
    with tf.variable_scope("pi"):
        pi = policy(observ_placeholder=ob)
    with tf.variable_scope("oldpi"):
        oldpi = policy(observ_placeholder=ob)

    atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return

    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    entbonus = ent_coef * meanent

    vferr = tf.reduce_mean(tf.square(pi.vf - ret))

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # advantage * pnew / pold
    surrgain = tf.reduce_mean(ratio * atarg)

    optimgain = surrgain + entbonus
    losses = [optimgain, meankl, entbonus, surrgain, meanent]
    loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]

    dist = meankl

    all_var_list = get_trainable_variables("pi")
    # var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("pol")]
    # vf_var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("vf")]
    var_list = get_pi_trainable_variables("pi")
    vf_var_list = get_vf_trainable_variables("pi")

    vfadam = MpiAdam(vf_var_list)

    get_flat = U.GetFlat(var_list)
    set_from_flat = U.SetFromFlat(var_list)
    klgrads = tf.gradients(dist, var_list)
    flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan")
    shapes = [var.get_shape().as_list() for var in var_list]
    start = 0
    tangents = []
    for shape in shapes:
        sz = U.intprod(shape)
        tangents.append(tf.reshape(flat_tangent[start:start+sz], shape))
        start += sz
    gvp = tf.add_n([tf.reduce_sum(g*tangent) for (g, tangent) in zipsame(klgrads, tangents)]) #pylint: disable=E1111
    fvp = U.flatgrad(gvp, var_list)

    assign_old_eq_new = U.function([],[], updates=[tf.assign(oldv, newv)
        for (oldv, newv) in zipsame(get_variables("oldpi"), get_variables("pi"))])

    compute_losses = U.function([ob, ac, atarg], losses)
    compute_lossandgrad = U.function([ob, ac, atarg], losses + [U.flatgrad(optimgain, var_list)])
    compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp)
    compute_vflossandgrad = U.function([ob, ret], U.flatgrad(vferr, vf_var_list))

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(colorize("done in %.3f seconds"%(time.time() - tstart), color='magenta'))
        else:
            yield

    def allmean(x):
        assert isinstance(x, np.ndarray)
        if MPI is not None:
            out = np.empty_like(x)
            MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
            out /= nworkers
        else:
            out = np.copy(x)

        return out

    U.initialize()
    if load_path is not None:
        pi.load(load_path)

    th_init = get_flat()
    if MPI is not None:
        MPI.COMM_WORLD.Bcast(th_init, root=0)

    set_from_flat(th_init)
    vfadam.sync()
    print("Init param sum", th_init.sum(), flush=True)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi, env, timesteps_per_batch, stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards

    if sum([max_iters>0, total_timesteps>0, max_episodes>0])==0:
        # noththing to be done
        return pi

    assert sum([max_iters>0, total_timesteps>0, max_episodes>0]) < 2, \
        'out of max_iters, total_timesteps, and max_episodes only one should be specified'

    while True:
        if callback: callback(locals(), globals())
        if total_timesteps and timesteps_so_far >= total_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        logger.log("********** Iteration %i ************"%iters_so_far)

        with timed("sampling"):
            seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"]
        vpredbefore = seg["vpred"] # predicted value function before udpate
        atarg = (atarg - atarg.mean()) / atarg.std() # standardized advantage function estimate

        if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret)
        if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy

        args = seg["ob"], seg["ac"], atarg
        fvpargs = [arr[::5] for arr in args]
        def fisher_vector_product(p):
            return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p

        assign_old_eq_new() # set old parameter values to new parameter values
        with timed("computegrad"):
            *lossbefore, g = compute_lossandgrad(*args)
        lossbefore = allmean(np.array(lossbefore))
        g = allmean(g)
        if np.allclose(g, 0):
            logger.log("Got zero gradient. not updating")
        else:
            with timed("cg"):
                stepdir = cg(fisher_vector_product, g, cg_iters=cg_iters, verbose=rank==0)
            assert np.isfinite(stepdir).all()
            shs = .5*stepdir.dot(fisher_vector_product(stepdir))
            lm = np.sqrt(shs / max_kl)
            # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
            fullstep = stepdir / lm
            expectedimprove = g.dot(fullstep)
            surrbefore = lossbefore[0]
            stepsize = 1.0
            thbefore = get_flat()
            for _ in range(10):
                thnew = thbefore + fullstep * stepsize
                set_from_flat(thnew)
                meanlosses = surr, kl, *_ = allmean(np.array(compute_losses(*args)))
                improve = surr - surrbefore
                logger.log("Expected: %.3f Actual: %.3f"%(expectedimprove, improve))
                if not np.isfinite(meanlosses).all():
                    logger.log("Got non-finite value of losses -- bad!")
                elif kl > max_kl * 1.5:
                    logger.log("violated KL constraint. shrinking step.")
                elif improve < 0:
                    logger.log("surrogate didn't improve. shrinking step.")
                else:
                    logger.log("Stepsize OK!")
                    break
                stepsize *= .5
            else:
                logger.log("couldn't compute a good step")
                set_from_flat(thbefore)
            if nworkers > 1 and iters_so_far % 20 == 0:
                paramsums = MPI.COMM_WORLD.allgather((thnew.sum(), vfadam.getflat().sum())) # list of tuples
                assert all(np.allclose(ps, paramsums[0]) for ps in paramsums[1:])

        for (lossname, lossval) in zip(loss_names, meanlosses):
            logger.record_tabular(lossname, lossval)

        with timed("vf"):

            for _ in range(vf_iters):
                for (mbob, mbret) in dataset.iterbatches((seg["ob"], seg["tdlamret"]),
                include_final_partial_batch=False, batch_size=64):
                    g = allmean(compute_vflossandgrad(mbob, mbret))
                    vfadam.update(g, vf_stepsize)

        logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))

        lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values
        if MPI is not None:
            listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
        else:
            listoflrpairs = [lrlocal]

        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)

        if rank==0:
            logger.dump_tabular()

    return pi
def learn(env,
          policy_func,
          reward_giver,
          expert_dataset,
          rank,
          pretrained,
          pretrained_weight,
          *,
          g_step,
          d_step,
          entcoeff,
          save_per_iter,
          ckpt_dir,
          log_dir,
          timesteps_per_batch,
          task_name,
          gamma,
          lam,
          max_kl,
          cg_iters,
          cg_damping=1e-2,
          vf_stepsize=3e-4,
          d_stepsize=3e-4,
          vf_iters=3,
          max_timesteps=0,
          max_episodes=0,
          max_iters=0,
          vf_batchsize=128,
          callback=None,
          freeze_g=False,
          freeze_d=False,
          semi_dataset=None,
          semi_loss=False):

    semi_loss = semi_loss and semi_dataset is not None
    l2_w = 0.1

    nworkers = MPI.COMM_WORLD.Get_size()
    rank = MPI.COMM_WORLD.Get_rank()
    np.set_printoptions(precision=3)

    if rank == 0:
        writer = U.file_writer(log_dir)

        # print all the hyperparameters in the log...
        log_dict = {
            "expert trajectories": expert_dataset.num_traj,
            "algo": "trpo",
            "threads": nworkers,
            "timesteps_per_batch": timesteps_per_batch,
            "timesteps_per_thread": -(-timesteps_per_batch // nworkers),
            "entcoeff": entcoeff,
            "vf_iters": vf_iters,
            "vf_batchsize": vf_batchsize,
            "vf_stepsize": vf_stepsize,
            "d_stepsize": d_stepsize,
            "g_step": g_step,
            "d_step": d_step,
            "max_kl": max_kl,
            "gamma": gamma,
            "lam": lam,
            "l2_weight": l2_w
        }

        if semi_dataset is not None:
            log_dict["semi trajectories"] = semi_dataset.num_traj
        if hasattr(semi_dataset, 'info'):
            log_dict["semi_dataset_info"] = semi_dataset.info

        # print them all together for csv
        logger.log(",".join([str(elem) for elem in log_dict]))
        logger.log(",".join([str(elem) for elem in log_dict.values()]))

        # also print them separately for easy reading:
        for elem in log_dict:
            logger.log(str(elem) + ": " + str(log_dict[elem]))

    # divide the timesteps to the threads
    timesteps_per_batch = -(-timesteps_per_batch // nworkers
                            )  # get ceil of division

    # Setup losses and stuff
    # ----------------------------------------
    if semi_dataset:
        ob_space = semi_ob_space(env, semi_size=semi_dataset.semi_size)
    else:
        ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_func("pi",
                     ob_space=ob_space,
                     ac_space=ac_space,
                     reuse=(pretrained_weight is not None))
    oldpi = policy_func("oldpi", ob_space=ob_space, ac_space=ac_space)
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    entbonus = entcoeff * meanent

    vferr = tf.reduce_mean(tf.square(pi.vpred - ret))

    ratio = tf.exp(pi.pd.logp(ac) -
                   oldpi.pd.logp(ac))  # advantage * pnew / pold
    surrgain = tf.reduce_mean(ratio * atarg)

    optimgain = surrgain + entbonus
    losses = [optimgain, meankl, entbonus, surrgain, meanent]
    loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]

    vf_losses = [vferr]
    vf_loss_names = ["vf_loss"]

    dist = meankl

    all_var_list = pi.get_trainable_variables()
    var_list = [
        v for v in all_var_list
        if v.name.startswith("pi/pol") or v.name.startswith("pi/logstd")
    ]
    vf_var_list = [v for v in all_var_list if v.name.startswith("pi/vf")]
    assert len(var_list) == len(vf_var_list) + 1

    get_flat = U.GetFlat(var_list)
    set_from_flat = U.SetFromFlat(var_list)
    klgrads = tf.gradients(dist, var_list)
    flat_tangent = tf.placeholder(dtype=tf.float32,
                                  shape=[None],
                                  name="flat_tan")
    shapes = [var.get_shape().as_list() for var in var_list]
    start = 0
    tangents = []
    for shape in shapes:
        sz = U.intprod(shape)
        tangents.append(tf.reshape(flat_tangent[start:start + sz], shape))
        start += sz
    gvp = tf.add_n([
        tf.reduce_sum(g * tangent)
        for (g, tangent) in zipsame(klgrads, tangents)
    ])  # pylint: disable=E1111
    fvp = U.flatgrad(gvp, var_list)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg], losses)
    compute_vf_losses = U.function([ob, ac, atarg, ret], losses + vf_losses)
    compute_lossandgrad = U.function([ob, ac, atarg], losses +
                                     [U.flatgrad(optimgain, var_list)])
    compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp)
    compute_vflossandgrad = U.function([ob, ret], vf_losses +
                                       [U.flatgrad(vferr, vf_var_list)])

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(
                colorize("done in %.3f seconds" % (time.time() - tstart),
                         color='magenta'))
        else:
            yield

    def allmean(x):
        assert isinstance(x, np.ndarray)
        out = np.empty_like(x)
        MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
        out /= nworkers
        return out

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=40)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40)  # rolling buffer for episode rewards
    true_rewbuffer = deque(maxlen=40)
    success_buffer = deque(maxlen=40)
    l2_rewbuffer = deque(
        maxlen=40) if semi_loss and semi_dataset is not None else None
    total_rewbuffer = deque(
        maxlen=40) if semi_loss and semi_dataset is not None else None

    assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1

    not_update = 1 if not freeze_d else 0  # do not update G before D the first time
    # if provide pretrained weight
    loaded = False
    if not U.load_checkpoint_variables(pretrained_weight):
        if U.load_checkpoint_variables(pretrained_weight,
                                       check_prefix=get_il_prefix()):
            if rank == 0:
                logger.log("loaded checkpoint variables from " +
                           pretrained_weight)
            loaded = True
    else:
        loaded = True

    if loaded:
        not_update = 0 if any(
            [x.op.name.find("adversary") != -1
             for x in U.ALREADY_INITIALIZED]) else 1
        if pretrained_weight and pretrained_weight.rfind("iter_") and \
                pretrained_weight[pretrained_weight.rfind("iter_") + len("iter_"):].isdigit():
            curr_iter = int(
                pretrained_weight[pretrained_weight.rfind("iter_") +
                                  len("iter_"):]) + 1
            print("loaded checkpoint at iteration: " + str(curr_iter))
            iters_so_far = curr_iter
            timesteps_so_far = iters_so_far * timesteps_per_batch

    d_adam = MpiAdam(reward_giver.get_trainable_variables())
    vfadam = MpiAdam(vf_var_list)

    U.initialize()

    th_init = get_flat()
    MPI.COMM_WORLD.Bcast(th_init, root=0)
    set_from_flat(th_init)
    d_adam.sync()
    vfadam.sync()
    if rank == 0:
        print("Init param sum", th_init.sum(), flush=True)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(
        pi,
        env,
        reward_giver,
        timesteps_per_batch,
        stochastic=True,
        semi_dataset=semi_dataset,
        semi_loss=semi_loss)  # ADD L2 loss to semi trajectories

    g_loss_stats = stats(loss_names + vf_loss_names)
    d_loss_stats = stats(reward_giver.loss_name)
    ep_names = ["True_rewards", "Rewards", "Episode_length", "Success"]
    if semi_loss and semi_dataset is not None:
        ep_names.append("L2_loss")
        ep_names.append("total_rewards")
    ep_stats = stats(ep_names)

    if rank == 0:
        start_time = time.time()
        ch_count = 0
        env_type = env.env.env.__class__.__name__

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break

        # Save model
        if rank == 0 and iters_so_far % save_per_iter == 0 and ckpt_dir is not None:
            fname = os.path.join(ckpt_dir, task_name)
            if env_type.find(
                    "Pendulum"
            ) != -1 or save_per_iter != 1:  # allow pendulum to save all iterations
                fname = os.path.join(ckpt_dir, 'iter_' + str(iters_so_far),
                                     'iter_' + str(iters_so_far))
            os.makedirs(os.path.dirname(fname), exist_ok=True)
            saver = tf.train.Saver()
            saver.save(tf.get_default_session(), fname, write_meta_graph=False)

        if rank == 0 and time.time(
        ) - start_time >= 3600 * ch_count:  # save a different checkpoint every hour
            fname = os.path.join(ckpt_dir, 'hour' + str(ch_count).zfill(3))
            fname = os.path.join(fname, 'iter_' + str(iters_so_far))
            os.makedirs(os.path.dirname(fname), exist_ok=True)
            saver = tf.train.Saver()
            saver.save(tf.get_default_session(), fname, write_meta_graph=False)
            ch_count += 1

        logger.log("********** Iteration %i ************" % iters_so_far)

        def fisher_vector_product(p):
            return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p

        # ------------------ Update G ------------------
        logger.log("Optimizing Policy...")
        for curr_step in range(g_step):
            with timed("sampling"):
                seg = seg_gen.__next__()

            seg["rew"] = seg["rew"] - seg["l2_loss"] * l2_w

            add_vtarg_and_adv(seg, gamma, lam)
            # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
            ob, ac, atarg, tdlamret, full_ob = seg["ob"], seg["ac"], seg[
                "adv"], seg["tdlamret"], seg["full_ob"]
            vpredbefore = seg[
                "vpred"]  # predicted value function before udpate
            atarg = (atarg - atarg.mean()) / atarg.std(
            )  # standardized advantage function estimate
            d = Dataset(dict(ob=full_ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                        shuffle=True)

            if not_update:
                break  # stop G from updating

            if hasattr(pi, "ob_rms"):
                pi.ob_rms.update(full_ob)  # update running mean/std for policy

            args = seg["full_ob"], seg["ac"], atarg
            fvpargs = [arr[::5] for arr in args]

            assign_old_eq_new(
            )  # set old parameter values to new parameter values
            with timed("computegrad"):
                *lossbefore, g = compute_lossandgrad(*args)
            lossbefore = allmean(np.array(lossbefore))
            g = allmean(g)
            if np.allclose(g, 0):
                logger.log("Got zero gradient. not updating")
            else:
                with timed("cg"):
                    stepdir = cg(fisher_vector_product,
                                 g,
                                 cg_iters=cg_iters,
                                 verbose=rank == 0)
                assert np.isfinite(stepdir).all()
                shs = .5 * stepdir.dot(fisher_vector_product(stepdir))
                lm = np.sqrt(shs / max_kl)
                # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
                fullstep = stepdir / lm
                expectedimprove = g.dot(fullstep)
                surrbefore = lossbefore[0]
                stepsize = 1.0
                thbefore = get_flat()
                for _ in range(10):
                    thnew = thbefore + fullstep * stepsize
                    set_from_flat(thnew)
                    meanlosses = surr, kl, *_ = allmean(
                        np.array(compute_losses(*args)))
                    if rank == 0:
                        print("Generator entropy " + str(meanlosses[4]) +
                              ", loss " + str(meanlosses[2]))
                    improve = surr - surrbefore
                    logger.log("Expected: %.3f Actual: %.3f" %
                               (expectedimprove, improve))
                    if not np.isfinite(meanlosses).all():
                        logger.log("Got non-finite value of losses -- bad!")
                    elif kl > max_kl * 1.5:
                        logger.log("violated KL constraint. shrinking step.")
                    elif improve < 0:
                        logger.log("surrogate didn't improve. shrinking step.")
                    else:
                        logger.log("Stepsize OK!")
                        break
                    stepsize *= .5
                else:
                    logger.log("couldn't compute a good step")
                    set_from_flat(thbefore)
                if nworkers > 1 and iters_so_far % 20 == 0:
                    paramsums = MPI.COMM_WORLD.allgather(
                        (thnew.sum(),
                         vfadam.getflat().sum()))  # list of tuples
                    assert all(
                        np.allclose(ps, paramsums[0]) for ps in paramsums[1:])
            with timed("vf"):
                logger.log(fmt_row(13, vf_loss_names))
                for _ in range(vf_iters):
                    vf_b_losses = []
                    for batch in d.iterate_once(vf_batchsize):
                        mbob = batch["ob"]
                        mbret = batch["vtarg"]

                        if hasattr(pi, "ob_rms"):
                            pi.ob_rms.update(
                                mbob)  # update running mean/std for policy
                        *newlosses, g = compute_vflossandgrad(mbob, mbret)
                        g = allmean(g)
                        newlosses = allmean(np.array(newlosses))

                        vfadam.update(g, vf_stepsize)
                        vf_b_losses.append(newlosses)
                    logger.log(fmt_row(13, np.mean(vf_b_losses, axis=0)))

            logger.log("Evaluating losses...")
            losses = []
            for batch in d.iterate_once(vf_batchsize):
                newlosses = compute_vf_losses(batch["ob"], batch["ac"],
                                              batch["atarg"], batch["vtarg"])
                losses.append(newlosses)
            meanlosses, _, _ = mpi_moments(losses, axis=0)

            #########################
            '''
            For evaluation during training.
            Can be commented out for faster training...
            '''
            for ob_batch, ac_batch, full_ob_batch in dataset.iterbatches(
                (ob, ac, full_ob),
                    include_final_partial_batch=False,
                    batch_size=len(ob)):
                ob_expert, ac_expert = expert_dataset.get_next_batch(
                    len(ob_batch))
                exp_rew = 0
                for obs, acs in zip(ob_expert, ac_expert):
                    exp_rew += 1 - np.exp(
                        -reward_giver.get_reward(obs, acs)[0][0])
                mean_exp_rew = exp_rew / len(ob_expert)

                gen_rew = 0
                for obs, acs, full_obs in zip(ob_batch, ac_batch,
                                              full_ob_batch):
                    gen_rew += 1 - np.exp(
                        -reward_giver.get_reward(obs, acs)[0][0])
                mean_gen_rew = gen_rew / len(ob_batch)
                if rank == 0:
                    logger.log("Generator step " + str(curr_step) +
                               ": Dicriminator reward of expert traj " +
                               str(mean_exp_rew) + " vs gen traj " +
                               str(mean_gen_rew))
            #########################

        if not not_update:
            g_losses = meanlosses
            for (lossname, lossval) in zip(loss_names + vf_loss_names,
                                           meanlosses):
                logger.record_tabular(lossname, lossval)
            logger.record_tabular("ev_tdlam_before",
                                  explained_variance(vpredbefore, tdlamret))

        # ------------------ Update D ------------------
        if not freeze_d:
            logger.log("Optimizing Discriminator...")
            batch_size = len(ob) // d_step
            d_losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for ob_batch, ac_batch, full_ob_batch in dataset.iterbatches(
                (ob, ac, full_ob),
                    include_final_partial_batch=False,
                    batch_size=batch_size):
                ob_expert, ac_expert = expert_dataset.get_next_batch(
                    len(ob_batch))
                #########################
                '''
                For evaluation during training.
                Can be commented out for faster training...
                '''
                exp_rew = 0
                for obs, acs in zip(ob_expert, ac_expert):
                    exp_rew += 1 - np.exp(
                        -reward_giver.get_reward(obs, acs)[0][0])
                mean_exp_rew = exp_rew / len(ob_expert)

                gen_rew = 0

                for obs, acs in zip(ob_batch, ac_batch):
                    gen_rew += 1 - np.exp(
                        -reward_giver.get_reward(obs, acs)[0][0])

                mean_gen_rew = gen_rew / len(ob_batch)
                if rank == 0:
                    logger.log("Dicriminator reward of expert traj " +
                               str(mean_exp_rew) + " vs gen traj " +
                               str(mean_gen_rew))
                #########################
                # update running mean/std for reward_giver
                if hasattr(reward_giver, "obs_rms"):
                    reward_giver.obs_rms.update(
                        np.concatenate((ob_batch, ob_expert), 0))
                loss_input = (ob_batch, ac_batch, ob_expert, ac_expert)
                *newlosses, g = reward_giver.lossandgrad(*loss_input)
                d_adam.update(allmean(g), d_stepsize)
                d_losses.append(newlosses)
            logger.log(fmt_row(13, reward_giver.loss_name))
            logger.log(fmt_row(13, np.mean(d_losses, axis=0)))

        lrlocal = (seg["ep_lens"], seg["ep_rets"], seg["ep_true_rets"],
                   seg["ep_success"], seg["ep_semi_loss"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews, true_rets, success, semi_losses = map(
            flatten_lists, zip(*listoflrpairs))

        # success
        success = [
            float(elem) for elem in success
            if isinstance(elem, (int, float, bool))
        ]  # remove potential None types
        if not success:
            success = [-1]  # set success to -1 if env has no success flag
        success_buffer.extend(success)

        true_rewbuffer.extend(true_rets)
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        if semi_loss and semi_dataset is not None:
            semi_losses = [elem * l2_w for elem in semi_losses]
            total_rewards = rews
            total_rewards = [
                re_elem - l2_elem
                for re_elem, l2_elem in zip(total_rewards, semi_losses)
            ]
            l2_rewbuffer.extend(semi_losses)
            total_rewbuffer.extend(total_rewards)

        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpTrueRewMean", np.mean(true_rewbuffer))
        logger.record_tabular("EpSuccess", np.mean(success_buffer))

        if semi_loss and semi_dataset is not None:
            logger.record_tabular("EpSemiLoss", np.mean(l2_rewbuffer))
            logger.record_tabular("EpTotalReward", np.mean(total_rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        logger.record_tabular("ItersSoFar", iters_so_far)

        if rank == 0:
            logger.dump_tabular()
            if not not_update:
                g_loss_stats.add_all_summary(writer, g_losses, iters_so_far)
            if not freeze_d:
                d_loss_stats.add_all_summary(writer, np.mean(d_losses, axis=0),
                                             iters_so_far)

            # default buffers
            ep_buffers = [
                np.mean(true_rewbuffer),
                np.mean(rewbuffer),
                np.mean(lenbuffer),
                np.mean(success_buffer)
            ]

            if semi_loss and semi_dataset is not None:
                ep_buffers.append(np.mean(l2_rewbuffer))
                ep_buffers.append(np.mean(total_rewbuffer))

            ep_stats.add_all_summary(writer, ep_buffers, iters_so_far)

        if not_update and not freeze_g:
            not_update -= 1
예제 #11
0
def learn(
        env,
        policy_func,
        *,
        timesteps_per_batch,  # timesteps per actor per update
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        optim_epochs,
        optim_stepsize,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,  # advantage estimation
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        max_seconds=0,  # time constraint
        callback=None,  # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
        init_policy_params=None,
        policy_scope='pi'):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space

    pi = policy_func(policy_scope, ob_space,
                     ac_space)  # Construct network for new policy
    oldpi = policy_func("old" + policy_scope, ob_space,
                        ac_space)  # Network for old policy

    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return
    clip_tf = tf.placeholder(dtype=tf.float32)

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)

    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_tf,
                             1.0 + clip_tf * lrmult) * atarg
    pol_surr = -tf.reduce_mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)

    if hasattr(pi, 'additional_loss'):
        pol_surr += pi.additional_loss

    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))

    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    pol_var_list = [
        v for v in pi.get_trainable_variables()
        if 'placehold' not in v.name and 'offset' not in v.name and 'secondary'
        not in v.name and 'vf' not in v.name and 'pol' in v.name
    ]
    pol_var_size = np.sum([np.prod(v.shape) for v in pol_var_list])

    get_pol_flat = U.GetFlat(pol_var_list)
    set_pol_from_flat = U.SetFromFlat(pol_var_list)

    total_loss = pol_surr + pol_entpen + vf_loss

    lossandgrad = U.function([ob, ac, atarg, ret, lrmult, clip_tf],
                             losses + [U.flatgrad(total_loss, var_list)])
    pol_lossandgrad = U.function([ob, ac, atarg, ret, lrmult, clip_tf],
                                 losses +
                                 [U.flatgrad(total_loss, pol_var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult, clip_tf], losses)
    compute_losses_cpo = U.function(
        [ob, ac, atarg, ret, lrmult, clip_tf],
        [tf.reduce_mean(surr1), pol_entpen, vf_loss, meankl, meanent])
    compute_ratios = U.function([ob, ac], ratio)
    compute_kls = U.function([ob, ac], kloldnew)

    compute_rollout_old_prob = U.function([ob, ac],
                                          tf.reduce_mean(oldpi.pd.logp(ac)))
    compute_rollout_new_prob = U.function([ob, ac],
                                          tf.reduce_mean(pi.pd.logp(ac)))
    compute_rollout_new_prob_min = U.function([ob, ac],
                                              tf.reduce_min(pi.pd.logp(ac)))

    update_ops = {}
    update_placeholders = {}
    for v in pi.get_trainable_variables():
        update_placeholders[v.name] = tf.placeholder(v.dtype,
                                                     shape=v.get_shape())
        update_ops[v.name] = v.assign(update_placeholders[v.name])

    # compute fisher information matrix
    dims = [int(np.prod(p.shape)) for p in pol_var_list]
    logprob_grad = U.flatgrad(tf.reduce_mean(pi.pd.logp(ac)), pol_var_list)
    compute_logprob_grad = U.function([ob, ac], logprob_grad)

    U.initialize()

    adam.sync()

    if init_policy_params is not None:
        cur_scope = pi.get_variables()[0].name[0:pi.get_variables()[0].name.
                                               find('/')]
        orig_scope = list(init_policy_params.keys()
                          )[0][0:list(init_policy_params.keys())[0].find('/')]
        print(cur_scope, orig_scope)
        for i in range(len(pi.get_variables())):
            if pi.get_variables()[i].name.replace(cur_scope, orig_scope,
                                                  1) in init_policy_params:
                assign_op = pi.get_variables()[i].assign(
                    init_policy_params[pi.get_variables()[i].name.replace(
                        cur_scope, orig_scope, 1)])
                tf.get_default_session().run(assign_op)
                assign_op = oldpi.get_variables()[i].assign(
                    init_policy_params[pi.get_variables()[i].name.replace(
                        cur_scope, orig_scope, 1)])
                tf.get_default_session().run(assign_op)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_batch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    prev_params = {}
    for v in var_list:
        if 'pol' in v.name:
            prev_params[v.name] = v.eval()

    optim_seg = None

    grad_scale = 1.0

    while True:
        if MPI.COMM_WORLD.Get_rank() == 0:
            print('begin')
            memory()
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()

        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        unstandardized_adv = np.copy(atarg)
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate

        args = seg["ob"], seg["ac"], atarg
        fvpargs = [arr for arr in args]

        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))

        cur_clip_val = clip_param

        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)

        for epoch in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult, cur_clip_val)
                adam.update(g * grad_scale, optim_stepsize * cur_lrmult)
                losses.append(newlosses)

            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult, clip_param)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)

        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.log(fmt_row(13, meanlosses))
            for (lossval, name) in zipsame(meanlosses, loss_names):
                logger.record_tabular("loss_" + name, lossval)

        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.record_tabular("EpLenMean", np.mean(lenbuffer))
            logger.record_tabular("EpRewMean", np.mean(rewbuffer))
            logger.record_tabular("EpThisIter", len(lens))
            logger.record_tabular(
                "PolVariance",
                repr(adam.getflat()[-env.action_space.shape[0]:]))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.record_tabular("EpisodesSoFar", episodes_so_far)
            logger.record_tabular("TimestepsSoFar", timesteps_so_far)
            logger.record_tabular("TimeElapsed", time.time() - tstart)

        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()
        if MPI.COMM_WORLD.Get_rank() == 0:
            print('end')
            memory()

    return pi, np.mean(rewbuffer)
예제 #12
0
def learn(env, policy_fn, *,
        batch_size, # what to train on
        task_horizon,
        max_kl, cg_iters,
        gamma, lam, # advantage estimation
        entcoeff=0.0,
        cg_damping=1e-2,
        vf_stepsize=3e-4,
        vf_iters =3,
        max_timesteps=0, max_episodes=0, max_iters=0,  # time constraint
        callback=None,
        weights_dir='.',
        per_decision = True,
        normalize = False,
        truncate_at = np.infty
        ):
    nworkers = MPI.COMM_WORLD.Get_size()
    rank = MPI.COMM_WORLD.Get_rank()
    np.set_printoptions(precision=3)
    timesteps_per_batch = batch_size * task_horizon
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_fn("pi", ob_space, ac_space)
    oldpi = policy_fn("oldpi", ob_space, ac_space)
    atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    entbonus = entcoeff * meanent

    vferr = tf.reduce_mean(tf.square(pi.vpred - ret))

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # advantage * pnew / pold
    surrgain = tf.reduce_mean(ratio * atarg)

    optimgain = surrgain + entbonus
    losses = [optimgain, meankl, entbonus, surrgain, meanent]
    loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]

    dist = meankl

    all_var_list = pi.get_trainable_variables()
    var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("pol")]
    vf_var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("vf")]
    vfadam = MpiAdam(vf_var_list)

    get_flat = U.GetFlat(var_list)
    set_from_flat = U.SetFromFlat(var_list)
    klgrads = tf.gradients(dist, var_list)
    flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan")
    shapes = [var.get_shape().as_list() for var in var_list]
    start = 0
    tangents = []
    for shape in shapes:
        sz = U.intprod(shape)
        tangents.append(tf.reshape(flat_tangent[start:start+sz], shape))
        start += sz
    gvp = tf.add_n([tf.reduce_sum(g*tangent) for (g, tangent) in zipsame(klgrads, tangents)]) #pylint: disable=E1111
    fvp = U.flatgrad(gvp, var_list)

    assign_old_eq_new = U.function([],[], updates=[tf.assign(oldv, newv)
        for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())])
    compute_losses = U.function([ob, ac, atarg], losses)
    compute_lossandgrad = U.function([ob, ac, atarg], losses + [U.flatgrad(optimgain, var_list)])
    compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp)
    compute_vflossandgrad = U.function([ob, ret], U.flatgrad(vferr, vf_var_list))

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(colorize("done in %.3f seconds"%(time.time() - tstart), color='magenta'))
        else:
            yield

    def allmean(x):
        assert isinstance(x, np.ndarray)
        out = np.empty_like(x)
        MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
        out /= nworkers
        return out

    U.initialize()
    th_init = get_flat()
    MPI.COMM_WORLD.Bcast(th_init, root=0)
    set_from_flat(th_init)
    vfadam.sync()
    print("Init param sum", th_init.sum(), flush=True)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi, env, task_horizon, timesteps_per_batch, stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards

    assert sum([max_iters>0, max_timesteps>0, max_episodes>0])==1

    
    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        logger.log("********** Iteration %i ************"%iters_so_far)

        with timed("sampling"):
            seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        
        #Params
        #"""
        params = pi.eval_param()
        #print(params)
        np.save(weights_dir+'/weights_'+str(iters_so_far), params)
        #"""

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"]
        vpredbefore = seg["vpred"] # predicted value function before udpate
        atarg = (atarg - atarg.mean()) / atarg.std() # standardized advantage function estimate

        if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret)
        if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy

        args = seg["ob"], seg["ac"], atarg
        fvpargs = [arr[::5] for arr in args]
        def fisher_vector_product(p):
            return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p

        assign_old_eq_new() # set old parameter values to new parameter values
        with timed("computegrad"):
            *lossbefore, g = compute_lossandgrad(*args)
        lossbefore = allmean(np.array(lossbefore))
        g = allmean(g)
        if np.allclose(g, 0):
            logger.log("Got zero gradient. not updating")
        else:
            with timed("cg"):
                stepdir = cg(fisher_vector_product, g, cg_iters=cg_iters, verbose=rank==0)
            assert np.isfinite(stepdir).all()
            shs = .5*stepdir.dot(fisher_vector_product(stepdir))
            lm = np.sqrt(shs / max_kl)
            print('DOT: %s' % np.dot(stepdir, g))
            # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
            fullstep = stepdir / lm
            expectedimprove = g.dot(fullstep)
            surrbefore = lossbefore[0]
            stepsize = 1.0
            thbefore = get_flat()
            for _ in range(10):
                thnew = thbefore + fullstep * stepsize
                set_from_flat(thnew)
                meanlosses = surr, kl, *_ = allmean(np.array(compute_losses(*args)))
                improve = surr - surrbefore
                logger.log("Expected: %.3f Actual: %.3f"%(expectedimprove, improve))
                if not np.isfinite(meanlosses).all():
                    logger.log("Got non-finite value of losses -- bad!")
                elif kl > max_kl * 1.5:
                    logger.log("violated KL constraint. shrinking step.")
                elif improve < 0:
                    logger.log("surrogate didn't improve. shrinking step.")
                else:
                    logger.log("Stepsize OK!")
                    break
                stepsize *= .5
            else:
                logger.log("couldn't compute a good step")
                set_from_flat(thbefore)
            if nworkers > 1 and iters_so_far % 20 == 0:
                paramsums = MPI.COMM_WORLD.allgather((thnew.sum(), vfadam.getflat().sum())) # list of tuples
                assert all(np.allclose(ps, paramsums[0]) for ps in paramsums[1:])

        for (lossname, lossval) in zip(loss_names, meanlosses):
            logger.record_tabular(lossname, lossval)

        with timed("vf"):

            for _ in range(vf_iters):
                for (mbob, mbret) in dataset.iterbatches((seg["ob"], seg["tdlamret"]),
                include_final_partial_batch=False, batch_size=64):
                    g = allmean(compute_vflossandgrad(mbob, mbret))
                    vfadam.update(g, vf_stepsize)

        logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))

        lrlocal = (seg["ep_lens"], seg["ep_rets"], seg["ob"],
                   seg["ac"],seg["rew"]) # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
        lens, rews, states, actions, rewards = map(flatten_lists, zip(*listoflrpairs))
        
        disc_rews = []
        start = 0
        for ep_len in lens:
            end = start + ep_len
            disc = gamma + np.zeros(ep_len)
            disc[0] = 1
            disc = np.cumprod(disc)
            disc_rewards = np.array(rewards[start:end]) * disc
            disc_rews.append(np.sum(disc_rewards))
            start = end
            
        #Save importance weights
        simple_iw = pi.eval_simple_iw(states, 
                               actions,
                               rewards,
                               lens,
                               gamma=gamma,
                               behavioral=oldpi)
        np.save(weights_dir+'/iws_'+str(iters_so_far), simple_iw)
        #print(len(simple_iw), simple_iw)
        
        #Save returns
        ep_rets = np.array(disc_rews)
        np.save(weights_dir+'/rets_'+str(iters_so_far), ep_rets)
        #print(len(ep_rets), ep_rets)
        

        #lenbuffer.extend(lens)
        #rewbuffer.extend(rews)

        #Renyi
        """
        renyi_4 = np.mean(pi.eval_renyi(states, oldpi, 4))
        #print('Renyi:', renyi)
        #"""
        
        #Importance weights stats
        """
        avg_iw, var_iw, max_iw, ess = pi.eval_iw_stats(states, 
                               actions,
                               rewards,
                               lens,
                               gamma=gamma,
                               behavioral=oldpi,
                               per_decision=per_decision,
                               normalize=normalize,
                               truncate_at=truncate_at)
        #"""
        
        #Returns stats
        """
        avg_ret, var_ret, max_ret = pi.eval_ret_stats(states, 
                               actions,
                               rewards,
                               lens,
                               gamma=gamma,
                               behavioral=oldpi,
                               per_decision=per_decision,
                               normalize=normalize,
                               truncate_at=truncate_at)
        #"""

        #Performance
        #"""
        bound_delta = .2
        batch_size = len(lens)
        J = pi.eval_J(states,
                      actions,
                      rewards,
                      lens,
                      gamma=gamma,
                      behavioral=oldpi,
                      per_decision=per_decision,
                      normalize=normalize,
                      truncate_at=truncate_at)
        
        var_J = pi.eval_var_J(states,
                      actions,
                      rewards,
                      lens,
                      gamma=gamma,
                      behavioral=oldpi,
                      per_decision=per_decision,
                      normalize=normalize,
                      truncate_at=truncate_at)
        
        """
        bound = pi.eval_bound(states,
                      actions,
                      rewards,
                      lens,
                      gamma=gamma,
                      behavioral=oldpi,
                      per_decision=per_decision,
                      normalize=normalize,
                      truncate_at=truncate_at,
                      delta=bound_delta,
                      use_ess=True)
        #"""
        
        #Sample Renyi
        d2s = pi.eval_renyi(states, oldpi, 2)
        d2s_by_episode = []
        start = 0
        for ep_len in lens:
            end = start + ep_len
            d2s_by_episode = np.sum(d2s[start:end])
            start = end
        sample_d2 = np.mean(np.exp(d2s_by_episode))
        
        """
        grad_bound = pi.eval_grad_bound(states,
                      actions,
                      rewards,
                      lens,
                      gamma=gamma,
                      behavioral=oldpi,
                      per_decision=per_decision,
                      normalize=normalize,
                      truncate_at=truncate_at,
                      delta=bound_delta,
                      use_ess=True)
        print(grad_bound)
        #print('Target performance', J, '+-', np.sqrt(var_J/len(lens)))    
        #"""
        
        #Gradients
        """
        grad_J = pi.eval_grad_J(states,
                                       actions,
                                       rewards,
                                       lens,
                                       behavioral=oldpi,
                                       per_decision=True)
        grad_var_J = pi.eval_grad_var_J(states,
                                       actions,
                                       rewards,
                                       lens,
                                       behavioral=oldpi,
                                       per_decision=True)
        print('Target performance grads', grad_J, grad_var_J)    
        #"""
    
        #Student-t bound
        """
        bound = pi.eval_bound(states,
                                 actions,
                                 rewards,
                                 lens,
                                 behavioral=oldpi,
                                 per_decision=True)
        #print('Bound comp. time', time.time() - checkpoint)
        print("StudentTBound", bound)
        #"""
    
        
        #Student-t bound grad
        """
        bound_grad = pi.eval_bound_grad(states,
                                 actions,
                                 rewards,
                                 lens,
                                 behavioral=oldpi,
                                 per_decision=True)
        print("StudentTBound grad", bound_grad)
        #"""
    
        #Fisher
        """
        checkpoint = time.time()
        fisher = oldpi.eval_fisher(states, actions, lens, behavioral=None)
        #print(fisher)
        assert np.array_equal(fisher, fisher.T)
        print('Fisher comp. time', time.time() - checkpoint)
        checkpoint = time.time()
        natural = np.linalg.solve(fisher + 1e-12*np.eye(fisher.shape[0]), grad_J)
        print(natural)
        #print('Fisher vector product time:', time.time() - checkpoint)
        #"""
        
        #Logging
        logger.record_tabular("Step_size", stepsize)
        #logger.record_tabular("Our_bound", bound)
        #logger.record_tabular("Reny_4", renyi_4)
        logger.record_tabular("SampleRenyi2", sample_d2)
        #logger.record_tabular("Max_iw", max_iw)
        #logger.record_tabular("Ess", ess)
        #logger.record_tabular("Avg_iw", avg_iw)
        #logger.record_tabular("Var_iw", var_iw)
        #logger.record_tabular("Max_ret", max_ret)
        #logger.record_tabular("Avg_ret", avg_ret)
        #logger.record_tabular("Var_ret", var_ret)
        logger.record_tabular("EpLenMean", np.mean(lens))
        logger.record_tabular("DiscEpRewMean", np.mean(disc_rews))
        logger.record_tabular("EpRewMean", np.mean(rews))
        logger.record_tabular("EpThisIter", len(lens))
        logger.record_tabular("J_hat", J)
        logger.record_tabular("Var_J", var_J)
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)

        if rank==0:
            logger.dump_tabular()
예제 #13
0
def learn(
        env,
        policy_func,
        *,
        timesteps_per_batch,
        max_kl,
        cg_iters,
        gamma,
        lam,
        entcoeff=0.0,
        cg_damping=1e-2,
        vf_stepsize=3e-4,
        vf_iters=3,
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        callback=None,
        # GAIL Params
        pretrained_weight=None,
        reward_giver=None,
        expert_dataset=None,
        rank=0,
        save_per_iter=1,
        ckpt_dir="/tmp/gail/ckpt/",
        g_step=1,
        d_step=1,
        task_name="task_name",
        d_stepsize=3e-4,
        using_gail=True):
    """
    learns a GAIL policy using the given environment

    :param env: (Gym Environment) the environment
    :param policy_func: (function (str, Gym Space, Gym Space, bool): MLPPolicy) policy generator
    :param timesteps_per_batch: (int) the number of timesteps to run per batch (horizon)
    :param max_kl: (float) the kullback leiber loss threashold
    :param cg_iters: (int) the number of iterations for the conjugate gradient calculation
    :param gamma: (float) the discount value
    :param lam: (float) GAE factor
    :param entcoeff: (float) the weight for the entropy loss
    :param cg_damping: (float) the compute gradient dampening factor
    :param vf_stepsize: (float) the value function stepsize
    :param vf_iters: (int) the value function's number iterations for learning
    :param max_timesteps: (int) the maximum number of timesteps before halting
    :param max_episodes: (int) the maximum number of episodes before halting
    :param max_iters: (int) the maximum number of training iterations  before halting
    :param callback: (function (dict, dict)) the call back function, takes the local and global attribute dictionary
    :param pretrained_weight: (str) the save location for the pretrained weights
    :param reward_giver: (TransitionClassifier) the reward predicter from obsevation and action
    :param expert_dataset: (MujocoDset) the dataset manager
    :param rank: (int) the rank of the mpi thread
    :param save_per_iter: (int) the number of iterations before saving
    :param ckpt_dir: (str) the location for saving checkpoints
    :param g_step: (int) number of steps to train policy in each epoch
    :param d_step: (int) number of steps to train discriminator in each epoch
    :param task_name: (str) the name of the task (can be None)
    :param d_stepsize: (float) the reward giver stepsize
    :param using_gail: (bool) using the GAIL model
    """

    nworkers = MPI.COMM_WORLD.Get_size()
    rank = MPI.COMM_WORLD.Get_rank()
    np.set_printoptions(precision=3)
    sess = tf_util.single_threaded_session()
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    policy = policy_func("pi", ob_space, ac_space, sess=sess)
    old_policy = policy_func("oldpi",
                             ob_space,
                             ac_space,
                             sess=sess,
                             placeholders={
                                 "obs": policy.obs_ph,
                                 "stochastic": policy.stochastic_ph
                             })

    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    observation = policy.obs_ph
    action = policy.pdtype.sample_placeholder([None])

    kloldnew = old_policy.proba_distribution.kl(policy.proba_distribution)
    ent = policy.proba_distribution.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    entbonus = entcoeff * meanent

    vferr = tf.reduce_mean(tf.square(policy.vpred - ret))

    # advantage * pnew / pold
    ratio = tf.exp(
        policy.proba_distribution.logp(action) -
        old_policy.proba_distribution.logp(action))
    surrgain = tf.reduce_mean(ratio * atarg)

    optimgain = surrgain + entbonus
    losses = [optimgain, meankl, entbonus, surrgain, meanent]
    loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]

    dist = meankl

    all_var_list = policy.get_trainable_variables()
    if using_gail:
        var_list = [
            v for v in all_var_list
            if v.name.startswith("pi/pol") or v.name.startswith("pi/logstd")
        ]
        vf_var_list = [v for v in all_var_list if v.name.startswith("pi/vff")]
        assert len(var_list) == len(vf_var_list) + 1
        d_adam = MpiAdam(reward_giver.get_trainable_variables())
    else:
        var_list = [
            v for v in all_var_list if v.name.split("/")[1].startswith("pol")
        ]
        vf_var_list = [
            v for v in all_var_list if v.name.split("/")[1].startswith("vf")
        ]

    vfadam = MpiAdam(vf_var_list, sess=sess)
    get_flat = tf_util.GetFlat(var_list, sess=sess)
    set_from_flat = tf_util.SetFromFlat(var_list, sess=sess)

    klgrads = tf.gradients(dist, var_list)
    flat_tangent = tf.placeholder(dtype=tf.float32,
                                  shape=[None],
                                  name="flat_tan")
    shapes = [var.get_shape().as_list() for var in var_list]
    start = 0
    tangents = []
    for shape in shapes:
        var_size = tf_util.intprod(shape)
        tangents.append(tf.reshape(flat_tangent[start:start + var_size],
                                   shape))
        start += var_size
    gvp = tf.add_n([
        tf.reduce_sum(grad * tangent)
        for (grad, tangent) in zipsame(klgrads, tangents)
    ])  # pylint: disable=E1111
    fvp = tf_util.flatgrad(gvp, var_list)

    assign_old_eq_new = tf_util.function(
        [], [],
        updates=[
            tf.assign(oldv, newv) for (oldv, newv) in zipsame(
                old_policy.get_variables(), policy.get_variables())
        ])
    compute_losses = tf_util.function([observation, action, atarg], losses)
    compute_lossandgrad = tf_util.function(
        [observation, action, atarg],
        losses + [tf_util.flatgrad(optimgain, var_list)])
    compute_fvp = tf_util.function([flat_tangent, observation, action, atarg],
                                   fvp)
    compute_vflossandgrad = tf_util.function([observation, ret],
                                             tf_util.flatgrad(
                                                 vferr, vf_var_list))

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            start_time = time.time()
            yield
            print(
                colorize("done in %.3f seconds" % (time.time() - start_time),
                         color='magenta'))
        else:
            yield

    def allmean(arr):
        assert isinstance(arr, np.ndarray)
        out = np.empty_like(arr)
        MPI.COMM_WORLD.Allreduce(arr, out, op=MPI.SUM)
        out /= nworkers
        return out

    tf_util.initialize(sess=sess)

    th_init = get_flat()
    MPI.COMM_WORLD.Bcast(th_init, root=0)
    set_from_flat(th_init)

    if using_gail:
        d_adam.sync()
    vfadam.sync()

    if rank == 0:
        print("Init param sum", th_init.sum(), flush=True)

    # Prepare for rollouts
    # ----------------------------------------
    if using_gail:
        seg_gen = traj_segment_generator(policy,
                                         env,
                                         timesteps_per_batch,
                                         stochastic=True,
                                         reward_giver=reward_giver,
                                         gail=True)
    else:
        seg_gen = traj_segment_generator(policy,
                                         env,
                                         timesteps_per_batch,
                                         stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    t_start = time.time()
    lenbuffer = deque(maxlen=40)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40)  # rolling buffer for episode rewards

    assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1

    if using_gail:
        true_rewbuffer = deque(maxlen=40)
        #  Stats not used for now
        #  g_loss_stats = Stats(loss_names)
        #  d_loss_stats = Stats(reward_giver.loss_name)
        #  ep_stats = Stats(["True_rewards", "Rewards", "Episode_length"])

        # if provide pretrained weight
        if pretrained_weight is not None:
            tf_util.load_state(pretrained_weight,
                               var_list=policy.get_variables())

    while True:
        if callback:
            callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break

        # Save model
        if using_gail and rank == 0 and iters_so_far % save_per_iter == 0 and ckpt_dir is not None:
            fname = os.path.join(ckpt_dir, task_name)
            os.makedirs(os.path.dirname(fname), exist_ok=True)
            saver = tf.train.Saver()
            saver.save(sess, fname)

        logger.log("********** Iteration %i ************" % iters_so_far)

        def fisher_vector_product(vec):
            return allmean(compute_fvp(vec, *fvpargs,
                                       sess=sess)) + cg_damping * vec

        # ------------------ Update G ------------------
        logger.log("Optimizing Policy...")
        # g_step = 1 when not using GAIL
        for _ in range(g_step):
            with timed("sampling"):
                seg = seg_gen.__next__()
            add_vtarg_and_adv(seg, gamma, lam)
            # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
            observation, action, atarg, tdlamret = seg["ob"], seg["ac"], seg[
                "adv"], seg["tdlamret"]
            vpredbefore = seg[
                "vpred"]  # predicted value function before udpate
            atarg = (atarg - atarg.mean()) / atarg.std(
            )  # standardized advantage function estimate

            if hasattr(policy, "ret_rms"):
                policy.ret_rms.update(tdlamret)
            if hasattr(policy, "ob_rms"):
                policy.ob_rms.update(
                    observation)  # update running mean/std for policy

            args = seg["ob"], seg["ac"], atarg
            fvpargs = [arr[::5] for arr in args]

            assign_old_eq_new(sess=sess)

            with timed("computegrad"):
                *lossbefore, grad = compute_lossandgrad(*args, sess=sess)
            lossbefore = allmean(np.array(lossbefore))
            grad = allmean(grad)
            if np.allclose(grad, 0):
                logger.log("Got zero gradient. not updating")
            else:
                with timed("cg"):
                    stepdir = conjugate_gradient(fisher_vector_product,
                                                 grad,
                                                 cg_iters=cg_iters,
                                                 verbose=rank == 0)
                assert np.isfinite(stepdir).all()
                shs = .5 * stepdir.dot(fisher_vector_product(stepdir))
                # abs(shs) to avoid taking square root of negative values
                lagrange_multiplier = np.sqrt(abs(shs) / max_kl)
                # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
                fullstep = stepdir / lagrange_multiplier
                expectedimprove = grad.dot(fullstep)
                surrbefore = lossbefore[0]
                stepsize = 1.0
                thbefore = get_flat()
                for _ in range(10):
                    thnew = thbefore + fullstep * stepsize
                    set_from_flat(thnew)
                    mean_losses = surr, kl_loss, *_ = allmean(
                        np.array(compute_losses(*args, sess=sess)))
                    improve = surr - surrbefore
                    logger.log("Expected: %.3f Actual: %.3f" %
                               (expectedimprove, improve))
                    if not np.isfinite(mean_losses).all():
                        logger.log("Got non-finite value of losses -- bad!")
                    elif kl_loss > max_kl * 1.5:
                        logger.log("violated KL constraint. shrinking step.")
                    elif improve < 0:
                        logger.log("surrogate didn't improve. shrinking step.")
                    else:
                        logger.log("Stepsize OK!")
                        break
                    stepsize *= .5
                else:
                    logger.log("couldn't compute a good step")
                    set_from_flat(thbefore)
                if nworkers > 1 and iters_so_far % 20 == 0:
                    paramsums = MPI.COMM_WORLD.allgather(
                        (thnew.sum(),
                         vfadam.getflat().sum()))  # list of tuples
                    assert all(
                        np.allclose(ps, paramsums[0]) for ps in paramsums[1:])

            with timed("vf"):
                for _ in range(vf_iters):
                    for (mbob, mbret) in dataset.iterbatches(
                        (seg["ob"], seg["tdlamret"]),
                            include_final_partial_batch=False,
                            batch_size=128):
                        if hasattr(policy, "ob_rms"):
                            policy.ob_rms.update(
                                mbob)  # update running mean/std for policy
                        grad = allmean(
                            compute_vflossandgrad(mbob, mbret, sess=sess))
                        vfadam.update(grad, vf_stepsize)

        for (loss_name, loss_val) in zip(loss_names, mean_losses):
            logger.record_tabular(loss_name, loss_val)

        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))

        if using_gail:
            # ------------------ Update D ------------------
            logger.log("Optimizing Discriminator...")
            logger.log(fmt_row(13, reward_giver.loss_name))
            ob_expert, ac_expert = expert_dataset.get_next_batch(
                len(observation))
            batch_size = len(observation) // d_step
            d_losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for ob_batch, ac_batch in dataset.iterbatches(
                (observation, action),
                    include_final_partial_batch=False,
                    batch_size=batch_size):
                ob_expert, ac_expert = expert_dataset.get_next_batch(
                    len(ob_batch))
                # update running mean/std for reward_giver
                if hasattr(reward_giver, "obs_rms"):
                    reward_giver.obs_rms.update(
                        np.concatenate((ob_batch, ob_expert), 0))
                *newlosses, grad = reward_giver.lossandgrad(
                    ob_batch, ac_batch, ob_expert, ac_expert)
                d_adam.update(allmean(grad), d_stepsize)
                d_losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(d_losses, axis=0)))

            lrlocal = (seg["ep_lens"], seg["ep_rets"], seg["ep_true_rets"]
                       )  # local values
            listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
            lens, rews, true_rets = map(flatten_lists, zip(*listoflrpairs))
            true_rewbuffer.extend(true_rets)
        else:
            lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
            listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
            lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        if using_gail:
            logger.record_tabular("EpTrueRewMean", np.mean(true_rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - t_start)

        if rank == 0:
            logger.dump_tabular()
예제 #14
0
def learn(
        *,
        network,
        env,
        save,
        total_timesteps,
        timesteps_per_batch=1024,  # what to train on
        max_kl=0.001,
        cg_iters=10,
        gamma=0.99,
        lam=1.0,  # advantage estimation
        seed=None,
        ent_coef=0.0,
        cg_damping=1e-2,
        vf_stepsize=3e-4,
        vf_iters=3,
        max_episodes=0,
        max_iters=0,  # ttotal_timestepsime constraint
        callback=None,
        load_path=None,
        **network_kwargs):
    '''
    learn a policy function with TRPO algorithm

    Parameters:
    ----------

    network                 neural network to learn. Can be either string ('mlp', 'cnn', 'lstm', 'lnlstm' for basic types)
                            or function that takes input placeholder and returns tuple (output, None) for feedforward nets
                            or (output, (state_placeholder, state_output, mask_placeholder)) for recurrent nets

    env                     environment (one of the gym environments or wrapped via baselines.common.vec_env.VecEnv-type class

    timesteps_per_batch     timesteps per gradient estimation batch

    max_kl                  max KL divergence between old policy and new policy ( KL(pi_old || pi) )

    ent_coef                coefficient of policy entropy term in the optimization objective

    cg_iters                number of iterations of conjugate gradient algorithm

    cg_damping              conjugate gradient damping

    vf_stepsize             learning rate for adam optimizer used to optimie value function loss

    vf_iters                number of iterations of value function optimization iterations per each policy optimization step

    total_timesteps           max number of timesteps

    max_episodes            max number of episodes

    max_iters               maximum number of policy optimization iterations

    callback                function to be called with (locals(), globals()) each policy optimization step

    load_path               str, path to load the model from (default: None, i.e. no model is loaded)

    **network_kwargs        keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network

    Returns:
    -------

    learnt model

    '''

    if MPI is not None:
        nworkers = MPI.COMM_WORLD.Get_size()
        rank = MPI.COMM_WORLD.Get_rank()
    else:
        nworkers = 1
        rank = 0

    set_global_seeds(seed)

    np.set_printoptions(precision=3)
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space

    if isinstance(network, str):
        network, network_model = get_network_builder(network)(**network_kwargs)

    with tf.name_scope("pi"):
        pi_policy_network = network(ob_space.shape)
        pi_value_network = network(ob_space.shape)
        pi = PolicyWithValue(ac_space, pi_policy_network, pi_value_network)
    with tf.name_scope("oldpi"):
        old_pi_policy_network = network(ob_space.shape)
        old_pi_value_network = network(ob_space.shape)
        oldpi = PolicyWithValue(ac_space, old_pi_policy_network,
                                old_pi_value_network)

    pi_var_list = pi_policy_network.trainable_variables + list(
        pi.pdtype.trainable_variables)
    old_pi_var_list = old_pi_policy_network.trainable_variables + list(
        oldpi.pdtype.trainable_variables)
    vf_var_list = pi_value_network.trainable_variables + pi.value_fc.trainable_variables
    old_vf_var_list = old_pi_value_network.trainable_variables + oldpi.value_fc.trainable_variables

    if load_path is not None:
        load_path = osp.expanduser(load_path)
        ckpt = tf.train.Checkpoint(model=pi)
        manager = tf.train.CheckpointManager(ckpt, load_path, max_to_keep=None)
        ckpt.restore(manager.latest_checkpoint)

    vfadam = MpiAdam(vf_var_list)

    get_flat = U.GetFlat(pi_var_list)
    set_from_flat = U.SetFromFlat(pi_var_list)
    loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]
    shapes = [var.get_shape().as_list() for var in pi_var_list]

    def assign_old_eq_new():
        for pi_var, old_pi_var in zip(pi_var_list, old_pi_var_list):
            old_pi_var.assign(pi_var)
        for vf_var, old_vf_var in zip(vf_var_list, old_vf_var_list):
            old_vf_var.assign(vf_var)

    @tf.function
    def compute_lossandgrad(ob, ac, atarg):
        with tf.GradientTape() as tape:
            old_policy_latent = oldpi.policy_network(ob)
            old_pd, _ = oldpi.pdtype.pdfromlatent(old_policy_latent)
            policy_latent = pi.policy_network(ob)
            pd, _ = pi.pdtype.pdfromlatent(policy_latent)
            kloldnew = old_pd.kl(pd)
            ent = pd.entropy()
            meankl = tf.reduce_mean(kloldnew)
            meanent = tf.reduce_mean(ent)
            entbonus = ent_coef * meanent
            ratio = tf.exp(pd.logp(ac) - old_pd.logp(ac))
            surrgain = tf.reduce_mean(ratio * atarg)
            optimgain = surrgain + entbonus
            losses = [optimgain, meankl, entbonus, surrgain, meanent]
        gradients = tape.gradient(optimgain, pi_var_list)
        return losses + [U.flatgrad(gradients, pi_var_list)]

    @tf.function
    def compute_losses(ob, ac, atarg):
        old_policy_latent = oldpi.policy_network(ob)
        old_pd, _ = oldpi.pdtype.pdfromlatent(old_policy_latent)
        policy_latent = pi.policy_network(ob)
        pd, _ = pi.pdtype.pdfromlatent(policy_latent)
        kloldnew = old_pd.kl(pd)
        ent = pd.entropy()
        meankl = tf.reduce_mean(kloldnew)
        meanent = tf.reduce_mean(ent)
        entbonus = ent_coef * meanent
        ratio = tf.exp(pd.logp(ac) - old_pd.logp(ac))
        surrgain = tf.reduce_mean(ratio * atarg)
        optimgain = surrgain + entbonus
        losses = [optimgain, meankl, entbonus, surrgain, meanent]
        return losses

    #ob shape should be [batch_size, ob_dim], merged nenv
    #ret shape should be [batch_size]
    @tf.function
    def compute_vflossandgrad(ob, ret):
        with tf.GradientTape() as tape:
            pi_vf = pi.value(ob)
            vferr = tf.reduce_mean(tf.square(pi_vf - ret))
        return U.flatgrad(tape.gradient(vferr, vf_var_list), vf_var_list)

    @tf.function
    def compute_fvp(flat_tangent, ob, ac, atarg):
        with tf.GradientTape() as outter_tape:
            with tf.GradientTape() as inner_tape:
                old_policy_latent = oldpi.policy_network(ob)
                old_pd, _ = oldpi.pdtype.pdfromlatent(old_policy_latent)
                policy_latent = pi.policy_network(ob)
                pd, _ = pi.pdtype.pdfromlatent(policy_latent)
                kloldnew = old_pd.kl(pd)
                meankl = tf.reduce_mean(kloldnew)
            klgrads = inner_tape.gradient(meankl, pi_var_list)
            start = 0
            tangents = []
            for shape in shapes:
                sz = U.intprod(shape)
                tangents.append(
                    tf.reshape(flat_tangent[start:start + sz], shape))
                start += sz
            gvp = tf.add_n([
                tf.reduce_sum(g * tangent)
                for (g, tangent) in zipsame(klgrads, tangents)
            ])
        hessians_products = outter_tape.gradient(gvp, pi_var_list)
        fvp = U.flatgrad(hessians_products, pi_var_list)
        return fvp

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(
                colorize("done in %.3f seconds" % (time.time() - tstart),
                         color='magenta'))
        else:
            yield

    def allmean(x):
        assert isinstance(x, np.ndarray)
        if MPI is not None:
            out = np.empty_like(x)
            MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
            out /= nworkers
        else:
            out = np.copy(x)

        return out

    th_init = get_flat()
    if MPI is not None:
        MPI.COMM_WORLD.Bcast(th_init, root=0)

    set_from_flat(th_init)
    vfadam.sync()
    print("Init param sum", th_init.sum(), flush=True)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi, env, timesteps_per_batch)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=40)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40)  # rolling buffer for episode rewards

    # ---------------------- New ----------------------
    rewforbuffer = deque(maxlen=40)
    rewctrlbuffer = deque(maxlen=40)
    rewconbuffer = deque(maxlen=40)
    rewsurbuffer = deque(maxlen=40)

    rewformeanbuf = np.array([])
    rewctrlmeanbuf = np.array([])
    rewconmeanbuf = np.array([])
    rewsurmeanbuf = np.array([])
    # -------------------------------------------------

    if sum([max_iters > 0, total_timesteps > 0, max_episodes > 0]) == 0:
        # nothing to be done
        return pi

    assert sum([max_iters>0, total_timesteps>0, max_episodes>0]) < 2, \
        'out of max_iters, total_timesteps, and max_episodes only one should be specified'

    x_axis = 0
    x_holder = np.array([])
    rew_holder = np.array([])
    while True:
        if timesteps_so_far > total_timesteps - 1500:  #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
            # Set recording XXXX timesteps before ending
            env = VecVideoRecorder(env,
                                   osp.join(logger.get_dir(), "videos"),
                                   record_video_trigger=lambda x: True,
                                   video_length=200)
            seg_gen = traj_segment_generator(pi, env, timesteps_per_batch)

        if callback: callback(locals(), globals())
        if total_timesteps and timesteps_so_far >= total_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        logger.log("********** Iteration %i ************" % iters_so_far)

        with timed("sampling"):
            seg = seg_gen.__next__()

        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        ob = sf01(ob)
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate

        if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret)
        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        args = ob, ac, atarg
        fvpargs = [arr[::5] for arr in args]

        def fisher_vector_product(p):
            return allmean(compute_fvp(p, *fvpargs).numpy()) + cg_damping * p

        assign_old_eq_new()  # set old parameter values to new parameter values
        with timed("computegrad"):
            *lossbefore, g = compute_lossandgrad(*args)
        lossbefore = allmean(np.array(lossbefore))
        g = g.numpy()
        g = allmean(g)
        if np.allclose(g, 0):
            logger.log("Got zero gradient. not updating")
        else:
            with timed("cg"):
                stepdir = cg(fisher_vector_product,
                             g,
                             cg_iters=cg_iters,
                             verbose=rank == 0)
            assert np.isfinite(stepdir).all()
            shs = .5 * stepdir.dot(fisher_vector_product(stepdir))
            lm = np.sqrt(shs / max_kl)
            # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
            fullstep = stepdir / lm
            expectedimprove = g.dot(fullstep)
            surrbefore = lossbefore[0]
            stepsize = 1.0
            thbefore = get_flat()
            for _ in range(10):
                thnew = thbefore + fullstep * stepsize
                set_from_flat(thnew)
                meanlosses = surr, kl, *_ = allmean(
                    np.array(compute_losses(*args)))
                improve = surr - surrbefore
                logger.log("Expected: %.3f Actual: %.3f" %
                           (expectedimprove, improve))
                if not np.isfinite(meanlosses).all():
                    logger.log("Got non-finite value of losses -- bad!")
                elif kl > max_kl * 1.5:
                    logger.log("violated KL constraint. shrinking step.")
                elif improve < 0:
                    logger.log("surrogate didn't improve. shrinking step.")
                else:
                    logger.log("Stepsize OK!")
                    break
                stepsize *= .5
            else:
                logger.log("couldn't compute a good step")
                set_from_flat(thbefore)
            if nworkers > 1 and iters_so_far % 20 == 0:
                paramsums = MPI.COMM_WORLD.allgather(
                    (thnew.sum(), vfadam.getflat().sum()))  # list of tuples
                assert all(
                    np.allclose(ps, paramsums[0]) for ps in paramsums[1:])

        for (lossname, lossval) in zip(loss_names, meanlosses):
            logger.record_tabular(lossname, lossval)

        with timed("vf"):

            for _ in range(vf_iters):
                for (mbob, mbret) in dataset.iterbatches(
                    (seg["ob"], seg["tdlamret"]),
                        include_final_partial_batch=False,
                        batch_size=64):
                    mbob = sf01(mbob)
                    g = allmean(compute_vflossandgrad(mbob, mbret).numpy())
                    vfadam.update(g, vf_stepsize)

        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))

        lrlocal = (seg["ep_lens"], seg["ep_rets"], seg["ep_rets_for"],
                   seg["ep_rets_ctrl"], seg["ep_rets_con"], seg["ep_rets_sur"]
                   )  # local values
        if MPI is not None:
            listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        else:
            listoflrpairs = [lrlocal]

        lens, rews, rews_for, rews_ctrl, rews_con, rews_sur = map(
            flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        # ---------------------- New ----------------------
        rewforbuffer.extend(rews_for)
        rewctrlbuffer.extend(rews_ctrl)
        rewconbuffer.extend(rews_con)
        rewsurbuffer.extend(rews_sur)

        rewformeanbuf = np.append([rewformeanbuf], [np.mean(rewforbuffer)])
        rewctrlmeanbuf = np.append([rewctrlmeanbuf], [np.mean(rewctrlbuffer)])
        rewconmeanbuf = np.append([rewconmeanbuf], [np.mean(rewconbuffer)])
        rewsurmeanbuf = np.append([rewsurmeanbuf], [np.mean(rewsurbuffer)])
        # -------------------------------------------------

        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)

        if rank == 0:
            logger.dump_tabular()

        x_axis += 1
        x_holder = np.append([x_holder], [x_axis])
        rew_holder = np.append([rew_holder], [np.mean(rewbuffer)])

    # --------------------------------------- NEW -----------------------------------------------------
    with open("img_rec.txt", "r") as rec:
        cur_gen = rec.read()
        cur_gen = cur_gen.strip()  # remove \n

    dir_of_gens = [
        '1_1', '2_1', '3_1', '1_2', '2_2', '3_2', '1_3', '2_3', '3_3', '1_4',
        '2_4', '3_4', '1_5', '2_5', '3_5', '1_6', '2_6', '3_6', '1_7', '2_7',
        '3_7', '1_8', '2_8', '3_8', '1_9', '2_9', '3_9', '1_10', '2_10',
        '3_10', '1_11', '2_11', '3_11', '1_12', '2_12', '3_12'
    ]
    # -------------------------------------------------------------------------------------------------

    from matplotlib import pyplot as plt
    f = plt.figure(1)
    plt.plot(x_holder, rew_holder)
    plt.title("Rewards for Ant v2")
    plt.grid(True)
    plt.savefig('rewards_for_antv2_{}'.format(cur_gen))

    g = plt.figure(2)
    plt.plot(x_holder, rewformeanbuf, label='Forward Reward')
    plt.plot(x_holder, rewctrlmeanbuf, label='CTRL Cost')
    plt.plot(x_holder, rewconmeanbuf, label='Contact Cost')
    plt.plot(x_holder, rewsurmeanbuf, label='Survive Reward')
    plt.title("Reward Breakdown")
    plt.legend()
    plt.grid(True)
    plt.savefig('rewards_breakdown{}'.format(cur_gen))

    # plt.show()

    # --------------------------------------- NEW -----------------------------------------------------
    elem = int(dir_of_gens.index(cur_gen))
    with open("img_rec.txt", "w") as rec:
        if elem == 35:
            new_elem = 0
        else:
            new_elem = elem + 1
        new_gen = cur_gen.replace(cur_gen, dir_of_gens[new_elem])
        rec.write(new_gen)
    # -------------------------------------------------------------------------------------------------

    #----------------------------------------------------------- SAVE WEIGHTS ------------------------------------------------------------#
    # np.save('val_weights_bias_2_c',val_weights_bias_2_c) # <-------------------------------------------------------------------------------------
    # save = save.replace(save[0],'..',2)
    # os.chdir(save)
    # name = 'max_reward'
    # completeName = os.path.join(name+".txt")
    # file1 = open(completeName,"w")
    # toFile = str(np.mean(rewbuffer))
    # file1.write(toFile)
    # file1.close()
    # os.chdir('../../../baselines-tf2')

    return pi
예제 #15
0
def learn(
        *,
        network,
        env,
        total_timesteps,
        timesteps_per_batch=1024,  # what to train on
        max_kl=0.002,
        cg_iters=10,
        gamma=0.99,
        lam=1.0,  # advantage estimation
        seed=None,
        ent_coef=0.00,
        cg_damping=1e-2,
        vf_stepsize=3e-4,
        vf_iters=3,
        max_episodes=0,
        max_iters=0,  # time constraint
        callback=None,
        load_path=None,
        num_reward=1,
        **network_kwargs):
    '''
    learn a policy function with TRPO algorithm

    Parameters:
    ----------

    network                 neural network to learn. Can be either string ('mlp', 'cnn', 'lstm', 'lnlstm' for basic types)
                            or function that takes input placeholder and returns tuple (output, None) for feedforward nets
                            or (output, (state_placeholder, state_output, mask_placeholder)) for recurrent nets

    env                     environment (one of the gym environments or wrapped via baselines.common.vec_env.VecEnv-type class

    timesteps_per_batch     timesteps per gradient estimation batch

    max_kl                  max KL divergence between old policy and new policy ( KL(pi_old || pi) )

    ent_coef                coefficient of policy entropy term in the optimization objective

    cg_iters                number of iterations of conjugate gradient algorithm

    cg_damping              conjugate gradient damping

    vf_stepsize             learning rate for adam optimizer used to optimie value function loss

    vf_iters                number of iterations of value function optimization iterations per each policy optimization step

    total_timesteps           max number of timesteps

    max_episodes            max number of episodes

    max_iters               maximum number of policy optimization iterations

    callback                function to be called with (locals(), globals()) each policy optimization step

    load_path               str, path to load the model from (default: None, i.e. no model is loaded)

    **network_kwargs        keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network

    Returns:
    -------

    learnt model

    '''

    if MPI is not None:
        nworkers = MPI.COMM_WORLD.Get_size()
        rank = MPI.COMM_WORLD.Get_rank()
    else:
        nworkers = 1
        rank = 0

    cpus_per_worker = 1
    U.get_session(
        config=tf.ConfigProto(allow_soft_placement=True,
                              inter_op_parallelism_threads=cpus_per_worker,
                              intra_op_parallelism_threads=cpus_per_worker))

    set_global_seeds(seed)
    # 创建policy
    policy = build_policy(env,
                          network,
                          value_network='copy',
                          num_reward=num_reward,
                          **network_kwargs)

    process_dir = logger.get_dir()
    save_dir = process_dir.split(
        'Data')[-2] + 'log/l1/seed' + process_dir[-1] + '/'
    os.makedirs(save_dir, exist_ok=True)
    coe_save = []
    impro_save = []
    grad_save = []
    adj_save = []
    coe = np.ones((num_reward)) / num_reward

    np.set_printoptions(precision=3)
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space

    #################################################################
    # ob ac ret atarg 都是 placeholder
    # ret atarg 此处应该是向量形式
    ob = observation_placeholder(ob_space)

    # 创建pi和oldpi
    with tf.variable_scope("pi"):
        pi = policy(observ_placeholder=ob)
    with tf.variable_scope("oldpi"):
        oldpi = policy(observ_placeholder=ob)

    # 每个reward都可以算一个atarg
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32,
                         shape=[None, num_reward])  # Empirical return

    ac = pi.pdtype.sample_placeholder([None])

    #此处的KL div和entropy与reward无关
    ##################################
    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    # entbonus 是entropy loss
    entbonus = ent_coef * meanent
    #################################

    ###########################################################
    # vferr 用来更新 v 网络
    vferr = tf.reduce_mean(tf.square(pi.vf - ret))
    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))
    # advantage * pnew / pold
    surrgain = tf.reduce_mean(ratio * atarg)

    # optimgain 用来更新 policy 网络, 应该每个reward有一个
    optimgain = surrgain + entbonus
    losses = [optimgain, meankl, entbonus, surrgain, meanent]
    loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]

    ###########################################################
    dist = meankl

    # 定义要优化的变量和 V 网络 adam 优化器
    all_var_list = get_trainable_variables("pi")
    # var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("pol")]
    # vf_var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("vf")]
    var_list = get_pi_trainable_variables("pi")
    vf_var_list = get_vf_trainable_variables("pi")

    vfadam = MpiAdam(vf_var_list)

    # 把变量展开成一个向量的类
    get_flat = U.GetFlat(var_list)

    # 这个类可以把一个向量分片赋值给var_list里的变量
    set_from_flat = U.SetFromFlat(var_list)
    # kl散度的梯度
    klgrads = tf.gradients(dist, var_list)

    ####################################################################
    # 拉直的向量
    flat_tangent = tf.placeholder(dtype=tf.float32,
                                  shape=[None],
                                  name="flat_tan")

    # 把拉直的向量重新分成很多向量
    shapes = [var.get_shape().as_list() for var in var_list]
    start = 0
    tangents = []
    for shape in shapes:
        sz = U.intprod(shape)
        tangents.append(tf.reshape(flat_tangent[start:start + sz], shape))
        start += sz
    ####################################################################

    ####################################################################
    # 把kl散度梯度与变量乘积相加
    gvp = tf.add_n([
        tf.reduce_sum(g * tangent)
        for (g, tangent) in zipsame(klgrads, tangents)
    ])  #pylint: disable=E1111
    # 把gvp的梯度展成向量
    fvp = U.flatgrad(gvp, var_list)
    ####################################################################

    # 用学习后的策略更新old策略
    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(get_variables("oldpi"), get_variables("pi"))
        ])

    # 计算loss
    compute_losses = U.function([ob, ac, atarg], losses)
    # 计算loss和梯度
    compute_lossandgrad = U.function([ob, ac, atarg], losses +
                                     [U.flatgrad(optimgain, var_list)])
    # 计算fvp
    compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp)
    # 计算值网络的梯度
    compute_vflossandgrad = U.function([ob, ret],
                                       U.flatgrad(vferr, vf_var_list))

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(
                colorize("done in %.3f seconds" % (time.time() - tstart),
                         color='magenta'))
        else:
            yield

    def allmean(x):
        assert isinstance(x, np.ndarray)
        if MPI is not None:
            out = np.empty_like(x)
            MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
            out /= nworkers
        else:
            out = np.copy(x)

        return out

    # 初始化variable
    U.initialize()
    if load_path is not None:
        pi.load(load_path)

    # 得到初始化的参数向量
    th_init = get_flat()
    if MPI is not None:
        MPI.COMM_WORLD.Bcast(th_init, root=0)

    # 把向量the_init的值分片赋值给var_list
    set_from_flat(th_init)

    #同步
    vfadam.sync()
    print("Init param sum", th_init.sum(), flush=True)

    # Prepare for rollouts
    # ----------------------------------------

    # 这是一个生成数据的迭代器
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_batch,
                                     stochastic=True,
                                     num_reward=num_reward)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()

    # 双端队列
    lenbuffer = deque(maxlen=40)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40)  # rolling buffer for episode rewards

    if sum([max_iters > 0, total_timesteps > 0, max_episodes > 0]) == 0:
        # noththing to be done
        return pi

    assert sum([max_iters>0, total_timesteps>0, max_episodes>0]) < 2, \
        'out of max_iters, total_timesteps, and max_episodes only one should be specified'

    while True:
        if callback: callback(locals(), globals())
        if total_timesteps and timesteps_so_far >= total_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        logger.log("********** Iteration %i ************" % iters_so_far)

        with timed("sampling"):
            seg = seg_gen.__next__()

        # 计算累积回报
        add_vtarg_and_adv(seg, gamma, lam, num_reward=num_reward)
        ###########$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$ToDo
        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))

        # ob, ac, atarg, tdlamret 的类型都是ndarray
        #ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"]
        _, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        #print(seg['ob'].shape,type(seg['ob']))
        #print(seg['ac'],type(seg['ac']))
        #print(seg['adv'],type(seg['adv']))
        #print(seg["tdlamret"].shape,type(seg['tdlamret']))
        vpredbefore = seg["vpred"]  # predicted value function before udpate

        # 标准化
        #print("============================== atarg =========================================================")
        #print(atarg)
        atarg = (atarg - np.mean(atarg, axis=0)) / np.std(
            atarg, axis=0)  # standardized advantage function estimate
        #atarg = (atarg) / np.max(np.abs(atarg),axis=0)
        #print('======================================= standardized atarg ====================================')
        #print(atarg)
        if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret)
        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        ## set old parameter values to new parameter values
        assign_old_eq_new()

        G = None
        S = None
        mr_lossbefore = np.zeros((num_reward, len(loss_names)))
        grad_norm = np.zeros((num_reward + 1))
        for i in range(num_reward):
            args = seg["ob"], seg["ac"], atarg[:, i]
            #print(atarg[:,i])
            # 算是args的一个sample,每隔5个取出一个
            fvpargs = [arr[::5] for arr in args]

            # 这个函数计算fisher matrix 与向量 p 的 乘积
            def fisher_vector_product(p):
                return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p

            with timed("computegrad of " + str(i + 1) + ".th reward"):
                *lossbefore, g = compute_lossandgrad(*args)
            lossbefore = allmean(np.array(lossbefore))
            mr_lossbefore[i] = lossbefore
            g = allmean(g)
            #print("***************************************************************")
            #print(g)
            #print('==================='+str(i+1)+"=====================",np.linalg.norm(g))
            #print(atarg[:,i])
            if isinstance(G, np.ndarray):
                G = np.vstack((G, g))
            else:
                G = g

            # g是目标函数的梯度
            # 利用共轭梯度获得更新方向
            if np.allclose(g, 0):
                logger.log("Got zero gradient. not updating")
            else:
                with timed("cg of " + str(i + 1) + ".th reward"):
                    # stepdir 是更新方向
                    stepdir = cg(fisher_vector_product,
                                 g,
                                 cg_iters=cg_iters,
                                 verbose=rank == 0)
                    shs = .5 * stepdir.dot(fisher_vector_product(stepdir))
                    lm = np.sqrt(shs / max_kl)
                    # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
                    fullstep = stepdir / lm
                    #print(np.linalg.norm(fullstep))
                    grad_norm[i] = np.linalg.norm(fullstep)
                assert np.isfinite(stepdir).all()
                if isinstance(S, np.ndarray):
                    S = np.vstack((S, stepdir))
                else:
                    S = stepdir
        #print('======================================= G ====================================')
        #print(G)
        #print('======================================= S ====================================')
        #print(S)
        new_coe = get_coefficient(G, S)
        #coe = 0.99 * coe + 0.01 * new_coe
        coe = new_coe
        coe_save.append(coe)
        #根据梯度的夹角调整参数
        try:
            GG = np.dot(S, S.T)
            D = np.sqrt(np.diag(1 / np.diag(GG)))
            GG = np.dot(np.dot(D, GG), D)
            #print('======================================= inner product ====================================')
            #print(GG)
            adj = np.sum(GG) / (num_reward**2)
        except:
            adj = 1
        #print('======================================= adj ====================================')
        #print(adj)
        try:
            adj = 1
            adj_save.append(adj)
            adj_max_kl = adj * max_kl
            #################################################################
            grad_norm = grad_norm * np.sqrt(adj)
            stepdir = np.dot(coe, S)
            g = np.dot(coe, G)
            lossbefore = np.dot(coe, mr_lossbefore)
            #################################################################

            shs = .5 * stepdir.dot(fisher_vector_product(stepdir))
            lm = np.sqrt(shs / adj_max_kl)
            # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
            fullstep = stepdir / lm
            grad_norm[num_reward] = np.linalg.norm(fullstep)
            grad_save.append(grad_norm)
            expectedimprove = g.dot(fullstep)
            surrbefore = lossbefore[0]
            stepsize = 1.0
            thbefore = get_flat()

            def compute_mr_losses():
                mr_losses = np.zeros((num_reward, len(loss_names)))
                for i in range(num_reward):
                    args = seg["ob"], seg["ac"], atarg[:, i]
                    one_reward_loss = allmean(np.array(compute_losses(*args)))
                    mr_losses[i] = one_reward_loss
                mr_loss = np.dot(coe, mr_losses)
                return mr_loss, mr_losses

            # 做10次搜索
            for _ in range(10):
                thnew = thbefore + fullstep * stepsize
                set_from_flat(thnew)
                mr_loss_new, mr_losses_new = compute_mr_losses()
                mr_impro = mr_losses_new - mr_lossbefore
                meanlosses = surr, kl, *_ = allmean(np.array(mr_loss_new))
                improve = surr - surrbefore
                logger.log("Expected: %.3f Actual: %.3f" %
                           (expectedimprove, improve))
                if not np.isfinite(meanlosses).all():
                    logger.log("Got non-finite value of losses -- bad!")
                elif kl > adj_max_kl * 1.5:
                    logger.log("violated KL constraint. shrinking step.")
                elif improve < 0:
                    logger.log("surrogate didn't improve. shrinking step.")
                else:
                    logger.log("Stepsize OK!")
                    impro_save.append(np.hstack((mr_impro[:, 0], improve)))
                    break
                stepsize *= .5
            else:
                logger.log("couldn't compute a good step")
                set_from_flat(thbefore)
            if nworkers > 1 and iters_so_far % 20 == 0:
                paramsums = MPI.COMM_WORLD.allgather(
                    (thnew.sum(), vfadam.getflat().sum()))  # list of tuples
                assert all(
                    np.allclose(ps, paramsums[0]) for ps in paramsums[1:])

            for (lossname, lossval) in zip(loss_names, meanlosses):
                logger.record_tabular(lossname, lossval)

            with timed("vf"):
                #print('======================================= tdlamret ====================================')
                #print(seg["tdlamret"])
                for _ in range(vf_iters):
                    for (mbob, mbret) in dataset.iterbatches(
                        (seg["ob"], seg["tdlamret"]),
                            include_final_partial_batch=False,
                            batch_size=64):
                        #with tf.Session() as sess:
                        #    sess.run(tf.global_variables_initializer())
                        #    aaa = sess.run(pi.vf,feed_dict={ob:mbob,ret:mbret})
                        #    print("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
                        #    print(aaa.shape)
                        #    print(mbret.shape)
                        g = allmean(compute_vflossandgrad(mbob, mbret))
                        vfadam.update(g, vf_stepsize)
            #print(mbob,mbret)
        except:
            print('error')
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))

        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        if MPI is not None:
            listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        else:
            listoflrpairs = [lrlocal]

        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        if rank == 0:
            logger.dump_tabular()
        #pdb.set_trace()
    np.save(save_dir + 'coe.npy', coe_save)
    np.save(save_dir + 'grad.npy', grad_save)
    np.save(save_dir + 'improve.npy', impro_save)
    np.save(save_dir + 'adj.npy', adj_save)
    return pi
예제 #16
0
def learn(env, last_ob, last_jpos, run_reach, policy_func, reward_giver, expert_dataset, rank,
          pretrained, pretrained_weight, *,
          g_step, d_step, entcoeff, save_per_iter,
          ckpt_dir, log_dir, timesteps_per_batch, task_name,
          gamma, lam,
          max_kl, cg_iters, cg_damping=1e-2,
          vf_stepsize=3e-4, d_stepsize=3e-4, vf_iters=3,
          max_timesteps=0, max_episodes=0, max_iters=0,
          callback=None
          ):

    nworkers = MPI.COMM_WORLD.Get_size()
    rank = MPI.COMM_WORLD.Get_rank()
    np.set_printoptions(precision=3)
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space

    

    pi = policy_func("pi_grasp", ob_space, ac_space, reuse=(pretrained_weight != None))
    oldpi = policy_func("oldpi", ob_space, ac_space)
    atarg = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None])  # Target advantage function (if applicable)
    ret = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])


    # Changes are made in order to use tensorboard
    # -------------------------------------------
    #train_writer = tf.compat.v1.summary.FileWriter('../../logs/trpo_mpi') # sets log dir to GailPart folder

    #sess = tf.compat.v1.Session() # create a session??

    # -------------------------------------------

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    entbonus = entcoeff * meanent

    vferr = tf.reduce_mean(tf.square(pi.vpred - ret))

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # advantage * pnew / pold
    surrgain = tf.reduce_mean(ratio * atarg)

    optimgain = surrgain + entbonus
    losses = [optimgain, meankl, entbonus, surrgain, meanent]
    loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]

    dist = meankl

    all_var_list = pi.get_trainable_variables()
    var_list = [v for v in all_var_list if v.name.startswith("pi_grasp/pol") or v.name.startswith("pi_grasp/logstd")]
    vf_var_list = [v for v in all_var_list if v.name.startswith("pi_grasp/vff")]
    assert len(var_list) == len(vf_var_list) + 1
    d_adam = MpiAdam(reward_giver.get_trainable_variables())
    vfadam = MpiAdam(vf_var_list)

    get_flat = U.GetFlat(var_list)
    set_from_flat = U.SetFromFlat(var_list)
    klgrads = tf.gradients(dist, var_list)
    flat_tangent = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None], name="flat_tan")
    shapes = [var.get_shape().as_list() for var in var_list]
    start = 0
    tangents = []
    for shape in shapes:
        sz = U.intprod(shape)
        tangents.append(tf.reshape(flat_tangent[start:start+sz], shape))
        start += sz
    gvp = tf.add_n([tf.reduce_sum(g*tangent) for (g, tangent) in zipsame(klgrads, tangents)])  # pylint: disable=E1111
    fvp = U.flatgrad(gvp, var_list)

    assign_old_eq_new = U.function([], [], updates=[tf.compat.v1.assign(oldv, newv)
                                                    for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())])
    compute_losses = U.function([ob, ac, atarg], losses)
    compute_lossandgrad = U.function([ob, ac, atarg], losses + [U.flatgrad(optimgain, var_list)])
    compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp)
    compute_vflossandgrad = U.function([ob, ret], U.flatgrad(vferr, vf_var_list))

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(colorize("done in %.3f seconds" % (time.time() - tstart), color='magenta'))
        else:
            yield

    def allmean(x):
        assert isinstance(x, np.ndarray)
        out = np.empty_like(x)
        MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
        out /= nworkers
        return out

    U.initialize()
    th_init = get_flat()
    MPI.COMM_WORLD.Bcast(th_init, root=0)
    set_from_flat(th_init)
    d_adam.sync()
    vfadam.sync()
    
        
    if rank == 0:
        print("Init param sum", th_init.sum(), flush=True)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi, last_ob, last_jpos, run_reach, policy_func, env, reward_giver, timesteps_per_batch, stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=40)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40)  # rolling buffer for episode rewards
    true_rewbuffer = deque(maxlen=40)

    assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1

    g_loss_stats = stats(loss_names)
    d_loss_stats = stats(reward_giver.loss_name)
    ep_stats = stats(["True_rewards", "Rewards", "Episode_length"])
    # if provide pretrained weight
    if pretrained_weight is not None:
        U.load_state(pretrained_weight, var_list=pi.get_variables())

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break

        # Save model
        if rank == 0 and iters_so_far % save_per_iter == 0 and ckpt_dir is not None:
            t_name = task_name + "_" + str(iters_so_far)
            fname = os.path.join(ckpt_dir, t_name) # changed from task_name
            os.makedirs(os.path.dirname(fname), exist_ok=True)
            saver = tf.compat.v1.train.Saver()
            saver.save(tf.compat.v1.get_default_session(), fname)

        logger.log("********** Iteration %i ************" % iters_so_far)

        def fisher_vector_product(p):
            return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p
        # ------------------ Update G ------------------
        logger.log("Optimizing Policy...")
        for _ in range(g_step):
            with timed("sampling"):
                seg = seg_gen.__next__()
                #print("trpo_mpi, seg = seg_gen.__next__() call output: ", seg )
            add_vtarg_and_adv(seg, gamma, lam)
            # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
            ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"]
            vpredbefore = seg["vpred"]  # predicted value function before udpate
            atarg = (atarg - atarg.mean()) / atarg.std()  # standardized advantage function estimate

            if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob)  # update running mean/std for policy

            args = seg["ob"], seg["ac"], atarg
            fvpargs = [arr[::5] for arr in args]

            assign_old_eq_new()  # set old parameter values to new parameter values
            with timed("computegrad"):
                *lossbefore, g = compute_lossandgrad(*args)
            lossbefore = allmean(np.array(lossbefore))
            g = allmean(g)
            if np.allclose(g, 0):
                logger.log("Got zero gradient. not updating")
            else:
                with timed("cg"):
                    stepdir = cg(fisher_vector_product, g, cg_iters=cg_iters, verbose=rank == 0)
                assert np.isfinite(stepdir).all()
                shs = .5*stepdir.dot(fisher_vector_product(stepdir))
                lm = np.sqrt(shs / max_kl)
                # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
                fullstep = stepdir / lm
                expectedimprove = g.dot(fullstep)
                surrbefore = lossbefore[0]
                stepsize = 1.0
                thbefore = get_flat()
                for _ in range(10):
                    thnew = thbefore + fullstep * stepsize
                    set_from_flat(thnew)
                    meanlosses = surr, kl, *_ = allmean(np.array(compute_losses(*args)))

                    #logger.log("trpo_mpi.py, what should be logged with loss names ie. meanlosses:_", meanlosses)

                    improve = surr - surrbefore
                    logger.log("Expected: %.3f Actual: %.3f" % (expectedimprove, improve))
                    if not np.isfinite(meanlosses).all():
                        logger.log("Got non-finite value of losses -- bad!")
                    elif kl > max_kl * 1.5:
                        logger.log("violated KL constraint. shrinking step.")
                    elif improve < 0:
                        logger.log("surrogate didn't improve. shrinking step.")
                    else:
                        logger.log("Stepsize OK!")
                        break
                    stepsize *= .5
                else:
                    logger.log("couldn't compute a good step")
                    set_from_flat(thbefore)
                if nworkers > 1 and iters_so_far % 20 == 0:
                    paramsums = MPI.COMM_WORLD.allgather((thnew.sum(), vfadam.getflat().sum()))  # list of tuples
                    assert all(np.allclose(ps, paramsums[0]) for ps in paramsums[1:])
            with timed("vf"):
                for _ in range(vf_iters):
                    for (mbob, mbret) in dataset.iterbatches((seg["ob"], seg["tdlamret"]),
                                                             include_final_partial_batch=False, batch_size=128):
                        if hasattr(pi, "ob_rms"):
                            pi.ob_rms.update(mbob)  # update running mean/std for policy
                        g = allmean(compute_vflossandgrad(mbob, mbret))
                        vfadam.update(g, vf_stepsize)

        g_losses = meanlosses

        #logger.log("trpo_mpi.py, mean losses before logging wiht loss names: \n")
        #logger.log(meanlosses)


        # This is where the nan values are tabulated for some of the entries
        #logger.log("trpo_mpi.py, view whats being printed with (loss_names, lossvalues)")
        for (lossname, lossval) in zip(loss_names, meanlosses):
            logger.record_tabular(lossname, lossval)
        logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))
        # ------------------ Update D ------------------
        logger.log("Optimizing Discriminator...")
        logger.log(fmt_row(13, reward_giver.loss_name))
        ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob))
        batch_size = len(ob) // d_step
        d_losses = []  # list of tuples, each of which gives the loss for a minibatch
        for ob_batch, ac_batch in dataset.iterbatches((ob, ac),
                                                      include_final_partial_batch=False,
                                                      batch_size=batch_size):
            ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob_batch))
            # update running mean/std for reward_giver
            if hasattr(reward_giver, "obs_rms"): reward_giver.obs_rms.update(np.concatenate((ob_batch, ob_expert), 0))
            *newlosses, g = reward_giver.lossandgrad(ob_batch, ac_batch, ob_expert, ac_expert)
            d_adam.update(allmean(g), d_stepsize)
            d_losses.append(newlosses)
        
        # This is to see what the d_losses are
        #logger.log("trpo_mpi.py, see what is being logged in d_losses")
        #logger.log("trpo_mpi.py, d_losses")
        #logger.log(d_losses)
        
        logger.log(fmt_row(13, np.mean(d_losses, axis=0)))

        # For Tensorboard Logging
        # ---------------------------
        #tf.compat.v1.summary.scalar("Generator Accuracy", tf.convert_to_tensor( np.mean(d_losses, axis=0)[4] )  ) # 5 position
        #tf.compat.v1.summary.scalar("Expert Accuracy", tf.convert_to_tensor( np.mean(d_losses, axis=0)[5] ) ) # 6 position
        #tf.compat.v1.summary.scalar("Entropy Loss", tf.convert_to_tensor( np.mean(d_losses, axis=0)[3] )  ) # 4 position

        #merge = tf.compat.v1.summary.merge_all() # merge summaries
        #summary = sess.run([merge])

        #train_writer.add_summary(summary, iters_so_far)

        # Is there a need to reset metric after every epoch? I dont think so?
        



        # ---------------------------

        
        #logger.log("trpo_mpi.py, after logging, but before recordeing timesteps so far")

        lrlocal = (seg["ep_lens"], seg["ep_rets"], seg["ep_true_rets"])  # local values, truly confirmed is empty after call
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews, true_rets = map(flatten_lists, zip(*listoflrpairs))
        true_rewbuffer.extend(true_rets)
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        # Could it be that the seg locals for lens and rets are ommitted since has no use in gail algorithm?

        # Probably dont have to worry about it, check the scalar part
        logger.record_tabular("EpLenMean", np.mean(lenbuffer)) # This has nan values
        logger.record_tabular("EpRewMean", np.mean(rewbuffer)) # This has nan values
        logger.record_tabular("EpTrueRewMean", np.mean(true_rewbuffer)) # This has nan values
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        #timesteps_so_far += sum(lens)

        timesteps_so_far += seg["steps"] # changed to match setup with no finishing condition
        iters_so_far += 1


        #env.reset() #reset the environment after a new iteration, therefore in traj generator check ob

        logger.record_tabular("EpisodesSoFar", episodes_so_far) # This is 0 ? if lens which is the number of entries for episode length doesnt exist, doesnt make sense for it to have a return.
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)

        # I think the entloss, entrpoy, ev_.... and the useful ones arent from the environment called using the trpo

        if rank == 0:
            logger.dump_tabular()
def learn(
        *,
        network,
        env,
        eval_env,
        make_eval_env,
        env_id,
        total_timesteps,
        timesteps_per_batch,
        sil_update,
        sil_loss,  # what to train on
        max_kl=0.001,
        cg_iters=10,
        gamma=0.99,
        lam=1.0,  # advantage estimation
        seed=None,
        ent_coef=0.0,
        lr=3e-4,
        cg_damping=1e-2,
        vf_stepsize=3e-4,
        vf_iters=5,
        sil_value=0.01,
        sil_alpha=0.6,
        sil_beta=0.1,
        max_episodes=0,
        max_iters=0,  # time constraint
        callback=None,
        save_interval=0,
        load_path=None,
        # MBL
        # For train mbl
        mbl_train_freq=5,

        # For eval
        num_eval_episodes=5,
        eval_freq=5,
        vis_eval=False,
        eval_targs=('mbmf', ),
        #eval_targs=('mf',),
        quant=2,

        # For mbl.step
        #num_samples=(1500,),
        num_samples=(1, ),
        horizon=(2, ),
        #horizon=(2,1),
        #num_elites=(10,),
        num_elites=(1, ),
        mbl_lamb=(1.0, ),
        mbl_gamma=0.99,
        #mbl_sh=1, # Number of step for stochastic sampling
        mbl_sh=10000,
        #vf_lookahead=-1,
        #use_max_vf=False,
        reset_per_step=(0, ),

        # For get_model
        num_fc=2,
        num_fwd_hidden=500,
        use_layer_norm=False,

        # For MBL
        num_warm_start=int(1e4),
        init_epochs=10,
        update_epochs=5,
        batch_size=512,
        update_with_validation=False,
        use_mean_elites=1,
        use_ent_adjust=0,
        adj_std_scale=0.5,

        # For data loading
        validation_set_path=None,

        # For data collect
        collect_val_data=False,

        # For traj collect
        traj_collect='mf',

        # For profile
        measure_time=True,
        eval_val_err=False,
        measure_rew=True,
        model_fn=None,
        update_fn=None,
        init_fn=None,
        mpi_rank_weight=1,
        comm=None,
        vf_coef=0.5,
        max_grad_norm=0.5,
        log_interval=1,
        nminibatches=4,
        noptepochs=4,
        cliprange=0.2,
        **network_kwargs):
    '''
    learn a policy function with TRPO algorithm

    Parameters:
    ----------

    network                 neural network to learn. Can be either string ('mlp', 'cnn', 'lstm', 'lnlstm' for basic types)
                            or function that takes input placeholder and returns tuple (output, None) for feedforward nets
                            or (output, (state_placeholder, state_output, mask_placeholder)) for recurrent nets

    env                     environment (one of the gym environments or wrapped via baselines.common.vec_env.VecEnv-type class

    timesteps_per_batch     timesteps per gradient estimation batch

    max_kl                  max KL divergence between old policy and new policy ( KL(pi_old || pi) )

    ent_coef                coefficient of policy entropy term in the optimization objective

    cg_iters                number of iterations of conjugate gradient algorithm

    cg_damping              conjugate gradient damping

    vf_stepsize             learning rate for adam optimizer used to optimie value function loss

    vf_iters                number of iterations of value function optimization iterations per each policy optimization step

    total_timesteps           max number of timesteps

    max_episodes            max number of episodes

    max_iters               maximum number of policy optimization iterations

    callback                function to be called with (locals(), globals()) each policy optimization step

    load_path               str, path to load the model from (default: None, i.e. no model is loaded)

    **network_kwargs        keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network

    Returns:
    -------

    learnt model

    '''
    if not isinstance(num_samples, tuple): num_samples = (num_samples, )
    if not isinstance(horizon, tuple): horizon = (horizon, )
    if not isinstance(num_elites, tuple): num_elites = (num_elites, )
    if not isinstance(mbl_lamb, tuple): mbl_lamb = (mbl_lamb, )
    if not isinstance(reset_per_step, tuple):
        reset_per_step = (reset_per_step, )
    if validation_set_path is None:
        if collect_val_data:
            validation_set_path = os.path.join(logger.get_dir(), 'val.pkl')
        else:
            validation_set_path = os.path.join('dataset',
                                               '{}-val.pkl'.format(env_id))
    if eval_val_err:
        eval_val_err_path = os.path.join('dataset',
                                         '{}-combine-val.pkl'.format(env_id))
    logger.log(locals())
    logger.log('MBL_SH', mbl_sh)
    logger.log('Traj_collect', traj_collect)

    if MPI is not None:
        nworkers = MPI.COMM_WORLD.Get_size()
        rank = MPI.COMM_WORLD.Get_rank()
    else:
        nworkers = 1
        rank = 0
    cpus_per_worker = 1
    U.get_session(
        config=tf.ConfigProto(allow_soft_placement=True,
                              inter_op_parallelism_threads=cpus_per_worker,
                              intra_op_parallelism_threads=cpus_per_worker))

    set_global_seeds(seed)
    if isinstance(lr, float): lr = constfn(lr)
    else: assert callable(lr)
    if isinstance(cliprange, float): cliprange = constfn(cliprange)
    else: assert callable(cliprange)

    policy = build_policy(env, network, value_network='copy', **network_kwargs)
    nenvs = env.num_envs
    np.set_printoptions(precision=3)
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    nbatch = nenvs * timesteps_per_batch
    nbatch_train = nbatch // nminibatches
    is_mpi_root = (MPI is None or MPI.COMM_WORLD.Get_rank() == 0)

    ob = observation_placeholder(ob_space)
    with tf.variable_scope("pi"):
        pi = policy(observ_placeholder=ob)
        make_model = lambda: Model(
            policy=policy,
            ob_space=ob_space,
            ac_space=ac_space,
            nbatch_act=nenvs,
            nbatch_train=nbatch_train,
            nsteps=timesteps_per_batch,
            ent_coef=ent_coef,
            vf_coef=vf_coef,
            max_grad_norm=max_grad_norm,
            sil_update=sil_update,
            sil_value=sil_value,
            sil_alpha=sil_alpha,
            sil_beta=sil_beta,
            sil_loss=sil_loss,
            #                                    fn_reward=env.process_reward,
            fn_reward=None,
            #                                    fn_obs=env.process_obs,
            fn_obs=None,
            ppo=False,
            prev_pi='pi',
            silm=pi)
        model = make_model()
    with tf.variable_scope("oldpi"):
        oldpi = policy(observ_placeholder=ob)
        make_old_model = lambda: Model(
            policy=policy,
            ob_space=ob_space,
            ac_space=ac_space,
            nbatch_act=nenvs,
            nbatch_train=nbatch_train,
            nsteps=timesteps_per_batch,
            ent_coef=ent_coef,
            vf_coef=vf_coef,
            max_grad_norm=max_grad_norm,
            sil_update=sil_update,
            sil_value=sil_value,
            sil_alpha=sil_alpha,
            sil_beta=sil_beta,
            sil_loss=sil_loss,
            #                                    fn_reward=env.process_reward,
            fn_reward=None,
            #                                    fn_obs=env.process_obs,
            fn_obs=None,
            ppo=False,
            prev_pi='oldpi',
            silm=oldpi)
        old_model = make_old_model()

    # MBL
    # ---------------------------------------
    #viz = Visdom(env=env_id)
    win = None
    eval_targs = list(eval_targs)
    logger.log(eval_targs)

    make_model_f = get_make_mlp_model(num_fc=num_fc,
                                      num_fwd_hidden=num_fwd_hidden,
                                      layer_norm=use_layer_norm)
    mbl = MBL(env=eval_env,
              env_id=env_id,
              make_model=make_model_f,
              num_warm_start=num_warm_start,
              init_epochs=init_epochs,
              update_epochs=update_epochs,
              batch_size=batch_size,
              **network_kwargs)

    val_dataset = {'ob': None, 'ac': None, 'ob_next': None}
    if update_with_validation:
        logger.log('Update with validation')
        val_dataset = load_val_data(validation_set_path)
    if eval_val_err:
        logger.log('Log val error')
        eval_val_dataset = load_val_data(eval_val_err_path)
    if collect_val_data:
        logger.log('Collect validation data')
        val_dataset_collect = []

    def _mf_pi(ob, t=None):
        stochastic = True
        ac, vpred, _, _ = pi.step(ob, stochastic=stochastic)
        return ac, vpred

    def _mf_det_pi(ob, t=None):
        #ac, vpred, _, _ = pi.step(ob, stochastic=False)
        ac, vpred = pi._evaluate([pi.pd.mode(), pi.vf], ob)
        return ac, vpred

    def _mf_ent_pi(ob, t=None):
        mean, std, vpred = pi._evaluate([pi.pd.mode(), pi.pd.std, pi.vf], ob)
        ac = np.random.normal(mean, std * adj_std_scale, size=mean.shape)
        return ac, vpred
################### use_ent_adjust======> adj_std_scale????????pi action sample

    def _mbmf_inner_pi(ob, t=0):
        if use_ent_adjust:
            return _mf_ent_pi(ob)
        else:
            #return _mf_pi(ob)
            if t < mbl_sh: return _mf_pi(ob)
            else: return _mf_det_pi(ob)

    # ---------------------------------------

    # Run multiple configuration once
    all_eval_descs = []

    def make_mbmf_pi(n, h, e, l):
        def _mbmf_pi(ob):
            ac, rew = mbl.step(ob=ob,
                               pi=_mbmf_inner_pi,
                               horizon=h,
                               num_samples=n,
                               num_elites=e,
                               gamma=mbl_gamma,
                               lamb=l,
                               use_mean_elites=use_mean_elites)
            return ac[None], rew

        return Policy(step=_mbmf_pi, reset=None)

    for n in num_samples:
        for h in horizon:
            for l in mbl_lamb:
                for e in num_elites:
                    if 'mbmf' in eval_targs:
                        all_eval_descs.append(('MeanRew', 'MBL_TRPO_SIL',
                                               make_mbmf_pi(n, h, e, l)))
                    #if 'mbmf' in eval_targs: all_eval_descs.append(('MeanRew-n-{}-h-{}-e-{}-l-{}-sh-{}-me-{}'.format(n, h, e, l, mbl_sh, use_mean_elites), 'MBL_TRPO-n-{}-h-{}-e-{}-l-{}-sh-{}-me-{}'.format(n, h, e, l, mbl_sh, use_mean_elites), make_mbmf_pi(n, h, e, l)))
    if 'mf' in eval_targs:
        all_eval_descs.append(
            ('MeanRew', 'TRPO_SIL', Policy(step=_mf_pi, reset=None)))

    logger.log('List of evaluation targets')
    for it in all_eval_descs:
        logger.log(it[0])

    pool = Pool(mp.cpu_count())
    warm_start_done = False
    # ----------------------------------------

    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    entbonus = ent_coef * meanent

    vferr = tf.reduce_mean(tf.square(pi.vf - ret))

    ratio = tf.exp(pi.pd.logp(ac) -
                   oldpi.pd.logp(ac))  # advantage * pnew / pold
    surrgain = tf.reduce_mean(ratio * atarg)

    optimgain = surrgain + entbonus
    losses = [optimgain, meankl, entbonus, surrgain, meanent]
    loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]

    dist = meankl

    all_var_list = get_trainable_variables("pi")
    # var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("pol")]
    # vf_var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("vf")]
    var_list = get_pi_trainable_variables("pi")
    vf_var_list = get_vf_trainable_variables("pi")

    vfadam = MpiAdam(vf_var_list)

    get_flat = U.GetFlat(var_list)
    set_from_flat = U.SetFromFlat(var_list)
    klgrads = tf.gradients(dist, var_list)
    flat_tangent = tf.placeholder(dtype=tf.float32,
                                  shape=[None],
                                  name="flat_tan")
    shapes = [var.get_shape().as_list() for var in var_list]
    start = 0
    tangents = []
    for shape in shapes:
        sz = U.intprod(shape)
        tangents.append(tf.reshape(flat_tangent[start:start + sz], shape))
        start += sz
    gvp = tf.add_n([
        tf.reduce_sum(g * tangent)
        for (g, tangent) in zipsame(klgrads, tangents)
    ])  #pylint: disable=E1111
    fvp = U.flatgrad(gvp, var_list)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(get_variables("oldpi"), get_variables("pi"))
        ])

    compute_losses = U.function([ob, ac, atarg], losses)
    compute_lossandgrad = U.function([ob, ac, atarg], losses +
                                     [U.flatgrad(optimgain, var_list)])
    compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp)
    compute_vflossandgrad = U.function([ob, ret],
                                       U.flatgrad(vferr, vf_var_list))

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(
                colorize("done in %.3f seconds" % (time.time() - tstart),
                         color='magenta'))
        else:
            yield

    def allmean(x):
        assert isinstance(x, np.ndarray)
        out = np.empty_like(x)
        MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
        out /= nworkers
        return out

    U.initialize()
    if load_path is not None:
        pi.load(load_path)

    th_init = get_flat()
    MPI.COMM_WORLD.Bcast(th_init, root=0)
    set_from_flat(th_init)
    vfadam.sync()
    print("Init param sum", th_init.sum(), flush=True)
    # Prepare for rollouts
    # ----------------------------------------
    if traj_collect == 'mf':
        seg_gen = traj_segment_generator(env,
                                         timesteps_per_batch,
                                         model,
                                         stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=40)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40)  # rolling buffer for episode rewards

    if sum([max_iters > 0, total_timesteps > 0, max_episodes > 0]) == 0:
        # noththing to be done
        return pi

    assert sum([max_iters>0, total_timesteps>0, max_episodes>0]) < 2, \
        'out of max_iters, total_timesteps, and max_episodes only one should be specified'

    while True:
        if callback: callback(locals(), globals())
        if total_timesteps and timesteps_so_far >= total_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        logger.log("********** Iteration %i ************" % iters_so_far)

        with timed("sampling"):
            seg = seg_gen.__next__()
            if traj_collect == 'mf-random' or traj_collect == 'mf-mb':
                seg_mbl = seg_gen_mbl.__next__()
            else:
                seg_mbl = seg
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]

        # Val data collection
        if collect_val_data:
            for ob_, ac_, ob_next_ in zip(ob[:-1, 0, ...], ac[:-1, ...],
                                          ob[1:, 0, ...]):
                val_dataset_collect.append(
                    (copy.copy(ob_), copy.copy(ac_), copy.copy(ob_next_)))
        # -----------------------------
        # MBL update
        else:
            ob_mbl, ac_mbl = seg_mbl["ob"], seg_mbl["ac"]

            mbl.add_data_batch(ob_mbl[:-1, 0, ...], ac_mbl[:-1, ...],
                               ob_mbl[1:, 0, ...])
            mbl.update_forward_dynamic(require_update=iters_so_far %
                                       mbl_train_freq == 0,
                                       ob_val=val_dataset['ob'],
                                       ac_val=val_dataset['ac'],
                                       ob_next_val=val_dataset['ob_next'])
        # -----------------------------

        if traj_collect == 'mf':
            #if traj_collect == 'mf' or traj_collect == 'mf-random' or traj_collect == 'mf-mb':
            vpredbefore = seg[
                "vpred"]  # predicted value function before udpate
            model = seg["model"]
            atarg = (atarg - atarg.mean()) / atarg.std(
            )  # standardized advantage function estimate

            if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret)
            if hasattr(pi, "rms"):
                pi.rms.update(ob)  # update running mean/std for policy

            args = seg["ob"], seg["ac"], atarg
            fvpargs = [arr[::5] for arr in args]

            def fisher_vector_product(p):
                return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p

            assign_old_eq_new(
            )  # set old parameter values to new parameter values
            with timed("computegrad"):
                *lossbefore, g = compute_lossandgrad(*args)
            lossbefore = allmean(np.array(lossbefore))
            g = allmean(g)
            if np.allclose(g, 0):
                logger.log("Got zero gradient. not updating")
            else:
                with timed("cg"):
                    stepdir = cg(fisher_vector_product,
                                 g,
                                 cg_iters=cg_iters,
                                 verbose=rank == 0)
                assert np.isfinite(stepdir).all()
                shs = .5 * stepdir.dot(fisher_vector_product(stepdir))
                lm = np.sqrt(shs / max_kl)
                # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
                fullstep = stepdir / lm
                expectedimprove = g.dot(fullstep)
                surrbefore = lossbefore[0]
                stepsize = 1.0
                thbefore = get_flat()
                for _ in range(10):
                    thnew = thbefore + fullstep * stepsize
                    set_from_flat(thnew)
                    meanlosses = surr, kl, *_ = allmean(
                        np.array(compute_losses(*args)))
                    improve = surr - surrbefore
                    logger.log("Expected: %.3f Actual: %.3f" %
                               (expectedimprove, improve))
                    if not np.isfinite(meanlosses).all():
                        logger.log("Got non-finite value of losses -- bad!")
                    elif kl > max_kl * 1.5:
                        logger.log("violated KL constraint. shrinking step.")
                    elif improve < 0:
                        logger.log("surrogate didn't improve. shrinking step.")
                    else:
                        logger.log("Stepsize OK!")
                        break
                    stepsize *= .5
                else:
                    logger.log("couldn't compute a good step")
                    set_from_flat(thbefore)
                if nworkers > 1 and iters_so_far % 20 == 0:
                    paramsums = MPI.COMM_WORLD.allgather(
                        (thnew.sum(),
                         vfadam.getflat().sum()))  # list of tuples
                    assert all(
                        np.allclose(ps, paramsums[0]) for ps in paramsums[1:])

            for (lossname, lossval) in zip(loss_names, meanlosses):
                logger.record_tabular(lossname, lossval)

            with timed("vf"):

                for _ in range(vf_iters):
                    for (mbob, mbret) in dataset.iterbatches(
                        (seg["ob"], seg["tdlamret"]),
                            include_final_partial_batch=False,
                            batch_size=64):
                        g = allmean(compute_vflossandgrad(mbob, mbret))
                        vfadam.update(g, vf_stepsize)
            with timed("SIL"):
                lrnow = lr(1.0 - timesteps_so_far / total_timesteps)
                l_loss, sil_adv, sil_samples, sil_nlogp = model.sil_train(
                    lrnow)

            logger.record_tabular("ev_tdlam_before",
                                  explained_variance(vpredbefore, tdlamret))

        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        if MPI is not None:
            listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        else:
            listoflrpairs = [lrlocal]
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        if sil_update > 0:
            logger.record_tabular("SilSamples", sil_samples)

        if rank == 0:
            # MBL evaluation
            if not collect_val_data:
                #set_global_seeds(seed)
                default_sess = tf.get_default_session()

                def multithread_eval_policy(env_, pi_, num_episodes_,
                                            vis_eval_, seed):
                    with default_sess.as_default():
                        if hasattr(env, 'ob_rms') and hasattr(env_, 'ob_rms'):
                            env_.ob_rms = env.ob_rms
                        res = eval_policy(env_, pi_, num_episodes_, vis_eval_,
                                          seed, measure_time, measure_rew)

                        try:
                            env_.close()
                        except:
                            pass
                    return res

                if mbl.is_warm_start_done() and iters_so_far % eval_freq == 0:
                    warm_start_done = mbl.is_warm_start_done()
                    if num_eval_episodes > 0:
                        targs_names = {}
                        with timed('eval'):
                            num_descs = len(all_eval_descs)
                            list_field_names = [e[0] for e in all_eval_descs]
                            list_legend_names = [e[1] for e in all_eval_descs]
                            list_pis = [e[2] for e in all_eval_descs]
                            list_eval_envs = [
                                make_eval_env() for _ in range(num_descs)
                            ]
                            list_seed = [seed for _ in range(num_descs)]
                            list_num_eval_episodes = [
                                num_eval_episodes for _ in range(num_descs)
                            ]
                            print(list_field_names)
                            print(list_legend_names)

                            list_vis_eval = [
                                vis_eval for _ in range(num_descs)
                            ]

                            for i in range(num_descs):
                                field_name, legend_name = list_field_names[
                                    i], list_legend_names[i],

                                res = multithread_eval_policy(
                                    list_eval_envs[i], list_pis[i],
                                    list_num_eval_episodes[i],
                                    list_vis_eval[i], seed)
                                #eval_results = pool.starmap(multithread_eval_policy, zip(list_eval_envs, list_pis, list_num_eval_episodes, list_vis_eval,list_seed))

                                #for field_name, legend_name, res in zip(list_field_names, list_legend_names, eval_results):
                                perf, elapsed_time, eval_rew = res
                                logger.record_tabular(field_name, perf)
                                if measure_time:
                                    logger.record_tabular(
                                        'Time-%s' % (field_name), elapsed_time)
                                if measure_rew:
                                    logger.record_tabular(
                                        'SimRew-%s' % (field_name), eval_rew)
                                targs_names[field_name] = legend_name

                    if eval_val_err:
                        fwd_dynamics_err = mbl.eval_forward_dynamic(
                            obs=eval_val_dataset['ob'],
                            acs=eval_val_dataset['ac'],
                            obs_next=eval_val_dataset['ob_next'])
                        logger.record_tabular('FwdValError', fwd_dynamics_err)

                    logger.dump_tabular()
                    #print(logger.get_dir())
                    #print(targs_names)
                    #if num_eval_episodes > 0:


#                        win = plot(viz, win, logger.get_dir(), targs_names=targs_names, quant=quant, opt='best')
# -----------
#logger.dump_tabular()
        yield pi

    if collect_val_data:
        with open(validation_set_path, 'wb') as f:
            pickle.dump(val_dataset_collect, f)
        logger.log('Save {} validation data'.format(len(val_dataset_collect)))
예제 #18
0
def learn(env,
          policy_func,
          reward_giver,
          expert_dataset,
          rank,
          pretrained,
          pretrained_weight,
          *,
          g_step,
          d_step,
          entcoeff,
          ckpt_dir,
          timesteps_per_batch,
          task_name,
          gamma,
          lam,
          max_kl,
          cg_iters,
          cg_damping=1e-2,
          vf_stepsize=3e-4,
          d_stepsize=3e-4,
          vf_iters=3,
          max_timesteps=0,
          max_episodes=0,
          max_iters=0,
          rnd_iter=200,
          callback=None,
          dyn_norm=False,
          mmd=False):

    nworkers = MPI.COMM_WORLD.Get_size()
    rank = MPI.COMM_WORLD.Get_rank()
    np.set_printoptions(precision=3)
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_func("pi", ob_space, ac_space)
    oldpi = policy_func("oldpi", ob_space, ac_space)
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    entbonus = entcoeff * meanent

    ratio = tf.exp(pi.pd.logp(ac) -
                   oldpi.pd.logp(ac))  # advantage * pnew / pold
    surrgain = tf.reduce_mean(ratio * atarg)

    optimgain = surrgain + entbonus
    losses = [optimgain, meankl, entbonus, surrgain, meanent]
    loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]

    dist = meankl

    all_var_list = pi.get_trainable_variables()
    var_list = [
        v for v in all_var_list
        if v.name.startswith("pi/pol") or v.name.startswith("pi/logstd")
    ]
    vf_var_list = [v for v in all_var_list if v.name.startswith("pi/vff")]
    vfadam = MpiAdam(vf_var_list)

    get_flat = U.GetFlat(var_list)
    set_from_flat = U.SetFromFlat(var_list)
    klgrads = tf.gradients(dist, var_list)
    flat_tangent = tf.placeholder(dtype=tf.float32,
                                  shape=[None],
                                  name="flat_tan")
    shapes = [var.get_shape().as_list() for var in var_list]
    start = 0
    tangents = []
    for shape in shapes:
        sz = U.intprod(shape)
        tangents.append(tf.reshape(flat_tangent[start:start + sz], shape))
        start += sz
    gvp = tf.add_n([
        tf.reduce_sum(g * tangent)
        for (g, tangent) in zipsame(klgrads, tangents)
    ])  # pylint: disable=E1111
    fvp = U.flatgrad(gvp, var_list)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg], losses)
    compute_lossandgrad = U.function([ob, ac, atarg], losses +
                                     [U.flatgrad(optimgain, var_list)])
    compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp)
    compute_vflossandgrad = pi.vlossandgrad

    def allmean(x):
        assert isinstance(x, np.ndarray)
        out = np.empty_like(x)
        MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
        out /= nworkers
        return out

    U.initialize()
    th_init = get_flat()
    MPI.COMM_WORLD.Bcast(th_init, root=0)
    set_from_flat(th_init)
    vfadam.sync()
    if rank == 0:
        print("Init param sum", th_init.sum(), flush=True)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     reward_giver,
                                     timesteps_per_batch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=40)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40)  # rolling buffer for episode rewards
    true_rewbuffer = deque(maxlen=40)

    assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1

    ep_stats = stats(["True_rewards", "Rewards", "Episode_length"])
    # if provide pretrained weight
    if pretrained_weight is not None:
        U.load_variables(pretrained_weight, variables=pi.get_variables())
    else:
        if not dyn_norm:
            pi.ob_rms.update(expert_dataset[0])

    if not mmd:
        reward_giver.train(*expert_dataset, iter=rnd_iter)
        #inspect the reward learned
        # for batch in iterbatches(expert_dataset, batch_size=32):
        #     print(reward_giver.get_reward(*batch))

    best = -2000
    save_ind = 0
    max_save = 3
    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break

        logger.log("********** Iteration %i ************" % iters_so_far)

        def fisher_vector_product(p):
            return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p

        # ------------------ Update G ------------------
        # logger.log("Optimizing Policy...")
        for _ in range(g_step):
            seg = seg_gen.__next__()

            #mmd reward
            if mmd:
                reward_giver.set_b2(seg["ob"], seg["ac"])
                seg["rew"] = reward_giver.get_reward(seg["ob"], seg["ac"])

            #report stats and save policy if any good
            lrlocal = (seg["ep_lens"], seg["ep_rets"], seg["ep_true_rets"]
                       )  # local values
            listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
            lens, rews, true_rets = map(flatten_lists, zip(*listoflrpairs))
            true_rewbuffer.extend(true_rets)
            lenbuffer.extend(lens)
            rewbuffer.extend(rews)

            true_rew_avg = np.mean(true_rewbuffer)
            logger.record_tabular("EpLenMean", np.mean(lenbuffer))
            logger.record_tabular("EpRewMean", np.mean(rewbuffer))
            logger.record_tabular("EpTrueRewMean", true_rew_avg)
            logger.record_tabular("EpThisIter", len(lens))
            episodes_so_far += len(lens)
            timesteps_so_far += sum(lens)
            iters_so_far += 1

            logger.record_tabular("EpisodesSoFar", episodes_so_far)
            logger.record_tabular("TimestepsSoFar", timesteps_so_far)
            logger.record_tabular("TimeElapsed", time.time() - tstart)
            logger.record_tabular("Best so far", best)

            # Save model
            if ckpt_dir is not None and true_rew_avg >= best:
                best = true_rew_avg
                fname = os.path.join(ckpt_dir, task_name)
                os.makedirs(os.path.dirname(fname), exist_ok=True)
                pi.save_policy(fname + "_" + str(save_ind))
                save_ind = (save_ind + 1) % max_save

            #compute gradient towards next policy
            add_vtarg_and_adv(seg, gamma, lam)
            ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
                "tdlamret"]
            vpredbefore = seg[
                "vpred"]  # predicted value function before udpate
            atarg = (atarg - atarg.mean()) / atarg.std(
            )  # standardized advantage function estimate

            if hasattr(pi, "ob_rms") and dyn_norm:
                pi.ob_rms.update(ob)  # update running mean/std for policy

            args = seg["ob"], seg["ac"], atarg
            fvpargs = [arr[::5] for arr in args]

            assign_old_eq_new(
            )  # set old parameter values to new parameter values
            *lossbefore, g = compute_lossandgrad(*args)
            lossbefore = allmean(np.array(lossbefore))
            g = allmean(g)
            if np.allclose(g, 0):
                logger.log("Got zero gradient. not updating")
            else:
                stepdir = cg(fisher_vector_product,
                             g,
                             cg_iters=cg_iters,
                             verbose=False)
                assert np.isfinite(stepdir).all()
                shs = .5 * stepdir.dot(fisher_vector_product(stepdir))
                lm = np.sqrt(shs / max_kl)
                fullstep = stepdir / lm
                expectedimprove = g.dot(fullstep)
                surrbefore = lossbefore[0]
                stepsize = 1.0
                thbefore = get_flat()
                for _ in range(10):
                    thnew = thbefore + fullstep * stepsize
                    set_from_flat(thnew)
                    meanlosses = surr, kl, *_ = allmean(
                        np.array(compute_losses(*args)))
                    improve = surr - surrbefore
                    logger.log("Expected: %.3f Actual: %.3f" %
                               (expectedimprove, improve))
                    if not np.isfinite(meanlosses).all():
                        logger.log("Got non-finite value of losses -- bad!")
                    elif kl > max_kl * 1.5:
                        logger.log("violated KL constraint. shrinking step.")
                    elif improve < 0:
                        logger.log("surrogate didn't improve. shrinking step.")
                    else:
                        logger.log("Stepsize OK!")
                        break
                    stepsize *= .5
                else:
                    logger.log("couldn't compute a good step")
                    set_from_flat(thbefore)
                if nworkers > 1 and iters_so_far % 20 == 0:
                    paramsums = MPI.COMM_WORLD.allgather(
                        (thnew.sum(),
                         vfadam.getflat().sum()))  # list of tuples
                    assert all(
                        np.allclose(ps, paramsums[0]) for ps in paramsums[1:])
            if pi.use_popart:
                pi.update_popart(tdlamret)
            for _ in range(vf_iters):
                for (mbob, mbret) in dataset.iterbatches(
                    (seg["ob"], seg["tdlamret"]),
                        include_final_partial_batch=False,
                        batch_size=128):
                    if hasattr(pi, "ob_rms") and dyn_norm:
                        pi.ob_rms.update(
                            mbob)  # update running mean/std for policy
                    vfadam.update(allmean(compute_vflossandgrad(mbob, mbret)),
                                  vf_stepsize)

        g_losses = meanlosses
        for (lossname, lossval) in zip(loss_names, meanlosses):
            logger.record_tabular(lossname, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        if rank == 0:
            logger.dump_tabular()
예제 #19
0
class TRPO_agent_new(object):
    index2 = 0

    def __init__(self, a_name, env, policy_func, par):

        self.env = env
        self.timesteps_per_batch = par.timesteps_per_batch
        self.max_kl = par.max_kl
        self.cg_iters = par.cg_iters
        self.gamma = par.gamma
        self.lam = par.lam  # advantage estimation
        self.entcoeff = par.entcoeff
        self.cg_damping = par.cg_damping
        self.vf_stepsize = par.vf_stepsize
        self.vf_iters = par.vf_iters
        self.max_timesteps = par.max_timesteps
        self.max_episodes = par.max_episodes
        self.max_iters = par.max_iters
        self.callback = par.callback,  # you can do anything in the callback, since it takes locals(), globals()

        self.nworkers = MPI.COMM_WORLD.Get_size()
        self.rank = MPI.COMM_WORLD.Get_rank()
        np.set_printoptions(precision=3)
        # Setup losses and stuff
        # ----------------------------------------
        self.ob_space = self.env.observation_space
        self.ac_space = self.env.action_space
        self.pi = policy_func(a_name, self.ob_space, self.ac_space)
        self.oldpi = policy_func("oldpi" + a_name, self.ob_space,
                                 self.ac_space)
        self.atarg = tf.placeholder(
            dtype=tf.float32,
            shape=[None])  # Target advantage function (if applicable)
        self.ret = tf.placeholder(dtype=tf.float32,
                                  shape=[None])  # Empirical return

        self.ob = U.get_placeholder_cached(name="ob" +
                                           str(TRPO_agent_new.index2))
        self.ac = self.pi.pdtype.sample_placeholder([None])

        self.kloldnew = self.oldpi.pd.kl(self.pi.pd)
        self.ent = self.pi.pd.entropy()
        meankl = U.mean(self.kloldnew)
        meanent = U.mean(self.ent)
        entbonus = self.entcoeff * meanent

        self.vferr = U.mean(tf.square(self.pi.vpred - self.ret))

        ratio = tf.exp(self.pi.pd.logp(self.ac) -
                       self.oldpi.pd.logp(self.ac))  # advantage * pnew / pold
        surrgain = U.mean(ratio * self.atarg)

        optimgain = surrgain + entbonus
        self.losses = [optimgain, meankl, entbonus, surrgain, meanent]
        self.loss_names = [
            "optimgain", "meankl", "entloss", "surrgain", "entropy"
        ]

        self.dist = meankl

        all_var_list = self.pi.get_trainable_variables()

        var_list = [
            v for v in all_var_list if v.name.split("/")[1].startswith("pol")
        ]
        vf_var_list = [
            v for v in all_var_list if v.name.split("/")[1].startswith("vf")
        ]
        self.vfadam = MpiAdam(vf_var_list)

        self.get_flat = U.GetFlat(var_list)
        self.set_from_flat = U.SetFromFlat(var_list)
        self.klgrads = tf.gradients(self.dist, var_list)
        self.flat_tangent = tf.placeholder(dtype=tf.float32,
                                           shape=[None],
                                           name="flat_tan" +
                                           str(TRPO_agent_new.index2))

        shapes = [var.get_shape().as_list() for var in var_list]
        start = 0
        self.tangents = []
        for shape in shapes:
            sz = U.intprod(shape)
            self.tangents.append(
                tf.reshape(self.flat_tangent[start:start + sz], shape))
            start += sz

        self.gvp = tf.add_n([
            U.sum(g * tangent)
            for (g, tangent) in zipsame(self.klgrads, self.tangents)
        ])  #pylint: disable=E1111
        self.fvp = U.flatgrad(self.gvp, var_list)

        self.assign_old_eq_new = U.function(
            [], [],
            updates=[
                tf.assign(oldv, newv) for (oldv, newv) in zipsame(
                    self.oldpi.get_variables(), self.pi.get_variables())
            ])

        self.compute_losses = U.function([self.ob, self.ac, self.atarg],
                                         self.losses)
        self.compute_lossandgrad = U.function(
            [self.ob, self.ac, self.atarg],
            self.losses + [U.flatgrad(optimgain, var_list)])
        self.compute_fvp = U.function(
            [self.flat_tangent, self.ob, self.ac, self.atarg], self.fvp)
        self.compute_vflossandgrad = U.function([self.ob, self.ret],
                                                U.flatgrad(
                                                    self.vferr, vf_var_list))

        TRPO_agent_new.index2 += 1
        U.initialize()
        self.th_init = self.get_flat()
        MPI.COMM_WORLD.Bcast(self.th_init, root=0)
        self.set_from_flat(self.th_init)
        self.vfadam.sync()
        print("Init param sum", self.th_init.sum(), flush=True)

    @contextmanager
    def timed(self, msg):
        if self.rank == 0:  ##################################
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(
                colorize("done in %.3f seconds" % (time.time() - tstart),
                         color='magenta'))
        else:
            yield

    def allmean(self, x):
        assert isinstance(x, np.ndarray)
        out = np.empty_like(x)
        MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
        out /= self.nworkers  ####################################
        return out

    def fisher_vector_product(self, p):
        return self.allmean(self.compute_fvp(
            p, *self.fvpargs)) + self.cg_damping * p

    def learn(self):

        # Prepare for rollouts
        # ----------------------------------------
        seg_gen = traj_segment_generator(self.pi,
                                         self.env,
                                         self.timesteps_per_batch,
                                         stochastic=True)

        episodes_so_far = 0
        timesteps_so_far = 0
        iters_so_far = 0
        tstart = time.time()
        lenbuffer = deque(maxlen=40)  # rolling buffer for episode lengths
        rewbuffer = deque(maxlen=40)  # rolling buffer for episode rewards

        assert sum([
            self.max_iters > 0, self.max_timesteps > 0, self.max_episodes > 0
        ]) == 1

        while True:

            if self.max_timesteps and timesteps_so_far >= self.max_timesteps:
                break
            elif self.max_episodes and episodes_so_far >= self.max_episodes:
                break
            elif self.max_iters and iters_so_far >= self.max_iters:
                break
            logger.log("********** Iteration %i ************" % iters_so_far)

            with self.timed("sampling"):
                seg = seg_gen.__next__()

            add_vtarg_and_adv(seg, self.gamma, self.lam)

            # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
            self.ob, self.ac, self.atarg, self.tdlamret = seg["ob"], seg[
                "ac"], seg["adv"], seg["tdlamret"]
            self.vpredbefore = seg[
                "vpred"]  # predicted value function before udpate
            self.atarg = (self.atarg - self.atarg.mean()) / self.atarg.std(
            )  # standardized advantage function estimate

            if hasattr(self.pi, "ret_rms"):
                self.pi.ret_rms.update(self.tdlamret)
            if hasattr(self.pi, "ob_rms"):
                self.pi.ob_rms.update(
                    self.ob)  # update running mean/std for policy

            args = seg["ob"], seg["ac"], self.atarg
            self.fvpargs = [arr[::5] for arr in args]

            self.assign_old_eq_new(
            )  # set old parameter values to new parameter values
            with self.timed("computegrad"):
                *lossbefore, g = self.compute_lossandgrad(*args)

            lossbefore = self.allmean(np.array(lossbefore))
            g = self.allmean(g)
            if np.allclose(g, 0):
                logger.log("Got zero gradient. not updating")
            else:

                with self.timed("cg"):
                    stepdir = cg(self.fisher_vector_product,
                                 g,
                                 cg_iters=self.cg_iters,
                                 verbose=self.rank == 0)

                assert np.isfinite(stepdir).all()
                shs = .5 * stepdir.dot(self.fisher_vector_product(stepdir))
                lm = np.sqrt(shs / self.max_kl)
                # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
                fullstep = stepdir / lm
                expectedimprove = g.dot(fullstep)
                surrbefore = lossbefore[0]
                stepsize = 1.0
                thbefore = self.get_flat()
                for _ in range(10):
                    thnew = thbefore + fullstep * stepsize
                    self.set_from_flat(thnew)
                    meanlosses = surr, kl, *_ = self.allmean(
                        np.array(self.compute_losses(*args)))
                    improve = surr - surrbefore
                    logger.log("Expected: %.3f Actual: %.3f" %
                               (expectedimprove, improve))
                    if not np.isfinite(meanlosses).all():
                        logger.log("Got non-finite value of losses -- bad!")
                    elif kl > self.max_kl * 1.5:
                        logger.log("violated KL constraint. shrinking step.")
                    elif improve < 0:
                        logger.log("surrogate didn't improve. shrinking step.")
                    else:
                        logger.log("Stepsize OK!")
                        break
                    stepsize *= .5
                else:
                    logger.log("couldn't compute a good step")
                    self.set_from_flat(thbefore)
                if self.nworkers > 1 and iters_so_far % 20 == 0:
                    paramsums = MPI.COMM_WORLD.allgather(
                        (thnew.sum(),
                         self.vfadam.getflat().sum()))  # list of tuples
                    assert all(
                        np.allclose(ps, paramsums[0]) for ps in paramsums[1:])

            for (lossname, lossval) in zip(self.loss_names, meanlosses):
                logger.record_tabular(lossname, lossval)

            with self.timed("vf"):

                for _ in range(self.vf_iters):
                    for (mbob, mbret) in dataset.iterbatches(
                        (seg["ob"], seg["tdlamret"]),
                            include_final_partial_batch=False,
                            batch_size=64):
                        g = self.allmean(
                            self.compute_vflossandgrad(mbob, mbret))
                        self.vfadam.update(g, self.vf_stepsize)

            logger.record_tabular(
                "ev_tdlam_before",
                explained_variance(self.vpredbefore, self.tdlamret))

            lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
            listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
            lens, rews = map(flatten_lists, zip(*listoflrpairs))
            lenbuffer.extend(lens)
            rewbuffer.extend(rews)

            logger.record_tabular("EpLenMean", np.mean(lenbuffer))
            logger.record_tabular("EpRewMean", np.mean(rewbuffer))
            logger.record_tabular("EpThisIter", len(lens))
            episodes_so_far += len(lens)
            timesteps_so_far += sum(lens)
            iters_so_far += 1

            logger.record_tabular("EpisodesSoFar", episodes_so_far)
            logger.record_tabular("TimestepsSoFar", timesteps_so_far)
            logger.record_tabular("TimeElapsed", time.time() - tstart)

            if self.rank == 0:
                logger.dump_tabular()

    def action_ev(self, obs):
        ac, vpred = self.pi.act(False, obs)
        return ac

    def restore(self, folder):
        U.load_state(folder + '/data')

    def save_data(self, folder):
        try:
            os.makedirs(folder)
        except OSError as e:
            if e.errno != errno.EEXIST:
                raise

        saver = U.tf.train.Saver()
        saver.save(U.get_session(), folder + '/data')
예제 #20
0
def learn(env,
          policy_func,
          rank,
          pretrained,
          pretrained_weight,
          *,
          g_step,
          d_step,
          entcoeff,
          save_per_iter,
          ckpt_dir,
          log_dir,
          timesteps_per_batch,
          task_name,
          gamma,
          lam,
          max_kl,
          cg_iters,
          cg_damping=1e-2,
          vf_stepsize=1e-4,
          vf_iters=3,
          max_timesteps=0,
          max_episodes=0,
          max_iters=0,
          callback=None):

    nworkers = MPI.COMM_WORLD.Get_size()
    rank = MPI.COMM_WORLD.Get_rank()
    np.set_printoptions(precision=3)
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_func(
        "pi",
        ob_space,
        ac_space,
    )  # reuse=(pretrained_weight != None)
    oldpi = policy_func("oldpi", ob_space, ac_space)
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    #ob = U.get_placeholder_cached(name="ob")
    ob_config = U.get_placeholder_cached(name="ob")
    ob_target = U.get_placeholder_cached(name="goal")
    obs_pos = U.get_placeholder_cached(name="obs_pos")
    #obs_pos2 = U.get_placeholder_cached(name="obs_pos2")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    entbonus = entcoeff * meanent

    vferr = tf.reduce_mean(tf.square(pi.vpred - ret))

    ratio = tf.exp(pi.pd.logp(ac) -
                   oldpi.pd.logp(ac))  # advantage * pnew / pold
    surrgain = tf.reduce_mean(ratio * atarg)

    optimgain = surrgain + entbonus
    losses = [optimgain, meankl, entbonus, surrgain, meanent]
    loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]

    dist = meankl

    all_var_list = pi.get_trainable_variables()
    var_list = [
        v for v in all_var_list if v.name.startswith("pi/pol")
        or v.name.startswith("pi/logstd") or v.name.startswith("pi/obs")
    ]
    vf_var_list = [
        v for v in all_var_list
        if v.name.startswith("pi/vf") or v.name.startswith("pi/obs")
    ]

    vfadam = MpiAdam(vf_var_list)

    get_flat = U.GetFlat(var_list)
    set_from_flat = U.SetFromFlat(var_list)
    klgrads = tf.gradients(dist, var_list)
    flat_tangent = tf.placeholder(dtype=tf.float32,
                                  shape=[None],
                                  name="flat_tan")
    shapes = [var.get_shape().as_list() for var in var_list]
    start = 0
    tangents = []
    for shape in shapes:
        sz = U.intprod(shape)
        tangents.append(tf.reshape(flat_tangent[start:start + sz], shape))
        start += sz
    gvp = tf.add_n([
        tf.reduce_sum(g * tangent)
        for (g, tangent) in zipsame(klgrads, tangents)
    ])  # pylint: disable=E1111
    fvp = U.flatgrad(gvp, var_list)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob_config, ob_target, obs_pos, ac, atarg],
                                losses)
    compute_lossandgrad = U.function(
        [ob_config, ob_target, obs_pos, ac, atarg],
        losses + [U.flatgrad(optimgain, var_list)])
    compute_fvp = U.function(
        [flat_tangent, ob_config, ob_target, obs_pos, ac, atarg], fvp)
    compute_vflossandgrad = U.function([ob_config, ob_target, obs_pos, ret],
                                       U.flatgrad(vferr, vf_var_list))

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(
                colorize("done in %.3f seconds" % (time.time() - tstart),
                         color='magenta'))
        else:
            yield

    def allmean(x):
        assert isinstance(x, np.ndarray)
        out = np.empty_like(x)
        MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
        out /= nworkers
        return out

    U.initialize()
    th_init = get_flat()
    MPI.COMM_WORLD.Bcast(th_init, root=0)
    set_from_flat(th_init)
    vfadam.sync()
    if rank == 0:
        print("Init param sum", th_init.sum(), flush=True)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_batch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=40)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40)  # rolling buffer for episode rewards
    true_rewbuffer = deque(maxlen=40)
    max_trm = -5
    true_reward_mean = 0
    assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1

    g_loss_stats = stats(loss_names)
    ep_stats = stats(["True_rewards", "Rewards", "Episode_length"])
    # if provide pretrained weight
    if pretrained_weight is not None:
        #U.load_variables(pretrained_weight, variables=pi.get_variables())
        saver = tf.train.Saver()
        saver.restore(tf.get_default_session(), pretrained_weight)

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break

        # Save model
        if rank == 0 and ckpt_dir is not None and true_reward_mean > max_trm:
            fname = os.path.join(ckpt_dir, task_name)
            os.makedirs(os.path.dirname(fname), exist_ok=True)
            saver = tf.train.Saver()
            saver.save(tf.get_default_session(), fname)
            max_trm = true_reward_mean

        logger.log("********** Iteration %i ************" % iters_so_far)

        def fisher_vector_product(p):
            v1 = allmean(compute_fvp(p, *fvpargs))
            # print("norm(v1):%.2e, norm(p):%.2e, cg_damping:%.2e"%(np.linalg.norm(v1), np.linalg.norm(p), cg_damping))
            return v1 + cg_damping * p

        # ------------------ Update G ------------------
        logger.log("Optimizing Policy...")
        for _ in range(g_step):
            with timed("sampling"):
                seg = seg_gen.__next__()
            add_vtarg_and_adv(seg, gamma, lam)
            # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
            ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
                "tdlamret"]
            vpredbefore = seg[
                "vpred"]  # predicted value function before udpate
            atarg = (atarg - atarg.mean()) / atarg.std(
            )  # standardized advantage function estimate

            if hasattr(pi, "ob_rms"):
                pi.ob_rms.update(ob)  # update running mean/std for policy
            config, goal, obstacle_pos = [], [], []
            for o in seg["ob"]:
                config.append(o["joint"])
                goal.append(o["target"])
                obstacle_pos.append(o["obstacle_pos1"])
                #obstacle_pos2.append(o["obstacle_pos2"])
            config, goal, obstacle_pos = map(np.array,
                                             [config, goal, obstacle_pos])
            args = config, goal, obstacle_pos, seg["ac"], atarg
            fvpargs = [arr[::5] for arr in args]

            assign_old_eq_new(
            )  # set old parameter values to new parameter values
            with timed("computegrad"):
                *lossbefore, g = compute_lossandgrad(*args)
            lossbefore = allmean(np.array(lossbefore))
            g = allmean(g)
            if np.allclose(g, 0):
                logger.log("Got zero gradient. not updating")
            else:
                with timed("cg"):
                    stepdir = cg(fisher_vector_product,
                                 g,
                                 cg_iters=cg_iters,
                                 verbose=rank == 0)
                    logger.log(
                        'iter:{:d}, norm of g: {:.4f}, error of cg: {:.4f}'.
                        format(
                            cg_iters, np.linalg.norm(g),
                            np.linalg.norm(g -
                                           compute_fvp(stepdir, *fvpargs))))
                assert np.isfinite(stepdir).all()

                shs = .5 * stepdir.dot(fisher_vector_product(stepdir))
                lm = np.sqrt(shs / max_kl)
                # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
                fullstep = stepdir / lm
                expectedimprove = g.dot(fullstep)
                surrbefore = lossbefore[0]
                stepsize = 1.0
                thbefore = get_flat()
                for _ in range(10):
                    thnew = thbefore + fullstep * stepsize
                    set_from_flat(thnew)
                    meanlosses = surr, kl, *_ = allmean(
                        np.array(compute_losses(*args)))
                    improve = surr - surrbefore
                    logger.log("Expected: %.3f Actual: %.3f" %
                               (expectedimprove, improve))
                    if not np.isfinite(meanlosses).all():
                        logger.log("Got non-finite value of losses -- bad!")
                    elif kl > max_kl * 1.5:
                        logger.log("violated KL constraint. shrinking step.")
                    elif improve < 0:
                        logger.log("surrogate didn't improve. shrinking step.")
                    else:
                        logger.log("Stepsize OK!")
                        break
                    stepsize *= .5
                else:
                    logger.log("couldn't compute a good step")
                    set_from_flat(thbefore)
                if nworkers > 1 and iters_so_far % 20 == 0:
                    paramsums = MPI.COMM_WORLD.allgather(
                        (thnew.sum(),
                         vfadam.getflat().sum()))  # list of tuples
                    assert all(
                        np.allclose(ps, paramsums[0]) for ps in paramsums[1:])
            with timed("vf"):
                for _ in range(vf_iters):
                    for (mbob, mbg, mbop, mbret) in dataset.iterbatches(
                        (config, goal, obstacle_pos, seg["tdlamret"]),
                            include_final_partial_batch=False,
                            batch_size=128):
                        if hasattr(pi, "ob_rms"):
                            pi.ob_rms.update(
                                mbob)  # update running mean/std for policy
                        g = allmean(
                            compute_vflossandgrad(mbob, mbg, mbop, mbret))
                        vfadam.update(g, vf_stepsize)

        g_losses = meanlosses
        for (lossname, lossval) in zip(loss_names, meanlosses):
            logger.record_tabular(lossname, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))

        lrlocal = (seg["ep_lens"], seg["ep_rets"], seg["ep_true_rets"]
                   )  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews, true_rets = map(flatten_lists, zip(*listoflrpairs))
        true_rewbuffer.extend(true_rets)
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpTrueRewMean", np.mean(true_rewbuffer))
        true_reward_mean = np.mean(true_rewbuffer)
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)

        if rank == 0:
            logger.dump_tabular()
예제 #21
0
def learn(
        *,
        network,
        env,
        total_timesteps,
        timesteps_per_batch=1024,  # what to train on
        max_kl=0.001,
        cg_iters=10,
        gamma=0.99,
        lam=1.0,  # advantage estimation
        seed=None,
        ent_coef=0.0,
        cg_damping=1e-2,
        vf_stepsize=3e-4,
        vf_iters=3,
        max_episodes=0,
        max_iters=0,  # time constraint
        callback=None,
        load_path=None,
        **network_kwargs):
    '''
    learn a policy function with TRPO algorithm

    Parameters:
    ----------

    network                 neural network to learn. Can be either string ('mlp', 'cnn', 'lstm', 'lnlstm' for basic types)
                            or function that takes input placeholder and returns tuple (output, None) for feedforward nets
                            or (output, (state_placeholder, state_output, mask_placeholder)) for recurrent nets

    env                     environment (one of the gym environments or wrapped via baselines.common.vec_env.VecEnv-type class

    timesteps_per_batch     timesteps per gradient estimation batch

    max_kl                  max KL divergence between old policy and new policy ( KL(pi_old || pi) )

    ent_coef                coefficient of policy entropy term in the optimization objective

    cg_iters                number of iterations of conjugate gradient algorithm

    cg_damping              conjugate gradient damping

    vf_stepsize             learning rate for adam optimizer used to optimie value function loss

    vf_iters                number of iterations of value function optimization iterations per each policy optimization step

    total_timesteps           max number of timesteps

    max_episodes            max number of episodes

    max_iters               maximum number of policy optimization iterations

    callback                function to be called with (locals(), globals()) each policy optimization step

    load_path               str, path to load the model from (default: None, i.e. no model is loaded)

    **network_kwargs        keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network

    Returns:
    -------

    learnt model

    '''

    if MPI is not None:
        nworkers = MPI.COMM_WORLD.Get_size()
        rank = MPI.COMM_WORLD.Get_rank()
    else:
        nworkers = 1
        rank = 0

    cpus_per_worker = 1
    U.get_session(
        config=tf.ConfigProto(allow_soft_placement=True,
                              inter_op_parallelism_threads=cpus_per_worker,
                              intra_op_parallelism_threads=cpus_per_worker))

    policy = build_policy(env, network, value_network='copy', **network_kwargs)
    set_global_seeds(seed)

    np.set_printoptions(precision=3)
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space

    ob = observation_placeholder(ob_space)
    with tf.variable_scope("pi"):
        pi = policy(observ_placeholder=ob)
    with tf.variable_scope("oldpi"):
        oldpi = policy(observ_placeholder=ob)

    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    entbonus = ent_coef * meanent

    vferr = tf.reduce_mean(tf.square(pi.vf - ret))

    ratio = tf.exp(pi.pd.logp(ac) -
                   oldpi.pd.logp(ac))  # advantage * pnew / pold
    surrgain = tf.reduce_mean(ratio * atarg)

    optimgain = surrgain + entbonus
    losses = [optimgain, meankl, entbonus, surrgain, meanent]
    loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]

    dist = meankl

    all_var_list = get_trainable_variables("pi")
    # var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("pol")]
    # vf_var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("vf")]
    var_list = get_pi_trainable_variables("pi")
    vf_var_list = get_vf_trainable_variables("pi")

    vfadam = MpiAdam(vf_var_list)

    get_flat = U.GetFlat(var_list)
    set_from_flat = U.SetFromFlat(var_list)
    klgrads = tf.gradients(dist, var_list)
    flat_tangent = tf.placeholder(dtype=tf.float32,
                                  shape=[None],
                                  name="flat_tan")
    shapes = [var.get_shape().as_list() for var in var_list]
    start = 0
    tangents = []
    for shape in shapes:
        sz = U.intprod(shape)
        tangents.append(tf.reshape(flat_tangent[start:start + sz], shape))
        start += sz
    gvp = tf.add_n([
        tf.reduce_sum(g * tangent)
        for (g, tangent) in zipsame(klgrads, tangents)
    ])  #pylint: disable=E1111
    fvp = U.flatgrad(gvp, var_list)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(get_variables("oldpi"), get_variables("pi"))
        ])

    compute_losses = U.function([ob, ac, atarg], losses)
    compute_lossandgrad = U.function([ob, ac, atarg], losses +
                                     [U.flatgrad(optimgain, var_list)])
    compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp)
    compute_vflossandgrad = U.function([ob, ret],
                                       U.flatgrad(vferr, vf_var_list))

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(
                colorize("done in %.3f seconds" % (time.time() - tstart),
                         color='magenta'))
        else:
            yield

    def allmean(x):
        assert isinstance(x, np.ndarray)
        if MPI is not None:
            out = np.empty_like(x)
            MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
            out /= nworkers
        else:
            out = np.copy(x)

        return out

    U.initialize()
    if load_path is not None:
        pi.load(load_path)

    th_init = get_flat()
    if MPI is not None:
        MPI.COMM_WORLD.Bcast(th_init, root=0)

    set_from_flat(th_init)
    vfadam.sync()
    print("Init param sum", th_init.sum(), flush=True)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_batch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=40)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40)  # rolling buffer for episode rewards

    if sum([max_iters > 0, total_timesteps > 0, max_episodes > 0]) == 0:
        # noththing to be done
        return pi

    assert sum([max_iters>0, total_timesteps>0, max_episodes>0]) < 2, \
        'out of max_iters, total_timesteps, and max_episodes only one should be specified'

    while True:
        if callback: callback(locals(), globals())
        if total_timesteps and timesteps_so_far >= total_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        logger.log("********** Iteration %i ************" % iters_so_far)

        with timed("sampling"):
            seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate

        if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret)
        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        args = seg["ob"], seg["ac"], atarg
        fvpargs = [arr[::5] for arr in args]

        def fisher_vector_product(p):
            return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p

        assign_old_eq_new()  # set old parameter values to new parameter values
        with timed("computegrad"):
            *lossbefore, g = compute_lossandgrad(*args)
        lossbefore = allmean(np.array(lossbefore))
        g = allmean(g)
        if np.allclose(g, 0):
            logger.log("Got zero gradient. not updating")
        else:
            with timed("cg"):
                stepdir = cg(fisher_vector_product,
                             g,
                             cg_iters=cg_iters,
                             verbose=rank == 0)
            assert np.isfinite(stepdir).all()
            shs = .5 * stepdir.dot(fisher_vector_product(stepdir))
            lm = np.sqrt(shs / max_kl)
            # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
            fullstep = stepdir / lm
            expectedimprove = g.dot(fullstep)
            surrbefore = lossbefore[0]
            stepsize = 1.0
            thbefore = get_flat()
            for _ in range(10):
                thnew = thbefore + fullstep * stepsize
                set_from_flat(thnew)
                meanlosses = surr, kl, *_ = allmean(
                    np.array(compute_losses(*args)))
                improve = surr - surrbefore
                logger.log("Expected: %.3f Actual: %.3f" %
                           (expectedimprove, improve))
                if not np.isfinite(meanlosses).all():
                    logger.log("Got non-finite value of losses -- bad!")
                elif kl > max_kl * 1.5:
                    logger.log("violated KL constraint. shrinking step.")
                elif improve < 0:
                    logger.log("surrogate didn't improve. shrinking step.")
                else:
                    logger.log("Stepsize OK!")
                    break
                stepsize *= .5
            else:
                logger.log("couldn't compute a good step")
                set_from_flat(thbefore)
            if nworkers > 1 and iters_so_far % 20 == 0:
                paramsums = MPI.COMM_WORLD.allgather(
                    (thnew.sum(), vfadam.getflat().sum()))  # list of tuples
                assert all(
                    np.allclose(ps, paramsums[0]) for ps in paramsums[1:])

        for (lossname, lossval) in zip(loss_names, meanlosses):
            logger.record_tabular(lossname, lossval)

        with timed("vf"):

            for _ in range(vf_iters):
                for (mbob, mbret) in dataset.iterbatches(
                    (seg["ob"], seg["tdlamret"]),
                        include_final_partial_batch=False,
                        batch_size=64):
                    g = allmean(compute_vflossandgrad(mbob, mbret))
                    vfadam.update(g, vf_stepsize)

        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))

        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        if MPI is not None:
            listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        else:
            listoflrpairs = [lrlocal]

        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)

        if rank == 0:
            logger.dump_tabular()

    return pi
예제 #22
0
class RLTrainer(object):
    def __init__(self, env, policy, old_policy, config):
        self._env = env
        self._config = config
        self.policy = policy
        self.old_policy = old_policy

        self._entcoeff = config.entcoeff
        self._optim_epochs = config.optim_epochs
        self._optim_stepsize = config.optim_stepsize
        self._optim_batchsize = config.optim_batchsize

        # global step
        self.global_step = tf.Variable(0,
                                       name='global_step',
                                       dtype=tf.int64,
                                       trainable=False)
        self.update_global_step = tf.assign(self.global_step,
                                            self.global_step + 1)

        # tensorboard summary
        self._is_chef = (MPI.COMM_WORLD.Get_rank() == 0)
        self._num_workers = MPI.COMM_WORLD.Get_size()
        if self._is_chef:
            self.summary_name = ["reward", "length"]
            self.summary_name += env.unwrapped.reward_type

        # build loss/optimizers
        if self._config.rl_method == 'trpo':
            self._build_trpo()
        elif self._config.rl_method == 'ppo':
            self._build_ppo()
        else:
            raise NotImplementedError

        if self._is_chef and self._config.is_train:
            self.ep_stats = stats(self.summary_name)
            self.writer = U.file_writer(config.log_dir)

    @contextmanager
    def timed(self, msg):
        if self._is_chef:
            logger.info(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            logger.info(
                colorize("done in %.3f seconds" % (time.time() - tstart),
                         color='magenta'))
        else:
            yield

    def _all_mean(self, x):
        assert isinstance(x, np.ndarray)
        out = np.empty_like(x)
        MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
        out /= self._num_workers
        return out

    def _build_ppo(self):
        config = self._config
        pi = self.policy
        oldpi = self.old_policy

        # input placeholders
        obs = pi.obs
        ac = pi.pdtype.sample_placeholder([None], name='action')
        atarg = tf.placeholder(dtype=tf.float32,
                               shape=[None],
                               name='advantage')
        ret = tf.placeholder(dtype=tf.float32, shape=[None], name='return')

        lrmult = tf.placeholder(name='lrmult', dtype=tf.float32, shape=[])
        self._clip_param = config.clip_param * lrmult

        # policy
        var_list = pi.get_trainable_variables()
        self._adam = MpiAdam(var_list)

        fetch_dict = self.policy_loss_ppo(pi, oldpi, ac, atarg, ret)
        if self._is_chef:
            self.summary_name += ['ppo/' + key for key in fetch_dict.keys()]
            self.summary_name += ['ppo/grad_norm', 'ppo/grad_norm_clipped']
        fetch_dict['g'] = U.flatgrad(fetch_dict['total_loss'], var_list)
        self._loss = U.function([lrmult] + obs + [ac, atarg, ret], fetch_dict)
        self._update_oldpi = U.function(
            [], [],
            updates=[
                tf.assign(oldv, newv) for (
                    oldv,
                    newv) in zipsame(oldpi.get_variables(), pi.get_variables())
            ])

        # initialize and sync
        U.initialize()
        self._adam.sync()

    def _build_trpo(self):
        pi = self.policy
        oldpi = self.old_policy

        # input placeholders
        obs = pi.obs
        ac = pi.pdtype.sample_placeholder([None], name='action')
        atarg = tf.placeholder(
            dtype=tf.float32, shape=[None],
            name='advantage')  # Target advantage function (if applicable)
        ret = tf.placeholder(dtype=tf.float32, shape=[None],
                             name='return')  # Empirical return

        # policy
        all_var_list = pi.get_trainable_variables()
        self.pol_var_list = [
            v for v in all_var_list if v.name.split("/")[2].startswith("pol")
        ]
        self.vf_var_list = [
            v for v in all_var_list if v.name.split("/")[2].startswith("vf")
        ]
        self._vf_adam = MpiAdam(self.vf_var_list)

        kl_oldnew = oldpi.pd.kl(pi.pd)
        ent = pi.pd.entropy()
        mean_kl = tf.reduce_mean(kl_oldnew)
        mean_ent = tf.reduce_mean(ent)
        pol_entpen = -self._config.entcoeff * mean_ent

        vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))

        ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))
        pol_surr = tf.reduce_mean(ratio * atarg)
        pol_loss = pol_surr + pol_entpen

        pol_losses = {
            'pol_loss': pol_loss,
            'pol_surr': pol_surr,
            'pol_entpen': pol_entpen,
            'kl': mean_kl,
            'entropy': mean_ent
        }
        if self._is_chef:
            self.summary_name += ['trpo/vf_loss']
            self.summary_name += ['trpo/' + key for key in pol_losses.keys()]

        self._get_flat = U.GetFlat(self.pol_var_list)
        self._set_from_flat = U.SetFromFlat(self.pol_var_list)
        klgrads = tf.gradients(mean_kl, self.pol_var_list)
        flat_tangent = tf.placeholder(dtype=tf.float32,
                                      shape=[None],
                                      name="flat_tan")
        shapes = [var.get_shape().as_list() for var in self.pol_var_list]
        start = 0
        tangents = []
        for shape in shapes:
            sz = U.intprod(shape)
            tangents.append(tf.reshape(flat_tangent[start:start + sz], shape))
            start += sz
        gvp = tf.add_n([
            tf.reduce_sum(g * tangent)
            for (g, tangent) in zipsame(klgrads, tangents)
        ])  # pylint: disable=E1111
        fvp = U.flatgrad(gvp, self.pol_var_list)

        self._update_oldpi = U.function(
            [], [],
            updates=[
                tf.assign(oldv, newv) for (
                    oldv,
                    newv) in zipsame(oldpi.get_variables(), pi.get_variables())
            ])
        self._compute_losses = U.function(obs + [ac, atarg], pol_losses)
        pol_losses = dict(pol_losses)
        pol_losses.update({'g': U.flatgrad(pol_loss, self.pol_var_list)})
        self._compute_lossandgrad = U.function(obs + [ac, atarg], pol_losses)
        self._compute_fvp = U.function([flat_tangent] + obs + [ac, atarg], fvp)
        self._compute_vflossandgrad = U.function(
            obs + [ret], U.flatgrad(vf_loss, self.vf_var_list))
        self._compute_vfloss = U.function(obs + [ret], vf_loss)

        # initialize and sync
        U.initialize()
        th_init = self._get_flat()
        MPI.COMM_WORLD.Bcast(th_init, root=0)
        self._set_from_flat(th_init)
        self._vf_adam.sync()
        logger.info("Init param sum", th_init.sum())

    def policy_loss_ppo(self, pi, oldpi, ac, atarg, ret):
        kl_oldnew = oldpi.pd.kl(pi.pd)
        ent = pi.pd.entropy()
        mean_kl = U.mean(kl_oldnew)
        mean_ent = U.mean(ent)
        pol_entpen = -self._entcoeff * mean_ent

        action_prob = pi.pd.logp(ac) - oldpi.pd.logp(ac)
        action_loss = tf.exp(action_prob) * atarg

        ratio = tf.exp(action_prob)

        surr1 = ratio * atarg
        surr2 = U.clip(ratio, 1.0 - self._clip_param,
                       1.0 + self._clip_param) * atarg
        pol_surr = -U.mean(tf.minimum(surr1, surr2))
        vf_loss = U.mean(tf.square(pi.vpred - ret))
        total_loss = pol_surr + pol_entpen + vf_loss

        losses = {
            'total_loss': total_loss,
            'action_loss': action_loss,
            'pol_surr': pol_surr,
            'pol_entpen': pol_entpen,
            'kl': mean_kl,
            'entropy': mean_ent,
            'vf_loss': vf_loss
        }
        return losses

    def _summary(self, it):
        if self._is_chef:
            if it % self._config.ckpt_save_step == 0:
                fname = osp.join(self._config.log_dir, '%.5d' % it)
                U.save_state(fname)

    def train(self, rollout):
        config = self._config
        sess = U.get_session()
        global_step = sess.run(self.global_step)
        t = trange(global_step,
                   config.max_iters,
                   total=config.max_iters,
                   initial=global_step)
        info = None

        for step in t:
            # backup checkpoint
            self._summary(step)
            self._cur_lrmult = max(1.0 - float(step) / config.max_iters, 0)

            # rollout
            with self.timed("sampling"):
                rolls = rollout.__next__()
            if config.rl_method == 'trpo':
                rollouts.add_advantage_rl(rolls, 0.99, 0.98)
            elif config.rl_method == 'ppo':
                rollouts.add_advantage_rl(rolls, 0.99, 0.95)

            # train policy
            info = self._update_policy(rolls, step)
            if self._is_chef:
                ep = len(rolls["ep_length"])
                reward_mean = np.mean(rolls["ep_reward"])
                reward_std = np.std(rolls["ep_reward"])
                length_mean = np.mean(rolls["ep_length"])
                length_std = np.std(rolls["ep_length"])
                desc = "ep(%d) reward(%.1f, %.1f) length(%d, %.1f)" % (
                    ep, reward_mean, reward_std, length_mean, length_std)
                for key, value in rolls.items():
                    if key.startswith('ep_'):
                        info[key.split('ep_')[1]] = np.mean(value)

            # log
            if self._is_chef:
                self.ep_stats.add_all_summary_dict(self.writer, info,
                                                   global_step)
                t.set_description(desc)
                global_step = sess.run(self.update_global_step)

    def evaluate(self, rollout, ckpt_num=None):
        config = self._config

        ep_lens = []
        ep_rets = []
        ep_success = []
        if config.record:
            record_dir = osp.join(config.log_dir, 'video')
            os.makedirs(record_dir, exist_ok=True)
        if config.is_collect_state:
            state_dir = osp.join(config.log_dir, 'state')
            os.makedirs(state_dir, exist_ok=True)
            state_file = h5py.File(
                osp.join(
                    state_dir,
                    'seed_{}_traj_{}.hdf5'.format(config.seed,
                                                  config.num_evaluation_run)),
                'w')

        for _ in range(config.num_evaluation_run):
            ep_traj = rollout.__next__()
            ep_lens.append(ep_traj["ep_length"][0])
            ep_rets.append(ep_traj["ep_reward"][0])
            if "ep_success" not in ep_traj:
                ep_success.append(0)
            else:
                ep_success.append(np.sum(ep_traj["ep_success"]))
            if config.evaluation_log:
                logger.log('Trial #{}: lengths {}, returns {}'.format(
                    _, ep_traj["ep_length"][0], ep_traj["ep_reward"][0]))

            # Video recording
            if config.record:
                visual_obs = ep_traj["visual_obs"]
                video_name = (config.video_prefix
                              or '') + '{}{}_rew_{:.2f}_len_{}.mp4'.format(
                                  '' if ckpt_num is None else
                                  'ckpt_{}_'.format(ckpt_num), _,
                                  ep_traj["ep_reward"][0],
                                  ep_traj["ep_length"][0])
                video_path = osp.join(record_dir, video_name)
                fps = 60.

                def f(t):
                    frame_length = len(visual_obs)
                    new_fps = 1. / (1. / fps + 1. / frame_length)
                    idx = min(int(t * new_fps), frame_length - 1)
                    return visual_obs[idx]

                video = mpy.VideoClip(f, duration=len(visual_obs) / fps + 2)
                video.write_videofile(video_path,
                                      fps,
                                      verbose=False,
                                      progress_bar=False)
            if config.is_collect_state:
                grp = state_file.create_group('traj_{}'.format(_))
                grp["obs"] = ep_traj["obs"]
                grp["len"] = ep_traj["ep_length"][0]
                grp["ret"] = ep_traj["ep_reward"][0]
                try:
                    grp["success"] = ep_traj["ep_success"][0]
                except:
                    pass

        if config.is_collect_state:
            state_file.close()
        logger.log('Episode Length: {}'.format(
            sum(ep_lens) / config.num_evaluation_run))
        logger.log('Episode Rewards: {}'.format(
            sum(ep_rets) / config.num_evaluation_run))
        logger.log('Episode Success: {}'.format(
            sum(ep_success) / config.num_evaluation_run))

        if config.final_eval:
            path = os.path.join(os.path.expanduser('~'), 'iclr_eval',
                                self._config.prefix + ".txt")
            with open(path, "w") as f:
                for s in ep_success:
                    f.write(str(s))
                    f.write('\n')

    def _update_policy(self, seg, it):
        if self._config.rl_method == 'trpo':
            info = self._update_policy_trpo(seg, it)
        elif self._config.rl_method == 'ppo':
            info = self._update_policy_ppo(seg)
        return info

    def _update_policy_trpo(self, seg, it):
        pi = self.policy
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        atarg = (atarg - atarg.mean()) / atarg.std()

        if self._is_chef:
            info = defaultdict(list)

        ob_dict = self._env.get_ob_dict(ob)
        for ob_name in pi.ob_type:
            pi.ob_rms[ob_name].update(ob_dict[ob_name])

        ob_list = pi.get_ob_list(ob_dict)
        args = ob_list + [ac, atarg]
        fvpargs = [arr[::5] for arr in args]

        def fisher_vector_product(p):
            return self._all_mean(self._compute_fvp(
                p, *fvpargs)) + self._config.cg_damping * p

        self._update_oldpi()

        with self.timed("computegrad"):
            lossbefore = self._compute_lossandgrad(*args)
            lossbefore = {
                k: self._all_mean(np.array(lossbefore[k]))
                for k in sorted(lossbefore.keys())
            }
        g = lossbefore['g']

        if np.allclose(g, 0):
            logger.log("Got zero gradient. not updating")
        else:
            with self.timed("cg"):
                stepdir = cg(fisher_vector_product,
                             g,
                             cg_iters=self._config.cg_iters,
                             verbose=self._is_chef)
            assert np.isfinite(stepdir).all()
            shs = .5 * stepdir.dot(fisher_vector_product(stepdir))
            lm = np.sqrt(shs / self._config.max_kl)
            fullstep = stepdir / lm
            expectedimprove = g.dot(fullstep)
            surrbefore = lossbefore['pol_loss']
            stepsize = 1.0
            thbefore = self._get_flat()
            for _ in range(10):
                thnew = thbefore + fullstep * stepsize
                self._set_from_flat(thnew)
                meanlosses = self._compute_losses(*args)
                meanlosses = {
                    k: self._all_mean(np.array(meanlosses[k]))
                    for k in sorted(meanlosses.keys())
                }
                # logger.info('mean', [float(meanlosses[k]) for k in ['pol_loss', 'kl', 'pol_entpen', 'pol_surr', 'entropy']])
                if self._is_chef:
                    for key, value in meanlosses.items():
                        if key != 'g':
                            info['trpo/' + key].append(value)
                surr = meanlosses['pol_loss']
                kl = meanlosses['kl']
                meanlosses = np.array(list(meanlosses.values()))
                improve = surr - surrbefore
                logger.log("Expected: %.3f Actual: %.3f" %
                           (expectedimprove, improve))
                if not np.isfinite(meanlosses).all():
                    logger.log("Got non-finite value of losses -- bad!")
                elif kl > self._config.max_kl * 1.5:
                    logger.log("violated KL constraint. shrinking step.")
                elif improve < 0:
                    logger.log("surrogate didn't improve. shrinking step.")
                else:
                    logger.log("Stepsize OK!")
                    break
                stepsize *= .5
            else:
                logger.log("couldn't compute a good step")
                self._set_from_flat(thbefore)
            if self._num_workers > 1 and it % 20 == 0:
                paramsums = MPI.COMM_WORLD.allgather(
                    (thnew.sum(),
                     self._vf_adam.getflat().sum()))  # list of tuples
                assert all(
                    np.allclose(ps, paramsums[0]) for ps in paramsums[1:])

        with self.timed("vf"):
            for _ in range(self._config.vf_iters):
                for (mbob, mbret) in dataset.iterbatches(
                    (ob, tdlamret),
                        include_final_partial_batch=False,
                        batch_size=64):
                    ob_list = pi.get_ob_list(mbob)
                    g = self._all_mean(
                        self._compute_vflossandgrad(*ob_list, mbret))
                    self._vf_adam.update(g, self._config.vf_stepsize)
                    vf_loss = self._all_mean(
                        np.array(self._compute_vfloss(*ob_list, mbret)))
                    if self._is_chef:
                        info['trpo/vf_loss'].append(vf_loss)

        if self._is_chef:
            for key, value in info.items():
                info[key] = np.mean(value)
            return info
        return None

    def _update_policy_ppo(self, seg):
        pi = self.policy
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        atarg = (atarg - atarg.mean()) / max(atarg.std(), 0.000001)

        if self._is_chef:
            info = defaultdict(list)

        optim_batchsize = min(self._optim_batchsize, ob.shape[0])
        # prepare batches
        d = dataset.Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                            shuffle=True)

        ob_dict = self._env.get_ob_dict(ob)
        for ob_name in pi.ob_type:
            pi.ob_rms[ob_name].update(ob_dict[ob_name])

        self._update_oldpi()

        with self.timed("update"):
            for _ in range(self._optim_epochs):
                for batch in d.iterate_once(optim_batchsize):
                    ob_list = pi.get_ob_list(batch["ob"])
                    fetched = self._loss(self._cur_lrmult, *ob_list,
                                         batch["ac"], batch["atarg"],
                                         batch["vtarg"])
                    self._adam.update(fetched['g'],
                                      self._optim_stepsize * self._cur_lrmult)
                    if self._is_chef:
                        for key, value in fetched.items():
                            if key != 'g':
                                if np.isscalar(value):
                                    info['ppo/' + key].append(value)
                                else:
                                    info['ppo/' + key].extend(value)
                            else:
                                grad_norm_value = np.linalg.norm(value)
                                info['ppo/grad_norm'].append(grad_norm_value)
                                info['ppo/grad_norm_clipped'].append(
                                    np.clip(grad_norm_value, 0,
                                            self._config.trans_max_grad_norm))

        if self._is_chef:
            for key, value in info.items():
                info[key] = np.mean(value)
            return info
        return None