def learn(env, policy_fn, *,
          timesteps_per_actorbatch,  # timesteps per actor per update
          clip_param, entcoeff,  # clipping parameter epsilon, entropy coeff
          optim_epochs, optim_stepsize, optim_batchsize,  # optimization hypers
          gamma, lam,  # advantage estimation
          # CMAES
          max_fitness,  # has to be negative, as cmaes consider minization
          popsize,
          gensize,
          bounds,
          sigma,
          eval_iters,
          max_timesteps = 0, max_episodes = 0, max_iters = 0, max_seconds = 0,
          # time constraint
          callback = None,
          # you can do anything in the callback, since it takes locals(), globals()
          adam_epsilon = 1e-5,
          schedule = 'constant',
          # annealing for stepsize parameters (epsilon and adam)
          seed,
          env_id
          ):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_fn("pi", ob_space, ac_space)  # Construct network for new policy
    oldpi = policy_fn("oldpi", ob_space, ac_space)  # Network for old policy
    backup_pi = policy_fn("backup_pi", ob_space, ac_space)  # Network for cmaes individual to train
    pi_zero = policy_fn("zero_pi", ob_space, ac_space)  # Network for cmaes individual to train

    atarg = tf.placeholder(dtype = tf.float32, shape = [
        None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype = tf.float32, shape = [None])  # Empirical return

    reward = tf.placeholder(dtype = tf.float32, shape = [None])  # step rewards
    lrmult = tf.placeholder(name = 'lrmult', dtype = tf.float32,
                            shape = [])  # learning rate multiplier, updated with schedule

    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name = "ob")
    next_ob = U.get_placeholder_cached(name = "next_ob")  # next step observation for updating q function
    ac = U.get_placeholder_cached(name = "act")  # action placeholder for computing q function
    mean_ac = U.get_placeholder_cached(name = "mean_act")  # action placeholder for computing q function

    kloldnew = oldpi.pd.kl(pi.pd)
    klpi_pi_zero = pi.pd.kl(pi_zero.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param,
                             1.0 + clip_param) * atarg  #
    pol_surr = - tf.reduce_mean(
        tf.minimum(surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)

    y = reward + gamma * pi.mean_qpred
    qf_loss = tf.reduce_mean(tf.square(y - pi.qpred))
    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen  # v function is independently trained
    qf_losses = [qf_loss]
    vf_losses = [vf_loss]
    qv_losses = [qf_loss, vf_loss]
    losses = [pol_surr, pol_entpen, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    # print(var_list)
    if isinstance(pi, CnnPolicy):
        lin_var_list = [v for v in var_list if v.name.split("/")[1].startswith(
            "lin")]
        vf_var_list = [v for v in var_list if v.name.split("/")[1].startswith(
            "logits")]
        pol_var_list = [v for v in var_list if v.name.split("/")[1].startswith(
            "value")]
        # Policy + Value function, the final layer, all trainable variables
        # Remove vf variables
        var_list = lin_var_list + pol_var_list
    else:
        fc2_var_list = [v for v in var_list if v.name.split("/")[2].startswith(
            "fc2")]
        final_var_list = [v for v in var_list if v.name.split("/")[
            2].startswith(
            "final")]
        # var_list = vf_var_list + pol_var_list
        var_list = fc2_var_list + final_var_list
        print(var_list)
    # print(var_list)
    qf_var_list = [v for v in var_list if v.name.split("/")[1].startswith(
        "qf")]
    mean_qf_var_list = [v for v in var_list if v.name.split("/")[1].startswith(
        "meanqf")]
    vf_var_list = [v for v in var_list if v.name.split("/")[1].startswith(
        "vf")]

    # compute the Advantage estimations: A = Q - V for pi
    get_A_estimation = U.function([ob, next_ob, ac], [pi.qpred - pi.vpred])
    get_A_pi_zero_estimation = U.function([ob, next_ob, ac], [pi_zero.qpred - pi_zero.vpred])
    # compute the Advantage estimations: A = Q - V for evalpi

    # compute the mean action for given states under pi
    mean_pi_actions = U.function([ob], [pi.pd.mode()])
    mean_pi_zero_actions = U.function([ob], [pi_zero.pd.mode()])
    # compute the mean kl
    mean_Kl = U.function([ob], [tf.reduce_mean(klpi_pi_zero)])

    qf_lossandgrad = U.function([ob, ac, next_ob, mean_ac, lrmult, reward],
                                qf_losses + [U.flatgrad(qf_loss, qf_var_list)])
    vf_lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                                vf_losses + [U.flatgrad(vf_loss, vf_var_list)])
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])

    qf_adam = MpiAdam(qf_var_list, epsilon = adam_epsilon)

    vf_adam = MpiAdam(vf_var_list, epsilon = adam_epsilon)

    adam = MpiAdam(var_list, epsilon = adam_epsilon)

    assign_target_q_eq_eval_q = U.function([], [], updates = [tf.assign(oldqv, newqv)
                                                      for (oldqv, newqv) in zipsame(
            mean_qf_var_list, qf_var_list)])
    assign_old_eq_new = U.function([], [], updates = [tf.assign(oldv, newv)
                                                      for (oldv, newv) in zipsame(
            oldpi.get_variables(), pi.get_variables())])

    # Assign pi to backup (only backup trainable variables)
    assign_backup_eq_new = U.function([], [], updates = [tf.assign(backup_v, newv)
                                                       for (backup_v, newv) in zipsame(
            backup_pi.get_variables(), pi.get_variables())])

    # Assign backup back to pi
    assign_new_eq_backup = U.function([], [], updates = [tf.assign(newv, backup_v)
                                                       for (newv, backup_v) in zipsame(
            pi.get_variables(), backup_pi.get_variables())])


    # Assign pi to pi0 (for parameter updating constraints)
    assign_pi_zero_eq_new = U.function([], [], updates = [tf.assign(pi_zero_v, newv)
                                                       for (pi_zero_v, newv) in zipsame(
            pi_zero.get_variables(), pi.get_variables())])


    # Compute all losses

    compute_v_losses = U.function([ob, ac, atarg, ret, lrmult], vf_losses)
    compute_v_losses = U.function([ob, ac, next_ob, lrmult, reward], qf_losses)
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    pi_set_flat = U.SetFromFlat(pi.get_trainable_variables())
    pi_get_flat = U.GetFlat(pi.get_trainable_variables())
    backup_pi_get_flat = U.GetFlat(backup_pi.get_trainable_variables())
    pi_zero_get_flat = U.GetFlat(pi_zero.get_trainable_variables())
    qf_adam.sync()
    vf_adam.sync()
    adam.sync()

    global timesteps_so_far, episodes_so_far, iters_so_far, \
        tstart, lenbuffer, rewbuffer, tstart, ppo_timesteps_so_far, best_fitness
    episodes_so_far = 0
    timesteps_so_far = 0
    ppo_timesteps_so_far = 0
    # cmaes_timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen = 100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen = 100)  # rolling buffer for episode rewards

    test_rew_buffer = []
    # Prepare for rollouts
    # ----------------------------------------
    # assign pi to eval_pi
    actors = []
    best_fitness = 0

    seg_gen = traj_segment_generator(pi, env, timesteps_per_actorbatch, stochastic = True)
    eval_gen = traj_segment_generator_eval(pi, env, 4096, stochastic = True)

    assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0,
                max_seconds > 0]) == 1, "Only one time constraint permitted"

    i = 0
    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        # PPO Train V and Q
        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, next_ob, ac, atarg, tdlamret, reward = seg["ob"], seg["next_ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"], seg["rew"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(ob = ob, ac = ac, atarg = atarg, vtarg = tdlamret),
                    shuffle = not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        # Record Pi0 behavior 25 times
        eval_seg = eval_gen.__next__()
        test_rew_buffer.append(eval_seg["ep_rets"])

        if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy

        # Re-train V function
        logger.log("Evaluating V Func Losses")
        for _ in range(optim_epochs):
            losses = []  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *vf_losses, g = vf_lossandgrad(batch["ob"], batch["ac"],
                                               batch["atarg"], batch["vtarg"],
                                               cur_lrmult)
                vf_adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(vf_losses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        # Random select tansitions to train Q
        random_idx = []
        len_repo = len(seg["ob"])
        optim_epochs_q = int(len_repo / optim_batchsize) if int(len_repo / optim_batchsize) > optim_epochs else optim_epochs
        for _ in range(optim_epochs_q):
            random_idx.append(np.random.choice(range(len_repo), optim_batchsize))

        # Re-train q function
        logger.log("Evaluating Q Func Losses")
        for _ in range(optim_epochs_q):
            losses = []  # list of tuples, each of which gives the loss for a minibatch
            for idx in random_idx:
                *qf_losses, g = qf_lossandgrad(seg["ob"][idx], seg["ac"][idx], seg["next_ob"][idx], mean_pi_actions(ob)[0][idx],
                                               cur_lrmult, seg["rew"][idx])
                qf_adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(qf_losses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("CMAES Policy Optimization")
        # Make two q network equal
        assign_target_q_eq_eval_q()

        # CMAES
        assign_pi_zero_eq_new() #memorize the p0

        weights = pi.get_trainable_variables()
        layer_params = [v for v in weights if v.name.split("/")[1].startswith(
            "pol")]
        # if i + 1 < len(weights):
        #     layer_params = [weights[i], weights[i + 1]]
        # else:
        #     layer_params = [weights[i]]
        #     if len(layer_params) <= 1:
        #         layer_params = [weights[i - 1], weights[i]]
        layer_params_flat = get_layer_flat(layer_params)
        index, init_uniform_layer_weights = uniform_select(layer_params_flat,
                                                           1000)
        opt = cma.CMAOptions()
        opt['tolfun'] = max_fitness
        opt['popsize'] = popsize
        opt['maxiter'] = gensize
        opt['verb_disp'] = 0
        opt['verb_log'] = 0
        opt['seed'] = seed
        opt['AdaptSigma'] = True
        # opt['bounds'] = bounds
        sigma1 = sigma - 0.001 * iters_so_far
        if sigma1 < 0.0001:
            sigma1 = 0.0001
        print("Sigma=", sigma1)
        es = cma.CMAEvolutionStrategy(init_uniform_layer_weights,
                                      sigma1, opt)
        best_solution = init_uniform_layer_weights.astype(
            np.float64)
        best_fitness = np.inf
        costs = None
        while True:
            if es.countiter >= opt['maxiter']:
                break
            solutions = es.ask()
            assign_backup_eq_new() #backup current policy, after Q and V have been trained
            if KL_Condition:
                for id, solution in enumerate(solutions):
                    new_variable = set_uniform_weights(layer_params_flat, solution, index)
                    set_layer_flat(layer_params, new_variable)
                    i = 0
                    mean_kl_const = mean_Kl(ob)[0]
                    while(mean_kl_const > 0.5):
                        i+=1
                        # solutions[id] = es.ask(number = 1, xmean = np.take(pi_zero_get_flat(), index),
                        # sigma_fac = 0.9 ** i)
                        solutions[id] = es.ask(number = 1, sigma_fac = 0.9 ** i)[0]
                        new_variable = set_uniform_weights(layer_params_flat, solutions[id], index)
                        set_layer_flat(layer_params, new_variable)
                        mean_kl_const = mean_Kl(ob)[0]
                        logger.log("Regenerate Solution for " +str(i)+ " times for ID:" + str(id) + " mean_kl:" + str(mean_kl_const))

            if mean_action_Condition:
                for id, solution in enumerate(solutions):
                    new_variable = set_uniform_weights(layer_params_flat, solution, index)
                    set_layer_flat(layer_params, new_variable)
                    i = 0
                    # mean_act_dist = np.sqrt(np.dot(np.array(mean_pi_actions(ob)).flatten() - np.array(mean_pi_zero_actions(ob)).flatten(),
                    #                                np.array(mean_pi_actions(ob)).flatten() - np.array(mean_pi_zero_actions(ob)).flatten()))
                    abs_act_dist = np.mean(np.abs(np.array(mean_pi_actions(ob)).flatten() - np.array(mean_pi_zero_actions(ob)).flatten()))
                    while(abs_act_dist > 0.01):
                        i+=1
                        # solutions[id] = es.ask(number = 1, xmean = np.take(pi_zero_get_flat(), index),
                        # sigma_fac = 0.9 ** i)
                        solutions[id] = es.ask(number = 1, sigma_fac = 0.999 ** i)[0]
                        new_variable = set_uniform_weights(layer_params_flat, solutions[id], index)
                        set_layer_flat(layer_params, new_variable)
                        # mean_act_dist = np.sqrt(np.dot(np.array(mean_pi_actions(ob)).flatten() - np.array(mean_pi_zero_actions(ob)).flatten(),
                        #                            np.array(mean_pi_actions(ob)).flatten() - np.array(mean_pi_zero_actions(ob)).flatten()))
                        abs_act_dist = np.mean(np.abs(np.array(mean_pi_actions(ob)).flatten() - np.array(mean_pi_zero_actions(ob)).flatten()))
                        logger.log("Regenerate Solution for " +str(i)+ " times for ID:" + str(id) + " mean_action_dist:" + str(abs_act_dist))
            assign_new_eq_backup() # Restore the backup
            segs = []
            ob_segs = None
            costs = []
            lens = []
            # Evaluation

            a_func =get_A_estimation(ob,ob,mean_pi_actions(ob)[0])
            # a_func = (a_func - np.mean(a_func)) / np.std(a_func)
            a_func_pi_zero = get_A_pi_zero_estimation(ob,ob,mean_pi_zero_actions(ob)[0])
            print("A-pi-zero:", np.mean(a_func_pi_zero))
            print("A-pi-best:",
                  np.mean(a_func))
            print()
            for id, solution in enumerate(solutions):
                new_variable = set_uniform_weights(layer_params_flat, solution, index)
                set_layer_flat(layer_params, new_variable)
                new_a_func= get_A_estimation(ob,
                                          ob,
                                          np.array(mean_pi_actions(ob)).transpose().reshape((len(ob), 1)))
                # new_a_func = (new_a_func - np.mean(new_a_func)) / np.std(new_a_func)
                print("A-pi" + str(id + 1), ":", np.mean(new_a_func))
                coeff1 = 0.9
                coeff2 = 0.9
                cost = - (np.mean(new_a_func))
                # cost = - (np.mean(new_a_func) - coeff1*
                #           np.sqrt(np.dot(pi_get_flat() - pi_zero_get_flat(),
                #                          pi_get_flat() - pi_zero_get_flat()))
                #           - coeff2 * mean_Kl(ob)[0])
                # new_a_funcs =
                costs.append(cost)
                assign_new_eq_backup() # Restore the backup
            # l2_decay = compute_weight_decay(0.999, solutions).reshape((np.array(costs).shape))
            # costs += l2_decay
            # costs, real_costs = fitness_normalization(costs)
            print(costs)
            # costs, real_costs = fitness_rank(costs)
            # es.tell(solutions=solutions, function_values = costs)
            es.tell_real_seg(solutions = solutions, function_values = costs, real_f = costs, segs = None)
            # if -min(costs) >= np.mean(a_func):
            if min(costs) <= best_fitness:
                print("Update Policy by CMAES")
                # best_solution = np.copy(es.result[0])
                # best_fitness = -es.result[1]
                best_solution = solutions[np.argmin(costs)]
                best_fitness = min(costs)
                best_layer_params_flat = set_uniform_weights(layer_params_flat,
                                                             best_solution,
                                                             index)
                set_layer_flat(layer_params, best_layer_params_flat)
            # assign_pi_zero_eq_new()
            # if mean_Kl(ob)[0] > 0.05: # Check the kl diverge
            #     print("mean_kl:", mean_Kl(ob)[0])
            #     print("Cancel updating")
            #     assign_new_eq_backup()
            # else:
            #     assign_pi_zero_eq_new() #memorize the p0
            print("Generation:", es.countiter)
            print("Best Solution Fitness:", best_fitness)
            # set old parameter values to new parameter values
            # break
        # Record CMAES-Updated Pi behaviors 25 times
        eval_seg = eval_gen.__next__()
        test_rew_buffer.append(eval_seg["ep_rets"])
        assign_pi_zero_eq_new() #Update the p0

        print("Pi 0 Performance:", test_rew_buffer[0])
        print("Pi 1 Performance:", test_rew_buffer[1])

        print("Pi 0 Mean {0} Std {1}".format(np.mean(test_rew_buffer[0]), np.std(test_rew_buffer[0])))
        print("Pi 1 Mean {0} Std {1}".format(np.mean(test_rew_buffer[1]), np.std(test_rew_buffer[1])))
        test_rew_buffer.clear()
        iters_so_far += 1
        episodes_so_far += sum(lens)
def learn(env, policy_fn, *,
        timesteps_per_actorbatch, # timesteps per actor per update
        clip_param, entcoeff, # clipping parameter epsilon, entropy coeff
        optim_epochs, optim_stepsize, optim_batchsize,# optimization hypers
        gamma, lam, # advantage estimation
        max_timesteps=0, max_episodes=0, max_iters=0, max_seconds=0,  # time constraint
        callback=None, # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant' # annealing for stepsize parameters (epsilon and adam)
        ):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_fn("pi", ob_space, ac_space) # Construct network for new policy
    oldpi = policy_fn("oldpi", ob_space, ac_space) # Network for old policy
    atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return

    lrmult = tf.placeholder(name='lrmult', dtype=tf.float32, shape=[]) # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # pnew / pold
    surr1 = ratio * atarg # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg #
    pol_surr = - tf.reduce_mean(tf.minimum(surr1, surr2)) # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult], losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function([],[], updates=[tf.assign(oldv, newv)
        for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    adam.sync()

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi, env, timesteps_per_actorbatch, stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100) # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100) # rolling buffer for episode rewards

    assert sum([max_iters>0, max_timesteps>0, max_episodes>0, max_seconds>0])==1, "Only one time constraint permitted"

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult =  max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************"%iters_so_far)

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"]
        vpredbefore = seg["vpred"] # predicted value function before udpate
        atarg = (atarg - atarg.mean()) / atarg.std() # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret), shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy

        assign_old_eq_new() # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [] # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult)
            losses.append(newlosses)
        meanlosses,_,_ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_"+name, lossval)
        logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        if MPI.COMM_WORLD.Get_rank()==0:
            logger.dump_tabular()
Beispiel #3
0
def learn(
        env,
        policy_func,
        *,
        timesteps_per_batch,  # timesteps per actor per update
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        optim_epochs,
        optim_stepsize,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,  # advantage estimation
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        max_seconds=0,  # time constraint
        callback=None,  # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
        sym_loss_weight=0.0,
        return_threshold=None,  # termiante learning if reaches return_threshold
        op_after_init=None,
        init_policy_params=None,
        policy_scope=None,
        max_threshold=None,
        positive_rew_enforce=False,
        reward_drop_bound=None,
        min_iters=0,
        ref_policy_params=None,
        rollout_length_thershold=None):

    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    if policy_scope is None:
        pi = policy_func("pi", ob_space,
                         ac_space)  # Construct network for new policy
        oldpi = policy_func("oldpi", ob_space,
                            ac_space)  # Network for old policy
    else:
        pi = policy_func(policy_scope, ob_space,
                         ac_space)  # Construct network for new policy
        oldpi = policy_func("old" + policy_scope, ob_space,
                            ac_space)  # Network for old policy

    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = U.mean(kloldnew)
    meanent = U.mean(ent)
    pol_entpen = (-entcoeff) * meanent

    sym_loss = sym_loss_weight * U.mean(
        tf.square(pi.mean - pi.mirrored_mean))  # mirror symmetric loss
    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = U.clip(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg  #
    pol_surr = -U.mean(tf.minimum(
        surr1, surr2)) + sym_loss  # PPO's pessimistic surrogate (L^CLIP)

    vf_loss = U.mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent, sym_loss]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent", "sym_loss"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()

    if init_policy_params is not None:
        cur_scope = pi.get_variables()[0].name[0:pi.get_variables()[0].name.
                                               find('/')]
        orig_scope = list(init_policy_params.keys()
                          )[0][0:list(init_policy_params.keys())[0].find('/')]
        for i in range(len(pi.get_variables())):
            assign_op = pi.get_variables()[i].assign(
                init_policy_params[pi.get_variables()[i].name.replace(
                    cur_scope, orig_scope, 1)])
            U.get_session().run(assign_op)
            assign_op = oldpi.get_variables()[i].assign(
                init_policy_params[pi.get_variables()[i].name.replace(
                    cur_scope, orig_scope, 1)])
            U.get_session().run(assign_op)

    if ref_policy_params is not None:
        ref_pi = policy_func("ref_pi", ob_space, ac_space)
        cur_scope = ref_pi.get_variables()[0].name[0:ref_pi.get_variables()[0].
                                                   name.find('/')]
        orig_scope = list(ref_policy_params.keys()
                          )[0][0:list(ref_policy_params.keys())[0].find('/')]
        for i in range(len(ref_pi.get_variables())):
            assign_op = ref_pi.get_variables()[i].assign(
                ref_policy_params[ref_pi.get_variables()[i].name.replace(
                    cur_scope, orig_scope, 1)])
            U.get_session().run(assign_op)
        env.env.env.ref_policy = ref_pi

    adam.sync()

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_batch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    max_thres_satisfied = max_threshold is None
    adjust_ratio = 0.0
    prev_avg_rew = -1000000
    revert_parameters = {}
    variables = pi.get_variables()
    for i in range(len(variables)):
        cur_val = variables[i].eval()
        revert_parameters[variables[i].name] = cur_val
    revert_data = [0, 0, 0]
    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()

        if reward_drop_bound is not None:
            lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
            listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
            lens, rews = map(flatten_lists, zip(*listoflrpairs))
            lenbuffer.extend(lens)
            rewbuffer.extend(rews)
            revert_iteration = False
            if np.mean(
                    rewbuffer
            ) < prev_avg_rew - 50:  # detect significant drop in performance, revert to previous iteration
                print("Revert Iteration!!!!!")
                revert_iteration = True
            else:
                prev_avg_rew = np.mean(rewbuffer)
            logger.record_tabular("Revert Rew", prev_avg_rew)
            if revert_iteration:  # revert iteration
                for i in range(len(pi.get_variables())):
                    assign_op = pi.get_variables()[i].assign(
                        revert_parameters[pi.get_variables()[i].name])
                    U.get_session().run(assign_op)
                episodes_so_far = revert_data[0]
                timesteps_so_far = revert_data[1]
                iters_so_far = revert_data[2]
                continue
            else:
                variables = pi.get_variables()
                for i in range(len(variables)):
                    cur_val = variables[i].eval()
                    revert_parameters[variables[i].name] = np.copy(cur_val)
                revert_data[0] = episodes_so_far
                revert_data[1] = timesteps_so_far
                revert_data[2] = iters_so_far

        if positive_rew_enforce:
            rewlocal = (seg["pos_rews"], seg["neg_pens"], seg["rew"]
                        )  # local values
            listofrews = MPI.COMM_WORLD.allgather(rewlocal)  # list of tuples
            pos_rews, neg_pens, rews = map(flatten_lists, zip(*listofrews))
            if np.mean(rews) < 0.0:
                #min_id = np.argmin(rews)
                #adjust_ratio = pos_rews[min_id]/np.abs(neg_pens[min_id])
                adjust_ratio = np.max([
                    adjust_ratio,
                    np.mean(pos_rews) / np.abs(np.mean(neg_pens))
                ])
                for i in range(len(seg["rew"])):
                    if np.abs(seg["rew"][i] - seg["pos_rews"][i] -
                              seg["neg_pens"][i]) > 1e-5:
                        print(seg["rew"][i], seg["pos_rews"][i],
                              seg["neg_pens"][i])
                        print('Reward wrong!')
                        abc
                    seg["rew"][i] = seg["pos_rews"][
                        i] + seg["neg_pens"][i] * adjust_ratio
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))
        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        if reward_drop_bound is None:
            lenbuffer.extend(lens)
            rewbuffer.extend(rews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        logger.record_tabular("Iter", iters_so_far)
        if positive_rew_enforce:
            if adjust_ratio is not None:
                logger.record_tabular("RewardAdjustRatio", adjust_ratio)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()

        if max_threshold is not None:
            print('Current max return: ', np.max(rewbuffer))
            if np.max(rewbuffer) > max_threshold:
                max_thres_satisfied = True
            else:
                max_thres_satisfied = False

        return_threshold_satisfied = True
        if return_threshold is not None:
            if not (np.mean(rewbuffer) > return_threshold
                    and iters_so_far > min_iters):
                return_threshold_satisfied = False
        rollout_length_thershold_satisfied = True
        if rollout_length_thershold is not None:
            rewlocal = (seg["avg_vels"], seg["rew"])  # local values
            listofrews = MPI.COMM_WORLD.allgather(rewlocal)  # list of tuples
            avg_vels, rews = map(flatten_lists, zip(*listofrews))
            if not (np.mean(lenbuffer) > rollout_length_thershold
                    and np.mean(avg_vels) > 0.5 * env.env.env.final_tv):
                rollout_length_thershold_satisfied = False
        if rollout_length_thershold is not None or return_threshold is not None:
            if rollout_length_thershold_satisfied and return_threshold_satisfied:
                break

    return pi, np.mean(rewbuffer)
Beispiel #4
0
def learn(
        env,
        policy_fn,
        *,
        timesteps_per_actorbatch,  # timesteps per actor per update
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        optim_epochs,
        optim_stepsize,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,  # advantage estimation
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        max_seconds=0,  # time constraint
        callback=None,  # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
        gradients=True,
        hessians=False,
        model_path='model',
        output_prefix,
        sim):

    #Directory setup:
    model_dir = 'models/'
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_fn("pi", ob_space,
                   ac_space)  # Construct network for new policy
    oldpi = policy_fn("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param,
                             1.0 + clip_param) * atarg  #
    pol_surr = -tf.reduce_mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()

    lossandgradandhessian = U.function(
        [ob, ac, atarg, ret, lrmult], losses +
        [U.flatgrad(total_loss, var_list),
         U.flathess(total_loss, var_list)])
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    # Set the logs writer to the folder /tmp/tensorflow_logs
    tf.summary.FileWriter(
        '/home/aespielberg/ResearchCode/baselines/baselines/tmp/',
        graph_def=tf.get_default_session().graph_def)
    adam.sync()

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_actorbatch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    gradient_indices = get_gradient_indices(pi)

    while True:
        if callback: callback(locals(), globals())

        #ANDYTODO: add new break condition
        '''
        try:
            print(np.std(rewbuffer) / np.mean(rewbuffer))
            print(rewbuffer)
            if np.std(rewbuffer) / np.mean(rewbuffer) < 0.01: #TODO: input argument
                break
        except:
            pass #No big
        '''

        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            gradient_set = []
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
                gradient_set.append(g)
                if not sim:
                    adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))
        print('objective is')
        print(np.sum(np.mean(losses, axis=0)[0:3]))
        print(get_model_vars(pi))
        if sim:
            print('return routine')
            return_routine(pi, d, batch, output_prefix, losses, cur_lrmult,
                           lossandgradandhessian, gradients, hessians,
                           gradient_set)
            return pi
        if np.mean(list(
                map(np.linalg.norm,
                    gradient_set))) < 1e-4:  #TODO: make this a variable
            #TODO: abstract all this away somehow (scope)
            print('minimized!')
            return_routine(pi, d, batch, output_prefix, losses, cur_lrmult,
                           lossandgradandhessian, gradients, hessians,
                           gradient_set)
            return pi
        print(np.mean(list(map(np.linalg.norm, np.array(gradient_set)))))
        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()
        if iters_so_far > 1:
            U.save_state(model_dir + model_path + str(iters_so_far))

    print('out of time')
    return_routine(pi, d, batch, output_prefix, losses, cur_lrmult,
                   lossandgradandhessian, gradients, hessians, gradient_set)
    return pi
