def learn(
        args,
        env,
        policy_fn,
        *,
        timesteps_per_actorbatch,  # timesteps per actor per update
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        optim_epochs,
        optim_stepsize,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,  # advantage estimation
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        max_seconds=0,  # time constraint
        callback=None,  # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
        writer=None):
    print("\nBeginning learning...\n")

    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_fn("pi", ob_space,
                   ac_space)  # Construct network for new policy
    oldpi = policy_fn("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.compat.v1.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.compat.v1.placeholder(dtype=tf.float32,
                                   shape=[None])  # Empirical return

    lrmult = tf.compat.v1.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = {}
    ob['adj'] = U.get_placeholder_cached(name="adj")
    ob['node'] = U.get_placeholder_cached(name="node")

    ob_gen = {}
    ob_gen['adj'] = U.get_placeholder(
        shape=[None, ob_space['adj'].shape[0], None, None],
        dtype=tf.float32,
        name='adj_gen')
    ob_gen['node'] = U.get_placeholder(
        shape=[None, 1, None, ob_space['node'].shape[2]],
        dtype=tf.float32,
        name='node_gen')

    ob_real = {}
    ob_real['adj'] = U.get_placeholder(
        shape=[None, ob_space['adj'].shape[0], None, None],
        dtype=tf.float32,
        name='adj_real')
    ob_real['node'] = U.get_placeholder(
        shape=[None, 1, None, ob_space['node'].shape[2]],
        dtype=tf.float32,
        name='node_real')

    ac = tf.compat.v1.placeholder(dtype=tf.int64,
                                  shape=[None, 4],
                                  name='ac_real')

    ## PPO loss
    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    pi_logp = pi.pd.logp(ac)
    oldpi_logp = oldpi.pd.logp(ac)
    ratio_log = pi.pd.logp(ac) - oldpi.pd.logp(ac)

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param,
                             1.0 + clip_param) * atarg  #
    pol_surr = -tf.reduce_mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    ## Expert loss
    loss_expert = -tf.reduce_mean(pi_logp)

    ## Discriminator loss
    step_pred_real, step_logit_real = discriminator_net(ob_real,
                                                        args,
                                                        name='d_step')
    step_pred_gen, step_logit_gen = discriminator_net(ob_gen,
                                                      args,
                                                      name='d_step')
    loss_d_step_real = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(
            logits=step_logit_real,
            labels=tf.ones_like(step_logit_real) * 0.9))
    loss_d_step_gen = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(
            logits=step_logit_gen, labels=tf.zeros_like(step_logit_gen)))
    loss_d_step = loss_d_step_real + loss_d_step_gen
    if args.gan_type == 'normal':
        loss_g_step_gen = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=step_logit_gen, labels=tf.zeros_like(step_logit_gen)))
    elif args.gan_type == 'recommend':
        loss_g_step_gen = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=step_logit_gen,
                labels=tf.ones_like(step_logit_gen) * 0.9))
    elif args.gan_type == 'wgan':
        loss_d_step, _, _ = discriminator(ob_real, ob_gen, args, name='d_step')
        loss_d_step = loss_d_step * -1
        loss_g_step_gen, _ = discriminator_net(ob_gen, args, name='d_step')

    final_pred_real, final_logit_real = discriminator_net(ob_real,
                                                          args,
                                                          name='d_final')
    final_pred_gen, final_logit_gen = discriminator_net(ob_gen,
                                                        args,
                                                        name='d_final')
    loss_d_final_real = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(
            logits=final_logit_real,
            labels=tf.ones_like(final_logit_real) * 0.9))
    loss_d_final_gen = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(
            logits=final_logit_gen, labels=tf.zeros_like(final_logit_gen)))
    loss_d_final = loss_d_final_real + loss_d_final_gen
    if args.gan_type == 'normal':
        loss_g_final_gen = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=final_logit_gen, labels=tf.zeros_like(final_logit_gen)))
    elif args.gan_type == 'recommend':
        loss_g_final_gen = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=final_logit_gen,
                labels=tf.ones_like(final_logit_gen) * 0.9))
    elif args.gan_type == 'wgan':
        loss_d_final, _, _ = discriminator(ob_real,
                                           ob_gen,
                                           args,
                                           name='d_final')
        loss_d_final = loss_d_final * -1
        loss_g_final_gen, _ = discriminator_net(ob_gen, args, name='d_final')

    var_list_pi = pi.get_trainable_variables()
    var_list_pi_stop = [
        var for var in var_list_pi
        if ('emb' in var.name) or ('gcn' in var.name) or ('stop' in var.name)
    ]
    var_list_d_step = [
        var for var in tf.compat.v1.global_variables() if 'd_step' in var.name
    ]
    var_list_d_final = [
        var for var in tf.compat.v1.global_variables() if 'd_final' in var.name
    ]

    ## debug
    debug = {}

    ## loss update function
    lossandgrad_ppo = U.function([
        ob['adj'], ob['node'], ac, pi.ac_real, oldpi.ac_real, atarg, ret,
        lrmult
    ], losses + [U.flatgrad(total_loss, var_list_pi)])
    lossandgrad_expert = U.function(
        [ob['adj'], ob['node'], ac, pi.ac_real],
        [loss_expert, U.flatgrad(loss_expert, var_list_pi)])
    lossandgrad_expert_stop = U.function(
        [ob['adj'], ob['node'], ac, pi.ac_real],
        [loss_expert, U.flatgrad(loss_expert, var_list_pi_stop)])
    lossandgrad_d_step = U.function(
        [ob_real['adj'], ob_real['node'], ob_gen['adj'], ob_gen['node']],
        [loss_d_step, U.flatgrad(loss_d_step, var_list_d_step)])
    lossandgrad_d_final = U.function(
        [ob_real['adj'], ob_real['node'], ob_gen['adj'], ob_gen['node']],
        [loss_d_final,
         U.flatgrad(loss_d_final, var_list_d_final)])
    loss_g_gen_step_func = U.function([ob_gen['adj'], ob_gen['node']],
                                      loss_g_step_gen)
    loss_g_gen_final_func = U.function([ob_gen['adj'], ob_gen['node']],
                                       loss_g_final_gen)

    adam_pi = MpiAdam(var_list_pi, epsilon=adam_epsilon)
    adam_pi_stop = MpiAdam(var_list_pi_stop, epsilon=adam_epsilon)
    adam_d_step = MpiAdam(var_list_d_step, epsilon=adam_epsilon)
    adam_d_final = MpiAdam(var_list_d_final, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.compat.v1.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])

    compute_losses = U.function([
        ob['adj'], ob['node'], ac, pi.ac_real, oldpi.ac_real, atarg, ret,
        lrmult
    ], losses)

    # Prepare for rollouts
    # ----------------------------------------
    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    lenbuffer_valid = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards
    rewbuffer_env = deque(maxlen=100)  # rolling buffer for episode rewards
    rewbuffer_d_step = deque(maxlen=100)  # rolling buffer for episode rewards
    rewbuffer_d_final = deque(maxlen=100)  # rolling buffer for episode rewards
    rewbuffer_final = deque(maxlen=100)  # rolling buffer for episode rewards
    rewbuffer_final_stat = deque(
        maxlen=100)  # rolling buffer for episode rewardsn

    seg_gen = traj_segment_generator(args, pi, env, timesteps_per_actorbatch,
                                     True, loss_g_gen_step_func,
                                     loss_g_gen_final_func)

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"
    if args.load == 1:
        try:
            fname = './ckpt/' + args.name_full_load
            sess = tf.get_default_session()
            # sess.run(tf.compat.v1.global_variables_initializer())
            saver = tf.train.Saver(var_list_pi)
            saver.restore(sess, fname)
            iters_so_far = int(fname.split('_')[-1]) + 1
            print('model restored!', fname, 'iters_so_far:', iters_so_far)
        except:
            print(fname, 'ckpt not found, start with iters 0')

    U.initialize()
    adam_pi.sync()
    adam_pi_stop.sync()
    adam_d_step.sync()
    adam_d_final.sync()

    counter = 0
    level = 0
    ## start training
    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        # logger.log("********** Iteration %i ************"%iters_so_far)

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)
        ob_adj, ob_node, ac, atarg, tdlamret = seg["ob_adj"], seg[
            "ob_node"], seg["ac"], seg["adv"], seg["tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(ob_adj=ob_adj,
                         ob_node=ob_node,
                         ac=ac,
                         atarg=atarg,
                         vtarg=tdlamret),
                    shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob_adj.shape[0]

        # inner training loop, train policy
        for i_optim in range(optim_epochs):

            loss_expert = 0
            loss_expert_stop = 0
            g_expert = 0
            g_expert_stop = 0

            loss_d_step = 0
            loss_d_final = 0
            g_ppo = 0
            g_d_step = 0
            g_d_final = 0

            pretrain_shift = 5
            ## Expert
            if iters_so_far >= args.expert_start and iters_so_far <= args.expert_end + pretrain_shift:
                ## Expert train
                # # # learn how to stop
                ob_expert, ac_expert = env.get_expert(optim_batchsize)
                loss_expert, g_expert = lossandgrad_expert(
                    ob_expert['adj'], ob_expert['node'], ac_expert, ac_expert)
                loss_expert = np.mean(loss_expert)

            ## PPO
            if iters_so_far >= args.rl_start and iters_so_far <= args.rl_end:
                assign_old_eq_new(
                )  # set old parameter values to new parameter values
                batch = d.next_batch(optim_batchsize)
                # ppo
                if iters_so_far >= args.rl_start + pretrain_shift:  # start generator after discriminator trained a well..
                    *newlosses, g_ppo = lossandgrad_ppo(
                        batch["ob_adj"], batch["ob_node"], batch["ac"],
                        batch["ac"], batch["ac"], batch["atarg"],
                        batch["vtarg"], cur_lrmult)
                    losses_ppo = newlosses

                if args.has_d_step == 1 and i_optim >= optim_epochs // 2:
                    # update step discriminator
                    ob_expert, _ = env.get_expert(
                        optim_batchsize,
                        curriculum=args.curriculum,
                        evel_total=args.curriculum_num,
                        evel=level)
                    loss_d_step, g_d_step = lossandgrad_d_step(
                        ob_expert["adj"], ob_expert["node"], batch["ob_adj"],
                        batch["ob_node"])
                    adam_d_step.update(g_d_step, optim_stepsize * cur_lrmult)
                    loss_d_step = np.mean(loss_d_step)

                if args.has_d_final == 1 and i_optim >= optim_epochs // 4 * 3:
                    # update final discriminator
                    ob_expert, _ = env.get_expert(
                        optim_batchsize,
                        is_final=True,
                        curriculum=args.curriculum,
                        level_total=args.curriculum_num,
                        level=level)
                    seg_final_adj, seg_final_node = traj_final_generator(
                        pi, copy.deepcopy(env), optim_batchsize, True)
                    # update final discriminator
                    loss_d_final, g_d_final = lossandgrad_d_final(
                        ob_expert["adj"], ob_expert["node"], seg_final_adj,
                        seg_final_node)
                    adam_d_final.update(g_d_final, optim_stepsize * cur_lrmult)

            # update generator
            adam_pi.update(0.2 * g_ppo + 0.05 * g_expert,
                           optim_stepsize * cur_lrmult)

        # WGAN
        # if args.has_d_step == 1:
        #     clip_D = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in var_list_d_step]
        # if args.has_d_final == 1:
        #     clip_D = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in var_list_d_final]
        #

        ## PPO val
        # if iters_so_far >= args.rl_start and iters_so_far <= args.rl_end:
        # logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob_adj"], batch["ob_node"],
                                       batch["ac"], batch["ac"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        # logger.log(fmt_row(13, meanlosses))

        if writer is not None:
            writer.add_scalar("loss_expert", loss_expert, iters_so_far)
            writer.add_scalar("loss_expert_stop", loss_expert_stop,
                              iters_so_far)
            writer.add_scalar("loss_d_step", loss_d_step, iters_so_far)
            writer.add_scalar("loss_d_final", loss_d_final, iters_so_far)
            writer.add_scalar('grad_expert_min', np.amin(g_expert),
                              iters_so_far)
            writer.add_scalar('grad_expert_max', np.amax(g_expert),
                              iters_so_far)
            writer.add_scalar('grad_expert_norm', np.linalg.norm(g_expert),
                              iters_so_far)
            writer.add_scalar('grad_expert_stop_min', np.amin(g_expert_stop),
                              iters_so_far)
            writer.add_scalar('grad_expert_stop_max', np.amax(g_expert_stop),
                              iters_so_far)
            writer.add_scalar('grad_expert_stop_norm',
                              np.linalg.norm(g_expert_stop), iters_so_far)
            writer.add_scalar('grad_rl_min', np.amin(g_ppo), iters_so_far)
            writer.add_scalar('grad_rl_max', np.amax(g_ppo), iters_so_far)
            writer.add_scalar('grad_rl_norm', np.linalg.norm(g_ppo),
                              iters_so_far)
            writer.add_scalar('g_d_step_min', np.amin(g_d_step), iters_so_far)
            writer.add_scalar('g_d_step_max', np.amax(g_d_step), iters_so_far)
            writer.add_scalar('g_d_step_norm', np.linalg.norm(g_d_step),
                              iters_so_far)
            writer.add_scalar('g_d_final_min', np.amin(g_d_final),
                              iters_so_far)
            writer.add_scalar('g_d_final_max', np.amax(g_d_final),
                              iters_so_far)
            writer.add_scalar('g_d_final_norm', np.linalg.norm(g_d_final),
                              iters_so_far)
            writer.add_scalar('learning_rate', optim_stepsize * cur_lrmult,
                              iters_so_far)

        for (lossval, name) in zipsame(meanlosses, loss_names):
            # logger.record_tabular("loss_"+name, lossval)
            if writer is not None:
                writer.add_scalar("loss_" + name, lossval, iters_so_far)
        # logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))
        if writer is not None:
            writer.add_scalar("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret),
                              iters_so_far)
        lrlocal = (seg["ep_lens"], seg["ep_lens_valid"], seg["ep_rets"],
                   seg["ep_rets_env"], seg["ep_rets_d_step"],
                   seg["ep_rets_d_final"], seg["ep_final_rew"],
                   seg["ep_final_rew_stat"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, lens_valid, rews, rews_env, rews_d_step, rews_d_final, rews_final, rews_final_stat = map(
            flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        lenbuffer_valid.extend(lens_valid)
        rewbuffer.extend(rews)
        rewbuffer_d_step.extend(rews_d_step)
        rewbuffer_d_final.extend(rews_d_final)
        rewbuffer_env.extend(rews_env)
        rewbuffer_final.extend(rews_final)
        rewbuffer_final_stat.extend(rews_final_stat)
        # logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        # logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        # logger.record_tabular("EpThisIter", len(lens))
        if writer is not None:
            writer.add_scalar("EpLenMean", np.mean(lenbuffer), iters_so_far)
            writer.add_scalar("EpLenValidMean", np.mean(lenbuffer_valid),
                              iters_so_far)
            writer.add_scalar("EpRewMean", np.mean(rewbuffer), iters_so_far)
            writer.add_scalar("EpRewDStepMean", np.mean(rewbuffer_d_step),
                              iters_so_far)
            writer.add_scalar("EpRewDFinalMean", np.mean(rewbuffer_d_final),
                              iters_so_far)
            writer.add_scalar("EpRewEnvMean", np.mean(rewbuffer_env),
                              iters_so_far)
            writer.add_scalar("EpRewFinalMean", np.mean(rewbuffer_final),
                              iters_so_far)
            writer.add_scalar("EpRewFinalStatMean",
                              np.mean(rewbuffer_final_stat), iters_so_far)
            writer.add_scalar("EpThisIter", len(lens), iters_so_far)
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        # logger.record_tabular("EpisodesSoFar", episodes_so_far)
        # logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        # logger.record_tabular("TimeElapsed", time.time() - tstart)
        if writer is not None:
            writer.add_scalar("EpisodesSoFar", episodes_so_far, iters_so_far)
            writer.add_scalar("TimestepsSoFar", timesteps_so_far, iters_so_far)
            writer.add_scalar("TimeElapsed",
                              time.time() - tstart, iters_so_far)

        if MPI.COMM_WORLD.Get_rank() == 0:
            with open('molecule_gen/' + args.name_full + '.csv', 'a') as f:
                f.write('***** Iteration {} *****\n'.format(iters_so_far))
            # save
            if iters_so_far % args.save_every == 0:
                fname = './ckpt/' + args.name_full + '_' + str(iters_so_far)
                saver = tf.compat.v1.train.Saver(var_list_pi)
                saver.save(tf.compat.v1.get_default_session(), fname)
                print('model saved!', fname)
                # fname = os.path.join(ckpt_dir, task_name)
                # os.makedirs(os.path.dirname(fname), exist_ok=True)
                # saver = tf.train.Saver()
                # saver.save(tf.get_default_session(), fname)
            # if iters_so_far==args.load_step:
        iters_so_far += 1
        counter += 1
        if counter % args.curriculum_step and counter // args.curriculum_step < args.curriculum_num:
            level += 1
Пример #2
0
def train_student(klts):
    env = make_mujoco_env("Reacher-v2", 0)
    with tf.Session() as sess:
        # Initialize agents
        student = StudentAgent(env, sess, False, klts)
        teacher = TeacherAgent(env, sess, True)

        # This observation placeholder is for querying teacher action
        # ob_ph = U.get_placeholder( name="ob", dtype=tf.float32,
        #       shape=[1, env.observation_space.shape[0] ] )

        ob_placeholder = U.get_placeholder(name="ob",
                                           dtype=tf.float32,
                                           shape=[TRAINING_BATCH_SIZE] +
                                           list(env.observation_space.shape))

        # get all hidden layer variables of the student pi
        student_var = tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES,
            scope="s_pi_{0}".format("klts" if klts else "klst"))
        # print(student_var)
        teacher_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                        scope="t_pi")

        # KL Divergence
        if klts:
            kl_div = teacher.pi.pd.kl(student.pi.pd)
        else:
            kl_div = student.pi.pd.kl(teacher.pi.pd)

        # define loss and gradient with thenos-like function
        # gradients wrt only to student variables
        lossandgrad = U.function([ob_placeholder],
                                 [kl_div] + [U.flatgrad(kl_div, student_var)])

        logstd = U.function([ob_placeholder],
                            [teacher.pi.pd.logstd, student.pi.pd.logstd])
        std = U.function([ob_placeholder],
                         [teacher.pi.pd.std, student.pi.pd.std])
        mean = U.function([ob_placeholder],
                          [teacher.pi.pd.mean, student.pi.pd.mean])

        # initialize only student variables
        U.initialize(
            tf.get_collection(
                tf.GraphKeys.GLOBAL_VARIABLES,
                scope="s_pi_{0}".format("klts" if klts else "klst")))

        # Adam optimiizer
        adam = MpiAdam(student_var, epsilon=1e-3)
        adam.sync()

        ob = env.reset()

        obs = []
        losses = []
        timesteps = []
        rets = []
        ret = 0
        num_resets = 0

        saver = tf.train.Saver(var_list=tf.get_collection(
            tf.GraphKeys.GLOBAL_VARIABLES,
            scope='s_pi_{0}'.format("klts" if klts else "klst")))

        # saver.restore(sess, "/Users/winstonww/RL/reacher_v1/student_{0}.ckpt".format("klts" if klts else "klst"))

        for timestep in range(1, TOTAL_EPISODES * TIMESTEPS_PER_EPISODE):
            # sample action
            # feed obs dict of size two
            #  append to zeros of size [100.11], so that we can use the same
            #  model to query and train at the same time
            ob = np.expand_dims(
                ob, axis=0) + np.zeros([TRAINING_BATCH_SIZE] +
                                       list(env.observation_space.shape))
            s_ac, _ = student.pi.act(False, ob)
            # print( " ob size: {0} ".format(ob.shape))
            # print( "s_ac shape" )
            # print( s_ac.shape )
            # tread along the student trajectory
            ob, reward, new, _ = env.step(s_ac)
            ret += reward
            if new:
                rets.append(ret)
                ret = 0
                ob = env.reset()
                num_resets += 1
                if num_resets > 40000:
                    break
            # env.render()
            # print( "ob to be appended: {0}".format(ob))
            obs.append(ob)

            # compute newloss and its gradient from the two actions sampled
            # if (timestep % TRAINING_BATCH_SIZE != 0 or not timestep):
            #     continue

            # accumulate more samples before starting
            if len(obs) < TRAINING_BATCH_SIZE:
                continue

            d = Dataset(dict(ob=np.array(obs)))
            batch = d.next_batch(TRAINING_BATCH_SIZE)

            newloss, g = lossandgrad(
                np.squeeze(np.stack(list(batch.values()), axis=0), axis=0))
            adam.update(g, 0.001)

            # record the following data only when reset to save time
            if new:
                losses.append(sum(newloss))
                timesteps.append(timestep)

                if num_resets % 100 == 0:
                    print("********** Episode {0} ***********".format(
                        num_resets))
                    print("obs: \n{0}".format(
                        np.squeeze(np.stack(list(batch.values()), axis=0),
                                   axis=0)))
                    t_m, s_m = mean(
                        np.squeeze(np.stack(list(batch.values()), axis=0),
                                   axis=0))
                    t_std, s_std = std(
                        np.squeeze(np.stack(list(batch.values()), axis=0),
                                   axis=0))
                    print("student pd std: \n{0}".format(s_std))
                    print("teacher pd std: \n{0}".format(t_std))
                    print("student pd mean: \n{0}".format(s_m))
                    print("teacher pd mean: \n{0}".format(t_m))
                    print("KL divergence: \n{0}".format(sum(newloss)))

            if timestep % 5000 == 0:
                # save results
                np.save(
                    klts_training_loss_path
                    if klts else klst_training_loss_path, losses)
                np.save(
                    klts_training_ret_path if klts else klst_training_ret_path,
                    rets)
                # save kl
                save_path = saver.save(
                    sess,
                    "/Users/winstonww/RL/reacher_v1/student_{0}.ckpt".format(
                        "klts" if klts else "klst"))

        # save results
        np.save(klts_training_loss_path if klts else klst_training_loss_path,
                losses)
        np.save(klts_training_ret_path if klts else klst_training_ret_path,
                rets)

        save_path = saver.save(
            sess, "/Users/winstonww/RL/reacher_v1/student_{0}.ckpt".format(
                "klts" if klts else "klst"))
Пример #3
0
def learn(
        env,
        policy_func,
        disc,
        *,
        timesteps_per_batch,  # timesteps per actor per update
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        optim_epochs,
        optim_stepsize,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,  # advantage estimation
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        max_seconds=0,  # time constraint
        callback=None,  # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
        logdir=".",
        agentName="PPO-Agent",
        resume=0,
        num_parallel=0,
        num_cpu=1,
        num_extra=0,
        gan_batch_size=128,
        gan_num_epochs=5,
        gan_display_step=40,
        resume_disc=0,
        resume_non_disc=0,
        mocap_path="",
        gan_replay_buffer_size=1000000,
        gan_prob_to_put_in_replay=0.01,
        gan_reward_to_retrain_discriminator=5,
        use_distance=0,
        use_blend=0):
    # Deal with GAN
    if not use_distance:
        replay_buf = MyReplayBuffer(gan_replay_buffer_size)
    data = np.loadtxt(
        mocap_path + ".dat"
    )  #"D:/p4sw/devrel/libdev/flex/dev/rbd/data/bvh/motion_simple.dat");
    label = np.concatenate((np.ones(
        (data.shape[0], 1)), np.zeros((data.shape[0], 1))),
                           axis=1)

    print("Real data label = " + str(label))

    mocap_set = Dataset(dict(data=data, label=label), shuffle=True)

    # Setup losses and stuff
    # ----------------------------------------
    rank = MPI.COMM_WORLD.Get_rank()
    ob_space = env.observation_space
    ac_space = env.action_space

    ob_size = ob_space.shape[0]
    ac_size = ac_space.shape[0]

    #print("rank = " + str(rank) + " ob_space = "+str(ob_space.shape) + " ac_space = "+str(ac_space.shape))
    #exit(0)
    pi = policy_func("pi", ob_space,
                     ac_space)  # Construct network for new policy
    oldpi = policy_func("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = U.mean(kloldnew)
    meanent = U.mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = U.clip(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg  #
    pol_surr = -U.mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vfloss1 = tf.square(pi.vpred - ret)
    vpredclipped = oldpi.vpred + tf.clip_by_value(pi.vpred - oldpi.vpred,
                                                  -clip_param, clip_param)
    vfloss2 = tf.square(vpredclipped - ret)
    vf_loss = .5 * U.mean(
        tf.maximum(vfloss1, vfloss2)
    )  # we do the same clipping-based trust region for the value function
    #vf_loss = U.mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    adam.sync()

    # Prepare for rollouts
    # ----------------------------------------
    sess = tf.get_default_session()

    avars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    non_disc_vars = [
        a for a in avars
        if not a.name.split("/")[0].startswith("discriminator")
    ]
    disc_vars = [
        a for a in avars if a.name.split("/")[0].startswith("discriminator")
    ]
    #print(str(non_disc_names))
    #print(str(disc_names))
    #exit(0)
    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    disc_saver = tf.train.Saver(disc_vars, max_to_keep=None)
    non_disc_saver = tf.train.Saver(non_disc_vars, max_to_keep=None)
    saver = tf.train.Saver(max_to_keep=None)
    if resume > 0:
        saver.restore(
            tf.get_default_session(),
            os.path.join(os.path.abspath(logdir),
                         "{}-{}".format(agentName, resume)))
        if not use_distance:
            if os.path.exists(logdir + "\\" + 'replay_buf_' +
                              str(int(resume / 100) * 100) + '.pkl'):
                print("Load replay buf")
                with open(
                        logdir + "\\" + 'replay_buf_' +
                        str(int(resume / 100) * 100) + '.pkl', 'rb') as f:
                    replay_buf = pickle.load(f)
            else:
                print("Can't load replay buf " + logdir + "\\" +
                      'replay_buf_' + str(int(resume / 100) * 100) + '.pkl')
    iters_so_far = resume

    if resume_non_disc > 0:
        non_disc_saver.restore(
            tf.get_default_session(),
            os.path.join(
                os.path.abspath(logdir),
                "{}-{}".format(agentName + "_non_disc", resume_non_disc)))
        iters_so_far = resume_non_disc

    if use_distance:
        print("Use distance")
        nn = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(data)
    else:
        nn = None
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     disc,
                                     timesteps_per_batch,
                                     stochastic=True,
                                     num_parallel=num_parallel,
                                     num_cpu=num_cpu,
                                     rank=rank,
                                     ob_size=ob_size,
                                     ac_size=ac_size,
                                     com=MPI.COMM_WORLD,
                                     num_extra=num_extra,
                                     iters_so_far=iters_so_far,
                                     use_distance=use_distance,
                                     nn=nn)

    if resume_disc > 0:
        disc_saver.restore(
            tf.get_default_session(),
            os.path.join(os.path.abspath(logdir),
                         "{}-{}".format(agentName + "_disc", resume_disc)))

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"
    logF = open(logdir + "\\" + 'log.txt', 'a')
    logR = open(logdir + "\\" + 'log_rew.txt', 'a')
    logStats = open(logdir + "\\" + 'log_stats.txt', 'a')
    if os.path.exists(logdir + "\\" + 'ob_list_' + str(rank) + '.pkl'):
        with open(logdir + "\\" + 'ob_list_' + str(rank) + '.pkl', 'rb') as f:
            ob_list = pickle.load(f)
    else:
        ob_list = []

    dump_training = 0
    learn_from_training = 0
    if dump_training:
        # , "mean": pi.ob_rms.mean, "std": pi.ob_rms.std
        saverRMS = tf.train.Saver({
            "_sum": pi.ob_rms._sum,
            "_sumsq": pi.ob_rms._sumsq,
            "_count": pi.ob_rms._count
        })
        saverRMS.save(tf.get_default_session(),
                      os.path.join(os.path.abspath(logdir), "rms.tf"))

        ob_np_a = np.asarray(ob_list)
        ob_np = np.reshape(ob_np_a, (-1, ob_size))
        [vpred, pdparam] = pi._vpred_pdparam(ob_np)

        print("vpred = " + str(vpred))
        print("pd_param = " + str(pdparam))
        with open('training.pkl', 'wb') as f:
            pickle.dump(ob_np, f)
            pickle.dump(vpred, f)
            pickle.dump(pdparam, f)
        exit(0)
    if learn_from_training:
        # , "mean": pi.ob_rms.mean, "std": pi.ob_rms.std

        with open('training.pkl', 'rb') as f:
            ob_np = pickle.load(f)
            vpred = pickle.load(f)
            pdparam = pickle.load(f)
        num = ob_np.shape[0]
        for i in range(num):
            xp = ob_np[i][1]
            ob_np[i][1] = 0.0
            ob_np[i][18] -= xp
            ob_np[i][22] -= xp
            ob_np[i][24] -= xp
            ob_np[i][26] -= xp
            ob_np[i][28] -= xp
            ob_np[i][30] -= xp
            ob_np[i][32] -= xp
            ob_np[i][34] -= xp
        print("ob_np = " + str(ob_np))
        print("vpred = " + str(vpred))
        print("pdparam = " + str(pdparam))
        batch_size = 128

        y_vpred = tf.placeholder(tf.float32, [
            batch_size,
        ])
        y_pdparam = tf.placeholder(tf.float32, [batch_size, pdparam.shape[1]])

        vpred_loss = U.mean(tf.square(pi.vpred - y_vpred))
        vpdparam_loss = U.mean(tf.square(pi.pdparam - y_pdparam))

        total_train_loss = vpred_loss + vpdparam_loss
        #total_train_loss = vpdparam_loss
        #total_train_loss = vpred_loss
        #coef = 0.01
        #dense_all = U.dense_all
        #for a in dense_all:
        #   total_train_loss += coef * tf.nn.l2_loss(a)
        #total_train_loss = vpdparam_loss
        optimizer = tf.train.AdamOptimizer(
            learning_rate=0.001).minimize(total_train_loss)
        d = Dataset(dict(ob=ob_np, vpred=vpred, pdparam=pdparam),
                    shuffle=not pi.recurrent)
        sess = tf.get_default_session()
        sess.run(tf.global_variables_initializer())
        saverRMS = tf.train.Saver({
            "_sum": pi.ob_rms._sum,
            "_sumsq": pi.ob_rms._sumsq,
            "_count": pi.ob_rms._count
        })
        saverRMS.restore(tf.get_default_session(),
                         os.path.join(os.path.abspath(logdir), "rms.tf"))
        if resume > 0:
            saver.restore(
                tf.get_default_session(),
                os.path.join(os.path.abspath(logdir),
                             "{}-{}".format(agentName, resume)))

        for q in range(100):
            sumLoss = 0
            for batch in d.iterate_once(batch_size):
                tl, _ = sess.run(
                    [total_train_loss, optimizer],
                    feed_dict={
                        pi.ob: batch["ob"],
                        y_vpred: batch["vpred"],
                        y_pdparam: batch["pdparam"]
                    })
                sumLoss += tl
            print("Iteration " + str(q) + " Loss = " + str(sumLoss))
        assign_old_eq_new()  # set old parameter values to new parameter values

        # Save as frame 1
        try:
            saver.save(tf.get_default_session(),
                       os.path.join(logdir, agentName),
                       global_step=1)
        except:
            pass
        #exit(0)
    if resume > 0:
        firstTime = False
    else:
        firstTime = True

    # Check accuracy
    #amocap = sess.run([disc.accuracy],
    #                feed_dict={disc.input: data,
    #                           disc.label: label})
    #print("Mocap accuracy = " + str(amocap))
    #print("Mocap label is " + str(label))

    #adata = np.array(replay_buf._storage)
    #print("adata shape = " + str(adata.shape))
    #alabel = np.concatenate((np.zeros((adata.shape[0], 1)), np.ones((adata.shape[0], 1))), axis=1)

    #areplay = sess.run([disc.accuracy],
    #                feed_dict={disc.input: adata,
    #                           disc.label: alabel})
    #print("Replay accuracy = " + str(areplay))
    #print("Replay label is " + str(alabel))
    #exit(0)
    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()

        add_vtarg_and_adv(seg, gamma, lam, timesteps_per_batch, num_parallel,
                          num_cpu)
        #print(" ob= " + str(seg["ob"])+ " rew= " + str(seg["rew"])+ " vpred= " + str(seg["vpred"])+ " new= " + str(seg["new"])+ " ac= " + str(seg["ac"])+ " prevac= " + str(seg["prevac"])+ " nextvpred= " + str(seg["nextvpred"])+ " ep_rets= " + str(seg["ep_rets"])+ " ep_lens= " + str(seg["ep_lens"]))

        #exit(0)
        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret, extra = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"], seg["extra"]

        #ob_list.append(ob.tolist())
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            #print(str(losses))
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        rewmean = np.mean(rewbuffer)
        logger.record_tabular("EpRewMean", rewmean)
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)

        # Train discriminator
        if not use_distance:
            print("Put in replay buf " +
                  str((int)(gan_prob_to_put_in_replay * extra.shape[0] + 1)))
            replay_buf.add(extra[np.random.choice(
                extra.shape[0],
                (int)(gan_prob_to_put_in_replay * extra.shape[0] + 1),
                replace=True)])
            #if iters_so_far == 1:
            if not use_blend:
                if firstTime:
                    firstTime = False
                    # Train with everything we got
                    lb = np.concatenate((np.zeros(
                        (extra.shape[0], 1)), np.ones((extra.shape[0], 1))),
                                        axis=1)
                    extra_set = Dataset(dict(data=extra, label=lb),
                                        shuffle=True)
                    for e in range(10):
                        i = 0
                        for mbatch in mocap_set.iterate_once(gan_batch_size):
                            batch = extra_set.next_batch(gan_batch_size)
                            _, l = sess.run(
                                [disc.optimizer_first, disc.loss],
                                feed_dict={
                                    disc.input:
                                    np.concatenate(
                                        (mbatch['data'], batch['data'])),
                                    disc.label:
                                    np.concatenate(
                                        (mbatch['label'], batch['label']))
                                })
                            i = i + 1
                            # Display logs per step
                            if i % gan_display_step == 0 or i == 1:
                                print(
                                    'discriminator epoch %i Step %i: Minibatch Loss: %f'
                                    % (e, i, l))
                        print(
                            'discriminator epoch %i Step %i: Minibatch Loss: %f'
                            % (e, i, l))
                if seg['mean_ext_rew'] > gan_reward_to_retrain_discriminator:
                    for e in range(gan_num_epochs):
                        i = 0
                        for mbatch in mocap_set.iterate_once(gan_batch_size):
                            data = replay_buf.sample(mbatch['data'].shape[0])
                            lb = np.concatenate((np.zeros(
                                (data.shape[0], 1)), np.ones(
                                    (data.shape[0], 1))),
                                                axis=1)
                            _, l = sess.run(
                                [disc.optimizer, disc.loss],
                                feed_dict={
                                    disc.input:
                                    np.concatenate((mbatch['data'], data)),
                                    disc.label:
                                    np.concatenate((mbatch['label'], lb))
                                })
                            i = i + 1
                            # Display logs per step
                            if i % gan_display_step == 0 or i == 1:
                                print(
                                    'discriminator epoch %i Step %i: Minibatch Loss: %f'
                                    % (e, i, l))
                        print(
                            'discriminator epoch %i Step %i: Minibatch Loss: %f'
                            % (e, i, l))
            else:
                if firstTime:
                    firstTime = False
                    # Train with everything we got
                    extra_set = Dataset(dict(data=extra), shuffle=True)
                    for e in range(10):
                        i = 0
                        for mbatch in mocap_set.iterate_once(gan_batch_size):
                            batch = extra_set.next_batch(gan_batch_size)
                            bf = np.random.uniform(0, 1, (gan_batch_size, 1))
                            onembf = 1 - bf
                            my_label = np.concatenate((bf, onembf), axis=1)
                            my_data = np.multiply(mbatch['data'],
                                                  bf) + np.multiply(
                                                      batch['data'], onembf)
                            _, l = sess.run([disc.optimizer_first, disc.loss],
                                            feed_dict={
                                                disc.input: my_data,
                                                disc.label: my_label
                                            })
                            i = i + 1
                            # Display logs per step
                            if i % gan_display_step == 0 or i == 1:
                                print(
                                    'discriminator epoch %i Step %i: Minibatch Loss: %f'
                                    % (e, i, l))
                        print(
                            'discriminator epoch %i Step %i: Minibatch Loss: %f'
                            % (e, i, l))
                if seg['mean_ext_rew'] > gan_reward_to_retrain_discriminator:
                    for e in range(gan_num_epochs):
                        i = 0
                        for mbatch in mocap_set.iterate_once(gan_batch_size):
                            data = replay_buf.sample(mbatch['data'].shape[0])

                            bf = np.random.uniform(0, 1, (gan_batch_size, 1))
                            onembf = 1 - bf
                            my_label = np.concatenate((bf, onembf), axis=1)
                            my_data = np.multiply(mbatch['data'],
                                                  bf) + np.multiply(
                                                      data, onembf)

                            _, l = sess.run([disc.optimizer_first, disc.loss],
                                            feed_dict={
                                                disc.input: my_data,
                                                disc.label: my_label
                                            })
                            i = i + 1
                            # Display logs per step
                            if i % gan_display_step == 0 or i == 1:
                                print(
                                    'discriminator epoch %i Step %i: Minibatch Loss: %f'
                                    % (e, i, l))
                        print(
                            'discriminator epoch %i Step %i: Minibatch Loss: %f'
                            % (e, i, l))

        # if True:
        #     lb = np.concatenate((np.zeros((extra.shape[0],1)),np.ones((extra.shape[0],1))),axis=1)
        #     extra_set = Dataset(dict(data=extra,label=lb), shuffle=True)
        #     num_r = 1
        #     if iters_so_far == 1:
        #         num_r = gan_num_epochs
        #     for e in range(num_r):
        #         i = 0
        #         for batch in extra_set.iterate_once(gan_batch_size):
        #             mbatch = mocap_set.next_batch(gan_batch_size)
        #             _, l = sess.run([disc.optimizer, disc.loss], feed_dict={disc.input: np.concatenate((mbatch['data'],batch['data'])), disc.label: np.concatenate((mbatch['label'],batch['label']))})
        #             i = i + 1
        #             # Display logs per step
        #             if i % gan_display_step == 0 or i == 1:
        #                 print('discriminator epoch %i Step %i: Minibatch Loss: %f' % (e, i, l))
        #         print('discriminator epoch %i Step %i: Minibatch Loss: %f' % (e, i, l))

        if not use_distance:
            if iters_so_far % 100 == 0:
                with open(
                        logdir + "\\" + 'replay_buf_' + str(iters_so_far) +
                        '.pkl', 'wb') as f:
                    pickle.dump(replay_buf, f)

        with open(logdir + "\\" + 'ob_list_' + str(rank) + '.pkl', 'wb') as f:
            pickle.dump(ob_list, f)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logF.write(str(rewmean) + "\n")
            logR.write(str(seg['mean_ext_rew']) + "\n")
            logStats.write(logger.get_str() + "\n")
            logF.flush()
            logStats.flush()
            logR.flush()

            logger.dump_tabular()

            try:
                os.remove(logdir + "/checkpoint")
            except OSError:
                pass
            try:
                saver.save(tf.get_default_session(),
                           os.path.join(logdir, agentName),
                           global_step=iters_so_far)
            except:
                pass
            try:
                non_disc_saver.save(tf.get_default_session(),
                                    os.path.join(logdir,
                                                 agentName + "_non_disc"),
                                    global_step=iters_so_far)
            except:
                pass
            try:
                disc_saver.save(tf.get_default_session(),
                                os.path.join(logdir, agentName + "_disc"),
                                global_step=iters_so_far)
            except:
                pass
Пример #4
0
    def replay(self, seg_list, batch_size):
        print(self.scope + " training")

        if self.schedule == 'constant':
            cur_lrmult = 1.0
        elif self.schedule == 'linear':
            cur_lrmult = max(
                1.0 - float(self.timesteps_so_far) / self.max_timesteps, 0)

        # Here we do a bunch of optimization epochs over the data
        # 批量计算的思路是,每次将所有战斗的g值得到,然后求平均,优化。循环多次
        newlosses_list = []
        logger.log("Optimizing...")
        loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]
        logger.log(fmt_row(13, loss_names))
        for _ in range(self.optim_epochs):
            g_list = []
            for seg in seg_list:

                self.add_vtarg_and_adv(seg, self.gamma, self.lam)

                # print(seg)

                # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
                ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg[
                    "adv"], seg["tdlamret"]
                vpredbefore = seg[
                    "vpred"]  # predicted value function before udpate
                atarg = (atarg - atarg.mean()) / atarg.std(
                )  # standardized advantage function estimate
                d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                            shuffle=not self.pi.recurrent)

                if hasattr(self.pi, "ob_rms"):
                    self.pi.ob_rms.update(
                        ob)  # update running mean/std for policy

                self.assign_old_eq_new(
                )  # set old parameter values to new parameter values

                # 完整的拿所有行为
                batch = d.next_batch(d.n)
                # print("ob", batch["ob"], "ac", batch["ac"], "atarg", batch["atarg"], "vtarg", batch["vtarg"])
                *newlosses, debug_atarg, pi_ac, opi_ac, vpred, pi_pd, opi_pd, kl_oldnew, total_loss, var_list, grads, g = \
                    self.lossandgrad(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult)
                # print("debug_atarg", debug_atarg, "pi_ac", pi_ac, "opi_ac", opi_ac, "vpred", vpred, "pi_pd", pi_pd,
                #       "opi_pd", opi_pd, "kl_oldnew", kl_oldnew, "var_mean", np.mean(g), "total_loss", total_loss)
                if np.isnan(np.mean(g)):
                    print('output nan, ignore it!')
                else:
                    g_list.append(g)
                    newlosses_list.append(newlosses)

            # 批量计算之后求平均在优化模型
            if len(g_list) > 0:
                avg_g = np.mean(g_list, axis=0)
                self.adam.update(avg_g, self.optim_stepsize * cur_lrmult)
                logger.log(fmt_row(13, np.mean(newlosses_list, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for seg in seg_list:
            self.add_vtarg_and_adv(seg, self.gamma, self.lam)

            # print(seg)

            # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
            ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
                "tdlamret"]
            vpredbefore = seg[
                "vpred"]  # predicted value function before udpate
            atarg = (atarg - atarg.mean()) / atarg.std(
            )  # standardized advantage function estimate
            d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                        shuffle=not self.pi.recurrent)
            # 完整的拿所有行为
            batch = d.next_batch(d.n)
            newlosses = self.compute_losses(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
            losses.append(newlosses)
        print(losses)

        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            if np.isinf(lossval):
                debug = True
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(self.flatten_lists, zip(*listoflrpairs))
        self.lenbuffer.extend(lens)
        self.rewbuffer.extend(rews)
        last_rew = self.rewbuffer[-1] if len(self.rewbuffer) > 0 else 0
        logger.record_tabular("LastRew", last_rew)
        logger.record_tabular(
            "LastLen", 0 if len(self.lenbuffer) <= 0 else self.lenbuffer[-1])
        logger.record_tabular("EpLenMean", np.mean(self.lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(self.rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        self.episodes_so_far += len(lens)
        self.timesteps_so_far += sum(lens)
        self.iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", self.episodes_so_far)
        logger.record_tabular("TimestepsSoFar", self.timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - self.tstart)
        logger.record_tabular("IterSoFar", self.iters_so_far)
        logger.record_tabular("CalulateActions", self.act_times)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()