Example #1
0
 def __init__(self, env, hidden_size, entcoeff=0.001, lr_rate=1e-3, scope="adversary"):
     self.scope = scope
     self.observation_shape = env.observation_space.shape
     self.actions_shape = env.action_space.shape
     self.input_shape = tuple([o+a for o, a in zip(self.observation_shape, self.actions_shape)])
     self.num_actions = env.action_space.shape[0]
     self.hidden_size = hidden_size
     self.build_ph()
     # Build grpah
     generator_logits = self.build_graph(self.generator_obs_ph, self.generator_acs_ph, reuse=False)
     expert_logits = self.build_graph(self.expert_obs_ph, self.expert_acs_ph, reuse=True)
     # Build accuracy
     generator_acc = tf.reduce_mean(tf.to_float(tf.nn.sigmoid(generator_logits) < 0.5))
     expert_acc = tf.reduce_mean(tf.to_float(tf.nn.sigmoid(expert_logits) > 0.5))
     # Build regression loss
     # let x = logits, z = targets.
     # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
     generator_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=generator_logits, labels=tf.zeros_like(generator_logits))
     generator_loss = tf.reduce_mean(generator_loss)
     expert_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=expert_logits, labels=tf.ones_like(expert_logits))
     expert_loss = tf.reduce_mean(expert_loss)
     # Build entropy loss
     logits = tf.concat([generator_logits, expert_logits], 0)
     entropy = tf.reduce_mean(logit_bernoulli_entropy(logits))
     entropy_loss = -entcoeff*entropy
     # Loss + Accuracy terms
     self.losses = [generator_loss, expert_loss, entropy, entropy_loss, generator_acc, expert_acc]
     self.loss_name = ["generator_loss", "expert_loss", "entropy", "entropy_loss", "generator_acc", "expert_acc"]
     self.total_loss = generator_loss + expert_loss + entropy_loss
     # Build Reward for policy
     self.reward_op = -tf.log(1-tf.nn.sigmoid(generator_logits)+1e-8)
     var_list = self.get_trainable_variables()
     self.lossandgrad = U.function([self.generator_obs_ph, self.generator_acs_ph, self.expert_obs_ph, self.expert_acs_ph],
                                   self.losses + [U.flatgrad(self.total_loss, var_list)])
Example #2
0
 def setup_actor_optimizer(self):
     logger.info('setting up actor optimizer')
     self.actor_loss = -tf.reduce_mean(self.critic_with_actor_tf)
     actor_shapes = [var.get_shape().as_list() for var in self.actor.trainable_vars]
     actor_nb_params = sum([reduce(lambda x, y: x * y, shape) for shape in actor_shapes])
     logger.info('  actor shapes: {}'.format(actor_shapes))
     logger.info('  actor params: {}'.format(actor_nb_params))
     self.actor_grads = U.flatgrad(self.actor_loss, self.actor.trainable_vars, clip_norm=self.clip_norm)
     self.actor_optimizer = MpiAdam(var_list=self.actor.trainable_vars,
         beta1=0.9, beta2=0.999, epsilon=1e-08)
Example #3
0
 def setup_critic_optimizer(self):
     logger.info('setting up critic optimizer')
     normalized_critic_target_tf = tf.clip_by_value(normalize(self.critic_target, self.ret_rms), self.return_range[0], self.return_range[1])
     self.critic_loss = tf.reduce_mean(tf.square(self.normalized_critic_tf - normalized_critic_target_tf))
     if self.critic_l2_reg > 0.:
         critic_reg_vars = [var for var in self.critic.trainable_vars if 'kernel' in var.name and 'output' not in var.name]
         for var in critic_reg_vars:
             logger.info('  regularizing: {}'.format(var.name))
         logger.info('  applying l2 regularization with {}'.format(self.critic_l2_reg))
         critic_reg = tc.layers.apply_regularization(
             tc.layers.l2_regularizer(self.critic_l2_reg),
             weights_list=critic_reg_vars
         )
         self.critic_loss += critic_reg
     critic_shapes = [var.get_shape().as_list() for var in self.critic.trainable_vars]
     critic_nb_params = sum([reduce(lambda x, y: x * y, shape) for shape in critic_shapes])
     logger.info('  critic shapes: {}'.format(critic_shapes))
     logger.info('  critic params: {}'.format(critic_nb_params))
     self.critic_grads = U.flatgrad(self.critic_loss, self.critic.trainable_vars, clip_norm=self.clip_norm)
     self.critic_optimizer = MpiAdam(var_list=self.critic.trainable_vars,
         beta1=0.9, beta2=0.999, epsilon=1e-08)
def learn(env, policy_func, dataset, optim_batch_size=128, max_iters=1e4,
          adam_epsilon=1e-5, optim_stepsize=3e-4,
          ckpt_dir=None, log_dir=None, task_name=None,
          verbose=False):

    val_per_iter = int(max_iters/10)
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_func("pi", ob_space, ac_space)  # Construct network for new policy
    # placeholder
    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])
    stochastic = U.get_placeholder_cached(name="stochastic")
    loss = tf.reduce_mean(tf.square(ac-pi.ac))
    var_list = pi.get_trainable_variables()
    adam = MpiAdam(var_list, epsilon=adam_epsilon)
    lossandgrad = U.function([ob, ac, stochastic], [loss]+[U.flatgrad(loss, var_list)])

    U.initialize()
    adam.sync()
    logger.log("Pretraining with Behavior Cloning...")
    for iter_so_far in tqdm(range(int(max_iters))):
        ob_expert, ac_expert = dataset.get_next_batch(optim_batch_size, 'train')
        train_loss, g = lossandgrad(ob_expert, ac_expert, True)
        adam.update(g, optim_stepsize)
        if verbose and iter_so_far % val_per_iter == 0:
            ob_expert, ac_expert = dataset.get_next_batch(-1, 'val')
            val_loss, _ = lossandgrad(ob_expert, ac_expert, True)
            logger.log("Training loss: {}, Validation loss: {}".format(train_loss, val_loss))

    if ckpt_dir is None:
        savedir_fname = tempfile.TemporaryDirectory().name
    else:
        savedir_fname = osp.join(ckpt_dir, task_name)
    U.save_state(savedir_fname, var_list=pi.get_variables())
    return savedir_fname
Example #5
0
def test_MpiAdam():
    np.random.seed(0)
    tf.set_random_seed(0)

    a = tf.Variable(np.random.randn(3).astype('float32'))
    b = tf.Variable(np.random.randn(2,5).astype('float32'))
    loss = tf.reduce_sum(tf.square(a)) + tf.reduce_sum(tf.sin(b))

    stepsize = 1e-2
    update_op = tf.train.AdamOptimizer(stepsize).minimize(loss)
    do_update = U.function([], loss, updates=[update_op])

    tf.get_default_session().run(tf.global_variables_initializer())
    losslist_ref = []
    for i in range(10):
        l = do_update()
        print(i, l)
        losslist_ref.append(l)



    tf.set_random_seed(0)
    tf.get_default_session().run(tf.global_variables_initializer())

    var_list = [a,b]
    lossandgrad = U.function([], [loss, U.flatgrad(loss, var_list)])
    adam = MpiAdam(var_list)

    losslist_test = []
    for i in range(10):
        l,g = lossandgrad()
        adam.update(g, stepsize)
        print(i,l)
        losslist_test.append(l)

    np.testing.assert_allclose(np.array(losslist_ref), np.array(losslist_test), atol=1e-4)
def train_student(klts):
    env = make_mujoco_env("Reacher-v2", 0)
    with tf.Session() as sess:
        # Initialize agents
        student = StudentAgent(env, sess, False, klts)
        teacher = TeacherAgent(env, sess, True)

        # This observation placeholder is for querying teacher action
        # ob_ph = U.get_placeholder( name="ob", dtype=tf.float32,
        #       shape=[1, env.observation_space.shape[0] ] )

        ob_placeholder = U.get_placeholder(name="ob",
                                           dtype=tf.float32,
                                           shape=[TRAINING_BATCH_SIZE] +
                                           list(env.observation_space.shape))

        # get all hidden layer variables of the student pi
        student_var = tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES,
            scope="s_pi_{0}".format("klts" if klts else "klst"))
        # print(student_var)
        teacher_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                        scope="t_pi")

        # KL Divergence
        if klts:
            kl_div = teacher.pi.pd.kl(student.pi.pd)
        else:
            kl_div = student.pi.pd.kl(teacher.pi.pd)

        # define loss and gradient with thenos-like function
        # gradients wrt only to student variables
        lossandgrad = U.function([ob_placeholder],
                                 [kl_div] + [U.flatgrad(kl_div, student_var)])

        logstd = U.function([ob_placeholder],
                            [teacher.pi.pd.logstd, student.pi.pd.logstd])
        std = U.function([ob_placeholder],
                         [teacher.pi.pd.std, student.pi.pd.std])
        mean = U.function([ob_placeholder],
                          [teacher.pi.pd.mean, student.pi.pd.mean])

        # initialize only student variables
        U.initialize(
            tf.get_collection(
                tf.GraphKeys.GLOBAL_VARIABLES,
                scope="s_pi_{0}".format("klts" if klts else "klst")))

        # Adam optimiizer
        adam = MpiAdam(student_var, epsilon=1e-3)
        adam.sync()

        ob = env.reset()

        obs = []
        losses = []
        timesteps = []
        rets = []
        ret = 0
        num_episodes_completed = 0

        saver = tf.train.Saver(var_list=tf.get_collection(
            tf.GraphKeys.GLOBAL_VARIABLES,
            scope='s_pi_{0}'.format("klts" if klts else "klst")))

        # saver.restore(sess, "/Users/winstonww/RL/reacher_v1/student_{0}.ckpt".format("klts" if klts else "klst"))

        for timestep in range(1, TOTAL_EPISODES * TIMESTEPS_PER_EPISODE):
            # sample action
            # feed obs dict of size two
            #  append to zeros of size [100.11], so that we can use the same
            #  model to query and train at the same time
            ob = np.expand_dims(
                ob, axis=0) + np.zeros([TRAINING_BATCH_SIZE] +
                                       list(env.observation_space.shape))
            s_ac, _ = student.pi.act(False, ob)
            # print( " ob size: {0} ".format(ob.shape))
            # print( "s_ac shape" )
            # print( s_ac.shape )
            # tread along the student trajectory
            ob, reward, new, _ = env.step(s_ac)
            ret += reward
            if new:
                rets.append(ret)
                ret = 0
                ob = env.reset()
                num_episodes_completed += 1
                if num_episodes_completed > 40000:
                    break
            # env.render()
            # print( "ob to be appended: {0}".format(ob))
            obs.append(ob)

            # compute newloss and its gradient from the two actions sampled
            # if (timestep % TRAINING_BATCH_SIZE != 0 or not timestep):
            #     continue

            # accumulate more samples before starting
            if len(obs) < TRAINING_BATCH_SIZE:
                continue

            d = Dataset(dict(ob=np.array(obs)))
            batch = d.next_batch(TRAINING_BATCH_SIZE)

            newloss, g = lossandgrad(
                np.squeeze(np.stack(list(batch.values()), axis=0), axis=0))
            adam.update(g, 0.001)

            # record the following data only when reset to save time
            if new:
                losses.append(sum(newloss))
                timesteps.append(timestep)

                if num_episodes_completed % 100 == 0:
                    print("********** Episode {0} ***********".format(
                        num_episodes_completed))
                    print("obs: \n{0}".format(
                        np.squeeze(np.stack(list(batch.values()), axis=0),
                                   axis=0)))
                    t_m, s_m = mean(
                        np.squeeze(np.stack(list(batch.values()), axis=0),
                                   axis=0))
                    t_std, s_std = std(
                        np.squeeze(np.stack(list(batch.values()), axis=0),
                                   axis=0))
                    print("student pd std: \n{0}".format(s_std))
                    print("teacher pd std: \n{0}".format(t_std))
                    print("student pd mean: \n{0}".format(s_m))
                    print("teacher pd mean: \n{0}".format(t_m))
                    print("KL divergence: \n{0}".format(sum(newloss)))

            if timestep % 5000 == 0:
                # save results
                np.save(
                    klts_training_loss_path
                    if klts else klst_training_loss_path, losses)
                np.save(
                    klts_training_ret_path if klts else klst_training_ret_path,
                    rets)
                # save kl
                save_path = saver.save(
                    sess,
                    "/Users/winstonww/RL/reacher_v1/student_{0}.ckpt".format(
                        "klts" if klts else "klst"))

        # save results
        np.save(klts_training_loss_path if klts else klst_training_loss_path,
                losses)
        np.save(klts_training_ret_path if klts else klst_training_ret_path,
                rets)

        save_path = saver.save(
            sess, "/Users/winstonww/RL/reacher_v1/student_{0}.ckpt".format(
                "klts" if klts else "klst"))
def learn(env, policy_fn, *,
        timesteps_per_actorbatch, # timesteps per actor per update
        clip_param, entcoeff, # clipping parameter epsilon, entropy coeff
        optim_epochs, optim_stepsize, optim_batchsize,# optimization hypers
        gamma, lam, # advantage estimation
        max_timesteps=0, max_episodes=0, max_iters=0, max_seconds=0,  # time constraint
        callback=None, # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant' # annealing for stepsize parameters (epsilon and adam)
        ):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_fn("pi", ob_space, ac_space) # Construct network for new policy
    oldpi = policy_fn("oldpi", ob_space, ac_space) # Network for old policy
    atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return

    lrmult = tf.placeholder(name='lrmult', dtype=tf.float32, shape=[]) # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # pnew / pold
    surr1 = ratio * atarg # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg #
    pol_surr = - tf.reduce_mean(tf.minimum(surr1, surr2)) # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult], losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function([],[], updates=[tf.assign(oldv, newv)
        for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    adam.sync()

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi, env, timesteps_per_actorbatch, stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100) # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100) # rolling buffer for episode rewards

    assert sum([max_iters>0, max_timesteps>0, max_episodes>0, max_seconds>0])==1, "Only one time constraint permitted"

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult =  max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************"%iters_so_far)

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"]
        vpredbefore = seg["vpred"] # predicted value function before udpate
        atarg = (atarg - atarg.mean()) / atarg.std() # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret), shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy

        assign_old_eq_new() # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [] # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult)
            losses.append(newlosses)
        meanlosses,_,_ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_"+name, lossval)
        logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        if MPI.COMM_WORLD.Get_rank()==0:
            logger.dump_tabular()
Example #8
0
def learn(
        *,
        policy_network,
        classifier_network,
        env,
        max_iters,
        timesteps_per_batch=1024,  # what to train on
        max_kl=0.001,
        cg_iters=10,
        gamma=0.99,
        lam=1.0,  # advantage estimation
        seed=None,
        entcoeff=0.0,
        cg_damping=1e-2,
        vf_stepsize=3e-4,
        vf_iters=3,
        expert_trajs_path='./expert_trajs',
        num_expert_trajs=500,
        g_step=1,
        d_step=1,
        classifier_entcoeff=1e-3,
        num_particles=5,
        d_stepsize=3e-4,
        max_episodes=0,
        total_timesteps=0,  # time constraint
        callback=None,
        load_path=None,
        save_path=None,
        render=False,
        use_classifier_logsumexp=True,
        use_reward_logsumexp=False,
        use_svgd=True,
        **policy_network_kwargs):
    '''
    learn a policy function with TRPO algorithm
    
    Parameters:
    ----------

    network                 neural network to learn. Can be either string ('mlp', 'cnn', 'lstm', 'lnlstm' for basic types)
                            or function that takes input placeholder and returns tuple (output, None) for feedforward nets
                            or (output, (state_placeholder, state_output, mask_placeholder)) for recurrent nets

    env                     environment (one of the gym environments or wrapped via baselines.common.vec_env.VecEnv-type class

    timesteps_per_batch     timesteps per gradient estimation batch

    max_kl                  max KL divergence between old policy and new policy ( KL(pi_old || pi) )

    entcoeff                coefficient of policy entropy term in the optimization objective

    cg_iters                number of iterations of conjugate gradient algorithm

    cg_damping              conjugate gradient damping 

    vf_stepsize             learning rate for adam optimizer used to optimie value function loss

    vf_iters                number of iterations of value function optimization iterations per each policy optimization step

    total_timesteps         max number of timesteps

    max_episodes            max number of episodes
    
    max_iters               maximum number of policy optimization iterations

    callback                function to be called with (locals(), globals()) each policy optimization step

    load_path               str, path to load the model from (default: None, i.e. no model is loaded)

    **network_kwargs        keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network

    Returns:
    -------

    learnt model

    '''

    nworkers = MPI.COMM_WORLD.Get_size()
    if nworkers > 1:
        raise NotImplementedError
    rank = MPI.COMM_WORLD.Get_rank()

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.49)
    U.get_session(config=tf.ConfigProto(allow_soft_placement=True,
                                        gpu_options=gpu_options))

    policy = build_policy(env,
                          policy_network,
                          value_network='copy',
                          **policy_network_kwargs)
    set_global_seeds(seed)

    np.set_printoptions(precision=3)
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space

    ob = observation_placeholder(ob_space)
    with tf.variable_scope("pi"):
        pi = policy(observ_placeholder=ob)
    with tf.variable_scope("oldpi"):
        oldpi = policy(observ_placeholder=ob)

    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    entbonus = entcoeff * meanent

    vferr = tf.reduce_mean(tf.square(pi.vf - ret))

    ratio = tf.exp(pi.pd.logp(ac) -
                   oldpi.pd.logp(ac))  # advantage * pnew / pold
    surrgain = tf.reduce_mean(ratio * atarg)

    optimgain = surrgain + entbonus
    losses = [optimgain, meankl, entbonus, surrgain, meanent]
    loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]

    dist = meankl

    all_var_list = get_trainable_variables("pi")
    # var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("pol")]
    # vf_var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("vf")]
    var_list = get_pi_trainable_variables("pi")
    vf_var_list = get_vf_trainable_variables("pi")

    vfadam = MpiAdam(vf_var_list)

    get_flat = U.GetFlat(var_list)
    set_from_flat = U.SetFromFlat(var_list)
    klgrads = tf.gradients(dist, var_list)
    flat_tangent = tf.placeholder(dtype=tf.float32,
                                  shape=[None],
                                  name="flat_tan")
    shapes = [var.get_shape().as_list() for var in var_list]
    start = 0
    tangents = []
    for shape in shapes:
        sz = U.intprod(shape)
        tangents.append(tf.reshape(flat_tangent[start:start + sz], shape))
        start += sz
    gvp = tf.add_n([
        tf.reduce_sum(g * tangent)
        for (g, tangent) in zipsame(klgrads, tangents)
    ])  #pylint: disable=E1111
    fvp = U.flatgrad(gvp, var_list)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(get_variables("oldpi"), get_variables("pi"))
        ])

    compute_losses = U.function([ob, ac, atarg], losses)
    compute_lossandgrad = U.function([ob, ac, atarg], losses +
                                     [U.flatgrad(optimgain, var_list)])
    compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp)
    compute_vflossandgrad = U.function([ob, ret],
                                       U.flatgrad(vferr, vf_var_list))

    D = build_classifier(env, classifier_network, num_particles,
                         classifier_entcoeff, use_classifier_logsumexp,
                         use_reward_logsumexp)
    grads_list, vars_list = D.get_grads_and_vars()

    if use_svgd:
        optimizer = SVGD(
            grads_list, vars_list,
            lambda: tf.train.AdamOptimizer(learning_rate=d_stepsize))
    else:
        optimizer = Ensemble(
            grads_list, vars_list,
            lambda: tf.train.AdamOptimizer(learning_rate=d_stepsize))

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='yellow'))
            tstart = time.time()
            yield
            print(
                colorize("done in %.3f seconds" % (time.time() - tstart),
                         color='blue'))
        else:
            yield

    def allmean(x):
        assert isinstance(x, np.ndarray)
        out = np.empty_like(x)
        MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
        out /= nworkers
        return out

    U.initialize()

    if rank == 0:
        saver = tf.train.Saver(var_list=get_variables("pi"), max_to_keep=10000)
        writer = FileWriter(os.path.join(save_path, 'logs'))
        stats = Statistics(
            scalar_keys=["average_return", "average_episode_length"])

    if load_path is not None:
        # pi.load(load_path)
        saver.restore(U.get_session(), load_path)

    th_init = get_flat()
    MPI.COMM_WORLD.Bcast(th_init, root=0)
    set_from_flat(th_init)
    vfadam.sync()
    print("Init param sum", th_init.sum(), flush=True)

    # Prepare for rollouts
    # ----------------------------------------
    if load_path is not None:
        seg_gen = traj_segment_generator(pi,
                                         env,
                                         1,
                                         stochastic=False,
                                         render=render)
    else:
        seg_gen = traj_segment_generator(pi,
                                         env,
                                         timesteps_per_batch,
                                         stochastic=True,
                                         render=render)
    seg_gen_e = expert_traj_segment_generator(env, expert_trajs_path,
                                              timesteps_per_batch,
                                              num_expert_trajs)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=40)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40)  # rolling buffer for episode rewards

    if sum([max_iters > 0, total_timesteps > 0, max_episodes > 0]) == 0:
        # nothing to be done
        return pi

    assert sum([max_iters>0, total_timesteps>0, max_episodes>0]) < 2, \
        'out of max_iters, total_timesteps, and max_episodes only one should be specified'

    while True:
        if callback: callback(locals(), globals())
        if total_timesteps and timesteps_so_far >= total_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        logger.log("********** Iteration %i ************" % iters_so_far)

        if iters_so_far % 500 == 0 and save_path is not None and load_path is None:
            fname = os.path.join(save_path, 'checkpoints', 'checkpoint')
            save_state(fname, saver, iters_so_far)

        with timed("sampling"):
            seg = seg_gen.__next__()

        if load_path is not None:
            iters_so_far += 1
            logger.record_tabular("EpRew", int(np.mean(seg["ep_true_rets"])))
            logger.record_tabular("EpLen", int(np.mean(seg["ep_lens"])))
            logger.dump_tabular()
            continue

        seg["rew"] = D.get_reward(seg["ob"], seg["ac"])

        add_vtarg_and_adv(seg, gamma, lam)

        ob, ac, ep_lens, atarg, tdlamret = seg["ob"], seg["ac"], seg[
            "ep_lens"], seg["adv"], seg["tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate

        if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret)
        if hasattr(pi, "rms"):
            pi.rms.update(ob)  # update running mean/std for policy

        args = seg["ob"], seg["ac"], atarg
        fvpargs = [arr[::5] for arr in args]

        def fisher_vector_product(p):
            return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p

        assign_old_eq_new()  # set old parameter values to new parameter values
        with timed("computegrad"):
            *lossbefore, g = compute_lossandgrad(*args)
        lossbefore = allmean(np.array(lossbefore))
        g = allmean(g)
        if np.allclose(g, 0):
            logger.log("Got zero gradient. not updating")
        else:
            with timed("cg"):
                stepdir = cg(fisher_vector_product,
                             g,
                             cg_iters=cg_iters,
                             verbose=rank == 0)
            assert np.isfinite(stepdir).all()
            shs = .5 * stepdir.dot(fisher_vector_product(stepdir))
            lm = np.sqrt(shs / max_kl)
            # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
            fullstep = stepdir / lm
            expectedimprove = g.dot(fullstep)
            surrbefore = lossbefore[0]
            stepsize = 1.0
            thbefore = get_flat()
            for _ in range(10):
                thnew = thbefore + fullstep * stepsize
                set_from_flat(thnew)
                meanlosses = surr, kl, *_ = allmean(
                    np.array(compute_losses(*args)))
                improve = surr - surrbefore
                logger.log("Expected: %.3f Actual: %.3f" %
                           (expectedimprove, improve))
                if not np.isfinite(meanlosses).all():
                    logger.log("Got non-finite value of losses -- bad!")
                elif kl > max_kl * 1.5:
                    logger.log("violated KL constraint. shrinking step.")
                elif improve < 0:
                    logger.log("surrogate didn't improve. shrinking step.")
                else:
                    logger.log("Stepsize OK!")
                    break
                stepsize *= .5
            else:
                logger.log("couldn't compute a good step")
                set_from_flat(thbefore)
            if nworkers > 1 and iters_so_far % 20 == 0:
                paramsums = MPI.COMM_WORLD.allgather(
                    (thnew.sum(), vfadam.getflat().sum()))  # list of tuples
                assert all(
                    np.allclose(ps, paramsums[0]) for ps in paramsums[1:])

        for (lossname, lossval) in zip(loss_names, meanlosses):
            logger.record_tabular(lossname, lossval)

        with timed("vf"):

            for _ in range(vf_iters):
                for (mbob, mbret) in dataset.iterbatches(
                    (seg["ob"], seg["tdlamret"]),
                        include_final_partial_batch=False,
                        batch_size=1000):
                    g = allmean(compute_vflossandgrad(mbob, mbret))
                    vfadam.update(g, vf_stepsize)

        with timed("sample expert trajectories"):
            ob_a, ac_a, ep_lens_a = ob, ac, ep_lens
            seg_e = seg_gen_e.__next__()
            ob_e, ac_e, ep_lens_e = seg_e["ob"], seg_e["ac"], seg_e["ep_lens"]

        if hasattr(D, "rms"):
            obs = np.concatenate([ob_a, ob_e], axis=0)
            if isinstance(ac_space, spaces.Box):
                acs = np.concatenate([ac_a, ac_e], axis=0)
                D.rms.update(np.concatenate([obs, acs], axis=1))
            elif isinstance(ac_space, spaces.Discrete):
                D.rms.update(obs)
            else:
                raise NotImplementedError

        with timed("SVGD"):
            sess = tf.get_default_session()
            feed_dict = {
                D.Xs['a']: ob_a,
                D.As['a']: ac_a,
                D.Ls['a']: ep_lens_a,
                D.Xs['e']: ob_e,
                D.As['e']: ac_e,
                D.Ls['e']: ep_lens_e
            }
            for _ in range(d_step):
                sess.run(optimizer.update_op, feed_dict=feed_dict)

        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))

        lrlocal = (seg["ep_lens"], seg["ep_true_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)

        if rank == 0:
            logger.dump_tabular()
            stats.add_all_summary(
                writer,
                [np.mean(rewbuffer), np.mean(lenbuffer)], iters_so_far)
            rewbuffer.clear()
            lenbuffer.clear()

    return pi
Example #9
0
def learn(
        env,
        policy_func,
        *,
        timesteps_per_batch,  # timesteps per actor per update
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        optim_epochs,
        optim_stepsize,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,  # advantage estimation
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        max_seconds=0,  # time constraint
        callback=None,  # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
        num_options=1,
        app='',
        saves=False,
        wsaves=False,
        epoch=-1,
        seed=1,
        dc=0):

    optim_batchsize_ideal = optim_batchsize
    np.random.seed(seed)
    tf.set_random_seed(seed)
    env.seed(seed)

    ### Book-keeping
    gamename = env.spec.id[:-3].lower()
    gamename += 'seed' + str(seed)
    gamename += app
    version_name = 'FINAL_NORM-ACT-LOWER-LR-len-400-wNoise-update1-ppo-ESCH-1-0-0-nI'

    dirname = '{}_{}_{}opts_saves/'.format(version_name, gamename, num_options)
    print(dirname)
    #input ("wait here after dirname")

    if wsaves:
        first = True
        if not os.path.exists(dirname):
            os.makedirs(dirname)
            first = False
        # while os.path.exists(dirname) and first:
        #     dirname += '0'

        files = ['pposgd_simple.py', 'mlp_policy.py', 'run_mujoco.py']
        first = True
        for i in range(len(files)):
            src = os.path.join(
                '/home/nfunk/Code_MA/ppoc_off_tryout/baselines/baselines/ppo1/'
            ) + files[i]
            print(src)
            #dest = os.path.join('/home/nfunk/results_NEW/ppo1/') + dirname
            dest = dirname + "src_code/"
            if (first):
                os.makedirs(dest)
                first = False
            print(dest)
            shutil.copy2(src, dest)
        # brute force copy normal env file at end of copying process:
        src = os.path.join(
            '/home/nfunk/Code_MA/ppoc_off_tryout/nfunk/envs_nf/pendulum_nf.py')
        shutil.copy2(src, dest)
    ###

    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    max_action = env.action_space.high

    # add the dimension in the observation space!
    ob_space.shape = ((ob_space.shape[0] + ac_space.shape[0]), )
    print(ob_space.shape)
    print(ac_space.shape)
    #input ("wait here where the spaces are printed!!!")
    pi = policy_func("pi", ob_space,
                     ac_space)  # Construct network for new policy
    oldpi = policy_func("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return
    pol_ov_op_ent = tf.placeholder(dtype=tf.float32,
                                   shape=None)  # Empirical return

    # option = tf.placeholder(dtype=tf.int32, shape=[None])

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    # pdb.set_trace()
    ob = U.get_placeholder_cached(name="ob")
    option = U.get_placeholder_cached(name="option")
    term_adv = U.get_placeholder(name='term_adv',
                                 dtype=tf.float32,
                                 shape=[None])

    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = U.mean(kloldnew)
    meanent = U.mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    atarg_clip = atarg  #tf.clip_by_value(atarg,-10,10)
    surr1 = ratio * atarg_clip  #atarg # surrogate from conservative policy iteration
    surr2 = U.clip(ratio, 1.0 - clip_param,
                   1.0 + clip_param) * atarg_clip  #atarg #
    pol_surr = -U.mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)

    #vf_loss = U.mean(tf.square(tf.clip_by_value(pi.vpred - ret, -10.0, 10.0)))
    vf_loss = U.mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    term_loss = pi.tpred * term_adv

    force_pi_loss = U.mean(
        tf.square(
            tf.clip_by_value(pi.op_pi, 1e-5, 1.0) -
            tf.constant([[0.05, 0.95]])))

    log_pi = tf.log(tf.clip_by_value(pi.op_pi, 1e-5, 1.0))
    #log_pi = tf.Print(log_pi, [log_pi, tf.shape(tf.transpose(log_pi))])
    old_log_pi = tf.log(tf.clip_by_value(oldpi.op_pi, 1e-5, 1.0))
    entropy = -tf.reduce_sum(pi.op_pi * log_pi, reduction_indices=1)

    ratio_pol_ov_op = tf.exp(
        tf.transpose(log_pi)[option[0]] -
        tf.transpose(old_log_pi)[option[0]])  # pnew / pold
    term_adv_clip = term_adv  #tf.clip_by_value(term_adv,-10,10)
    surr1_pol_ov_op = ratio_pol_ov_op * term_adv_clip  # surrogate from conservative policy iteration
    surr2_pol_ov_op = U.clip(ratio_pol_ov_op, 1.0 - clip_param,
                             1.0 + clip_param) * term_adv_clip  #
    pol_surr_pol_ov_op = -U.mean(
        tf.minimum(surr1_pol_ov_op,
                   surr2_pol_ov_op))  # PPO's pessimistic surrogate (L^CLIP)

    op_loss = pol_surr_pol_ov_op - pol_ov_op_ent * tf.reduce_sum(entropy)
    #op_loss = pol_surr_pol_ov_op

    #total_loss += force_pi_loss
    total_loss += op_loss

    var_list = pi.get_trainable_variables()
    term_list = var_list[6:8]

    lossandgrad = U.function(
        [ob, ac, atarg, ret, lrmult, option, term_adv, pol_ov_op_ent],
        losses + [U.flatgrad(total_loss, var_list)])
    termloss = U.function([ob, option, term_adv],
                          [U.flatgrad(term_loss, var_list)
                           ])  # Since we will use a different step size.
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult, option], losses)

    U.initialize()
    adam.sync()

    saver = tf.train.Saver(max_to_keep=10000)
    saver_best = tf.train.Saver(max_to_keep=1)

    ### More book-kepping
    results = []
    if saves:
        results = open(
            version_name + '_' + gamename + '_' + str(num_options) + 'opts_' +
            '_results.csv', 'w')
        results_best_model = open(
            dirname + version_name + '_' + gamename + '_' + str(num_options) +
            'opts_' + '_bestmodel.csv', 'w')

        out = 'epoch,avg_reward'

        for opt in range(num_options):
            out += ',option {} dur'.format(opt)
        for opt in range(num_options):
            out += ',option {} std'.format(opt)
        for opt in range(num_options):
            out += ',option {} term'.format(opt)
        for opt in range(num_options):
            out += ',option {} adv'.format(opt)
        out += '\n'
        results.write(out)
        # results.write('epoch,avg_reward,option 1 dur, option 2 dur, option 1 term, option 2 term\n')
        results.flush()

    if epoch >= 0:

        dirname = '{}_{}opts_saves/'.format(gamename, num_options)
        print("Loading weights from iteration: " + str(epoch))

        filename = dirname + '{}_epoch_{}.ckpt'.format(gamename, epoch)
        saver.restore(U.get_session(), filename)
    ###

    episodes_so_far = 0
    timesteps_so_far = 0
    global iters_so_far
    iters_so_far = 0
    des_pol_op_ent = 0.1
    max_val = -100000
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_batch,
                                     stochastic=True,
                                     num_options=num_options,
                                     saves=saves,
                                     results=results,
                                     rewbuffer=rewbuffer,
                                     dc=dc)

    datas = [0 for _ in range(num_options)]

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        opt_d = []
        for i in range(num_options):
            dur = np.mean(
                seg['opt_dur'][i]) if len(seg['opt_dur'][i]) > 0 else 0.
            opt_d.append(dur)

        std = []
        for i in range(num_options):
            logstd = np.mean(
                seg['logstds'][i]) if len(seg['logstds'][i]) > 0 else 0.
            std.append(np.exp(logstd))
        print("mean opt dur:", opt_d)
        print("mean op pol:", np.mean(np.array(seg['optpol_p']), axis=0))
        print("mean term p:", np.mean(np.array(seg['term_p']), axis=0))
        print("mean value val:", np.mean(np.array(seg['value_val']), axis=0))

        ob, ac, opts, atarg, tdlamret = seg["ob"], seg["ac"], seg["opts"], seg[
            "adv"], seg["tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy
        if hasattr(pi, "ob_rms_only"):
            pi.ob_rms_only.update(ob[:, :-ac_space.shape[0]]
                                  )  # update running mean/std for policy
        assign_old_eq_new()  # set old parameter values to new parameter values

        if (iters_so_far + 1) % 1000 == 0:
            des_pol_op_ent = des_pol_op_ent / 10

        if iters_so_far % 50 == 0 and wsaves:
            print("weights are saved...")
            filename = dirname + '{}_epoch_{}.ckpt'.format(
                gamename, iters_so_far)
            save_path = saver.save(U.get_session(), filename)

        # adaptively save best run:
        if (np.mean(rewbuffer) > max_val) and wsaves:
            max_val = np.mean(rewbuffer)
            results_best_model.write('epoch: ' + str(iters_so_far) + 'rew: ' +
                                     str(np.mean(rewbuffer)) + '\n')
            results_best_model.flush()
            filename = dirname + 'best.ckpt'.format(gamename, iters_so_far)
            save_path = saver_best.save(U.get_session(), filename)

        min_batch = 160  # Arbitrary
        t_advs = [[] for _ in range(num_options)]
        for opt in range(num_options):
            indices = np.where(opts == opt)[0]
            print("batch size:", indices.size)
            opt_d[opt] = indices.size
            if not indices.size:
                t_advs[opt].append(0.)
                continue

            ### This part is only necessasry when we use options. We proceed to these verifications in order not to discard any collected trajectories.
            if datas[opt] != 0:
                if (indices.size < min_batch and datas[opt].n > min_batch):
                    datas[opt] = Dataset(dict(ob=ob[indices],
                                              ac=ac[indices],
                                              atarg=atarg[indices],
                                              vtarg=tdlamret[indices]),
                                         shuffle=not pi.recurrent)
                    t_advs[opt].append(0.)
                    continue

                elif indices.size + datas[opt].n < min_batch:
                    # pdb.set_trace()
                    oldmap = datas[opt].data_map

                    cat_ob = np.concatenate((oldmap['ob'], ob[indices]))
                    cat_ac = np.concatenate((oldmap['ac'], ac[indices]))
                    cat_atarg = np.concatenate(
                        (oldmap['atarg'], atarg[indices]))
                    cat_vtarg = np.concatenate(
                        (oldmap['vtarg'], tdlamret[indices]))
                    datas[opt] = Dataset(dict(ob=cat_ob,
                                              ac=cat_ac,
                                              atarg=cat_atarg,
                                              vtarg=cat_vtarg),
                                         shuffle=not pi.recurrent)
                    t_advs[opt].append(0.)
                    continue

                elif (indices.size + datas[opt].n > min_batch and datas[opt].n
                      < min_batch) or (indices.size > min_batch
                                       and datas[opt].n < min_batch):

                    oldmap = datas[opt].data_map
                    cat_ob = np.concatenate((oldmap['ob'], ob[indices]))
                    cat_ac = np.concatenate((oldmap['ac'], ac[indices]))
                    cat_atarg = np.concatenate(
                        (oldmap['atarg'], atarg[indices]))
                    cat_vtarg = np.concatenate(
                        (oldmap['vtarg'], tdlamret[indices]))
                    datas[opt] = d = Dataset(dict(ob=cat_ob,
                                                  ac=cat_ac,
                                                  atarg=cat_atarg,
                                                  vtarg=cat_vtarg),
                                             shuffle=not pi.recurrent)

                if (indices.size > min_batch and datas[opt].n > min_batch):
                    datas[opt] = d = Dataset(dict(ob=ob[indices],
                                                  ac=ac[indices],
                                                  atarg=atarg[indices],
                                                  vtarg=tdlamret[indices]),
                                             shuffle=not pi.recurrent)

            elif datas[opt] == 0:
                datas[opt] = d = Dataset(dict(ob=ob[indices],
                                              ac=ac[indices],
                                              atarg=atarg[indices],
                                              vtarg=tdlamret[indices]),
                                         shuffle=not pi.recurrent)
            ###

            optim_batchsize = optim_batchsize or ob.shape[0]
            optim_epochs = np.clip(
                np.int(10 * (indices.size /
                             (timesteps_per_batch / num_options))), 10,
                10) if num_options > 1 else optim_epochs
            print("optim epochs:", optim_epochs)
            logger.log("Optimizing...")

            # Here we do a bunch of optimization epochs over the data
            for _ in range(optim_epochs):
                losses = [
                ]  # list of tuples, each of which gives the loss for a minibatch
                for batch in d.iterate_once(optim_batchsize):

                    #tadv,nodc_adv = pi.get_term_adv(batch["ob"],[opt])
                    tadv, nodc_adv = pi.get_opt_adv(batch["ob"], [opt])
                    tadv = tadv if num_options > 1 else np.zeros_like(tadv)
                    t_advs[opt].append(nodc_adv)

                    #if (opt==1):
                    #    *newlosses, grads = lossandgrad(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult, [opt], tadv)
                    #else:
                    #    *newlosses, grads = lossandgrad0(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult, [opt], tadv)
                    *newlosses, grads = lossandgrad(batch["ob"], batch["ac"],
                                                    batch["atarg"],
                                                    batch["vtarg"], cur_lrmult,
                                                    [opt], tadv,
                                                    des_pol_op_ent)
                    #*newlosses, grads = lossandgrad(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult, [opt], tadv)
                    #termg = termloss(batch["ob"], [opt], tadv)
                    #adam.update(termg[0], 5e-7 * cur_lrmult)
                    adam.update(grads, optim_stepsize * cur_lrmult)
                    losses.append(newlosses)

        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()

        ### Book keeping
        if saves:
            out = "{},{}"
            for _ in range(num_options):
                out += ",{},{},{},{}"
            out += "\n"

            info = [iters_so_far, np.mean(rewbuffer)]
            for i in range(num_options):
                info.append(opt_d[i])
            for i in range(num_options):
                info.append(std[i])
            for i in range(num_options):
                info.append(np.mean(np.array(seg['term_p']), axis=0)[i])
            for i in range(num_options):
                info.append(np.mean(t_advs[i]))

            results.write(out.format(*info))
            results.flush()
Example #10
0
def train(num_timesteps, iters):
    from baselines.ppo1 import mlp_policy
    U.make_session(num_cpu=1).__enter__()

    def policy_fn(name, ob_space, ac_space):
        return mlp_policy.MlpPolicy(name=name,
                                    ob_space=ob_space,
                                    ac_space=ac_space,
                                    hid_size=64,
                                    num_hid_layers=2)

    env0 = TestEnv()
    # env0 = ImageEnv()
    model_0 = learn(
        env0,
        policy_fn,
        "pi0",
        max_timesteps=num_timesteps,
        timesteps_per_batch=1000,
        clip_param=0.2,
        entcoeff=0.0,
        optim_epochs=10,
        optim_stepsize=3e-4,
        optim_batchsize=64,
        gamma=0.99,
        lam=0.95,
        schedule='linear',
    )
    env0.close()

    env1 = TestEnv1()
    # env1 = ImageEnv1()
    model_1 = learn(
        env1,
        policy_fn,
        "pi1",
        max_timesteps=num_timesteps,
        timesteps_per_batch=1000,
        clip_param=0.2,
        entcoeff=0.0,
        optim_epochs=10,
        optim_stepsize=3e-4,
        optim_batchsize=64,
        gamma=0.99,
        lam=0.95,
        schedule='linear',
    )
    env1.close()

    env2 = TestEnv2()
    # env2 = ImageEnv2()
    model_2 = learn(
        env2,
        policy_fn,
        "pi2",
        max_timesteps=num_timesteps,
        timesteps_per_batch=1000,
        clip_param=0.2,
        entcoeff=0.0,
        optim_epochs=10,
        optim_stepsize=3e-4,
        optim_batchsize=64,
        gamma=0.99,
        lam=0.95,
        schedule='linear',
    )
    env2.close()

    ob_space = env0.observation_space
    ac_space = env0.action_space
    pi = policy_fn("model_d", ob_space,
                   ac_space)  # Construct network for new policy
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])
    lrmult = tf.placeholder(name='lrmult', dtype=tf.float32, shape=[])

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])
    kl = pi.pd.kl(model_0.pd) + pi.pd.kl(model_1.pd) + pi.pd.kl(model_2.pd)
    ent = model_0.pd.entropy() + model_1.pd.entropy() + model_2.pd.entropy()
    meankl = U.mean(kl)
    meanent = U.mean(ent)
    loss = -meankl  # - U.mean(tf.exp(model_0.pd.logp(ac)) * atarg) - U.mean(tf.exp(model_1.pd.logp(ac)) * atarg) - U.mean(tf.exp(model_2.pd.logp(ac)) * atarg)
    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             loss + [U.flatgrad(loss, var_list)])
    adam = MpiAdam(var_list, epsilon=1e-5)
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], loss)

    U.initialize()
    adam.sync()

    seg_gen0 = traj_segment_generator(model_0, env0, 1000, stochastic=True)
    seg_gen1 = traj_segment_generator(model_1, env1, 1000, stochastic=True)
    seg_gen2 = traj_segment_generator(model_2, env2, 1000, stochastic=True)

    seg_gend0 = traj_segment_generator(pi, env0, 1000, stochastic=True)
    seg_gend1 = traj_segment_generator(pi, env1, 1000, stochastic=True)
    seg_gend2 = traj_segment_generator(pi, env2, 1000, stochastic=True)

    lenbuffer0 = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer0 = deque(maxlen=100)
    lenbuffer1 = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer1 = deque(maxlen=100)
    lenbuffer2 = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer2 = deque(maxlen=100)

    rew0 = []
    rew1 = []
    rew2 = []

    # env2.close()
    # return model_0, model_1, model_2
    for i in range(iters):

        logger.log("********** Iteration %i ************" % i)
        cur_lrmult = 1.0

        seg0 = seg_gen0.__next__()
        add_vtarg_and_adv(seg0, 0.99, 0.95)

        ob, ac, atarg, tdlamret = seg0["ob"], seg0["ac"], seg0["adv"], seg0[
            "tdlamret"]
        vpredbefore = seg0["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret))
        optim_batchsize = ob.shape[0]

        for _ in range(10):
            # losses = [] # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
                adam.update(g, 3e-4 * cur_lrmult)

        seg1 = seg_gen1.__next__()
        add_vtarg_and_adv(seg1, 0.99, 0.95)

        ob, ac, atarg, tdlamret = seg1["ob"], seg1["ac"], seg1["adv"], seg1[
            "tdlamret"]
        vpredbefore = seg1["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret))
        optim_batchsize = ob.shape[0]

        for _ in range(10):
            # losses = [] # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
                adam.update(g, 3e-4 * cur_lrmult)

        seg2 = seg_gen2.__next__()
        add_vtarg_and_adv(seg2, 0.99, 0.95)

        ob, ac, atarg, tdlamret = seg2["ob"], seg2["ac"], seg2["adv"], seg2[
            "tdlamret"]
        vpredbefore = seg2["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret))
        optim_batchsize = ob.shape[0]

        for _ in range(10):
            # losses = [] # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
                adam.update(g, 3e-4 * cur_lrmult)

        segd0 = seg_gend0.__next__()
        segd1 = seg_gend1.__next__()
        segd2 = seg_gend2.__next__()

        lrlocal0 = (segd0["ep_lens"], segd0["ep_rets"])  # local values
        listoflrpairs0 = MPI.COMM_WORLD.allgather(lrlocal0)  # list of tuples
        lens0, rews0 = map(flatten_lists, zip(*listoflrpairs0))
        lenbuffer0.extend(lens0)
        rewbuffer0.extend(rews0)
        mean_rew0 = np.mean(rewbuffer0)
        logger.record_tabular("Env0EpLenMean", np.mean(lenbuffer0))
        logger.record_tabular("Env0EpRewMean", mean_rew0)
        rew0.append(mean_rew0)

        lrlocal1 = (segd1["ep_lens"], segd1["ep_rets"])  # local values
        listoflrpairs1 = MPI.COMM_WORLD.allgather(lrlocal1)  # list of tuples
        lens1, rews1 = map(flatten_lists, zip(*listoflrpairs1))
        lenbuffer1.extend(lens1)
        rewbuffer1.extend(rews1)
        mean_rew1 = np.mean(rewbuffer1)
        logger.record_tabular("Env1EpLenMean", np.mean(lenbuffer1))
        logger.record_tabular("Env1EpRewMean", mean_rew1)
        rew1.append(mean_rew1)

        lrlocal2 = (segd2["ep_lens"], segd2["ep_rets"])  # local values
        listoflrpairs2 = MPI.COMM_WORLD.allgather(lrlocal2)  # list of tuples
        lens2, rews2 = map(flatten_lists, zip(*listoflrpairs2))
        lenbuffer2.extend(lens2)
        rewbuffer2.extend(rews2)
        mean_rew2 = np.mean(rewbuffer2)
        logger.record_tabular("Env2EpLenMean", np.mean(lenbuffer2))
        logger.record_tabular("Env2EpRewMean", mean_rew2)
        rew2.append(mean_rew2)

        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()

    return model_0, model_1, model_2, pi, np.array(rew0), np.array(
        rew1), np.array(rew2)
Example #11
0
def learn(env, policy_fn, *,
        timesteps_per_batch, # what to train on
        max_kl, cg_iters,
        gamma, lam, # advantage estimation
        entcoeff=0.0,
        cg_damping=1e-2,
        vf_stepsize=3e-4,
        vf_iters =3,
        max_timesteps=0, max_episodes=0, max_iters=0,is_Original = 0,  # time constraint
        callback=None
        ):
    nworkers = MPI.COMM_WORLD.Get_size()
    rank = MPI.COMM_WORLD.Get_rank()
    np.set_printoptions(precision=3)
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_fn("pi", ob_space, ac_space)
    oldpi = policy_fn("oldpi", ob_space, ac_space)
    atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    entbonus = entcoeff * meanent

    vferr = tf.reduce_mean(tf.square(pi.vpred - ret))

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # advantage * pnew / pold

    if is_Original == 0:

        surrgain = tf.reduce_mean(ratio * atarg)

    if is_Original == 1:

        surrgain = tf.reduce_mean(tf.log(tf.clip_by_value(ratio, 1e-10, 1e100)) * (atarg ))

    if is_Original == 2:

        surrgain = tf.reduce_mean(tf.log(tf.clip_by_value(ratio, 1e-10, 1e100)) * tf.nn.relu(atarg) -
                                  (tf.nn.relu(-1.0 * atarg) * (2 *ratio - tf.log(tf.clip_by_value(ratio, 1e-10, 1e100)))))

    if is_Original == 3:

        surrgain =  tf.reduce_mean(tf.log(tf.clip_by_value(ratio, 1e-10, 1e100)) * tf.nn.relu(atarg) +
                                   tf.nn.relu(-1.0 * atarg) * tf.log(tf.clip_by_value(2 - ratio, 1e-10, 1e100)))









    optimgain = surrgain + entbonus
    losses = [optimgain, meankl, entbonus, surrgain, meanent]
    loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]

    dist = meankl

    all_var_list = pi.get_trainable_variables()
    var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("pol")]
    vf_var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("vf")]
    vfadam = MpiAdam(vf_var_list)

    get_flat = U.GetFlat(var_list)
    set_from_flat = U.SetFromFlat(var_list)
    klgrads = tf.gradients(dist, var_list)
    flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan")
    shapes = [var.get_shape().as_list() for var in var_list]
    start = 0
    tangents = []
    for shape in shapes:
        sz = U.intprod(shape)
        tangents.append(tf.reshape(flat_tangent[start:start+sz], shape))
        start += sz
    gvp = tf.add_n([tf.reduce_sum(g*tangent) for (g, tangent) in zipsame(klgrads, tangents)]) #pylint: disable=E1111
    fvp = U.flatgrad(gvp, var_list)

    assign_old_eq_new = U.function([],[], updates=[tf.assign(oldv, newv)
        for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())])
    compute_losses = U.function([ob, ac, atarg], losses)
    compute_lossandgrad = U.function([ob, ac, atarg], losses + [U.flatgrad(optimgain, var_list)])
    compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp)
    compute_vflossandgrad = U.function([ob, ret], U.flatgrad(vferr, vf_var_list))

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(colorize("done in %.3f seconds"%(time.time() - tstart), color='magenta'))
        else:
            yield

    def allmean(x):
        assert isinstance(x, np.ndarray)
        out = np.empty_like(x)
        MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
        out /= nworkers
        return out

    U.initialize()
    th_init = get_flat()
    MPI.COMM_WORLD.Bcast(th_init, root=0)
    set_from_flat(th_init)
    vfadam.sync()
    print("Init param sum", th_init.sum(), flush=True)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi, env, timesteps_per_batch, stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards

    assert sum([max_iters>0, max_timesteps>0, max_episodes>0])==1

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        logger.log("********** Iteration %i ************"%iters_so_far)

        with timed("sampling"):
            seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"]
        vpredbefore = seg["vpred"] # predicted value function before udpate
        # atarg = (atarg - atarg.mean()) / atarg.std() # standardized advantage function estimate
        # atarg = (atarg - atarg.mean()) / atarg.std()
        if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret)
        if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy

        args = seg["ob"], seg["ac"], atarg
        fvpargs = [arr[::5] for arr in args]
        def fisher_vector_product(p):
            return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p

        assign_old_eq_new() # set old parameter values to new parameter values
        with timed("computegrad"):
            *lossbefore, g = compute_lossandgrad(*args)
        lossbefore = allmean(np.array(lossbefore))
        g = allmean(g)
        if np.allclose(g, 0):
            logger.log("Got zero gradient. not updating")
        else:
            with timed("cg"):
                stepdir = cg(fisher_vector_product, g, cg_iters=cg_iters, verbose=rank==0)
            assert np.isfinite(stepdir).all()
            shs = .5*stepdir.dot(fisher_vector_product(stepdir))
            lm = np.sqrt(shs / max_kl)
            # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
            fullstep = stepdir / lm
            expectedimprove = g.dot(fullstep)
            surrbefore = lossbefore[0]
            stepsize = 1.0
            thbefore = get_flat()
            for _ in range(10):
                thnew = thbefore + fullstep * stepsize
                set_from_flat(thnew)
                meanlosses = surr, kl, *_ = allmean(np.array(compute_losses(*args)))
                improve = surr - surrbefore
                logger.log("Expected: %.3f Actual: %.3f"%(expectedimprove, improve))


                if not np.isfinite(meanlosses).all():
                    logger.log("Got non-finite value of losses -- bad!")
                elif kl > max_kl * 1.5:
                    logger.log("violated KL constraint. shrinking step.")
                elif improve < 0:
                    logger.log("surrogate didn't improve. shrinking step.")
                else:
                    logger.log("Stepsize OK!")
                    break
                stepsize *= .5
            else:
                logger.log("couldn't compute a good step")
                set_from_flat(thbefore)
            if nworkers > 1 and iters_so_far % 20 == 0:
                paramsums = MPI.COMM_WORLD.allgather((thnew.sum(), vfadam.getflat().sum())) # list of tuples
                assert all(np.allclose(ps, paramsums[0]) for ps in paramsums[1:])

        for (lossname, lossval) in zip(loss_names, meanlosses):
            logger.record_tabular(lossname, lossval)

        with timed("vf"):

            for _ in range(vf_iters):
                for (mbob, mbret) in dataset.iterbatches((seg["ob"], seg["tdlamret"]),
                include_final_partial_batch=False, batch_size=64):
                    g = allmean(compute_vflossandgrad(mbob, mbret))
                    vfadam.update(g, vf_stepsize)

        logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))

        lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)


        if rank==0:
            logger.dump_tabular()
def learn(
        env,
        policy_func_pro,
        policy_func_adv,
        *,
        timesteps_per_batch,  # timesteps per actor per update
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        optim_epochs,
        optim_stepsize,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,
        lr_l,
        lr_a,
        max_steps_episode,  # advantage estimation
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        max_seconds=0,  # time constraint
        callback=None,  # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
        clip_action=False,
        restore_dir=None,
        ckpt_dir=None,
        save_timestep_period=1000):
    # Setup losses and stuff
    # ----------------------------------------
    rew_mean = []

    ob_space = env.observation_space
    pro_ac_space = env.action_space
    adv_ac_space = env.adv_action_space

    # env.render()
    pro_pi = policy_func_pro("pro_pi", ob_space,
                             pro_ac_space)  # Construct network for new policy
    pro_oldpi = policy_func_pro("pro_oldpi", ob_space,
                                pro_ac_space)  # Network for old policy

    adv_pi = policy_func_adv(
        "adv_pi", ob_space,
        adv_ac_space)  # Construct network for new adv policy
    adv_oldpi = policy_func_adv("adv_oldpi", ob_space,
                                adv_ac_space)  # Network for old adv policy

    pro_atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    adv_atarg = tf.placeholder(dtype=tf.float32, shape=[None])

    ret_pro = tf.placeholder(dtype=tf.float32,
                             shape=[None])  # Empirical return
    ret_adv = tf.placeholder(dtype=tf.float32,
                             shape=[None])  # Empirical return

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ob_ = U.get_placeholder_cached(name="ob_")

    adv_ob = U.get_placeholder_cached(name="ob_adv")
    adv_ob_ = U.get_placeholder_cached(name="adv_ob_")

    pro_ac = pro_pi.pdtype.sample_placeholder([None])
    adv_ac = adv_pi.pdtype.sample_placeholder([None])

    # define Lyapunov net
    s_dim = ob_space.shape[0]
    a_dim = pro_ac_space.shape[0]
    d_dim = adv_ac_space.shape[0]

    use_lyapunov = True
    approx_value = True
    finite_horizon = True
    labda_init = 1.  # ΔL的权重,拉格朗日乘子
    alpha3 = 0.5  # L2性能,后期我们可以修改这个指标来看性能
    tau = 5e-3
    ita = 1.
    lr_labda = 0.033

    LN_R = tf.placeholder(tf.float32, [None, 1], 'r')  # 回报
    LN_V = tf.placeholder(tf.float32, [None, 1], 'v')  # 回报
    LR_L = tf.placeholder(tf.float32, None, 'LR_L')  # Lyapunov网络学习率
    LR_A = tf.placeholder(tf.float32, None, 'LR_A')  # Actor网络学习率
    L_terminal = tf.placeholder(tf.float32, [None, 1], 'terminal')
    # LN_S = tf.placeholder(tf.float32, [None, s_dim], 's')  # 状态
    # LN_a_input = tf.placeholder(tf.float32, [None, a_dim], 'a_input')  # batch中输入的动作
    # ob_ = tf.placeholder(tf.float32, [None, s_dim], 's_')  # 后继状态
    LN_a_input_ = tf.placeholder(tf.float32, [None, a_dim], 'a_')  # 后继状态
    LN_d_input = tf.placeholder(tf.float32, [None, d_dim],
                                'd_input')  # batch中输入的干扰
    labda = tf.placeholder(tf.float32, None, 'LR_lambda')

    # log_labda = tf.get_variable('lambda', None, tf.float32, initializer=tf.log(labda_init))  # log(λ),用于自适应系数
    # labda = tf.clip_by_value(tf.exp(log_labda), *SCALE_lambda_MIN_MAX)

    l = build_l(ob, pro_ac, s_dim, a_dim)  # lyapunov 网络
    l_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                 scope='Lyapunov')

    ema = tf.train.ExponentialMovingAverage(decay=1 -
                                            tau)  # soft replacement  网络软更新

    def ema_getter(getter, name, *args, **kwargs):
        return ema.average(getter(name, *args, **kwargs))

    target_update = [ema.apply(l_params)]

    a_ = pro_pi.ac_
    a_old_ = pro_oldpi.ac_

    # 这里的下一个动作的采样方式有好几种
    l_ = build_l(ob_, a_, s_dim, a_dim, reuse=True)

    l_d = build_l(adv_ob, pro_ac, s_dim, a_dim, reuse=True)
    l_d_ = build_l(adv_ob_, LN_a_input_, s_dim, a_dim, reuse=True)

    l_old_ = build_l(ob_,
                     a_old_,
                     s_dim,
                     a_dim,
                     reuse=True,
                     custom_getter=ema_getter)  # 这里是否可以用a_代替a_old呢
    l_derta = tf.reduce_mean(
        l_ - l + (alpha3 + 1) * LN_R -
        ita * tf.expand_dims(tf.norm(LN_d_input, axis=1), axis=1))
    # d 里的l这一项可能也需要改,因为这里主策略已经优化过了,所以需要重新采样,但是重新采样好像也不太适合,因为R变了,要不还有一个方式,就是轮流更新
    # 还有一种方式就是这里的ac_也采用样本里的,来保证一致性,但是为啥之前SAC那个就没问题呢
    l_d_derta = tf.reduce_mean(
        l_d_ - l_d + (alpha3 + 1) * LN_R -
        ita * tf.expand_dims(tf.norm(adv_pi.ac, axis=1), axis=1)
    )  # 可能是这里震荡了, adv_pi.ac  或许这里我真的该同步两个lyapunov
    # labda_loss = -tf.reduce_mean(log_labda * l_derta)  # lambda的更新loss
    # lambda_train = tf.train.AdamOptimizer(LR_A).minimize(labda_loss, var_list=log_labda)  # alpha的优化器

    with tf.control_dependencies(target_update):
        if approx_value:
            if finite_horizon:
                l_target = LN_V  # 这里自己近似会不会好一点
            else:
                l_target = LN_R + gamma * (1 - L_terminal) * tf.stop_gradient(
                    l_old_)  # Lyapunov critic - self.alpha * next_log_pis
        else:
            l_target = LN_R
        l_error = tf.losses.mean_squared_error(labels=l_target, predictions=l)
        ltrain = tf.train.AdamOptimizer(LR_L).minimize(l_error,
                                                       var_list=l_params)
        Lyapunov_train = [ltrain]
        Lyapunov_opt_input = [
            LN_R, LN_V, L_terminal, ob, ob_, pro_ac, LN_d_input, LR_A, LR_L
        ]

        Lyapunov_opt = U.function(Lyapunov_opt_input, Lyapunov_train)
        Lyapunov_opt_loss = U.function(Lyapunov_opt_input, [l_error])
    # Lyapunov函数

    pro_kloldnew = pro_oldpi.pd.kl(pro_pi.pd)  # compute kl difference
    adv_kloldnew = adv_oldpi.pd.kl(adv_pi.pd)
    pro_ent = pro_pi.pd.entropy()
    adv_ent = adv_pi.pd.entropy()
    pro_meankl = tf.reduce_mean(pro_kloldnew)
    adv_meankl = tf.reduce_mean(adv_kloldnew)
    pro_meanent = tf.reduce_mean(pro_ent)
    adv_meanent = tf.reduce_mean(adv_ent)
    pro_pol_entpen = (-entcoeff) * pro_meanent
    adv_pol_entpen = (-entcoeff) * adv_meanent

    pro_ratio = tf.exp(pro_pi.pd.logp(pro_ac) -
                       pro_oldpi.pd.logp(pro_ac))  # pnew / pold
    adv_ratio = tf.exp(adv_pi.pd.logp(adv_ac) - adv_oldpi.pd.logp(adv_ac))
    pro_surr1 = pro_ratio * pro_atarg  # surrogate from conservative policy iteration
    adv_surr1 = adv_ratio * adv_atarg
    pro_surr2 = tf.clip_by_value(pro_ratio, 1.0 - clip_param,
                                 1.0 + clip_param) * pro_atarg  #
    adv_surr2 = tf.clip_by_value(adv_ratio, 1.0 - clip_param,
                                 1.0 + clip_param) * adv_atarg
    pro_pol_surr = -tf.reduce_mean(tf.minimum(
        pro_surr1, pro_surr2))  # PPO's pessimistic surrogate (L^CLIP)
    adv_pol_surr = -tf.reduce_mean(tf.minimum(adv_surr1, adv_surr2))
    pro_vf_loss = tf.reduce_mean(tf.square(pro_pi.vpred - ret_pro))
    adv_vf_loss = tf.reduce_mean(tf.square(adv_pi.vpred - ret_adv))
    pro_lyapunov_loss = tf.reduce_mean(-l_derta * labda)
    pro_total_loss = pro_pol_surr + pro_pol_entpen + pro_vf_loss + pro_lyapunov_loss

    adv_lyapunov_loss = tf.reduce_mean(-l_d_derta * labda)
    adv_total_loss = adv_pol_surr + adv_pol_entpen + adv_vf_loss + adv_lyapunov_loss

    pro_losses = [
        pro_pol_surr, pro_pol_entpen, pro_vf_loss, pro_meankl, pro_meanent,
        pro_lyapunov_loss
    ]
    pro_loss_names = [
        "pro_pol_surr", "pro_pol_entpen", "pro_vf_loss", "pro_kl", "pro_ent",
        "pro_lyapunov_loss"
    ]
    adv_losses = [
        adv_pol_surr, adv_pol_entpen, adv_vf_loss, adv_meankl, adv_meanent,
        adv_lyapunov_loss
    ]
    adv_loss_names = [
        "adv_pol_surr", "adv_pol_entpen", "adv_vf_loss", "adv_kl", "adv_ent",
        "adv_lyapunov_loss"
    ]

    pro_var_list = pro_pi.get_trainable_variables()
    adv_var_list = adv_pi.get_trainable_variables()
    pro_opt_input = [
        ob, pro_ac, pro_atarg, ret_pro, lrmult, LN_R, LN_V, ob_, LN_d_input,
        LR_A, LR_L, labda
    ]
    # pro_lossandgrad = U.function([ob, pro_ac, pro_atarg, ret, lrmult], pro_losses + [U.flatgrad(pro_total_loss, pro_var_list)])
    pro_lossandgrad = U.function(
        pro_opt_input, pro_losses + [U.flatgrad(pro_total_loss, pro_var_list)])

    # Lyapunov_grad = U.function(Lyapunov_opt_input, U.flatgrad(l_derta * labda, pro_var_list))

    adv_opt_input = [
        adv_ob, adv_ac, adv_atarg, ret_adv, lrmult, LN_R, adv_ob_, pro_ac,
        LN_a_input_, LR_A, LR_L, labda
    ]
    adv_lossandgrad = U.function(
        adv_opt_input, adv_losses + [U.flatgrad(adv_total_loss, adv_var_list)])
    pro_adam = MpiAdam(pro_var_list, epsilon=adam_epsilon)
    adv_adam = MpiAdam(adv_var_list, epsilon=adam_epsilon)

    pro_assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv) for (oldv, newv) in zipsame(
                pro_oldpi.get_variables(), pro_pi.get_variables())
        ])
    adv_assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv) for (oldv, newv) in zipsame(
                adv_oldpi.get_variables(), adv_pi.get_variables())
        ])
    pro_compute_losses = U.function(pro_opt_input, pro_losses)
    adv_compute_losses = U.function(adv_opt_input, adv_losses)

    # zp gymfc new
    saver = None
    if ckpt_dir:  # save model
        # Store for each one
        keep = int(max_timesteps /
                   float(save_timestep_period))  # number of model want to save
        print("[INFO] Keeping ", keep, " checkpoints")
        saver = tf.train.Saver(save_relative_paths=True, max_to_keep=keep)

    print('version: use discarl v2-v5')
    print('info:', 'alpha3', alpha3, 'SCALE_lambda_MIN_MAX', lr_labda,
          'Finite_horizon', finite_horizon, 'adv_mag',
          env.adv_action_space.high, 'timesteps', max_timesteps)

    U.initialize()
    pro_adam.sync()
    adv_adam.sync()

    if restore_dir:  # restore model
        ckpt = tf.train.get_checkpoint_state(restore_dir)
        if ckpt:
            # If there is one that already exists then restore it
            print("Restoring model from ", ckpt.model_checkpoint_path)
            saver.restore(tf.get_default_session(), ckpt.model_checkpoint_path)
        else:
            print("Trying to restore model from ", restore_dir,
                  " but doesn't exist")

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pro_pi,
                                     adv_pi,
                                     env,
                                     timesteps_per_batch,
                                     max_steps_episode,
                                     stochastic=True,
                                     clip_action=clip_action)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    stop_buffer = deque(maxlen=30)
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards
    costbuffer = deque(maxlen=100)
    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    next_ckpt_timestep = save_timestep_period

    while True:
        if callback: callback(locals(), globals())

        end = False
        if max_timesteps and timesteps_so_far >= max_timesteps:
            end = True
        elif max_episodes and episodes_so_far >= max_episodes:
            end = True
        elif max_iters and iters_so_far >= max_iters:
            end = True
        elif max_seconds and time.time() - tstart >= max_seconds:
            end = True

        if saver and ((timesteps_so_far >= next_ckpt_timestep) or end):
            task_name = "ppo-{}-{}.ckpt".format(env.spec.id, timesteps_so_far)
            fname = os.path.join(ckpt_dir, task_name)
            os.makedirs(os.path.dirname(fname), exist_ok=True)
            saver.save(tf.get_default_session(), fname)
            next_ckpt_timestep += save_timestep_period

        if end:  #and np.mean(stop_buffer) > zp_max_step:
            break

        if end and max_timesteps < 100:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        # 这里是对应的Lyapunov函数的学习率
        Lya_epsilon = 1.0 - (timesteps_so_far - 1.0) / max_timesteps
        if end:
            Lya_epsilon = 0.0001
        lr_a_this = lr_a * Lya_epsilon
        lr_l_this = lr_l * Lya_epsilon
        lr_labda_this = lr_labda * Lya_epsilon

        Percentage = min(timesteps_so_far / max_timesteps, 1) * 100
        logger.log("**********Iteration %i **Percentage %.2f **********" %
                   (iters_so_far, Percentage))

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, pro_ac, adv_ac, pro_atarg, adv_atarg, pro_tdlamret, adv_tdlamret = seg[
            "ob"], seg["pro_ac"], seg["adv_ac"], seg["pro_adv"], seg[
                "adv_adv"], seg["pro_tdlamret"], seg["adv_tdlamret"]
        rew = seg["rew"]
        ob_ = seg["ob_"]
        new = seg["new"]

        pro_vpredbefore = seg[
            "pro_vpred"]  # predicted value function before udpate
        adv_vpredbefore = seg["adv_vpred"]
        pro_atarg = (pro_atarg - pro_atarg.mean()) / pro_atarg.std(
        )  # standardized advantage function estimate
        adv_atarg = (adv_atarg - adv_atarg.mean()) / adv_atarg.std()

        # TODO
        # d = Dataset(dict(ob=ob, ac=pro_ac, atarg=pro_atarg, vtarg=pro_tdlamret), shuffle=not pro_pi.recurrent)
        d = Dataset(dict(ob=ob,
                         ob_=ob_,
                         rew=rew,
                         new=new,
                         ac=pro_ac,
                         adv=adv_ac,
                         atarg=pro_atarg,
                         vtarg=pro_tdlamret),
                    shuffle=not pro_pi.recurrent)  # 放入经验回放寄存器
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pro_pi, "ob_rms"):
            pro_pi.ob_rms.update(ob)  # update running mean/std for policy

        pro_assign_old_eq_new(
        )  # set old parameter values to new parameter values

        logger.log("Pro Optimizing...")
        logger.log(fmt_row(13, pro_loss_names))

        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            pro_losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                l_value = deepcopy(batch["atarg"])
                zp_fuk = l_value.reshape(-1, 1)

                # [LN_R, LN_V, ob, ob_, pro_ac, LN_d_input, LR_A, LR_L]
                Lyapunov_opt(batch["rew"].reshape(-1,
                                                  1), l_value.reshape(-1, 1),
                             batch["new"], batch["ob"], batch["ob_"],
                             batch["ac"], batch["adv"], lr_a_this, lr_l_this)
                *newlosses, g = pro_lossandgrad(
                    batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"],
                    cur_lrmult, batch["rew"].reshape(-1, 1),
                    l_value.reshape(-1, 1), batch["ob_"], batch["adv"],
                    lr_a_this, lr_l_this, lr_labda_this)

                pro_adam.update(g, optim_stepsize * cur_lrmult)
                pro_losses.append(newlosses)
            # logger.log(fmt_row(13, np.mean(pro_losses, axis=0)))

        # logger.log("Pro Evaluating losses...")
        pro_losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = pro_compute_losses(batch["ob"], batch["ac"],
                                           batch["atarg"], batch["vtarg"],
                                           cur_lrmult,
                                           batch["rew"].reshape(-1, 1),
                                           l_value.reshape(-1, 1),
                                           batch["ob_"], batch["adv"],
                                           lr_a_this, lr_l_this, lr_labda_this)
            pro_losses.append(newlosses)
        pro_meanlosses, _, _ = mpi_moments(pro_losses, axis=0)

        logger.log(fmt_row(13, pro_meanlosses))

        ac_ = sample_next_act(ob_, a_dim, pro_pi,
                              stochastic=True)  # ob, a_dim, policy, stochastic

        d = Dataset(dict(ob=ob,
                         adv_ac=adv_ac,
                         atarg=adv_atarg,
                         vtarg=adv_tdlamret,
                         ob_=ob_,
                         rew=rew,
                         ac_=ac_,
                         new=new,
                         pro_ac=pro_ac),
                    shuffle=not adv_pi.recurrent)

        if hasattr(adv_pi, "ob_rms"): adv_pi.ob_rms.update(ob)
        adv_assign_old_eq_new()

        # logger.log(fmt_row(13, adv_loss_names))
        logger.log("Adv Optimizing...")
        logger.log(fmt_row(13, adv_loss_names))
        for _ in range(optim_epochs):
            adv_losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                # adv_opt_input = [ob, adv_ac, adv_atarg, ret, lrmult, LN_R, ob_, LN_a_input_, LR_A, LR_L]
                # [ob, adv_ac, adv_atarg, ret, lrmult, LN_R, ob_, pro_ac, LN_a_input_, LR_A, LR_L]
                # ac_ = sample_next_act(batch["ob_"], a_dim, pro_pi, stochastic=True) # ob, a_dim, policy, stochastic
                *newlosses, g = adv_lossandgrad(
                    batch["ob"], batch["adv_ac"], batch["atarg"],
                    batch["vtarg"], cur_lrmult, batch["rew"].reshape(-1, 1),
                    batch["ob_"], batch["pro_ac"], batch["ac_"], lr_a_this,
                    lr_l_this, lr_labda_this)
                # *newlosses, g = adv_lossandgrad(batch["ob"], batch["adv_ac"], batch["atarg"], batch["vtarg"], cur_lrmult,
                #                                 batch["rew"].reshape(-1,1), batch["ob_"],
                #                                 batch["pro_ac"], batch["pro_ac"], lr_a_this, lr_l_this)
                adv_adam.update(g, optim_stepsize * cur_lrmult)
                adv_losses.append(newlosses)
            # logger.log(fmt_row(13, np.mean(adv_losses, axis=0)))
        # logger.log("Adv Evaluating losses...")
        adv_losses = []

        for batch in d.iterate_once(optim_batchsize):
            newlosses = adv_compute_losses(
                batch["ob"], batch["adv_ac"], batch["atarg"], batch["vtarg"],
                cur_lrmult, batch["rew"].reshape(-1, 1), batch["ob_"],
                batch["pro_ac"], batch["ac_"], lr_a_this, lr_l_this,
                lr_labda_this)
            adv_losses.append(newlosses)
        adv_meanlosses, _, _ = mpi_moments(adv_losses, axis=0)
        logger.log(fmt_row(13, adv_meanlosses))

        # curr_rew = evaluate(pro_pi, test_env)
        # rew_mean.append(curr_rew)
        # print(curr_rew)

        # logger.record_tabular("ev_tdlam_before", explained_variance(pro_vpredbefore, pro_tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))

        # print(rews)

        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        cost = seg["ep_cost"]
        costbuffer.extend(cost)

        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpCostMean", np.mean(costbuffer))
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)

        stop_buffer.extend(lens)

        logger.record_tabular("stop_flag", np.mean(stop_buffer))
        logger.dump_tabular()

        # print(stop_buffer)
        print(lr_labda_this)

    print('version: use discarl v2-v5')
    print('info:', 'alpha3', alpha3, 'SCALE_lambda_MIN_MAX', lr_labda,
          'Finite_horizon', finite_horizon, 'adv_mag',
          env.adv_action_space.high, 'timesteps', max_timesteps)

    return pro_pi, np.mean(rewbuffer), timesteps_so_far, np.mean(lenbuffer)
Example #13
0
def learn(env,
          policy_func,
          reward_giver,
          expert_dataset,
          rank,
          g_step,
          d_step,
          entcoeff,
          save_per_iter,
          timesteps_per_batch,
          ckpt_dir,
          log_dir,
          task_name,
          gamma,
          lam,
          max_kl,
          cg_iters,
          cg_damping=1e-2,
          vf_stepsize=3e-4,
          d_stepsize=3e-4,
          vf_iters=3,
          max_timesteps=0,
          max_episodes=0,
          max_iters=0,
          callback=None):

    nworkers = MPI.COMM_WORLD.Get_size()
    rank = MPI.COMM_WORLD.Get_rank()
    np.set_printoptions(precision=3)
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_func("pi", ob_space, ac_space)
    saver = tf.train.Saver(
        var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='pi'))
    saver.restore(tf.get_default_session(), U_.getPath() + '/model/bc.ckpt')

    oldpi = policy_func("oldpi", ob_space, ac_space)
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    entbonus = entcoeff * meanent

    vferr = tf.reduce_mean(tf.square(pi.vpred - ret))

    ratio = tf.exp(pi.pd.logp(ac) -
                   oldpi.pd.logp(ac))  # advantage * pnew / pold
    surrgain = tf.reduce_mean(ratio * atarg)

    optimgain = surrgain + entbonus
    losses = [optimgain, meankl, entbonus, surrgain, meanent]
    loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]

    dist = meankl

    all_var_list = pi.get_trainable_variables()
    var_list = [
        v for v in all_var_list
        if v.name.startswith("pi/pol") or v.name.startswith("pi/logstd")
    ]
    vf_var_list = [v for v in all_var_list if v.name.startswith("pi/vff")]
    assert len(var_list) == len(vf_var_list) + 1
    d_adam = MpiAdam(reward_giver.get_trainable_variables())
    vfadam = MpiAdam(vf_var_list)

    get_flat = U.GetFlat(var_list)
    set_from_flat = U.SetFromFlat(var_list)
    klgrads = tf.gradients(dist, var_list)
    flat_tangent = tf.placeholder(dtype=tf.float32,
                                  shape=[None],
                                  name="flat_tan")
    shapes = [var.get_shape().as_list() for var in var_list]
    start = 0
    tangents = []
    for shape in shapes:
        sz = U.intprod(shape)
        tangents.append(tf.reshape(flat_tangent[start:start + sz], shape))
        start += sz
    gvp = tf.add_n([
        tf.reduce_sum(g * tangent)
        for (g, tangent) in zipsame(klgrads, tangents)
    ])  # pylint: disable=E1111
    fvp = U.flatgrad(gvp, var_list)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg], losses)
    compute_lossandgrad = U.function([ob, ac, atarg], losses +
                                     [U.flatgrad(optimgain, var_list)])
    compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp)
    compute_vflossandgrad = U.function([ob, ret],
                                       U.flatgrad(vferr, vf_var_list))

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(
                colorize("done in %.3f seconds" % (time.time() - tstart),
                         color='magenta'))
        else:
            yield

    def allmean(x):
        assert isinstance(x, np.ndarray)
        out = np.empty_like(x)
        MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
        out /= nworkers
        return out

    U.initialize()
    th_init = get_flat()
    MPI.COMM_WORLD.Bcast(th_init, root=0)
    set_from_flat(th_init)
    d_adam.sync()
    vfadam.sync()
    if rank == 0:
        print("Init param sum", th_init.sum())

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     reward_giver,
                                     timesteps_per_batch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=40)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40)  # rolling buffer for episode rewards
    true_rewbuffer = deque(maxlen=40)

    assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1

    g_loss_stats = stats(loss_names)
    d_loss_stats = stats(reward_giver.loss_name)
    ep_stats = stats(["True_rewards", "Rewards", "Episode_length"])
    # if provide pretrained weight

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break

        # Save model
        if rank == 0 and iters_so_far % save_per_iter == 0 and ckpt_dir is not None:
            fname = os.path.join(ckpt_dir, task_name)
            print('save model as ', fname)
            try:
                os.makedirs(os.path.dirname(fname))
            except OSError:
                # folder already exists
                pass
            saver = tf.train.Saver()
            saver.save(tf.get_default_session(), fname)

        print("********** Iteration %i ************" % iters_so_far)

        def fisher_vector_product(p):
            return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p

        # ------------------ Update G ------------------
        print("Optimizing Policy...")
        for _ in range(g_step):
            with timed("sampling"):
                seg = seg_gen.next()
            add_vtarg_and_adv(seg, gamma, lam)
            # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
            ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
                "tdlamret"]
            vpredbefore = seg[
                "vpred"]  # predicted value function before udpate
            atarg = (atarg - atarg.mean()) / atarg.std(
            )  # standardized advantage function estimate

            if hasattr(pi, "ob_rms"):
                pi.ob_rms.update(ob)  # update running mean/std for policy

            args = seg["ob"], seg["ac"], atarg
            fvpargs = [arr[::5] for arr in args]

            assign_old_eq_new(
            )  # set old parameter values to new parameter values
            with timed("computegrad"):
                tmp_result = compute_lossandgrad(seg["ob"], seg["ac"], atarg)
                lossbefore = tmp_result[:-1]
                g = tmp_result[-1]
            lossbefore = allmean(np.array(lossbefore))
            g = allmean(g)
            if np.allclose(g, 0):
                print("Got zero gradient. not updating")
            else:
                with timed("cg"):
                    stepdir = cg(fisher_vector_product,
                                 g,
                                 cg_iters=cg_iters,
                                 verbose=rank == 0)
                assert np.isfinite(stepdir).all()
                shs = .5 * stepdir.dot(fisher_vector_product(stepdir))
                lm = np.sqrt(shs / max_kl)
                # print("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
                fullstep = stepdir / lm
                expectedimprove = g.dot(fullstep)
                surrbefore = lossbefore[0]
                stepsize = 1.0
                thbefore = get_flat()
                for _ in range(10):
                    thnew = thbefore + fullstep * stepsize
                    set_from_flat(thnew)
                    meanlosses = allmean(
                        np.array(compute_losses(seg["ob"], seg["ac"], atarg)))
                    surr = meanlosses[0]
                    kl = meanlosses[1]
                    improve = surr - surrbefore
                    print("Expected: %.3f Actual: %.3f" %
                          (expectedimprove, improve))
                    if not np.isfinite(meanlosses).all():
                        print("Got non-finite value of losses -- bad!")
                    elif kl > max_kl * 1.5:
                        print("violated KL constraint. shrinking step.")
                    elif improve < 0:
                        print("surrogate didn't improve. shrinking step.")
                    else:
                        print("Stepsize OK!")
                        break
                    stepsize *= .5
                else:
                    print("couldn't compute a good step")
                    set_from_flat(thbefore)
                if nworkers > 1 and iters_so_far % 20 == 0:
                    paramsums = MPI.COMM_WORLD.allgather(
                        (thnew.sum(),
                         vfadam.getflat().sum()))  # list of tuples
                    assert all(
                        np.allclose(ps, paramsums[0]) for ps in paramsums[1:])
            with timed("vf"):
                for _ in range(vf_iters):
                    for (mbob, mbret) in dataset.iterbatches(
                        (seg["ob"], seg["tdlamret"]),
                            include_final_partial_batch=False,
                            batch_size=128):
                        if hasattr(pi, "ob_rms"):
                            pi.ob_rms.update(
                                mbob)  # update running mean/std for policy
                        g = allmean(compute_vflossandgrad(mbob, mbret))
                        vfadam.update(g, vf_stepsize)

        print("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))

        # ------------------ Update D ------------------
        print("Optimizing Discriminator...")
        print(fmt_row(13, reward_giver.loss_name))
        ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob))
        batch_size = len(ob) // d_step
        d_losses = [
        ]  # list of tuples, each of which gives the loss for a minibatch
        for ob_batch, ac_batch in tqdm(
                dataset.iterbatches((ob, ac),
                                    include_final_partial_batch=False,
                                    batch_size=batch_size)):
            ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob_batch))
            # update running mean/std for reward_giver
            if hasattr(reward_giver, "obs_rms"):
                reward_giver.obs_rms.update(
                    np.concatenate((ob_batch, ob_expert), 0))
            tmp_result = reward_giver.lossandgrad(ob_batch, ac_batch,
                                                  ob_expert, ac_expert)
            newlosses = tmp_result[:-1]
            g = tmp_result[-1]
            d_adam.update(allmean(g), d_stepsize)
            d_losses.append(newlosses)
        print(fmt_row(13, np.mean(d_losses, axis=0)))

        timesteps_so_far += len(seg['ob'])
        iters_so_far += 1

        print("EpisodesSoFar", episodes_so_far)
        print("TimestepsSoFar", timesteps_so_far)
        print("TimeElapsed", time.time() - tstart)
Example #14
0
pol_entpen = (-entcoeff) * meanent

ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
surr1 = ratio * atarg  # surrogate from conservative policy iteration
surr2 = U.clip(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg  #
pol_surr = -U.mean(tf.minimum(surr1,
                              surr2))  # PPO's pessimistic surrogate (L^CLIP)
vf_loss = U.mean(tf.square(pi.vpred - ret))
total_loss = pol_surr + pol_entpen + vf_loss

losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

var_list = pi.get_trainable_variables()
lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                         losses + [U.flatgrad(total_loss, var_list)])
adam = MpiAdam(var_list, epsilon=adam_epsilon)

assign_old_eq_new = U.function(
    [], [],
    updates=[
        tf.assign(oldv, newv)
        for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())
    ])
compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

U.initialize()
adam.sync()

saver = tf.train.Saver()
if args.resume > 0:
def learn(env, policy_fn, *,
          timesteps_per_actorbatch,  # timesteps per actor per update
          clip_param, entcoeff,  # clipping parameter epsilon, entropy coeff
          optim_epochs, optim_stepsize, optim_batchsize,  # optimization hypers
          gamma, lam,  # advantage estimation
          max_timesteps = 0, max_episodes = 0, max_iters = 0, max_seconds = 0,  # time constraint
          callback = None,  # you can do anything in the callback, since it takes locals(), globals()
          adam_epsilon = 1e-5,
          schedule = 'constant'  # annealing for stepsize parameters (epsilon and adam)
          ):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_fn("pi", ob_space, ac_space)  # Construct network for new policy
    oldpi = policy_fn("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(dtype = tf.float32, shape = [None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype = tf.float32, shape = [None])  # Empirical return
    G_t_inv = tf.placeholder(dtype = tf.float32, shape = [None, None])
    alpha = tf.placeholder(dtype = tf.float32, shape = [1])

    td_v_target = tf.placeholder(dtype = tf.float32, shape = [1, 1])  # V target for RAC

    lrmult = tf.placeholder(name = 'lrmult', dtype = tf.float32,
                            shape = [])  # learning rate multiplier, updated with schedule
    # adv = tf.placeholder(dtype = tf.float32, shape = [1, 1])  # Advantage function for RAC

    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name = "ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg  #
    pol_surr = - tf.reduce_mean(tf.minimum(surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    vf_rac_loss = tf.reduce_mean(tf.square(pi.vpred - td_v_target))
    vf_rac_losses = [vf_rac_loss]
    vf_rac_loss_names = ["vf_rac_loss"]

    pol_rac_loss_surr1 = atarg * pi.pd.neglogp(ac) * ratio
    pol_rac_loss_surr2 = tf.clip_by_value(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg * pi.pd.neglogp(ac) #
    pol_rac_loss = tf.reduce_mean(tf.minimum(pol_rac_loss_surr1, pol_rac_loss_surr2))
    pol_rac_losses = [pol_rac_loss]
    pol_rac_loss_names = ["pol_rac_loss"]

    var_list = pi.get_trainable_variables()

    vf_final_var_list = [v for v in var_list if v.name.split("/")[1].startswith(
        "vf") and v.name.split("/")[2].startswith(
        "final")]
    pol_final_var_list = [v for v in var_list if v.name.split("/")[1].startswith(
        "pol") and v.name.split("/")[2].startswith(
        "final")]

    compatible_feature = U.flatgrad(pi.pd.neglogp(ac), pol_final_var_list)
    G_t_inv_next = 1/(1-alpha) * (G_t_inv -
                                  alpha * (G_t_inv * compatible_feature)*tf.transpose(G_t_inv * compatible_feature)
                                  / (1 - alpha + alpha * tf.transpose(compatible_feature) * G_t_inv * compatible_feature))

    # Train V function
    vf_lossandgrad = U.function([ob, td_v_target, lrmult],
                                vf_rac_losses + [U.flatgrad(vf_rac_loss, vf_final_var_list)])
    vf_adam = MpiAdam(vf_final_var_list, epsilon = adam_epsilon)

    # Train Policy
    pol_lossandgrad = U.function([ob, ac, atarg, lrmult],
                                 pol_rac_losses + [U.flatgrad(pol_rac_loss, pol_final_var_list)])
    pol_adam = MpiAdam(pol_final_var_list, epsilon = adam_epsilon)

    lossandgrad = U.function([ob, ac, atarg, ret, lrmult], losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon = adam_epsilon)

    assign_old_eq_new = U.function([], [], updates = [tf.assign(oldv, newv)
                                                      for (oldv, newv) in
                                                      zipsame(oldpi.get_variables(), pi.get_variables())])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    compute_v_pred = U.function([ob], [pi.vpred])
    get_pol_weights_num = np.sum([np.prod(v.get_shape().as_list()) for v in pol_final_var_list])
    get_compatible_feature = U.function([ob, ac], [compatible_feature])
    get_G_t_inv = U.function([ob, ac, G_t_inv, alpha], [G_t_inv_next])

    U.initialize()
    adam.sync()
    pol_adam.sync()
    vf_adam.sync()

    global timesteps_so_far, episodes_so_far, iters_so_far, \
        tstart, lenbuffer, rewbuffer, best_fitness
    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen = 100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen = 100)  # rolling buffer for episode rewards

    assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0,
                max_seconds > 0]) == 1, "Only one time constraint permitted"

    seg = None
    omega_t = np.random.rand(get_pol_weights_num)
    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break
        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult =  max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        t = 0
        ac = env.action_space.sample()  # not used, just so we have the datatype
        new = True  # marks if we're on first timestep of an episode
        ob = env.reset()

        cur_ep_ret = 0  # return in current episode
        cur_ep_len = 0  # len of current episode
        ep_rets = []  # returns of completed episodes in this segment
        ep_lens = []  # lengths of ...
        horizon = timesteps_per_actorbatch

        # Initialize history arrays
        obs = np.array([ob for _ in range(horizon)])
        rews = np.zeros(horizon, 'float32')
        vpreds = np.zeros(horizon, 'float32')
        news = np.zeros(horizon, 'int32')
        acs = np.array([ac for _ in range(horizon)])
        prevacs = acs.copy()

        rac_alpha = optim_stepsize * cur_lrmult * 0.1
        rac_beta = optim_stepsize * cur_lrmult * 0.001

        k = 1.0
        G_t_inv = [k * np.eye(get_pol_weights_num)]
        assign_old_eq_new()
        while True:
            if timesteps_so_far % 10000 == 0 and timesteps_so_far > 0:
                result_record()
            prevac = ac
            ac, vpred = pi.act(stochastic = True, ob = ob)
            # Slight weirdness here because we need value function at time T
            # before returning segment [0, T-1] so we get the correct
            # terminal value
            if t > 0 and t % horizon == 0:
                seg = {"ob": obs, "rew": rews, "vpred": vpreds, "new": news,
                       "ac": acs, "prevac": prevacs, "nextvpred": vpred * (1 - new),
                       "ep_rets": ep_rets, "ep_lens": ep_lens}
                ep_rets = []
                ep_lens = []
                break
            i = t % horizon
            obs[i] = ob
            vpreds[i] = vpred
            news[i] = new
            acs[i] = ac
            prevacs[i] = prevac
            if env.spec._env_name == "LunarLanderContinuous":
                ac = np.clip(ac, -1.0, 1.0)
            next_ob, rew, new, _ = env.step(ac)
            # Compute v target and TD
            v_target = rew + gamma * np.array(compute_v_pred(next_ob.reshape((1, ob.shape[0]))))
            adv = v_target - np.array(compute_v_pred(ob.reshape((1, ob.shape[0]))))
            G_t_inv =get_G_t_inv(ob.reshape((1, ob.shape[0])), ac.reshape((1, ac.shape[0])), G_t_inv[0], np.array([rac_alpha]))
            # Update V and Update Policy
            vf_loss, vf_g = vf_lossandgrad(ob.reshape((1, ob.shape[0])), v_target,
                                           rac_alpha)
            vf_adam.update(vf_g, rac_alpha)
            pol_loss, pol_g = pol_lossandgrad(ob.reshape((1, ob.shape[0])), ac.reshape((1, ac.shape[0])), adv.reshape(adv.shape[0], ),
                                              rac_beta)
            compatible_feature = np.array(
                get_compatible_feature(ob.reshape((1, ob.shape[0])), ac.reshape((1, ac.shape[0]))))
            compatible_feature_product = compatible_feature * compatible_feature.T
            omega_t = (np.eye(compatible_feature_product.shape[0]) - 0.1 * rac_alpha * compatible_feature_product).dot(
                omega_t) \
                      + 0.1 * rac_alpha * G_t_inv[0].dot(pol_g)

            pol_adam.update(omega_t, rac_beta)

            rews[i] = rew

            cur_ep_ret += rew
            cur_ep_len += 1
            timesteps_so_far += 1
            ob = next_ob
            if new:
                # print(
                #     "Episode {} - Total reward = {}, Total Steps = {}".format(episodes_so_far, cur_ep_ret, cur_ep_len))
                ep_rets.append(cur_ep_ret)
                ep_lens.append(cur_ep_len)
                rewbuffer.extend(ep_rets)
                lenbuffer.extend(ep_lens)
                cur_ep_ret = 0
                cur_ep_len = 0
                ob = env.reset()
                episodes_so_far += 1
            t += 1

        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"]
        vpredbefore = seg["vpred"] # predicted value function before udpate
        atarg = (atarg - atarg.mean()) / atarg.std() # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret), shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy

        assign_old_eq_new() # set old parameter values to new parameter values
        # logger.log("Optimizing...")
        # logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [] # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
        # logger.log(fmt_row(13, np.mean(losses, axis=0)))
        # logger.log("Current Iteration Training Performance:" + str(np.mean(seg["ep_rets"])))
        if iters_so_far == 0:
            result_record()
        iters_so_far += 1
Example #16
0
def create_graph(env,
                 pi_name,
                 policy_func,
                 *,
                 clip_param,
                 entcoeff,
                 adam_epsilon=1e-5):
    # Setup losses and stuff
    # ----------------------------------------
    with tf.name_scope(pi_name):
        ob_space = env.observation_space
        ac_space = env.action_space
        pi = policy_func("pi", pi_name, ob_space,
                         ac_space)  # Construct network for new policy
        oldpi = policy_func("oldpi", pi_name, ob_space,
                            ac_space)  # Network for old policy
        atarg = tf.placeholder(
            dtype=tf.float32,
            shape=[None])  # Target advantage function (if applicable)
        ret = tf.placeholder(dtype=tf.float32,
                             shape=[None])  # Empirical return

        lrmult = tf.placeholder(
            name='lrmult', dtype=tf.float32,
            shape=[])  # learning rate multiplier, updated with schedule
        clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

        ob = U.get_placeholder_cached(name="ob" + pi_name)
        ac = pi.pdtype.sample_placeholder([None])

        kloldnew = oldpi.pd.kl(pi.pd)
        ent = pi.pd.entropy()
        meankl = U.mean(kloldnew)
        meanent = U.mean(ent)
        pol_entpen = (-entcoeff) * meanent

        ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
        surr1 = ratio * atarg  # surrogate from conservative policy iteration
        surr2 = U.clip(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg  #
        pol_surr = -U.mean(tf.minimum(
            surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
        vf_loss = U.mean(tf.square(pi.vpred - ret))
        total_loss = pol_surr + pol_entpen + vf_loss
        losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
        loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

        var_list = pi.get_trainable_variables()
        lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                                 losses + [U.flatgrad(total_loss, var_list)])

        adam = MpiAdam(var_list, epsilon=adam_epsilon)

        assign_old_eq_new = U.function(
            [], [],
            updates=[
                tf.assign(oldv, newv) for (
                    oldv,
                    newv) in zipsame(oldpi.get_variables(), pi.get_variables())
            ])
        compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    return pi, oldpi, loss_names, lossandgrad, adam, assign_old_eq_new, compute_losses
Example #17
0
def learn(
        env,
        policy_func,
        *,
        timesteps_per_actorbatch,  # timesteps per actor per update
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        optim_epochs,
        optim_stepsize,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,  # advantage estimation
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        max_seconds=0,  # time constraint
        callback=None,  # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant',
        seed=1  # annealing for stepsize parameters (epsilon and adam)
):

    # We want to log:
    num_options = 1  # Hacky solution -> enables to use same logging!
    epoch = -1
    saves = True
    wsaves = True

    ### Book-keeping
    gamename = env.spec.id[:-3].lower()
    gamename += 'seed' + str(seed)
    #gamename += app
    version_name = 'officialPPO'

    dirname = '{}_{}_{}opts_saves/'.format(version_name, gamename, num_options)

    # retrieve everything using relative paths. Create a train_results folder where the repo has been cloned
    dirname_rel = os.path.dirname(__file__)
    splitted = dirname_rel.split("/")
    dirname_rel = ("/".join(dirname_rel.split("/")[:len(splitted) - 3]) + "/")
    dirname = dirname_rel + "train_results/" + dirname

    # Specify the paths where results shall be written to:
    src_code_path = dirname_rel + 'baselines/baselines/ppo1/'
    results_path = dirname_rel
    envs_path = dirname_rel + 'nfunk/envs_nf/'

    print(dirname)
    #input ("wait here after dirname")

    if wsaves:
        first = True
        if not os.path.exists(dirname):
            os.makedirs(dirname)
            first = False
        # while os.path.exists(dirname) and first:
        #     dirname += '0'

        files = ['pposgd_simple.py', 'mlp_policy.py', 'run_mujoco.py']
        first = True
        for i in range(len(files)):
            src = os.path.join(src_code_path) + files[i]
            print(src)
            #dest = os.path.join('/home/nfunk/results_NEW/ppo1/') + dirname
            dest = dirname + "src_code/"
            if (first):
                os.makedirs(dest)
                first = False
            print(dest)
            shutil.copy2(src, dest)
        # brute force copy normal env file at end of copying process:
        env_files = ['pendulum_nf.py']
        for i in range(len(env_files)):
            src = os.path.join(envs_path + env_files[i])
            shutil.copy2(src, dest)
        os.makedirs(dest + "assets/")
        src = os.path.join(envs_path + "assets/clockwise.png")
        shutil.copy2(src, dest + "assets/")

    np.random.seed(seed)
    tf.set_random_seed(seed)
    env.seed(seed)

    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_func("pi", ob_space,
                     ac_space)  # Construct network for new policy
    oldpi = policy_func("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = U.mean(kloldnew)
    meanent = U.mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = U.clip(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg  #
    pol_surr = -U.mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = U.mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    adam.sync()

    saver = tf.train.Saver(max_to_keep=10000)
    saver_best = tf.train.Saver(max_to_keep=1)

    ### More book-kepping
    results = []
    if saves:
        results = open(
            dirname + version_name + '_' + gamename + '_' + str(num_options) +
            'opts_' + '_results.csv', 'w')
        results_best_model = open(
            dirname + version_name + '_' + gamename + '_' + str(num_options) +
            'opts_' + '_bestmodel.csv', 'w')

        out = 'epoch,avg_reward'

        for opt in range(num_options):
            out += ',option {} dur'.format(opt)
        for opt in range(num_options):
            out += ',option {} std'.format(opt)
        for opt in range(num_options):
            out += ',option {} term'.format(opt)
        for opt in range(num_options):
            out += ',option {} adv'.format(opt)
        out += '\n'
        results.write(out)
        # results.write('epoch,avg_reward,option 1 dur, option 2 dur, option 1 term, option 2 term\n')
        results.flush()

    if epoch >= 0:

        print("Loading weights from iteration: " + str(epoch))

        filename = dirname + '{}_epoch_{}.ckpt'.format(gamename, epoch)
        saver.restore(U.get_session(), filename)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_actorbatch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    max_val = -100000
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values

        if iters_so_far % 50 == 0 and wsaves:
            print("weights are saved...")
            filename = dirname + '{}_epoch_{}.ckpt'.format(
                gamename, iters_so_far)
            save_path = saver.save(U.get_session(), filename)

        if (np.mean(rewbuffer) > max_val) and wsaves:
            max_val = np.mean(rewbuffer)
            results_best_model.write('epoch: ' + str(iters_so_far) + 'rew: ' +
                                     str(np.mean(rewbuffer)) + '\n')
            results_best_model.flush()
            filename = dirname + 'best.ckpt'.format(gamename, iters_so_far)
            save_path = saver_best.save(U.get_session(), filename)

        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()

        ### Book keeping
        if saves:
            out = "{},{}"
            #for _ in range(num_options): #out+=",{},{},{},{}"
            out += "\n"

            info = [iters_so_far, np.mean(rewbuffer)]

            results.write(out.format(*info))
            results.flush()
Example #18
0
def enjoy(
        env,
        policy_func,
        *,
        timesteps_per_actorbatch,  # timesteps per actor per update
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        optim_epochs,
        optim_stepsize,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,  # advantage estimation
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        max_seconds=0,  # time constraint
        callback=None,  # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
        save_name=None,
        save_per_acts=3,
        sensor=False,
        reload_name=None,
        target_pos=None):
    if sensor:
        ob_space = env.sensor_space
    else:
        ob_space = env.observation_space
    ac_space = env.action_space

    pi = policy_func("pi", ob_space,
                     ac_space)  # Construct network for new policy
    oldpi = policy_func("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return
    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed cliping parameter epsilon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param,
                             1.0 + clip_param) * atarg  #
    pol_surr = -tf.reduce_mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    adam.sync()

    if reload_name:
        saver = tf.train.Saver()
        saver.restore(tf.get_default_session(), reload_name)
        print("Loaded model successfully.")

    # from IPython import embed; embed()
    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_actorbatch,
                                     stochastic=True,
                                     sensor=sensor)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % (iters_so_far + 1))

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        iters_so_far += 1
Example #19
0
def learn(
        env,
        policy_func,
        *,
        timesteps_per_actorbatch,  # timesteps per actor per update
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        vfcoeff,
        optim_epochs,
        optim_stepsize,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,  # advantage estimation
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        max_seconds=0,  # time constraint
        callback=None,  # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
        sensor=False,
        save_name=None,
        save_per_acts=3,
        reload_name=None):
    # Setup losses and stuff
    # ----------------------------------------
    if sensor:
        ob_space = env.sensor_space
    else:
        ob_space = env.observation_space
    ac_space = env.action_space

    pi = policy_func("pi", ob_space,
                     ac_space)  # Construct network for new policy
    oldpi = policy_func("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return
    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed cliping parameter epsilon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param,
                             1.0 + clip_param) * atarg  #
    pol_surr = -tf.reduce_mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = vfcoeff * tf.reduce_mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    adam.sync()

    if reload_name:
        saver = tf.train.Saver()
        saver.restore(tf.get_default_session(), reload_name)
        print("Loaded model successfully.")

    # from IPython import embed; embed()
    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_actorbatch,
                                     stochastic=True,
                                     sensor=sensor)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % (iters_so_far + 1))

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        elapse = time.time() - tstart
        logger.record_tabular("TimeElapsed", elapse)

        #Iteration Recording
        record = 1
        if record:
            file_path = os.path.join(
                os.path.expanduser("~"),
                "PycharmProjects/Gibson_Exercise/gibson/utils/models/iterations"
            )
            try:
                os.mkdir(file_path)
            except OSError:
                pass

            if iters_so_far == 1:
                with open(os.path.join(
                        os.path.expanduser("~"),
                        'PycharmProjects/Gibson_Exercise/gibson/utils/models/iterations/values.csv'
                ),
                          'w',
                          newline='') as csvfile:
                    fieldnames = [
                        'Iteration', 'TimeSteps', 'Reward', 'LossEnt',
                        'LossVF', 'PolSur'
                    ]
                    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

                    writer.writeheader()
                    writer.writerow({
                        'Iteration': iters_so_far,
                        'TimeSteps': timesteps_so_far,
                        'Reward': np.mean(rews),
                        'LossEnt': meanlosses[4],
                        'LossVF': meanlosses[2],
                        'PolSur': meanlosses[1]
                    })
            else:
                with open(os.path.join(
                        os.path.expanduser("~"),
                        'PycharmProjects/Gibson_Exercise/gibson/utils/models/iterations/values.csv'
                ),
                          'a',
                          newline='') as csvfile:
                    fieldnames = [
                        'Iteration', 'TimeSteps', 'Reward', 'LossEnt',
                        'LossVF', 'PolSur'
                    ]
                    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

                    writer.writerow({
                        'Iteration': iters_so_far,
                        'TimeSteps': timesteps_so_far,
                        'Reward': np.mean(rews),
                        'LossEnt': meanlosses[4],
                        'LossVF': meanlosses[2],
                        'PolSur': meanlosses[1]
                    })

        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()

        load_number = 0
        if not reload_name == None:
            load_number = int(str(reload_name.split('_')[7]).split('.')[0])

        if save_name and (iters_so_far % save_per_acts == 0):
            base_path = os.path.dirname(os.path.abspath(__file__))
            print(base_path)
            out_name = os.path.join(
                base_path, 'models',
                save_name + '_' + str(iters_so_far + load_number) + ".model")
            U.save_state(out_name)
            print("Saved model successfully.")
Example #20
0
def learn(env, policy_fn, *,
          timesteps_per_actorbatch,  # timesteps per actor per update
          clip_param, entcoeff,  # clipping parameter epsilon, entropy coeff
          optim_epochs, optim_stepsize, optim_batchsize,  # optimization hypers
          gamma, lam,  # advantage estimation
          max_timesteps=0, max_episodes=0, max_iters=0, max_seconds=0,  # time constraint
          callback=None,  # you can do anything in the callback, since it takes locals(), globals()
          adam_epsilon=1e-5,
          schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
          gae_kstep=None,
          env_eval=None,
          saved_model=None,
          eval_at=50,
          save_at=50,
          normalize_atarg=True,
          experiment_spec=None,  # dict with: experiment_name, experiment_folder
          **extra_args
          ):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_fn("pi", ob_space, ac_space)  # Construct network for new policy
    oldpi = policy_fn("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(dtype=tf.float32, shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(name='lrmult', dtype=tf.float32, shape=[])  # learning rate multiplier, updated with schedule
    entromult = tf.placeholder(name='entromult', dtype=tf.float32, shape=[])  # entropy penalty multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    MPI_n_ranks = MPI.COMM_WORLD.Get_size()

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg  #
    pol_surr = - tf.reduce_mean(tf.minimum(surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult, entromult], losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function([], [], updates=[tf.assign(oldv, newv)
                                                    for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult, entromult], losses)

    acts = list()
    stats1 = list()
    stats2 = list()
    stats3 = list()
    for p in range(ac_space.shape[0]):
        acts.append(tf.placeholder(tf.float32, name="act_{}".format(p + 1)))
        stats1.append(tf.placeholder(tf.float32, name="stats1_{}".format(p + 1)))
        stats2.append(tf.placeholder(tf.float32, name="stats2_{}".format(p + 1)))
        stats3.append(tf.placeholder(tf.float32, name="stats3_{}".format(p + 1)))
        tf.summary.histogram("act_{}".format(p), acts[p])
        if pi.dist == 'gaussian':
            tf.summary.histogram("pd_mean_{}".format(p), stats1[p])
            tf.summary.histogram("pd_std_{}".format(p), stats2[p])
            tf.summary.histogram("pd_logstd_{}".format(p), stats3[p])
        else:
            tf.summary.histogram("pd_beta_{}".format(p), stats1[p])
            tf.summary.histogram("pd_alpha_{}".format(p), stats2[p])
            tf.summary.histogram("pd_alpha_beta_{}".format(p), stats3[p])

    rew = tf.placeholder(tf.float32, name="rew")
    tf.summary.histogram("rew", rew)
    summaries = tf.summary.merge_all()
    gather_summaries = U.function([ob, *acts, *stats1, *stats2, *stats3, rew], summaries)

    U.initialize()
    adam.sync()
    if saved_model is not None:
        U.load_state(saved_model)

    if (MPI.COMM_WORLD.Get_rank() == 0) & (experiment_spec is not None):
        # TensorBoard & Saver
        # ----------------------------------------
        if experiment_spec['experiment_folder'] is not None:
            path_tb = os.path.join(experiment_spec['experiment_folder'], 'tensorboard')
            path_logs = os.path.join(experiment_spec['experiment_folder'], 'logs')
            exp_name = '' if experiment_spec['experiment_name'] is not None else experiment_spec['experiment_name']
            summary_file = tf.summary.FileWriter(os.path.join(path_tb, exp_name), U.get_session().graph)
            saver = tf.train.Saver(max_to_keep=None)
            logger.configure(dir=os.path.join(path_logs, exp_name))
    else:
        logger.configure(format_strs=[])

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi, env, timesteps_per_actorbatch, stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0, max_seconds > 0]) == 1, "Only one time constraint permitted"

    while True:
        if callback:
            callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(iters_so_far) / max_iters, 0)
        elif 'exp' in schedule:
            current_lr = schedule.strip()
            _, d = current_lr.split('__')
            cur_lrmult = float(d) ** (float(iters_so_far) / max_iters)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()
        if gae_kstep is None:
            add_vtarg_and_adv(seg, gamma, lam)
            T = len(seg["rew"])
        else:
            calculate_advantage_and_vtarg(seg, gamma, lam, k_step=gae_kstep)
            T = len(seg["rew"]) - gae_kstep

        ob, ac, atarg, tdlamret = seg["ob"][:T], seg["ac"][:T], seg["adv"][:T], seg["tdlamret"][:T]
        vpredbefore = seg["vpred"][:T]  # predicted value function before udpate

        if normalize_atarg:
            eps = 1e-9
            atarg = (atarg - atarg.mean()) / (atarg.std() + eps)  # standardized advantage function estimate

        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret), shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        g_max = -np.Inf
        g_min = np.Inf
        g_mean = []
        for _ in range(optim_epochs):
            losses = []  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult, cur_lrmult)
                g_max = g.max() if g.max() > g_max else g_max
                g_min = g.min() if g.min() < g_min else g_min
                g_mean.append(g.mean())
                if np.isnan(np.sum(g)):
                    print('NaN in Gradient, skipping this update')
                    continue
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
                # logger.log(fmt_row(13, np.mean(losses, axis=0)))

        summary = tf.Summary()
        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult, cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.log(fmt_row(13, meanlosses))
            for (lossval, name) in zipsame(meanlosses, loss_names):
                logger.record_tabular("loss_" + name, lossval)
                summary.value.add(tag="loss_" + name, simple_value=lossval)

        logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("ItersSoFar (%)", iters_so_far / max_iters * 100)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        logger.record_tabular("TimePerIter", (time.time() - tstart) / (iters_so_far + 1))

        if MPI.COMM_WORLD.Get_rank() == 0:
            # Saves model
            if ((iters_so_far % save_at) == 0) & (iters_so_far != 0):
                if experiment_spec['experiment_folder'] is not None:
                    path_models = os.path.join(experiment_spec['experiment_folder'], 'models')
                    dir_path = os.path.join(path_models, exp_name)
                    if not os.path.exists(dir_path):
                        os.makedirs(dir_path)
                    saver.save(U.get_session(), os.path.join(dir_path, 'model'), global_step=iters_so_far)

        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()

            summ = gather_summaries(ob,
                                    *np.split(ac, ac_space.shape[0], axis=1),
                                    *np.split(seg['stat1'], ac_space.shape[0], axis=1),
                                    *np.split(seg['stat2'], ac_space.shape[0], axis=1),
                                    *np.split(seg['stat3'], ac_space.shape[0], axis=1),
                                    seg['rew'])

            summary.value.add(tag="total_loss", simple_value=meanlosses[:3].sum())
            summary.value.add(tag="explained_variance", simple_value=explained_variance(vpredbefore, tdlamret))
            summary.value.add(tag='EpRewMean', simple_value=np.mean(rewbuffer))
            summary.value.add(tag='EpLenMean', simple_value=np.mean(lenbuffer))
            summary.value.add(tag='EpThisIter', simple_value=len(lens))
            summary.value.add(tag='atarg_max', simple_value=atarg.max())
            summary.value.add(tag='atarg_min', simple_value=atarg.min())
            summary.value.add(tag='atarg_mean', simple_value=atarg.mean())
            summary.value.add(tag='GMean', simple_value=np.mean(g_mean))
            summary.value.add(tag='GMax', simple_value=g_max)
            summary.value.add(tag='GMin', simple_value=g_min)
            summary.value.add(tag='learning_rate', simple_value=cur_lrmult * optim_stepsize)
            summary.value.add(tag='AcMAX', simple_value=np.mean(seg["ac"].max()))
            summary.value.add(tag='AcMIN', simple_value=np.mean(seg["ac"].min()))
            summary_file.add_summary(summary, iters_so_far)
            summary_file.add_summary(summ, iters_so_far)

        iters_so_far += 1
Example #21
0
    def _init(self,
              ob_space,
              ac_space,
              hid_size,
              num_hid_layers,
              gaussian_fixed_var=False,
              popart=True):
        assert isinstance(ob_space, gym.spaces.Box)

        self.pdtype = pdtype = make_pdtype(ac_space)

        ob = U.get_placeholder(name="ob",
                               dtype=tf.float32,
                               shape=[None] + list(ob_space.shape))

        with tf.variable_scope("obfilter"):
            self.ob_rms = RunningMeanStd(shape=ob_space.shape)

        with tf.variable_scope("popart"):
            self.v_rms = RunningMeanStd(shape=[1])

        obz = tf.clip_by_value((ob - self.ob_rms.mean) / self.ob_rms.std, -5.0,
                               5.0)
        last_out = obz
        for i in range(num_hid_layers):
            last_out = tf.nn.tanh(
                dense(last_out,
                      hid_size,
                      "vffc%i" % (i + 1),
                      weight_init=U.normc_initializer(1.0)))
        self.norm_vpred = dense(last_out,
                                1,
                                "vffinal",
                                weight_init=U.normc_initializer(1.0))[:, 0]
        if popart:
            self.vpred = denormalize(self.norm_vpred, self.v_rms)
        else:
            self.vpred = self.norm_vpred

        last_out = obz
        for i in range(num_hid_layers):
            last_out = tf.nn.tanh(
                dense(last_out,
                      hid_size,
                      "polfc%i" % (i + 1),
                      weight_init=U.normc_initializer(1.0)))

        if gaussian_fixed_var and isinstance(ac_space, gym.spaces.Box):
            mean = dense(last_out,
                         pdtype.param_shape()[0] // 2, "polfinal",
                         U.normc_initializer(0.01))
            logstd = tf.get_variable(name="logstd",
                                     shape=[1, pdtype.param_shape()[0] // 2],
                                     initializer=tf.zeros_initializer())
            pdparam = tf.concat([mean, mean * 0.0 + logstd], axis=1)
        else:
            pdparam = dense(last_out,
                            pdtype.param_shape()[0], "polfinal",
                            U.normc_initializer(0.01))

        self.pd = pdtype.pdfromflat(pdparam)

        self.state_in = []
        self.state_out = []

        # change for BC
        stochastic = U.get_placeholder(name="stochastic",
                                       dtype=tf.bool,
                                       shape=())
        ac = U.switch(stochastic, self.pd.sample(), self.pd.mode())
        self.ac = ac
        self._act = U.function([stochastic, ob], [ac, self.vpred])

        self.use_popart = popart
        if popart:
            self.init_popart()

        ret = tf.placeholder(tf.float32, [None])
        vferr = tf.reduce_mean(tf.square(self.vpred - ret))
        self.vlossandgrad = U.function([ob, ret],
                                       U.flatgrad(vferr,
                                                  self.get_vf_variable()))
Example #22
0
                count_in_var = 1
                for dim in vars.shape._dims:
                    count_in_var *= dim
                trainable_var_count += count_in_var

            with tf.variable_scope('Lossandgrads'):
                # losses
                regret_loss = tf.reduce_mean(tf.square(y - y_true))
                vf_loss = tf.square(scalar * ua.lipschitz_constant -
                                    args.Lipschitz)

                # gradients
                regret_lossandgrad = U.function(
                    [x],
                    [regret_loss,
                     U.flatgrad(regret_loss, trainable_vars)])
                vf_lossandgrad = U.function(
                    [x],
                    [vf_loss, U.flatgrad(vf_loss, trainable_vars)])

            # define our train operation using Adam optimizer
            adam_all = MpiAdam(trainable_vars, epsilon=1e-3)

    # create a Saver to save UA afeter training
    saver = tf.train.Saver()
    with U.make_session() as sess:
        # create a SummaryWritter to save data for TensorBoard
        result_folder = dir + '/results/' + args.output_file + str(
            int(time.time()))
        sw = tf.summary.FileWriter(result_folder, sess.graph)
Example #23
0
def learn(make_env, make_policy, *,
          n_episodes,
          horizon,
          delta,
          gamma,
          max_iters,
          sampler=None,
          use_natural_gradient=False, #can be 'exact', 'approximate'
          fisher_reg=1e-2,
          iw_method='is',
          iw_norm='none',
          bound='J',
          line_search_type='parabola',
          save_weights=False,
          improvement_tol=0.,
          center_return=False,
          render_after=None,
          max_offline_iters=100,
          callback=None):

    np.set_printoptions(precision=3)
    max_samples = horizon * n_episodes

    if line_search_type == 'binary':
        line_search = line_search_binary
    elif line_search_type == 'parabola':
        line_search = line_search_parabola
    else:
        raise ValueError()

    # Building the environment
    env = make_env()
    ob_space = env.observation_space
    ac_space = env.action_space

    # Building the policy
    pi = make_policy('pi', ob_space, ac_space)
    oldpi = make_policy('oldpi', ob_space, ac_space)

    all_var_list = pi.get_trainable_variables()
    var_list = [v for v in all_var_list if v.name.split('/')[1].startswith('pol')]

    shapes = [U.intprod(var.get_shape().as_list()) for var in var_list]
    n_parameters = sum(shapes)

    # Placeholders
    ob_ = ob = U.get_placeholder_cached(name='ob')
    ac_ = pi.pdtype.sample_placeholder([max_samples], name='ac')
    mask_ = tf.placeholder(dtype=tf.float32, shape=(max_samples), name='mask')
    disc_rew_ = tf.placeholder(dtype=tf.float32, shape=(max_samples), name='disc_rew')
    gradient_ = tf.placeholder(dtype=tf.float32, shape=(n_parameters, 1), name='gradient')

    # Policy densities
    target_log_pdf = pi.pd.logp(ac_)
    behavioral_log_pdf = oldpi.pd.logp(ac_)
    log_ratio = target_log_pdf - behavioral_log_pdf
    
    # Split operations
    disc_rew_split = tf.stack(tf.split(disc_rew_ * mask_, n_episodes))
    log_ratio_split = tf.stack(tf.split(log_ratio * mask_, n_episodes))
    target_log_pdf_split = tf.stack(tf.split(target_log_pdf * mask_, n_episodes))
    mask_split = tf.stack(tf.split(mask_, n_episodes))
    
    # Renyi divergence
    emp_d2_split = tf.stack(tf.split(pi.pd.renyi(oldpi.pd, 2) * mask_, n_episodes))
    emp_d2_cum_split = tf.reduce_sum(emp_d2_split, axis=1)
    empirical_d2 = tf.reduce_mean(tf.exp(emp_d2_cum_split))

    # Return
    ep_return = tf.reduce_sum(mask_split * disc_rew_split, axis=1)
    if center_return:
        ep_return = ep_return - tf.reduce_mean(ep_return)

    return_mean = tf.reduce_mean(ep_return)
    return_std = U.reduce_std(ep_return)
    return_max = tf.reduce_max(ep_return)
    return_min = tf.reduce_min(ep_return)
    return_abs_max = tf.reduce_max(tf.abs(ep_return))
    
    if iw_method == 'pdis':
        raise NotImplementedError()
    elif iw_method == 'is':
        iw = tf.exp(tf.reduce_sum(log_ratio_split, axis=1))
        if iw_norm == 'none':
            iwn = iw / n_episodes
            w_return_mean = tf.reduce_sum(iwn * ep_return)
        elif iw_norm == 'sn':
            iwn = iw / tf.reduce_sum(iw)
            w_return_mean = tf.reduce_sum(iwn * ep_return)
        elif iw_norm == 'regression':
            iwn = iw / n_episodes
            mean_iw = tf.reduce_mean(iw)
            beta = tf.reduce_sum((iw - mean_iw) * ep_return * iw) / (tf.reduce_sum((iw - mean_iw) ** 2) + 1e-24)
            w_return_mean = tf.reduce_mean(iw * ep_return - beta * (iw - 1))
        else:
            raise NotImplementedError()
        
        ess_classic = tf.linalg.norm(iw, 1) ** 2 / tf.linalg.norm(iw, 2) ** 2
        sqrt_ess_classic = tf.linalg.norm(iw, 1) / tf.linalg.norm(iw, 2)
        ess_renyi = n_episodes / empirical_d2
    else:
        raise NotImplementedError()
    
    if bound == 'J':
        bound_ = w_return_mean
    elif bound == 'std-d2':
        bound_ = w_return_mean - tf.sqrt((1 - delta) / (delta * ess_renyi)) * return_std
    elif bound == 'max-d2':
        bound_ = w_return_mean - tf.sqrt((1 - delta) / (delta * ess_renyi)) * return_abs_max
    elif bound == 'max-ess':
        bound_ = w_return_mean - tf.sqrt((1 - delta) / delta) / sqrt_ess_classic * return_abs_max
    elif bound == 'std-ess':
        bound_ = w_return_mean - tf.sqrt((1 - delta) / delta) / sqrt_ess_classic * return_std
    else:
        raise NotImplementedError()

    losses = [bound_, return_mean, return_max, return_min, return_std, empirical_d2, w_return_mean,
              tf.reduce_max(iwn), tf.reduce_min(iwn), tf.reduce_mean(iwn), U.reduce_std(iwn), tf.reduce_max(iw),
              tf.reduce_min(iw), tf.reduce_mean(iw), U.reduce_std(iw), ess_classic, ess_renyi]
    loss_names = ['Bound', 'InitialReturnMean', 'InitialReturnMax', 'InitialReturnMin', 'InitialReturnStd',
                  'EmpiricalD2', 'ReturnMeanIW', 'MaxIWNorm', 'MinIWNorm', 'MeanIWNorm', 'StdIWNorm',
                  'MaxIW', 'MinIW', 'MeanIW', 'StdIW', 'ESSClassic', 'ESSRenyi']

    if use_natural_gradient:
        p = tf.placeholder(dtype=tf.float32, shape=[None])
        target_logpdf_episode = tf.reduce_sum(target_log_pdf_split * mask_split, axis=1)
        grad_logprob = U.flatgrad(tf.stop_gradient(iwn) * target_logpdf_episode, var_list)
        dot_product = tf.reduce_sum(grad_logprob * p)
        hess_logprob = U.flatgrad(dot_product, var_list)
        compute_linear_operator = U.function([p, ob_, ac_, disc_rew_, mask_], [-hess_logprob])


    assign_old_eq_new = U.function([], [], updates=[tf.assign(oldv, newv)
                for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())])
    
    compute_lossandgrad = U.function([ob_, ac_, disc_rew_, mask_], losses + [U.flatgrad(bound_, var_list)])
    compute_grad = U.function([ob_, ac_, disc_rew_, mask_], [U.flatgrad(bound_, var_list)])
    compute_bound = U.function([ob_, ac_, disc_rew_, mask_], [bound_])
    compute_losses = U.function([ob_, ac_, disc_rew_, mask_], losses)

    set_parameter = U.SetFromFlat(var_list)
    get_parameter = U.GetFlat(var_list)

    if sampler is None:
        seg_gen = traj_segment_generator(pi, env, n_episodes, horizon, stochastic=True)
        sampler = type("SequentialSampler", (object,), {"collect": lambda self, _: seg_gen.__next__()})()

    U.initialize()
    
    # Starting optimizing
    
    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=n_episodes)
    rewbuffer = deque(maxlen=n_episodes)
    
    while True:

        iters_so_far += 1

        if render_after is not None and iters_so_far % render_after == 0:
            if hasattr(env, 'render'):
                render(env, pi, horizon)

        if callback:
            callback(locals(), globals())

        if iters_so_far >= max_iters:
            print('Finised...')
            break

        logger.log('********** Iteration %i ************' % iters_so_far)
        
        theta = get_parameter()
        print(theta)
        with timed('sampling'):
            seg = sampler.collect(theta)
        
        add_disc_rew(seg, gamma)

        lens, rets = seg['ep_lens'], seg['ep_rets']
        lenbuffer.extend(lens)
        rewbuffer.extend(rets)
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)

        args = ob, ac, disc_rew, mask = seg['ob'], seg['ac'], seg['disc_rew'], seg['mask']

        assign_old_eq_new()

        def evaluate_loss():
            loss = compute_bound(*args)
            return loss[0]

        def evaluate_gradient():
            gradient = compute_grad(*args)
            return gradient[0]

        if use_natural_gradient:
            def evaluate_fisher_vector_prod(x):
                return compute_linear_operator(x, *args)[0] + fisher_reg * x

            def evaluate_natural_gradient(g):
                return cg(evaluate_fisher_vector_prod, g, cg_iters=10, verbose=0)
        else:
            evaluate_natural_gradient = None

        with timed('summaries before'):
            logger.record_tabular("Itaration", iters_so_far)
            logger.record_tabular("InitialBound", evaluate_loss())
            logger.record_tabular("EpLenMean", np.mean(lenbuffer))
            logger.record_tabular("EpRewMean", np.mean(rewbuffer))
            logger.record_tabular("EpThisIter", len(lens))
            logger.record_tabular("EpisodesSoFar", episodes_so_far)
            logger.record_tabular("TimestepsSoFar", timesteps_so_far)
            logger.record_tabular("TimeElapsed", time.time() - tstart)

        if save_weights:
            logger.record_tabular('Weights', str(get_parameter()))

        with timed("offline optimization"):

            theta, improvement = optimize_offline(theta,
                                                  set_parameter,
                                                  line_search,
                                                  evaluate_loss,
                                                  evaluate_gradient,
                                                  evaluate_natural_gradient,
                                                  max_offline_ite=max_offline_iters)

        set_parameter(theta)

        with timed('summaries after'):
            meanlosses = np.array(compute_losses(*args))
            for (lossname, lossval) in zip(loss_names, meanlosses):
                logger.record_tabular(lossname, lossval)

        logger.dump_tabular()

    env.close()
Example #24
0
def learn(
        env,
        policy_func,
        *,
        timesteps_per_batch,  # timesteps per actor per update
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        optim_epochs,
        optim_stepsize,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,  # advantage estimation
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        max_seconds=0,  # time constraint
        callback=None,  # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
        init_policy_params=None,
        policy_scope='pi'):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space

    pi = policy_func(policy_scope, ob_space,
                     ac_space)  # Construct network for new policy
    oldpi = policy_func("old" + policy_scope, ob_space,
                        ac_space)  # Network for old policy

    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return
    clip_tf = tf.placeholder(dtype=tf.float32)

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)

    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_tf,
                             1.0 + clip_tf * lrmult) * atarg
    pol_surr = -tf.reduce_mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)

    if hasattr(pi, 'additional_loss'):
        pol_surr += pi.additional_loss

    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))

    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    pol_var_list = [
        v for v in pi.get_trainable_variables()
        if 'placehold' not in v.name and 'offset' not in v.name and 'secondary'
        not in v.name and 'vf' not in v.name and 'pol' in v.name
    ]
    pol_var_size = np.sum([np.prod(v.shape) for v in pol_var_list])

    get_pol_flat = U.GetFlat(pol_var_list)
    set_pol_from_flat = U.SetFromFlat(pol_var_list)

    total_loss = pol_surr + pol_entpen + vf_loss

    lossandgrad = U.function([ob, ac, atarg, ret, lrmult, clip_tf],
                             losses + [U.flatgrad(total_loss, var_list)])
    pol_lossandgrad = U.function([ob, ac, atarg, ret, lrmult, clip_tf],
                                 losses +
                                 [U.flatgrad(total_loss, pol_var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult, clip_tf], losses)
    compute_losses_cpo = U.function(
        [ob, ac, atarg, ret, lrmult, clip_tf],
        [tf.reduce_mean(surr1), pol_entpen, vf_loss, meankl, meanent])
    compute_ratios = U.function([ob, ac], ratio)
    compute_kls = U.function([ob, ac], kloldnew)

    compute_rollout_old_prob = U.function([ob, ac],
                                          tf.reduce_mean(oldpi.pd.logp(ac)))
    compute_rollout_new_prob = U.function([ob, ac],
                                          tf.reduce_mean(pi.pd.logp(ac)))
    compute_rollout_new_prob_min = U.function([ob, ac],
                                              tf.reduce_min(pi.pd.logp(ac)))

    update_ops = {}
    update_placeholders = {}
    for v in pi.get_trainable_variables():
        update_placeholders[v.name] = tf.placeholder(v.dtype,
                                                     shape=v.get_shape())
        update_ops[v.name] = v.assign(update_placeholders[v.name])

    # compute fisher information matrix
    dims = [int(np.prod(p.shape)) for p in pol_var_list]
    logprob_grad = U.flatgrad(tf.reduce_mean(pi.pd.logp(ac)), pol_var_list)
    compute_logprob_grad = U.function([ob, ac], logprob_grad)

    U.initialize()

    adam.sync()

    if init_policy_params is not None:
        cur_scope = pi.get_variables()[0].name[0:pi.get_variables()[0].name.
                                               find('/')]
        orig_scope = list(init_policy_params.keys()
                          )[0][0:list(init_policy_params.keys())[0].find('/')]
        print(cur_scope, orig_scope)
        for i in range(len(pi.get_variables())):
            if pi.get_variables()[i].name.replace(cur_scope, orig_scope,
                                                  1) in init_policy_params:
                assign_op = pi.get_variables()[i].assign(
                    init_policy_params[pi.get_variables()[i].name.replace(
                        cur_scope, orig_scope, 1)])
                tf.get_default_session().run(assign_op)
                assign_op = oldpi.get_variables()[i].assign(
                    init_policy_params[pi.get_variables()[i].name.replace(
                        cur_scope, orig_scope, 1)])
                tf.get_default_session().run(assign_op)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_batch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    prev_params = {}
    for v in var_list:
        if 'pol' in v.name:
            prev_params[v.name] = v.eval()

    optim_seg = None

    grad_scale = 1.0

    while True:
        if MPI.COMM_WORLD.Get_rank() == 0:
            print('begin')
            memory()
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()

        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        unstandardized_adv = np.copy(atarg)
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate

        args = seg["ob"], seg["ac"], atarg
        fvpargs = [arr for arr in args]

        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))

        cur_clip_val = clip_param

        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)

        for epoch in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult, cur_clip_val)
                adam.update(g * grad_scale, optim_stepsize * cur_lrmult)
                losses.append(newlosses)

            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult, clip_param)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)

        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.log(fmt_row(13, meanlosses))
            for (lossval, name) in zipsame(meanlosses, loss_names):
                logger.record_tabular("loss_" + name, lossval)

        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.record_tabular("EpLenMean", np.mean(lenbuffer))
            logger.record_tabular("EpRewMean", np.mean(rewbuffer))
            logger.record_tabular("EpThisIter", len(lens))
            logger.record_tabular(
                "PolVariance",
                repr(adam.getflat()[-env.action_space.shape[0]:]))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.record_tabular("EpisodesSoFar", episodes_so_far)
            logger.record_tabular("TimestepsSoFar", timesteps_so_far)
            logger.record_tabular("TimeElapsed", time.time() - tstart)

        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()
        if MPI.COMM_WORLD.Get_rank() == 0:
            print('end')
            memory()

    return pi, np.mean(rewbuffer)
def learn(
    env,
    test_env,
    policy_func,
    *,
    timesteps_per_batch,  # timesteps per actor per update
    clip_param,
    optim_epochs,
    optim_stepsize,
    optim_batchsize,  # optimization hypers
    gamma,
    lam,  # advantage estimation
    max_timesteps=0,
    max_episodes=0,
    max_iters=0,
    max_seconds=0,  # time constraint
    entcoeff=0.0,
    vf_coef=0.5,
    callback=None,  # you can do anything in the callback, since it takes locals(), globals()
    adam_epsilon=1e-5,
    schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
    save_interval=50,
    #load_path = "C:\\Users\\Yangang REN\\AppData\\Local\\Temp\\openai-2019-11-21-10-40-10-039590\\checkpoints\\00351"
    load_path=None):
    """
    :param env:
    :param test_env:
    :param policy_func:
    :param timesteps_per_batch:
    :param clip_param:
    :param optim_epochs:
    :param optim_stepsize:
    :param optim_batchsize:
    :param gamma:
    :param lam:
    :param max_timesteps:
    :param max_episodes:
    :param max_iters:
    :param max_seconds:
    :param entcoeff:
    :param vf_coef: float                   value function loss coefficient in the optimization objective
    :param callback:
    :param adam_epsilon:
    :param schedule:
    :param save_interval:
    :param load_path:
    :return:
    """

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    rew_mean = []

    # get state and action space
    ob_space = env.observation_space
    pro_ac_space = env.action_space
    adv_ac_space = env.adv_action_space

    # Construct network for new policy
    pro_pi = policy_func("pro_pi", ob_space, pro_ac_space)
    pro_oldpi = policy_func("pro_oldpi", ob_space, pro_ac_space)
    adv_pi = policy_func("adv_pi", ob_space, adv_ac_space)
    adv_oldpi = policy_func("adv_oldpi", ob_space, adv_ac_space)

    pro_atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    adv_atarg = tf.placeholder(dtype=tf.float32, shape=[None])
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule

    # Annealed cliping parameter epislon
    clip_param = clip_param * lrmult

    ob = U.get_placeholder_cached(name="ob")
    pro_ac = pro_pi.pdtype.sample_placeholder([None])
    adv_ac = adv_pi.pdtype.sample_placeholder([None])

    pro_kloldnew = pro_oldpi.pd.kl(pro_pi.pd)  # compute kl difference
    adv_kloldnew = adv_oldpi.pd.kl(adv_pi.pd)
    pro_ent = pro_pi.pd.entropy()
    adv_ent = adv_pi.pd.entropy()
    pro_meankl = tf.reduce_mean(pro_kloldnew)
    adv_meankl = tf.reduce_mean(adv_kloldnew)

    pro_meanent = tf.reduce_mean(pro_ent)
    adv_meanent = tf.reduce_mean(adv_ent)
    pro_pol_entpen = (-entcoeff) * pro_meanent
    adv_pol_entpen = (-entcoeff) * adv_meanent

    pro_ratio = tf.exp(pro_pi.pd.logp(pro_ac) - pro_oldpi.pd.logp(pro_ac))
    adv_ratio = tf.exp(adv_pi.pd.logp(adv_ac) - adv_oldpi.pd.logp(adv_ac))

    pro_surr1 = pro_ratio * pro_atarg  # surrogate from conservative policy iteration
    adv_surr1 = adv_ratio * adv_atarg

    pro_surr2 = tf.clip_by_value(pro_ratio, 1.0 - clip_param,
                                 1.0 + clip_param) * pro_atarg
    adv_surr2 = tf.clip_by_value(adv_ratio, 1.0 - clip_param,
                                 1.0 + clip_param) * adv_atarg

    # TODO:check this code carefully
    pro_pol_surr = -tf.reduce_mean(tf.minimum(pro_surr1, pro_surr2))
    adv_pol_surr = tf.reduce_mean(tf.minimum(adv_surr1, adv_surr2))

    pro_vf_loss = tf.reduce_mean(tf.square(pro_pi.vpred - ret))
    adv_vf_loss = tf.reduce_mean(tf.square(adv_pi.vpred - ret))

    # FIXME: do not forget cofficient between different loss
    pro_total_loss = pro_pol_surr + pro_pol_entpen + vf_coef * pro_vf_loss
    adv_total_loss = adv_pol_surr + adv_pol_entpen + vf_coef * adv_vf_loss

    pro_losses = [
        pro_pol_surr, pro_pol_entpen, pro_vf_loss, pro_meankl, pro_meanent
    ]
    pro_loss_names = [
        "pro_pol_surr", "pro_pol_entpen", "pro_vf_loss", "pro_kl", "pro_ent"
    ]
    adv_losses = [
        adv_pol_surr, adv_pol_entpen, adv_vf_loss, adv_meankl, adv_meanent
    ]
    adv_loss_names = [
        "adv_pol_surr", "adv_pol_entpen", "adv_vf_loss", "adv_kl", "adv_ent"
    ]

    pro_var_list = pro_pi.get_trainable_variables()
    adv_var_list = adv_pi.get_trainable_variables()

    pro_lossandgrad = U.function([ob, pro_ac, pro_atarg, ret, lrmult],
                                 pro_losses +
                                 [U.flatgrad(pro_total_loss, pro_var_list)])
    adv_lossandgrad = U.function([ob, adv_ac, adv_atarg, ret, lrmult],
                                 adv_losses +
                                 [U.flatgrad(adv_total_loss, adv_var_list)])
    pro_adam = MpiAdam(pro_var_list, epsilon=adam_epsilon)
    adv_adam = MpiAdam(adv_var_list, epsilon=adam_epsilon)

    pro_assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv) for (oldv, newv) in zipsame(
                pro_oldpi.get_variables(), pro_pi.get_variables())
        ])
    adv_assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv) for (oldv, newv) in zipsame(
                adv_oldpi.get_variables(), adv_pi.get_variables())
        ])
    # U.function(inputs, outputs)
    pro_compute_losses = U.function([ob, pro_ac, pro_atarg, ret, lrmult],
                                    pro_losses)
    adv_compute_losses = U.function([ob, adv_ac, adv_atarg, ret, lrmult],
                                    adv_losses)

    U.initialize()
    pro_adam.sync()
    adv_adam.sync()

    save = functools.partial(save_variables, sess=get_session())
    load = functools.partial(load_variables, sess=get_session())

    # TODO: load save the path
    if load_path is not None:
        load(load_path)
        print('Loading model and running it…')
        max_iters = 0

    # Prepare for rollouts
    seg_gen = traj_segment_generator(pro_pi,
                                     adv_pi,
                                     env,
                                     timesteps_per_batch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    # Begin to update the loss function
    for update in range(1, max_iters + 1):
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        # adjusting the learning rate
        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = 1.0 - (update - 1.0) / max_iters
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % (iters_so_far + 1))

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, pro_ac, adv_ac, pro_atarg, adv_atarg, pro_tdlamret, adv_tdlamret = seg[
            "ob"], seg["pro_ac"], seg["adv_ac"], seg["pro_adv"], seg[
                "adv_adv"], seg["pro_tdlamret"], seg["adv_tdlamret"]
        pro_vpredbefore = seg[
            "pro_vpred"]  # predicted value function before udpate
        adv_vpredbefore = seg["adv_vpred"]
        # standardized advantage function estimate
        pro_atarg = (pro_atarg - pro_atarg.mean()) / (pro_atarg.std() + 1e-8)
        adv_atarg = (adv_atarg - adv_atarg.mean()) / (adv_atarg.std() + 1e-8)

        # TODO
        d = Dataset(dict(ob=ob, ac=pro_ac, atarg=pro_atarg,
                         vtarg=pro_tdlamret),
                    shuffle=not pro_pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pro_pi, "ob_rms"):
            pro_pi.ob_rms.update(ob)  # update running mean/std for policy

        pro_assign_old_eq_new(
        )  # set old parameter values to new parameter values

        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            pro_losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = pro_lossandgrad(batch["ob"], batch["ac"],
                                                batch["atarg"], batch["vtarg"],
                                                cur_lrmult)
                pro_adam.update(g, optim_stepsize * cur_lrmult)
                pro_losses.append(newlosses)

        pro_losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = pro_compute_losses(batch["ob"], batch["ac"],
                                           batch["atarg"], batch["vtarg"],
                                           cur_lrmult)
            pro_losses.append(newlosses)
        pro_meanlosses, _, _ = mpi_moments(pro_losses, axis=0)

        # Training the adversary agent
        d = Dataset(dict(ob=ob, ac=adv_ac, atarg=adv_atarg,
                         vtarg=adv_tdlamret),
                    shuffle=not adv_pi.recurrent)
        if hasattr(adv_pi, "ob_rms"): adv_pi.ob_rms.update(ob)
        adv_assign_old_eq_new()

        # logger.log(fmt_row(13, adv_loss_names))
        for _ in range(optim_epochs):
            adv_losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = adv_lossandgrad(batch["ob"], batch["ac"],
                                                batch["atarg"], batch["vtarg"],
                                                cur_lrmult)
                adv_adam.update(g, optim_stepsize * cur_lrmult)
                adv_losses.append(newlosses)

        adv_losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = adv_compute_losses(batch["ob"], batch["ac"],
                                           batch["atarg"], batch["vtarg"],
                                           cur_lrmult)
            adv_losses.append(newlosses)
        adv_meanlosses, _, _ = mpi_moments(adv_losses, axis=0)

        # print the results
        logger.logkv("pro_policy_vf", pro_meanlosses[2])
        logger.logkv("adv_policy_vf", adv_meanlosses[2])

        # test
        # curr_rew = evaluate(pro_pi, test_env)
        # rew_mean.append(curr_rew)
        # print(curr_rew)
        curr_rew = evaluate(pro_pi, adv_pi, test_env)
        rew_mean.append(curr_rew)
        logger.logkv("test reward", curr_rew)

        # logger.record_tabular("ev_tdlam_before", explained_variance(pro_vpredbefore, pro_tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        logger.logkv('eprewmean', safemean(rewbuffer))
        logger.logkv('eplenmean', safemean(lenbuffer))
        logger.dumpkvs()

        if save_interval and (update == 1 or iters_so_far % save_interval
                              == 0) and logger.get_dir():
            checkdir = osp.join(logger.get_dir(), 'checkpoints')
            os.makedirs(checkdir, exist_ok=True)
            savepath = osp.join(checkdir, '%.5i' % update)
            print('Saving to…', savepath)
            save(savepath)
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

    # return np.array(rew_mean)
    return pro_pi, adv_pi
Example #26
0
def learn(
        env,
        test_env,
        policy_fn,
        *,
        timesteps_per_actorbatch,  # timesteps per actor per update
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        optim_epochs,
        optim_stepsize,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,  # advantage estimation
        # CMAES
    max_fitness,  # has to be negative, as cmaes consider minization
        popsize,
        gensize,
        bounds,
        sigma,
        eval_iters,
        max_v_train_iter,
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        max_seconds=0,
        # time constraint
        callback=None,
        # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant',
        # annealing for stepsize parameters (epsilon and adam)
        seed,
        env_id):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_fn("pi", ob_space,
                   ac_space)  # Construct network for new policy
    oldpi = policy_fn("oldpi", ob_space, ac_space)  # Network for old policy
    backup_pi = policy_fn(
        "backup_pi", ob_space, ac_space
    )  # Construct a network for every individual to adapt during the es evolution
    pi_zero = policy_fn(
        "zero_pi", ob_space,
        ac_space)  # pi_0 will only be updated along with iterations

    reward = tf.placeholder(dtype=tf.float32, shape=[None])  # step rewards
    pi_params = tf.placeholder(dtype=tf.float32, shape=[None])
    old_pi_params = tf.placeholder(dtype=tf.float32, shape=[None])
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule

    bound_coeff = tf.placeholder(
        name='bound_coeff', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule

    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    next_ob = U.get_placeholder_cached(
        name="next_ob")  # next step observation for updating q function
    ac = U.get_placeholder_cached(
        name="act")  # action placeholder for computing q function
    mean_ac = U.get_placeholder_cached(
        name="mean_act")  # action placeholder for computing q function

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    param_dist = tf.reduce_mean(tf.square(pi_params - old_pi_params))
    mean_action_loss = tf.cast(
        tf.reduce_mean(tf.square(1.0 - pi.pd.mode() / oldpi.pd.mode())),
        tf.float32)

    pi_adv = (pi.qpred - pi.vpred)
    adv_mean, adv_var = tf.nn.moments(pi_adv, axes=[0])
    normalized_pi_adv = (pi_adv - adv_mean) / tf.sqrt(adv_var)

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param,
                             1.0 + clip_param) * atarg  #
    pol_surr = -tf.reduce_mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)

    # qf_loss = tf.reduce_mean(tf.square(reward + gamma * pi.mean_qpred - pi.qpred))
    qf_loss = tf.reduce_mean(
        U.huber_loss(reward + gamma * pi.mean_qpred - pi.qpred))
    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    qf_losses = [qf_loss]
    vf_losses = [vf_loss]
    pol_loss = pol_surr + pol_entpen
    # pol_loss = pol_surr + pol_entpen

    # Advantage function should be improved
    losses = [pol_loss, pol_entpen, meankl, meanent]
    loss_names = ["pol_surr_2", "pol_entpen", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    qf_var_list = [
        v for v in var_list if v.name.split("/")[1].startswith("qf")
    ]
    mean_qf_var_list = [
        v for v in var_list if v.name.split("/")[1].startswith("meanqf")
    ]
    vf_var_list = [
        v for v in var_list if v.name.split("/")[1].startswith("vf")
    ]
    pol_var_list = [
        v for v in var_list if v.name.split("/")[1].startswith("pol")
    ]

    vf_lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                                vf_losses + [U.flatgrad(vf_loss, vf_var_list)])

    qf_lossandgrad = U.function(
        [ob, ac, next_ob, mean_ac, lrmult, reward, atarg],
        qf_losses + [U.flatgrad(qf_loss, qf_var_list)])

    qf_adam = MpiAdam(qf_var_list, epsilon=adam_epsilon)

    vf_adam = MpiAdam(vf_var_list, epsilon=adam_epsilon)

    assign_target_q_eq_eval_q = U.function(
        [], [],
        updates=[
            tf.assign(target_q, eval_q)
            for (target_q, eval_q) in zipsame(mean_qf_var_list, qf_var_list)
        ])

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])

    assign_backup_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(backup_v, newv) for (
                backup_v,
                newv) in zipsame(backup_pi.get_variables(), pi.get_variables())
        ])
    assign_new_eq_backup = U.function(
        [], [],
        updates=[
            tf.assign(newv, backup_v)
            for (newv, backup_v
                 ) in zipsame(pi.get_variables(), backup_pi.get_variables())
        ])

    mean_pi_actions = U.function(
        [ob], [pi.pd.mode()])  # later for computing pol_loss
    # Compute all losses
    compute_pol_losses = U.function([ob, ob, ac, lrmult, atarg], [pol_loss])

    U.initialize()

    get_pi_flat_params = U.GetFlat(pol_var_list)
    set_pi_flat_params = U.SetFromFlat(pol_var_list)

    vf_adam.sync()
    qf_adam.sync()

    global timesteps_so_far, episodes_so_far, iters_so_far, \
        tstart, lenbuffer, rewbuffer, tstart, ppo_timesteps_so_far, best_fitness

    episodes_so_far = 0
    timesteps_so_far = 0
    ppo_timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    best_fitness = np.inf

    eval_gen = traj_segment_generator_eval(pi,
                                           test_env,
                                           timesteps_per_actorbatch,
                                           stochastic=True)  # For evaluation
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_actorbatch,
                                     stochastic=True,
                                     eval_gen=eval_gen)  # For train V Func

    # Build generator for all solutions
    actors = []
    for i in range(popsize):
        newActor = traj_segment_generator(pi,
                                          env,
                                          timesteps_per_actorbatch,
                                          stochastic=True,
                                          eval_gen=eval_gen)
        actors.append(newActor)

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    while True:
        if max_timesteps and timesteps_so_far >= max_timesteps:
            print("Max time steps")
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            print("Max episodes")
            break
        elif max_iters and iters_so_far >= max_iters:
            print("Max iterations")
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            print("Max time")
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)

        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        # Generate new samples
        # Train V func
        ob_segs = None
        for i in range(max_v_train_iter):
            logger.log("Iteration:" + str(iters_so_far) +
                       " - sub-train iter for V func:" + str(i))
            logger.log("Generate New Samples")
            seg = seg_gen.__next__()
            add_vtarg_and_adv(seg, gamma, lam)

            ob, ac, next_ob, atarg, reward, tdlamret, traj_idx = seg["ob"], seg["ac"], seg["next_ob"], seg["adv"], seg[
                "rew"], seg["tdlamret"], \
                                                                 seg["traj_index"]
            atarg = (atarg - atarg.mean()) / atarg.std(
            )  # standardized advantage function estimate
            d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                        shuffle=not pi.recurrent)
            optim_batchsize = optim_batchsize or ob.shape[0]

            if hasattr(pi, "ob_rms"):
                pi.ob_rms.update(
                    ob)  # update running mean/std for normalization

            assign_old_eq_new(
            )  # set old parameter values to new parameter values
            # Train V function
            logger.log("Training V Func and Evaluating V Func Losses")
            for _ in range(optim_epochs):
                losses = [
                ]  # list of tuples, each of which gives the loss for a minibatch
                for batch in d.iterate_once(optim_batchsize):
                    *vf_losses, g = vf_lossandgrad(batch["ob"], batch["ac"],
                                                   batch["atarg"],
                                                   batch["vtarg"], cur_lrmult)
                    vf_adam.update(g, optim_stepsize * cur_lrmult)
                    losses.append(vf_losses)
                logger.log(fmt_row(13, np.mean(losses, axis=0)))

            d_q = Dataset(dict(ob=ob,
                               ac=ac,
                               next_ob=next_ob,
                               reward=reward,
                               atarg=atarg,
                               vtarg=tdlamret),
                          shuffle=not pi.recurrent)

            # Re-train q function
            logger.log("Training Q Func Evaluating Q Func Losses")
            for _ in range(optim_epochs):
                losses = [
                ]  # list of tuples, each of which gives the loss for a minibatch
                for batch in d_q.iterate_once(optim_batchsize):
                    *qf_losses, g = qf_lossandgrad(
                        batch["ob"], batch["ac"], batch["next_ob"],
                        mean_pi_actions(batch["ob"])[0], cur_lrmult,
                        batch["reward"], batch["atarg"])
                    qf_adam.update(g, optim_stepsize * cur_lrmult)
                    losses.append(qf_losses)
                logger.log(fmt_row(13, np.mean(losses, axis=0)))

            assign_target_q_eq_eval_q()

        pi0_fitness = compute_pol_losses(ob, ob,
                                         mean_pi_actions(ob)[0], cur_lrmult,
                                         atarg)
        logger.log("Best fitness for Pi0:" + str(np.mean(atarg)))
        logger.log("Best fitness for Pi0:" + str(pi0_fitness))

        # CMAES Train Policy
        assign_old_eq_new()  # set old parameter values to new parameter values
        assign_backup_eq_new()  # backup current policy
        flatten_weights = get_pi_flat_params()
        opt = cma.CMAOptions()
        opt['tolfun'] = max_fitness
        opt['popsize'] = popsize
        opt['maxiter'] = gensize
        opt['verb_disp'] = 0
        opt['verb_log'] = 0
        opt['seed'] = seed
        opt['AdaptSigma'] = True
        es = cma.CMAEvolutionStrategy(flatten_weights, sigma, opt)
        while True:
            if es.countiter >= gensize:
                logger.log("Max generations for current layer")
                break
            logger.log("Iteration:" + str(iters_so_far) +
                       " - sub-train Generation for Policy:" +
                       str(es.countiter))
            logger.log("Sigma=" + str(es.sigma))
            solutions = es.ask()
            costs = []
            lens = []

            assign_backup_eq_new()  # backup current policy
            for id, solution in enumerate(solutions):
                set_pi_flat_params(solution)
                losses = []
                # cost = compute_pol_losses(ob_segs['ob'], ob_segs['ob'], mean_pi_actions(ob_segs['ob'])[0])
                cost = compute_pol_losses(ob, ob,
                                          mean_pi_actions(ob)[0], cur_lrmult,
                                          atarg)
                costs.append(cost[0])
                assign_new_eq_backup()
            # Weights decay
            l2_decay = compute_weight_decay(0.99, solutions)
            costs += l2_decay
            # costs, real_costs = fitness_normalization(costs)
            costs, real_costs = fitness_rank(costs)
            es.tell_real_seg(solutions=solutions,
                             function_values=costs,
                             real_f=real_costs,
                             segs=None)
            logger.log("best_fitness:" + str(best_fitness) +
                       " current best fitness:" + str(es.result[1]))
            best_solution = es.result[0]
            best_fitness = es.result[1]
            logger.log("Best Solution Fitness:" + str(best_fitness))
            set_pi_flat_params(best_solution)
        sigma = es.sigma

        iters_so_far += 1
        episodes_so_far += sum(lens)
Example #27
0
def learn(make_env, make_policy, *,
          n_episodes,
          horizon,
          delta,
          gamma,
          max_iters,
          sampler=None,
          use_natural_gradient=False, #can be 'exact', 'approximate'
          fisher_reg=1e-2,
          iw_method='is',
          iw_norm='none',
          bound='J',
          line_search_type='parabola',
          save_weights=False,
          improvement_tol=0.,
          center_return=False,
          render_after=None,
          max_offline_iters=100,
          callback=None,
          clipping=False,
          entropy='none',
          positive_return=False,
          reward_clustering='none'):

    np.set_printoptions(precision=3)
    max_samples = horizon * n_episodes

    if line_search_type == 'binary':
        line_search = line_search_binary
    elif line_search_type == 'parabola':
        line_search = line_search_parabola
    else:
        raise ValueError()

    # Building the environment
    env = make_env()
    ob_space = env.observation_space
    ac_space = env.action_space

    # Building the policy
    pi = make_policy('pi', ob_space, ac_space)
    oldpi = make_policy('oldpi', ob_space, ac_space)

    all_var_list = pi.get_trainable_variables()
    var_list = [v for v in all_var_list if v.name.split('/')[1].startswith('pol')]

    shapes = [U.intprod(var.get_shape().as_list()) for var in var_list]
    n_parameters = sum(shapes)

    # Placeholders
    ob_ = ob = U.get_placeholder_cached(name='ob')
    ac_ = pi.pdtype.sample_placeholder([max_samples], name='ac')
    mask_ = tf.placeholder(dtype=tf.float32, shape=(max_samples), name='mask')
    rew_ = tf.placeholder(dtype=tf.float32, shape=(max_samples), name='rew')
    disc_rew_ = tf.placeholder(dtype=tf.float32, shape=(max_samples), name='disc_rew')
    clustered_rew_ = tf.placeholder(dtype=tf.float32, shape=(n_episodes))
    gradient_ = tf.placeholder(dtype=tf.float32, shape=(n_parameters, 1), name='gradient')
    iter_number_ = tf.placeholder(dtype=tf.int32, name='iter_number')
    losses_with_name = []

    # Policy densities
    target_log_pdf = pi.pd.logp(ac_)
    behavioral_log_pdf = oldpi.pd.logp(ac_)
    log_ratio = target_log_pdf - behavioral_log_pdf

    # Split operations
    disc_rew_split = tf.stack(tf.split(disc_rew_ * mask_, n_episodes))
    rew_split = tf.stack(tf.split(rew_ * mask_, n_episodes))
    log_ratio_split = tf.stack(tf.split(log_ratio * mask_, n_episodes))
    target_log_pdf_split = tf.stack(tf.split(target_log_pdf * mask_, n_episodes))
    behavioral_log_pdf_split = tf.stack(tf.split(behavioral_log_pdf * mask_, n_episodes))
    mask_split = tf.stack(tf.split(mask_, n_episodes))

    # Renyi divergence
    emp_d2_split = tf.stack(tf.split(pi.pd.renyi(oldpi.pd, 2) * mask_, n_episodes))
    emp_d2_cum_split = tf.reduce_sum(emp_d2_split, axis=1)
    empirical_d2 = tf.reduce_mean(tf.exp(emp_d2_cum_split))

    # Return
    ep_return = clustered_rew_ #tf.reduce_sum(mask_split * disc_rew_split, axis=1)
    if clipping:
        rew_split = tf.clip_by_value(rew_split, -1, 1)

    if center_return:
        ep_return = ep_return - tf.reduce_mean(ep_return)
        rew_split = rew_split - (tf.reduce_sum(rew_split) / (tf.reduce_sum(mask_split) + 1e-24))

    discounter = [pow(gamma, i) for i in range(0, horizon)] # Decreasing gamma
    discounter_tf = tf.constant(discounter)
    disc_rew_split = rew_split * discounter_tf

    #tf.add_to_collection('prints', tf.Print(ep_return, [ep_return], 'ep_return_not_clustered', summarize=20))

    # Reward clustering
    '''
    rew_clustering_options = reward_clustering.split(':')
    if reward_clustering == 'none':
        pass # Do nothing
    elif rew_clustering_options[0] == 'global':
        assert len(rew_clustering_options) == 2, "Reward clustering: Provide the correct number of parameters"
        N = int(rew_clustering_options[1])
        tf.add_to_collection('prints', tf.Print(ep_return, [ep_return], 'ep_return', summarize=20))
        global_rew_min = tf.Variable(float('+inf'), trainable=False)
        global_rew_max = tf.Variable(float('-inf'), trainable=False)
        rew_min = tf.reduce_min(ep_return)
        rew_max = tf.reduce_max(ep_return)
        global_rew_min = tf.assign(global_rew_min, tf.minimum(global_rew_min, rew_min))
        global_rew_max = tf.assign(global_rew_max, tf.maximum(global_rew_max, rew_max))
        interval_size = (global_rew_max - global_rew_min) / N
        ep_return = tf.floordiv(ep_return, interval_size) * interval_size
    elif rew_clustering_options[0] == 'batch':
        assert len(rew_clustering_options) == 2, "Reward clustering: Provide the correct number of parameters"
        N = int(rew_clustering_options[1])
        rew_min = tf.reduce_min(ep_return)
        rew_max = tf.reduce_max(ep_return)
        interval_size = (rew_max - rew_min) / N
        ep_return = tf.floordiv(ep_return, interval_size) * interval_size
    elif rew_clustering_options[0] == 'manual':
        assert len(rew_clustering_options) == 4, "Reward clustering: Provide the correct number of parameters"
        N, rew_min, rew_max = map(int, rew_clustering_options[1:])
        print("N:", N)
        print("Min reward:", rew_min)
        print("Max reward:", rew_max)
        interval_size = (rew_max - rew_min) / N
        print("Interval size:", interval_size)
        # Clip to avoid overflow and cluster
        ep_return = tf.clip_by_value(ep_return, rew_min, rew_max)
        ep_return = tf.cast(tf.floordiv(ep_return, interval_size) * interval_size, tf.float32)
        tf.add_to_collection('prints', tf.Print(ep_return, [ep_return], 'ep_return_clustered', summarize=20))
    else:
        raise Exception('Unrecognized reward clustering scheme.')
    '''

    return_mean = tf.reduce_mean(ep_return)
    return_std = U.reduce_std(ep_return)
    return_max = tf.reduce_max(ep_return)
    return_min = tf.reduce_min(ep_return)
    return_abs_max = tf.reduce_max(tf.abs(ep_return))
    return_step_max = tf.reduce_max(tf.abs(rew_split)) # Max step reward
    return_step_mean = tf.abs(tf.reduce_mean(rew_split))
    positive_step_return_max = tf.maximum(0.0, tf.reduce_max(rew_split))
    negative_step_return_max = tf.maximum(0.0, tf.reduce_max(-rew_split))
    return_step_maxmin = tf.abs(positive_step_return_max - negative_step_return_max)

    losses_with_name.extend([(return_mean, 'InitialReturnMean'),
                             (return_max, 'InitialReturnMax'),
                             (return_min, 'InitialReturnMin'),
                             (return_std, 'InitialReturnStd'),
                             (empirical_d2, 'EmpiricalD2'),
                             (return_step_max, 'ReturnStepMax'),
                             (return_step_maxmin, 'ReturnStepMaxmin')])

    if iw_method == 'pdis':
        # log_ratio_split cumulative sum
        log_ratio_cumsum = tf.cumsum(log_ratio_split, axis=1)
        # Exponentiate
        ratio_cumsum = tf.exp(log_ratio_cumsum)
        # Multiply by the step-wise reward (not episode)
        ratio_reward = ratio_cumsum * disc_rew_split
        # Average on episodes
        ratio_reward_per_episode = tf.reduce_sum(ratio_reward, axis=1)
        w_return_mean = tf.reduce_sum(ratio_reward_per_episode, axis=0) / n_episodes
        # Get d2(w0:t) with mask
        d2_w_0t = tf.exp(tf.cumsum(emp_d2_split, axis=1)) * mask_split # LEAVE THIS OUTSIDE
        # Sum d2(w0:t) over timesteps
        episode_d2_0t = tf.reduce_sum(d2_w_0t, axis=1)
        # Sample variance
        J_sample_variance = (1/(n_episodes-1)) * tf.reduce_sum(tf.square(ratio_reward_per_episode - w_return_mean))
        losses_with_name.append((J_sample_variance, 'J_sample_variance'))
        losses_with_name.extend([(tf.reduce_max(ratio_cumsum), 'MaxIW'),
                                 (tf.reduce_min(ratio_cumsum), 'MinIW'),
                                 (tf.reduce_mean(ratio_cumsum), 'MeanIW'),
                                 (U.reduce_std(ratio_cumsum), 'StdIW')])
        losses_with_name.extend([(tf.reduce_max(d2_w_0t), 'MaxD2w0t'),
                                 (tf.reduce_min(d2_w_0t), 'MinD2w0t'),
                                 (tf.reduce_mean(d2_w_0t), 'MeanD2w0t'),
                                 (U.reduce_std(d2_w_0t), 'StdD2w0t')])

    elif iw_method == 'is':
        iw = tf.exp(tf.reduce_sum(log_ratio_split, axis=1))
        if iw_norm == 'none':
            iwn = iw / n_episodes
            w_return_mean = tf.reduce_sum(iwn * ep_return)
            J_sample_variance = (1/(n_episodes-1)) * tf.reduce_sum(tf.square(iw * ep_return - w_return_mean))
            losses_with_name.append((J_sample_variance, 'J_sample_variance'))
        elif iw_norm == 'sn':
            iwn = iw / tf.reduce_sum(iw)
            w_return_mean = tf.reduce_sum(iwn * ep_return)
        elif iw_norm == 'regression':
            iwn = iw / n_episodes
            mean_iw = tf.reduce_mean(iw)
            beta = tf.reduce_sum((iw - mean_iw) * ep_return * iw) / (tf.reduce_sum((iw - mean_iw) ** 2) + 1e-24)
            w_return_mean = tf.reduce_mean(iw * ep_return - beta * (iw - 1))
        else:
            raise NotImplementedError()
        ess_classic = tf.linalg.norm(iw, 1) ** 2 / tf.linalg.norm(iw, 2) ** 2
        sqrt_ess_classic = tf.linalg.norm(iw, 1) / tf.linalg.norm(iw, 2)
        ess_renyi = n_episodes / empirical_d2
        losses_with_name.extend([(tf.reduce_max(iwn), 'MaxIWNorm'),
                                 (tf.reduce_min(iwn), 'MinIWNorm'),
                                 (tf.reduce_mean(iwn), 'MeanIWNorm'),
                                 (U.reduce_std(iwn), 'StdIWNorm'),
                                 (tf.reduce_max(iw), 'MaxIW'),
                                 (tf.reduce_min(iw), 'MinIW'),
                                 (tf.reduce_mean(iw), 'MeanIW'),
                                 (U.reduce_std(iw), 'StdIW'),
                                 (ess_classic, 'ESSClassic'),
                                 (ess_renyi, 'ESSRenyi')])
    elif iw_method == 'rbis':
        # Get pdfs for episodes
        target_log_pdf_episode = tf.reduce_sum(target_log_pdf_split, axis=1)
        behavioral_log_pdf_episode = tf.reduce_sum(behavioral_log_pdf_split, axis=1)
        # Normalize log_proba (avoid as overflows as possible)
        normalization_factor = tf.reduce_mean(tf.stack([target_log_pdf_episode, behavioral_log_pdf_episode]))
        target_norm_log_pdf_episode = target_log_pdf_episode - normalization_factor
        behavioral_norm_log_pdf_episode = behavioral_log_pdf_episode - normalization_factor
        # Exponentiate
        target_pdf_episode = tf.clip_by_value(tf.cast(tf.exp(target_norm_log_pdf_episode), tf.float64), 1e-300, 1e+300)
        behavioral_pdf_episode = tf.clip_by_value(tf.cast(tf.exp(behavioral_norm_log_pdf_episode), tf.float64), 1e-300, 1e+300)
        tf.add_to_collection('asserts', tf.assert_positive(target_pdf_episode, name='target_pdf_positive'))
        tf.add_to_collection('asserts', tf.assert_positive(behavioral_pdf_episode, name='behavioral_pdf_positive'))
        # Compute the merging matrix (reward-clustering) and the number of clusters
        reward_unique, reward_indexes = tf.unique(ep_return)
        episode_clustering_matrix = tf.cast(tf.one_hot(reward_indexes, n_episodes), tf.float64)
        max_index = tf.reduce_max(reward_indexes) + 1
        trajectories_per_cluster = tf.reduce_sum(episode_clustering_matrix, axis=0)[:max_index]
        tf.add_to_collection('asserts', tf.assert_positive(tf.reduce_sum(episode_clustering_matrix, axis=0)[:max_index], name='clustering_matrix'))
        # Get the clustered pdfs
        clustered_target_pdf = tf.matmul(tf.reshape(target_pdf_episode, (1, -1)), episode_clustering_matrix)[0][:max_index]
        clustered_behavioral_pdf = tf.matmul(tf.reshape(behavioral_pdf_episode, (1, -1)), episode_clustering_matrix)[0][:max_index]
        tf.add_to_collection('asserts', tf.assert_positive(clustered_target_pdf, name='clust_target_pdf_positive'))
        tf.add_to_collection('asserts', tf.assert_positive(clustered_behavioral_pdf, name='clust_behavioral_pdf_positive'))
        # Compute the J
        ratio_clustered = clustered_target_pdf / clustered_behavioral_pdf
        #ratio_reward = tf.cast(ratio_clustered, tf.float32) * reward_unique                                                  # ---- No cluster cardinality
        ratio_reward = tf.cast(ratio_clustered, tf.float32) * reward_unique * tf.cast(trajectories_per_cluster, tf.float32)   # ---- Cluster cardinality
        #w_return_mean = tf.reduce_sum(ratio_reward) / tf.cast(max_index, tf.float32)                                         # ---- No cluster cardinality
        w_return_mean = tf.reduce_sum(ratio_reward) / tf.cast(n_episodes, tf.float32)                                         # ---- Cluster cardinality
        # Divergences
        ess_classic = tf.linalg.norm(ratio_reward, 1) ** 2 / tf.linalg.norm(ratio_reward, 2) ** 2
        sqrt_ess_classic = tf.linalg.norm(ratio_reward, 1) / tf.linalg.norm(ratio_reward, 2)
        ess_renyi = n_episodes / empirical_d2
        # Summaries
        losses_with_name.extend([(tf.reduce_max(ratio_clustered), 'MaxIW'),
                                 (tf.reduce_min(ratio_clustered), 'MinIW'),
                                 (tf.reduce_mean(ratio_clustered), 'MeanIW'),
                                 (U.reduce_std(ratio_clustered), 'StdIW'),
                                 (1-(max_index / n_episodes), 'RewardCompression'),
                                 (ess_classic, 'ESSClassic'),
                                 (ess_renyi, 'ESSRenyi')])
    else:
        raise NotImplementedError()

    if bound == 'J':
        bound_ = w_return_mean
    elif bound == 'std-d2':
        bound_ = w_return_mean - tf.sqrt((1 - delta) / (delta * ess_renyi)) * return_std
    elif bound == 'max-d2':
        var_estimate = tf.sqrt((1 - delta) / (delta * ess_renyi)) * return_abs_max
        bound_ = w_return_mean - tf.sqrt((1 - delta) / (delta * ess_renyi)) * return_abs_max
    elif bound == 'max-ess':
        bound_ = w_return_mean - tf.sqrt((1 - delta) / delta) / sqrt_ess_classic * return_abs_max
    elif bound == 'std-ess':
        bound_ = w_return_mean - tf.sqrt((1 - delta) / delta) / sqrt_ess_classic * return_std
    elif bound == 'pdis-max-d2':
        # Discount factor
        if gamma >= 1:
            discounter = [float(1+2*(horizon-t-1)) for t in range(0, horizon)]
        else:
            def f(t):
                return pow(gamma, 2*t) + (2*pow(gamma,t)*(pow(gamma, t+1) - pow(gamma, horizon))) / (1-gamma)
            discounter = [f(t) for t in range(0, horizon)]
        discounter_tf = tf.constant(discounter)
        mean_episode_d2 = tf.reduce_sum(d2_w_0t, axis=0) / (tf.reduce_sum(mask_split, axis=0) + 1e-24)
        discounted_d2 = mean_episode_d2 * discounter_tf # Discounted d2
        discounted_total_d2 = tf.reduce_sum(discounted_d2, axis=0) # Sum over time
        bound_ = w_return_mean - tf.sqrt((1-delta) * discounted_total_d2 / (delta*n_episodes)) * return_step_max
    elif bound == 'pdis-mean-d2':
        # Discount factor
        if gamma >= 1:
            discounter = [float(1+2*(horizon-t-1)) for t in range(0, horizon)]
        else:
            def f(t):
                return pow(gamma, 2*t) + (2*pow(gamma,t)*(pow(gamma, t+1) - pow(gamma, horizon))) / (1-gamma)
            discounter = [f(t) for t in range(0, horizon)]
        discounter_tf = tf.constant(discounter)
        mean_episode_d2 = tf.reduce_sum(d2_w_0t, axis=0) / (tf.reduce_sum(mask_split, axis=0) + 1e-24)
        discounted_d2 = mean_episode_d2 * discounter_tf # Discounted d2
        discounted_total_d2 = tf.reduce_sum(discounted_d2, axis=0) # Sum over time
        bound_ = w_return_mean - tf.sqrt((1-delta) * discounted_total_d2 / (delta*n_episodes)) * return_step_mean
    else:
        raise NotImplementedError()

    # Policy entropy for exploration
    ent = pi.pd.entropy()
    meanent = tf.reduce_mean(ent)
    losses_with_name.append((meanent, 'MeanEntropy'))
    # Add policy entropy bonus
    if entropy != 'none':
        scheme, v1, v2 = entropy.split(':')
        if scheme == 'step':
            entcoeff = tf.cond(iter_number_ < int(v2), lambda: float(v1), lambda: float(0.0))
            losses_with_name.append((entcoeff, 'EntropyCoefficient'))
            entbonus = entcoeff * meanent
            bound_ = bound_ + entbonus
        elif scheme == 'lin':
            ip = tf.cast(iter_number_ / max_iters, tf.float32)
            entcoeff_decay = tf.maximum(0.0, float(v2) + (float(v1) - float(v2)) * (1.0 - ip))
            losses_with_name.append((entcoeff_decay, 'EntropyCoefficient'))
            entbonus = entcoeff_decay * meanent
            bound_ = bound_ + entbonus
        elif scheme == 'exp':
            ent_f = tf.exp(-tf.abs(tf.reduce_mean(iw) - 1) * float(v2)) * float(v1)
            losses_with_name.append((ent_f, 'EntropyCoefficient'))
            bound_ = bound_ + ent_f * meanent
        else:
            raise Exception('Unrecognized entropy scheme.')

    losses_with_name.append((w_return_mean, 'ReturnMeanIW'))
    losses_with_name.append((bound_, 'Bound'))
    losses, loss_names = map(list, zip(*losses_with_name))

    if use_natural_gradient:
        p = tf.placeholder(dtype=tf.float32, shape=[None])
        target_logpdf_episode = tf.reduce_sum(target_log_pdf_split * mask_split, axis=1)
        grad_logprob = U.flatgrad(tf.stop_gradient(iwn) * target_logpdf_episode, var_list)
        dot_product = tf.reduce_sum(grad_logprob * p)
        hess_logprob = U.flatgrad(dot_product, var_list)
        compute_linear_operator = U.function([p, ob_, ac_, disc_rew_, mask_], [-hess_logprob])

    assign_old_eq_new = U.function([], [], updates=[tf.assign(oldv, newv)
                for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())])

    assert_ops = tf.group(*tf.get_collection('asserts'))
    print_ops = tf.group(*tf.get_collection('prints'))

    compute_lossandgrad = U.function([ob_, ac_, rew_, disc_rew_, clustered_rew_, mask_, iter_number_], losses + [U.flatgrad(bound_, var_list), assert_ops, print_ops])
    compute_grad = U.function([ob_, ac_, rew_, disc_rew_, clustered_rew_, mask_, iter_number_], [U.flatgrad(bound_, var_list), assert_ops, print_ops])
    compute_bound = U.function([ob_, ac_, rew_, disc_rew_, clustered_rew_, mask_, iter_number_], [bound_, assert_ops, print_ops])
    compute_losses = U.function([ob_, ac_, rew_, disc_rew_, clustered_rew_, mask_, iter_number_], losses)
    #compute_temp = U.function([ob_, ac_, rew_, disc_rew_, mask_], [ratio_cumsum, discounted_ratio])

    set_parameter = U.SetFromFlat(var_list)
    get_parameter = U.GetFlat(var_list)

    if sampler is None:
        seg_gen = traj_segment_generator(pi, env, n_episodes, horizon, stochastic=True)
        sampler = type("SequentialSampler", (object,), {"collect": lambda self, _: seg_gen.__next__()})()

    U.initialize()

    # Starting optimizing

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=n_episodes)
    rewbuffer = deque(maxlen=n_episodes)

    while True:

        iters_so_far += 1

        if render_after is not None and iters_so_far % render_after == 0:
            if hasattr(env, 'render'):
                render(env, pi, horizon)

        if callback:
            callback(locals(), globals())

        if iters_so_far >= max_iters:
            print('Finised...')
            break

        logger.log('********** Iteration %i ************' % iters_so_far)

        theta = get_parameter()

        with timed('sampling'):
            seg = sampler.collect(theta)

        add_disc_rew(seg, gamma)

        lens, rets = seg['ep_lens'], seg['ep_rets']
        lenbuffer.extend(lens)
        rewbuffer.extend(rets)
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)

        # Get clustered reward
        reward_matrix = np.reshape(seg['disc_rew'] * seg['mask'], (n_episodes, horizon))
        ep_reward = np.sum(reward_matrix, axis=1)
        if reward_clustering == 'none':
            pass
        elif reward_clustering == 'floor':
            ep_reward = np.floor(ep_reward)
        elif reward_clustering == 'ceil':
            ep_reward = np.ceil(ep_reward)
        elif reward_clustering == 'floor10':
            ep_reward = np.floor(ep_reward * 10) / 10
        elif reward_clustering == 'ceil10':
            ep_reward = np.ceil(ep_reward * 10) / 10
        elif reward_clustering == 'floor100':
            ep_reward = np.floor(ep_reward * 100) / 100
        elif reward_clustering == 'ceil100':
            ep_reward = np.ceil(ep_reward * 100) / 100


        args = ob, ac, rew, disc_rew, clustered_rew, mask, iter_number = seg['ob'], seg['ac'], seg['rew'], seg['disc_rew'], ep_reward, seg['mask'], iters_so_far

        assign_old_eq_new()

        def evaluate_loss():
            loss = compute_bound(*args)
            return loss[0]

        def evaluate_gradient():
            gradient = compute_grad(*args)
            return gradient[0]

        if use_natural_gradient:
            def evaluate_fisher_vector_prod(x):
                return compute_linear_operator(x, *args)[0] + fisher_reg * x

            def evaluate_natural_gradient(g):
                return cg(evaluate_fisher_vector_prod, g, cg_iters=10, verbose=0)
        else:
            evaluate_natural_gradient = None

        with timed('summaries before'):
            logger.record_tabular("Iteration", iters_so_far)
            logger.record_tabular("InitialBound", evaluate_loss())
            logger.record_tabular("EpLenMean", np.mean(lenbuffer))
            logger.record_tabular("EpRewMean", np.mean(rewbuffer))
            logger.record_tabular("EpThisIter", len(lens))
            logger.record_tabular("EpisodesSoFar", episodes_so_far)
            logger.record_tabular("TimestepsSoFar", timesteps_so_far)
            logger.record_tabular("TimeElapsed", time.time() - tstart)

        if save_weights:
            logger.record_tabular('Weights', str(get_parameter()))
            import pickle
            file = open('checkpoint.pkl', 'wb')
            pickle.dump(theta, file)

        with timed("offline optimization"):
            theta, improvement = optimize_offline(theta,
                                                  set_parameter,
                                                  line_search,
                                                  evaluate_loss,
                                                  evaluate_gradient,
                                                  evaluate_natural_gradient,
                                                  max_offline_ite=max_offline_iters)

        set_parameter(theta)

        with timed('summaries after'):
            meanlosses = np.array(compute_losses(*args))
            for (lossname, lossval) in zip(loss_names, meanlosses):
                logger.record_tabular(lossname, lossval)

        logger.dump_tabular()

    env.close()
def learn(
        env,
        policy_fn,
        *,
        timesteps_per_batch,  # what to train on
        epsilon,
        beta,
        cg_iters,
        gamma,
        lam,  # advantage estimation
        trial,
        sess,
        method,
        entcoeff=0.0,
        cg_damping=1e-2,
        kl_target=0.01,
        vf_stepsize=3e-4,
        vf_iters=3,
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,  # time constraint
        callback=None,
        TRPO=False):
    nworkers = MPI.COMM_WORLD.Get_size()
    rank = MPI.COMM_WORLD.Get_rank()
    np.set_printoptions(precision=3)
    # Setup losses and stuff
    # ----------------------------------------
    total_space = env.total_space
    ob_space = env.observation_space
    ac_space = env.action_space

    pi = policy_fn("pi",
                   ob_space,
                   ac_space,
                   ob_name="ob",
                   m_name="mask",
                   svfname="vfstate",
                   spiname="pistate")
    oldpi = policy_fn("oldpi",
                      ob_space,
                      ac_space,
                      ob_name="ob",
                      m_name="mask",
                      svfname="vfstate",
                      spiname="pistate")

    gpi = policy_fn("gpi",
                    total_space,
                    ac_space,
                    ob_name="gob",
                    m_name="gmask",
                    svfname="gvfstate",
                    spiname="gpistate")
    goldpi = policy_fn("goldpi",
                       total_space,
                       ac_space,
                       ob_name="gob",
                       m_name="gmask",
                       svfname="gvfstate",
                       spiname="gpistate")

    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    gatarg = tf.placeholder(dtype=tf.float32, shape=[None])
    gret = tf.placeholder(dtype=tf.float32, shape=[None])

    ob = U.get_placeholder_cached(name="ob")
    m = U.get_placeholder_cached(name="mask")
    svf = U.get_placeholder_cached(name="vfstate")
    spi = U.get_placeholder_cached(name="pistate")

    gob = U.get_placeholder_cached(name='gob')
    gm = U.get_placeholder_cached(name="gmask")
    gsvf = U.get_placeholder_cached(name="gvfstate")
    gspi = U.get_placeholder_cached(name="gpistate")

    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    gkloldnew = goldpi.pd.kl(gpi.pd)

    # crosskl_ob = pi.pd.kl(goldpi.pd)
    # crosskl_gob = gpi.pd.kl(oldpi.pd)
    crosskl_gob = pi.pd.kl(gpi.pd)
    crosskl_ob = gpi.pd.kl(pi.pd)

    pdmean = pi.pd.mean
    pdstd = pi.pd.std
    gpdmean = gpi.pd.mean
    gpdstd = gpi.pd.std

    ent = pi.pd.entropy()
    gent = gpi.pd.entropy()

    old_entropy = oldpi.pd.entropy()
    gold_entropy = goldpi.pd.entropy()

    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    meancrosskl = tf.reduce_mean(crosskl_ob)

    # meancrosskl = tf.maximum(tf.reduce_mean(crosskl_ob - 100), 0)

    gmeankl = tf.reduce_mean(gkloldnew)
    gmeanent = tf.reduce_mean(gent)
    gmeancrosskl = tf.reduce_mean(crosskl_gob)

    vferr = tf.reduce_mean(tf.square(pi.vpred - ret))
    gvferr = tf.reduce_mean(tf.square(gpi.vpred - gret))

    ratio = tf.exp(pi.pd.logp(ac) -
                   oldpi.pd.logp(ac))  # advantage * pnew / pold
    gratio = tf.exp(gpi.pd.logp(ac) - goldpi.pd.logp(ac))

    # Ratio objective
    # surrgain = tf.reduce_mean(ratio * atarg)
    # gsurrgain = tf.reduce_mean(gratio * gatarg)

    # Log objective
    surrgain = tf.reduce_mean(pi.pd.logp(ac) * atarg)
    gsurrgain = tf.reduce_mean(gpi.pd.logp(ac) * gatarg)

    optimgain = surrgain
    losses = [
        optimgain, meankl, meancrosskl, surrgain, meanent,
        tf.reduce_mean(ratio)
    ]
    loss_names = [
        "optimgain", "meankl", "meancrosskl", "surrgain", "entropy", "ratio"
    ]

    goptimgain = gsurrgain

    glosses = [
        goptimgain, gmeankl, gmeancrosskl, gsurrgain, gmeanent,
        tf.reduce_mean(gratio)
    ]
    gloss_names = [
        "goptimgain", "gmeankl", "gmeancrosskl", "gsurrgain", "gentropy",
        "gratio"
    ]

    dist = meankl
    gdist = gmeankl

    all_pi_var_list = pi.get_trainable_variables()
    all_var_list = [
        v for v in all_pi_var_list if v.name.split("/")[0].startswith("pi")
    ]
    var_list = [
        v for v in all_var_list if v.name.split("/")[1].startswith("pol")
    ]
    vf_var_list = [
        v for v in all_var_list if v.name.split("/")[1].startswith("vf")
    ]
    vfadam = MpiAdam(vf_var_list)
    poladam = MpiAdam(var_list)

    gall_gpi_var_list = gpi.get_trainable_variables()
    gall_var_list = [
        v for v in gall_gpi_var_list if v.name.split("/")[0].startswith("gpi")
    ]
    gvar_list = [
        v for v in gall_var_list if v.name.split("/")[1].startswith("pol")
    ]
    gvf_var_list = [
        v for v in gall_var_list if v.name.split("/")[1].startswith("vf")
    ]
    gvfadam = MpiAdam(gvf_var_list)
    # gpoladpam = MpiAdam(gvar_list)

    get_flat = U.GetFlat(var_list)
    set_from_flat = U.SetFromFlat(var_list)
    klgrads = tf.gradients(dist, var_list)
    # crossklgrads = tf.gradients(meancrosskl, var_list)

    flat_tangent = tf.placeholder(dtype=tf.float32,
                                  shape=[None],
                                  name="flat_tan")
    shapes = [var.get_shape().as_list() for var in var_list]
    start = 0
    tangents = []
    for shape in shapes:
        sz = U.intprod(shape)
        tangents.append(tf.reshape(flat_tangent[start:start + sz], shape))
        start += sz
    gvp = tf.add_n([
        tf.reduce_sum(g * tangent)
        for (g, tangent) in zipsame(klgrads, tangents)
    ])  #pylint: disable=E1111
    fvp = U.flatgrad(gvp, var_list)

    gget_flat = U.GetFlat(gvar_list)
    gset_from_flat = U.SetFromFlat(gvar_list)
    gklgrads = tf.gradients(gdist, gvar_list)
    # gcrossklgrads = tf.gradients(gmeancrosskl, gvar_list)

    gflat_tangent = tf.placeholder(dtype=tf.float32,
                                   shape=[None],
                                   name="gflat_tan")
    gshapes = [var.get_shape().as_list() for var in gvar_list]
    gstart = 0
    gtangents = []
    for shape in gshapes:
        sz = U.intprod(shape)
        gtangents.append(tf.reshape(gflat_tangent[gstart:gstart + sz], shape))
        gstart += sz
    ggvp = tf.add_n([
        tf.reduce_sum(g * tangent)
        for (g, tangent) in zipsame(gklgrads, gtangents)
    ])  #pylint: disable=E1111
    gfvp = U.flatgrad(ggvp, gvar_list)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])

    gassign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(goldpi.get_variables(), gpi.get_variables())
        ])

    compute_losses = U.function(
        [gob, gm, gsvf, gspi, ob, m, svf, spi, ac, atarg], losses)
    compute_lossandgrad = U.function(
        [gob, gm, gsvf, gspi, ob, m, svf, spi, ac, atarg],
        losses + [U.flatgrad(optimgain, var_list)])
    compute_fvp = U.function([flat_tangent, ob, m, svf, spi, ac, atarg], fvp)
    compute_vflossandgrad = U.function([ob, m, svf, spi, ret],
                                       U.flatgrad(vferr, vf_var_list))
    compute_crossklandgrad = U.function([ob, m, svf, spi, gob, gm, gsvf, gspi],
                                        U.flatgrad(meancrosskl, var_list))

    gcompute_losses = U.function(
        [ob, m, svf, spi, gob, gm, gsvf, gspi, ac, gatarg], glosses)
    gcompute_lossandgrad = U.function(
        [ob, m, svf, spi, gob, gm, gsvf, gspi, ac, gatarg],
        glosses + [U.flatgrad(goptimgain, gvar_list)])
    gcompute_fvp = U.function([gflat_tangent, gob, gm, gsvf, gspi, ac, gatarg],
                              gfvp)
    gcompute_vflossandgrad = U.function([gob, gm, gsvf, gspi, gret],
                                        U.flatgrad(gvferr, gvf_var_list))
    # compute_gcrossklandgrad = U.function([gob, ob], U.flatgrad(gmeancrosskl, gvar_list))

    saver = tf.train.Saver()

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(
                colorize("done in %.3f seconds" % (time.time() - tstart),
                         color='magenta'))
        else:
            yield

    def allmean(x):
        assert isinstance(x, np.ndarray)
        out = np.empty_like(x)
        MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
        out /= nworkers
        return out

    U.initialize()

    # guided_initilizer(gpol=gvar_list, gvf=gvf_var_list, fpol=var_list, fvf=vf_var_list)

    th_init = get_flat()
    MPI.COMM_WORLD.Bcast(th_init, root=0)
    set_from_flat(th_init)
    vfadam.sync()
    poladam.sync()
    print("Init final policy param sum", th_init.sum(), flush=True)

    gth_init = gget_flat()
    MPI.COMM_WORLD.Bcast(gth_init, root=0)
    gset_from_flat(gth_init)
    gvfadam.sync()
    # gpoladpam.sync()
    print("Init guided policy param sum", gth_init.sum(), flush=True)

    # Initialize eta, omega optimizer
    init_eta = 0.5
    init_omega = 2.0
    eta_omega_optimizer = EtaOmegaOptimizer(beta, epsilon, init_eta,
                                            init_omega)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     gpi,
                                     env,
                                     timesteps_per_batch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    num_iters = max_timesteps // timesteps_per_batch
    lenbuffer = deque(maxlen=40)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40)  # rolling buffer for episode rewards

    assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        logger.log("********** Iteration %i ************" % iters_so_far)

        with timed("sampling"):
            seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, new, atarg, tdlamret = seg["ob"], seg["ac"], seg["new"], seg[
            "adv"], seg["tdlamret"]
        gob, gatarg, gtdlamret = seg["gob"], seg["gadv"], seg["gtdlamret"]
        pistate, vfstate, gpistate, gvfstate = seg["pistate"], seg[
            "vfstate"], seg["gpistate"], seg["gvfstate"]

        vpredbefore = seg["vpred"]  # predicted value function before udpate
        gvpredbefore = seg["gvpred"]

        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        gatarg = (gatarg - gatarg.mean()) / gatarg.std()

        if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret)
        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        if hasattr(gpi, "ret_rms"): gpi.ret_rms.update(gtdlamret)
        if hasattr(gpi, "ob_rms"): gpi.ob_rms.update(gob)

        args = gob, new, gvfstate, gpistate, ob, new, vfstate, pistate, ac, atarg
        fvpargs = [arr[::5] for arr in args[4:]]

        gargs = ob, new, vfstate, pistate, gob, new, gvfstate, gpistate, ac, gatarg
        gfvpargs = [arr[::5] for arr in gargs[4:]]

        def fisher_vector_product(p):
            return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p

        def gfisher_vector_product(p):
            return allmean(gcompute_fvp(p, *gfvpargs)) + cg_damping * p

        assign_old_eq_new()  # set old parameter values to new parameter values
        gassign_old_eq_new()

        with timed("computegrad"):
            *lossbefore, g = compute_lossandgrad(*args)
            *glossbefore, gg = gcompute_lossandgrad(*gargs)

        lossbefore = allmean(np.array(lossbefore))
        g = allmean(g)

        glossbefore = allmean(np.array(glossbefore))
        gg = allmean(gg)

        if np.allclose(g, 0) or np.allclose(gg, 0):
            logger.log("Got zero gradient. not updating")
        else:
            with timed("cg"):
                stepdir = cg(fisher_vector_product,
                             g,
                             cg_iters=cg_iters,
                             verbose=rank == 0)
                gstepdir = cg(gfisher_vector_product,
                              gg,
                              cg_iters=cg_iters,
                              verbose=rank == 0)
            assert np.isfinite(gstepdir).all()
            assert np.isfinite(stepdir).all()

            if TRPO:
                #
                # TRPO specific code.
                # Find correct step size using line search
                #
                #TODO: also enable guided learning for TRPO
                shs = .5 * stepdir.dot(fisher_vector_product(stepdir))
                lm = np.sqrt(shs / epsilon)
                # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
                fullstep = stepdir / lm
                expectedimprove = g.dot(fullstep)
                surrbefore = lossbefore[0]
                stepsize = 1.0
                thbefore = get_flat()
                for _ in range(10):
                    thnew = thbefore + fullstep * stepsize
                    set_from_flat(thnew)
                    meanlosses = surr, kl, *_ = allmean(
                        np.array(compute_losses(*args)))
                    improve = surr - surrbefore
                    logger.log("Expected: %.3f Actual: %.3f" %
                               (expectedimprove, improve))
                    if not np.isfinite(meanlosses).all():
                        logger.log("Got non-finite value of losses -- bad!")
                    elif kl > epsilon * 1.5:
                        logger.log("violated KL constraint. shrinking step.")
                    elif improve < 0:
                        logger.log("surrogate didn't improve. shrinking step.")
                    else:
                        logger.log("Stepsize OK!")
                        break
                    stepsize *= .5
                else:
                    logger.log("couldn't compute a good step")
                    set_from_flat(thbefore)
            else:
                #
                # COPOS specific implementation.
                #

                copos_update_dir = stepdir
                gcopos_update_dir = gstepdir

                # Split direction into log-linear 'w_theta' and non-linear 'w_beta' parts
                w_theta, w_beta = pi.split_w(copos_update_dir)
                gw_theta, gw_beta = gpi.split_w(gcopos_update_dir)

                # q_beta(s,a) = \grad_beta \log \pi(a|s) * w_beta
                #             = features_beta(s) * K^T * Prec * a
                # q_beta = self.target.get_q_beta(features_beta, actions)

                Waa, Wsa = pi.w2W(w_theta)
                wa = pi.get_wa(ob, w_beta)

                gWaa, gWsa = gpi.w2W(gw_theta)
                gwa = gpi.get_wa(gob, gw_beta)

                varphis = pi.get_varphis(ob)
                gvarphis = gpi.get_varphis(gob)

                # Optimize eta and omega
                tmp_ob = np.zeros(
                    (1, ) + ob_space.shape
                )  # We assume that entropy does not depend on the NN
                old_ent = old_entropy.eval({oldpi.ob: tmp_ob})[0]
                eta, omega = eta_omega_optimizer.optimize(
                    w_theta, Waa, Wsa, wa, varphis, pi.get_kt(),
                    pi.get_prec_matrix(), pi.is_new_policy_valid, old_ent)
                logger.log("Initial eta of final policy: " + str(eta) +
                           " and omega: " + str(omega))

                gtmp_ob = np.zeros((1, ) + total_space.shape)
                gold_ent = gold_entropy.eval({goldpi.ob: gtmp_ob})[0]
                geta, gomega = eta_omega_optimizer.optimize(
                    gw_theta, gWaa, gWsa, gwa, gvarphis, gpi.get_kt(),
                    gpi.get_prec_matrix(), gpi.is_new_policy_valid, gold_ent)
                logger.log("Initial eta of guided policy: " + str(geta) +
                           " and omega: " + str(gomega))

                current_theta_beta = get_flat()
                prev_theta, prev_beta = pi.all_to_theta_beta(
                    current_theta_beta)

                gcurrent_theta_beta = gget_flat()
                gprev_theta, gprev_beta = gpi.all_to_theta_beta(
                    gcurrent_theta_beta)

                for i in range(2):
                    # Do a line search for both theta and beta parameters by adjusting only eta
                    eta = eta_search(w_theta, w_beta, eta, omega, allmean,
                                     compute_losses, get_flat, set_from_flat,
                                     pi, epsilon, args)
                    logger.log("Updated eta of final policy, eta: " +
                               str(eta) + " and omega: " + str(omega))

                    # Find proper omega for new eta. Use old policy parameters first.
                    set_from_flat(pi.theta_beta_to_all(prev_theta, prev_beta))
                    eta, omega = \
                        eta_omega_optimizer.optimize(w_theta, Waa, Wsa, wa, varphis, pi.get_kt(),
                                                     pi.get_prec_matrix(), pi.is_new_policy_valid, old_ent, eta)
                    logger.log("Updated omega of final policy, eta: " +
                               str(eta) + " and omega: " + str(omega))

                    geta = eta_search(gw_theta, gw_beta, geta, gomega, allmean,
                                      gcompute_losses, gget_flat,
                                      gset_from_flat, gpi, epsilon, gargs)
                    logger.log("updated eta of guided policy, eta:" +
                               str(geta) + "and omega:" + str(gomega))

                    gset_from_flat(
                        gpi.theta_beta_to_all(gprev_theta, gprev_beta))
                    geta, gomega = eta_omega_optimizer.optimize(
                        gw_theta, gWaa, gWsa, gwa, gvarphis, gpi.get_kt(),
                        gpi.get_prec_matrix(), gpi.is_new_policy_valid,
                        gold_ent, geta)
                    logger.log("Updated omega of guided policy, eta:" +
                               str(geta) + "and omega:" + str(gomega))

                # Use final policy
                logger.log("Final eta of final policy: " + str(eta) +
                           " and omega: " + str(omega))
                logger.log("Final eta of guided policy: " + str(geta) +
                           "and omega:" + str(gomega))

                cur_theta = (eta * prev_theta +
                             w_theta.reshape(-1, )) / (eta + omega)
                cur_beta = prev_beta + w_beta.reshape(-1, ) / eta
                set_from_flat(pi.theta_beta_to_all(cur_theta, cur_beta))

                gcur_theta = (geta * gprev_theta +
                              gw_theta.reshape(-1, )) / (geta + gomega)
                gcur_beta = gprev_beta + gw_beta.reshape(-1, ) / geta
                gset_from_flat(gpi.theta_beta_to_all(gcur_theta, gcur_beta))

                meanlosses = surr, kl, crosskl, *_ = allmean(
                    np.array(compute_losses(*args)))
                gmeanlosses = gsurr, gkl, gcrosskl, *_ = allmean(
                    np.array(gcompute_losses(*gargs)))

                # poladam.update(allmean(compute_crossklandgrad(ob, gob)), vf_stepsize)
                # gpoladpam.update(allmean(compute_gcrossklandgrad(gob, ob)), vf_stepsize)

                for _ in range(vf_iters):
                    for (mbob, mbgob) in dataset.iterbatches(
                        (seg["ob"], seg["gob"]),
                            include_final_partial_batch=False,
                            batch_size=64):
                        g = allmean(compute_crossklandgrad(mbob, mbgob))
                        poladam.update(g, vf_stepsize)

            # if nworkers > 1 and iters_so_far % 20 == 0:
            #     paramsums = MPI.COMM_WORLD.allgather((thnew.sum(), vfadam.getflat().sum())) # list of tuples
            #     assert all(np.allclose(ps, paramsums[0]) for ps in paramsums[1:])

        for (lossname, lossval) in zip(loss_names, meanlosses):
            logger.record_tabular(lossname, lossval)

        for (lossname, lossval) in zip(gloss_names, gmeanlosses):
            logger.record_tabular(lossname, lossval)

        with timed("vf"):
            for _ in range(vf_iters):
                for (mbob, mbret) in dataset.iterbatches(
                    (seg["ob"], seg["tdlamret"]),
                        include_final_partial_batch=False,
                        batch_size=64):
                    g = allmean(compute_vflossandgrad(mbob, mbret))
                    vfadam.update(g, vf_stepsize)
                for (mbob, mbret) in dataset.iterbatches(
                    (seg["gob"], seg["gtdlamret"]),
                        include_final_partial_batch=False,
                        batch_size=64):
                    gg = allmean(gcompute_vflossandgrad(mbob, mbret))
                    gvfadam.update(gg, vf_stepsize)

        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        logger.record_tabular("gev_tdlam_before",
                              explained_variance(gvpredbefore, gtdlamret))

        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))

        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        logger.record_tabular("Name", method)
        logger.record_tabular("Iteration", iters_so_far)
        logger.record_tabular("trial", trial)

        if rank == 0:
            logger.dump_tabular()

        if iters_so_far % 100 == 0 or iters_so_far == 1 or iters_so_far == num_iters:
            # sess = tf.get_default_session()
            checkdir = get_dir(osp.join(logger.get_dir(), 'checkpoints'))
            savepath = osp.join(checkdir, '%.5i.ckpt' % iters_so_far)
            saver.save(sess, save_path=savepath)
            print("save model to path:", savepath)
Example #29
0
def learn(
    env,
    policy_fn,
    *,
    timesteps_per_actorbatch,  # timesteps per actor per update
    clip_param,
    entcoeff,  # clipping parameter epsilon, entropy coeff
    optim_epochs,
    optim_stepsize,
    optim_batchsize,  # optimization hypers
    gamma,
    lam,  # advantage estimation
    max_timesteps=0,
    max_episodes=0,
    max_iters=0,
    max_seconds=0,  # time constraint
    callback=None,  # you can do anything in the callback, since it takes locals(), globals()
    adam_epsilon=1e-5,
    schedule='constant'  # annealing for stepsize parameters (epsilon and adam)
):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_fn("pi", ob_space,
                   ac_space)  # Construct network for new policy
    oldpi = policy_fn("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param,
                             1.0 + clip_param) * atarg  #
    pol_surr = -tf.reduce_mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    adam.sync()

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_actorbatch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards
    eprewards = []

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        # logger.log("********** Iteration %i ************"%iters_so_far)

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        # logger.log("Optimizing...")
        # logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            # logger.log(fmt_row(13, np.mean(losses, axis=0)))

        # logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        # logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            pass
            # logger.record_tabular("loss_"+name, lossval)
        # logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        # logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        # logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        # logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        # logger.record_tabular("EpisodesSoFar", episodes_so_far)
        # logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        # logger.record_tabular("TimeElapsed", time.time() - tstart)

        eprewards.append(np.mean(rewbuffer))

        if MPI.COMM_WORLD.Get_rank() == 0:
            pass
            # logger.dump_tabular()

    return np.mean(eprewards)
Example #30
0
def run_hoof_all(
        network,
        env,
        total_timesteps,
        timesteps_per_batch,  # what to train on
        kl_range,
        gamma_range,
        lam_range,  # advantage estimation
        num_kl,
        num_gamma_lam,
        cg_iters=10,
        seed=None,
        ent_coef=0.0,
        cg_damping=1e-2,
        vf_stepsize=3e-4,
        vf_iters=3,
        max_episodes=0,
        max_iters=0,  # time constraint
        callback=None,
        load_path=None,
        **network_kwargs):
    '''
    learn a policy function with TRPO algorithm
    Parameters:
    ----------
    network                 neural network to learn. Can be either string ('mlp', 'cnn', 'lstm', 'lnlstm' for basic types)
                            or function that takes input placeholder and returns tuple (output, None) for feedforward nets
                            or (output, (state_placeholder, state_output, mask_placeholder)) for recurrent nets
    env                     environment (one of the gym environments or wrapped via baselines.common.vec_env.VecEnv-type class
    timesteps_per_batch     timesteps per gradient estimation batch
    max_kl                  max KL divergence between old policy and new policy ( KL(pi_old || pi) )
    ent_coef                coefficient of policy entropy term in the optimization objective
    cg_iters                number of iterations of conjugate gradient algorithm
    cg_damping              conjugate gradient damping
    vf_stepsize             learning rate for adam optimizer used to optimie value function loss
    vf_iters                number of iterations of value function optimization iterations per each policy optimization step
    total_timesteps           max number of timesteps
    max_episodes            max number of episodes
    max_iters               maximum number of policy optimization iterations
    callback                function to be called with (locals(), globals()) each policy optimization step
    load_path               str, path to load the model from (default: None, i.e. no model is loaded)
    **network_kwargs        keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network
    Returns:
    -------
    learnt model
    '''

    MPI = None
    nworkers = 1
    rank = 0

    cpus_per_worker = 1
    U.get_session(
        config=tf.ConfigProto(allow_soft_placement=True,
                              inter_op_parallelism_threads=cpus_per_worker,
                              intra_op_parallelism_threads=cpus_per_worker))

    policy = build_policy(env, network, value_network='copy', **network_kwargs)
    set_global_seeds(seed)

    np.set_printoptions(precision=3)
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space

    # +2 for gamma, lambda
    ob = tf.placeholder(shape=(None, env.observation_space.shape[0] + 2),
                        dtype=env.observation_space.dtype,
                        name='Ob')
    with tf.variable_scope("pi"):
        pi = policy(observ_placeholder=ob)
    with tf.variable_scope("oldpi"):
        oldpi = policy(observ_placeholder=ob)

    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    entbonus = ent_coef * meanent

    vferr = tf.reduce_mean(tf.square(pi.vf - ret))

    ratio = tf.exp(pi.pd.logp(ac) -
                   oldpi.pd.logp(ac))  # advantage * pnew / pold
    surrgain = tf.reduce_mean(ratio * atarg)

    optimgain = surrgain + entbonus
    losses = [optimgain, meankl, entbonus, surrgain, meanent]
    loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]

    dist = meankl

    all_var_list = get_trainable_variables("pi")
    var_list = get_pi_trainable_variables("pi")
    vf_var_list = get_vf_trainable_variables("pi")

    vfadam = MpiAdam(vf_var_list)

    get_flat = U.GetFlat(var_list)
    set_from_flat = U.SetFromFlat(var_list)
    klgrads = tf.gradients(dist, var_list)
    flat_tangent = tf.placeholder(dtype=tf.float32,
                                  shape=[None],
                                  name="flat_tan")
    shapes = [var.get_shape().as_list() for var in var_list]
    start = 0
    tangents = []
    for shape in shapes:
        sz = U.intprod(shape)
        tangents.append(tf.reshape(flat_tangent[start:start + sz], shape))
        start += sz
    gvp = tf.add_n([
        tf.reduce_sum(g * tangent)
        for (g, tangent) in zipsame(klgrads, tangents)
    ])  #pylint: disable=E1111
    fvp = U.flatgrad(gvp, var_list)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(get_variables("oldpi"), get_variables("pi"))
        ])

    compute_ratio = U.function(
        [ob, ac, atarg], ratio)  # IS ratio - used for computing IS weights

    compute_losses = U.function([ob, ac, atarg], losses)
    compute_lossandgrad = U.function([ob, ac, atarg], losses +
                                     [U.flatgrad(optimgain, var_list)])
    compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp)
    compute_vflossandgrad = U.function([ob, ret],
                                       U.flatgrad(vferr, vf_var_list))

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(
                colorize("done in %.3f seconds" % (time.time() - tstart),
                         color='magenta'))
        else:
            yield

    def allmean(x):
        assert isinstance(x, np.ndarray)
        if MPI is not None:
            out = np.empty_like(x)
            MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
            out /= nworkers
        else:
            out = np.copy(x)

        return out

    U.initialize()
    if load_path is not None:
        pi.load(load_path)

    th_init = get_flat()
    if MPI is not None:
        MPI.COMM_WORLD.Bcast(th_init, root=0)

    set_from_flat(th_init)
    vfadam.sync()
    print("Init param sum", th_init.sum(), flush=True)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator_with_gl(pi,
                                             env,
                                             timesteps_per_batch,
                                             stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=40)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40)  # rolling buffer for episode rewards

    if sum([max_iters > 0, total_timesteps > 0, max_episodes > 0]) == 0:
        # noththing to be done
        return pi

    assert sum([max_iters>0, total_timesteps>0, max_episodes>0]) < 2, \
        'out of max_iters, total_timesteps, and max_episodes only one should be specified'

    kl_range = np.atleast_1d(kl_range)
    gamma_range = np.atleast_1d(gamma_range)
    lam_range = np.atleast_1d(lam_range)

    while True:
        if callback: callback(locals(), globals())
        if total_timesteps and timesteps_so_far >= total_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        logger.log("********** Iteration %i ************" % iters_so_far)

        with timed("sampling"):
            seg = seg_gen.__next__()

        thbefore = get_flat()

        rand_gamma = gamma_range[0] + (
            gamma_range[-1] - gamma_range[0]) * np.random.rand(num_gamma_lam)
        rand_lam = lam_range[0] + (
            lam_range[-1] - lam_range[0]) * np.random.rand(num_gamma_lam)
        rand_kl = kl_range[0] + (kl_range[-1] -
                                 kl_range[0]) * np.random.rand(num_kl)

        opt_polval = -10**8
        est_polval = np.zeros((num_gamma_lam, num_kl))
        ob_lam_gam = []
        tdlamret = []
        vpred = []

        for gl in range(num_gamma_lam):
            oblg, vpredbefore, atarg, tdlr = add_vtarg_and_adv_with_gl(
                pi, seg, rand_gamma[gl], rand_lam[gl])
            ob_lam_gam += [oblg]
            tdlamret += [tdlr]
            vpred += [vpredbefore]
            atarg = (atarg - atarg.mean()) / atarg.std(
            )  # standardized advantage function estimate

            pol_ob = np.concatenate(
                (seg['ob'], np.zeros(seg['ob'].shape[:-1] + (2, ))), axis=-1)
            args = pol_ob, seg["ac"], atarg
            fvpargs = [arr[::5] for arr in args]

            def fisher_vector_product(p):
                return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p

            assign_old_eq_new(
            )  # set old parameter values to new parameter values
            with timed("computegrad"):
                *lossbefore, g = compute_lossandgrad(*args)
            lossbefore = allmean(np.array(lossbefore))
            g = allmean(g)
            if np.allclose(g, 0):
                logger.log("Got zero gradient. not updating")
            else:
                with timed("cg"):
                    stepdir = cg(fisher_vector_product,
                                 g,
                                 cg_iters=cg_iters,
                                 verbose=False)
                assert np.isfinite(stepdir).all()
                shs = .5 * stepdir.dot(fisher_vector_product(stepdir))
                surrbefore = lossbefore[0]

                for m, kl in enumerate(rand_kl):
                    lm = np.sqrt(shs / kl)
                    fullstep = stepdir / lm
                    thnew = thbefore + fullstep
                    set_from_flat(thnew)

                    # compute the IS estimates
                    lik_ratio = compute_ratio(*args)
                    est_polval[gl, m] = wis_estimate(seg, lik_ratio)

                    # update best policy found so far
                    if est_polval[gl, m] > opt_polval:
                        opt_polval = est_polval[gl, m]
                        opt_th = thnew
                        opt_kl = kl
                        opt_gamma = rand_gamma[gl]
                        opt_lam = rand_lam[gl]
                        opt_vpredbefore = vpredbefore
                        opt_tdlr = tdlr
                        meanlosses = surr, kl, *_ = allmean(
                            np.array(compute_losses(*args)))
                        improve = surr - surrbefore
                        expectedimprove = g.dot(fullstep)
                    set_from_flat(thbefore)
        logger.log("Expected: %.3f Actual: %.3f" % (expectedimprove, improve))
        set_from_flat(opt_th)

        for (lossname, lossval) in zip(loss_names, meanlosses):
            logger.record_tabular(lossname, lossval)

        ob_lam_gam = np.concatenate(ob_lam_gam, axis=0)
        tdlamret = np.concatenate(tdlamret, axis=0)
        vpred = np.concatenate(vpred, axis=0)
        with timed("vf"):
            for _ in range(vf_iters):
                for (mbob, mbret) in dataset.iterbatches(
                    (ob_lam_gam, tdlamret),
                        include_final_partial_batch=False,
                        batch_size=num_gamma_lam * 64):
                    g = allmean(compute_vflossandgrad(mbob, mbret))
                    vfadam.update(g, vf_stepsize)

        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpred, tdlamret))

        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        if MPI is not None:
            listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        else:
            listoflrpairs = [lrlocal]

        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        logger.record_tabular("Opt_KL", opt_kl)
        logger.record_tabular("gamma", opt_gamma)
        logger.record_tabular("lam", opt_lam)

        if rank == 0:
            logger.dump_tabular()

    return pi
def learn(
        env,
        policy_fn,
        *,
        timesteps_per_batch,  # what to train on
        epsilon,
        beta,
        cg_iters,
        gamma,
        lam,  # advantage estimation
        entcoeff=0.0,
        cg_damping=1e-2,
        vf_stepsize=3e-4,
        vf_iters=3,
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,  # time constraint
        callback=None,
        TRPO=False):
    nworkers = MPI.COMM_WORLD.Get_size()
    rank = MPI.COMM_WORLD.Get_rank()
    np.set_printoptions(precision=3)
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    discrete_ac_space = isinstance(ac_space, gym.spaces.Discrete)
    print("ob_space: " + str(ob_space))
    print("ac_space: " + str(ac_space))
    pi = policy_fn("pi", ob_space, ac_space)
    oldpi = policy_fn("oldpi", ob_space, ac_space)
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    old_entropy = oldpi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    entbonus = entcoeff * meanent

    vferr = tf.reduce_mean(tf.square(pi.vpred - ret))

    ratio = tf.exp(pi.pd.logp(ac) -
                   oldpi.pd.logp(ac))  # advantage * pnew / pold
    surrgain = tf.reduce_mean(ratio * atarg)

    optimgain = surrgain + entbonus
    losses = [optimgain, meankl, entbonus, surrgain, meanent]
    loss_names = ["optimgain", "meankl", "entloss", "surrgain", "Entropy"]

    dist = meankl

    all_var_list = pi.get_trainable_variables()
    all_var_list = [
        v for v in all_var_list if v.name.split("/")[0].startswith("pi")
    ]
    var_list = [
        v for v in all_var_list if v.name.split("/")[1].startswith("pol")
    ]
    vf_var_list = [
        v for v in all_var_list if v.name.split("/")[1].startswith("vf")
    ]
    vfadam = MpiAdam(vf_var_list)

    get_flat = U.GetFlat(var_list)
    set_from_flat = U.SetFromFlat(var_list)
    klgrads = tf.gradients(dist, var_list)
    flat_tangent = tf.placeholder(dtype=tf.float32,
                                  shape=[None],
                                  name="flat_tan")
    shapes = [var.get_shape().as_list() for var in var_list]
    start = 0
    tangents = []
    for shape in shapes:
        sz = U.intprod(shape)
        tangents.append(tf.reshape(flat_tangent[start:start + sz], shape))
        start += sz
    #????gvp and fvp???
    gvp = tf.add_n([
        tf.reduce_sum(g * tangent)
        for (g, tangent) in zipsame(klgrads, tangents)
    ])  #pylint: disable=E1111
    fvp = U.flatgrad(gvp, var_list)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg], losses)
    compute_lossandgrad = U.function([ob, ac, atarg], losses +
                                     [U.flatgrad(optimgain, var_list)])
    compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp)
    compute_vflossandgrad = U.function([ob, ret],
                                       U.flatgrad(vferr, vf_var_list))

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(
                colorize("done in %.3f seconds" % (time.time() - tstart),
                         color='magenta'))
        else:
            yield

    def allmean(x):
        assert isinstance(x, np.ndarray)
        out = np.empty_like(x)
        MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
        out /= nworkers
        return out

    U.initialize()
    th_init = get_flat()
    MPI.COMM_WORLD.Bcast(th_init, root=0)
    set_from_flat(th_init)
    vfadam.sync()
    print("Init param sum", th_init.sum(), flush=True)

    # Initialize eta, omega optimizer
    if discrete_ac_space:
        init_eta = 1
        init_omega = 0.5
        eta_omega_optimizer = EtaOmegaOptimizerDiscrete(
            beta, epsilon, init_eta, init_omega)
    else:
        init_eta = 0.5
        init_omega = 2.0
        #????eta_omega_optimizer details?????
        eta_omega_optimizer = EtaOmegaOptimizer(beta, epsilon, init_eta,
                                                init_omega)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_batch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=40)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40)  # rolling buffer for episode rewards

    assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        logger.log("********** Iteration %i ************" % iters_so_far)

        with timed("sampling"):
            seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        #print(ob[:20])
        #print(ac[:20])

        if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret)
        if hasattr(pi, "ob_rms"):
            print(pi.ob_rms.mean)
            pi.ob_rms.update(ob)  # update running mean/std for policy

        args = seg["ob"], seg["ac"], atarg
        fvpargs = [arr[::5] for arr in args]

        def fisher_vector_product(p):
            return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p

        assign_old_eq_new()  # set old parameter values to new parameter values
        with timed("computegrad"):
            *lossbefore, g = compute_lossandgrad(*args)
        lossbefore = allmean(np.array(lossbefore))
        g = allmean(g)
        if np.allclose(g, 0):
            logger.log("Got zero gradient. not updating")
        else:
            with timed("cg"):
                stepdir = cg(fisher_vector_product,
                             g,
                             cg_iters=cg_iters,
                             verbose=rank == 0)
            assert np.isfinite(stepdir).all()

            if TRPO:
                #
                # TRPO specific code.
                # Find correct step size using line search
                #
                shs = .5 * stepdir.dot(fisher_vector_product(stepdir))
                lm = np.sqrt(shs / epsilon)
                # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
                fullstep = stepdir / lm
                expectedimprove = g.dot(fullstep)
                surrbefore = lossbefore[0]
                stepsize = 1.0
                thbefore = get_flat()
                for _ in range(10):
                    thnew = thbefore + fullstep * stepsize
                    set_from_flat(thnew)
                    meanlosses = surr, kl, *_ = allmean(
                        np.array(compute_losses(*args)))
                    improve = surr - surrbefore
                    logger.log("Expected: %.3f Actual: %.3f" %
                               (expectedimprove, improve))
                    if not np.isfinite(meanlosses).all():
                        logger.log("Got non-finite value of losses -- bad!")
                    elif kl > epsilon * 1.5:
                        logger.log("violated KL constraint. shrinking step.")
                    elif improve < 0:
                        logger.log("surrogate didn't improve. shrinking step.")
                    else:
                        logger.log("Stepsize OK!")
                        break
                    stepsize *= .5
                '''else:
                    logger.log("couldn't compute a good step")
                    set_from_flat(thbefore)'''
                if nworkers > 1 and iters_so_far % 20 == 0:
                    paramsums = MPI.COMM_WORLD.allgather(
                        (thnew.sum(),
                         vfadam.getflat().sum()))  # list of tuples
                    assert all(
                        np.allclose(ps, paramsums[0]) for ps in paramsums[1:])

            else:
                #
                # COPOS specific implementation.
                #
                copos_update_dir = stepdir

                # Split direction into log-linear 'w_theta' and non-linear 'w_beta' parts
                w_theta, w_beta = pi.split_w(copos_update_dir)

                tmp_ob = np.zeros(
                    (1, ) + env.observation_space.shape
                )  # We assume that entropy does not depend on the NN

                # Optimize eta and omega
                if discrete_ac_space:
                    entropy = lossbefore[4]
                    #entropy = - 1/timesteps_per_batch * np.sum(np.sum(pi.get_action_prob(ob) * pi.get_log_action_prob(ob), axis=1))
                    eta, omega = eta_omega_optimizer.optimize(
                        pi.compute_F_w(ob, copos_update_dir),
                        pi.get_log_action_prob(ob), timesteps_per_batch,
                        entropy)
                else:
                    Waa, Wsa = pi.w2W(w_theta)
                    wa = pi.get_wa(ob, w_beta)

                    varphis = pi.get_varphis(ob)

                    #old_ent = old_entropy.eval({oldpi.ob: tmp_ob})[0]
                    old_ent = lossbefore[4]
                    eta, omega = eta_omega_optimizer.optimize(
                        w_theta, Waa, Wsa, wa, varphis, pi.get_kt(),
                        pi.get_prec_matrix(), pi.is_new_policy_valid, old_ent)
                logger.log("Initial eta: " + str(eta) + " and omega: " +
                           str(omega))

                current_theta_beta = get_flat()
                prev_theta, prev_beta = pi.all_to_theta_beta(
                    current_theta_beta)

                if discrete_ac_space:
                    # Do a line search for both theta and beta parameters by adjusting only eta
                    eta = eta_search(w_theta, w_beta, eta, omega, allmean,
                                     compute_losses, get_flat, set_from_flat,
                                     pi, epsilon, args, discrete_ac_space)
                    logger.log("Updated eta, eta: " + str(eta))
                    set_from_flat(pi.theta_beta_to_all(prev_theta, prev_beta))
                    # Find proper omega for new eta. Use old policy parameters first.
                    eta, omega = eta_omega_optimizer.optimize(
                        pi.compute_F_w(ob, copos_update_dir),
                        pi.get_log_action_prob(ob), timesteps_per_batch,
                        entropy, eta)
                    logger.log("Updated omega, eta: " + str(eta) +
                               " and omega: " + str(omega))

                    # do line search for ratio for non-linear "beta" parameter values
                    #ratio = beta_ratio_line_search(w_theta, w_beta, eta, omega, allmean, compute_losses, get_flat, set_from_flat, pi,
                    #                     epsilon, beta, args)
                    # set ratio to 1 if we do not use beta ratio line search
                    ratio = 1
                    #print("ratio from line search: " + str(ratio))
                    cur_theta = (eta * prev_theta +
                                 w_theta.reshape(-1, )) / (eta + omega)
                    cur_beta = prev_beta + ratio * w_beta.reshape(-1, ) / eta
                else:
                    for i in range(2):
                        # Do a line search for both theta and beta parameters by adjusting only eta
                        eta = eta_search(w_theta, w_beta, eta, omega, allmean,
                                         compute_losses, get_flat,
                                         set_from_flat, pi, epsilon, args)
                        logger.log("Updated eta, eta: " + str(eta) +
                                   " and omega: " + str(omega))

                        # Find proper omega for new eta. Use old policy parameters first.
                        set_from_flat(
                            pi.theta_beta_to_all(prev_theta, prev_beta))
                        eta, omega = \
                            eta_omega_optimizer.optimize(w_theta, Waa, Wsa, wa, varphis, pi.get_kt(),
                                                         pi.get_prec_matrix(), pi.is_new_policy_valid, old_ent, eta)
                        logger.log("Updated omega, eta: " + str(eta) +
                                   " and omega: " + str(omega))

                    # Use final policy
                    logger.log("Final eta: " + str(eta) + " and omega: " +
                               str(omega))
                    cur_theta = (eta * prev_theta +
                                 w_theta.reshape(-1, )) / (eta + omega)
                    cur_beta = prev_beta + w_beta.reshape(-1, ) / eta

                paramnew = allmean(pi.theta_beta_to_all(cur_theta, cur_beta))
                set_from_flat(paramnew)
                meanlosses = surr, kl, *_ = allmean(
                    np.array(compute_losses(*args)))
                if nworkers > 1 and iters_so_far % 20 == 0:
                    paramsums = MPI.COMM_WORLD.allgather(
                        (paramnew.sum(),
                         vfadam.getflat().sum()))  # list of tuples
                    assert all(
                        np.allclose(ps, paramsums[0]) for ps in paramsums[1:])
                ##copos specific over
#cg over

        for (lossname, lossval) in zip(loss_names, meanlosses):
            logger.record_tabular(lossname, lossval)


#policy update over
        with timed("vf"):
            for _ in range(vf_iters):
                for (mbob, mbret) in dataset.iterbatches(
                    (seg["ob"], seg["tdlamret"]),
                        include_final_partial_batch=False,
                        batch_size=64):
                    g = allmean(compute_vflossandgrad(mbob, mbret))
                    vfadam.update(g, vf_stepsize)

        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))

        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        print("Reward max: " + str(max(rewbuffer)))
        print("Reward min: " + str(min(rewbuffer)))

        logger.record_tabular(
            "EpLenMean",
            np.mean(lenbuffer) if np.sum(lenbuffer) != 0.0 else 0.0)
        logger.record_tabular(
            "EpRewMean",
            np.mean(rewbuffer) if np.sum(rewbuffer) != 0.0 else 0.0)
        logger.record_tabular(
            "AverageReturn",
            np.mean(rewbuffer) if np.sum(rewbuffer) != 0.0 else 0.0)
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)

        if rank == 0:
            logger.dump_tabular()
Example #32
0
def learn(
    env,
    policy_func,
    *,
    timesteps_per_batch,  # timesteps per actor per update
    clip_param,
    entcoeff,  # clipping parameter epsilon, entropy coeff
    optim_epochs,
    optim_stepsize,
    optim_batchsize,  # optimization hypers
    gamma,
    lam,  # advantage estimation
    max_timesteps=0,
    max_episodes=0,
    max_iters=0,
    max_seconds=0,  # time constraint
    callback=None,  # you can do anything in the callback, since it takes locals(), globals()
    adam_epsilon=1e-5,
    schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
    sym_loss_weight=0.0,
    return_threshold=None,  # termiante learning if reaches return_threshold
    op_after_init=None,
    init_policy_params=None,
    policy_scope=None,
    max_threshold=None,
    positive_rew_enforce=False,
    reward_drop_bound=True,
    min_iters=0,
    ref_policy_params=None,
    discrete_learning=None  # [obs_disc, act_disc, state_filter_fn, state_unfilter_fn, weight]
):

    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    if policy_scope is None:
        pi = policy_func("pi", ob_space,
                         ac_space)  # Construct network for new policy
        oldpi = policy_func("oldpi", ob_space,
                            ac_space)  # Network for old policy
    else:
        pi = policy_func(policy_scope, ob_space,
                         ac_space)  # Construct network for new policy
        oldpi = policy_func("old" + policy_scope, ob_space,
                            ac_space)  # Network for old policy

    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = U.mean(kloldnew)
    meanent = U.mean(ent)
    pol_entpen = (-entcoeff) * meanent

    sym_loss = sym_loss_weight * U.mean(
        tf.square(pi.mean - pi.mirrored_mean))  # mirror symmetric loss
    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = U.clip(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg  #
    pol_surr = -U.mean(tf.minimum(
        surr1, surr2)) + sym_loss  # PPO's pessimistic surrogate (L^CLIP)

    vf_loss = U.mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent, sym_loss]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent", "sym_loss"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()

    if init_policy_params is not None:
        cur_scope = pi.get_variables()[0].name[0:pi.get_variables()[0].name.
                                               find('/')]
        orig_scope = list(init_policy_params.keys()
                          )[0][0:list(init_policy_params.keys())[0].find('/')]
        for i in range(len(pi.get_variables())):
            assign_op = pi.get_variables()[i].assign(
                init_policy_params[pi.get_variables()[i].name.replace(
                    cur_scope, orig_scope, 1)])
            U.get_session().run(assign_op)
            assign_op = oldpi.get_variables()[i].assign(
                init_policy_params[pi.get_variables()[i].name.replace(
                    cur_scope, orig_scope, 1)])
            U.get_session().run(assign_op)

    if ref_policy_params is not None:
        ref_pi = policy_func("ref_pi", ob_space, ac_space)
        cur_scope = ref_pi.get_variables()[0].name[0:ref_pi.get_variables()[0].
                                                   name.find('/')]
        orig_scope = list(ref_policy_params.keys()
                          )[0][0:list(ref_policy_params.keys())[0].find('/')]
        for i in range(len(ref_pi.get_variables())):
            assign_op = ref_pi.get_variables()[i].assign(
                ref_policy_params[ref_pi.get_variables()[i].name.replace(
                    cur_scope, orig_scope, 1)])
            U.get_session().run(assign_op)
        #env.env.env.ref_policy = ref_pi

    adam.sync()

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_batch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    max_thres_satisfied = max_threshold is None
    adjust_ratio = 0.0
    prev_avg_rew = -1000000
    revert_parameters = {}
    variables = pi.get_variables()
    for i in range(len(variables)):
        cur_val = variables[i].eval()
        revert_parameters[variables[i].name] = cur_val
    revert_data = [0, 0, 0]
    all_collected_transition_data = []
    Vfunc = {}

    # temp
    import joblib
    path = 'data/value_iter_cartpole_discrete'
    [Vfunc, obs_disc, act_disc, state_filter_fn,
     state_unfilter_fn] = joblib.load(path + '/ref_policy_funcs.pkl')

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()

        if reward_drop_bound is not None:
            lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
            listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
            lens, rews = map(flatten_lists, zip(*listoflrpairs))
            lenbuffer.extend(lens)
            rewbuffer.extend(rews)
            revert_iteration = False
            if np.mean(
                    rewbuffer
            ) < prev_avg_rew - 50:  # detect significant drop in performance, revert to previous iteration
                print("Revert Iteration!!!!!")
                revert_iteration = True
            else:
                prev_avg_rew = np.mean(rewbuffer)
            logger.record_tabular("Revert Rew", prev_avg_rew)
            if revert_iteration:  # revert iteration
                for i in range(len(pi.get_variables())):
                    assign_op = pi.get_variables()[i].assign(
                        revert_parameters[pi.get_variables()[i].name])
                    U.get_session().run(assign_op)
                episodes_so_far = revert_data[0]
                timesteps_so_far = revert_data[1]
                iters_so_far = revert_data[2]
                continue
            else:
                variables = pi.get_variables()
                for i in range(len(variables)):
                    cur_val = variables[i].eval()
                    revert_parameters[variables[i].name] = np.copy(cur_val)
                revert_data[0] = episodes_so_far
                revert_data[1] = timesteps_so_far
                revert_data[2] = iters_so_far

        if positive_rew_enforce:
            rewlocal = (seg["pos_rews"], seg["neg_pens"], seg["rew"]
                        )  # local values
            listofrews = MPI.COMM_WORLD.allgather(rewlocal)  # list of tuples
            pos_rews, neg_pens, rews = map(flatten_lists, zip(*listofrews))
            if np.mean(rews) < 0.0:
                #min_id = np.argmin(rews)
                #adjust_ratio = pos_rews[min_id]/np.abs(neg_pens[min_id])
                adjust_ratio = np.max([
                    adjust_ratio,
                    np.mean(pos_rews) / np.abs(np.mean(neg_pens))
                ])
                for i in range(len(seg["rew"])):
                    if np.abs(seg["rew"][i] - seg["pos_rews"][i] -
                              seg["neg_pens"][i]) > 1e-5:
                        print(seg["rew"][i], seg["pos_rews"][i],
                              seg["neg_pens"][i])
                        print('Reward wrong!')
                        abc
                    seg["rew"][i] = seg["pos_rews"][
                        i] + seg["neg_pens"][i] * adjust_ratio
        if ref_policy_params is not None:
            rewed = 0
            for i in range(len(seg["rew"])):
                #pred_nexvf = np.max([ref_pi.act(False, seg["collected_transitions"][i][5])[1], pi.act(False, seg["collected_transitions"][i][5])[1]])
                #pred_curvf = np.max([ref_pi.act(False, seg["collected_transitions"][i][4])[1], pi.act(False, seg["collected_transitions"][i][4])[1]])

                if obs_disc(state_filter_fn(seg["collected_transitions"][i][2])) in Vfunc and \
                        obs_disc(state_filter_fn(seg["collected_transitions"][i][0])) in Vfunc:
                    pred_nexvf = Vfunc[obs_disc(
                        state_filter_fn(seg["collected_transitions"][i][2]))]
                    pred_curvf = Vfunc[obs_disc(
                        state_filter_fn(seg["collected_transitions"][i][0]))]
                    rewed += 1
                else:
                    pred_nexvf = 0
                    pred_curvf = 0

                vf_diff = 0.99 * pred_nexvf - pred_curvf
                seg["rew"][i] += vf_diff * 0.1
            print('rewarded for : ', rewed / len(seg["rew"]))
        if discrete_learning is not None:
            rewlocal = (seg["collected_transitions"], seg["rew"]
                        )  # local values
            listofrews = MPI.COMM_WORLD.allgather(rewlocal)  # list of tuples
            collected_transitions, rews = map(flatten_lists, zip(*listofrews))
            processed_transitions = []
            for trans in collected_transitions:
                processed_transitions.append([
                    discrete_learning[2](trans[0]), trans[1],
                    discrete_learning[2](trans[2]), trans[3]
                ])
            all_collected_transition_data += processed_transitions
            if len(all_collected_transition_data) > 500000:
                all_collected_transition_data = all_collected_transition_data[
                    len(all_collected_transition_data) - 500000:]
            if len(all_collected_transition_data) > 50000:
                logger.log("Fitting discrete dynamic model...")
                dyn_model, obs_disc = fit_dyn_model(
                    discrete_learning[0], discrete_learning[1],
                    all_collected_transition_data)
                logger.log(
                    "Perform value iteration on the discrete dynamic model...")
                Vfunc, policy = optimize_policy(dyn_model, 0.99)
                discrete_learning[0] = obs_disc
                rewarded = 0
                for i in range(len(seg["rew"])):
                    vf_diff = 0.99*Vfunc[discrete_learning[0](discrete_learning[2](seg["collected_transitions"][i][2]))] - \
                        Vfunc[discrete_learning[0](discrete_learning[2](seg["collected_transitions"][i][0]))]
                    seg["rew"][i] += vf_diff * discrete_learning[4]
                    #if policy[discrete_learning[0](discrete_learning[2](seg["collected_transitions"][i][0]))] == discrete_learning[1](seg["collected_transitions"][i][1]):
                    #    seg["rew"][i] += 2.0
                    #    rewarded += 1
                #logger.log(str(rewarded*1.0/len(seg["rew"])) + ' rewarded')

        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))
        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        if reward_drop_bound is None:
            lenbuffer.extend(lens)
            rewbuffer.extend(rews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        logger.record_tabular("Iter", iters_so_far)
        if positive_rew_enforce:
            if adjust_ratio is not None:
                logger.record_tabular("RewardAdjustRatio", adjust_ratio)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()
        if return_threshold is not None and max_thres_satisfied:
            if np.mean(
                    rewbuffer) > return_threshold and iters_so_far > min_iters:
                break
        if max_threshold is not None:
            print('Current max return: ', np.max(rewbuffer))
            if np.max(rewbuffer) > max_threshold:
                max_thres_satisfied = True
            else:
                max_thres_satisfied = False
    return pi
Example #33
0
def learn(*,
        network,
        env,
        total_timesteps,
        timesteps_per_batch=1024, # what to train on
        max_kl=0.001,
        cg_iters=10,
        gamma=0.99,
        lam=1.0, # advantage estimation
        seed=None,
        ent_coef=0.0,
        cg_damping=1e-2,
        vf_stepsize=3e-4,
        vf_iters =3,
        max_episodes=0, max_iters=0,  # time constraint
        callback=None,
        load_path=None,
        **network_kwargs
        ):
    '''
    learn a policy function with TRPO algorithm

    Parameters:
    ----------

    network                 neural network to learn. Can be either string ('mlp', 'cnn', 'lstm', 'lnlstm' for basic types)
                            or function that takes input placeholder and returns tuple (output, None) for feedforward nets
                            or (output, (state_placeholder, state_output, mask_placeholder)) for recurrent nets

    env                     environment (one of the gym environments or wrapped via baselines.common.vec_env.VecEnv-type class

    timesteps_per_batch     timesteps per gradient estimation batch

    max_kl                  max KL divergence between old policy and new policy ( KL(pi_old || pi) )

    ent_coef                coefficient of policy entropy term in the optimization objective

    cg_iters                number of iterations of conjugate gradient algorithm

    cg_damping              conjugate gradient damping

    vf_stepsize             learning rate for adam optimizer used to optimie value function loss

    vf_iters                number of iterations of value function optimization iterations per each policy optimization step

    total_timesteps           max number of timesteps

    max_episodes            max number of episodes

    max_iters               maximum number of policy optimization iterations

    callback                function to be called with (locals(), globals()) each policy optimization step

    load_path               str, path to load the model from (default: None, i.e. no model is loaded)

    **network_kwargs        keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network

    Returns:
    -------

    learnt model

    '''

    if MPI is not None:
        nworkers = MPI.COMM_WORLD.Get_size()
        rank = MPI.COMM_WORLD.Get_rank()
    else:
        nworkers = 1
        rank = 0

    cpus_per_worker = 1
    U.get_session(config=tf.ConfigProto(
            allow_soft_placement=True,
            inter_op_parallelism_threads=cpus_per_worker,
            intra_op_parallelism_threads=cpus_per_worker
    ))


    policy = build_policy(env, network, value_network='copy', **network_kwargs)
    set_global_seeds(seed)

    np.set_printoptions(precision=3)
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space

    ob = observation_placeholder(ob_space)
    with tf.variable_scope("pi"):
        pi = policy(observ_placeholder=ob)
    with tf.variable_scope("oldpi"):
        oldpi = policy(observ_placeholder=ob)

    atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return

    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    entbonus = ent_coef * meanent

    vferr = tf.reduce_mean(tf.square(pi.vf - ret))

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # advantage * pnew / pold
    surrgain = tf.reduce_mean(ratio * atarg)

    optimgain = surrgain + entbonus
    losses = [optimgain, meankl, entbonus, surrgain, meanent]
    loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]

    dist = meankl

    all_var_list = get_trainable_variables("pi")
    # var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("pol")]
    # vf_var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("vf")]
    var_list = get_pi_trainable_variables("pi")
    vf_var_list = get_vf_trainable_variables("pi")

    vfadam = MpiAdam(vf_var_list)

    get_flat = U.GetFlat(var_list)
    set_from_flat = U.SetFromFlat(var_list)
    klgrads = tf.gradients(dist, var_list)
    flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan")
    shapes = [var.get_shape().as_list() for var in var_list]
    start = 0
    tangents = []
    for shape in shapes:
        sz = U.intprod(shape)
        tangents.append(tf.reshape(flat_tangent[start:start+sz], shape))
        start += sz
    gvp = tf.add_n([tf.reduce_sum(g*tangent) for (g, tangent) in zipsame(klgrads, tangents)]) #pylint: disable=E1111
    fvp = U.flatgrad(gvp, var_list)

    assign_old_eq_new = U.function([],[], updates=[tf.assign(oldv, newv)
        for (oldv, newv) in zipsame(get_variables("oldpi"), get_variables("pi"))])

    compute_losses = U.function([ob, ac, atarg], losses)
    compute_lossandgrad = U.function([ob, ac, atarg], losses + [U.flatgrad(optimgain, var_list)])
    compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp)
    compute_vflossandgrad = U.function([ob, ret], U.flatgrad(vferr, vf_var_list))

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(colorize("done in %.3f seconds"%(time.time() - tstart), color='magenta'))
        else:
            yield

    def allmean(x):
        assert isinstance(x, np.ndarray)
        if MPI is not None:
            out = np.empty_like(x)
            MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
            out /= nworkers
        else:
            out = np.copy(x)

        return out

    U.initialize()
    if load_path is not None:
        pi.load(load_path)

    th_init = get_flat()
    if MPI is not None:
        MPI.COMM_WORLD.Bcast(th_init, root=0)

    set_from_flat(th_init)
    vfadam.sync()
    print("Init param sum", th_init.sum(), flush=True)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi, env, timesteps_per_batch, stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards

    if sum([max_iters>0, total_timesteps>0, max_episodes>0])==0:
        # noththing to be done
        return pi

    assert sum([max_iters>0, total_timesteps>0, max_episodes>0]) < 2, \
        'out of max_iters, total_timesteps, and max_episodes only one should be specified'

    while True:
        if callback: callback(locals(), globals())
        if total_timesteps and timesteps_so_far >= total_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        logger.log("********** Iteration %i ************"%iters_so_far)

        with timed("sampling"):
            seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"]
        vpredbefore = seg["vpred"] # predicted value function before udpate
        atarg = (atarg - atarg.mean()) / atarg.std() # standardized advantage function estimate

        if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret)
        if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy

        args = seg["ob"], seg["ac"], atarg
        fvpargs = [arr[::5] for arr in args]
        def fisher_vector_product(p):
            return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p

        assign_old_eq_new() # set old parameter values to new parameter values
        with timed("computegrad"):
            *lossbefore, g = compute_lossandgrad(*args)
        lossbefore = allmean(np.array(lossbefore))
        g = allmean(g)
        if np.allclose(g, 0):
            logger.log("Got zero gradient. not updating")
        else:
            with timed("cg"):
                stepdir = cg(fisher_vector_product, g, cg_iters=cg_iters, verbose=rank==0)
            assert np.isfinite(stepdir).all()
            shs = .5*stepdir.dot(fisher_vector_product(stepdir))
            lm = np.sqrt(shs / max_kl)
            # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
            fullstep = stepdir / lm
            expectedimprove = g.dot(fullstep)
            surrbefore = lossbefore[0]
            stepsize = 1.0
            thbefore = get_flat()
            for _ in range(10):
                thnew = thbefore + fullstep * stepsize
                set_from_flat(thnew)
                meanlosses = surr, kl, *_ = allmean(np.array(compute_losses(*args)))
                improve = surr - surrbefore
                logger.log("Expected: %.3f Actual: %.3f"%(expectedimprove, improve))
                if not np.isfinite(meanlosses).all():
                    logger.log("Got non-finite value of losses -- bad!")
                elif kl > max_kl * 1.5:
                    logger.log("violated KL constraint. shrinking step.")
                elif improve < 0:
                    logger.log("surrogate didn't improve. shrinking step.")
                else:
                    logger.log("Stepsize OK!")
                    break
                stepsize *= .5
            else:
                logger.log("couldn't compute a good step")
                set_from_flat(thbefore)
            if nworkers > 1 and iters_so_far % 20 == 0:
                paramsums = MPI.COMM_WORLD.allgather((thnew.sum(), vfadam.getflat().sum())) # list of tuples
                assert all(np.allclose(ps, paramsums[0]) for ps in paramsums[1:])

        for (lossname, lossval) in zip(loss_names, meanlosses):
            logger.record_tabular(lossname, lossval)

        with timed("vf"):

            for _ in range(vf_iters):
                for (mbob, mbret) in dataset.iterbatches((seg["ob"], seg["tdlamret"]),
                include_final_partial_batch=False, batch_size=64):
                    g = allmean(compute_vflossandgrad(mbob, mbret))
                    vfadam.update(g, vf_stepsize)

        logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))

        lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values
        if MPI is not None:
            listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
        else:
            listoflrpairs = [lrlocal]

        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)

        if rank==0:
            logger.dump_tabular()

    return pi
Example #34
0
def learn(
        *,
        network,
        env,
        total_timesteps,
        timesteps_per_batch=1024,  # what to train on
        max_kl=0.002,
        cg_iters=10,
        gamma=0.99,
        lam=1.0,  # advantage estimation
        seed=None,
        ent_coef=0.00,
        cg_damping=1e-2,
        vf_stepsize=3e-4,
        vf_iters=3,
        max_episodes=0,
        max_iters=0,  # time constraint
        callback=None,
        load_path=None,
        num_reward=1,
        **network_kwargs):
    '''
    learn a policy function with TRPO algorithm

    Parameters:
    ----------

    network                 neural network to learn. Can be either string ('mlp', 'cnn', 'lstm', 'lnlstm' for basic types)
                            or function that takes input placeholder and returns tuple (output, None) for feedforward nets
                            or (output, (state_placeholder, state_output, mask_placeholder)) for recurrent nets

    env                     environment (one of the gym environments or wrapped via baselines.common.vec_env.VecEnv-type class

    timesteps_per_batch     timesteps per gradient estimation batch

    max_kl                  max KL divergence between old policy and new policy ( KL(pi_old || pi) )

    ent_coef                coefficient of policy entropy term in the optimization objective

    cg_iters                number of iterations of conjugate gradient algorithm

    cg_damping              conjugate gradient damping

    vf_stepsize             learning rate for adam optimizer used to optimie value function loss

    vf_iters                number of iterations of value function optimization iterations per each policy optimization step

    total_timesteps           max number of timesteps

    max_episodes            max number of episodes

    max_iters               maximum number of policy optimization iterations

    callback                function to be called with (locals(), globals()) each policy optimization step

    load_path               str, path to load the model from (default: None, i.e. no model is loaded)

    **network_kwargs        keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network

    Returns:
    -------

    learnt model

    '''

    if MPI is not None:
        nworkers = MPI.COMM_WORLD.Get_size()
        rank = MPI.COMM_WORLD.Get_rank()
    else:
        nworkers = 1
        rank = 0

    cpus_per_worker = 1
    U.get_session(
        config=tf.ConfigProto(allow_soft_placement=True,
                              inter_op_parallelism_threads=cpus_per_worker,
                              intra_op_parallelism_threads=cpus_per_worker))

    set_global_seeds(seed)
    # 创建policy
    policy = build_policy(env,
                          network,
                          value_network='copy',
                          num_reward=num_reward,
                          **network_kwargs)

    process_dir = logger.get_dir()
    save_dir = process_dir.split(
        'Data')[-2] + 'log/mu/seed' + process_dir[-1] + '/'
    os.makedirs(save_dir, exist_ok=True)
    coe_save = []
    impro_save = []
    grad_save = []
    adj_save = []
    coe = np.ones((num_reward)) / num_reward

    np.set_printoptions(precision=3)
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space

    #################################################################
    # ob ac ret atarg 都是 placeholder
    # ret atarg 此处应该是向量形式
    ob = observation_placeholder(ob_space)

    # 创建pi和oldpi
    with tf.variable_scope("pi"):
        pi = policy(observ_placeholder=ob)
    with tf.variable_scope("oldpi"):
        oldpi = policy(observ_placeholder=ob)

    # 每个reward都可以算一个atarg
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32,
                         shape=[None, num_reward])  # Empirical return

    ac = pi.pdtype.sample_placeholder([None])

    #此处的KL div和entropy与reward无关
    ##################################
    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    # entbonus 是entropy loss
    entbonus = ent_coef * meanent
    #################################

    ###########################################################
    # vferr 用来更新 v 网络
    vferr = tf.reduce_mean(tf.square(pi.vf - ret))
    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))
    # advantage * pnew / pold
    surrgain = tf.reduce_mean(ratio * atarg)

    # optimgain 用来更新 policy 网络, 应该每个reward有一个
    optimgain = surrgain + entbonus
    losses = [optimgain, meankl, entbonus, surrgain, meanent]
    loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]

    ###########################################################
    dist = meankl

    # 定义要优化的变量和 V 网络 adam 优化器
    all_var_list = get_trainable_variables("pi")
    # var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("pol")]
    # vf_var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("vf")]
    var_list = get_pi_trainable_variables("pi")
    vf_var_list = get_vf_trainable_variables("pi")

    vfadam = MpiAdam(vf_var_list)

    # 把变量展开成一个向量的类
    get_flat = U.GetFlat(var_list)

    # 这个类可以把一个向量分片赋值给var_list里的变量
    set_from_flat = U.SetFromFlat(var_list)
    # kl散度的梯度
    klgrads = tf.gradients(dist, var_list)

    ####################################################################
    # 拉直的向量
    flat_tangent = tf.placeholder(dtype=tf.float32,
                                  shape=[None],
                                  name="flat_tan")

    # 把拉直的向量重新分成很多向量
    shapes = [var.get_shape().as_list() for var in var_list]
    start = 0
    tangents = []
    for shape in shapes:
        sz = U.intprod(shape)
        tangents.append(tf.reshape(flat_tangent[start:start + sz], shape))
        start += sz
    ####################################################################

    ####################################################################
    # 把kl散度梯度与变量乘积相加
    gvp = tf.add_n([
        tf.reduce_sum(g * tangent)
        for (g, tangent) in zipsame(klgrads, tangents)
    ])  #pylint: disable=E1111
    # 把gvp的梯度展成向量
    fvp = U.flatgrad(gvp, var_list)
    ####################################################################

    # 用学习后的策略更新old策略
    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(get_variables("oldpi"), get_variables("pi"))
        ])

    # 计算loss
    compute_losses = U.function([ob, ac, atarg], losses)
    # 计算loss和梯度
    compute_lossandgrad = U.function([ob, ac, atarg], losses +
                                     [U.flatgrad(optimgain, var_list)])
    # 计算fvp
    compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp)
    # 计算值网络的梯度
    compute_vflossandgrad = U.function([ob, ret],
                                       U.flatgrad(vferr, vf_var_list))

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(
                colorize("done in %.3f seconds" % (time.time() - tstart),
                         color='magenta'))
        else:
            yield

    def allmean(x):
        assert isinstance(x, np.ndarray)
        if MPI is not None:
            out = np.empty_like(x)
            MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
            out /= nworkers
        else:
            out = np.copy(x)

        return out

    # 初始化variable
    U.initialize()
    if load_path is not None:
        pi.load(load_path)

    # 得到初始化的参数向量
    th_init = get_flat()
    if MPI is not None:
        MPI.COMM_WORLD.Bcast(th_init, root=0)

    # 把向量the_init的值分片赋值给var_list
    set_from_flat(th_init)

    #同步
    vfadam.sync()
    print("Init param sum", th_init.sum(), flush=True)

    # Prepare for rollouts
    # ----------------------------------------

    # 这是一个生成数据的迭代器
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_batch,
                                     stochastic=True,
                                     num_reward=num_reward)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()

    # 双端队列
    lenbuffer = deque(maxlen=40)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40)  # rolling buffer for episode rewards

    if sum([max_iters > 0, total_timesteps > 0, max_episodes > 0]) == 0:
        # noththing to be done
        return pi

    assert sum([max_iters>0, total_timesteps>0, max_episodes>0]) < 2, \
        'out of max_iters, total_timesteps, and max_episodes only one should be specified'

    while True:
        if callback: callback(locals(), globals())
        if total_timesteps and timesteps_so_far >= total_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        logger.log("********** Iteration %i ************" % iters_so_far)

        with timed("sampling"):
            seg = seg_gen.__next__()

        # 计算累积回报
        add_vtarg_and_adv(seg, gamma, lam, num_reward=num_reward)
        ###########$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$ToDo
        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))

        # ob, ac, atarg, tdlamret 的类型都是ndarray
        #ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"]
        _, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        #print(seg['ob'].shape,type(seg['ob']))
        #print(seg['ac'],type(seg['ac']))
        #print(seg['adv'],type(seg['adv']))
        #print(seg["tdlamret"].shape,type(seg['tdlamret']))
        vpredbefore = seg["vpred"]  # predicted value function before udpate

        # 标准化
        #print("============================== atarg =========================================================")
        #print(atarg)
        atarg = (atarg - np.mean(atarg, axis=0)) / np.std(
            atarg, axis=0)  # standardized advantage function estimate
        #atarg = (atarg) / np.max(np.abs(atarg),axis=0)
        #print('======================================= standardized atarg ====================================')
        #print(atarg)
        if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret)
        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        ## set old parameter values to new parameter values
        assign_old_eq_new()

        G = None
        S = None
        mr_lossbefore = np.zeros((num_reward, len(loss_names)))
        grad_norm = np.zeros((num_reward + 1))
        for i in range(num_reward):
            args = seg["ob"], seg["ac"], atarg[:, i]
            #print(atarg[:,i])
            # 算是args的一个sample,每隔5个取出一个
            fvpargs = [arr[::5] for arr in args]

            # 这个函数计算fisher matrix 与向量 p 的 乘积
            def fisher_vector_product(p):
                return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p

            with timed("computegrad of " + str(i + 1) + ".th reward"):
                *lossbefore, g = compute_lossandgrad(*args)
            lossbefore = allmean(np.array(lossbefore))
            mr_lossbefore[i] = lossbefore
            g = allmean(g)
            #print("***************************************************************")
            #print(g)
            if isinstance(G, np.ndarray):
                G = np.vstack((G, g))
            else:
                G = g

            # g是目标函数的梯度
            # 利用共轭梯度获得更新方向
            if np.allclose(g, 0):
                logger.log("Got zero gradient. not updating")
            else:
                with timed("cg of " + str(i + 1) + ".th reward"):
                    # stepdir 是更新方向
                    stepdir = cg(fisher_vector_product,
                                 g,
                                 cg_iters=cg_iters,
                                 verbose=rank == 0)
                    shs = .5 * stepdir.dot(fisher_vector_product(stepdir))
                    lm = np.sqrt(shs / max_kl)
                    # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
                    fullstep = stepdir / lm
                    grad_norm[i] = np.linalg.norm(fullstep)
                assert np.isfinite(stepdir).all()
                if isinstance(S, np.ndarray):
                    S = np.vstack((S, stepdir))
                else:
                    S = stepdir
        #print('======================================= G ====================================')
        #print(G)
        #print('======================================= S ====================================')
        #print(S)
        try:
            new_coe = get_coefficient(G, S)
            #coe = 0.99 * coe + 0.01 * new_coe
            coe = new_coe
            coe_save.append(coe)
            #根据梯度的夹角调整参数
            # GG = np.dot(S, S.T)
            # D = np.sqrt(np.diag(1/np.diag(GG)))
            # GG = np.dot(np.dot(D,GG),D)
            # #print('======================================= inner product ====================================')
            # #print(GG)
            # adj = np.sum(GG) / (num_reward ** 2)
            adj = 1
            #print('======================================= adj ====================================')
            #print(adj)
            adj_save.append(adj)
            adj_max_kl = adj * max_kl
            #################################################################
            grad_norm = grad_norm * np.sqrt(adj)
            stepdir = np.dot(coe, S)
            g = np.dot(coe, G)
            lossbefore = np.dot(coe, mr_lossbefore)
            #################################################################

            shs = .5 * stepdir.dot(fisher_vector_product(stepdir))
            lm = np.sqrt(shs / adj_max_kl)
            # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
            fullstep = stepdir / lm
            grad_norm[num_reward] = np.linalg.norm(fullstep)
            grad_save.append(grad_norm)
            expectedimprove = g.dot(fullstep)
            surrbefore = lossbefore[0]
            stepsize = 1.0
            thbefore = get_flat()

            def compute_mr_losses():
                mr_losses = np.zeros((num_reward, len(loss_names)))
                for i in range(num_reward):
                    args = seg["ob"], seg["ac"], atarg[:, i]
                    one_reward_loss = allmean(np.array(compute_losses(*args)))
                    mr_losses[i] = one_reward_loss
                mr_loss = np.dot(coe, mr_losses)
                return mr_loss, mr_losses

            # 做10次搜索
            for _ in range(10):
                thnew = thbefore + fullstep * stepsize
                set_from_flat(thnew)
                mr_loss_new, mr_losses_new = compute_mr_losses()
                mr_impro = mr_losses_new - mr_lossbefore
                meanlosses = surr, kl, *_ = allmean(np.array(mr_loss_new))
                improve = surr - surrbefore
                logger.log("Expected: %.3f Actual: %.3f" %
                           (expectedimprove, improve))
                if not np.isfinite(meanlosses).all():
                    logger.log("Got non-finite value of losses -- bad!")
                elif kl > adj_max_kl * 1.5:
                    logger.log("violated KL constraint. shrinking step.")
                elif improve < 0:
                    logger.log("surrogate didn't improve. shrinking step.")
                else:
                    logger.log("Stepsize OK!")
                    impro_save.append(np.hstack((mr_impro[:, 0], improve)))
                    break
                stepsize *= .5
            else:
                logger.log("couldn't compute a good step")
                set_from_flat(thbefore)
            if nworkers > 1 and iters_so_far % 20 == 0:
                paramsums = MPI.COMM_WORLD.allgather(
                    (thnew.sum(), vfadam.getflat().sum()))  # list of tuples
                assert all(
                    np.allclose(ps, paramsums[0]) for ps in paramsums[1:])

            for (lossname, lossval) in zip(loss_names, meanlosses):
                logger.record_tabular(lossname, lossval)

            with timed("vf"):
                #print('======================================= tdlamret ====================================')
                #print(seg["tdlamret"])
                for _ in range(vf_iters):
                    for (mbob, mbret) in dataset.iterbatches(
                        (seg["ob"], seg["tdlamret"]),
                            include_final_partial_batch=False,
                            batch_size=64):
                        #with tf.Session() as sess:
                        #    sess.run(tf.global_variables_initializer())
                        #    aaa = sess.run(pi.vf,feed_dict={ob:mbob,ret:mbret})
                        #    print("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
                        #    print(aaa.shape)
                        #    print(mbret.shape)
                        g = allmean(compute_vflossandgrad(mbob, mbret))
                        vfadam.update(g, vf_stepsize)
        except:
            print('error')
            #print(mbob,mbret)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))

        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        if MPI is not None:
            listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        else:
            listoflrpairs = [lrlocal]

        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        if rank == 0:
            logger.dump_tabular()
        #pdb.set_trace()
    np.save(save_dir + 'coe.npy', coe_save)
    np.save(save_dir + 'grad.npy', grad_save)
    np.save(save_dir + 'improve.npy', impro_save)
    np.save(save_dir + 'adj.npy', adj_save)
    return pi
Example #35
0
def learn(
        make_env,
        make_policy,
        *,
        n_episodes,
        horizon,
        delta,
        gamma,
        max_iters,
        sampler=None,
        use_natural_gradient=False,  #can be 'exact', 'approximate'
        fisher_reg=1e-2,
        iw_method='is',
        iw_norm='none',
        bound='J',
        line_search_type='parabola',
        save_weights=0,
        improvement_tol=0.,
        center_return=False,
        render_after=None,
        max_offline_iters=100,
        callback=None,
        clipping=False,
        entropy='none',
        positive_return=False,
        reward_clustering='none',
        capacity=10,
        inner=10,
        penalization=True,
        learnable_variance=True,
        variance_initializer=-1,
        constant_step_size=0,
        shift_return=False,
        power=1,
        warm_start=True):

    np.set_printoptions(precision=3)
    max_samples = horizon * n_episodes

    if line_search_type == 'binary':
        line_search = line_search_binary
    elif line_search_type == 'parabola':
        line_search = line_search_parabola
    else:
        raise ValueError()

    if constant_step_size != 0:
        line_search = line_search_constant

    # Building the environment
    env = make_env()
    ob_space = env.observation_space
    ac_space = env.action_space

    # Creating the memory buffer
    memory = Memory(capacity=capacity,
                    batch_size=n_episodes,
                    horizon=horizon,
                    ob_space=ob_space,
                    ac_space=ac_space)

    # Building the target policy and saving its parameters
    pi = make_policy('pi', ob_space, ac_space)

    nu = make_policy('nu', ob_space, ac_space)

    all_var_list = nu.get_trainable_variables()
    var_list = [
        v for v in all_var_list if v.name.split('/')[1].startswith('pol')
    ]
    shapes = [U.intprod(var.get_shape().as_list()) for var in var_list]
    n_parameters = sum(shapes)

    all_var_list_pi = pi.get_trainable_variables()
    var_list_pi = [
        v for v in all_var_list_pi if v.name.split('/')[1].startswith('pol')
    ]

    # Building a set of behavioral policies
    memory.build_policies(make_policy, nu)

    # Placeholders
    ob_ = ob = U.get_placeholder_cached(name='ob')
    ac_ = pi.pdtype.sample_placeholder([None], name='ac')
    mask_ = tf.placeholder(dtype=tf.float32, shape=(None), name='mask')
    rew_ = tf.placeholder(dtype=tf.float32, shape=(None), name='rew')
    disc_rew_ = tf.placeholder(dtype=tf.float32, shape=(None), name='disc_rew')
    clustered_rew_ = tf.placeholder(dtype=tf.float32, shape=(None))
    gradient_ = tf.placeholder(dtype=tf.float32,
                               shape=(n_parameters, 1),
                               name='gradient')
    iter_number_ = tf.placeholder(dtype=tf.int32, name='iter_number')
    active_policies = tf.placeholder(dtype=tf.float32,
                                     shape=(capacity),
                                     name='active_policies')
    losses_with_name = []

    # Total number of trajectories
    N_total = tf.reduce_sum(active_policies) * n_episodes

    # Split operations
    disc_rew_split = tf.reshape(disc_rew_ * mask_, [-1, horizon])
    rew_split = tf.reshape(rew_ * mask_, [-1, horizon])
    mask_split = tf.reshape(mask_, [-1, horizon])

    # Policy densities
    target_log_pdf = pi.pd.logp(ac_) * mask_
    target_log_pdf_split = tf.reshape(target_log_pdf, [-1, horizon])
    behavioral_log_pdfs = tf.stack([
        bpi.pd.logp(ac_) * mask_ for bpi in memory.policies
    ])  # Shape is (capacity, ntraj*horizon)
    behavioral_log_pdfs_split = tf.reshape(behavioral_log_pdfs,
                                           [memory.capacity, -1, horizon])
    new_behavioural_log_pdf = nu.pd.logp(ac_) * mask_
    new_behavioural_log_pdf_split = tf.reshape(new_behavioural_log_pdf,
                                               [-1, horizon])

    divergence_split = tf.reshape(
        tf.stack([
            tf.log(pi.pd.compute_divergence(bpi.pd, nu.pd)) * mask_
            for bpi in memory.policies
        ]), [memory.capacity, -1, horizon])
    divergence_split_cum = tf.exp(tf.reduce_sum(divergence_split, axis=2))
    divergence_mean = tf.reduce_mean(divergence_split_cum, axis=1)
    divergence_harmonic = tf.reduce_sum(active_policies) / tf.reduce_sum(
        1 / divergence_mean)

    # Compute renyi divergencies and sum over time, then exponentiate
    emp_d2_split = tf.reshape(
        tf.stack([pi.pd.renyi(bpi.pd, 2) * mask_ for bpi in memory.policies]),
        [memory.capacity, -1, horizon])
    emp_d2_split_cum = tf.exp(tf.reduce_sum(emp_d2_split, axis=2))
    # Compute arithmetic and harmonic mean of emp_d2
    emp_d2_mean = tf.reduce_mean(emp_d2_split_cum, axis=1)
    emp_d2_arithmetic = tf.reduce_sum(
        emp_d2_mean * active_policies) / tf.reduce_sum(active_policies)
    emp_d2_harmonic = tf.reduce_sum(active_policies) / tf.reduce_sum(
        1 / emp_d2_mean)

    # Return processing: clipping, centering, discounting
    ep_return = clustered_rew_  #tf.reduce_sum(mask_split * disc_rew_split, axis=1)
    ep_return_optimization = (ep_return - tf.reduce_min(ep_return))**power
    if clipping:
        rew_split = tf.clip_by_value(rew_split, -1, 1)
    if center_return:
        ep_return = ep_return - tf.reduce_mean(ep_return)
        rew_split = rew_split - (tf.reduce_sum(rew_split) /
                                 (tf.reduce_sum(mask_split) + 1e-24))
    discounter = [pow(gamma, i) for i in range(0, horizon)]  # Decreasing gamma
    discounter_tf = tf.constant(discounter)
    disc_rew_split = rew_split * discounter_tf

    # Reward statistics
    return_mean = tf.reduce_mean(ep_return)
    optimization_return_mean = tf.reduce_mean(ep_return_optimization)
    return_std = U.reduce_std(ep_return)
    return_max = tf.reduce_max(ep_return)
    optimization_return_max = tf.reduce_max(ep_return_optimization)
    return_min = tf.reduce_min(ep_return)
    optimization_return_min = tf.reduce_min(ep_return_optimization)
    return_abs_max = tf.reduce_max(tf.abs(ep_return))
    optimization_return_abs_max = tf.reduce_max(tf.abs(ep_return_optimization))
    return_step_max = tf.reduce_max(tf.abs(rew_split))  # Max step reward
    return_step_mean = tf.abs(tf.reduce_mean(rew_split))
    positive_step_return_max = tf.maximum(0.0, tf.reduce_max(rew_split))
    negative_step_return_max = tf.maximum(0.0, tf.reduce_max(-rew_split))
    return_step_maxmin = tf.abs(positive_step_return_max -
                                negative_step_return_max)
    losses_with_name.extend([
        (return_mean, 'InitialReturnMean'), (return_max, 'InitialReturnMax'),
        (return_min, 'InitialReturnMin'),
        (optimization_return_mean, 'OptimizationReturnMean'),
        (optimization_return_max, 'OptimizationReturnMax'),
        (optimization_return_min, 'OptimizationReturnMin'),
        (return_std, 'InitialReturnStd'),
        (divergence_harmonic, 'DivergenceHarmonic'),
        (emp_d2_arithmetic, 'EmpiricalD2Arithmetic'),
        (emp_d2_harmonic, 'EmpiricalD2Harmonic'),
        (return_step_max, 'ReturnStepMax'),
        (return_step_maxmin, 'ReturnStepMaxmin')
    ])

    # Add D2 statistics for each memory cell
    for i in range(capacity):
        losses_with_name.extend([(tf.reduce_mean(emp_d2_split_cum, axis=1)[i],
                                  'MeanD2-' + str(i))])

    if iw_method == 'is':
        # Sum the log prob over time. Shapes: target(Nep, H), behav (Cap, Nep, H)
        target_log_pdf_episode = tf.reduce_sum(target_log_pdf_split, axis=1)
        behavioral_log_pdf_episode = tf.reduce_sum(behavioral_log_pdfs_split,
                                                   axis=2)
        new_behavioural_log_pdf_episode = tf.reduce_sum(
            new_behavioural_log_pdf_split, axis=1)
        # To avoid numerical instability, compute the inversed ratio
        log_inverse_ratio = behavioral_log_pdf_episode + new_behavioural_log_pdf_episode - 2 * target_log_pdf_episode
        abc = tf.exp(log_inverse_ratio) * tf.expand_dims(active_policies, -1)
        iw = 1 / tf.reduce_sum(
            tf.exp(log_inverse_ratio) * tf.expand_dims(active_policies, -1),
            axis=0)
        iwn = iw / n_episodes
        log_inverse_ratio_lb = behavioral_log_pdf_episode - target_log_pdf_episode
        iw_lb = 1 / tf.reduce_sum(
            tf.exp(log_inverse_ratio_lb) * tf.expand_dims(active_policies, -1),
            axis=0)
        iwn_lb = iw_lb / n_episodes
        w_return_mean_lb = tf.reduce_sum(ep_return**2 * iwn_lb)

        # Compute the J
        if shift_return:
            w_return_mean = tf.reduce_sum(ep_return_optimization**2 * iwn)
        else:
            w_return_mean = tf.reduce_sum(ep_return**2 * iwn)

        control_variate = tf.reduce_sum(return_min**2 * iwn)

        # Empirical D2 of the mixture and relative ESS
        ess_renyi_arithmetic = N_total / emp_d2_arithmetic
        ess_renyi_harmonic = N_total / emp_d2_harmonic
        ess_divergence_harmonic = N_total / divergence_harmonic

        # Log quantities
        losses_with_name.extend([
            (tf.reduce_max(iw), 'MaxIW'), (tf.reduce_min(iw), 'MinIW'),
            (tf.reduce_mean(iw), 'MeanIW'), (U.reduce_std(iw), 'StdIW'),
            (U.reduce_std(w_return_mean), 'StdWReturnMean'),
            (tf.reduce_min(target_log_pdf_episode), 'MinTargetPdf'),
            (tf.reduce_min(behavioral_log_pdf_episode), 'MinBehavPdf'),
            (ess_renyi_arithmetic, 'ESSRenyiArithmetic'),
            (ess_renyi_harmonic, 'ESSRenyiHarmonic')
        ])
    else:
        raise NotImplementedError()

    if bound == 'J':
        bound_ = w_return_mean
    elif bound == 'max-d2-harmonic':
        if penalization:
            if shift_return:
                bound_ = -w_return_mean - tf.sqrt(
                    (1 - delta) /
                    (delta *
                     ess_divergence_harmonic)) * optimization_return_abs_max**2
            else:
                bound_ = -w_return_mean - tf.sqrt(
                    (1 - delta) /
                    (delta * ess_divergence_harmonic)) * return_abs_max**2
        else:
            bound_ = -w_return_mean
        lower_bound = -w_return_mean_lb + tf.sqrt(
            (1 - delta) / (delta * ess_renyi_harmonic)) * return_abs_max**2
    elif bound == 'max-d2-arithmetic':
        bound_ = -w_return_mean - tf.sqrt(
            1 / (delta * ess_renyi_arithmetic)) * return_abs_max**2
    else:
        raise NotImplementedError()

    # Policy entropy for exploration
    ent = pi.pd.entropy()
    meanent = tf.reduce_mean(ent)
    losses_with_name.append((meanent, 'MeanEntropy'))
    # Add policy entropy bonus
    if entropy != 'none':
        scheme, v1, v2 = entropy.split(':')
        if scheme == 'step':
            entcoeff = tf.cond(iter_number_ < int(v2), lambda: float(v1),
                               lambda: float(0.0))
            losses_with_name.append((entcoeff, 'EntropyCoefficient'))
            entbonus = entcoeff * meanent
            bound_ = bound_ + entbonus
        elif scheme == 'lin':
            ip = tf.cast(iter_number_ / max_iters, tf.float32)
            entcoeff_decay = tf.maximum(
                0.0,
                float(v2) + (float(v1) - float(v2)) * (1.0 - ip))
            losses_with_name.append((entcoeff_decay, 'EntropyCoefficient'))
            entbonus = entcoeff_decay * meanent
            bound_ = bound_ + entbonus
        elif scheme == 'exp':
            ent_f = tf.exp(
                -tf.abs(tf.reduce_mean(iw) - 1) * float(v2)) * float(v1)
            losses_with_name.append((ent_f, 'EntropyCoefficient'))
            bound_ = bound_ + ent_f * meanent
        else:
            raise Exception('Unrecognized entropy scheme.')

    losses_with_name.append((w_return_mean, 'ReturnMeanIW'))
    losses_with_name.append((bound_, 'Bound'))
    losses, loss_names = map(list, zip(*losses_with_name))
    '''
    if use_natural_gradient:
        p = tf.placeholder(dtype=tf.float32, shape=[None])
        target_logpdf_episode = tf.reduce_sum(target_log_pdf_split * mask_split, axis=1)
        grad_logprob = U.flatgrad(tf.stop_gradient(iwn) * target_logpdf_episode, var_list)
        dot_product = tf.reduce_sum(grad_logprob * p)
        hess_logprob = U.flatgrad(dot_product, var_list)
        compute_linear_operator = U.function([p, ob_, ac_, disc_rew_, mask_], [-hess_logprob])
    '''

    assign_nu_eq_mu = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv, newv) in zipsame(nu.get_variables(), pi.get_variables())
        ])

    assign_mu_eq_nu = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv, newv) in zipsame(pi.get_variables(), nu.get_variables())
        ])

    assert_ops = tf.group(*tf.get_collection('asserts'))
    print_ops = tf.group(*tf.get_collection('prints'))

    compute_lossandgrad = U.function([
        ob_, ac_, rew_, disc_rew_, clustered_rew_, mask_, iter_number_,
        active_policies
    ], losses + [U.flatgrad(bound_, var_list), assert_ops, print_ops])
    compute_grad = U.function([
        ob_, ac_, rew_, disc_rew_, clustered_rew_, mask_, iter_number_,
        active_policies
    ], [U.flatgrad(bound_, var_list), assert_ops, print_ops])
    compute_bound = U.function([
        ob_, ac_, rew_, disc_rew_, clustered_rew_, mask_, iter_number_,
        active_policies
    ], [bound_, assert_ops, print_ops])
    compute_losses = U.function([
        ob_, ac_, rew_, disc_rew_, clustered_rew_, mask_, iter_number_,
        active_policies
    ], losses)
    compute_w_return = U.function([
        ob_, ac_, rew_, disc_rew_, clustered_rew_, mask_, iter_number_,
        active_policies
    ], [w_return_mean, assert_ops, print_ops])

    set_parameter = U.SetFromFlat(var_list)
    get_parameter = U.GetFlat(var_list)
    policy_reinit = tf.variables_initializer(var_list)

    get_parameter_pi = U.GetFlat(var_list_pi)

    if sampler is None:
        seg_gen = traj_segment_generator(pi,
                                         env,
                                         n_episodes,
                                         horizon,
                                         stochastic=True)
        sampler = type("SequentialSampler", (object, ), {
            "collect": lambda self, _: seg_gen.__next__()
        })()

    U.initialize()

    # Starting optimizing
    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=n_episodes)
    rewbuffer = deque(maxlen=n_episodes)

    while True:  #outer loop

        iters_so_far += 1  #index i

        if render_after is not None and iters_so_far % render_after == 0:
            if hasattr(env, 'render'):
                render(env, pi, horizon)

        if callback:
            callback(locals(), globals())

        if iters_so_far >= max_iters:
            print('Finished...')
            break

        logger.log('********** Iteration %i ************' % iters_so_far)

        assign_nu_eq_mu()

        #print(get_parameter(), get_parameter_pi())

        iters_so_far_inner = 0

        while True:  #inner loop

            iters_so_far_inner += 1  #index j

            if iters_so_far_inner >= inner + 1:
                print('Inner loop finished...')
                break

            logger.log('********** Inner Iteration %i ************' %
                       iters_so_far_inner)

            theta = get_parameter()

            with timed('sampling'):
                seg = sampler.collect(theta)

            add_disc_rew(seg, gamma)

            lens, rets = seg['ep_lens'], seg['ep_rets']
            lenbuffer.extend(lens)
            rewbuffer.extend(rets)
            episodes_so_far += len(lens)
            timesteps_so_far += sum(lens)

            # Adding batch of trajectories to memory
            memory.add_trajectory_batch(seg)

            # Get multiple batches from memory
            seg_with_memory = memory.get_trajectories()

            # Get clustered reward
            reward_matrix = np.reshape(
                seg_with_memory['disc_rew'] * seg_with_memory['mask'],
                (-1, horizon))
            ep_reward = np.sum(reward_matrix, axis=1)
            ep_reward = cluster_rewards(ep_reward, reward_clustering)

            args = ob, ac, rew, disc_rew, clustered_rew, mask, iter_number, active_policies = (
                seg_with_memory['ob'], seg_with_memory['ac'],
                seg_with_memory['rew'], seg_with_memory['disc_rew'], ep_reward,
                seg_with_memory['mask'], iters_so_far,
                memory.get_active_policies_mask())

            def evaluate_loss():
                loss = compute_bound(*args)
                return loss[0]

            def evaluate_gradient():
                gradient = compute_grad(*args)
                return gradient[0]

            if use_natural_gradient:

                def evaluate_fisher_vector_prod(x):
                    return compute_linear_operator(x, *
                                                   args)[0] + fisher_reg * x

                def evaluate_natural_gradient(g):
                    return cg(evaluate_fisher_vector_prod,
                              g,
                              cg_iters=10,
                              verbose=0)
            else:
                evaluate_natural_gradient = None

            with timed('summaries before'):
                logger.record_tabular("Iteration", iters_so_far)
                logger.record_tabular("Inner Iteration", iters_so_far_inner)
                logger.record_tabular("InitialBound", evaluate_loss())
                logger.record_tabular("EpLenMean", np.mean(lenbuffer))
                logger.record_tabular("EpRewMean", np.mean(rewbuffer))
                logger.record_tabular("EpThisIter", len(lens))
                logger.record_tabular("EpisodesSoFar", episodes_so_far)
                logger.record_tabular("TimestepsSoFar", timesteps_so_far)
                logger.record_tabular("TimeElapsed", time.time() - tstart)
                logger.record_tabular("WReturnMean",
                                      compute_w_return(*args)[0])
                logger.record_tabular("Penalization", penalization)
                logger.record_tabular("LearnableVariance", learnable_variance)
                logger.record_tabular("VarianceInitializer",
                                      variance_initializer)
                logger.record_tabular("Epsilon", constant_step_size)

            if save_weights > 0 and iters_so_far % save_weights == 0:
                logger.record_tabular('Weights', str(get_parameter()))
                #import pickle
                #file = open('checkpoint' + str(iters_so_far) + '.pkl', 'wb')
                #pickle.dump(theta, file)

            #print(get_parameter(), get_parameter_pi())
            #memory.print_parameters()

            #print('check ', theta, get_parameter())
            if not warm_start or memory.get_current_load() == capacity:
                # Optimize

                with timed("offline optimization"):
                    theta, improvement = optimize_offline(
                        theta,
                        set_parameter,
                        line_search,
                        evaluate_loss,
                        evaluate_gradient,
                        evaluate_natural_gradient,
                        max_offline_ite=max_offline_iters,
                        constant_step_size=constant_step_size)

                set_parameter(theta)
                #print('new theta ', theta)
                #print(get_parameter_pi())

                with timed('summaries after'):
                    meanlosses = np.array(compute_losses(*args))
                    for (lossname, lossval) in zip(loss_names, meanlosses):
                        logger.record_tabular(lossname, lossval)
            else:
                pass
                # Reinitialize the policy
                #tf.get_default_session().run(policy_reinit)

            logger.dump_tabular()

        assign_mu_eq_nu()

    env.close()
def learn(
        env,
        policy_fn,
        *,
        timesteps_per_actorbatch,  # timesteps per actor per update
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        optim_epochs,
        optim_stepsize,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,  # advantage estimation
        # CMAES
    max_fitness,  # has to be negative, as cmaes consider minization
        popsize,
        gensize,
        bounds,
        sigma,
        eval_iters,
        max_v_train_iter,
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        max_seconds=0,
        # time constraint
        callback=None,
        # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant',
        # annealing for stepsize parameters (epsilon and adam)
        seed,
        env_id):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_fn("pi", ob_space,
                   ac_space)  # Construct network for new policy
    oldpi = policy_fn("oldpi", ob_space, ac_space)  # Network for old policy
    backup_pi = policy_fn(
        "backup_pi", ob_space, ac_space
    )  # Construct a network for every individual to adapt during the es evolution

    pi_params = tf.placeholder(dtype=tf.float32, shape=[None])
    old_pi_params = tf.placeholder(dtype=tf.float32, shape=[None])
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule

    layer_clip = tf.placeholder(
        name='layer_clip', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule

    bound_coeff = tf.placeholder(
        name='bound_coeff', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule

    clip_param = clip_param * lrmult * layer_clip  # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - (oldpi.pd.logp(ac) + 1e-8))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param,
                             1.0 + clip_param) * atarg  #
    pol_surr = -tf.reduce_mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    vf_losses = [vf_loss]
    vf_loss_names = ["vf_loss"]

    pol_loss = pol_surr + pol_entpen
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    vf_var_list = [
        v for v in var_list if v.name.split("/")[1].startswith("vf")
    ]
    pol_var_list = [
        v for v in var_list if v.name.split("/")[1].startswith("pol")
    ]

    layer_var_list = []
    for i in range(pi.num_hid_layers):
        layer_var_list.append([
            v for v in pol_var_list
            if v.name.split("/")[2].startswith('fc%i' % (i + 1))
        ])
    logstd_var_list = [
        v for v in pol_var_list if v.name.split("/")[2].startswith("logstd")
    ]
    if len(logstd_var_list) != 0:
        layer_var_list.append([
            v for v in pol_var_list if v.name.split("/")[2].startswith("final")
        ] + logstd_var_list)

    vf_lossandgrad = U.function([ob, ac, ret, lrmult],
                                vf_losses + [U.flatgrad(vf_loss, vf_var_list)])

    vf_adam = MpiAdam(vf_var_list, epsilon=adam_epsilon)
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    assign_backup_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(backup_v, newv) for (
                backup_v,
                newv) in zipsame(backup_pi.get_variables(), pi.get_variables())
        ])
    assign_new_eq_backup = U.function(
        [], [],
        updates=[
            tf.assign(newv, backup_v)
            for (newv, backup_v
                 ) in zipsame(pi.get_variables(), backup_pi.get_variables())
        ])
    # Compute all losses

    compute_pol_losses = U.function([ob, ac, atarg, ret, lrmult, layer_clip],
                                    [pol_loss, pol_surr, pol_entpen, meankl])

    compute_v_pred = U.function([ob], [pi.vpred])

    a_prob = tf.exp(pi.pd.logp(ac))
    compute_a_prob = U.function([ob, ac], [a_prob])

    U.initialize()

    layer_set_operate_list = []
    layer_get_operate_list = []
    for var in layer_var_list:
        set_pi_layer_flat_params = U.SetFromFlat(var)
        layer_set_operate_list.append(set_pi_layer_flat_params)
        get_pi_layer_flat_params = U.GetFlat(var)
        layer_get_operate_list.append(get_pi_layer_flat_params)

    # get_pi_layer_flat_params = U.GetFlat(pol_var_list)
    # set_pi_layer_flat_params = U.SetFromFlat(pol_var_list)

    vf_adam.sync()

    adam.sync()

    global timesteps_so_far, episodes_so_far, iters_so_far, \
        tstart, lenbuffer, rewbuffer, tstart, ppo_timesteps_so_far, best_fitness

    episodes_so_far = 0
    timesteps_so_far = 0
    ppo_timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    best_fitness = -np.inf

    eval_seq = traj_segment_generator_eval(pi,
                                           env,
                                           timesteps_per_actorbatch,
                                           stochastic=False)
    # eval_gen = traj_segment_generator_eval(pi, test_env, timesteps_per_actorbatch, stochastic = True)  # For evaluation
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_actorbatch,
                                     stochastic=True,
                                     eval_seq=eval_seq)  # For train V Func

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    indices = []  # maintain all selected indices for each iteration

    opt = cma.CMAOptions()
    opt['tolfun'] = max_fitness
    opt['popsize'] = popsize
    opt['maxiter'] = gensize
    opt['verb_disp'] = 0
    opt['verb_log'] = 0
    # opt['seed'] = seed
    opt['AdaptSigma'] = True
    # opt['bounds'] = bounds
    # opt['tolstagnation'] = 20
    ess = []
    seg = None
    segs = None
    sum_vpred = []
    while True:
        if max_timesteps and timesteps_so_far >= max_timesteps:
            print("Max time steps")
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            print("Max episodes")
            break
        elif max_iters and iters_so_far >= max_iters:
            print("Max iterations")
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            print("Max time")
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / (max_timesteps),
                             0)
        else:
            raise NotImplementedError

        epsilon = max(0.5 * cur_lrmult, 0)
        # epsilon = max(0.5 - float(timesteps_so_far) / (max_timesteps), 0) * cur_lrmult
        # epsilon = 0.2
        # sigma_adapted = max(max(sigma - float(timesteps_so_far) / (5000 * max_timesteps), 0) * cur_lrmult, 1e-8)
        sigma_adapted = max(sigma * cur_lrmult, 1e-8)
        # cmean_adapted = max(1.0 - float(timesteps_so_far) / (max_timesteps), 1e-8)
        # cmean_adapted = max(0.8 - float(timesteps_so_far) / (2*max_timesteps), 1e-8)
        # if timesteps_so_far % max_timesteps == 10:
        max_v_train_iter = int(
            max(
                max_v_train_iter * (1 - timesteps_so_far /
                                    (0.5 * max_timesteps)), 1))
        logger.log("********** Iteration %i ************" % iters_so_far)
        if iters_so_far == 0:
            eval_seg = eval_seq.__next__()
            rewbuffer.extend(eval_seg["ep_rets"])
            lenbuffer.extend(eval_seg["ep_lens"])
            result_record()

        # Repository Train
        train_segs = {}
        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)
        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(
                seg["ob"])  # update running mean/std for normalization

        # rewbuffer.extend(seg["ep_rets"])
        # lenbuffer.extend(seg["ep_lens"])
        # if iters_so_far == 0:
        #     result_record()
        # print(np.random.get_state()[1][0])

        assign_old_eq_new()  # set old parameter values to new parameter values
        if segs is None:
            segs = seg
            segs["v_target"] = np.zeros(len(seg["ob"]), 'float32')
        elif len(segs["ob"]) >= 50000:
            segs["ob"] = np.take(segs["ob"],
                                 np.arange(timesteps_per_actorbatch,
                                           len(segs["ob"])),
                                 axis=0)
            segs["next_ob"] = np.take(segs["next_ob"],
                                      np.arange(timesteps_per_actorbatch,
                                                len(segs["next_ob"])),
                                      axis=0)
            segs["ac"] = np.take(segs["ac"],
                                 np.arange(timesteps_per_actorbatch,
                                           len(segs["ac"])),
                                 axis=0)
            segs["rew"] = np.take(segs["rew"],
                                  np.arange(timesteps_per_actorbatch,
                                            len(segs["rew"])),
                                  axis=0)
            segs["vpred"] = np.take(segs["vpred"],
                                    np.arange(timesteps_per_actorbatch,
                                              len(segs["vpred"])),
                                    axis=0)
            segs["act_props"] = np.take(segs["act_props"],
                                        np.arange(timesteps_per_actorbatch,
                                                  len(segs["act_props"])),
                                        axis=0)
            segs["new"] = np.take(segs["new"],
                                  np.arange(timesteps_per_actorbatch,
                                            len(segs["new"])),
                                  axis=0)
            segs["adv"] = np.take(segs["adv"],
                                  np.arange(timesteps_per_actorbatch,
                                            len(segs["adv"])),
                                  axis=0)
            segs["tdlamret"] = np.take(segs["tdlamret"],
                                       np.arange(timesteps_per_actorbatch,
                                                 len(segs["tdlamret"])),
                                       axis=0)
            segs["ep_rets"] = np.take(segs["ep_rets"],
                                      np.arange(timesteps_per_actorbatch,
                                                len(segs["ep_rets"])),
                                      axis=0)
            segs["ep_lens"] = np.take(segs["ep_lens"],
                                      np.arange(timesteps_per_actorbatch,
                                                len(segs["ep_lens"])),
                                      axis=0)
            segs["v_target"] = np.take(segs["v_target"],
                                       np.arange(timesteps_per_actorbatch,
                                                 len(segs["v_target"])),
                                       axis=0)
            segs["ob"] = np.append(segs['ob'], seg['ob'], axis=0)
            segs["next_ob"] = np.append(segs['next_ob'],
                                        seg['next_ob'],
                                        axis=0)
            segs["ac"] = np.append(segs['ac'], seg['ac'], axis=0)
            segs["rew"] = np.append(segs['rew'], seg['rew'], axis=0)
            segs["vpred"] = np.append(segs['vpred'], seg['vpred'], axis=0)
            segs["act_props"] = np.append(segs['act_props'],
                                          seg['act_props'],
                                          axis=0)
            segs["new"] = np.append(segs['new'], seg['new'], axis=0)
            segs["adv"] = np.append(segs['adv'], seg['adv'], axis=0)
            segs["tdlamret"] = np.append(segs['tdlamret'],
                                         seg['tdlamret'],
                                         axis=0)
            segs["ep_rets"] = np.append(segs['ep_rets'],
                                        seg['ep_rets'],
                                        axis=0)
            segs["ep_lens"] = np.append(segs['ep_lens'],
                                        seg['ep_lens'],
                                        axis=0)
            segs["v_target"] = np.append(segs['v_target'],
                                         np.zeros(len(seg["ob"]), 'float32'),
                                         axis=0)
        else:
            segs["ob"] = np.append(segs['ob'], seg['ob'], axis=0)
            segs["next_ob"] = np.append(segs['next_ob'],
                                        seg['next_ob'],
                                        axis=0)
            segs["ac"] = np.append(segs['ac'], seg['ac'], axis=0)
            segs["rew"] = np.append(segs['rew'], seg['rew'], axis=0)
            segs["vpred"] = np.append(segs['vpred'], seg['vpred'], axis=0)
            segs["act_props"] = np.append(segs['act_props'],
                                          seg['act_props'],
                                          axis=0)
            segs["new"] = np.append(segs['new'], seg['new'], axis=0)
            segs["adv"] = np.append(segs['adv'], seg['adv'], axis=0)
            segs["tdlamret"] = np.append(segs['tdlamret'],
                                         seg['tdlamret'],
                                         axis=0)
            segs["ep_rets"] = np.append(segs['ep_rets'],
                                        seg['ep_rets'],
                                        axis=0)
            segs["ep_lens"] = np.append(segs['ep_lens'],
                                        seg['ep_lens'],
                                        axis=0)
            segs["v_target"] = np.append(segs['v_target'],
                                         np.zeros(len(seg["ob"]), 'float32'),
                                         axis=0)

        if iters_so_far == 0:
            ob, ac, tdlamret = seg["ob"], seg["ac"], seg["tdlamret"]
            d = Dataset(dict(ob=ob, ac=ac, vtarg=tdlamret),
                        shuffle=not pi.recurrent)
            optim_batchsize = optim_batchsize or ob.shape[0]

            # Train V function
            # logger.log("Catchup Training V Func and Evaluating V Func Losses")
            for _ in range(max_v_train_iter):
                for batch in d.iterate_once(optim_batchsize):
                    *vf_loss, g = vf_lossandgrad(batch["ob"], batch["ac"],
                                                 batch["vtarg"], cur_lrmult)
                    vf_adam.update(g, optim_stepsize * cur_lrmult)
                # logger.log(fmt_row(13, np.mean(vf_losses, axis = 0)))
        else:
            # Update v target
            new = segs["new"]
            rew = segs["rew"]
            act_prob = np.asarray(compute_a_prob(segs["ob"], segs["ac"])).T
            importance_ratio = np.squeeze(act_prob) / (
                segs["act_props"] + np.ones(segs["act_props"].shape) * 1e-8)
            segs["v_target"] = importance_ratio * (1/np.sum(importance_ratio)) * \
                               np.squeeze(rew + np.invert(new).astype(np.float32) * gamma * compute_v_pred(segs["next_ob"]))
            # train_segs["v_target"] = rew + np.invert(new).astype(np.float32) * gamma * compute_v_pred(train_segs["next_ob"])
            if len(segs["ob"]) >= 20000:
                train_times = int(max_v_train_iter /
                                  2) if int(max_v_train_iter / 2) > 0 else 1
            else:
                train_times = 2
            for i in range(train_times):
                selected_train_index = np.random.choice(
                    range(len(segs["ob"])),
                    timesteps_per_actorbatch,
                    replace=False)
                train_segs["ob"] = np.take(segs["ob"],
                                           selected_train_index,
                                           axis=0)
                train_segs["next_ob"] = np.take(segs["next_ob"],
                                                selected_train_index,
                                                axis=0)
                train_segs["ac"] = np.take(segs["ac"],
                                           selected_train_index,
                                           axis=0)
                train_segs["rew"] = np.take(segs["rew"],
                                            selected_train_index,
                                            axis=0)
                train_segs["vpred"] = np.take(segs["vpred"],
                                              selected_train_index,
                                              axis=0)
                train_segs["new"] = np.take(segs["new"],
                                            selected_train_index,
                                            axis=0)
                train_segs["adv"] = np.take(segs["adv"],
                                            selected_train_index,
                                            axis=0)
                train_segs["tdlamret"] = np.take(segs["tdlamret"],
                                                 selected_train_index,
                                                 axis=0)
                train_segs["v_target"] = np.take(segs["v_target"],
                                                 selected_train_index,
                                                 axis=0)
                #
                ob, ac, v_target = train_segs["ob"], train_segs[
                    "ac"], train_segs["v_target"]
                d = Dataset(dict(ob=ob, ac=ac, vtarg=v_target),
                            shuffle=not pi.recurrent)
                optim_batchsize = optim_batchsize or ob.shape[0]

                # Train V function
                # logger.log("Training V Func and Evaluating V Func Losses")
                # Train V function
                # logger.log("Catchup Training V Func and Evaluating V Func Losses")
                # logger.log("Train V - "+str(_))
                for _ in range(max_v_train_iter):
                    for batch in d.iterate_once(optim_batchsize):
                        *vf_loss, g = vf_lossandgrad(batch["ob"], batch["ac"],
                                                     batch["vtarg"],
                                                     cur_lrmult)
                        vf_adam.update(g, optim_stepsize * cur_lrmult)
                    # logger.log(fmt_row(13, np.mean(vf_losses, axis = 0)))
                # seg['vpred'] = np.asarray(compute_v_pred(seg["ob"])).reshape(seg['vpred'].shape)
                # seg['nextvpred'] = seg['vpred'][-1] * (1 - seg["new"][-1])
                # add_vtarg_and_adv(seg, gamma, lam)

            # seg['vpred'] = np.asarray(compute_v_pred(seg["ob"])).reshape(seg['vpred'].shape)
            # seg['nextvpred'] = seg['vpred'][-1] * (1 - seg["new"][-1])
            # add_vtarg_and_adv(seg, gamma, lam)

        # seg['vpred'] = np.asarray(compute_v_pred(seg["ob"])).reshape(seg['vpred'].shape)
        # seg['nextvpred'] = seg['vpred'][-1] * (1 - seg["new"][-1])
        # add_vtarg_and_adv(seg, gamma, lam)

        ob_po, ac_po, atarg_po, tdlamret_po = seg["ob"], seg["ac"], seg[
            "adv"], seg["tdlamret"]
        atarg_po = (atarg_po - atarg_po.mean()) / atarg_po.std(
        )  # standardized advantage function estimate

        # opt['CMA_cmean'] = cmean_adapted
        # assign_old_eq_new()  # set old parameter values to new parameter values
        for i in range(len(layer_var_list)):
            # CMAES Train Policy
            assign_backup_eq_new()  # backup current policy
            flatten_weights = layer_get_operate_list[i]()

            if len(indices) < len(layer_var_list):
                selected_index, init_weights = uniform_select(
                    flatten_weights,
                    0.5)  # 0.5 means 50% proportion of params are selected
                indices.append(selected_index)
            else:
                rand = np.random.uniform()
                # print("Random-Number:", rand)
                # print("Epsilon:", epsilon)
                if rand < epsilon:
                    selected_index, init_weights = uniform_select(
                        flatten_weights, 0.5)
                    indices.append(selected_index)
                    # logger.log("Random: select new weights")
                else:
                    selected_index = indices[i]
                    init_weights = np.take(flatten_weights, selected_index)
            es = cma.CMAEvolutionStrategy(init_weights, sigma_adapted, opt)
            while True:
                if es.countiter >= gensize:
                    # logger.log("Max generations for current layer")
                    break
                # logger.log("Iteration:" + str(iters_so_far) + " - sub-train Generation for Policy:" + str(es.countiter))
                # logger.log("Sigma=" + str(es.sigma))
                # solutions = es.ask(sigma_fac = max(cur_lrmult, 1e-8))
                solutions = es.ask()
                # solutions = [np.clip(solution, -5.0, 5.0).tolist() for solution in solutions]
                costs = []
                lens = []

                assign_backup_eq_new()  # backup current policy

                for id, solution in enumerate(solutions):
                    np.put(flatten_weights, selected_index, solution)
                    layer_set_operate_list[i](flatten_weights)
                    cost = compute_pol_losses(ob_po, ac_po, atarg_po,
                                              tdlamret_po, cur_lrmult,
                                              1 / 3 * (i + 1))
                    # cost = compute_pol_losses(ob_po, ac_po, atarg_po, tdlamret_po, cur_lrmult, 1.0)
                    costs.append(cost[0])
                    assign_new_eq_backup()
                # Weights decay
                l2_decay = compute_weight_decay(0.01, solutions)
                costs += l2_decay
                costs, real_costs = fitness_rank(costs)
                # logger.log("real_costs:"+str(real_costs))
                # best_solution = np.copy(es.result[0])
                # best_fitness = -es.result[1]
                es.tell_real_seg(solutions=solutions,
                                 function_values=costs,
                                 real_f=real_costs,
                                 segs=None)
                # best_solution = np.copy(solutions[np.argmin(costs)])
                # best_fitness = -real_costs[np.argmin(costs)]
                best_solution = es.result[0]
                best_fitness = es.result[1]
                np.put(flatten_weights, selected_index, best_solution)
                layer_set_operate_list[i](flatten_weights)
                # logger.log("Update the layer")
                # best_solution = es.result[0]
                # best_fitness = es.result[1]
                # logger.log("Best Solution Fitness:" + str(best_fitness))
                # set_pi_flat_params(best_solution)
            import gc
            gc.collect()

        iters_so_far += 1
        episodes_so_far += sum(lens)
Example #37
0
def learn(env, policy_func, *,
        timesteps_per_batch, # what to train on
        max_kl, cg_iters,
        gamma, lam, # advantage estimation
        entcoeff=0.0,
        cg_damping=1e-2,
        vf_stepsize=3e-4,
        vf_iters =3,
        max_timesteps=0, max_episodes=0, max_iters=0,  # time constraint
        callback=None
        ):
    nworkers = MPI.COMM_WORLD.Get_size()
    rank = MPI.COMM_WORLD.Get_rank()
    np.set_printoptions(precision=3)    
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_func("pi", ob_space, ac_space)
    oldpi = policy_func("oldpi", ob_space, ac_space)
    atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = U.mean(kloldnew)
    meanent = U.mean(ent)
    entbonus = entcoeff * meanent

    vferr = U.mean(tf.square(pi.vpred - ret))

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # advantage * pnew / pold
    surrgain = U.mean(ratio * atarg)

    optimgain = surrgain + entbonus
    losses = [optimgain, meankl, entbonus, surrgain, meanent]
    loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]

    dist = meankl

    all_var_list = pi.get_trainable_variables()
    var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("pol")]
    vf_var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("vf")]
    vfadam = MpiAdam(vf_var_list)

    get_flat = U.GetFlat(var_list)
    set_from_flat = U.SetFromFlat(var_list)
    klgrads = tf.gradients(dist, var_list)
    flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan")
    shapes = [var.get_shape().as_list() for var in var_list]
    start = 0
    tangents = []
    for shape in shapes:
        sz = U.intprod(shape)
        tangents.append(tf.reshape(flat_tangent[start:start+sz], shape))
        start += sz
    gvp = tf.add_n([U.sum(g*tangent) for (g, tangent) in zipsame(klgrads, tangents)]) #pylint: disable=E1111
    fvp = U.flatgrad(gvp, var_list)

    assign_old_eq_new = U.function([],[], updates=[tf.assign(oldv, newv)
        for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())])
    compute_losses = U.function([ob, ac, atarg], losses)
    compute_lossandgrad = U.function([ob, ac, atarg], losses + [U.flatgrad(optimgain, var_list)])
    compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp)
    compute_vflossandgrad = U.function([ob, ret], U.flatgrad(vferr, vf_var_list))

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(colorize("done in %.3f seconds"%(time.time() - tstart), color='magenta'))
        else:
            yield
    
    def allmean(x):
        assert isinstance(x, np.ndarray)
        out = np.empty_like(x)
        MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
        out /= nworkers
        return out

    U.initialize()
    th_init = get_flat()
    MPI.COMM_WORLD.Bcast(th_init, root=0)
    set_from_flat(th_init)
    vfadam.sync()
    print("Init param sum", th_init.sum(), flush=True)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi, env, timesteps_per_batch, stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards

    assert sum([max_iters>0, max_timesteps>0, max_episodes>0])==1

    while True:        
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        logger.log("********** Iteration %i ************"%iters_so_far)

        with timed("sampling"):
            seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"]
        vpredbefore = seg["vpred"] # predicted value function before udpate
        atarg = (atarg - atarg.mean()) / atarg.std() # standardized advantage function estimate

        if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret)
        if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy

        args = seg["ob"], seg["ac"], atarg
        fvpargs = [arr[::5] for arr in args]
        def fisher_vector_product(p):
            return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p

        assign_old_eq_new() # set old parameter values to new parameter values
        with timed("computegrad"):
            *lossbefore, g = compute_lossandgrad(*args)
        lossbefore = allmean(np.array(lossbefore))
        g = allmean(g)
        if np.allclose(g, 0):
            logger.log("Got zero gradient. not updating")
        else:
            with timed("cg"):
                stepdir = cg(fisher_vector_product, g, cg_iters=cg_iters, verbose=rank==0)
            assert np.isfinite(stepdir).all()
            shs = .5*stepdir.dot(fisher_vector_product(stepdir))
            lm = np.sqrt(shs / max_kl)
            # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
            fullstep = stepdir / lm
            expectedimprove = g.dot(fullstep)
            surrbefore = lossbefore[0]
            stepsize = 1.0
            thbefore = get_flat()
            for _ in range(10):
                thnew = thbefore + fullstep * stepsize
                set_from_flat(thnew)
                meanlosses = surr, kl, *_ = allmean(np.array(compute_losses(*args)))
                improve = surr - surrbefore
                logger.log("Expected: %.3f Actual: %.3f"%(expectedimprove, improve))
                if not np.isfinite(meanlosses).all():
                    logger.log("Got non-finite value of losses -- bad!")
                elif kl > max_kl * 1.5:
                    logger.log("violated KL constraint. shrinking step.")
                elif improve < 0:
                    logger.log("surrogate didn't improve. shrinking step.")
                else:
                    logger.log("Stepsize OK!")
                    break
                stepsize *= .5
            else:
                logger.log("couldn't compute a good step")
                set_from_flat(thbefore)
            if nworkers > 1 and iters_so_far % 20 == 0:
                paramsums = MPI.COMM_WORLD.allgather((thnew.sum(), vfadam.getflat().sum())) # list of tuples
                assert all(np.allclose(ps, paramsums[0]) for ps in paramsums[1:])

        for (lossname, lossval) in zip(loss_names, meanlosses):
            logger.record_tabular(lossname, lossval)

        with timed("vf"):

            for _ in range(vf_iters):
                for (mbob, mbret) in dataset.iterbatches((seg["ob"], seg["tdlamret"]), 
                include_final_partial_batch=False, batch_size=64):
                    g = allmean(compute_vflossandgrad(mbob, mbret))
                    vfadam.update(g, vf_stepsize)

        logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))

        lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)

        if rank==0:
            logger.dump_tabular()
def learn(
        *,
        network,
        env,
        eval_env,
        make_eval_env,
        env_id,
        seed,
        beta,
        total_timesteps,
        sil_update,
        sil_loss,
        timesteps_per_batch,  # what to train on
        #num_samples=(1500,),
    num_samples=(1, ),
        #horizon=(5,),
        horizon=(2, ),
        #num_elites=(10,),
        num_elites=(1, ),
        max_kl=0.001,
        cg_iters=10,
        gamma=0.99,
        lam=1.0,  # advantage estimation
        ent_coef=0.0,
        lr=3e-4,
        cg_damping=1e-2,
        vf_stepsize=3e-4,
        vf_iters=5,
        sil_value=0.01,
        sil_alpha=0.6,
        sil_beta=0.1,
        max_episodes=0,
        max_iters=0,  # time constraint
        callback=None,
        save_interval=0,
        load_path=None,
        model_fn=None,
        update_fn=None,
        init_fn=None,
        mpi_rank_weight=1,
        comm=None,
        vf_coef=0.5,
        max_grad_norm=0.5,
        log_interval=1,
        nminibatches=4,
        noptepochs=4,
        cliprange=0.2,
        TRPO=False,

        # MBL
        # For train mbl
        mbl_train_freq=5,

        # For eval
        num_eval_episodes=5,
        eval_freq=5,
        vis_eval=False,
        eval_targs=('mbmf', ),
        #eval_targs=('mf',),
        quant=2,

        # For mbl.step
        mbl_lamb=(1.0, ),
        mbl_gamma=0.99,
        #mbl_sh=1, # Number of step for stochastic sampling
        mbl_sh=10000,
        #vf_lookahead=-1,
        #use_max_vf=False,
        reset_per_step=(0, ),

        # For get_model
        num_fc=2,
        num_fwd_hidden=500,
        use_layer_norm=False,

        # For MBL
        num_warm_start=int(1e4),
        init_epochs=10,
        update_epochs=5,
        batch_size=512,
        update_with_validation=False,
        use_mean_elites=1,
        use_ent_adjust=0,
        adj_std_scale=0.5,

        # For data loading
        validation_set_path=None,

        # For data collect
        collect_val_data=False,

        # For traj collect
        traj_collect='mf',

        # For profile
        measure_time=True,
        eval_val_err=False,
        measure_rew=True,
        **network_kwargs):
    '''
    learn a policy function with TRPO algorithm

    Parameters:
    ----------

    network                 neural network to learn. Can be either string ('mlp', 'cnn', 'lstm', 'lnlstm' for basic types)
                            or function that takes input placeholder and returns tuple (output, None) for feedforward nets
                            or (output, (state_placeholder, state_output, mask_placeholder)) for recurrent nets

    env                     environment (one of the gym environments or wrapped via baselines.common.vec_env.VecEnv-type class

    timesteps_per_batch     timesteps per gradient estimation batch

    max_kl                  max KL divergence between old policy and new policy ( KL(pi_old || pi) )

    ent_coef                coefficient of policy entropy term in the optimization objective

    cg_iters                number of iterations of conjugate gradient algorithm

    cg_damping              conjugate gradient damping

    vf_stepsize             learning rate for adam optimizer used to optimie value function loss

    vf_iters                number of iterations of value function optimization iterations per each policy optimization step

    total_timesteps           max number of timesteps

    max_episodes            max number of episodes

    max_iters               maximum number of policy optimization iterations

    callback                function to be called with (locals(), globals()) each policy optimization step

    load_path               str, path to load the model from (default: None, i.e. no model is loaded)

    **network_kwargs        keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network

    Returns:
    -------

    learnt model

    '''

    if not isinstance(num_samples, tuple): num_samples = (num_samples, )
    if not isinstance(horizon, tuple): horizon = (horizon, )
    if not isinstance(num_elites, tuple): num_elites = (num_elites, )
    if not isinstance(mbl_lamb, tuple): mbl_lamb = (mbl_lamb, )
    if not isinstance(reset_per_step, tuple):
        reset_per_step = (reset_per_step, )
    if validation_set_path is None:
        if collect_val_data:
            validation_set_path = os.path.join(logger.get_dir(), 'val.pkl')
        else:
            validation_set_path = os.path.join('dataset',
                                               '{}-val.pkl'.format(env_id))
    if eval_val_err:
        eval_val_err_path = os.path.join('dataset',
                                         '{}-combine-val.pkl'.format(env_id))
    logger.log(locals())
    logger.log('MBL_SH', mbl_sh)
    logger.log('Traj_collect', traj_collect)

    set_global_seeds(seed)
    if isinstance(lr, float): lr = constfn(lr)
    else: assert callable(lr)
    if isinstance(cliprange, float): cliprange = constfn(cliprange)
    else: assert callable(cliprange)
    nworkers = MPI.COMM_WORLD.Get_size()
    rank = MPI.COMM_WORLD.Get_rank()

    if MPI is not None:
        nworkers = MPI.COMM_WORLD.Get_size()
        rank = MPI.COMM_WORLD.Get_rank()
    else:
        nworkers = 1
        rank = 0
    cpus_per_worker = 1
    U.get_session(
        config=tf.ConfigProto(allow_soft_placement=True,
                              inter_op_parallelism_threads=cpus_per_worker,
                              intra_op_parallelism_threads=cpus_per_worker))

    policy = build_policy(env,
                          network,
                          value_network='copy',
                          copos=True,
                          **network_kwargs)
    nenvs = env.num_envs
    np.set_printoptions(precision=3)

    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    nbatch = nenvs * timesteps_per_batch
    nbatch_train = nbatch // nminibatches
    is_mpi_root = (MPI is None or MPI.COMM_WORLD.Get_rank() == 0)
    if model_fn is None:
        model_fn = Model
    discrete_ac_space = isinstance(ac_space, gym.spaces.Discrete)

    ob = observation_placeholder(ob_space)
    with tf.variable_scope("pi"):
        pi = policy(observ_placeholder=ob)
        make_model = lambda: Model(
            policy=policy,
            ob_space=ob_space,
            ac_space=ac_space,
            nbatch_act=nenvs,
            nbatch_train=nbatch_train,
            nsteps=timesteps_per_batch,
            ent_coef=ent_coef,
            vf_coef=vf_coef,
            max_grad_norm=max_grad_norm,
            sil_update=sil_update,
            sil_value=sil_value,
            sil_alpha=sil_alpha,
            sil_beta=sil_beta,
            sil_loss=sil_loss,
            #                                    fn_reward=env.process_reward,
            fn_reward=None,
            #                                    fn_obs=env.process_obs,
            fn_obs=None,
            ppo=False,
            prev_pi='pi',
            silm=pi)
        model = make_model()
        if load_path is not None:
            model.load(load_path)
    with tf.variable_scope("oldpi"):
        oldpi = policy(observ_placeholder=ob)
        make_old_model = lambda: Model(
            policy=policy,
            ob_space=ob_space,
            ac_space=ac_space,
            nbatch_act=nenvs,
            nbatch_train=nbatch_train,
            nsteps=timesteps_per_batch,
            ent_coef=ent_coef,
            vf_coef=vf_coef,
            max_grad_norm=max_grad_norm,
            sil_update=sil_update,
            sil_value=sil_value,
            sil_alpha=sil_alpha,
            sil_beta=sil_beta,
            sil_loss=sil_loss,
            #                                    fn_reward=env.process_reward,
            fn_reward=None,
            #                                    fn_obs=env.process_obs,
            fn_obs=None,
            ppo=False,
            prev_pi='oldpi',
            silm=oldpi)
        old_model = make_old_model()

    # MBL
    # ---------------------------------------
    #viz = Visdom(env=env_id)
    win = None
    eval_targs = list(eval_targs)
    logger.log(eval_targs)

    make_model_f = get_make_mlp_model(num_fc=num_fc,
                                      num_fwd_hidden=num_fwd_hidden,
                                      layer_norm=use_layer_norm)
    mbl = MBL(env=eval_env,
              env_id=env_id,
              make_model=make_model_f,
              num_warm_start=num_warm_start,
              init_epochs=init_epochs,
              update_epochs=update_epochs,
              batch_size=batch_size,
              **network_kwargs)

    val_dataset = {'ob': None, 'ac': None, 'ob_next': None}
    if update_with_validation:
        logger.log('Update with validation')
        val_dataset = load_val_data(validation_set_path)
    if eval_val_err:
        logger.log('Log val error')
        eval_val_dataset = load_val_data(eval_val_err_path)
    if collect_val_data:
        logger.log('Collect validation data')
        val_dataset_collect = []

    def _mf_pi(ob, t=None):
        stochastic = True
        ac, vpred, _, _ = pi.step(ob, stochastic=stochastic)
        return ac, vpred

    def _mf_det_pi(ob, t=None):
        #ac, vpred, _, _ = pi.step(ob, stochastic=False)
        ac, vpred = pi._evaluate([pi.pd.mode(), pi.vf], ob)
        return ac, vpred

    def _mf_ent_pi(ob, t=None):
        mean, std, vpred = pi._evaluate([pi.pd.mode(), pi.pd.std, pi.vf], ob)
        ac = np.random.normal(mean, std * adj_std_scale, size=mean.shape)
        return ac, vpred
################### use_ent_adjust======> adj_std_scale????????pi action sample

    def _mbmf_inner_pi(ob, t=0):
        if use_ent_adjust:
            return _mf_ent_pi(ob)
        else:
            #return _mf_pi(ob)
            if t < mbl_sh: return _mf_pi(ob)
            else: return _mf_det_pi(ob)

    # ---------------------------------------

    # Run multiple configuration once
    all_eval_descs = []

    def make_mbmf_pi(n, h, e, l):
        def _mbmf_pi(ob):
            ac, rew = mbl.step(ob=ob,
                               pi=_mbmf_inner_pi,
                               horizon=h,
                               num_samples=n,
                               num_elites=e,
                               gamma=mbl_gamma,
                               lamb=l,
                               use_mean_elites=use_mean_elites)
            return ac[None], rew

        return Policy(step=_mbmf_pi, reset=None)

    for n in num_samples:
        for h in horizon:
            for l in mbl_lamb:
                for e in num_elites:
                    if 'mbmf' in eval_targs:
                        all_eval_descs.append(('MeanRew', 'MBL_COPOS_SIL',
                                               make_mbmf_pi(n, h, e, l)))
                    #if 'mbmf' in eval_targs: all_eval_descs.append(('MeanRew-n-{}-h-{}-e-{}-l-{}-sh-{}-me-{}'.format(n, h, e, l, mbl_sh, use_mean_elites), 'MBL_TRPO-n-{}-h-{}-e-{}-l-{}-sh-{}-me-{}'.format(n, h, e, l, mbl_sh, use_mean_elites), make_mbmf_pi(n, h, e, l)))
    if 'mf' in eval_targs:
        all_eval_descs.append(
            ('MeanRew', 'COPOS_SIL', Policy(step=_mf_pi, reset=None)))

    logger.log('List of evaluation targets')
    for it in all_eval_descs:
        logger.log(it[0])

    pool = Pool(mp.cpu_count())
    warm_start_done = False
    # ----------------------------------------

    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    entbonus = ent_coef * meanent

    vferr = tf.reduce_mean(tf.square(pi.vf - ret))

    ratio = tf.exp(pi.pd.logp(ac) -
                   oldpi.pd.logp(ac))  # advantage * pnew / pold
    surrgain = tf.reduce_mean(ratio * atarg)

    optimgain = surrgain + entbonus
    losses = [optimgain, meankl, entbonus, surrgain, meanent]
    loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]

    dist = meankl

    all_var_list = get_trainable_variables("pi")
    # var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("pol")]
    # vf_var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("vf")]
    var_list = get_pi_trainable_variables("pi")
    vf_var_list = get_vf_trainable_variables("pi")

    vfadam = MpiAdam(vf_var_list)

    get_flat = U.GetFlat(var_list)
    set_from_flat = U.SetFromFlat(var_list)
    klgrads = tf.gradients(dist, var_list)
    flat_tangent = tf.placeholder(dtype=tf.float32,
                                  shape=[None],
                                  name="flat_tan")
    shapes = [var.get_shape().as_list() for var in var_list]
    start = 0
    tangents = []
    for shape in shapes:
        sz = U.intprod(shape)
        tangents.append(tf.reshape(flat_tangent[start:start + sz], shape))
        start += sz
    gvp = tf.add_n([
        tf.reduce_sum(g * tangent)
        for (g, tangent) in zipsame(klgrads, tangents)
    ])  #pylint: disable=E1111
    fvp = U.flatgrad(gvp, var_list)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(get_variables("oldpi"), get_variables("pi"))
        ])

    compute_losses = U.function([ob, ac, atarg], losses)
    compute_lossandgrad = U.function([ob, ac, atarg], losses +
                                     [U.flatgrad(optimgain, var_list)])
    compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp)
    compute_vflossandgrad = U.function([ob, ret],
                                       U.flatgrad(vferr, vf_var_list))

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(
                colorize("done in %.3f seconds" % (time.time() - tstart),
                         color='magenta'))
        else:
            yield

    def allmean(x):
        assert isinstance(x, np.ndarray)
        out = np.empty_like(x)
        MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
        out /= nworkers
        return out

    U.initialize()
    if load_path is not None:
        pi.load(load_path)

    th_init = get_flat()
    MPI.COMM_WORLD.Bcast(th_init, root=0)
    set_from_flat(th_init)
    vfadam.sync()
    print("Init param sum", th_init.sum(), flush=True)
    # Initialize eta, omega optimizer
    if discrete_ac_space:
        init_eta = 1
        init_omega = 0.5
        eta_omega_optimizer = EtaOmegaOptimizerDiscrete(
            beta, max_kl, init_eta, init_omega)
    else:
        init_eta = 0.5
        init_omega = 2.0
        #????eta_omega_optimizer details?????
        eta_omega_optimizer = EtaOmegaOptimizer(beta, max_kl, init_eta,
                                                init_omega)

    # Prepare for rollouts
    # ----------------------------------------
    if traj_collect == 'mf':
        seg_gen = traj_segment_generator(env,
                                         timesteps_per_batch,
                                         model,
                                         stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=40)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40)  # rolling buffer for episode rewards

    if sum([max_iters > 0, total_timesteps > 0, max_episodes > 0]) == 0:
        # noththing to be done
        return pi

    assert sum([max_iters>0, total_timesteps>0, max_episodes>0]) < 2, \
        'out of max_iters, total_timesteps, and max_episodes only one should be specified'

    while True:
        if callback: callback(locals(), globals())
        if total_timesteps and timesteps_so_far >= total_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        logger.log("********** Iteration %i ************" % iters_so_far)

        with timed("sampling"):
            seg = seg_gen.__next__()
            if traj_collect == 'mf-random' or traj_collect == 'mf-mb':
                seg_mbl = seg_gen_mbl.__next__()
            else:
                seg_mbl = seg
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]

        # Val data collection
        if collect_val_data:
            for ob_, ac_, ob_next_ in zip(ob[:-1, 0, ...], ac[:-1, ...],
                                          ob[1:, 0, ...]):
                val_dataset_collect.append(
                    (copy.copy(ob_), copy.copy(ac_), copy.copy(ob_next_)))
        # -----------------------------
        # MBL update
        else:
            ob_mbl, ac_mbl = seg_mbl["ob"], seg_mbl["ac"]

            mbl.add_data_batch(ob_mbl[:-1, 0, ...], ac_mbl[:-1, ...],
                               ob_mbl[1:, 0, ...])
            mbl.update_forward_dynamic(require_update=iters_so_far %
                                       mbl_train_freq == 0,
                                       ob_val=val_dataset['ob'],
                                       ac_val=val_dataset['ac'],
                                       ob_next_val=val_dataset['ob_next'])
        # -----------------------------

        if traj_collect == 'mf':
            #if traj_collect == 'mf' or traj_collect == 'mf-random' or traj_collect == 'mf-mb':
            vpredbefore = seg[
                "vpred"]  # predicted value function before udpate
            model = seg["model"]
            atarg = (atarg - atarg.mean()) / atarg.std(
            )  # standardized advantage function estimate

            if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret)
            if hasattr(pi, "rms"):
                pi.rms.update(ob)  # update running mean/std for policy

            args = seg["ob"], seg["ac"], atarg
            fvpargs = [arr[::5] for arr in args]

            def fisher_vector_product(p):
                return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p

            assign_old_eq_new(
            )  # set old parameter values to new parameter values
            with timed("computegrad"):
                *lossbefore, g = compute_lossandgrad(*args)
            lossbefore = allmean(np.array(lossbefore))
            g = allmean(g)
            if np.allclose(g, 0):
                logger.log("Got zero gradient. not updating")
            else:
                with timed("cg"):
                    stepdir = cg(fisher_vector_product,
                                 g,
                                 cg_iters=cg_iters,
                                 verbose=rank == 0)
                assert np.isfinite(stepdir).all()

                if TRPO:
                    shs = .5 * stepdir.dot(fisher_vector_product(stepdir))
                    lm = np.sqrt(shs / max_kl)
                    # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
                    fullstep = stepdir / lm
                    expectedimprove = g.dot(fullstep)
                    surrbefore = lossbefore[0]
                    stepsize = 1.0
                    thbefore = get_flat()
                    for _ in range(10):
                        thnew = thbefore + fullstep * stepsize
                        set_from_flat(thnew)
                        meanlosses = surr, kl, *_ = allmean(
                            np.array(compute_losses(*args)))
                        improve = surr - surrbefore
                        logger.log("Expected: %.3f Actual: %.3f" %
                                   (expectedimprove, improve))
                        if not np.isfinite(meanlosses).all():
                            logger.log(
                                "Got non-finite value of losses -- bad!")
                        elif kl > max_kl * 1.5:
                            logger.log(
                                "violated KL constraint. shrinking step.")
                        elif improve < 0:
                            logger.log(
                                "surrogate didn't improve. shrinking step.")
                        else:
                            logger.log("Stepsize OK!")
                            break
                        stepsize *= .5
                    else:
                        logger.log("couldn't compute a good step")
                        set_from_flat(thbefore)
                else:
                    copos_update_dir = stepdir
                    # Split direction into log-linear 'w_theta' and non-linear 'w_beta' parts
                    w_theta, w_beta = pi.split_w(copos_update_dir)
                    tmp_ob = np.zeros(
                        (1, ) + env.observation_space.shape
                    )  # We assume that entropy does not depend on the NN

                    # Optimize eta and omega
                    if discrete_ac_space:
                        entropy = lossbefore[4]
                        #entropy = - 1/timesteps_per_batch * np.sum(np.sum(pi.get_action_prob(ob) * pi.get_log_action_prob(ob), axis=1))
                        eta, omega = eta_omega_optimizer.optimize(
                            pi.compute_F_w(ob, copos_update_dir),
                            pi.get_log_action_prob(ob), timesteps_per_batch,
                            entropy)
                    else:
                        Waa, Wsa = pi.w2W(w_theta)
                        wa = pi.get_wa(ob, w_beta)
                        varphis = pi.get_varphis(ob)

                        #old_ent = old_entropy.eval({oldpi.ob: tmp_ob})[0]
                        old_ent = lossbefore[4]
                        eta, omega = eta_omega_optimizer.optimize(
                            w_theta, Waa, Wsa, wa, varphis, pi.get_kt(),
                            pi.get_prec_matrix(), pi.is_new_policy_valid,
                            old_ent)
                    logger.log("Initial eta: " + str(eta) + " and omega: " +
                               str(omega))

                    current_theta_beta = get_flat()
                    prev_theta, prev_beta = pi.all_to_theta_beta(
                        current_theta_beta)

                    if discrete_ac_space:
                        # Do a line search for both theta and beta parameters by adjusting only eta
                        eta = eta_search(w_theta, w_beta, eta, omega, allmean,
                                         compute_losses, get_flat,
                                         set_from_flat, pi, max_kl, args,
                                         discrete_ac_space)
                        logger.log("Updated eta, eta: " + str(eta))
                        set_from_flat(
                            pi.theta_beta_to_all(prev_theta, prev_beta))
                        # Find proper omega for new eta. Use old policy parameters first.
                        eta, omega = eta_omega_optimizer.optimize(
                            pi.compute_F_w(ob, copos_update_dir),
                            pi.get_log_action_prob(ob), timesteps_per_batch,
                            entropy, eta)
                        logger.log("Updated omega, eta: " + str(eta) +
                                   " and omega: " + str(omega))

                        # do line search for ratio for non-linear "beta" parameter values
                        #ratio = beta_ratio_line_search(w_theta, w_beta, eta, omega, allmean, compute_losses, get_flat, set_from_flat, pi,
                        #                     max_kl, beta, args)
                        # set ratio to 1 if we do not use beta ratio line search
                        ratio = 1
                        #print("ratio from line search: " + str(ratio))
                        cur_theta = (eta * prev_theta +
                                     w_theta.reshape(-1, )) / (eta + omega)
                        cur_beta = prev_beta + ratio * w_beta.reshape(
                            -1, ) / eta
                    else:
                        for i in range(2):
                            # Do a line search for both theta and beta parameters by adjusting only eta
                            eta = eta_search(w_theta, w_beta, eta, omega,
                                             allmean, compute_losses, get_flat,
                                             set_from_flat, pi, max_kl, args)
                            logger.log("Updated eta, eta: " + str(eta) +
                                       " and omega: " + str(omega))

                            # Find proper omega for new eta. Use old policy parameters first.
                            set_from_flat(
                                pi.theta_beta_to_all(prev_theta, prev_beta))
                            eta, omega = \
                                eta_omega_optimizer.optimize(w_theta, Waa, Wsa, wa, varphis, pi.get_kt(),
                                                             pi.get_prec_matrix(), pi.is_new_policy_valid, old_ent, eta)
                            logger.log("Updated omega, eta: " + str(eta) +
                                       " and omega: " + str(omega))

                        # Use final policy
                        logger.log("Final eta: " + str(eta) + " and omega: " +
                                   str(omega))
                        cur_theta = (eta * prev_theta +
                                     w_theta.reshape(-1, )) / (eta + omega)
                        cur_beta = prev_beta + w_beta.reshape(-1, ) / eta

                    set_from_flat(pi.theta_beta_to_all(cur_theta, cur_beta))
                    meanlosses = surr, kl, *_ = allmean(
                        np.array(compute_losses(*args)))
##copos specific over
                if nworkers > 1 and iters_so_far % 20 == 0:
                    paramsums = MPI.COMM_WORLD.allgather(
                        (thnew.sum(),
                         vfadam.getflat().sum()))  # list of tuples
                    assert all(
                        np.allclose(ps, paramsums[0]) for ps in paramsums[1:])
#cg over
            for (lossname, lossval) in zip(loss_names, meanlosses):
                logger.record_tabular(lossname, lossval)
#policy update over
            with timed("vf"):
                for _ in range(vf_iters):
                    for (mbob, mbret) in dataset.iterbatches(
                        (seg["ob"], seg["tdlamret"]),
                            include_final_partial_batch=False,
                            batch_size=64):
                        g = allmean(compute_vflossandgrad(mbob, mbret))
                        vfadam.update(g, vf_stepsize)
            with timed("SIL"):
                lrnow = lr(1.0 - timesteps_so_far / total_timesteps)
                l_loss, sil_adv, sil_samples, sil_nlogp = model.sil_train(
                    lrnow)

            logger.record_tabular("ev_tdlam_before",
                                  explained_variance(vpredbefore, tdlamret))

        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        if MPI is not None:
            listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        else:
            listoflrpairs = [lrlocal]
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        if sil_update > 0:
            logger.record_tabular("SilSamples", sil_samples)

        if rank == 0:
            # MBL evaluation
            if not collect_val_data:
                #set_global_seeds(seed)
                default_sess = tf.get_default_session()

                def multithread_eval_policy(env_, pi_, num_episodes_,
                                            vis_eval_, seed):
                    with default_sess.as_default():
                        if hasattr(env, 'ob_rms') and hasattr(env_, 'ob_rms'):
                            env_.ob_rms = env.ob_rms
                        res = eval_policy(env_, pi_, num_episodes_, vis_eval_,
                                          seed, measure_time, measure_rew)

                        try:
                            env_.close()
                        except:
                            pass
                    return res

                if mbl.is_warm_start_done() and iters_so_far % eval_freq == 0:
                    warm_start_done = mbl.is_warm_start_done()
                    if num_eval_episodes > 0:
                        targs_names = {}
                        with timed('eval'):
                            num_descs = len(all_eval_descs)
                            list_field_names = [e[0] for e in all_eval_descs]
                            list_legend_names = [e[1] for e in all_eval_descs]
                            list_pis = [e[2] for e in all_eval_descs]
                            list_eval_envs = [
                                make_eval_env() for _ in range(num_descs)
                            ]
                            list_seed = [seed for _ in range(num_descs)]
                            list_num_eval_episodes = [
                                num_eval_episodes for _ in range(num_descs)
                            ]
                            print(list_field_names)
                            print(list_legend_names)

                            list_vis_eval = [
                                vis_eval for _ in range(num_descs)
                            ]

                            for i in range(num_descs):
                                field_name, legend_name = list_field_names[
                                    i], list_legend_names[i],

                                res = multithread_eval_policy(
                                    list_eval_envs[i], list_pis[i],
                                    list_num_eval_episodes[i],
                                    list_vis_eval[i], seed)
                                #eval_results = pool.starmap(multithread_eval_policy, zip(list_eval_envs, list_pis, list_num_eval_episodes, list_vis_eval,list_seed))

                                #for field_name, legend_name, res in zip(list_field_names, list_legend_names, eval_results):
                                perf, elapsed_time, eval_rew = res
                                logger.record_tabular(field_name, perf)
                                if measure_time:
                                    logger.record_tabular(
                                        'Time-%s' % (field_name), elapsed_time)
                                if measure_rew:
                                    logger.record_tabular(
                                        'SimRew-%s' % (field_name), eval_rew)
                                targs_names[field_name] = legend_name

                    if eval_val_err:
                        fwd_dynamics_err = mbl.eval_forward_dynamic(
                            obs=eval_val_dataset['ob'],
                            acs=eval_val_dataset['ac'],
                            obs_next=eval_val_dataset['ob_next'])
                        logger.record_tabular('FwdValError', fwd_dynamics_err)

                    logger.dump_tabular()
                    #print(logger.get_dir())
                    #print(targs_names)


#                    if num_eval_episodes > 0:
#                        win = plot(viz, win, logger.get_dir(), targs_names=targs_names, quant=quant, opt='best')
# -----------
#logger.dump_tabular()
        yield pi

    if collect_val_data:
        with open(validation_set_path, 'wb') as f:
            pickle.dump(val_dataset_collect, f)
        logger.log('Save {} validation data'.format(len(val_dataset_collect)))