Beispiel #5
0
def learn(
    env,
    policy_fn,
    reward_giver,
    expert_dataset,
    *,
    timesteps_per_actorbatch,  # timesteps per actor per update
    clip_param,
    entcoeff,  # clipping parameter epsilon, entropy coeff
    optim_epochs,
    optim_stepsize,
    optim_batchsize,  # optimization hypers
    gamma,
    lam,  # advantage estimation
    max_timesteps=0,
    max_episodes=0,
    max_iters=0,
    max_seconds=0,  # time constraint
    callback=None,  # you can do anything in the callback, since it takes locals(), globals()
    adam_epsilon=1e-5,
    schedule='constant'  # annealing for stepsize parameters (epsilon and adam)
):
    # Setup losses and stuff
    # ----------------------------------------
    d_stepsize = 3e-4
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_fn("pi", ob_space,
                   ac_space)  # Construct network for new policy
    oldpi = policy_fn("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param,
                             1.0 + clip_param) * atarg  #
    pol_surr = -tf.reduce_mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)
    d_adam = MpiAdam(reward_giver.get_trainable_variables())

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    adam.sync()
    d_adam.sync()

    # Prepare for rollouts
    # ----------------------------------------
    viewer = mujoco_py.MjViewer(env.sim)
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     viewer,
                                     reward_giver,
                                     timesteps_per_actorbatch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=40)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40)  # rolling buffer for episode rewards
    true_rewbuffer = deque(maxlen=40)

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        # ------------------ Update D ------------------
        logger.log("Optimizing Discriminator...")
        logger.log(fmt_row(13, reward_giver.loss_name))
        ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob))
        batch_size = len(ob)
        d_losses = [
        ]  # list of tuples, each of which gives the loss for a minibatch
        for ob_batch, ac_batch in dataset.iterbatches(
            (ob, ac), include_final_partial_batch=False,
                batch_size=batch_size):
            ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob_batch))
            # update running mean/std for reward_giver
            if hasattr(reward_giver, "obs_rms"):
                reward_giver.obs_rms.update(
                    np.concatenate((ob_batch, ob_expert), 0))
            *newlosses, g = reward_giver.lossandgrad(ob_batch, ac_batch,
                                                     ob_expert, ac_expert)
            d_adam.update(g, d_stepsize)
            d_losses.append(newlosses)
        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()

    return pi
Beispiel #6
0
def learn(env, policy_fn, *,
          timesteps_per_actorbatch,  # timesteps per actor per update
          clip_param, entcoeff,  # clipping parameter epsilon, entropy coeff
          optim_epochs, optim_stepsize, optim_batchsize,  # optimization hypers
          gamma, lam,  # advantage estimation
          max_timesteps=0, max_episodes=0, max_iters=0, max_seconds=0,  # time constraint
          callback=None,  # you can do anything in the callback, since it takes locals(), globals()
          adam_epsilon=1e-5,
          schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
          **kwargs,
          ):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space

    pi = policy_fn("pi", ob_space, ac_space)  # Construct network for new policy

    oldpi = policy_fn("oldpi", ob_space, ac_space)  # Network for old policy

    atarg = tf.placeholder(dtype=tf.float32, shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    atarg_novel = tf.placeholder(dtype=tf.float32,
                                 shape=[None])  # Target advantage function for the novelty reward term
    ret_novel = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return for the novelty reward term

    lrmult = tf.placeholder(name='lrmult', dtype=tf.float32,
                            shape=[])  # learning rate multiplier, updated with schedule

    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()

    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold

    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg  #

    surr1_novel = ratio * atarg_novel  # surrogate loss of the novelty term
    surr2_novel = tf.clip_by_value(ratio, 1.0 - clip_param,
                                   1.0 + clip_param) * atarg_novel  # surrogate loss of the novelty term

    pol_surr = - tf.reduce_mean(tf.minimum(surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    pol_surr_novel = -tf.reduce_mean(tf.minimum(surr1_novel, surr2_novel))  # PPO's surrogate for the novelty part

    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    vf_loss_novel = tf.reduce_mean(tf.square(pi.vpred_novel - ret_novel))

    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]

    total_loss_novel = pol_surr_novel + pol_entpen + vf_loss_novel
    losses_novel = [pol_surr_novel, pol_entpen, vf_loss_novel, meankl, meanent]

    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    policy_var_list = pi.get_trainable_variables(scope='pi/pol')

    policy_var_count = 0
    for vars in policy_var_list:
        count_in_var = 1
        for dim in vars.shape._dims:
            count_in_var *= dim
        policy_var_count += count_in_var

    noise_count = pi.get_trainable_variables(scope='pi/pol/logstd')[0].shape._dims[1]

    var_list = pi.get_trainable_variables(scope='pi/pol') + pi.get_trainable_variables(scope='pi/vf/')
    var_list_novel = pi.get_trainable_variables(scope='pi/pol') + pi.get_trainable_variables(scope='pi/vf_novel/')

    lossandgrad = U.function([ob, ac, atarg, ret, lrmult], losses + [U.flatgrad(total_loss, var_list)])

    lossandgrad_novel = U.function([ob, ac, atarg_novel, ret_novel, lrmult],
                                   losses_novel + [U.flatgrad(total_loss_novel, var_list_novel)])

    adam = MpiAdam(var_list, epsilon=adam_epsilon)
    adam_novel = MpiAdam(var_list_novel, epsilon=adam_epsilon)

    assign_old_eq_new = U.function([], [], updates=[tf.assign(oldv, newv)
                                                    for (oldv, newv) in
                                                    zipsame(oldpi.get_variables(), pi.get_variables())])

    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)
    compute_losses_novel = U.function([ob, ac, atarg_novel, ret_novel, lrmult], losses_novel)

    U.initialize()
    adam.sync()
    adam_novel.sync()

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi, env, timesteps_per_actorbatch, stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0

    novelty_update_iter_cycle = 10
    novelty_start_iter = 50
    novelty_update = True

    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards
    rewnovelbuffer = deque(maxlen=100)  # rolling buffer for episode novelty rewards

    assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0,
                max_seconds > 0]) == 1, "Only one time constraint permitted"

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, atarg_novel, tdlamret, tdlamret_novel = seg["ob"], seg["ac"], seg["adv"], seg["adv_novel"], seg[
            "tdlamret"], seg["tdlamret_novel"]

        vpredbefore = seg["vpred"]  # predicted value function before udpate
        vprednovelbefore = seg['vpred_novel']  # predicted novelty value function before update

        atarg = (atarg - atarg.mean()) / atarg.std()  # standardized advantage function estimate
        atarg_novel = (
                              atarg_novel - atarg_novel.mean()) / atarg_novel.std()  # standartized novelty advantage function estimate

        d = Dataset(
            dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret, atarg_novel=atarg_novel, vtarg_novel=tdlamret_novel),
            shuffle=not pi.recurrent)

        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        task_gradient_mag = [0]

        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = []  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult)

                adam.update(g, optim_stepsize * cur_lrmult)

                # adam_novel.update(g_novel, optim_stepsize * cur_lrmult)

                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult)
            # newlosses_novel = compute_losses_novel(batch["ob"], batch["ac"], batch["atarg_novel"], batch["vtarg_novel"],
            #                                        cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"], seg['ep_rets_novel'])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews, rews_novel = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        rewnovelbuffer.extend(rews_novel)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpRNoveltyRewMean", np.mean(rewnovelbuffer))

        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        if iters_so_far >= novelty_start_iter and iters_so_far % novelty_update_iter_cycle == 0:
            novelty_update = not novelty_update

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        logger.record_tabular("TaskGradMag", np.array(task_gradient_mag).mean())
        # logger.record_tabular("NoveltyUpdate", novelty_update)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()

    return pi
Beispiel #7
0
def learn(
    env,
    test_env,
    policy_func,
    *,
    timesteps_per_batch,  # timesteps per actor per update
    clip_param,
    entcoeff,  # clipping parameter epsilon, entropy coeff
    optim_epochs,
    optim_stepsize,
    optim_batchsize,  # optimization hypers
    gamma,
    lam,  # advantage estimation
    max_timesteps=0,
    max_episodes=0,
    max_iters=0,
    max_seconds=0,  # time constraint
    callback=None,  # you can do anything in the callback, since it takes locals(), globals()
    adam_epsilon=1e-5,
    schedule='constant'  # annealing for stepsize parameters (epsilon and adam)
):
    # Setup losses and stuff
    # ----------------------------------------
    rew_mean = []

    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_func("pi", ob_space,
                     ac_space)  # Construct network for new policy
    oldpi = policy_func("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = U.mean(kloldnew)
    meanent = U.mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = U.clip(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg  #
    pol_surr = -U.mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = U.mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    adam.sync()

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_batch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values

        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)

        curr_rew = evaluate(pi, test_env)
        rew_mean.append(curr_rew)
        print(curr_rew)

        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        episodes_so_far += len(lens)
        if len(lens) != 0:
            rew_mean.append(np.mean(rewbuffer))
        timesteps_so_far += sum(lens)
        iters_so_far += 1

    return rew_mean
def learn(env, genv, i_trial,policy_fn, *,
        timesteps_per_actorbatch, # timesteps per actor per update
        clip_param, entp, # clipping parameter epsilon, entropy coeff
        optim_epochs, optim_stepsize, optim_batchsize,# optimization hypers
        gamma, lam, # advantage estimation
        max_timesteps=0, max_episodes=0, max_iters=0, max_seconds=0,  # time constraint
        callback=None, # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant', # annealing for stepsize parameters (epsilon and adam)
        useentr=False,
        retrace=False
        ):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_fn("pi", ob_space, ac_space) # Construct network for new policy
    oldpi = policy_fn("oldpi", ob_space, ac_space) # Network for old policy
    atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return

    gpi = policy_fn("gpi", ob_space, ac_space) # Construct network for new policy
    goldpi = policy_fn("goldpi", ob_space, ac_space) # Network for old policy
    gatarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable)
    gret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return

    lrmult = tf.placeholder(name='lrmult', dtype=tf.float32, shape=[]) # learning rate multiplier, updated with schedule
    entcoeff = tf.placeholder(dtype=tf.float32, shape=[])
    clip_param = clip_param * lrmult # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])
    gac = gpi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    gkloldnew = goldpi.pd.kl(gpi.pd)
    gent = gpi.pd.entropy()
    gmeankl = tf.reduce_mean(gkloldnew)
    gmeanent = tf.reduce_mean(gent)
    gpol_entpen = (-entcoeff) * gmeanent


    ratio = tf.exp(pi.pd.logp(gac) - goldpi.pd.logp(gac)) # pnew / pold
    compute_ratio = U.function([ob, gac], ratio)
    surr1 = ratio * gatarg # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param, 1.0 + clip_param) * gatarg #
    pol_surr = - tf.reduce_mean(tf.minimum(surr1, surr2)) # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = tf.reduce_mean(tf.square(gpi.vpred - gret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    gratio = tf.exp(gpi.pd.logp(ac) - oldpi.pd.logp(ac))
    compute_gratio = U.function([ob, ac], gratio)
    gsurr1 = gratio * atarg
    gsurr2 = tf.clip_by_value(gratio, 1.0 - clip_param, 1.0 + clip_param) * atarg
    gpol_surr = - tf.reduce_mean(tf.minimum(gsurr1, gsurr2))
    gvf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    gtotal_loss = gpol_surr + gpol_entpen + gvf_loss
    glosses = [gpol_surr, gpol_entpen, gvf_loss, gmeankl, gmeanent]
    gloss_names = ["gpol_surr", "gpol_entpen", "gvf_loss", "gkl", "gent"]


    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, gac, gatarg, gret, lrmult, entcoeff], losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    gvar_list = gpi.get_trainable_variables()
    glossandgrad = U.function([ob, ac, atarg, ret, lrmult, entcoeff], glosses + [U.flatgrad(gtotal_loss, gvar_list)])
    gadam = MpiAdam(gvar_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function([],[], updates=[tf.assign(oldv, newv)
        for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())])

    gassign_old_eq_new = U.function([], [], updates=[tf.assign(oldv, newv)
        for (oldv, newv) in zipsame(goldpi.get_variables(), gpi.get_variables())])


    compute_losses = U.function([ob, gac, gatarg, gret, lrmult, entcoeff], losses)
    gcompute_losses = U.function([ob, ac, atarg, ret, lrmult, entcoeff], glosses)


    U.initialize()
    adam.sync()
    gadam.sync()

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi, gpi, env, timesteps_per_actorbatch, stochastic=True, gamma=gamma)
    gseg_gen = traj_segment_generator(gpi, pi, genv, timesteps_per_actorbatch, stochastic=True, gamma=gamma)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()

    # lenbuffer = deque(maxlen=100) # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100) # rolling buffer for episode rewards
    grewbuffer = deque(maxlen=100)
    drwdsbuffer = deque(maxlen=100)
    gdrwdsbuffer = deque(maxlen=100)


    assert sum([max_iters>0, max_timesteps>0, max_episodes>0, max_seconds>0])==1, "Only one time constraint permitted"

    def standarize(value):
        return (value - value.mean()) / (value.std())

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult =  max(1.0 - float(iters_so_far) / float(max_iters), 0)
        else:
            raise NotImplementedError

        if useentr:
            entcoeff = max(entp - float(iters_so_far) / float(max_iters), 0.01)
            # entcoeff = max(entp - float(iters_so_far) / float(max_iters), 0)
        else:
            entcoeff = 0.0


        print("********** Iteration %i ************"%iters_so_far)

        print("********** Guided Policy ************")

        gseg = gseg_gen.__next__()
        add_vtarg_and_adv(gseg, gamma, lam)

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)




        gob, gac, gatarg, gtdlamret, gvpredbefore= gseg["ob"], gseg["ac"], \
                                gseg["adv"], gseg["tdlamret"], gseg["vpred"]

        ob, ac, atarg, tdlamret, vpredbefore = seg["ob"], seg["ac"],\
                            seg["adv"], seg["tdlamret"], seg["vpred"],

        # use retrace clip advantage into new range
        if retrace:
            gprob_r = compute_gratio(gob, gac)
            gatarg = gatarg * np.minimum(1., gprob_r)
            prob_r = compute_ratio(ob, ac)
            atarg = atarg * np.minimum(1., prob_r)

        standarize(gatarg)
        standarize(atarg)

        gd = Dataset(dict(gob=gob, gac=gac, gatarg=gatarg, gvtarg=gtdlamret),
                     shuffle=not gpi.recurrent)

        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)

        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(gpi, "ob_rms"): gpi.ob_rms.update(ob)
        if hasattr(pi, "ob_rms"): pi.ob_rms.update(gob) # update running mean/std for policy

        gassign_old_eq_new()
        print("Optimizing...Guided Policy")
        # print(fmt_row(13, gloss_names))

        # Here we do a bunch of optimization epochs over the data

        for _ in range(optim_epochs):
            glosses = []  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = glossandgrad(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult, entcoeff)
                gadam.update(g, optim_stepsize * cur_lrmult)
                glosses.append(newlosses)
            # print(fmt_row(13, np.mean(glosses, axis=0)))

        # print("Evaluating losses...")
        glosses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = gcompute_losses(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult, entcoeff)
            glosses.append(newlosses)
        gmeanlosses, _, _ = mpi_moments(glosses, axis=0)
        # print(fmt_row(13, gmeanlosses))

        for (lossval, name) in zipsame(gmeanlosses, gloss_names):
            logger.record_tabular("gloss_" + name, lossval)
        # logger.record_tabular("gev_tdlam_before", explained_variance(vpredbefore, tdlamret))


        assign_old_eq_new() # set old parameter values to new parameter values
        # print("Optimizing...")
        # print(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data

        optim_batchsize = optim_batchsize or ob.shape[0]


        for _ in range(optim_epochs):
            losses = [] # list of tuples, each of which gives the loss for a minibatch
            for batch in gd.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["gob"], batch["gac"], batch["gatarg"], batch["gvtarg"], cur_lrmult, entcoeff)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            # print(fmt_row(13, np.mean(losses, axis=0)))

        # print("Evaluating losses...")
        losses = []
        for batch in gd.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["gob"], batch["gac"], batch["gatarg"], batch["gvtarg"], cur_lrmult, entcoeff)
            losses.append(newlosses)
        meanlosses,_,_ = mpi_moments(losses, axis=0)
        # print(fmt_row(13, meanlosses))

        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_"+name, lossval)
        # logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))



        lrlocal = (seg["ep_lens"], seg["ep_rets"], seg["ep_drwds"]) # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
        lens, rews, drwds = map(flatten_lists, zip(*listoflrpairs))

        glrlocal = (gseg["ep_lens"], gseg["ep_rets"], gseg["ep_drwds"]) # local values
        glistoflrpairs = MPI.COMM_WORLD.allgather(glrlocal) # list of tuples
        glens, grews, gdrwds = map(flatten_lists, zip(*glistoflrpairs))

        # lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        grewbuffer.extend(grews)
        drwdsbuffer.extend(drwds)
        gdrwdsbuffer.extend(gdrwds)

        # logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpDRewMean", np.mean(drwdsbuffer))
        logger.record_tabular("GEpRewMean", np.mean(grewbuffer))
        logger.record_tabular("GEpDRewMean", np.mean(gdrwdsbuffer))

        # logger.record_tabular("EpThisIter", len(lens))



        # episodes_so_far += len(lens)
        # timesteps_so_far += sum(lens)
        iters_so_far += 1
        # logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)

        logger.logkv('trial', i_trial)
        logger.logkv("Iteration", iters_so_far)
        logger.logkv("Name", 'PPOguided')

        if MPI.COMM_WORLD.Get_rank()==0:
            logger.dump_tabular()
Beispiel #9
0
def learn(env, policy_func, *,
        timesteps_per_batch, # timesteps per actor per update
        clip_param, entcoeff, # clipping parameter epsilon, entropy coeff
        optim_epochs, optim_stepsize, optim_batchsize,# optimization hypers
        gamma, lam, # advantage estimation
        max_timesteps=0, max_episodes=0, max_iters=0, max_seconds=0,  # time constraint
        callback=None, # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant', # annealing for stepsize parameters (epsilon and adam)
        num_options=1,
        app='',
        saves=False,
        wsaves=False,
        epoch=-1,
        seed=1,
        dc=0
        ):


    optim_batchsize_ideal = optim_batchsize 
    np.random.seed(seed)
    tf.set_random_seed(seed)
    env.seed(seed)

    ### Book-keeping
    gamename = env.spec.id[:-3].lower()
    gamename += 'seed' + str(seed)
    gamename += app
    # This variable: "version name, defines the name of the training"
    version_name = 'NORM-ACT-LOWER-LR-len-400-wNoise-update1-ppo-ESCH-1-own-impl-both-equal' 

    dirname = '{}_{}_{}opts_saves/'.format(version_name,gamename,num_options)
    print (dirname)

    # retrieve everything using relative paths. Create a train_results folder where the repo has been cloned
    dirname_rel = os.path.dirname(__file__)
    splitted = dirname_rel.split("/")
    dirname_rel = ("/".join(dirname_rel.split("/")[:len(splitted)-3])+"/")
    dirname = dirname_rel + "train_results/" + dirname

    # if saving -> create the necessary directories
    if wsaves:
        first=True
        if not os.path.exists(dirname):
            os.makedirs(dirname)
            first = False

        # copy also the original files into the folder where the training results are stored

        files = ['pposgd_simple.py','mlp_policy.py','run_mujoco.py']
        first = True
        for i in range(len(files)):
            src = os.path.join(dirname_rel,'baselines/baselines/ppo1/') + files[i]
            print (src)
            #dest = os.path.join('/home/nfunk/results_NEW/ppo1/') + dirname
            dest = dirname + "src_code/"
            if (first):
                os.makedirs(dest)
                first = False
            print (dest)
            shutil.copy2(src,dest)
        # brute force copy normal env file at end of copying process:
        src = os.path.join(dirname_rel,'nfunk/envs_nf/pendulum_nf.py')
        shutil.copy2(src,dest)
        shutil.copy2(src,dest)
        os.makedirs(dest+"assets/")
        src = os.path.join(dirname_rel,'nfunk/envs_nf/assets/clockwise.png')
        shutil.copy2(src,dest+"assets/")
    ###


    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    max_action = env.action_space.high

    # add the dimension in the observation space!
    ob_space.shape =((ob_space.shape[0] + ac_space.shape[0]),)
    print (ob_space.shape)
    print (ac_space.shape)

    pi = policy_func("pi", ob_space, ac_space) # Construct network for new policy
    oldpi = policy_func("oldpi", ob_space, ac_space) # Network for old policy
    atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function 
    ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return
    pol_ov_op_ent = tf.placeholder(dtype=tf.float32, shape=None) # Entropy coefficient for policy over options


    lrmult = tf.placeholder(name='lrmult', dtype=tf.float32, shape=[]) # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult # Annealed cliping parameter epislon for PPO


    # setup observation, option and terminal advantace
    ob = U.get_placeholder_cached(name="ob")
    option = U.get_placeholder_cached(name="option")
    term_adv = U.get_placeholder(name='term_adv', dtype=tf.float32, shape=[None])

    # create variable for action
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = U.mean(kloldnew)
    meanent = U.mean(ent)
    pol_entpen = (-entcoeff) * meanent

    # propability of choosing action under new policy vs old policy (PPO)
    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) 
    # advantage of choosing the action
    atarg_clip = atarg
    # surrogate 1:
    surr1 = ratio * atarg_clip #atarg # surrogate from conservative policy iteration
    # surrogate 2:
    surr2 = U.clip(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg_clip 
    # PPO's pessimistic surrogate (L^CLIP)
    pol_surr = - U.mean(tf.minimum(surr1, surr2)) 

    # Loss on the Q-function
    vf_loss = U.mean(tf.square(pi.vpred - ret))
    # calculate the total loss
    total_loss = pol_surr + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    # calculate logarithm of propability of policy over options
    log_pi = tf.log(tf.clip_by_value(pi.op_pi, 1e-5, 1.0))
    # calculate logarithm of propability of policy over options old parameter
    old_log_pi = tf.log(tf.clip_by_value(oldpi.op_pi, 1e-5, 1.0))
    # calculate entropy of policy over options
    entropy = -tf.reduce_sum(pi.op_pi * log_pi, reduction_indices=1)

    # calculate the ppo update for the policy over options:
    ratio_pol_ov_op = tf.exp(tf.transpose(log_pi)[option[0]] - tf.transpose(old_log_pi)[option[0]]) # pnew / pold
    term_adv_clip = term_adv 
    surr1_pol_ov_op = ratio_pol_ov_op * term_adv_clip # surrogate from conservative policy iteration
    surr2_pol_ov_op = U.clip(ratio_pol_ov_op, 1.0 - clip_param, 1.0 + clip_param) * term_adv_clip #
    pol_surr_pol_ov_op = - U.mean(tf.minimum(surr1_pol_ov_op, surr2_pol_ov_op)) # PPO's pessimistic surrogate (L^CLIP)
    
    op_loss = pol_surr_pol_ov_op - pol_ov_op_ent*tf.reduce_sum(entropy)

    # add loss of policy over options to total loss
    total_loss += op_loss
    
    var_list = pi.get_trainable_variables()
    term_list = var_list[6:8]

    # define function that we will later do gradien descent on
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult,option, term_adv,pol_ov_op_ent], losses + [U.flatgrad(total_loss, var_list)])
    
    # define adam optimizer
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    # define function that will assign the current parameters to the old policy
    assign_old_eq_new = U.function([],[], updates=[tf.assign(oldv, newv)
        for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult, option], losses)


    U.initialize()
    adam.sync()


    # NOW: all the stuff for training was defined, from here on we start with the execution:

    # initialize "savers" which will store the results
    saver = tf.train.Saver(max_to_keep=10000)
    saver_best = tf.train.Saver(max_to_keep=1)


    ### Define the names of the .csv files that are going to be stored
    results=[]
    if saves:
        results = open(dirname + version_name + '_' + gamename +'_'+str(num_options)+'opts_'+'_results.csv','w')
        results_best_model = open(dirname + version_name + '_' + gamename +'_'+str(num_options)+'opts_'+'_bestmodel.csv','w')


        out = 'epoch,avg_reward'

        for opt in range(num_options): out += ',option {} dur'.format(opt)
        for opt in range(num_options): out += ',option {} std'.format(opt)
        for opt in range(num_options): out += ',option {} term'.format(opt)
        for opt in range(num_options): out += ',option {} adv'.format(opt)
        out+='\n'
        results.write(out)
        # results.write('epoch,avg_reward,option 1 dur, option 2 dur, option 1 term, option 2 term\n')
        results.flush()

    # speciality: if running the training with epoch argument -> a model is loaded
    if epoch >= 0:
        
        dirname = '{}_{}opts_saves/'.format(gamename,num_options)
        print("Loading weights from iteration: " + str(epoch))

        filename = dirname + '{}_epoch_{}.ckpt'.format(gamename,epoch)
        saver.restore(U.get_session(),filename)
    ###    


    # start training
    episodes_so_far = 0
    timesteps_so_far = 0
    global iters_so_far
    iters_so_far = 0
    des_pol_op_ent = 0.1    # define policy over options entropy scheduling
    max_val = -100000       # define max_val, this will be updated to always store the best model
    tstart = time.time()
    lenbuffer = deque(maxlen=100) # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100) # rolling buffer for episode rewards

    assert sum([max_iters>0, max_timesteps>0, max_episodes>0, max_seconds>0])==1, "Only one time constraint permitted"

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi, env, timesteps_per_batch, stochastic=True, num_options=num_options,saves=saves,results=results,rewbuffer=rewbuffer,dc=dc)

    datas = [0 for _ in range(num_options)]

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult =  max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************"%iters_so_far)

        # Sample (s,a)-Transitions
        seg = seg_gen.__next__()
        # Calculate A(s,a,o) using GAE
        add_vtarg_and_adv(seg, gamma, lam)


        # calculate information for logging
        opt_d = []
        for i in range(num_options):
            dur = np.mean(seg['opt_dur'][i]) if len(seg['opt_dur'][i]) > 0 else 0.
            opt_d.append(dur)

        std = []
        for i in range(num_options):
            logstd = np.mean(seg['logstds'][i]) if len(seg['logstds'][i]) > 0 else 0.
            std.append(np.exp(logstd))
        print("mean opt dur:", opt_d)             
        print("mean op pol:", np.mean(np.array(seg['optpol_p']),axis=0))         
        print("mean term p:", np.mean(np.array(seg['term_p']),axis=0))
        print("mean value val:", np.mean(np.array(seg['value_val']),axis=0))
       

        ob, ac, opts, atarg, tdlamret = seg["ob"], seg["ac"], seg["opts"], seg["adv"], seg["tdlamret"]
        vpredbefore = seg["vpred"] # predicted value function before udpate
        atarg = (atarg - atarg.mean()) / atarg.std() # standardized advantage function estimate

        if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy
        if hasattr(pi, "ob_rms_only"): pi.ob_rms_only.update(ob[:,:-ac_space.shape[0]]) # update running mean/std for policy
        assign_old_eq_new() # set old parameter values to new parameter values

        # if iterations modulo 1000 -> adapt entropy scheduling coefficient
        if (iters_so_far+1)%1000 == 0:
            des_pol_op_ent = des_pol_op_ent/10

        # every 50 epochs save the best model
        if iters_so_far % 50 == 0 and wsaves:
            print("weights are saved...")
            filename = dirname + '{}_epoch_{}.ckpt'.format(gamename,iters_so_far)
            save_path = saver.save(U.get_session(),filename)

        # adaptively save best model -> if current reward is highest, save the model
        if (np.mean(rewbuffer)>max_val) and wsaves:
            max_val = np.mean(rewbuffer)
            results_best_model.write('epoch: '+str(iters_so_far) + 'rew: ' + str(np.mean(rewbuffer)) + '\n')
            results_best_model.flush()
            filename = dirname + 'best.ckpt'.format(gamename,iters_so_far)
            save_path = saver_best.save(U.get_session(),filename)



        # minimum batch size:
        min_batch=160 
        t_advs = [[] for _ in range(num_options)]
        
        # select all the samples concering one of the options
        # Note: so far the update is that we first use all samples from option 0 to update, then we use all samples from option 1 to update
        for opt in range(num_options):
            indices = np.where(opts==opt)[0]
            print("batch size:",indices.size)
            opt_d[opt] = indices.size
            if not indices.size:
                t_advs[opt].append(0.)
                continue


            ### This part is only necessasry when we use options. We proceed to these verifications in order not to discard any collected trajectories.
            if datas[opt] != 0:
                if (indices.size < min_batch and datas[opt].n > min_batch):
                    datas[opt] = Dataset(dict(ob=ob[indices], ac=ac[indices], atarg=atarg[indices], vtarg=tdlamret[indices]), shuffle=not pi.recurrent)
                    t_advs[opt].append(0.)
                    continue

                elif indices.size + datas[opt].n < min_batch:
                    # pdb.set_trace()
                    oldmap = datas[opt].data_map

                    cat_ob = np.concatenate((oldmap['ob'],ob[indices]))
                    cat_ac = np.concatenate((oldmap['ac'],ac[indices]))
                    cat_atarg = np.concatenate((oldmap['atarg'],atarg[indices]))
                    cat_vtarg = np.concatenate((oldmap['vtarg'],tdlamret[indices]))
                    datas[opt] = Dataset(dict(ob=cat_ob, ac=cat_ac, atarg=cat_atarg, vtarg=cat_vtarg), shuffle=not pi.recurrent)
                    t_advs[opt].append(0.)
                    continue

                elif (indices.size + datas[opt].n > min_batch and datas[opt].n < min_batch) or (indices.size > min_batch and datas[opt].n < min_batch):

                    oldmap = datas[opt].data_map
                    cat_ob = np.concatenate((oldmap['ob'],ob[indices]))
                    cat_ac = np.concatenate((oldmap['ac'],ac[indices]))
                    cat_atarg = np.concatenate((oldmap['atarg'],atarg[indices]))
                    cat_vtarg = np.concatenate((oldmap['vtarg'],tdlamret[indices]))
                    datas[opt] = d = Dataset(dict(ob=cat_ob, ac=cat_ac, atarg=cat_atarg, vtarg=cat_vtarg), shuffle=not pi.recurrent)

                if (indices.size > min_batch and datas[opt].n > min_batch):
                    datas[opt] = d = Dataset(dict(ob=ob[indices], ac=ac[indices], atarg=atarg[indices], vtarg=tdlamret[indices]), shuffle=not pi.recurrent)

            elif datas[opt] == 0:
                datas[opt] = d = Dataset(dict(ob=ob[indices], ac=ac[indices], atarg=atarg[indices], vtarg=tdlamret[indices]), shuffle=not pi.recurrent)
            ###


            # define the batchsize of the optimizer:
            optim_batchsize = optim_batchsize or ob.shape[0]
            print("optim epochs:", optim_epochs)
            logger.log("Optimizing...")


            # Here we do a bunch of optimization epochs over the data
            for _ in range(optim_epochs):
                losses = [] # list of tuples, each of which gives the loss for a minibatch
                for batch in d.iterate_once(optim_batchsize):

                    # Calculate advantage for using specific option here
                    tadv,nodc_adv = pi.get_opt_adv(batch["ob"],[opt])
                    tadv = tadv if num_options > 1 else np.zeros_like(tadv)
                    t_advs[opt].append(nodc_adv)

                    # calculate the gradient
                    *newlosses, grads = lossandgrad(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult, [opt], tadv,des_pol_op_ent)

                    # perform gradient update
                    adam.update(grads, optim_stepsize * cur_lrmult) 
                    losses.append(newlosses)


        # do logging:
        lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        if MPI.COMM_WORLD.Get_rank()==0:
            logger.dump_tabular()

        ### Book keeping
        if saves:
            out = "{},{}"
            for _ in range(num_options): out+=",{},{},{},{}"
            out+="\n"
            

            info = [iters_so_far, np.mean(rewbuffer)]
            for i in range(num_options): info.append(opt_d[i])
            for i in range(num_options): info.append(std[i])
            for i in range(num_options): info.append(np.mean(np.array(seg['term_p']),axis=0)[i])
            for i in range(num_options): 
                info.append(np.mean(t_advs[i]))

            results.write(out.format(*info))
            results.flush()
Beispiel #10
0
def learn(env, policy_fn, *,
        timesteps_per_actorbatch, # timesteps per actor per update
        clip_param, entcoeff, # clipping parameter epsilon, entropy coeff
        optim_epochs, optim_stepsize, optim_batchsize,# optimization hypers
        gamma, lam, # advantage estimation
        max_timesteps=0, max_episodes=0, max_iters=0, max_seconds=0,  # time constraint
        callback=None, # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant', # annealing for stepsize parameters (epsilon and adam)
        continue_from=None):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_fn("pi", ob_space, ac_space) # Construct network for new policy
    oldpi = policy_fn("oldpi", ob_space, ac_space) # Network for old policy
    atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return

    lrmult = tf.placeholder(name='lrmult', dtype=tf.float32, shape=[]) # learning rate multiplier, updated with schedule

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # pnew / pold
    surr1 = ratio * atarg # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg #
    pol_surr = - tf.reduce_mean(tf.minimum(surr1, surr2)) # PPO's pessimistic surrogate (L^CLIP)

    # count clipped samples
    clip_frac = tf.reduce_mean(tf.cast(tf.greater(tf.abs(1.0 - ratio), clip_param), tf.float32))

    # vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    vf_loss = tf.reduce_mean(tf.square(pi.vpred - (1 - gamma) * ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult], losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function([],[], updates=[tf.assign(oldv, newv)
        for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], [losses, clip_frac])

    U.initialize()
    adam.sync()
    saver = tf.train.Saver(max_to_keep=30)

    # continue training from saved models
    if continue_from:
        latest_model = tf.train.latest_checkpoint("../tf_models/" + continue_from)
        saver.restore(tf.get_default_session(), latest_model)
        logger.log("Loaded model from {}".format(continue_from))

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi, gamma, env, timesteps_per_actorbatch, stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100) # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100) # rolling buffer for episode rewards
    rewbuffer_comf = deque(maxlen=100)
    rewbuffer_effi = deque(maxlen=100)
    rewbuffer_time = deque(maxlen=100)
    rewbuffer_speed = deque(maxlen=100)
    rewbuffer_safety = deque(maxlen=100)
    num_danger_buffer = deque(maxlen=100)
    num_crash_buffer = deque(maxlen=100)
    is_success_buffer = deque(maxlen=100)
    is_collision_buffer = deque(maxlen=100)

    assert sum([max_iters>0, max_timesteps>0, max_episodes>0, max_seconds>0])==1, "Only one time constraint permitted"

    while True and max_iters != 1:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult =  max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************"%iters_so_far)

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"]
        vpredbefore = seg["vpred"] # predicted value function before udpate
        atarg = (atarg - atarg.mean()) / atarg.std() # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret), deterministic=pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy

        assign_old_eq_new() # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [] # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        clip_fracs = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses, newclipfrac = compute_losses(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult)
            losses.append(newlosses)
            clip_fracs.append(newclipfrac)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        meanclipfracs, _, _ = mpi_moments(clip_fracs, axis=0)

        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss/"+name, lossval)
        logger.record_tabular("misc/ev_tdlam_before", explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"], seg["ep_num_danger"], seg["ep_num_crash"], seg["ep_is_success"],
                   seg["ep_is_collision"],
                   seg["ep_rets_detail"][:, 0],
                   seg["ep_rets_detail"][:, 1],
                   seg["ep_rets_detail"][:, 2],
                   seg["ep_rets_detail"][:, 3],
                   seg["ep_rets_detail"][:, 4]) # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
        lens, rews, num_danger, num_crash, ep_is_success, ep_is_collision, \
        rews_comf, rews_effi, rews_time, rews_speed, rews_safety = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        num_danger_buffer.extend(num_danger)
        num_crash_buffer.extend(num_crash)
        is_success_buffer.extend(ep_is_success)
        is_collision_buffer.extend(ep_is_collision)

        rewbuffer.extend(rews)
        rewbuffer_comf.extend(rews_comf)
        rewbuffer_effi.extend(rews_effi)
        rewbuffer_time.extend(rews_time)
        rewbuffer_speed.extend(rews_speed)
        rewbuffer_safety.extend(rews_safety)

        logger.record_tabular("evaluations/meanclipfracs", meanclipfracs)
        logger.record_tabular("evaluations/EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("evaluations/numDangerMean", np.mean(num_danger_buffer))
        logger.record_tabular("evaluations/numCrashMean", np.mean(num_crash_buffer))
        logger.record_tabular("evaluations/successMean", np.mean(is_success_buffer))
        logger.record_tabular("evaluations/collisionMean", np.mean(is_collision_buffer))

        logger.record_tabular("rewards/EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("rewards/EpRewMean_comf", np.mean(rewbuffer_comf))
        logger.record_tabular("rewards/EpRewMean_effi", np.mean(rewbuffer_effi))
        logger.record_tabular("rewards/EpRewMean_time", np.mean(rewbuffer_time))
        logger.record_tabular("rewards/EpRewMean_speed", np.mean(rewbuffer_speed))
        logger.record_tabular("rewards/EpRewMean_safety", np.mean(rewbuffer_safety))

        logger.record_tabular("misc/EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("misc/EpisodesSoFar", episodes_so_far)
        logger.record_tabular("misc/TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("misc/TimeElapsed", time.time() - tstart)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()

    return pi
Beispiel #11
0
def learn(
        env,
        policy_func,
        *,
        timesteps_per_batch,  # timesteps per actor per update
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        optim_epochs,
        optim_stepsize,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,  # advantage estimation
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        max_seconds=0,  # time constraint
        callback=None,  # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
        init_policy_params=None,
        policy_scope='pi'):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space

    pi = policy_func(policy_scope, ob_space,
                     ac_space)  # Construct network for new policy
    oldpi = policy_func("old" + policy_scope, ob_space,
                        ac_space)  # Network for old policy

    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return
    clip_tf = tf.placeholder(dtype=tf.float32)

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)

    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_tf,
                             1.0 + clip_tf * lrmult) * atarg
    pol_surr = -tf.reduce_mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)

    if hasattr(pi, 'additional_loss'):
        pol_surr += pi.additional_loss

    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))

    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    pol_var_list = [
        v for v in pi.get_trainable_variables()
        if 'placehold' not in v.name and 'offset' not in v.name and 'secondary'
        not in v.name and 'vf' not in v.name and 'pol' in v.name
    ]
    pol_var_size = np.sum([np.prod(v.shape) for v in pol_var_list])

    get_pol_flat = U.GetFlat(pol_var_list)
    set_pol_from_flat = U.SetFromFlat(pol_var_list)

    total_loss = pol_surr + pol_entpen + vf_loss

    lossandgrad = U.function([ob, ac, atarg, ret, lrmult, clip_tf],
                             losses + [U.flatgrad(total_loss, var_list)])
    pol_lossandgrad = U.function([ob, ac, atarg, ret, lrmult, clip_tf],
                                 losses +
                                 [U.flatgrad(total_loss, pol_var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult, clip_tf], losses)
    compute_losses_cpo = U.function(
        [ob, ac, atarg, ret, lrmult, clip_tf],
        [tf.reduce_mean(surr1), pol_entpen, vf_loss, meankl, meanent])
    compute_ratios = U.function([ob, ac], ratio)
    compute_kls = U.function([ob, ac], kloldnew)

    compute_rollout_old_prob = U.function([ob, ac],
                                          tf.reduce_mean(oldpi.pd.logp(ac)))
    compute_rollout_new_prob = U.function([ob, ac],
                                          tf.reduce_mean(pi.pd.logp(ac)))
    compute_rollout_new_prob_min = U.function([ob, ac],
                                              tf.reduce_min(pi.pd.logp(ac)))

    update_ops = {}
    update_placeholders = {}
    for v in pi.get_trainable_variables():
        update_placeholders[v.name] = tf.placeholder(v.dtype,
                                                     shape=v.get_shape())
        update_ops[v.name] = v.assign(update_placeholders[v.name])

    # compute fisher information matrix
    dims = [int(np.prod(p.shape)) for p in pol_var_list]
    logprob_grad = U.flatgrad(tf.reduce_mean(pi.pd.logp(ac)), pol_var_list)
    compute_logprob_grad = U.function([ob, ac], logprob_grad)

    U.initialize()

    adam.sync()

    if init_policy_params is not None:
        cur_scope = pi.get_variables()[0].name[0:pi.get_variables()[0].name.
                                               find('/')]
        orig_scope = list(init_policy_params.keys()
                          )[0][0:list(init_policy_params.keys())[0].find('/')]
        print(cur_scope, orig_scope)
        for i in range(len(pi.get_variables())):
            if pi.get_variables()[i].name.replace(cur_scope, orig_scope,
                                                  1) in init_policy_params:
                assign_op = pi.get_variables()[i].assign(
                    init_policy_params[pi.get_variables()[i].name.replace(
                        cur_scope, orig_scope, 1)])
                tf.get_default_session().run(assign_op)
                assign_op = oldpi.get_variables()[i].assign(
                    init_policy_params[pi.get_variables()[i].name.replace(
                        cur_scope, orig_scope, 1)])
                tf.get_default_session().run(assign_op)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_batch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    prev_params = {}
    for v in var_list:
        if 'pol' in v.name:
            prev_params[v.name] = v.eval()

    optim_seg = None

    grad_scale = 1.0

    while True:
        if MPI.COMM_WORLD.Get_rank() == 0:
            print('begin')
            memory()
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()

        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        unstandardized_adv = np.copy(atarg)
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate

        args = seg["ob"], seg["ac"], atarg
        fvpargs = [arr for arr in args]

        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))

        cur_clip_val = clip_param

        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)

        for epoch in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult, cur_clip_val)
                adam.update(g * grad_scale, optim_stepsize * cur_lrmult)
                losses.append(newlosses)

            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult, clip_param)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)

        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.log(fmt_row(13, meanlosses))
            for (lossval, name) in zipsame(meanlosses, loss_names):
                logger.record_tabular("loss_" + name, lossval)

        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.record_tabular("EpLenMean", np.mean(lenbuffer))
            logger.record_tabular("EpRewMean", np.mean(rewbuffer))
            logger.record_tabular("EpThisIter", len(lens))
            logger.record_tabular(
                "PolVariance",
                repr(adam.getflat()[-env.action_space.shape[0]:]))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.record_tabular("EpisodesSoFar", episodes_so_far)
            logger.record_tabular("TimestepsSoFar", timesteps_so_far)
            logger.record_tabular("TimeElapsed", time.time() - tstart)

        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()
        if MPI.COMM_WORLD.Get_rank() == 0:
            print('end')
            memory()

    return pi, np.mean(rewbuffer)