def learn(
        args,
        env,
        policy_fn,
        *,
        timesteps_per_actorbatch,  # timesteps per actor per update
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        optim_epochs,
        optim_stepsize,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,  # advantage estimation
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        max_seconds=0,  # time constraint
        callback=None,  # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
        writer=None):
    print("\nBeginning learning...\n")

    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_fn("pi", ob_space,
                   ac_space)  # Construct network for new policy
    oldpi = policy_fn("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.compat.v1.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.compat.v1.placeholder(dtype=tf.float32,
                                   shape=[None])  # Empirical return

    lrmult = tf.compat.v1.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = {}
    ob['adj'] = U.get_placeholder_cached(name="adj")
    ob['node'] = U.get_placeholder_cached(name="node")

    ob_gen = {}
    ob_gen['adj'] = U.get_placeholder(
        shape=[None, ob_space['adj'].shape[0], None, None],
        dtype=tf.float32,
        name='adj_gen')
    ob_gen['node'] = U.get_placeholder(
        shape=[None, 1, None, ob_space['node'].shape[2]],
        dtype=tf.float32,
        name='node_gen')

    ob_real = {}
    ob_real['adj'] = U.get_placeholder(
        shape=[None, ob_space['adj'].shape[0], None, None],
        dtype=tf.float32,
        name='adj_real')
    ob_real['node'] = U.get_placeholder(
        shape=[None, 1, None, ob_space['node'].shape[2]],
        dtype=tf.float32,
        name='node_real')

    ac = tf.compat.v1.placeholder(dtype=tf.int64,
                                  shape=[None, 4],
                                  name='ac_real')

    ## PPO loss
    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    pi_logp = pi.pd.logp(ac)
    oldpi_logp = oldpi.pd.logp(ac)
    ratio_log = pi.pd.logp(ac) - oldpi.pd.logp(ac)

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param,
                             1.0 + clip_param) * atarg  #
    pol_surr = -tf.reduce_mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    ## Expert loss
    loss_expert = -tf.reduce_mean(pi_logp)

    ## Discriminator loss
    step_pred_real, step_logit_real = discriminator_net(ob_real,
                                                        args,
                                                        name='d_step')
    step_pred_gen, step_logit_gen = discriminator_net(ob_gen,
                                                      args,
                                                      name='d_step')
    loss_d_step_real = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(
            logits=step_logit_real,
            labels=tf.ones_like(step_logit_real) * 0.9))
    loss_d_step_gen = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(
            logits=step_logit_gen, labels=tf.zeros_like(step_logit_gen)))
    loss_d_step = loss_d_step_real + loss_d_step_gen
    if args.gan_type == 'normal':
        loss_g_step_gen = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=step_logit_gen, labels=tf.zeros_like(step_logit_gen)))
    elif args.gan_type == 'recommend':
        loss_g_step_gen = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=step_logit_gen,
                labels=tf.ones_like(step_logit_gen) * 0.9))
    elif args.gan_type == 'wgan':
        loss_d_step, _, _ = discriminator(ob_real, ob_gen, args, name='d_step')
        loss_d_step = loss_d_step * -1
        loss_g_step_gen, _ = discriminator_net(ob_gen, args, name='d_step')

    final_pred_real, final_logit_real = discriminator_net(ob_real,
                                                          args,
                                                          name='d_final')
    final_pred_gen, final_logit_gen = discriminator_net(ob_gen,
                                                        args,
                                                        name='d_final')
    loss_d_final_real = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(
            logits=final_logit_real,
            labels=tf.ones_like(final_logit_real) * 0.9))
    loss_d_final_gen = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(
            logits=final_logit_gen, labels=tf.zeros_like(final_logit_gen)))
    loss_d_final = loss_d_final_real + loss_d_final_gen
    if args.gan_type == 'normal':
        loss_g_final_gen = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=final_logit_gen, labels=tf.zeros_like(final_logit_gen)))
    elif args.gan_type == 'recommend':
        loss_g_final_gen = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=final_logit_gen,
                labels=tf.ones_like(final_logit_gen) * 0.9))
    elif args.gan_type == 'wgan':
        loss_d_final, _, _ = discriminator(ob_real,
                                           ob_gen,
                                           args,
                                           name='d_final')
        loss_d_final = loss_d_final * -1
        loss_g_final_gen, _ = discriminator_net(ob_gen, args, name='d_final')

    var_list_pi = pi.get_trainable_variables()
    var_list_pi_stop = [
        var for var in var_list_pi
        if ('emb' in var.name) or ('gcn' in var.name) or ('stop' in var.name)
    ]
    var_list_d_step = [
        var for var in tf.compat.v1.global_variables() if 'd_step' in var.name
    ]
    var_list_d_final = [
        var for var in tf.compat.v1.global_variables() if 'd_final' in var.name
    ]

    ## debug
    debug = {}

    ## loss update function
    lossandgrad_ppo = U.function([
        ob['adj'], ob['node'], ac, pi.ac_real, oldpi.ac_real, atarg, ret,
        lrmult
    ], losses + [U.flatgrad(total_loss, var_list_pi)])
    lossandgrad_expert = U.function(
        [ob['adj'], ob['node'], ac, pi.ac_real],
        [loss_expert, U.flatgrad(loss_expert, var_list_pi)])
    lossandgrad_expert_stop = U.function(
        [ob['adj'], ob['node'], ac, pi.ac_real],
        [loss_expert, U.flatgrad(loss_expert, var_list_pi_stop)])
    lossandgrad_d_step = U.function(
        [ob_real['adj'], ob_real['node'], ob_gen['adj'], ob_gen['node']],
        [loss_d_step, U.flatgrad(loss_d_step, var_list_d_step)])
    lossandgrad_d_final = U.function(
        [ob_real['adj'], ob_real['node'], ob_gen['adj'], ob_gen['node']],
        [loss_d_final,
         U.flatgrad(loss_d_final, var_list_d_final)])
    loss_g_gen_step_func = U.function([ob_gen['adj'], ob_gen['node']],
                                      loss_g_step_gen)
    loss_g_gen_final_func = U.function([ob_gen['adj'], ob_gen['node']],
                                       loss_g_final_gen)

    adam_pi = MpiAdam(var_list_pi, epsilon=adam_epsilon)
    adam_pi_stop = MpiAdam(var_list_pi_stop, epsilon=adam_epsilon)
    adam_d_step = MpiAdam(var_list_d_step, epsilon=adam_epsilon)
    adam_d_final = MpiAdam(var_list_d_final, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.compat.v1.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])

    compute_losses = U.function([
        ob['adj'], ob['node'], ac, pi.ac_real, oldpi.ac_real, atarg, ret,
        lrmult
    ], losses)

    # Prepare for rollouts
    # ----------------------------------------
    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    lenbuffer_valid = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards
    rewbuffer_env = deque(maxlen=100)  # rolling buffer for episode rewards
    rewbuffer_d_step = deque(maxlen=100)  # rolling buffer for episode rewards
    rewbuffer_d_final = deque(maxlen=100)  # rolling buffer for episode rewards
    rewbuffer_final = deque(maxlen=100)  # rolling buffer for episode rewards
    rewbuffer_final_stat = deque(
        maxlen=100)  # rolling buffer for episode rewardsn

    seg_gen = traj_segment_generator(args, pi, env, timesteps_per_actorbatch,
                                     True, loss_g_gen_step_func,
                                     loss_g_gen_final_func)

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"
    if args.load == 1:
        try:
            fname = './ckpt/' + args.name_full_load
            sess = tf.get_default_session()
            # sess.run(tf.compat.v1.global_variables_initializer())
            saver = tf.train.Saver(var_list_pi)
            saver.restore(sess, fname)
            iters_so_far = int(fname.split('_')[-1]) + 1
            print('model restored!', fname, 'iters_so_far:', iters_so_far)
        except:
            print(fname, 'ckpt not found, start with iters 0')

    U.initialize()
    adam_pi.sync()
    adam_pi_stop.sync()
    adam_d_step.sync()
    adam_d_final.sync()

    counter = 0
    level = 0
    ## start training
    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        # logger.log("********** Iteration %i ************"%iters_so_far)

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)
        ob_adj, ob_node, ac, atarg, tdlamret = seg["ob_adj"], seg[
            "ob_node"], seg["ac"], seg["adv"], seg["tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(ob_adj=ob_adj,
                         ob_node=ob_node,
                         ac=ac,
                         atarg=atarg,
                         vtarg=tdlamret),
                    shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob_adj.shape[0]

        # inner training loop, train policy
        for i_optim in range(optim_epochs):

            loss_expert = 0
            loss_expert_stop = 0
            g_expert = 0
            g_expert_stop = 0

            loss_d_step = 0
            loss_d_final = 0
            g_ppo = 0
            g_d_step = 0
            g_d_final = 0

            pretrain_shift = 5
            ## Expert
            if iters_so_far >= args.expert_start and iters_so_far <= args.expert_end + pretrain_shift:
                ## Expert train
                # # # learn how to stop
                ob_expert, ac_expert = env.get_expert(optim_batchsize)
                loss_expert, g_expert = lossandgrad_expert(
                    ob_expert['adj'], ob_expert['node'], ac_expert, ac_expert)
                loss_expert = np.mean(loss_expert)

            ## PPO
            if iters_so_far >= args.rl_start and iters_so_far <= args.rl_end:
                assign_old_eq_new(
                )  # set old parameter values to new parameter values
                batch = d.next_batch(optim_batchsize)
                # ppo
                if iters_so_far >= args.rl_start + pretrain_shift:  # start generator after discriminator trained a well..
                    *newlosses, g_ppo = lossandgrad_ppo(
                        batch["ob_adj"], batch["ob_node"], batch["ac"],
                        batch["ac"], batch["ac"], batch["atarg"],
                        batch["vtarg"], cur_lrmult)
                    losses_ppo = newlosses

                if args.has_d_step == 1 and i_optim >= optim_epochs // 2:
                    # update step discriminator
                    ob_expert, _ = env.get_expert(
                        optim_batchsize,
                        curriculum=args.curriculum,
                        evel_total=args.curriculum_num,
                        evel=level)
                    loss_d_step, g_d_step = lossandgrad_d_step(
                        ob_expert["adj"], ob_expert["node"], batch["ob_adj"],
                        batch["ob_node"])
                    adam_d_step.update(g_d_step, optim_stepsize * cur_lrmult)
                    loss_d_step = np.mean(loss_d_step)

                if args.has_d_final == 1 and i_optim >= optim_epochs // 4 * 3:
                    # update final discriminator
                    ob_expert, _ = env.get_expert(
                        optim_batchsize,
                        is_final=True,
                        curriculum=args.curriculum,
                        level_total=args.curriculum_num,
                        level=level)
                    seg_final_adj, seg_final_node = traj_final_generator(
                        pi, copy.deepcopy(env), optim_batchsize, True)
                    # update final discriminator
                    loss_d_final, g_d_final = lossandgrad_d_final(
                        ob_expert["adj"], ob_expert["node"], seg_final_adj,
                        seg_final_node)
                    adam_d_final.update(g_d_final, optim_stepsize * cur_lrmult)

            # update generator
            adam_pi.update(0.2 * g_ppo + 0.05 * g_expert,
                           optim_stepsize * cur_lrmult)

        # WGAN
        # if args.has_d_step == 1:
        #     clip_D = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in var_list_d_step]
        # if args.has_d_final == 1:
        #     clip_D = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in var_list_d_final]
        #

        ## PPO val
        # if iters_so_far >= args.rl_start and iters_so_far <= args.rl_end:
        # logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob_adj"], batch["ob_node"],
                                       batch["ac"], batch["ac"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        # logger.log(fmt_row(13, meanlosses))

        if writer is not None:
            writer.add_scalar("loss_expert", loss_expert, iters_so_far)
            writer.add_scalar("loss_expert_stop", loss_expert_stop,
                              iters_so_far)
            writer.add_scalar("loss_d_step", loss_d_step, iters_so_far)
            writer.add_scalar("loss_d_final", loss_d_final, iters_so_far)
            writer.add_scalar('grad_expert_min', np.amin(g_expert),
                              iters_so_far)
            writer.add_scalar('grad_expert_max', np.amax(g_expert),
                              iters_so_far)
            writer.add_scalar('grad_expert_norm', np.linalg.norm(g_expert),
                              iters_so_far)
            writer.add_scalar('grad_expert_stop_min', np.amin(g_expert_stop),
                              iters_so_far)
            writer.add_scalar('grad_expert_stop_max', np.amax(g_expert_stop),
                              iters_so_far)
            writer.add_scalar('grad_expert_stop_norm',
                              np.linalg.norm(g_expert_stop), iters_so_far)
            writer.add_scalar('grad_rl_min', np.amin(g_ppo), iters_so_far)
            writer.add_scalar('grad_rl_max', np.amax(g_ppo), iters_so_far)
            writer.add_scalar('grad_rl_norm', np.linalg.norm(g_ppo),
                              iters_so_far)
            writer.add_scalar('g_d_step_min', np.amin(g_d_step), iters_so_far)
            writer.add_scalar('g_d_step_max', np.amax(g_d_step), iters_so_far)
            writer.add_scalar('g_d_step_norm', np.linalg.norm(g_d_step),
                              iters_so_far)
            writer.add_scalar('g_d_final_min', np.amin(g_d_final),
                              iters_so_far)
            writer.add_scalar('g_d_final_max', np.amax(g_d_final),
                              iters_so_far)
            writer.add_scalar('g_d_final_norm', np.linalg.norm(g_d_final),
                              iters_so_far)
            writer.add_scalar('learning_rate', optim_stepsize * cur_lrmult,
                              iters_so_far)

        for (lossval, name) in zipsame(meanlosses, loss_names):
            # logger.record_tabular("loss_"+name, lossval)
            if writer is not None:
                writer.add_scalar("loss_" + name, lossval, iters_so_far)
        # logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))
        if writer is not None:
            writer.add_scalar("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret),
                              iters_so_far)
        lrlocal = (seg["ep_lens"], seg["ep_lens_valid"], seg["ep_rets"],
                   seg["ep_rets_env"], seg["ep_rets_d_step"],
                   seg["ep_rets_d_final"], seg["ep_final_rew"],
                   seg["ep_final_rew_stat"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, lens_valid, rews, rews_env, rews_d_step, rews_d_final, rews_final, rews_final_stat = map(
            flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        lenbuffer_valid.extend(lens_valid)
        rewbuffer.extend(rews)
        rewbuffer_d_step.extend(rews_d_step)
        rewbuffer_d_final.extend(rews_d_final)
        rewbuffer_env.extend(rews_env)
        rewbuffer_final.extend(rews_final)
        rewbuffer_final_stat.extend(rews_final_stat)
        # logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        # logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        # logger.record_tabular("EpThisIter", len(lens))
        if writer is not None:
            writer.add_scalar("EpLenMean", np.mean(lenbuffer), iters_so_far)
            writer.add_scalar("EpLenValidMean", np.mean(lenbuffer_valid),
                              iters_so_far)
            writer.add_scalar("EpRewMean", np.mean(rewbuffer), iters_so_far)
            writer.add_scalar("EpRewDStepMean", np.mean(rewbuffer_d_step),
                              iters_so_far)
            writer.add_scalar("EpRewDFinalMean", np.mean(rewbuffer_d_final),
                              iters_so_far)
            writer.add_scalar("EpRewEnvMean", np.mean(rewbuffer_env),
                              iters_so_far)
            writer.add_scalar("EpRewFinalMean", np.mean(rewbuffer_final),
                              iters_so_far)
            writer.add_scalar("EpRewFinalStatMean",
                              np.mean(rewbuffer_final_stat), iters_so_far)
            writer.add_scalar("EpThisIter", len(lens), iters_so_far)
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        # logger.record_tabular("EpisodesSoFar", episodes_so_far)
        # logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        # logger.record_tabular("TimeElapsed", time.time() - tstart)
        if writer is not None:
            writer.add_scalar("EpisodesSoFar", episodes_so_far, iters_so_far)
            writer.add_scalar("TimestepsSoFar", timesteps_so_far, iters_so_far)
            writer.add_scalar("TimeElapsed",
                              time.time() - tstart, iters_so_far)

        if MPI.COMM_WORLD.Get_rank() == 0:
            with open('molecule_gen/' + args.name_full + '.csv', 'a') as f:
                f.write('***** Iteration {} *****\n'.format(iters_so_far))
            # save
            if iters_so_far % args.save_every == 0:
                fname = './ckpt/' + args.name_full + '_' + str(iters_so_far)
                saver = tf.compat.v1.train.Saver(var_list_pi)
                saver.save(tf.compat.v1.get_default_session(), fname)
                print('model saved!', fname)
                # fname = os.path.join(ckpt_dir, task_name)
                # os.makedirs(os.path.dirname(fname), exist_ok=True)
                # saver = tf.train.Saver()
                # saver.save(tf.get_default_session(), fname)
            # if iters_so_far==args.load_step:
        iters_so_far += 1
        counter += 1
        if counter % args.curriculum_step and counter // args.curriculum_step < args.curriculum_num:
            level += 1
Exemple #2
0
    def learn(self):
        # Prepare for rollouts
        # ----------------------------------------
        seg_gen = traj_segment_generator(self.pi,
                                         self.env,
                                         self.timesteps_per_actorbatch,
                                         stochastic=True)
        episodes_so_far = 0
        timesteps_so_far = 0
        iters_so_far = 0
        tstart = time.time()
        lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
        rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards
        assert sum([
            self.max_iters > 0, self.max_timesteps > 0, self.max_episodes > 0,
            self.max_seconds > 0
        ]) == 1, "Only one time constraint permitted"
        while True:
            if (timesteps_so_far >= self.max_timesteps) and self.max_timesteps:
                break
            elif (episodes_so_far >= self.max_episodes) and self.max_episodes:
                break
            elif (iters_so_far >= self.max_iters) and self.max_iters:
                break
            elif self.max_seconds and (time.time() - tstart >=
                                       self.max_seconds):
                break

            if self.schedule == 'constant':
                cur_lrmult = 1.0
            elif self.schedule == 'linear':
                cur_lrmult = max(
                    1.0 - float(timesteps_so_far) / self.max_timesteps, 0)
            else:
                raise NotImplementedError

            logger.log("********** Iteration %i ************" % iters_so_far)

            seg = seg_gen.__next__()
            add_vtarg_and_adv(seg, self.gamma, self.lam)

            # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
            self.ob, self.ac, self.atarg, tdlamret = seg["ob"], seg["ac"], seg[
                "adv"], seg["tdlamret"]

            vpredbefore = seg[
                "vpred"]  # predicted value function before udpate
            self.atarg = (self.atarg - self.atarg.mean()) / self.atarg.std(
            )  # standardized advantage function estimate
            d = Dataset(dict(ob=self.ob,
                             ac=self.ac,
                             atarg=self.atarg,
                             vtarg=tdlamret),
                        shuffle=not self.pi.recurrent)
            self.optim_batchsize = self.optim_batchsize or self.ob.shape[0]

            if hasattr(self.pi, "ob_rms"):
                self.pi.ob_rms.update(
                    self.ob)  # update running mean/std for policy

            self.assign_old_eq_new(
            )  # set old parameter values to new parameter values
            logger.log("Optimizing...")
            logger.log(fmt_row(13, self.loss_names))
            # Here we do a bunch of optimization epochs over the data
            for _ in range(self.optim_epochs):
                losses = [
                ]  # list of tuples, each of which gives the loss for a minibatch
                for batch in d.iterate_once(self.optim_batchsize):
                    *newlosses, g = self.lossandgrad(batch["ob"], batch["ac"],
                                                     batch["atarg"],
                                                     batch["vtarg"],
                                                     cur_lrmult)
                    self.adam.update(g, self.optim_stepsize * cur_lrmult)
                    losses.append(newlosses)
                logger.log(fmt_row(13, np.mean(losses, axis=0)))

            logger.log("Evaluating losses...")
            losses = []
            for batch in d.iterate_once(self.optim_batchsize):
                newlosses = self.compute_losses(batch["ob"], batch["ac"],
                                                batch["atarg"], batch["vtarg"],
                                                cur_lrmult)
                losses.append(newlosses)
            meanlosses, _, _ = mpi_moments(losses, axis=0)
            logger.log(fmt_row(13, meanlosses))
            for (lossval, name) in zipsame(meanlosses, self.loss_names):
                logger.record_tabular("loss_" + name, lossval)
            logger.record_tabular("ev_tdlam_before",
                                  explained_variance(vpredbefore, tdlamret))
            lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
            listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
            lens, rews = map(flatten_lists, zip(*listoflrpairs))
            lenbuffer.extend(lens)
            rewbuffer.extend(rews)
            logger.record_tabular("EpLenMean", np.mean(lenbuffer))
            logger.record_tabular("EpRewMean", np.mean(rewbuffer))
            logger.record_tabular("EpThisIter", len(lens))
            episodes_so_far += len(lens)
            timesteps_so_far += sum(lens)
            iters_so_far += 1
            logger.record_tabular("EpisodesSoFar", episodes_so_far)
            logger.record_tabular("TimestepsSoFar", timesteps_so_far)
            logger.record_tabular("TimeElapsed", time.time() - tstart)
            if MPI.COMM_WORLD.Get_rank() == 0:
                logger.dump_tabular()
Exemple #3
0
    def run(self):
        # shift forward
        if len(
                self.mb_stuff[2]
        ) >= self.nsteps + self.num_steps_to_cut_left + self.num_steps_to_cut_right:
            self.mb_stuff = [l[self.nsteps:] for l in self.mb_stuff]

        mb_obs, mb_increase_ent, mb_rewards, mb_reward_avg, mb_actions, mb_values, mb_valids, mb_random_resets, \
            mb_dones, mb_neglogpacs, mb_states = self.mb_stuff
        epinfos = []
        while len(
                mb_rewards
        ) < self.nsteps + self.num_steps_to_cut_left + self.num_steps_to_cut_right:
            actions, values, states, neglogpacs = self.model.step(
                mb_obs[-1], mb_states[-1], mb_dones[-1], mb_increase_ent[-1])
            mb_actions.append(actions)
            mb_values.append(values)
            mb_states.append(states)
            mb_neglogpacs.append(neglogpacs)

            obs, rewards, dones, infos = self.env.step(actions)
            mb_obs.append(np.cast[self.model.train_model.X.dtype.name](obs))
            mb_increase_ent.append(
                np.asarray(
                    [info.get('increase_entropy', False) for info in infos],
                    dtype=np.uint8))
            mb_rewards.append(rewards)
            mb_dones.append(dones)
            mb_valids.append([
                (not info.get('replay_reset.invalid_transition', False))
                for info in infos
            ])
            mb_random_resets.append(
                np.array([
                    info.get('replay_reset.random_reset', False)
                    for info in infos
                ]))

            for info in infos:
                maybeepinfo = info.get('episode')
                if maybeepinfo: epinfos.append(maybeepinfo)

        # GAE
        mb_advs = [np.zeros_like(mb_values[0])] * (len(mb_rewards) + 1)
        for t in reversed(range(len(mb_rewards))):
            if t < self.num_steps_to_cut_left:
                mb_valids[t] = np.zeros_like(mb_valids[t])
            else:
                if t == len(mb_values) - 1:
                    next_value = self.model.value(mb_obs[-1], mb_states[-1],
                                                  mb_dones[-1])
                else:
                    next_value = mb_values[t + 1]
                use_next = np.logical_not(mb_dones[t + 1])
                adv_mask = np.logical_not(mb_random_resets[t + 1])
                delta = mb_rewards[
                    t] + self.gamma * use_next * next_value - mb_values[t]
                mb_advs[t] = adv_mask * (
                    delta + self.gamma * self.lam * use_next * mb_advs[t + 1])

        # extract arrays
        end = self.nsteps + self.num_steps_to_cut_left
        ar_mb_obs = np.asarray(mb_obs[:end],
                               dtype=self.model.train_model.X.dtype.name)
        ar_mb_ent = np.stack(mb_increase_ent[:end], axis=0)
        ar_mb_valids = np.asarray(mb_valids[:end], dtype=np.float32)
        ar_mb_actions = np.asarray(mb_actions[:end])
        ar_mb_values = np.asarray(mb_values[:end], dtype=np.float32)
        ar_mb_neglogpacs = np.asarray(mb_neglogpacs[:end], dtype=np.float32)
        ar_mb_dones = np.asarray(mb_dones[:end], dtype=np.bool)
        ar_mb_advs = np.asarray(mb_advs[:end], dtype=np.float32)
        ar_mb_rets = ar_mb_values + ar_mb_advs

        if self.norm_adv:
            adv_mean, adv_std, _ = mpi_moments(ar_mb_advs.ravel())
            ar_mb_advs = (ar_mb_advs - adv_mean) / (adv_std + 1e-7)

        # obs, increase_ent, advantages, masks, actions, values, neglogpacs, valids, returns, states, epinfos = runner.run()
        return (*map(
            sf01,
            (ar_mb_obs, ar_mb_ent, ar_mb_advs, ar_mb_dones, ar_mb_actions,
             ar_mb_values, ar_mb_neglogpacs, ar_mb_valids, ar_mb_rets)),
                mb_states[0], epinfos)
Exemple #4
0
def learn(env_list, policy_fn, *,
          timesteps_per_actorbatch,  # timesteps per actor per update
          clip_param, entcoeff,  # clipping parameter epsilon, entropy coeff
          optim_epochs, optim_stepsize, optim_batchsize,  # optimization hypers
          gamma, lam,  # advantage estimation
          max_timesteps=0, max_episodes=0, max_iters=0, max_seconds=0,  # time constraint
          callback=None,  # you can do anything in the callback, since it takes locals(), globals()
          adam_epsilon=1e-5,
          schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
          end_timesteps,
          newround
          ):

    env = env_list.popleft()
    # Open a file to record the accumulated rewards
    rewFile = open("reward/%d.txt" % (env.seed), "ab")
    resptimeFile = open("respTime/%d.txt" % (env.seed), "ab")
    pktnumFile = open("pktNum/%d.txt" % (env.seed), "ab")

    # Setup losses and stuff
    # ----------------------------------------
    vf_ob_space = env.vf_observation_space
    # ac_ob_space = env.ac_observation_space
    ac_space = env.action_space
    pi = policy_fn("pi1", vf_ob_space, ac_space)  # Construct network for new policy
    oldpi = policy_fn("oldpi", vf_ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(name="atarg", dtype=tf.float32, shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(name="ret", dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(name='lrmult', dtype=tf.float32, shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed clipping parameter epislon

    vf_ob = U.get_placeholder_cached(name="vf_ob")
    nn_in = U.get_placeholder_cached(name="nn_in")  # placeholder for nn input
    ac = pi.pdtype.sample_placeholder([None])

    # kloldnew = oldpi.pd.kl(pi.pd)
    # ent = pi.pd.entropy()
    pb_old_holder = tf.placeholder(name="pd_old", dtype=tf.float32, shape=[None, ac_space.n])
    pb_new_holder = tf.placeholder(name="pd_new", dtype=tf.float32, shape=[None, ac_space.n])
    oldpd = CategoricalPd(pb_old_holder)
    pd = CategoricalPd(pb_new_holder)
    kloldnew = oldpd.kl(pd)
    ent = pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    # ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    ratio = tf.placeholder(dtype=tf.float32, shape=[None])
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg  #
    pol_surr = - tf.reduce_mean(tf.minimum(surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    vf_var_list = [v for v in var_list if v.name.split("/")[1].startswith("vf")]
    pol_var_list = [v for v in var_list if v.name.split("/")[1].startswith("pol")]

    vf_grad = U.function([vf_ob, ret], U.flatgrad(vf_loss, vf_var_list))  # gradient of value function
    pol_nn_grad = U.function([nn_in], U.flatgrad(pi.nn_out, pol_var_list))
    vf_adam = MpiAdam(vf_var_list, epsilon=adam_epsilon)
    pol_adam = MpiAdam(pol_var_list, epsilon=adam_epsilon)
    clip_para = U.function([lrmult], [clip_param])

    assign_old_eq_new = U.function([], [], updates=[tf.assign(oldv, newv)
                                                    for (oldv, newv) in
                                                    zipsame(oldpi.get_variables(), pi.get_variables())])
    compute_losses = U.function([vf_ob, atarg, ret, lrmult, ratio, pb_new_holder, pb_old_holder], losses)

    U.initialize()
    vf_adam.sync()
    pol_adam.sync()

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi, env, timesteps_per_actorbatch, stochastic=True)

    end_timestep = end_timesteps.popleft()
    new = newround.popleft()
    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=10)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=10)  # rolling buffer for episode rewards
    env_so_far = 1

    assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0,
                max_seconds > 0]) == 1, "Only one time constraint permitted"

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            rewFile.close()
            resptimeFile.close()
            pktnumFile.close()
            para = {}
            for vf in range(len(vf_var_list)):
                # para[vf_var_list[vf].name] = vf_var_list[vf].eval()
                para[vf] = vf_var_list[vf].eval()
            for pol in range(len(pol_var_list)):
                # para[pol_var_list[pol].name] = pol_var_list[pol].eval()
                para[pol + len(vf_var_list)] = pol_var_list[pol].eval()
            f = open("network/%d-%d.txt" % (env.seed, timesteps_so_far), "wb")
            pickle.dump(para, f)
            f.close()
            print("============================= policy is stored =================================")
            break
        elif end_timestep and timesteps_so_far >= end_timestep:
            env = env_list.popleft()
            seg_gen = traj_segment_generator(pi, env, timesteps_per_actorbatch, stochastic=True)
            end_timestep = end_timesteps.popleft()
            new = newround.popleft()
            env_so_far += 1
            if True:
                para = {}
                for vf in range(len(vf_var_list)):
                    # para[vf_var_list[vf].name] = vf_var_list[vf].eval()
                    para[vf] = vf_var_list[vf].eval()
                for pol in range(len(pol_var_list)):
                    # para[pol_var_list[pol].name] = pol_var_list[pol].eval()
                    para[pol + len(vf_var_list)] = pol_var_list[pol].eval()
                f = open("network/%d-%d.txt" % (env.seed, timesteps_so_far), "wb")
                pickle.dump(para, f)
                f.close()
            print("======================== new environment (%s network settings left) ===========================" % len(env_list))
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break
        elif timesteps_so_far == 0:
            para = {}
            for vf in range(len(vf_var_list)):
                # para[vf_var_list[vf].name] = vf_var_list[vf].eval()
                para[vf] = vf_var_list[vf].eval()
            for pol in range(len(pol_var_list)):
                # para[pol_var_list[pol].name] = pol_var_list[pol].eval()
                para[pol + len(vf_var_list)] = pol_var_list[pol].eval()
            f = open("network/%d-%d.txt" % (env.seed, timesteps_so_far), "wb")
            pickle.dump(para, f)
            f.close()

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i, Environment %i ************" % (iters_so_far, env_so_far))

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # for vf in range(len(vf_var_list)):
        #     print(vf_var_list[vf].name, vf_var_list[vf].eval())
        # for pol in range(len(pol_var_list)):
        #     print(pol_var_list[pol].name, pol_var_list[pol].eval())

        record_reward(rewFile, sum(seg["rew"]))
        record_reward(resptimeFile, sum(seg["resptime"]))
        record_reward(pktnumFile, sum(seg["pktnum"]))
        print("total rewards for Iteration %s: %s" % (iters_so_far, sum(seg["rew"])))
        print("average response time: %s, num of pkts: %s" % (sum(seg["resptime"])/sum(seg["pktnum"]), sum(seg["pktnum"])))
        prob = collections.Counter(seg["ac"])  # a dict where elements are stored as dictionary keys and their counts are stored as dictionary values.
        for key in prob:
            prob[key] = prob[key]/len(seg["ac"])
        print("percentage of choosing each controller: %s" % (prob))

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        vf_ob, ac_ob, ac, atarg, tdlamret = seg["vf_ob"], seg['ac_ob'], seg["ac"], seg["adv"], seg["tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(vf_ob=vf_ob, ac_ob=ac_ob, ac=ac, atarg=atarg, vtarg=tdlamret), shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or vf_ob.shape[0]

        # if hasattr(pi, "vf_ob_rms"): pi.vf_ob_rms.update(vf_ob)  # update running mean/std for policy
        # if hasattr(pi, "nn_in_rms"):
        #     temp = ac_ob.reshape(-1,ac_ob.shape[2])
        #     pi.nn_in_rms.update(temp)

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = []  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                # calculate the value function gradient
                vf_g = vf_grad(batch["vf_ob"], batch["vtarg"])
                vf_adam.update(vf_g, optim_stepsize * cur_lrmult)

                # calculate the policy gradient
                pol_g = []
                ratios = []
                pbs_new_batch = []
                pbs_old_batch = []
                e = clip_para(cur_lrmult)[0]
                for sample_id in range(optim_batchsize):
                    sample_ac_ob = batch["ac_ob"][sample_id]
                    sample_ac = batch["ac"][sample_id]
                    probs_new = pi.calculate_ac_prob(sample_ac_ob)
                    prob_new = probs_new[sample_ac]
                    probs_old = oldpi.calculate_ac_prob(sample_ac_ob)
                    prob_old = probs_old[sample_ac]
                    if prob_old == 0:
                        logger.error("pi_old = 0 in %s th iteration %s th epoch %s th sample..." % (iters_so_far, _, sample_id))
                    r = prob_new / prob_old
                    ratios.append(r)
                    pbs_new_batch.append(probs_new)
                    pbs_old_batch.append(probs_old)
                    if (r > 1.0 + e and batch["atarg"][sample_id] > 0) or (r < 1.0 - e and batch["atarg"][sample_id] < 0) or r == 0:
                        dnn_dtheta = pol_nn_grad(sample_ac_ob[0].reshape(1, -1))
                        pol_g.append(0.*dnn_dtheta)
                    else:
                        nn = pi.calculate_ac_value(sample_ac_ob)
                        denominator = np.power(sum(nn), 2)
                        sorted_ind = np.argsort(nn)  # sort the array in ascending order
                        if len(probs_new) == 2:
                            if sample_ac == 0:
                                numerator1 = nn[1]*pol_nn_grad(sample_ac_ob[0].reshape(1,-1))
                                numerator2 = nn[0] * pol_nn_grad(sample_ac_ob[1].reshape(1, -1))
                                dpi_dtheta = -(numerator1-numerator2)/denominator
                            else:
                                numerator1 = nn[1]*pol_nn_grad(sample_ac_ob[0].reshape(1,-1))
                                numerator2 = nn[0]*pol_nn_grad(sample_ac_ob[1].reshape(1,-1))
                                dpi_dtheta = -(numerator2 - numerator1)/denominator

                            # numerator1 = nn[sorted_ind[0]]*pol_nn_grad(sample_ac_ob[sorted_ind[1]].reshape(1,-1))
                            # numerator2 = nn[sorted_ind[1]]*pol_nn_grad(sample_ac_ob[sorted_ind[0]].reshape(1,-1))
                            # dpi_dtheta = (numerator1-numerator2)/denominator

                        elif len(probs_new) == 3:
                            if sample_ac == sorted_ind[0]:
                                # the controller with lowest probability will still possible to be chosen because the probability is not zero
                                dnn_dtheta = pol_nn_grad(sample_ac_ob[0].reshape(1, -1))
                                pol_g.append(0. * dnn_dtheta)
                            else:
                                numerator1 = sum(nn) * (pol_nn_grad(sample_ac_ob[sample_ac].reshape(1,-1)) + 0.5 * pol_nn_grad(
                                    sample_ac_ob[sorted_ind[0]].reshape(1, -1)))
                                numerator2 = (nn[sample_ac] + 0.5 * nn[sorted_ind[0]]) * pol_nn_grad(sample_ac_ob)
                                dpi_dtheta = -(numerator1 - numerator2) / denominator
                        else:
                            if sample_ac == sorted_ind[-1] or sample_ac == sorted_ind[-2]:
                                numerator1 = sum(nn) * (pol_nn_grad(sample_ac_ob[sample_ac] .reshape(1,-1))+0.5*pol_nn_grad(sample_ac_ob[sorted_ind[0:-2]]))
                                numerator2 = (nn[sample_ac]+0.5*sum(nn[sorted_ind[0:-2]])) * pol_nn_grad(sample_ac_ob)
                                dpi_dtheta = -(numerator1 - numerator2) / denominator
                            else:
                                dnn_dtheta = pol_nn_grad(sample_ac_ob[0].reshape(1, -1))
                                pol_g.append(0. * dnn_dtheta)
                        pol_g.append(batch["atarg"][sample_id] * dpi_dtheta / prob_old)

                pol_g_mean = np.mean(np.array(pol_g), axis=0)
                pol_adam.update(pol_g_mean, optim_stepsize * cur_lrmult)

                newlosses = compute_losses(batch["vf_ob"], batch["atarg"], batch["vtarg"],
                                           cur_lrmult, np.array(ratios), np.array(pbs_new_batch), np.array(pbs_old_batch))

                # adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        # losses = []
        # for batch in d.iterate_once(optim_batchsize):
        #     newlosses = compute_losses(batch["vf_ob"], batch["ac_ob"], batch["ac"], batch["atarg"], batch["vtarg"],
        #                                cur_lrmult)
        #     losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        if len(lenbuffer) == 0:
            logger.record_tabular("EpLenMean", 0)
            logger.record_tabular("EpRewMean", 0)
        else:
            logger.record_tabular("EpLenMean", np.mean(lenbuffer))
            logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()
def learn(
    env,
    policy_fn,
    *,
    timesteps_per_actorbatch,  # timesteps per actor per update
    clip_param,
    entcoeff,  # clipping parameter epsilon, entropy coeff
    optim_epochs,
    optim_stepsize,
    optim_batchsize,  # optimization hypers
    gamma,
    lam,  # advantage estimation
    max_timesteps=0,
    max_episodes=0,
    max_iters=0,
    max_seconds=0,  # time constraint
    callback=None,  # you can do anything in the callback, since it takes locals(), globals()
    adam_epsilon=1e-5,
    schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
    sym_loss_weight=0.0,
    **kwargs,
):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space

    pi = policy_fn("pi", ob_space,
                   ac_space)  # Construct network for new policy

    oldpi = policy_fn("oldpi", ob_space, ac_space)  # Network for old policy

    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    atarg_novel = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function for the novelty reward term
    ret_novel = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Empirical return for the novelty reward term

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule

    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()

    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    sym_loss = sym_loss_weight * tf.reduce_mean(
        tf.square(pi.mean - pi.mirr_mean))
    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold

    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param,
                             1.0 + clip_param) * atarg  #

    surr1_novel = ratio * atarg_novel  # surrogate loss of the novelty term
    surr2_novel = tf.clip_by_value(
        ratio, 1.0 - clip_param,
        1.0 + clip_param) * atarg_novel  # surrogate loss of the novelty term

    pol_surr = -tf.reduce_mean(tf.minimum(
        surr1, surr2)) + sym_loss  # PPO's pessimistic surrogate (L^CLIP)
    pol_surr_novel = -tf.reduce_mean(tf.minimum(
        surr1_novel,
        surr2_novel)) + sym_loss  # PPO's surrogate for the novelty part

    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    vf_loss_novel = tf.reduce_mean(tf.square(pi.vpred_novel - ret_novel))

    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent, sym_loss]

    total_loss_novel = pol_surr_novel + pol_entpen + vf_loss_novel
    losses_novel = [
        pol_surr_novel, pol_entpen, vf_loss_novel, meankl, meanent, sym_loss
    ]

    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent", "symm"]

    policy_var_list = pi.get_trainable_variables(scope='pi/pol')

    policy_var_count = 0
    for vars in policy_var_list:
        count_in_var = 1
        for dim in vars.shape._dims:
            count_in_var *= dim
        policy_var_count += count_in_var

    var_list = pi.get_trainable_variables(
        scope='pi/pol') + pi.get_trainable_variables(scope='pi/vf/')
    var_list_novel = pi.get_trainable_variables(
        scope='pi/pol') + pi.get_trainable_variables(scope='pi/vf_novel/')
    var_list_pi = pi.get_trainable_variables(
        scope='pi/pol') + pi.get_trainable_variables(
            scope='pi/vf/') + pi.get_trainable_variables(scope='pi/vf_novel/')

    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])

    lossandgrad_novel = U.function(
        [ob, ac, atarg_novel, ret_novel, lrmult],
        losses_novel + [U.flatgrad(total_loss_novel, var_list_novel)])

    adam = MpiAdam(var_list, epsilon=adam_epsilon)
    adam_novel = MpiAdam(var_list_novel, epsilon=adam_epsilon)
    adam_all = MpiAdam(var_list_pi, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])

    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)
    compute_losses_novel = U.function([ob, ac, atarg_novel, ret_novel, lrmult],
                                      losses_novel)

    comp_sym_loss = U.function([], sym_loss)
    U.initialize()

    adam.sync()
    adam_novel.sync()
    adam_all.sync()

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_actorbatch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0

    novelty_update_iter_cycle = 10
    novelty_start_iter = 50
    novelty_update = True

    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards
    rewnovelbuffer = deque(
        maxlen=100)  # rolling buffer for episode novelty rewards

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    # This for debug purpose
    # from collections import defaultdict
    # sum_batch = {}
    # sum_batch = defaultdict(lambda: 0, sum_batch)

    while True:
        # if iters_so_far == 5:
        #     print("BREAK PLACEHOLDER")

        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()

        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, atarg_novel, tdlamret, tdlamret_novel = seg["ob"], seg[
            "ac"], seg["adv"], seg["adv_novel"], seg["tdlamret"], seg[
                "tdlamret_novel"]

        vpredbefore = seg["vpred"]  # predicted value function before udpate
        vprednovelbefore = seg[
            'vpred_novel']  # predicted novelty value function before update

        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        atarg_novel = (atarg_novel - atarg_novel.mean()) / atarg_novel.std(
        )  # standartized novelty advantage function estimate

        d = Dataset(dict(ob=ob,
                         ac=ac,
                         atarg=atarg,
                         vtarg=tdlamret,
                         atarg_novel=atarg_novel,
                         vtarg_novel=tdlamret_novel),
                    shuffle=not pi.recurrent)

        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        same_update_direction = True
        # Here we do a bunch of optimization epochs over the data

        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch

            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)

                *newlosses_novel, g_novel = lossandgrad_novel(
                    batch["ob"], batch["ac"], batch["atarg_novel"],
                    batch["vtarg_novel"], cur_lrmult)

                pol_g = g[0:policy_var_count]
                pol_g_novel = g_novel[0:policy_var_count]

                comm = MPI.COMM_WORLD

                pol_g_reduced = np.zeros_like(pol_g)
                pol_g_novel_reduced = np.zeros_like(pol_g_novel)

                comm.Allreduce(pol_g, pol_g_reduced, op=MPI.SUM)
                pol_g_reduced /= comm.Get_size()
                comm.Allreduce(pol_g_novel, pol_g_novel_reduced, op=MPI.SUM)
                pol_g_novel_reduced /= comm.Get_size()

                final_gradient = np.zeros(
                    len(g) + len(g_novel) - policy_var_count)
                final_gradient[policy_var_count::] = np.concatenate(
                    (g[policy_var_count::], g_novel[policy_var_count::]))

                if (np.dot(pol_g_reduced, pol_g_novel_reduced) > 0):

                    final_gradient[0:policy_var_count] = pol_g_novel
                    # final_gradient[0:policy_var_count] = pol_g

                    adam_all.update(final_gradient,
                                    optim_stepsize * cur_lrmult)
                    same_update_direction = True
                else:

                    parallel_g = (np.dot(pol_g, pol_g_novel) /
                                  np.linalg.norm(pol_g_novel)) * pol_g_novel
                    # parallel_g = (np.dot(pol_g, pol_g_novel) / np.linalg.norm(pol_g)) * pol_g

                    # final_pol_gradient = pol_g_novel - parallel_g
                    final_pol_gradient = pol_g - parallel_g

                    final_gradient[0:policy_var_count] = final_pol_gradient

                    # adam_novel.update(final_gradient, optim_stepsize * cur_lrmult)
                    adam_all.update(final_gradient,
                                    optim_stepsize * cur_lrmult)

                    # adam.update(final_gradient, optim_stepsize * cur_lrmult)
                    same_update_direction = False

                # step = optim_stepsize * cur_lrmult

                # adam.update(g, optim_stepsize * cur_lrmult)

                # adam_novel.update(g_novel, optim_stepsize * cur_lrmult)

                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult)
            # newlosses_novel = compute_losses_novel(batch["ob"], batch["ac"], batch["atarg_novel"], batch["vtarg_novel"],
            #                                        cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"], seg['ep_rets_novel']
                   )  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews, rews_novel = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        rewnovelbuffer.extend(rews_novel)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpRNoveltyRewMean", np.mean(rewnovelbuffer))
        logger.record_tabular("MirroredLoss", comp_sym_loss())
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        if iters_so_far >= novelty_start_iter and iters_so_far % novelty_update_iter_cycle == 0:
            novelty_update = not novelty_update

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        logger.record_tabular("SameUpdateDirection", same_update_direction)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()

    return pi
Exemple #6
0
    def train(self, T_max, graph_name=None):
        step = 0
        self.num_lookahead = 5

        self.reset_workers()
        self.wait_for_workers()

        stat = {
            'ploss': [],
            'vloss': [],
            'score': [],
            'int_reward': [],
            'entropy': [],
            'fwd_kl_div': [],
            'running_loss': 0
        }

        reward_tracker = RunningMeanStd()
        reward_buffer = np.empty((self.batch_size, self.num_lookahead),
                                 dtype=np.float32)
        while step < T_max:

            # these will keep tensors, which we'll use later for backpropagation
            values = []
            log_probs = []
            rewards = []
            entropies = []

            actions = []
            actions_pred = []
            features = []
            features_pred = []

            state = torch.from_numpy(self.sh_state).to(self.device)

            for i in range(self.num_lookahead):
                step += self.batch_size

                logit, value = self.model(state)
                prob = torch.softmax(logit, dim=1)
                log_prob = torch.log_softmax(logit, dim=1)
                entropy = -(prob * log_prob).sum(1, keepdim=True)

                action = prob.multinomial(1)
                sampled_lp = log_prob.gather(1, action)

                # one-hot action
                oh_action = torch.zeros(self.batch_size,
                                        self.num_actions,
                                        device=self.device).scatter_(
                                            1, action, 1)

                self.broadcast_actions(action)
                self.wait_for_workers()

                next_state = torch.from_numpy(self.sh_state).to(self.device)
                s1, s1_pred, action_pred = self.icm(state, oh_action,
                                                    next_state)

                with torch.no_grad():
                    int_reward = 0.5 * (s1 - s1_pred).pow(2).sum(dim=1,
                                                                 keepdim=True)
                reward_buffer[:, i] = int_reward.cpu().numpy().ravel()

                state = next_state

                # save variables for gradient descent
                values.append(value)
                log_probs.append(sampled_lp)
                rewards.append(int_reward)
                entropies.append(entropy)

                if not self.random:
                    actions.append(action.flatten())
                    actions_pred.append(action_pred)
                features.append(s1)
                features_pred.append(s1_pred)

                stat['entropy'].append(entropy.sum(dim=1).mean().item())
                stat['fwd_kl_div'].append(
                    torch.kl_div(s1_pred, s1).mean().item())

            # may have to update reward_buffer with gamma first
            reward_mean, reward_std, count = mpi_moments(reward_buffer.ravel())
            reward_tracker.update_from_moments(reward_mean, reward_std**2,
                                               count)
            std = np.sqrt(reward_tracker.var)
            rewards = [rwd / std for rwd in rewards]
            for rwd in rewards:
                stat['int_reward'].append(rwd.mean().item())

            state = torch.from_numpy(self.sh_state.astype(np.float32)).to(
                self.device)
            with torch.no_grad():
                _, R = self.model(state)  # R is the estimated return

            values.append(R)

            ploss = 0
            vloss = 0
            fwd_loss = 0
            inv_loss = 0

            delta = torch.zeros((self.batch_size, 1),
                                dtype=torch.float,
                                device=self.device)
            for i in reversed(range(self.num_lookahead)):
                R = rewards[i] + self.gamma * R
                advantage = R - values[i]
                vloss += (0.5 * advantage.pow(2)).mean()

                delta = rewards[i] + self.gamma * values[
                    i + 1].detach() - values[i].detach()
                ploss += -(log_probs[i] * delta +
                           0.01 * entropies[i]).mean()  # beta = 0.01

                fwd_loss += 0.5 * (features[i] -
                                   features_pred[i]).pow(2).sum(dim=1).mean()
                if not self.random:
                    inv_loss += self.cross_entropy(actions_pred[i], actions[i])

            self.optim.zero_grad()

            # inv_loss is 0 if using random features
            loss = ploss + vloss + fwd_loss + inv_loss  # 2018 Large scale curiosity paper simply sums them (no lambda and beta anymore)
            loss.backward()

            torch.nn.utils.clip_grad_norm_(
                list(self.model.parameters()) + list(self.icm.parameters()),
                40)
            self.optim.step()

            while not self.channel.empty():
                score = self.channel.get()
                stat['score'].append(score)

            stat['ploss'].append(ploss.item() / self.num_lookahead)
            stat['vloss'].append(vloss.item() / self.num_lookahead)
            stat['running_loss'] = 0.99 * stat[
                'running_loss'] + 0.01 * loss.item() / self.num_lookahead

            if len(stat['score']) > 20 and step % (self.batch_size *
                                                   1000) == 0:
                now = datetime.datetime.now().strftime("%H:%M")
                print(
                    f"Step {step: <10} | Running loss: {stat['running_loss']:.4f} | Running score: {np.mean(stat['score'][-10:]):.2f} | Time: {now}"
                )
                if graph_name is not None and step % (self.batch_size *
                                                      10000) == 0:
                    plot(step,
                         stat['score'],
                         stat['int_reward'],
                         stat['ploss'],
                         stat['vloss'],
                         stat['entropy'],
                         name=graph_name)
def learn(
    env,
    policy_func,
    *,
    timesteps_per_batch,  # timesteps per actor per update
    clip_param,
    entcoeff,  # clipping parameter epsilon, entropy coeff
    optim_epochs,
    optim_stepsize,
    optim_batchsize,  # optimization hypers
    gamma,
    lam,  # advantage estimation
    max_timesteps=0,
    max_episodes=0,
    max_iters=0,
    max_seconds=0,  # time constraint
    callback=None,  # you can do anything in the callback, since it takes locals(), globals()
    adam_epsilon=1e-5,
    schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
    load_model_path,
    test_only,
    stochastic,
    symmetric_training=False,
    obs_names=None,
    single_episode=False,
    horizon_hack=False,
    running_avg_len=100,
    init_three=False,
    actions=None,
    symmetric_training_trick=False,
    seeds_fn=None,
    bootstrap_seeds=False,
):
    global seeds
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_func("pi", ob_space, ac_space)  # Network for new policy
    old_pi = policy_func("old_pi", ob_space,
                         ac_space)  # Network for old policy
    adv_targ = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return
    mask = tf.placeholder(dtype=tf.bool, shape=[None])  # Mask for the trick

    lr_mult = tf.placeholder(
        name='lr_mult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lr_mult  # Annealed clipping parameter epsilon

    ob = U.get_placeholder_cached(name="ob")
    st = U.get_placeholder_cached(name="st")
    ac = pi.pdtype.sample_placeholder([None])

    kl = old_pi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    mean_kl = U.mean(tf.boolean_mask(kl, mask))  # Mean over the batch
    mean_ent = U.mean(tf.boolean_mask(ent, mask))
    entropy_penalty = -entcoeff * mean_ent

    ratio = tf.exp(pi.pd.logp(ac) - old_pi.pd.logp(ac))  # pi_new / pi_old
    surr_1 = ratio * adv_targ  # surrogate from conservative policy iteration
    surr_2 = U.clip(ratio, 1.0 - clip_param, 1.0 + clip_param) * adv_targ  #
    surr_loss = -U.mean(tf.boolean_mask(
        tf.minimum(surr_1, surr_2),
        mask))  # PPO's pessimistic surrogate (L^CLIP), mean over the batch
    vf_loss = U.mean(tf.boolean_mask(tf.square(pi.vpred - ret), mask))
    total_loss = surr_loss + entropy_penalty + vf_loss
    losses = [surr_loss, entropy_penalty, vf_loss, mean_kl, mean_ent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    comp_loss_and_grad = U.function([ob, st, ac, adv_targ, ret, lr_mult, mask],
                                    losses +
                                    [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(old_v, new_v)
            for (old_v,
                 new_v) in zipsame(old_pi.get_variables(), pi.get_variables())
        ])
    comp_loss = U.function([ob, st, ac, adv_targ, ret, lr_mult, mask], losses)

    if init_three:
        assign_init_three_1 = U.function(
            [], [],
            updates=[
                tf.assign(new_v, old_v) for (old_v, new_v) in zipsame(
                    pi.get_orig_variables(), pi.get_part_variables(1))
            ])
        assign_init_three_2 = U.function(
            [], [],
            updates=[
                tf.assign(new_v, old_v) for (old_v, new_v) in zipsame(
                    pi.get_orig_variables(), pi.get_part_variables(2))
            ])

    U.initialize()
    if load_model_path is not None:
        U.load_state(load_model_path)
        if init_three:
            assign_init_three_1()
            assign_init_three_2()
    adam.sync()

    if seeds_fn is not None:
        with open(seeds_fn) as f:
            seeds = [int(seed) for seed in f.readlines()]
    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_batch,
                                     stochastic=stochastic,
                                     single_episode=test_only
                                     or single_episode,
                                     actions=actions,
                                     bootstrap_seeds=bootstrap_seeds)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    len_buffer = deque(
        maxlen=running_avg_len)  # rolling buffer for episode lengths
    rew_buffer = deque(
        maxlen=running_avg_len)  # rolling buffer for episode rewards
    origrew_buffer = deque(
        maxlen=running_avg_len)  # rolling buffer for original episode rewards

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    while True:
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()

        add_vtarg_and_adv(seg, gamma, lam, horizon_hack=horizon_hack)

        # ob, ac, adv_targ, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, st, ac, adv_targ, tdlamret = seg["ob"], seg["step"], seg[
            "ac"], seg["adv"], seg["tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate

        if symmetric_training_trick:
            first_75 = st < 75
            mask = ~np.concatenate((np.zeros_like(first_75), first_75))
        else:
            mask = np.concatenate(
                (np.ones_like(st,
                              dtype=np.bool), np.ones_like(st, dtype=np.bool)))
        if symmetric_training:
            sym_obss = []
            sym_acc = []
            for i in range(timesteps_per_batch):
                obs = OrderedDict(zip(obs_names, ob[i]))
                sym_obs = obs.copy()
                swap_legs(sym_obs)

                sym_ac = ac[i].copy()
                sym_ac = np.concatenate((sym_ac[9:], sym_ac[:9]))
                sym_obss.append(np.asarray(list(sym_obs.values())))
                sym_acc.append(sym_ac)
            sym_obss = np.asarray(sym_obss)
            sym_acc = np.asarray(sym_acc)

            ob = np.concatenate((ob, sym_obss))
            ac = np.concatenate((ac, sym_acc))
            adv_targ = np.concatenate((adv_targ, adv_targ))
            tdlamret = np.concatenate((tdlamret, tdlamret))
            vpredbefore = np.concatenate((vpredbefore, vpredbefore))
            st = np.concatenate((st, st))

        # Compute stats before updating
        if bootstrap_seeds:
            lrlocal = (seg["ep_lens"], seg["ep_rets"], seg["ep_orig_rets"],
                       seg["easy_seeds"], seg["hard_seeds"])  # local values
            listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
            lens, rews, orig_rews, easy_seeds, hard_seeds = map(
                flatten_lists, zip(*listoflrpairs))
            easy_seeds = [x for x in easy_seeds if x != 0]
            hard_seeds = [x for x in hard_seeds if x != 0]
            print('seeds set sizes:', len(seeds), len(easy_seeds),
                  len(hard_seeds))
            seeds = list((set(seeds) - set(easy_seeds)) | set(hard_seeds))
        else:
            lrlocal = (seg["ep_lens"], seg["ep_rets"], seg["ep_orig_rets"]
                       )  # local values
            listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
            lens, rews, orig_rews = map(flatten_lists, zip(*listoflrpairs))

        len_buffer.extend(lens)
        rew_buffer.extend(rews)
        origrew_buffer.extend(orig_rews)
        logger.record_tabular("Iter", iters_so_far)
        logger.record_tabular("EpLenMean", np.mean(len_buffer))
        logger.record_tabular("EpRewMean", np.mean(rew_buffer))
        logger.record_tabular("EpOrigRewMean", np.mean(origrew_buffer))
        logger.record_tabular("EpOrigRewStd", np.std(origrew_buffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)

        n_completed = 0
        sum_completed = 0
        for ep_len, orig_rew in zip(lens, orig_rews):
            if ep_len == 1000:
                n_completed += 1
                sum_completed += orig_rew
        avg_completed = sum_completed / n_completed if n_completed > 0 else 0
        logger.record_tabular("AvgCompleted", avg_completed)
        perc_completed = 100 * n_completed / len(lens) if len(lens) > 0 else 0
        logger.record_tabular("PercCompleted", perc_completed)

        if callback: callback(locals(), globals())

        adv_targ = (adv_targ - adv_targ.mean()) / adv_targ.std(
        )  # standardized advantage function estimate
        d = Dataset(dict(ob=ob,
                         st=st,
                         ac=ac,
                         atarg=adv_targ,
                         vtarg=tdlamret,
                         mask=mask),
                    shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        if not test_only:
            logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data. I log results only for the first worker (rank=0)
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *batch_losses, grads = comp_loss_and_grad(
                    batch["ob"], batch["st"], batch["ac"], batch["atarg"],
                    batch["vtarg"], cur_lrmult, batch["mask"])
                if not test_only:
                    adam.update(grads, optim_stepsize * cur_lrmult)
                losses.append(batch_losses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            batch_losses = comp_loss(batch["ob"], batch["st"], batch["ac"],
                                     batch["atarg"], batch["vtarg"],
                                     cur_lrmult, batch["mask"])
            losses.append(batch_losses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))

        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()

        iters_so_far += 1
Exemple #8
0
def learn(
        env,
        policy_func,
        disc,
        *,
        timesteps_per_batch,  # timesteps per actor per update
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        optim_epochs,
        optim_stepsize,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,  # advantage estimation
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        max_seconds=0,  # time constraint
        callback=None,  # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
        logdir=".",
        agentName="PPO-Agent",
        resume=0,
        num_parallel=0,
        num_cpu=1,
        num_extra=0,
        gan_batch_size=128,
        gan_num_epochs=5,
        gan_display_step=40,
        resume_disc=0,
        resume_non_disc=0,
        mocap_path="",
        gan_replay_buffer_size=1000000,
        gan_prob_to_put_in_replay=0.01,
        gan_reward_to_retrain_discriminator=5,
        use_distance=0,
        use_blend=0):
    # Deal with GAN
    if not use_distance:
        replay_buf = MyReplayBuffer(gan_replay_buffer_size)
    data = np.loadtxt(
        mocap_path + ".dat"
    )  #"D:/p4sw/devrel/libdev/flex/dev/rbd/data/bvh/motion_simple.dat");
    label = np.concatenate((np.ones(
        (data.shape[0], 1)), np.zeros((data.shape[0], 1))),
                           axis=1)

    print("Real data label = " + str(label))

    mocap_set = Dataset(dict(data=data, label=label), shuffle=True)

    # Setup losses and stuff
    # ----------------------------------------
    rank = MPI.COMM_WORLD.Get_rank()
    ob_space = env.observation_space
    ac_space = env.action_space

    ob_size = ob_space.shape[0]
    ac_size = ac_space.shape[0]

    #print("rank = " + str(rank) + " ob_space = "+str(ob_space.shape) + " ac_space = "+str(ac_space.shape))
    #exit(0)
    pi = policy_func("pi", ob_space,
                     ac_space)  # Construct network for new policy
    oldpi = policy_func("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = U.mean(kloldnew)
    meanent = U.mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = U.clip(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg  #
    pol_surr = -U.mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vfloss1 = tf.square(pi.vpred - ret)
    vpredclipped = oldpi.vpred + tf.clip_by_value(pi.vpred - oldpi.vpred,
                                                  -clip_param, clip_param)
    vfloss2 = tf.square(vpredclipped - ret)
    vf_loss = .5 * U.mean(
        tf.maximum(vfloss1, vfloss2)
    )  # we do the same clipping-based trust region for the value function
    #vf_loss = U.mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    adam.sync()

    # Prepare for rollouts
    # ----------------------------------------
    sess = tf.get_default_session()

    avars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    non_disc_vars = [
        a for a in avars
        if not a.name.split("/")[0].startswith("discriminator")
    ]
    disc_vars = [
        a for a in avars if a.name.split("/")[0].startswith("discriminator")
    ]
    #print(str(non_disc_names))
    #print(str(disc_names))
    #exit(0)
    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    disc_saver = tf.train.Saver(disc_vars, max_to_keep=None)
    non_disc_saver = tf.train.Saver(non_disc_vars, max_to_keep=None)
    saver = tf.train.Saver(max_to_keep=None)
    if resume > 0:
        saver.restore(
            tf.get_default_session(),
            os.path.join(os.path.abspath(logdir),
                         "{}-{}".format(agentName, resume)))
        if not use_distance:
            if os.path.exists(logdir + "\\" + 'replay_buf_' +
                              str(int(resume / 100) * 100) + '.pkl'):
                print("Load replay buf")
                with open(
                        logdir + "\\" + 'replay_buf_' +
                        str(int(resume / 100) * 100) + '.pkl', 'rb') as f:
                    replay_buf = pickle.load(f)
            else:
                print("Can't load replay buf " + logdir + "\\" +
                      'replay_buf_' + str(int(resume / 100) * 100) + '.pkl')
    iters_so_far = resume

    if resume_non_disc > 0:
        non_disc_saver.restore(
            tf.get_default_session(),
            os.path.join(
                os.path.abspath(logdir),
                "{}-{}".format(agentName + "_non_disc", resume_non_disc)))
        iters_so_far = resume_non_disc

    if use_distance:
        print("Use distance")
        nn = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(data)
    else:
        nn = None
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     disc,
                                     timesteps_per_batch,
                                     stochastic=True,
                                     num_parallel=num_parallel,
                                     num_cpu=num_cpu,
                                     rank=rank,
                                     ob_size=ob_size,
                                     ac_size=ac_size,
                                     com=MPI.COMM_WORLD,
                                     num_extra=num_extra,
                                     iters_so_far=iters_so_far,
                                     use_distance=use_distance,
                                     nn=nn)

    if resume_disc > 0:
        disc_saver.restore(
            tf.get_default_session(),
            os.path.join(os.path.abspath(logdir),
                         "{}-{}".format(agentName + "_disc", resume_disc)))

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"
    logF = open(logdir + "\\" + 'log.txt', 'a')
    logR = open(logdir + "\\" + 'log_rew.txt', 'a')
    logStats = open(logdir + "\\" + 'log_stats.txt', 'a')
    if os.path.exists(logdir + "\\" + 'ob_list_' + str(rank) + '.pkl'):
        with open(logdir + "\\" + 'ob_list_' + str(rank) + '.pkl', 'rb') as f:
            ob_list = pickle.load(f)
    else:
        ob_list = []

    dump_training = 0
    learn_from_training = 0
    if dump_training:
        # , "mean": pi.ob_rms.mean, "std": pi.ob_rms.std
        saverRMS = tf.train.Saver({
            "_sum": pi.ob_rms._sum,
            "_sumsq": pi.ob_rms._sumsq,
            "_count": pi.ob_rms._count
        })
        saverRMS.save(tf.get_default_session(),
                      os.path.join(os.path.abspath(logdir), "rms.tf"))

        ob_np_a = np.asarray(ob_list)
        ob_np = np.reshape(ob_np_a, (-1, ob_size))
        [vpred, pdparam] = pi._vpred_pdparam(ob_np)

        print("vpred = " + str(vpred))
        print("pd_param = " + str(pdparam))
        with open('training.pkl', 'wb') as f:
            pickle.dump(ob_np, f)
            pickle.dump(vpred, f)
            pickle.dump(pdparam, f)
        exit(0)
    if learn_from_training:
        # , "mean": pi.ob_rms.mean, "std": pi.ob_rms.std

        with open('training.pkl', 'rb') as f:
            ob_np = pickle.load(f)
            vpred = pickle.load(f)
            pdparam = pickle.load(f)
        num = ob_np.shape[0]
        for i in range(num):
            xp = ob_np[i][1]
            ob_np[i][1] = 0.0
            ob_np[i][18] -= xp
            ob_np[i][22] -= xp
            ob_np[i][24] -= xp
            ob_np[i][26] -= xp
            ob_np[i][28] -= xp
            ob_np[i][30] -= xp
            ob_np[i][32] -= xp
            ob_np[i][34] -= xp
        print("ob_np = " + str(ob_np))
        print("vpred = " + str(vpred))
        print("pdparam = " + str(pdparam))
        batch_size = 128

        y_vpred = tf.placeholder(tf.float32, [
            batch_size,
        ])
        y_pdparam = tf.placeholder(tf.float32, [batch_size, pdparam.shape[1]])

        vpred_loss = U.mean(tf.square(pi.vpred - y_vpred))
        vpdparam_loss = U.mean(tf.square(pi.pdparam - y_pdparam))

        total_train_loss = vpred_loss + vpdparam_loss
        #total_train_loss = vpdparam_loss
        #total_train_loss = vpred_loss
        #coef = 0.01
        #dense_all = U.dense_all
        #for a in dense_all:
        #   total_train_loss += coef * tf.nn.l2_loss(a)
        #total_train_loss = vpdparam_loss
        optimizer = tf.train.AdamOptimizer(
            learning_rate=0.001).minimize(total_train_loss)
        d = Dataset(dict(ob=ob_np, vpred=vpred, pdparam=pdparam),
                    shuffle=not pi.recurrent)
        sess = tf.get_default_session()
        sess.run(tf.global_variables_initializer())
        saverRMS = tf.train.Saver({
            "_sum": pi.ob_rms._sum,
            "_sumsq": pi.ob_rms._sumsq,
            "_count": pi.ob_rms._count
        })
        saverRMS.restore(tf.get_default_session(),
                         os.path.join(os.path.abspath(logdir), "rms.tf"))
        if resume > 0:
            saver.restore(
                tf.get_default_session(),
                os.path.join(os.path.abspath(logdir),
                             "{}-{}".format(agentName, resume)))

        for q in range(100):
            sumLoss = 0
            for batch in d.iterate_once(batch_size):
                tl, _ = sess.run(
                    [total_train_loss, optimizer],
                    feed_dict={
                        pi.ob: batch["ob"],
                        y_vpred: batch["vpred"],
                        y_pdparam: batch["pdparam"]
                    })
                sumLoss += tl
            print("Iteration " + str(q) + " Loss = " + str(sumLoss))
        assign_old_eq_new()  # set old parameter values to new parameter values

        # Save as frame 1
        try:
            saver.save(tf.get_default_session(),
                       os.path.join(logdir, agentName),
                       global_step=1)
        except:
            pass
        #exit(0)
    if resume > 0:
        firstTime = False
    else:
        firstTime = True

    # Check accuracy
    #amocap = sess.run([disc.accuracy],
    #                feed_dict={disc.input: data,
    #                           disc.label: label})
    #print("Mocap accuracy = " + str(amocap))
    #print("Mocap label is " + str(label))

    #adata = np.array(replay_buf._storage)
    #print("adata shape = " + str(adata.shape))
    #alabel = np.concatenate((np.zeros((adata.shape[0], 1)), np.ones((adata.shape[0], 1))), axis=1)

    #areplay = sess.run([disc.accuracy],
    #                feed_dict={disc.input: adata,
    #                           disc.label: alabel})
    #print("Replay accuracy = " + str(areplay))
    #print("Replay label is " + str(alabel))
    #exit(0)
    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()

        add_vtarg_and_adv(seg, gamma, lam, timesteps_per_batch, num_parallel,
                          num_cpu)
        #print(" ob= " + str(seg["ob"])+ " rew= " + str(seg["rew"])+ " vpred= " + str(seg["vpred"])+ " new= " + str(seg["new"])+ " ac= " + str(seg["ac"])+ " prevac= " + str(seg["prevac"])+ " nextvpred= " + str(seg["nextvpred"])+ " ep_rets= " + str(seg["ep_rets"])+ " ep_lens= " + str(seg["ep_lens"]))

        #exit(0)
        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret, extra = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"], seg["extra"]

        #ob_list.append(ob.tolist())
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            #print(str(losses))
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        rewmean = np.mean(rewbuffer)
        logger.record_tabular("EpRewMean", rewmean)
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)

        # Train discriminator
        if not use_distance:
            print("Put in replay buf " +
                  str((int)(gan_prob_to_put_in_replay * extra.shape[0] + 1)))
            replay_buf.add(extra[np.random.choice(
                extra.shape[0],
                (int)(gan_prob_to_put_in_replay * extra.shape[0] + 1),
                replace=True)])
            #if iters_so_far == 1:
            if not use_blend:
                if firstTime:
                    firstTime = False
                    # Train with everything we got
                    lb = np.concatenate((np.zeros(
                        (extra.shape[0], 1)), np.ones((extra.shape[0], 1))),
                                        axis=1)
                    extra_set = Dataset(dict(data=extra, label=lb),
                                        shuffle=True)
                    for e in range(10):
                        i = 0
                        for mbatch in mocap_set.iterate_once(gan_batch_size):
                            batch = extra_set.next_batch(gan_batch_size)
                            _, l = sess.run(
                                [disc.optimizer_first, disc.loss],
                                feed_dict={
                                    disc.input:
                                    np.concatenate(
                                        (mbatch['data'], batch['data'])),
                                    disc.label:
                                    np.concatenate(
                                        (mbatch['label'], batch['label']))
                                })
                            i = i + 1
                            # Display logs per step
                            if i % gan_display_step == 0 or i == 1:
                                print(
                                    'discriminator epoch %i Step %i: Minibatch Loss: %f'
                                    % (e, i, l))
                        print(
                            'discriminator epoch %i Step %i: Minibatch Loss: %f'
                            % (e, i, l))
                if seg['mean_ext_rew'] > gan_reward_to_retrain_discriminator:
                    for e in range(gan_num_epochs):
                        i = 0
                        for mbatch in mocap_set.iterate_once(gan_batch_size):
                            data = replay_buf.sample(mbatch['data'].shape[0])
                            lb = np.concatenate((np.zeros(
                                (data.shape[0], 1)), np.ones(
                                    (data.shape[0], 1))),
                                                axis=1)
                            _, l = sess.run(
                                [disc.optimizer, disc.loss],
                                feed_dict={
                                    disc.input:
                                    np.concatenate((mbatch['data'], data)),
                                    disc.label:
                                    np.concatenate((mbatch['label'], lb))
                                })
                            i = i + 1
                            # Display logs per step
                            if i % gan_display_step == 0 or i == 1:
                                print(
                                    'discriminator epoch %i Step %i: Minibatch Loss: %f'
                                    % (e, i, l))
                        print(
                            'discriminator epoch %i Step %i: Minibatch Loss: %f'
                            % (e, i, l))
            else:
                if firstTime:
                    firstTime = False
                    # Train with everything we got
                    extra_set = Dataset(dict(data=extra), shuffle=True)
                    for e in range(10):
                        i = 0
                        for mbatch in mocap_set.iterate_once(gan_batch_size):
                            batch = extra_set.next_batch(gan_batch_size)
                            bf = np.random.uniform(0, 1, (gan_batch_size, 1))
                            onembf = 1 - bf
                            my_label = np.concatenate((bf, onembf), axis=1)
                            my_data = np.multiply(mbatch['data'],
                                                  bf) + np.multiply(
                                                      batch['data'], onembf)
                            _, l = sess.run([disc.optimizer_first, disc.loss],
                                            feed_dict={
                                                disc.input: my_data,
                                                disc.label: my_label
                                            })
                            i = i + 1
                            # Display logs per step
                            if i % gan_display_step == 0 or i == 1:
                                print(
                                    'discriminator epoch %i Step %i: Minibatch Loss: %f'
                                    % (e, i, l))
                        print(
                            'discriminator epoch %i Step %i: Minibatch Loss: %f'
                            % (e, i, l))
                if seg['mean_ext_rew'] > gan_reward_to_retrain_discriminator:
                    for e in range(gan_num_epochs):
                        i = 0
                        for mbatch in mocap_set.iterate_once(gan_batch_size):
                            data = replay_buf.sample(mbatch['data'].shape[0])

                            bf = np.random.uniform(0, 1, (gan_batch_size, 1))
                            onembf = 1 - bf
                            my_label = np.concatenate((bf, onembf), axis=1)
                            my_data = np.multiply(mbatch['data'],
                                                  bf) + np.multiply(
                                                      data, onembf)

                            _, l = sess.run([disc.optimizer_first, disc.loss],
                                            feed_dict={
                                                disc.input: my_data,
                                                disc.label: my_label
                                            })
                            i = i + 1
                            # Display logs per step
                            if i % gan_display_step == 0 or i == 1:
                                print(
                                    'discriminator epoch %i Step %i: Minibatch Loss: %f'
                                    % (e, i, l))
                        print(
                            'discriminator epoch %i Step %i: Minibatch Loss: %f'
                            % (e, i, l))

        # if True:
        #     lb = np.concatenate((np.zeros((extra.shape[0],1)),np.ones((extra.shape[0],1))),axis=1)
        #     extra_set = Dataset(dict(data=extra,label=lb), shuffle=True)
        #     num_r = 1
        #     if iters_so_far == 1:
        #         num_r = gan_num_epochs
        #     for e in range(num_r):
        #         i = 0
        #         for batch in extra_set.iterate_once(gan_batch_size):
        #             mbatch = mocap_set.next_batch(gan_batch_size)
        #             _, l = sess.run([disc.optimizer, disc.loss], feed_dict={disc.input: np.concatenate((mbatch['data'],batch['data'])), disc.label: np.concatenate((mbatch['label'],batch['label']))})
        #             i = i + 1
        #             # Display logs per step
        #             if i % gan_display_step == 0 or i == 1:
        #                 print('discriminator epoch %i Step %i: Minibatch Loss: %f' % (e, i, l))
        #         print('discriminator epoch %i Step %i: Minibatch Loss: %f' % (e, i, l))

        if not use_distance:
            if iters_so_far % 100 == 0:
                with open(
                        logdir + "\\" + 'replay_buf_' + str(iters_so_far) +
                        '.pkl', 'wb') as f:
                    pickle.dump(replay_buf, f)

        with open(logdir + "\\" + 'ob_list_' + str(rank) + '.pkl', 'wb') as f:
            pickle.dump(ob_list, f)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logF.write(str(rewmean) + "\n")
            logR.write(str(seg['mean_ext_rew']) + "\n")
            logStats.write(logger.get_str() + "\n")
            logF.flush()
            logStats.flush()
            logR.flush()

            logger.dump_tabular()

            try:
                os.remove(logdir + "/checkpoint")
            except OSError:
                pass
            try:
                saver.save(tf.get_default_session(),
                           os.path.join(logdir, agentName),
                           global_step=iters_so_far)
            except:
                pass
            try:
                non_disc_saver.save(tf.get_default_session(),
                                    os.path.join(logdir,
                                                 agentName + "_non_disc"),
                                    global_step=iters_so_far)
            except:
                pass
            try:
                disc_saver.save(tf.get_default_session(),
                                os.path.join(logdir, agentName + "_disc"),
                                global_step=iters_so_far)
            except:
                pass
Exemple #9
0
def main():
    args = get_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    args_dir, logs_dir, models_dir, samples_dir = get_all_save_paths(
        args, 'pretrain', combine_action=args.combine_action)
    eval_log_dir = logs_dir + "_eval"
    utils.cleanup_log_dir(logs_dir)
    utils.cleanup_log_dir(eval_log_dir)

    _, _, intrinsic_models_dir, _ = get_all_save_paths(args,
                                                       'learn_reward',
                                                       load_only=True)
    if args.load_iter != 'final':
        intrinsic_model_file_name = os.path.join(
            intrinsic_models_dir,
            args.env_name + '_{}.pt'.format(args.load_iter))
    else:
        intrinsic_model_file_name = os.path.join(
            intrinsic_models_dir, args.env_name + '.pt'.format(args.load_iter))
    intrinsic_arg_file_name = os.path.join(args_dir, 'command.txt')

    # save args to arg_file
    with open(intrinsic_arg_file_name, 'w') as f:
        json.dump(args.__dict__, f, indent=2)

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                         args.gamma, logs_dir, device, False)

    actor_critic = Policy(envs.observation_space.shape,
                          envs.action_space,
                          base_kwargs={'recurrent': args.recurrent_policy})
    actor_critic.to(device)

    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.lr,
                         eps=args.eps,
                         max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               acktr=True)
    else:
        raise NotImplementedError

    if args.use_intrinsic:
        obs_shape = envs.observation_space.shape
        if len(obs_shape) == 3:
            action_dim = envs.action_space.n
        elif len(obs_shape) == 1:
            action_dim = envs.action_space.shape[0]

        if 'NoFrameskip' in args.env_name:
            file_name = os.path.join(
                args.experts_dir, "trajs_ppo_{}.pt".format(
                    args.env_name.split('-')[0].replace('NoFrameskip',
                                                        '').lower()))
        else:
            file_name = os.path.join(
                args.experts_dir,
                "trajs_ppo_{}.pt".format(args.env_name.split('-')[0].lower()))

        rff = RewardForwardFilter(args.gamma)
        intrinsic_rms = RunningMeanStd(shape=())

        if args.intrinsic_module == 'icm':
            print('Loading pretrained intrinsic module: %s' %
                  intrinsic_model_file_name)
            inverse_model, forward_dynamics_model, encoder = torch.load(
                intrinsic_model_file_name)
            icm =  IntrinsicCuriosityModule(envs, device, inverse_model, forward_dynamics_model, \
                                            inverse_lr=args.intrinsic_lr, forward_lr=args.intrinsic_lr,\
                                            )

        if args.intrinsic_module == 'vae':
            print('Loading pretrained intrinsic module: %s' %
                  intrinsic_model_file_name)
            vae = torch.load(intrinsic_model_file_name)
            icm =  GenerativeIntrinsicRewardModule(envs, device, \
                                                   vae, lr=args.intrinsic_lr, \
                                                   )

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)

    start = time.time()
    num_updates = int(
        args.num_env_steps) // args.num_steps // args.num_processes
    for j in range(num_updates):

        if args.use_linear_lr_decay:
            # decrease learning rate linearly
            utils.update_linear_schedule(
                agent.optimizer, j, num_updates,
                agent.optimizer.lr if args.algo == "acktr" else args.lr)

        for step in range(args.num_steps):
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            obs, reward, done, infos = envs.step(action)
            next_obs = obs

            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            rollouts.insert(obs, next_obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        if args.use_intrinsic:
            for step in range(args.num_steps):
                state = rollouts.obs[step]
                action = rollouts.actions[step]
                next_state = rollouts.next_obs[step]
                if args.intrinsic_module == 'icm':
                    state = encoder(state)
                    next_state = encoder(next_state)
                with torch.no_grad():
                    rollouts.rewards[
                        step], pred_next_state = icm.calculate_intrinsic_reward(
                            state, action, next_state, args.lambda_true_action)
            if args.standardize == 'True':
                buf_rews = rollouts.rewards.cpu().numpy()
                intrinsic_rffs = np.array(
                    [rff.update(rew) for rew in buf_rews.T])
                rffs_mean, rffs_std, rffs_count = mpi_moments(
                    intrinsic_rffs.ravel())
                intrinsic_rms.update_from_moments(rffs_mean, rffs_std**2,
                                                  rffs_count)
                mean = intrinsic_rms.mean
                std = np.asarray(np.sqrt(intrinsic_rms.var))
                rollouts.rewards = rollouts.rewards / torch.from_numpy(std).to(
                    device)

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.gae_lambda, args.use_proper_time_limits)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        # save for every interval-th episode or for the last epoch
        if (j % args.save_interval == 0
                or j == num_updates - 1) and args.save_dir != "":
            save_path = os.path.join(models_dir, args.algo)
            policy_file_name = os.path.join(save_path, args.env_name + '.pt')

            try:
                os.makedirs(save_path)
            except OSError:
                pass

            torch.save([
                actor_critic,
                getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
            ], policy_file_name)

        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            end = time.time()
            print(
                "{} Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                .format(args.env_name, j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist_entropy, value_loss,
                        action_loss))

        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            ob_rms = utils.get_vec_normalize(envs).ob_rms
            evaluate(actor_critic, ob_rms, args.env_name, args.seed,
                     args.num_processes, eval_log_dir, device)
Exemple #10
0
def mpi_std(value):
    if value == []:
        value = [0.]
    if not isinstance(value, list):
        value = [value]
    return mpi_moments(np.array(value))[1][0]
def learn(
    env,
    policy_func,
    *,
    timesteps_per_actorbatch,  # timesteps per actor per update
    clip_param,
    entcoeff,  # clipping parameter epsilon, entropy coeff
    optim_epochs,
    optim_stepsize,
    optim_batchsize,  # optimization hypers
    gamma,
    lam,  # advantage estimation
    max_timesteps=0,
    max_episodes=0,
    max_iters=0,
    max_seconds=0,  # time constraint
    callback=None,  # you can do anything in the callback, since it takes locals(), globals()
    adam_epsilon=1e-5,
    schedule='constant'  # annealing for stepsize parameters (epsilon and adam)
):
    print(
        "----------------------------------------------------------------------------Learning is here ..."
    )
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_func("pi", ob_space,
                     ac_space)  # Construct network for new policy
    oldpi = policy_func("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param,
                             1.0 + clip_param) * atarg  #
    pol_surr = -tf.reduce_mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    adam.sync()

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_actorbatch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    # my
    #sess = tf.get_default_session()
    #writer = tf.summary.FileWriter("/home/chris/openai_logdir/tb2")
    #placeholder_rews = tf.placeholder(tf.float32)
    #placeholder_vpreds = tf.placeholder(tf.float32)
    #placeholder_advs = tf.placeholder(tf.float32)
    #placeholder_news = tf.placeholder(tf.float32)
    #placeholder_ep_rews = tf.placeholder(tf.float32)

    #tf.summary.histogram("rews", placeholder_rews)
    #tf.summary.histogram("vpreds", placeholder_vpreds)
    #tf.summary.histogram("advs", placeholder_advs)
    #tf.summary.histogram("news", placeholder_news)
    #tf.summary.scalar("ep_rews", placeholder_ep_rews)

    #writer = SummaryWriter("/home/chris/openai_logdir/tb2")

    #placeholder_ep_rews = tf.placeholder(tf.float32)
    #placeholder_ep_vpred = tf.placeholder(tf.float32)
    #placeholder_ep_atarg = tf.placeholder(tf.float32)

    #sess = tf.get_default_session()
    #writer = tf.summary.FileWriter("/home/chris/openai_logdir/x_new/tb2")
    #tf.summary.scalar("EpRews", placeholder_ep_rews)
    #tf.summary.scalar("EpVpred", placeholder_ep_vpred)

    #tf.summary.scalar("EpRews", placeholder_ep_rews)
    #tf.summary.scalar("EpVpred", placeholder_ep_vpred)
    #tf.summary.scalar("EpAtarg", placeholder_ep_atarg)
    #summ = tf.summary.merge_all()

    time_step_idx = 0

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()

        # Compute target value using TD(lambda) estimator, and advantage with GAE(lambda)
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate

        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate

        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))

        # ep_lens and ep_rets
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        #TODO make MPI compatible
        path_lens = seg["path_lens"]
        path_rews = seg["path_rets"]

        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        timesteps_so_far += sum(path_lens)  # my add path lens
        print("timesteps_so_far: %i" % timesteps_so_far)

        iters_so_far += 1
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)

        #my
        logger.record_tabular("PathLenMean", np.mean(path_lens))
        logger.record_tabular("PathRewMean", np.mean(path_rews))

        # MY
        dir = "/home/chris/openai_logdir/"
        env_name = env.unwrapped.spec.id

        #plots

        plt.figure(figsize=(10, 4))
        # plt.plot(vpredbefore, animated=True, label="vpredbefore")
        #new_with_nones = seg["new"]
        #np.place(new_with_nones, new_with_nones == 0, [6])
        #plt.plot(new_with_nones, 'r.', animated=True, label="new")

        for dd in np.where(seg["new"] == 1)[0]:
            plt.axvline(x=dd, color='green', linewidth=1)
            #plt.annotate('asd', xy=(2, 1), xytext=(3, 1.5),)

        #np.place(new_episode_with_nones, new_episode_with_nones == 0, [90])
        #aaaa = np.where(seg["news_episode"] > 0)
        #for dd in aaaa[0]:
        #    plt.annotate('ne'+str(dd), xy=(0, 6), xytext=(3, 1.5),)
        #plt.axvline(x=dd, color='green', linewidth=1)

        #plt.plot(ac, animated=True, label="ac")

        plt.plot(vpredbefore, 'g', label="vpredbefore", antialiased=True)

        # plot advantage of episode time steps
        #plt.plot(seg["adv"], 'b', animated=True, label="adv")
        #plt.plot(atarg, 'r', animated=True, label="atarg")
        plt.plot(tdlamret, 'y', animated=True, label="tdlamret")

        plt.legend()
        plt.title('iters_so_far: ' + str(iters_so_far))
        plt.savefig(dir + env_name + '_plot.png', dpi=300)

        if iters_so_far % 2 == 0 or iters_so_far == 0:
            #plt.ylim(ymin=-10, ymax=100)
            #plt.ylim(ymin=-15, ymax=15)

            plt.savefig(dir + '/plotiters/' + env_name + '_plot' + '_iter' +
                        str(iters_so_far).zfill(3) + '.png',
                        dpi=300)

        plt.clf()
        plt.close()

        # 3d V obs
        #ac, vpred = pi.act(True, ob)
        # check ob dim
        freq = 1  # 5
        # <= 3?
        if env.observation_space.shape[0] <= 3 and (iters_so_far % freq == 0
                                                    or iters_so_far == 1):

            figV, axV = plt.subplots()

            # surface
            #figV = plt.figure()
            #axV = Axes3D(figV)

            obs = env.observation_space

            X = np.arange(obs.low[0], obs.high[0],
                          (obs.high[0] - obs.low[0]) / 30)
            Y = np.arange(obs.low[1], obs.high[1],
                          (obs.high[1] - obs.low[1]) / 30)
            X, Y = np.meshgrid(X, Y)

            Z = np.zeros((len(X), len(Y)))

            for x in range(len(X)):
                for y in range(len(Y)):
                    #strange datatype needed??
                    myob = np.copy(ob[0])
                    myob[0] = X[0][x]
                    myob[1] = Y[y][0]
                    stochastic = True
                    ac, vpred = pi.act(stochastic, myob)
                    Z[x][y] = vpred

            plt.xlabel('First D')
            plt.ylabel('Second D')
            #plt.clabel('Value-Function')
            plt.title(env_name + ' iteration: ' + str(iters_so_far))

            # surface
            #axV.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=cm.coolwarm)

            # heat map
            imgplot = plt.imshow(Z, interpolation='nearest')
            imgplot.set_cmap('hot')
            plt.colorbar()

            figV.savefig(dir + env_name + '_V.png', dpi=300)

            if iters_so_far % 2 == 0 or iters_so_far == 0:
                figV.savefig(dir + '/plotiters/' + env_name + '_plot' +
                             '_iter' + str(iters_so_far).zfill(3) + '_V' +
                             '.png',
                             dpi=300)

            figV.clf()
            plt.close(figV)
        """
        # transfer timesteps of iterations into timesteps of episodes
        idx_seg = 0
        for ep in range(len(lens)) :
            all_ep = episodes_so_far - len(lens) + ep

            if all_ep%100==0:
                break

            ep_rews = seg["rew"][idx_seg:idx_seg+lens[ep]]
            ep_vpred = seg["vpred"][idx_seg:idx_seg+lens[ep]]
            ep_atarg = atarg[idx_seg:idx_seg+lens[ep]]

            idx_seg += lens[ep]
            #writer.add_histogram("ep_vpred", data, iters_so_far)
            #hist_dict[placeholder_ep_rews] = ep_rews[ep]
            #sess2 = sess.run(summ, feed_dict=hist_dict)

            #summary = tf.Summary()

            #if test_ep:
            #    break
            for a in range(len(ep_rews)):
                d = ep_rews[a]
                d2 = ep_vpred[a]
                d3 = ep_atarg[a]

                #summary.value.add(tag="EpRews/"+str(all_ep), simple_value=d)
                #writer.add_summary(summary, a)

                sess2 = sess.run(summ, feed_dict={placeholder_ep_rews: d, placeholder_ep_vpred: d2, placeholder_ep_atarg: d3})
            #writer.add_summary(sess2, global_step=a)
                writer.add_summary(sess2, global_step=time_step_idx)
                time_step_idx += 1

            writer.flush()
            #writer.close()
            #logger.record_tabular("vpred_e" + str(all_ep), ep_rews[ep])

        """
        """
        ###
        sess2 = sess.run(summ, feed_dict={placeholder_rews: seg["rew"],
                                          placeholder_vpreds: seg["vpred"],
                                          placeholder_advs: seg["adv"],
                                          placeholder_news: seg["new"]})
        writer.add_summary(sess2, global_step=episodes_so_far)
        """

        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()
Exemple #12
0
def compete_learn(
        env,
        policy_func,
        *,
        timesteps_per_batch,  # timesteps per actor per update
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        optim_epochs,
        optim_stepsize,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,  # advantage estimation
        max_timesteps=0,
        max_episodes=0,
        max_iters=20,
        max_seconds=0,  # time constraint
        callback=None,  # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-3,
        schedule='linear'  # annealing for stepsize parameters (epsilon and adam)
):
    # Setup losses and stuff
    # ----------------------------------------
    # at this stage, the ob_space and reward_space is
    #TODO: all of this tuples are not right ? becuase items in tuple is not mutable
    #TODO: another way to store the two agents' states is to use with with tf.variable_scope(scope, reuse=reuse):
    len1 = 2
    ob_space = env.observation_space.spaces
    ac_space = env.action_space.spaces
    pi = [
        policy_func("pi" + str(i),
                    ob_space[i],
                    ac_space[i],
                    placeholder_name="observation" + str(i))
        for i in range(len1)
    ]
    oldpi = [
        policy_func("oldpi" + str(i),
                    ob_space[i],
                    ac_space[i],
                    placeholder_name="observation" + str(i))
        for i in range(len1)
    ]
    atarg = [
        tf.placeholder(dtype=tf.float32, shape=[None]) for i in range(len1)
    ]
    ret = [tf.placeholder(dtype=tf.float32, shape=[None]) for i in range(len1)]
    tdlamret = [[] for i in range(len1)]
    # TODO: here I should revise lrmult to as it was before
    # lrmult = 1.0 # here for simple I only use constant learning rate multiplier
    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult

    #TODO: this point I cannot finally understand it, originally it is
    # ob=U.get_placeholder_cached(name="ob")
    #TODO: here it is a bug to fix, I think the get_placeholder_cached is global, you can only cache observation once and the next time if it finds the name placeholder it will return the previous placeholder, I don't know whether different namescope have  effect on this.
    # ob1 = U.get_placeholder_cached(name="observation1") # Note: I am not sure about this point
    # # ob2 = U.get_placeholder_cached(name="observation2")
    # ob1 = U.get_placeholder_cached(name="observation0")  # Note: I am not sure about this point
    # ob2 = U.get_placeholder_cached(name="observation1")
    #TODO: the only one question now is that pi network and oldpi networ both have the ob_ph named "observation", even in the original baseline implementation, does pi and oldpi share the observation placeholder, I think it is not

    ob = [
        U.get_placeholder_cached(name="observation" + str(i))
        for i in range(len1)
    ]
    # ac = tuple([pi[i].act(stochastic=True, observation=env.observation_space[i])[0]
    #      for i in range(len1)])
    # TODO: here for the policy to work I changed the observation parameter passed into the pi function to s which comes from env.reset()
    # s = env.reset()
    # ac = tuple([pi[i].act(stochastic=True, observation=s[i])[0]
    #             for i in range(len1)])

    ac = [pi[i].pdtype.sample_placeholder([None]) for i in range(len1)]
    kloldnew = [oldpi[i].pd.kl(pi[i].pd) for i in range(len1)]
    ent = [pi[i].pd.entropy() for i in range(len1)]
    print("ent1 and ent2 are {} and {}".format(ent[0], ent[1]))
    meankl = [U.mean(kloldnew[i]) for i in range(len1)]
    meanent = [U.mean(ent[i]) for i in range(len1)]

    pol_entpen = [(-entcoeff) * meanent[i] for i in range(len1)]
    ratio = [
        tf.exp(pi[i].pd.logp(ac[i]) - oldpi[i].pd.logp(ac[i]))
        for i in range(len1)
    ]
    # ratio = [tf.exp(pi[i].pd.logp(ac) - oldpi[i].pd.logp(ac[i])) for i in range(len1)] #pnew / pold
    surr1 = [ratio[i] * atarg[i] for i in range(len1)]
    # U.clip = tf.clip_by_value(t, clip_value_min, clip_value_max,name=None):
    # # among which t is A 'Tensor' so
    surr2 = [
        U.clip(ratio[i], 1.0 - clip_param, 1.0 + clip_param)
        for i in range(len1)
    ]
    pol_surr = [-U.mean(tf.minimum(surr1[i], surr2[i])) for i in range(len1)]
    vf_loss = [U.mean(tf.square(pi[i].vpred - ret[i])) for i in range(len1)]
    total_loss = [
        pol_surr[i] + pol_entpen[i] + vf_loss[i] for i in range(len1)
    ]
    # here I ccome to realize that the following miscelleous losses are just operations not tensors so they should be
    # # be made to a list to contain the info of the two agents
    # surr2 = U.clip(ratio[i], 1.0 - clip_param, 1.0 + clip_param)
    # pol_surr = -U.mean(tf.minimum(surr1[i], surr2[i]))
    # vf_loss = U.mean(tf.square(pi[i].vpred - ret[i]))
    # total_loss = pol_surr + pol_entpen + vf_loss

    #TODO: in another way I choose to revise losses to following:
    losses = [[pol_surr[i], pol_entpen[i], vf_loss[i], meankl[i], meanent[i]]
              for i in range(len1)]
    loss_names = ["pol_sur", "pol_entpen", "vf_loss", "kl", "ent"]
    var_list = [pi[i].get_trainable_variables() for i in range(len1)]

    lossandgrad = [
        U.function([ob[i], ac[i], atarg[i], ret[i], lrmult],
                   losses[i] + [U.flatgrad(total_loss[i], var_list[i])])
        for i in range(len1)
    ]
    adam = [MpiAdam(var_list[i], epsilon=adam_epsilon) for i in range(2)]

    #TODO: I wonder this cannot function as expected because the result is a list of functions, not will not execute automatically
    # assign_old_eq_new = [U.function([],[], updates=[tf.assign(oldv, newv)
    #     for (oldv, newv) in zipsame(oldpi[i].get_variables(), pi[i].get_variables())]) for i in range(len1)]

    # compute_losses is a function, so it should not be copied to copies, nevertheless the parameters should be
    # passed into it as the two agents
    compute_losses = [
        U.function([ob[i], ac[i], atarg[i], ret[i], lrmult], losses[i])
        for i in range(len1)
    ]
    # sess = U.get_session()
    # writer = tf.summary.FileWriter(logdir='log-mlp',graph=sess.graph)
    # now when the training iteration ends, save the trained model and test the win rate of the two.
    pi0_variables = slim.get_variables(scope="pi0")
    pi1_variables = slim.get_variables(scope="pi1")
    parameters_to_save_list0 = [v for v in pi0_variables]
    parameters_to_save_list1 = [v for v in pi1_variables]
    parameters_to_save_list = parameters_to_save_list0 + parameters_to_save_list1
    saver = tf.train.Saver(parameters_to_save_list)
    restore = tf.train.Saver(parameters_to_save_list)
    U.initialize()
    restore.restore(U.get_session(), "parameter/500/500.pkl")
    U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(pi[1].get_variables(), pi[0].get_variables())
        ])()
    U.get_session().run
    # [adam[i].sync() for i in range(2)]
    adam[0].sync()
    adam[1].sync()
    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     horizon=timesteps_per_batch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()

    lenbuffer = [deque(maxlen=100)
                 for i in range(len1)]  # rolling buffer for episode lengths
    rewbuffer = [deque(maxlen=100)
                 for i in range(len1)]  # rolling buffer for episode rewards

    parameters_savers = []
    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        # saver.restore()

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()
        #TODO: I got to fix this function to let it return the right seg["adv"] and seg["lamret"]
        add_vtarg_and_adv(seg, gamma, lam)

        losses = [[] for i in range(len1)]
        meanlosses = [[] for i in range(len1)]
        for i in range(len1):
            ob[i], ac[i], atarg[i], tdlamret[i] = seg["ob"][i], seg["ac"][
                i], seg["adv"][i], seg["tdlamret"][i]
            # ob_extend = np.expand_dims(ob[i],axis=0)
            # ob[i] = ob_extend
            vpredbefore = seg["vpred"][
                i]  # predicted value function before udpate
            atarg[i] = (atarg[i] - atarg[i].mean()) / atarg[i].std(
            )  # standardized advantage function estimate
            d = Dataset(dict(ob=ob[i],
                             ac=ac[i],
                             atarg=atarg[i],
                             vtarg=tdlamret[i]),
                        shuffle=not pi[i].recurrent)
            optim_batchsize = optim_batchsize or ob[i].shape[0]

            if hasattr(pi[i], "ob_rms"):
                pi[i].ob_rms.update(
                    ob[i])  # update running mean/std for policy

            #TODO: I have to make suer how assign_old_ea_new works and whether to assign it for each agent
            #Yes I can assure it will work now

            # save network parameters using tf.train.Saver
            #     saver_name = "saver" + str(iters_so_far)

            U.function([], [],
                       updates=[
                           tf.assign(oldv, newv) for (oldv, newv) in zipsame(
                               oldpi[i].get_variables(), pi[i].get_variables())
                       ])()
            # set old parameter values to new parameter values
            # Here we do a bunch of optimization epochs over the data
            logger.log("Optimizing the agent{}...".format(i))
            logger.log(fmt_row(13, loss_names))
            for _ in range(optim_epochs):
                losses[i] = [
                ]  # list of tuples, each of which gives the loss for a minibatch
                for batch in d.iterate_once(optim_batchsize):
                    # batch["ob"] = np.expand_dims(batch["ob"], axis=0)
                    *newlosses, g = lossandgrad[i](batch["ob"], batch["ac"],
                                                   batch["atarg"],
                                                   batch["vtarg"], cur_lrmult)
                    adam[i].update(g, optim_stepsize * cur_lrmult)
                    losses[i].append(newlosses)
                    logger.log(fmt_row(13, np.mean(losses[i], axis=0)))

            logger.log("Evaluating losses of agent{}...".format(i))
            losses[i] = []
            for batch in d.iterate_once(optim_batchsize):
                newlosses = compute_losses[i](batch["ob"], batch["ac"],
                                              batch["atarg"], batch["vtarg"],
                                              cur_lrmult)
                losses[i].append(newlosses)
            meanlosses[i], _, _ = mpi_moments(losses[i], axis=0)
            logger.log(fmt_row(13, meanlosses[i]))
            for (lossval, name) in zipsame(meanlosses[i], loss_names):
                logger.record_tabular("loss_" + name, lossval)
            logger.record_tabular("ev_tdlam_before{}".format(i),
                                  explained_variance(vpredbefore, tdlamret[i]))

            lrlocal = (seg["ep_lens"][i], seg["ep_rets"][i])  # local values
            listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
            lens, rews = map(flatten_lists, zip(*listoflrpairs))
            lenbuffer[i].extend(lens)
            rewbuffer[i].extend(rews)
            logger.record_tabular("EpLenMean {}".format(i),
                                  np.mean(lenbuffer[i]))
            logger.record_tabular("EpRewMean {}".format(i),
                                  np.mean(rewbuffer[i]))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        temp_pi = policy_func("temp_pi" + str(iters_so_far),
                              ob_space[0],
                              ac_space[0],
                              placeholder_name="temp_pi_observation" +
                              str(iters_so_far))
        U.function([], [],
                   updates=[
                       tf.assign(oldv, newv) for (oldv, newv) in zipsame(
                           temp_pi.get_variables(), pi[0].get_variables())
                   ])()
        parameters_savers.append(temp_pi)

        # now I think when the
        if iters_so_far % 3 == 0:
            sample_iteration = int(
                np.random.uniform(iters_so_far / 2, iters_so_far))
            print("now assign the {}th parameter of agent0 to agent1".format(
                sample_iteration))
            pi_restore = parameters_savers[sample_iteration]
            U.function([], [],
                       updates=[
                           tf.assign(oldv, newv)
                           for (oldv,
                                newv) in zipsame(pi[1].get_variables(),
                                                 pi_restore.get_variables())
                       ])()

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()

    # # now when the training iteration ends, save the trained model and test the win rate of the two.
    # pi0_variables = slim.get_variables(scope = "pi0")
    # pi1_variables = slim.get_variables(scope = "pi1")
    # parameters_to_save_list0 = [v for v in pi0_variables]
    # parameters_to_save_list1 = [v for v in pi1_variables]
    # parameters_to_save_list = parameters_to_save_list0 + parameters_to_save_list1
    # saver = tf.train.Saver(parameters_to_save_list)
    # parameters_path = 'parameter/'
    # tf.train.Saver()
    save_path = saver.save(U.get_session(), "parameter/800/800.pkl")
def learn(
        env,
        policy_fn,
        timesteps_per_actorbatch,  # timesteps per actor per update
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        optim_epochs,
        optim_stepsize,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,  # advantage estimation
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        max_seconds=0,  # time constraint
        callback=None,  # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
        save_per_iter=50,
        max_sample_traj=10,
        ckpt_dir=None,
        log_dir=None,
        task_name="origin",
        sample_stochastic=True,
        load_model_path=None,
        task="train"):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_fn("pi", ob_space,
                   ac_space)  # Construct network for new policy
    oldpi = policy_fn("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed clipping parameter epsilon

    ob_p = U.get_placeholder_cached(name="ob_physics")
    ob_f = U.get_placeholder_cached(name="ob_frames")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param,
                             1.0 + clip_param) * atarg  #
    pol_surr = -tf.reduce_mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    if pi.recurrent:
        state_c = U.get_placeholder_cached(name="state_c")
        state_h = U.get_placeholder_cached(name="state_h")
        lossandgrad = U.function(
            [ob_p, ob_f, state_c, state_h, ac, atarg, ret, lrmult],
            losses + [U.flatgrad(total_loss, var_list)])
        compute_losses = U.function(
            [ob_p, ob_f, state_c, state_h, ac, atarg, ret, lrmult], losses)
    else:
        lossandgrad = U.function([ob_p, ob_f, ac, atarg, ret, lrmult],
                                 losses + [U.flatgrad(total_loss, var_list)])
        compute_losses = U.function([ob_p, ob_f, ac, atarg, ret, lrmult],
                                    losses)

    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])

    writer = U.FileWriter(log_dir)
    U.initialize()
    adam.sync()

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_actorbatch,
                                     stochastic=True)
    traj_gen = traj_episode_generator(pi,
                                      env,
                                      timesteps_per_actorbatch,
                                      stochastic=sample_stochastic)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    loss_stats = stats(loss_names)
    ep_stats = stats(["Reward", "Episode_Length", "Episode_This_Iter"])
    if task == "sample_trajectory":
        sample_trajectory(load_model_path, max_sample_traj, traj_gen,
                          task_name, sample_stochastic)
        sys.exit()

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        # Save model
        if iters_so_far % save_per_iter == 0 and ckpt_dir is not None:
            U.save_state(os.path.join(ckpt_dir, task_name),
                         counter=iters_so_far)

        logger.log2("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.next()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        ob_p, ob_f = zip(*ob)
        ob_p = np.array(ob_p)
        ob_f = np.array(ob_f)
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate

        if pi.recurrent:
            sc, sh = zip(*seg["state"])
            d = Dataset(dict(ob_p=ob_p,
                             ob_f=ob_f,
                             ac=ac,
                             atarg=atarg,
                             vtarg=tdlamret,
                             state_c=np.array(sc),
                             state_h=np.array(sh)),
                        shuffle=not pi.recurrent)
        else:
            d = Dataset(dict(ob_p=ob_p,
                             ob_f=ob_f,
                             ac=ac,
                             atarg=atarg,
                             vtarg=tdlamret),
                        shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob_p)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log2("Optimizing...")
        logger.log2(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                if pi.recurrent:
                    lg = lossandgrad(batch["ob_p"], batch["ob_f"],
                                     batch["state_c"], batch["state_h"],
                                     batch["ac"], batch["atarg"],
                                     batch["vtarg"], cur_lrmult)
                else:
                    lg = lossandgrad(batch["ob_p"], batch["ob_f"], batch["ac"],
                                     batch["atarg"], batch["vtarg"],
                                     cur_lrmult)
                newlosses = lg[:-1]
                g = lg[-1]
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            logger.log2(fmt_row(13, np.mean(losses, axis=0)))

        logger.log2("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            if pi.recurrent:
                newlosses = compute_losses(batch["ob_p"], batch["ob_f"],
                                           batch["state_c"], batch["state_h"],
                                           batch["ac"], batch["atarg"],
                                           batch["vtarg"], cur_lrmult)
            else:
                newlosses = compute_losses(batch["ob_p"], batch["ob_f"],
                                           batch["ac"], batch["atarg"],
                                           batch["vtarg"], cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log2(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()
            loss_stats.add_all_summary(writer, meanlosses, iters_so_far)
            ep_stats.add_all_summary(
                writer, [np.mean(rewbuffer),
                         np.mean(lenbuffer),
                         len(lens)], iters_so_far)

    return pi
def learn(
    env,
    policy_fn,
    timesteps_per_actorbatch,  # timesteps per actor per update
    clip_param,
    entcoeff,  # clipping parameter epsilon, entropy coeff
    optim_epochs,
    optim_stepsize,
    optim_batchsize,  # optimization hypers
    gamma,
    lam,  # advantage estimation
    max_timesteps=0,
    max_episodes=0,
    max_iters=0,
    max_seconds=0,  # time constraint
    callback=None,  # you can do anything in the callback, since it takes locals(), globals()
    adam_epsilon=1e-5,
    schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
    resume_training=False,
):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_fn("pi", ob_space,
                   ac_space)  # Construct network for new policy
    oldpi = policy_fn("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(
        name="atarg", dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(name="ret", dtype=tf.float32,
                         shape=[None])  # Empirical return

    summ_writer = tf.summary.FileWriter("/tmp/tensorboard",
                                        U.get_session().graph)
    U.launch_tensorboard_in_background("/tmp/tensorboard")

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule

    ob = U.get_placeholder_cached(name="ob")
    ob_2 = U.get_placeholder_cached(name="ob_2")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param,
                             1.0 + clip_param) * atarg  #
    pol_surr = -tf.reduce_mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    check_ops = tf.add_check_numerics_ops()

    var_list = pi.get_trainable_variables()
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    lossandgrad = U.function(
        [ob, ob_2, ac, atarg, ret, lrmult],
        losses + [U.flatgrad(total_loss, var_list)],
        updates=[check_ops],
    )
    debugnan = U.function([ob, ob_2, ac, atarg, ret, lrmult],
                          losses + [ratio, surr1, surr2])
    dbgnames = loss_names + ["ratio", "surr1", "surr2"]

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ob_2, ac, atarg, ret, lrmult], losses)

    U.initialize()
    adam.sync()

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_actorbatch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    if resume_training:
        pi.load_variables("/tmp/rlnav_model")
        oldpi.load_variables("/tmp/rlnav_model")
    else:
        # clear reward history log
        with open(rew_hist_filepath, 'w') as f:
            f.write('')

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ob_2, ac, atarg, tdlamret = seg["ob"], seg["ob_2"], seg["ac"], seg[
            "adv"], seg["tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        # flatten first two dimensions and put into batch maker
        dataset_dict = dict(ob=ob,
                            ob_2=ob_2,
                            ac=ac,
                            atarg=atarg,
                            vtarg=tdlamret)

        def flatten_horizon_and_agent_dims(array):
            """ using F order because we assume that the shape is (horizon, n_agents)
            and we want the new flattened first dimension to first run through
            horizon, then n_agents in order to not cut up the sequentiality of
            the experiences """
            new_shape = (array.shape[0] * array.shape[1], ) + array.shape[2:]
            return array.reshape(new_shape, order='F')

        for key in dataset_dict:
            dataset_dict[key] = flatten_horizon_and_agent_dims(
                dataset_dict[key])
        d = Dataset(dataset_dict, shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]
        n_agents = ob.shape[1]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        #export rewards to log file
        rewplot = np.array(seg["rew"])
        with open(rew_hist_filepath, 'ab') as f:
            np.savetxt(f, rewplot, delimiter=',')

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                landg = lossandgrad(batch["ob"], batch["ob_2"], batch["ac"],
                                    batch["atarg"], batch["vtarg"], cur_lrmult)
                newlosses, g = landg[:-1], landg[-1]
                # debug nans
                if np.any(np.isnan(newlosses)):
                    dbglosses = debugnan(batch["ob"], batch["ob_2"],
                                         batch["ac"], batch["atarg"],
                                         batch["vtarg"], cur_lrmult)
                    raise ValueError("Nan detected in losses: {} {}".format(
                        dbgnames, dbglosses))
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        pi.save_variables("/tmp/rlnav_model")
        pi.load_variables("/tmp/rlnav_model")

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ob_2"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular(
            "ev_tdlam_before",
            np.average([
                explained_variance(vpredbefore[:, i], tdlamret[:, i])
                for i in range(n_agents)
            ]))  # average of explained variance for each agent
        #         lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values
        #         listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
        #         lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lens, rews = (seg["ep_lens"], seg["ep_rets"])
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        tf.summary.scalar("EpLenMean", np.mean(lenbuffer))
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()

    return pi
Exemple #15
0
    def replay(self, seg_list, batch_size):
        print(self.scope + " training")

        if self.schedule == 'constant':
            cur_lrmult = 1.0
        elif self.schedule == 'linear':
            cur_lrmult = max(
                1.0 - float(self.timesteps_so_far) / self.max_timesteps, 0)

        # Here we do a bunch of optimization epochs over the data
        # 批量计算的思路是,每次将所有战斗的g值得到,然后求平均,优化。循环多次
        newlosses_list = []
        logger.log("Optimizing...")
        loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]
        logger.log(fmt_row(13, loss_names))
        for _ in range(self.optim_epochs):
            g_list = []
            for seg in seg_list:

                self.add_vtarg_and_adv(seg, self.gamma, self.lam)

                # print(seg)

                # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
                ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg[
                    "adv"], seg["tdlamret"]
                vpredbefore = seg[
                    "vpred"]  # predicted value function before udpate
                atarg = (atarg - atarg.mean()) / atarg.std(
                )  # standardized advantage function estimate
                d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                            shuffle=not self.pi.recurrent)

                if hasattr(self.pi, "ob_rms"):
                    self.pi.ob_rms.update(
                        ob)  # update running mean/std for policy

                self.assign_old_eq_new(
                )  # set old parameter values to new parameter values

                # 完整的拿所有行为
                batch = d.next_batch(d.n)
                # print("ob", batch["ob"], "ac", batch["ac"], "atarg", batch["atarg"], "vtarg", batch["vtarg"])
                *newlosses, debug_atarg, pi_ac, opi_ac, vpred, pi_pd, opi_pd, kl_oldnew, total_loss, var_list, grads, g = \
                    self.lossandgrad(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult)
                # print("debug_atarg", debug_atarg, "pi_ac", pi_ac, "opi_ac", opi_ac, "vpred", vpred, "pi_pd", pi_pd,
                #       "opi_pd", opi_pd, "kl_oldnew", kl_oldnew, "var_mean", np.mean(g), "total_loss", total_loss)
                if np.isnan(np.mean(g)):
                    print('output nan, ignore it!')
                else:
                    g_list.append(g)
                    newlosses_list.append(newlosses)

            # 批量计算之后求平均在优化模型
            if len(g_list) > 0:
                avg_g = np.mean(g_list, axis=0)
                self.adam.update(avg_g, self.optim_stepsize * cur_lrmult)
                logger.log(fmt_row(13, np.mean(newlosses_list, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for seg in seg_list:
            self.add_vtarg_and_adv(seg, self.gamma, self.lam)

            # print(seg)

            # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
            ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
                "tdlamret"]
            vpredbefore = seg[
                "vpred"]  # predicted value function before udpate
            atarg = (atarg - atarg.mean()) / atarg.std(
            )  # standardized advantage function estimate
            d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                        shuffle=not self.pi.recurrent)
            # 完整的拿所有行为
            batch = d.next_batch(d.n)
            newlosses = self.compute_losses(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
            losses.append(newlosses)
        print(losses)

        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            if np.isinf(lossval):
                debug = True
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(self.flatten_lists, zip(*listoflrpairs))
        self.lenbuffer.extend(lens)
        self.rewbuffer.extend(rews)
        last_rew = self.rewbuffer[-1] if len(self.rewbuffer) > 0 else 0
        logger.record_tabular("LastRew", last_rew)
        logger.record_tabular(
            "LastLen", 0 if len(self.lenbuffer) <= 0 else self.lenbuffer[-1])
        logger.record_tabular("EpLenMean", np.mean(self.lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(self.rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        self.episodes_so_far += len(lens)
        self.timesteps_so_far += sum(lens)
        self.iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", self.episodes_so_far)
        logger.record_tabular("TimestepsSoFar", self.timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - self.tstart)
        logger.record_tabular("IterSoFar", self.iters_so_far)
        logger.record_tabular("CalulateActions", self.act_times)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()
def learn(env, policy_fn, *,
        timesteps_per_actorbatch, # timesteps per actor per update
        clip_param, entcoeff, # clipping parameter epsilon, entropy coeff
        optim_epochs, optim_stepsize, optim_batchsize,# optimization hypers
        gamma, lam, # advantage estimation
        max_timesteps=0, max_episodes=0, max_iters=0, max_seconds=0,  # time constraint
        callback=None, # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant' # annealing for stepsize parameters (epsilon and adam)
        ):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_fn("pi", ob_space, ac_space) # Construct network for new policy
    oldpi = policy_fn("oldpi", ob_space, ac_space) # Network for old policy
    atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return

    lrmult = tf.placeholder(name='lrmult', dtype=tf.float32, shape=[]) # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # pnew / pold
    surr1 = ratio * atarg # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg #
    pol_surr = - tf.reduce_mean(tf.minimum(surr1, surr2)) # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult], losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function([],[], updates=[tf.assign(oldv, newv)
        for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    adam.sync()

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi, env, timesteps_per_actorbatch, stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100) # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100) # rolling buffer for episode rewards

    assert sum([max_iters>0, max_timesteps>0, max_episodes>0, max_seconds>0])==1, "Only one time constraint permitted"

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult =  max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************"%iters_so_far)

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"]
        vpredbefore = seg["vpred"] # predicted value function before udpate
        atarg = (atarg - atarg.mean()) / atarg.std() # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret), shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy

        assign_old_eq_new() # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [] # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult)
            losses.append(newlosses)
        meanlosses,_,_ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_"+name, lossval)
        logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        if MPI.COMM_WORLD.Get_rank()==0:
            logger.dump_tabular()
    def update_constraint_network(self, do_train_cnet):
        '''
        Update the constraint network and control learning rate decay / training early stop
        depending on training loss and accuracy
        '''
        if not do_train_cnet:
            activation_probability_after = None  # unused
            self.logger.log('Skip training CNet this iteration')
            return do_train_cnet, activation_probability_after
        best_cnet_losses = None
        best_classification_accuracy = [0., 0., 0.]
        i_epoch_best = 0
        if self.extra_args.cnet_decay_epochs > 0:
            n_epochs_without_improvement = 0
        for i_epoch in range(self.extra_args.cnet_training_epochs):
            if (i_epoch % 10) == 0:
                epoch_str = 'Iter {0}, epoch {1}/{2}'.format(
                    self.iters_so_far, i_epoch,
                    self.extra_args.cnet_training_epochs)
                epoch_str += ' ({0}: positive {1:.1f} %, negative {2:.1f} %, avg {3:.1f} % out of {4} / {5} / {6}'.format(
                    i_epoch_best, 100. * best_classification_accuracy[0],
                    100. * best_classification_accuracy[1],
                    100. * best_classification_accuracy[2],
                    self.n_positive_demonstrations_global,
                    self.n_negative_demonstrations_global,
                    self.n_total_demonstrations_global)
                self.logger.log(epoch_str)
                self.logger.log(fmt_row(13, self.cnet_loss_and_accuracy_names))

            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            training_classification_accuracy = []
            #for batch in d.iterate_once(optim_batchsize):
            if self.do_dummy_cnet_update:
                batch_generator = self.n_batches_this_epoch * [None]
            else:
                batch_generator = self.constraint_demonstration_buffer.iterate_epoch_balanced(
                    self.extra_args.cnet_batch_size,
                    n_max=self.n_max_demonstrations_global)
            for i_batch, batch_demonstrations in enumerate(batch_generator):
                if self.do_dummy_cnet_update:
                    newlosses, g = self.cnet_newlosses_dummy, self.cnet_g_dummy
                    new_training_classification_accuracy = self.cnet_scores_dummy
                else:
                    batch_cnet_observations = np.stack(
                        [_e.state for _e in batch_demonstrations])
                    batch_cnet_actions = np.stack(
                        [_e.action for _e in batch_demonstrations])
                    batch_cnet_action_indicators = np.array(
                        [_e.action_indicator for _e in batch_demonstrations])
                    *newlosses, g = self.cnet_lossandgrad(
                        batch_cnet_observations, batch_cnet_actions,
                        batch_cnet_action_indicators)
                    training_classification_scores = self.cnet_compute_scores(
                        batch_cnet_observations, batch_cnet_actions,
                        batch_cnet_action_indicators)
                    new_training_classification_accuracy = self.calc_classification_accuracy(
                        *training_classification_scores)
                self.cnet_adam.update(
                    g, self.extra_args.cnet_learning_rate *
                    self.cnet_cur_lrmult_epoch)
                losses.append(newlosses)
                training_classification_accuracy.append(
                    new_training_classification_accuracy)
                if i_batch >= (self.n_batches_this_epoch - 1):
                    break
            mean_cnet_losses, _, _ = mpi_moments(losses, axis=0)
            mean_classification_accuracy, _, _ = mpi_moments(
                training_classification_accuracy, axis=0)
            mean_losses_and_accuracy = np.concatenate(
                [mean_cnet_losses, mean_classification_accuracy])
            self.logger.log(fmt_row(13, mean_losses_and_accuracy))
            if self.test_early_stop_classification_accuracy(
                    i_epoch, mean_classification_accuracy[0],
                    mean_classification_accuracy[1]):
                break
            if (i_epoch == 0) or self.test_improvement_classification_accuracy(
                    mean_cnet_losses, best_cnet_losses,
                    mean_classification_accuracy,
                    best_classification_accuracy):
                if i_epoch == 0:
                    mean_classification_accuracy_first_epoch = mean_classification_accuracy
                i_epoch_best = i_epoch
                best_cnet_losses = mean_cnet_losses
                best_classification_accuracy = mean_classification_accuracy
                n_epochs_without_improvement = 0
            else:
                n_epochs_without_improvement += 1
            if (self.extra_args.cnet_decay_epochs > 0):
                if (n_epochs_without_improvement >=
                        self.extra_args.cnet_decay_epochs):
                    if self.cnet_cur_lrmult_epoch > self.extra_args.cnet_decay_max:
                        self.cnet_cur_lrmult_epoch *= 0.5
                        i_epoch_best = i_epoch
                        best_cnet_losses = mean_cnet_losses
                        best_classification_accuracy = mean_classification_accuracy
                        n_epochs_without_improvement = 0
                        if self.cnet_cur_lrmult_epoch >= self.extra_args.cnet_decay_max:
                            self.logger.log(
                                'Halve CNet learning rate multiplier to {0}'.
                                format(self.cnet_cur_lrmult_epoch))
                        else:
                            self.cnet_cur_lrmult_epoch = self.extra_args.cnet_decay_max
                            self.logger.log(
                                'Keep CNet learning rate multiplier to {0}'.
                                format(self.cnet_cur_lrmult_epoch))
                    else:
                        self.logger.log(
                            'No improvement at max decay {0} for {1} epochs'.
                            format(self.cnet_cur_lrmult_epoch,
                                   self.extra_args.cnet_decay_epochs))
                        break

        self.logger.log('Evaluating CNet losses...')
        cnet_test_losses_after, activation_probability_after = self.evaluate_cnet_losses(
            do_log=True)

        if self.do_check_interrupt_cnet_training:
            if self.cnet_interruption_type == 'accuracy':
                do_train_cnet = self.check_interruption(
                    do_train_cnet, activation_probability_after)
            else:
                assert self.cnet_interruption_type == 'prior_accuracy'

        return do_train_cnet, activation_probability_after
Exemple #18
0
def learn(env, policy_func, *,
        timesteps_per_batch, # timesteps per actor per update
        clip_param, entcoeff, # clipping parameter epsilon, entropy coeff
        optim_epochs, optim_stepsize, optim_batchsize, # optimization hypers
        gamma, lam, # advantage estimation
        max_timesteps=0, max_episodes=0, max_iters=0, max_seconds=0, # time constraint
        noisy_nets=False,
        callback=None, # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant', # annealing for stepsize parameters (epsilon and adam)
        desired_kl=0.02,
        logdir=".",
        agentName="PPO-Agent",
        resume = 0,
        num_parallel=1,
        num_cpu=1
        ):
    # Setup losses and stuff
    # ----------------------------------------
    rank = MPI.COMM_WORLD.Get_rank()
    ob_space = env.observation_space
    ac_space = env.action_space

    ob_size = ob_space.shape[0]
    ac_size = ac_space.shape[0]

    #print("rank = " + str(rank) + " ob_space = "+str(ob_space.shape) + " ac_space = "+str(ac_space.shape))
    #exit(0)
    pi = policy_func("pi", ob_space, ac_space, noisy_nets) # Construct network for new policy
    oldpi = policy_func("oldpi", ob_space, ac_space, noisy_nets) # Network for old policy
    atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return

    lrmult = tf.placeholder(name='lrmult', dtype=tf.float32, shape=[]) # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = U.mean(kloldnew)
    meanent = U.mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # pnew / pold
    surr1 = ratio * atarg # surrogate from conservative policy iteration
    surr2 = U.clip(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg #
    pol_surr = - U.mean(tf.minimum(surr1, surr2)) # PPO's pessimistic surrogate (L^CLIP)
    vfloss1 = tf.square(pi.vpred - ret)
    vpredclipped = oldpi.vpred + tf.clip_by_value(pi.vpred - oldpi.vpred, -clip_param, clip_param)
    vfloss2 = tf.square(vpredclipped - ret)
    vf_loss = .5 * U.mean(tf.maximum(vfloss1, vfloss2)) # we do the same clipping-based trust region for the value function
    #vf_loss = U.mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult], losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function([],[], updates=[tf.assign(oldv, newv)
        for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    adam.sync()

    # Prepare for rollouts
    # ----------------------------------------
    if noisy_nets:
        stochastic = False
    else:
        stochastic = True
    seg_gen = traj_segment_generator(pi, env, timesteps_per_batch, stochastic=stochastic, num_parallel=num_parallel, num_cpu=num_cpu, rank=rank, ob_size=ob_size, ac_size=ac_size,com=MPI.COMM_WORLD)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100) # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100) # rolling buffer for episode rewards

    saver = tf.train.Saver()
    if resume > 0:
        saver.restore(tf.get_default_session(), os.path.join(os.path.abspath(logdir), "{}-{}".format(agentName, resume)))
    iters_so_far = resume
    assert sum([max_iters>0, max_timesteps>0, max_episodes>0, max_seconds>0])==1, "Only one time constraint permitted"

    logF = open(os.path.join(logdir, 'log.txt'), 'a')
    logStats = open(os.path.join(logdir, 'log_stats.txt'), 'a')

    dump_training = 0
    learn_from_training = 0
    if dump_training:
        if os.path.exists(logdir + "\\" + 'ob_list_' + str(rank) + '.pkl'):
            with open(logdir + "\\" +'ob_list_' + str(rank) + '.pkl', 'rb') as f:
                ob_list = pickle.load(f)
        else:
            ob_list = []

        # , "mean": pi.ob_rms.mean, "std": pi.ob_rms.std
        saverRMS = tf.train.Saver({"_sum": pi.ob_rms._sum, "_sumsq": pi.ob_rms._sumsq, "_count": pi.ob_rms._count})
        saverRMS.save(tf.get_default_session(), os.path.join(os.path.abspath(logdir), "rms.tf"))

        ob_np_a = np.asarray(ob_list)
        ob_np = np.reshape(ob_np_a, (-1,ob_size))
        [vpred, pdparam] = pi._vpred_pdparam(ob_np)

        print("vpred = " + str(vpred))
        print("pd_param = " + str(pdparam))
        with open('training.pkl', 'wb') as f:
            pickle.dump(ob_np, f)
            pickle.dump(vpred, f)
            pickle.dump(pdparam, f)
        exit(0)

    if learn_from_training:
        # , "mean": pi.ob_rms.mean, "std": pi.ob_rms.std
        with open('training.pkl', 'rb') as f:
            ob_np = pickle.load(f)
            vpred = pickle.load(f)
            pdparam = pickle.load(f)
        num = ob_np.shape[0]
        for i in range(num):
            xp = ob_np[i][1]
            ob_np[i][1] = 0.0
            ob_np[i][18] -= xp
            ob_np[i][22] -= xp
            ob_np[i][24] -= xp
            ob_np[i][26] -= xp
            ob_np[i][28] -= xp
            ob_np[i][30] -= xp
            ob_np[i][32] -= xp
            ob_np[i][34] -= xp
        print("ob_np = " + str(ob_np))
        print("vpred = " + str(vpred))
        print("pdparam = " + str(pdparam))
        batch_size = 128

        y_vpred = tf.placeholder(tf.float32, [batch_size, ])
        y_pdparam = tf.placeholder(tf.float32, [batch_size, pdparam.shape[1]])

        vpred_loss = U.mean(tf.square(pi.vpred - y_vpred))
        vpdparam_loss = U.mean(tf.square(pi.pdparam - y_pdparam))

        total_train_loss = vpred_loss + vpdparam_loss
        #total_train_loss = vpdparam_loss
        #total_train_loss = vpred_loss
        #coef = 0.01
        #dense_all = U.dense_all
        #for a in dense_all:
        #   total_train_loss += coef * tf.nn.l2_loss(a)
        #total_train_loss = vpdparam_loss
        optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(total_train_loss)
        d = Dataset(dict(ob=ob_np, vpred=vpred, pdparam=pdparam), shuffle=not pi.recurrent)
        sess = tf.get_default_session()
        sess.run(tf.global_variables_initializer())
        saverRMS = tf.train.Saver({"_sum": pi.ob_rms._sum, "_sumsq": pi.ob_rms._sumsq, "_count": pi.ob_rms._count})
        saverRMS.restore(tf.get_default_session(), os.path.join(os.path.abspath(logdir), "rms.tf"))
        if resume > 0:
            saver.restore(tf.get_default_session(),
                          os.path.join(os.path.abspath(logdir), "{}-{}".format(agentName, resume)))
        for q in range(100):
            sumLoss = 0
            for batch in d.iterate_once(batch_size):
                tl, _ = sess.run([total_train_loss, optimizer], feed_dict={pi.ob: batch["ob"], y_vpred: batch["vpred"], y_pdparam:batch["pdparam"]})
                sumLoss += tl
            print("Iteration " + str(q)+ " Loss = " + str(sumLoss))
        assign_old_eq_new()  # set old parameter values to new parameter values

        # Save as frame 1
        try:
            saver.save(tf.get_default_session(), os.path.join(logdir, agentName), global_step=1)
        except:
            pass
        #exit(0)

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'adaptive' or 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult =  max(1.0 - float(timesteps_so_far) / max_timesteps, 0.0)
        elif schedule == 'linear_clipped':
            cur_lrmult =  max(1.0 - float(timesteps_so_far) / max_timesteps, 0.2)
        elif schedule == 'cyclic':
        #    cur_lrmult =  max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
            raise NotImplementedError
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************"%iters_so_far)

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam, timesteps_per_batch, num_parallel, num_cpu)
        #print(" ob= " + str(seg["ob"])+ " rew= " + str(seg["rew"])+ " vpred= " + str(seg["vpred"])+ " new= " + str(seg["new"])+ " ac= " + str(seg["ac"])+ " prevac= " + str(seg["prevac"])+ " nextvpred= " + str(seg["nextvpred"])+ " ep_rets= " + str(seg["ep_rets"])+ " ep_lens= " + str(seg["ep_lens"]))

        #exit(0)
        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"]

        if dump_training:
            ob_list.append(ob.tolist())
        vpredbefore = seg["vpred"] # predicted value function before udpate
        atarg = (atarg - atarg.mean()) / atarg.std() # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret), shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy

        assign_old_eq_new() # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [] # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult)
                if desired_kl != None and schedule == 'adaptive':
                    if newlosses[-2] > desired_kl * 2.0:
                        optim_stepsize = max(1e-8, optim_stepsize / 1.5)
                        print('kl divergence was too large = ', newlosses[-2])
                        print('New optim_stepsize = ', optim_stepsize)
                    elif newlosses[-2] < desired_kl / 2.0:
                        optim_stepsize = min(1e0, optim_stepsize * 1.5)
                        print('kl divergence was too small = ', newlosses[-2])
                        print('New optim_stepsize = ', optim_stepsize)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            #print(str(losses))
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult)
            losses.append(newlosses)
        meanlosses,_,_ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))

        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_"+name, lossval)
        logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        rewmean = np.mean(rewbuffer)
        logger.record_tabular("EpRewMean", rewmean)
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)

        if dump_training:
            with open(logdir + "\\" + 'ob_list_' + str(rank) + '.pkl', 'wb') as f:
                pickle.dump(ob_list, f)

        if MPI.COMM_WORLD.Get_rank()==0:
            logF.write(str(rewmean) + "\n")
            logStats.write(logger.get_str() + "\n")
            logF.flush()
            logStats.flush()

            logger.dump_tabular()

            try:
                os.remove(logdir + "/checkpoint")
            except OSError:
                pass
            try:
                saver.save(tf.get_default_session(), os.path.join(logdir, agentName), global_step=iters_so_far)
            except:
                pass
def learn(
        env,
        policy_func,
        *,
        timesteps_per_batch,  # timesteps per actor per update
        log_every=None,
        log_dir=None,
        episodes_so_far=0,
        timesteps_so_far=0,
        iters_so_far=0,
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        optim_epochs,
        optim_stepsize,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,  # advantage estimation
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        max_seconds=0,  # time constraint
        callback=None,  # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
        **kwargs):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_func("pi", ob_space,
                     ac_space)  # Construct network for new policy
    oldpi = policy_func("oldpi", ob_space, ac_space)  # Network for old policy
    # Target advantage function (if applicable)
    atarg = tf.placeholder(dtype=tf.float32, shape=[None])
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    # learning rate multiplier, updated with schedule
    lrmult = tf.placeholder(name='lrmult', dtype=tf.float32, shape=[])
    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = U.mean(kloldnew)
    meanent = U.mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = U.clip(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg
    pol_surr = -U.mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = U.mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    adam.sync()

    # Prepare for rollouts
    # ----------------------------------------
    # GRASPING
    saver = tf.train.Saver(var_list=U.ALREADY_INITIALIZED, max_to_keep=1)
    checkpoint = tf.train.latest_checkpoint(log_dir)
    if checkpoint:
        print("Restoring checkpoint: {}".format(checkpoint))
        saver.restore(U.get_session(), checkpoint)
    if hasattr(env, "set_actor"):

        def actor(obs):
            return pi.act(False, obs)[0]

        env.set_actor(actor)
    if not checkpoint and hasattr(env, "warm_init_eps"):
        pretrain(pi, env)
        saver.save(U.get_session(), osp.join(log_dir, "model"))
    # /GRASPING
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_batch,
                                     stochastic=True)

    tstart = time.time()

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    while True:
        if callback:
            callback(locals(), globals())
        should_break = False
        if max_timesteps and timesteps_so_far >= max_timesteps:
            should_break = True
        elif max_episodes and episodes_so_far >= max_episodes:
            should_break = True
        elif max_iters and iters_so_far >= max_iters:
            should_break = True
        elif max_seconds and time.time() - tstart >= max_seconds:
            should_break = True

        if log_every and log_dir:
            if (iters_so_far + 1) % log_every == 0 or should_break:
                # To reduce space, don't specify global step.
                saver.save(U.get_session(), osp.join(log_dir, "model"))

            job_info = {
                'episodes_so_far': episodes_so_far,
                'iters_so_far': iters_so_far,
                'timesteps_so_far': timesteps_so_far
            }
            with open(osp.join(log_dir, "job_info_new.yaml"), 'w') as file:
                yaml.dump(job_info, file, default_flow_style=False)
                # Make sure write is instantaneous.
                file.flush()
                os.fsync(file)
            os.rename(osp.join(log_dir, "job_info_new.yaml"),
                      osp.join(log_dir, "job_info.yaml"))

        if should_break:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()) / (
            atarg.std() + 1e-10)  # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        # logger.log("Optimizing...")
        # logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            # logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        logger.record_tabular("EpLenMean", np.mean(lens))
        logger.record_tabular("EpRewMean", np.mean(rews))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()
Exemple #20
0
def learn(env, policy_fn, *,
          timesteps_per_actorbatch,  # timesteps per actor per update
          clip_param, entcoeff,  # clipping parameter epsilon, entropy coeff
          optim_epochs, optim_stepsize, optim_batchsize,  # optimization hypers
          gamma, lam,  # advantage estimation
          max_timesteps=0, max_episodes=0, max_iters=0, max_seconds=0,  # time constraint
          callback=None,  # you can do anything in the callback, since it takes locals(), globals()
          adam_epsilon=1e-5,
          schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
          model_dir_base='../tf_models/',
          is_train=True):

    # tensorboard summary writer & model saving path
    i = 1
    while is_train:
        if not os.path.exists(model_dir_base + str(i)):
            model_dir = model_dir_base + str(i)
            os.makedirs(model_dir)
            break
        else:
            i += 1

    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_fn("pi", ob_space, ac_space)  # Construct network for new policy
    oldpi = policy_fn("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(dtype=tf.float32, shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(name='lrmult', dtype=tf.float32,
                            shape=[])  # learning rate multiplier, updated with schedule

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    # KL entropy loss
    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    # Clip loss
    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg
    pol_surr = - tf.reduce_mean(tf.minimum(surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)

    # value function loss
    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    # define tensorboard summary scalars
    with tf.name_scope('loss'):
        summary_list_loss = [tf.summary.scalar('total_loss', total_loss)]
        for name in loss_names:
            i = loss_names.index(name)
            summary_list_loss.append(tf.summary.scalar('loss_' + name, losses[i]))
        summary_merged_loss = tf.summary.merge(summary_list_loss)
    with tf.name_scope('reward'):
        reward_total_ph = tf.placeholder(tf.float32, shape=())
        reward_comf_ph = tf.placeholder(tf.float32, shape=())
        reward_effi_ph = tf.placeholder(tf.float32, shape=())
        reward_safety_ph = tf.placeholder(tf.float32, shape=())
        summary_list_reward = [tf.summary.scalar('reward_total', reward_total_ph)]
        summary_list_reward.extend([tf.summary.scalar('reward_comf', reward_comf_ph),
                                    tf.summary.scalar('reward_effi', reward_effi_ph),
                                    tf.summary.scalar('reward_safety', reward_safety_ph)])
        summary_merged_reward = tf.summary.merge(summary_list_reward)
    with tf.name_scope('observation'):
        ego_speed_ph = tf.placeholder(tf.float32, shape=())
        ego_latPos_ph = tf.placeholder(tf.float32, shape=())
        ego_acce_ph = tf.placeholder(tf.float32, shape=())
        #dis2origLeader_ph = tf.placeholder(tf.float32, shape=())
        #dis2trgtLeader_ph = tf.placeholder(tf.float32, shape=())
        #obs_ph_list = [ego_speed_ph, ego_latPos_ph, ego_acce_ph, dis2origLeader_ph, dis2trgtLeader_ph]
        obs_ph_list = [ego_speed_ph, ego_latPos_ph, ego_acce_ph]
        #obs_name_list = ['ego_speed', 'ego_latPos', 'ego_acce', 'dis2origLeader', 'dis2trgtLeader']
        obs_name_list = ['ego_speed', 'ego_latPos', 'ego_acce']
        summary_list_obs = [tf.summary.histogram(name, ph) for name, ph in zip(obs_name_list, obs_ph_list)]
        summary_merged_obs = tf.summary.merge(summary_list_obs)
    # with tf.name_scope('action'):
    #     ac_ph = tf.placeholder(tf.int32, shape=())
    #     summary_list_ac = [tf.summary.histogram('longitudinal', tf.floordiv(ac_ph, 3)),
    #                        tf.summary.histogram('lateral', tf.floormod(ac_ph, 3))]
    #     summary_merged_acs = tf.summary.merge(summary_list_ac)

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult], losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function([], [], updates=[tf.assign(oldv, newv)
                                                    for (oldv, newv) in
                                                    zipsame(oldpi.get_variables(), pi.get_variables())])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    adam.sync()

    if is_train:
        sess = U.get_session()
        summary_writer = tf.summary.FileWriter(model_dir, sess.graph)
        saver = tf.train.Saver(max_to_keep=10)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi, env, timesteps_per_actorbatch, stochastic=True)
    seg_gen_test = traj_segment_generator(pi, env, timesteps_per_actorbatch, stochastic=False)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0, max_seconds > 0]) == 1, \
        "Only one time constraint permitted"

    while is_train:
        if callback:
            callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        print("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret), deterministic=pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values

        print("Optimizing...")
        print(fmt_row(13, loss_names))

        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses_batch = []  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses_batch.append(newlosses)
            print(fmt_row(13, np.mean(losses_batch, axis=0)))

        if iters_so_far % 10 == 0:
            # evaluate losses
            print("Evaluating losses...")
            losses_batch = []
            for batch in d.iterate_once(optim_batchsize):
                newlosses = compute_losses(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult)
                losses_batch.append(newlosses)
            meanlosses, _, _ = mpi_moments(losses_batch, axis=0)
            print(fmt_row(13, meanlosses))
            for (lossval, name) in zipsame(meanlosses, loss_names):
                print("loss_" + name, lossval)
            # write loss summaries
            summary_eval_loss = sess.run(summary_merged_loss, feed_dict={i: d for i, d in zip(losses, meanlosses)})
            summary_writer.add_summary(summary_eval_loss, iters_so_far)

            # evaluate reward
            ret_eval = 0
            ret_det_eval = 0  # not a integer, will be broadcasted
            for i in range(5):
                seg_test = seg_gen_test.__next__()
                ret_eval += np.mean(seg_test['ep_rets'])
                ep_rets_detail_np = np.vstack([ep_ret_detail for ep_ret_detail in seg_test['ep_rets_detail']])
                ret_det_eval += np.mean(ep_rets_detail_np, axis=0)
            ret_eval /= 5.0
            ret_det_eval /= 5.0
            summary_eval_reward = sess.run(summary_merged_reward, feed_dict={reward_total_ph: ret_eval,
                                                                             reward_comf_ph: ret_det_eval[0],
                                                                             reward_effi_ph: ret_det_eval[1],
                                                                             reward_safety_ph: ret_det_eval[2]})
            summary_writer.add_summary(summary_eval_reward, iters_so_far)

            # save model
            saver.save(sess, model_dir + '/model.ckpt', global_step=iters_so_far)

        # for ep_ret, ep_ret_detail in zip(seg['ep_rets'], seg['ep_rets_detail']):
        #     summary_eval_reward = sess.run(summary_merged_reward, feed_dict={reward_total_ph: ep_ret,
        #                                                                      reward_comf_ph: ep_ret_detail[0],
        #                                                                      reward_effi_ph: ep_ret_detail[1],
        #                                                                      reward_safety_ph: ep_ret_detail[2]})
        #     summary_writer.add_summary(summary_eval_reward, episodes_so_far)
        #     episodes_so_far += 1
        # write observation and action summaries
        assert len(seg['ac']) == len(seg['ob'])
        for ac, ob in zip(seg['ac'], seg['ob']):
            summary_eval_obs = sess.run(summary_merged_obs, feed_dict={ego_speed_ph: ob[1],
                                                                       ego_latPos_ph: ob[2],
                                                                       ego_acce_ph: ob[3]})
                                                                       #dis2origLeader_ph: ob[4] - ob[0],
                                                                       #dis2trgtLeader_ph: ob[12] - ob[0]})
            summary_writer.add_summary(summary_eval_obs, timesteps_so_far)
            #summary_eval_acs = sess.run(summary_merged_acs, feed_dict={ac_ph: ac})
            #summary_writer.add_summary(summary_eval_acs, timesteps_so_far)
            timesteps_so_far += 1

        # todo: investigate MPI
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        episodes_so_far += len(lens)
        iters_so_far += 1

        # if iters_so_far % 10 == 0:
        #     saver.save(sess, model_dir + '/model.ckpt', global_step=iters_so_far)
    return pi
Exemple #21
0
def learn(
        env,
        policy_func,
        timesteps_per_batch,  # timesteps per actor per update
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        optim_epochs,
        optim_stepsize,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,  # advantage estimation
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        max_seconds=0,  # time constraint
        callback=None,  # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
        load_model=None,
        action_bias=0.4,
        action_repeat=0,
        action_repeat_rand=False,
        warmup_frames=0,
        target_kl=0.01,
        vf_loss_mult=1,
        vfloss_optim_stepsize=0.003,
        vfloss_optim_batchsize=8,
        vfloss_optim_epochs=10):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_func("pi", ob_space,
                     ac_space)  # Construct network for new policy
    oldpi = policy_func("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    # Not sure why they anneal clip and learning rate with the same parameter
    #clip_param = clip_param * lrmult # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = U.mean(kloldnew)
    meanent = U.mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = U.clip(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg  #
    pol_surr = -U.mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = U.mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen
    losses = [pol_surr, pol_entpen, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    lossandgrad_vfloss = U.function([ob, ac, atarg, ret], [vf_loss] +
                                    [U.flatgrad(vf_loss, var_list)])
    adam_vfloss = MpiAdam(var_list, epsilon=adam_epsilon)
    compute_vfloss = U.function([ob, ac, atarg, ret], [vf_loss])

    U.initialize()
    adam.sync()
    adam_vfloss.sync()

    if load_model:
        logger.log('Loading model: %s' % load_model)
        pi.load(load_model)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_batch,
                                     stochastic=True,
                                     action_bias=action_bias,
                                     action_repeat=action_repeat,
                                     action_repeat_rand=action_repeat_rand,
                                     warmup_frames=warmup_frames)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    ep_rew_file = None
    if MPI.COMM_WORLD.Get_rank() == 0:
        import wandb
        ep_rew_file = open(
            os.path.join(wandb.run.dir, 'episode_rewards.jsonl'), 'w')
        checkpoint_dir = 'checkpoints-%s' % wandb.run.id
        os.mkdir(checkpoint_dir)

    cur_lrmult = 1.0
    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        elif schedule == 'target_kl':
            pass
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.next()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                result = lossandgrad(batch["ob"], batch["ac"], batch["atarg"],
                                     batch["vtarg"], cur_lrmult)
                newlosses = result[:-1]
                g = result[-1]
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        # vfloss optimize
        logger.log("Optimizing value function...")
        logger.log(fmt_row(13, ['vf']))
        for _ in range(vfloss_optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(vfloss_optim_batchsize):
                result = lossandgrad_vfloss(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"])
                newlosses = result[:-1]
                g = result[-1]
                adam_vfloss.update(g, vfloss_optim_stepsize)
                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult)
            newlosses += compute_vfloss(batch["ob"], batch["ac"],
                                        batch["atarg"], batch["vtarg"])
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names + ['vf']):
            logger.record_tabular("loss_" + name, lossval)
        # check kl
        if schedule == 'target_kl':
            if meanlosses[2] > target_kl * 1.1:
                cur_lrmult /= 1.5
            elif meanlosses[2] < target_kl / 1.1:
                cur_lrmult *= 1.5
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        if rewbuffer:
            logger.record_tabular('CurLrMult', cur_lrmult)
            logger.record_tabular('StepSize', optim_stepsize * cur_lrmult)
            logger.record_tabular("EpLenMean", np.mean(lenbuffer))
            logger.record_tabular("EpRewMax", np.max(rewbuffer))
            logger.record_tabular("EpRewMean", np.mean(rewbuffer))
            logger.record_tabular("EpRewMin", np.min(rewbuffer))
            logger.record_tabular("EpThisIter", len(lens))
            episodes_so_far += len(lens)
            timesteps_so_far += sum(lens)
            iters_so_far += 1
            logger.record_tabular("EpisodesSoFar", episodes_so_far)
            logger.record_tabular("TimestepsSoFar", timesteps_so_far)
            time_elapsed = time.time() - tstart
            logger.record_tabular("TimeElapsed", time_elapsed)
            if MPI.COMM_WORLD.Get_rank() == 0:
                import wandb
                ep_rew_file.write('%s\n' % json.dumps({
                    'TimeElapsed': time_elapsed,
                    'Rewards': rews
                }))
                ep_rew_file.flush()
                data = logger.Logger.CURRENT.name2val
                wandb.run.history.add(data)
                summary_data = {}
                for k, v in data.iteritems():
                    if 'Rew' in k:
                        summary_data[k] = v
                wandb.run.summary.update(summary_data)
                pi.save(
                    os.path.join(checkpoint_dir,
                                 'model-%s.ckpt' % (iters_so_far - 1)))

                logger.dump_tabular()
        else:
            logger.log('No episodes complete yet')
Exemple #22
0
    def train(self, seg, optim_batchsize, optim_epochs):
        #normalize the reward
        rffs_int = np.array(
            [self.rff_int.update(rew) for rew in seg["rew_int"]])
        self.rff_rms_int.update(rffs_int.ravel())
        seg["rew_int"] = seg["rew_int"] / np.sqrt(self.rff_rms_int.var)

        cur_lrmult = 1.0
        add_vtarg_and_adv(seg, self.gamma, self.lam)
        ob, unnorm_ac, atarg_ext, tdlamret_ext, atarg_int, tdlamret_int = seg[
            "ob"], seg["unnorm_ac"], seg["adv_ext"], seg["tdlamret_ext"], seg[
                "adv_int"], seg["tdlamret_int"]
        vpredbefore_ext, vpredbefore_int = seg["vpred_ext"], seg[
            "vpred_int"]  # predicted value function before udpate
        atarg_ext = (atarg_ext - atarg_ext.mean()) / atarg_ext.std(
        )  # standardized advantage function estimate
        atarg_int = (atarg_int - atarg_int.mean()) / atarg_int.std()
        atarg = self.int_coeff * atarg_int + self.ext_coeff * atarg_ext

        d = Dataset(dict(ob=ob,
                         ac=unnorm_ac,
                         atarg=atarg,
                         vtarg_ext=tdlamret_ext,
                         vtarg_int=tdlamret_int),
                    shuffle=not self.pi.recurrent)

        if hasattr(self.pi, "ob_rms"):
            self.pi.update_obs_rms(ob)  # update running mean/std for policy
        if hasattr(self.int_rew, "ob_rms"):
            self.int_rew.update_obs_rms(
                ob)  #update running mean/std for int_rew
        self.assign_old_eq_new(
        )  # set old parameter values to new parameter values
        logger.log2("Optimizing...")
        logger.log2(fmt_row(13, self.loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                lg = self.lossandgrad(batch["ac"], batch["atarg"],
                                      batch["vtarg_ext"], batch["vtarg_int"],
                                      cur_lrmult, *zip(*batch["ob"].tolist()))
                new_losses, g = lg[:-1], lg[-1]
                self.adam.update(g, self.optim_stepsize * cur_lrmult)
                losses.append(new_losses)
            logger.log2(fmt_row(13, np.mean(losses, axis=0)))

        logger.log2("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = self.compute_losses(batch["ac"], batch["atarg"],
                                            batch["vtarg_ext"],
                                            batch["vtarg_int"], cur_lrmult,
                                            *zip(*batch["ob"].tolist()))
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log2(fmt_row(13, meanlosses))

        for (lossval, name) in zipsame(meanlosses, self.loss_names):
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular(
            "ev_tdlam_ext_before",
            explained_variance(vpredbefore_ext, tdlamret_ext))
        return meanlosses
Exemple #23
0
def learn(env, policy_func, reward_giver, expert_dataset, rank, 
          pretrained, pretrained_weight, *, clip_param,
          g_step, d_step, entcoeff, save_per_iter,
          optim_epochs, optim_stepsize, optim_batchsize,# optimization hypers
          ckpt_dir, log_dir, timesteps_per_batch, task_name,
          gamma, lam, d_stepsize=3e-4, adam_epsilon=1e-5,
          max_timesteps=0, max_episodes=0, max_iters=0,
          mix_reward=False, r_lambda=0.44,
          callback=None,
          schedule='constant', # annealing for stepsize parameters (epsilon and adam),
          frame_stack=1
          ):

    nworkers = MPI.COMM_WORLD.Get_size()
    rank = MPI.COMM_WORLD.Get_rank()
    np.set_printoptions(precision=3)
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ob_space.shape = (ob_space.shape[0] * frame_stack,)
    print(ob_space)
    ac_space = env.action_space
    pi = policy_func("pi", ob_space, ac_space, reuse=(pretrained_weight != None))
    oldpi = policy_func("oldpi", ob_space, ac_space)
    atarg = tf.placeholder(dtype=tf.float32, shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(name='lrmult', dtype=tf.float32, shape=[]) # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    # kloldnew = oldpi.pd.kl(pi.pd)
    # ent = pi.pd.entropy()
    # meankl = tf.reduce_mean(kloldnew)
    # meanent = tf.reduce_mean(ent)
    # entbonus = entcoeff * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # pnew / pold
    surr1 = ratio * atarg # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg #
    pol_surr = - tf.reduce_mean(tf.minimum(surr1, surr2)) # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    # vferr = tf.reduce_mean(tf.square(pi.vpred - ret))

    # ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # advantage * pnew / pold
    # surrgain = tf.reduce_mean(ratio * atarg)

    # optimgain = surrgain + entbonus
    # losses = [optimgain, meankl, entbonus, surrgain, meanent]
    # loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult], losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)
    d_adam = MpiAdam(reward_giver.get_trainable_variables())

    assign_old_eq_new = U.function([],[], updates=[tf.assign(oldv, newv)
        for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    # dist = meankl

    # all_var_list = pi.get_trainable_variables()
    # var_list = [v for v in all_var_list if v.name.startswith("pi/pol") or v.name.startswith("pi/logstd")]
    # vf_var_list = [v for v in all_var_list if v.name.startswith("pi/vff")]
    # assert len(var_list) == len(vf_var_list) + 1
    # d_adam = MpiAdam(reward_giver.get_trainable_variables())
    # vfadam = MpiAdam(vf_var_list)

    # get_flat = U.GetFlat(var_list)
    # set_from_flat = U.SetFromFlat(var_list)
    # klgrads = tf.gradients(dist, var_list)
    # flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan")
    # shapes = [var.get_shape().as_list() for var in var_list]
    # start = 0
    # tangents = []
    # for shape in shapes:
    #     sz = U.intprod(shape)
    #     tangents.append(tf.reshape(flat_tangent[start:start+sz], shape))
    #     start += sz
    # gvp = tf.add_n([tf.reduce_sum(g*tangent) for (g, tangent) in zipsame(klgrads, tangents)])  # pylint: disable=E1111
    # fvp = U.flatgrad(gvp, var_list)

    # assign_old_eq_new = U.function([], [], updates=[tf.assign(oldv, newv)
    #                                                 for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())])
    # compute_losses = U.function([ob, ac, atarg], losses)
    # compute_lossandgrad = U.function([ob, ac, atarg], losses + [U.flatgrad(optimgain, var_list)])
    # compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp)
    # compute_vflossandgrad = U.function([ob, ret], U.flatgrad(vferr, vf_var_list))

    if rank == 0:
        generator_loss = tf.placeholder(tf.float32, [], name='generator_loss')
        expert_loss = tf.placeholder(tf.float32, [], name='expert_loss')
        entropy = tf.placeholder(tf.float32, [], name='entropy')
        entropy_loss = tf.placeholder(tf.float32, [], name='entropy_loss')
        generator_acc = tf.placeholder(tf.float32, [], name='genrator_acc')
        expert_acc = tf.placeholder(tf.float32, [], name='expert_acc')
        eplenmean = tf.placeholder(tf.int32, [], name='eplenmean')
        eprewmean = tf.placeholder(tf.float32, [], name='eprewmean')
        eptruerewmean = tf.placeholder(tf.float32, [], name='eptruerewmean')
        # _meankl = tf.placeholder(tf.float32, [], name='meankl')
        # _optimgain = tf.placeholder(tf.float32, [], name='optimgain')
        # _surrgain = tf.placeholder(tf.float32, [], name='surrgain')
        _ops_to_merge = [generator_loss, expert_loss, entropy, entropy_loss, generator_acc, expert_acc, eplenmean, eprewmean, eptruerewmean]
        ops_to_merge = [ tf.summary.scalar(op.name, op) for op in _ops_to_merge]

        merged = tf.summary.merge(ops_to_merge)

    ### TODO: report these stats ### 
    #     generator_loss = tf.placeholder(tf.float32, [], name='generator_loss')
    #     expert_loss = tf.placeholder(tf.float32, [], name='expert_loss')
    #     generator_acc = tf.placeholder(tf.float32, [], name='genrator_acc')
    #     expert_acc = tf.placeholder(tf.float32, [], name='expert_acc')
    #     eplenmean = tf.placeholder(tf.int32, [], name='eplenmean')
    #     eprewmean = tf.placeholder(tf.float32, [], name='eprewmean')
    #     eptruerewmean = tf.placeholder(tf.float32, [], name='eptruerewmean')

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(colorize("done in %.3f seconds" % (time.time() - tstart), color='magenta'))
        else:
            yield

    def allmean(x):
        assert isinstance(x, np.ndarray)
        out = np.empty_like(x)
        MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
        out /= nworkers
        return out

    U.initialize()
    adam.sync()
    d_adam.sync()

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi, env, reward_giver, timesteps_per_batch, 
        mix_reward, r_lambda,
        stochastic=True, frame_stack=frame_stack)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards
    true_rewbuffer = deque(maxlen=100)

    assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1

    g_loss_stats = stats(loss_names)
    d_loss_stats = stats(reward_giver.loss_name)
    ep_stats = stats(["True_rewards", "Rewards", "Episode_length"])

    # if provide pretrained weight
    if pretrained_weight is not None:
        U.load_state(pretrained_weight, var_list=pi.get_variables())

    if rank == 0:
        filenames = [f for f in os.listdir(log_dir) if 'logs' in f]
        writer = tf.summary.FileWriter('{}/logs-{}'.format(log_dir, len(filenames)))

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break

        # Save model
        if rank == 0 and iters_so_far % save_per_iter == 0 and ckpt_dir is not None:
            fname = os.path.join(ckpt_dir, task_name)
            os.makedirs(os.path.dirname(fname), exist_ok=True)

            from tensorflow.core.protobuf import saver_pb2
            saver = tf.train.Saver(write_version=saver_pb2.SaverDef.V1)
            saver.save(tf.get_default_session(), fname)
            # U.save_state(fname)

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult =  max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        # ------------------ Update G ------------------
        logger.log("Optimizing Policy...")
        for _ in range(g_step):
            with timed("sampling"):
                seg = seg_gen.__next__()
            add_vtarg_and_adv(seg, gamma, lam)

            # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
            ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"]
            vpredbefore = seg["vpred"] # predicted value function before udpate
            atarg = (atarg - atarg.mean()) / atarg.std() # standardized advantage function estimate
            d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret), shuffle=not pi.recurrent)
            optim_batchsize = optim_batchsize or ob.shape[0]


            # # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
            # ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"]
            # vpredbefore = seg["vpred"]  # predicted value function before udpate
            # atarg = (atarg - atarg.mean()) / atarg.std()  # standardized advantage function estimate

            if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob)  # update running mean/std for policy

            # args = seg["ob"], seg["ac"], atarg
            # fvpargs = [arr[::5] for arr in args]

            assign_old_eq_new()  # set old parameter values to new parameter values

            with timed("policy optimization"):
                logger.log("Optimizing...")
                logger.log(fmt_row(13, loss_names))
                # Here we do a bunch of optimization epochs over the data
                for _ in range(optim_epochs):
                    losses = [] # list of tuples, each of which gives the loss for a minibatch
                    for batch in d.iterate_once(optim_batchsize):
                        *newlosses, g = lossandgrad(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult)
                        adam.update(g, optim_stepsize * cur_lrmult)
                        losses.append(newlosses)
                    logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult)
            losses.append(newlosses)
        meanlosses,_,_ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_"+name, lossval)
        logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))


        # g_losses = meanlosses
        # for (lossname, lossval) in zip(loss_names, meanlosses):
        #     logger.record_tabular(lossname, lossval)
        # logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))


        # ------------------ Update D ------------------
        logger.log("Optimizing Discriminator...")
        logger.log(fmt_row(13, reward_giver.loss_name))
        ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob))
        batch_size = len(ob) // d_step
        d_losses = []  # list of tuples, each of which gives the loss for a minibatch

        for _ in range(optim_epochs // 10):
            for ob_batch, ac_batch in dataset.iterbatches((ob, ac),
                                                          include_final_partial_batch=False,
                                                          batch_size=batch_size):
                ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob_batch))
                # update running mean/std for reward_giver
                ob_batch = ob_batch[:, -ob_expert.shape[1]:][:-30]
                if hasattr(reward_giver, "obs_rms"): reward_giver.obs_rms.update(np.concatenate((ob_batch, ob_expert), 0)[:, :-30])
                # *newlosses, g = reward_giver.lossandgrad(ob_batch, ac_batch, ob_expert, ac_expert)
                *newlosses, g = reward_giver.lossandgrad(ob_batch[:, :-30], ob_expert[:, :-30])
                d_adam.update(allmean(g), d_stepsize)
                d_losses.append(newlosses)
        logger.log(fmt_row(13, np.mean(d_losses, axis=0)))


        lrlocal = (seg["ep_lens"], seg["ep_rets"], seg["ep_true_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews, true_rets = map(flatten_lists, zip(*listoflrpairs))
        true_rewbuffer.extend(true_rets)
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpTrueRewMean", np.mean(true_rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)

        if rank == 0 and iters_so_far % 10 == 0:
            disc_losses = np.mean(d_losses, axis=0)
            res = tf.get_default_session().run(merged, feed_dict={
                generator_loss: disc_losses[0],
                expert_loss: disc_losses[1],
                entropy: disc_losses[2],
                entropy_loss: disc_losses[3],
                generator_acc: disc_losses[4],
                expert_acc: disc_losses[5],
                eplenmean: np.mean(lenbuffer),
                eprewmean: np.mean(rewbuffer),
                eptruerewmean: np.mean(true_rewbuffer),
            })
            writer.add_summary(res, iters_so_far)
            writer.flush()

        if rank == 0:
            logger.dump_tabular()
Exemple #24
0
def learn(
    env,
    policy_fn,
    *,
    timesteps_per_actorbatch,  # timesteps per actor per update
    clip_param,
    entcoeff,  # clipping parameter epsilon, entropy coeff
    optim_epochs,
    optim_stepsize,
    optim_batchsize,  # optimization hypers
    gamma,
    lam,  # advantage estimation
    max_timesteps=0,
    max_episodes=0,
    max_iters=0,
    max_seconds=0,  # time constraint
    callback=None,  # you can do anything in the callback, since it takes locals(), globals()
    adam_epsilon=1e-5,
    schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
    restore_model_from_file=None,
    save_model_with_prefix,  # this is the naming of the saved model file. Usually here we set indication of the target goal:
    # for example 3dof_ppo1_H.
    # That way we can only select which networks we can execute to the real robot. We do not have to send all files or folder.
    # Naming of the model file should be self explanatory.
    job_id=None,  # this variable is used for indentifing Spearmint iteration number. It is usually set by the Spearmint iterator
    outdir="/tmp/rosrl/experiments/continuous/ppo1/"):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_fn("pi", ob_space,
                   ac_space)  # Construct network for new policy
    oldpi = policy_fn("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param,
                             1.0 + clip_param) * atarg  #
    pol_surr = -tf.reduce_mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    adam.sync()
    """
    Here we add a possibility to resume from a previously saved model if a model file is provided
    """
    if restore_model_from_file:
        # saver = tf.train.Saver(tf.all_variables())
        saver = tf.train.import_meta_graph(restore_model_from_file)
        saver.restore(
            tf.get_default_session(),
            tf.train.latest_checkpoint('./'))  #restore_model_from_file)
        logger.log("Loaded model from {}".format(restore_model_from_file))

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_actorbatch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    if save_model_with_prefix:
        if job_id is not None:
            basePath = '/tmp/rosrl/' + str(
                env.__class__.__name__) + '/ppo1/' + job_id
        else:
            basePath = '/tmp/rosrl/' + str(env.__class__.__name__) + '/ppo1/'

    # Create the writer for TensorBoard logs
    summary_writer = tf.summary.FileWriter(outdir,
                                           graph=tf.get_default_graph())

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    deterministic=pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpRewSEM", np.std(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        """
        Save the model at every iteration
        """

        if save_model_with_prefix:
            # if np.mean(rewbuffer) > 10.0:
            if iters_so_far % 10 == 0 or np.mean(rewbuffer) > 10.0:
                basePath = outdir + "/models/"

                if not os.path.exists(basePath):
                    os.makedirs(basePath)
                modelF = basePath + save_model_with_prefix + "_afterIter_" + str(
                    iters_so_far) + ".model"
                U.save_state(modelF)
                logger.log("Saved model to file :{}".format(modelF))

        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()

        summary = tf.Summary(value=[
            tf.Summary.Value(tag="EpRewMean", simple_value=np.mean(rewbuffer))
        ])
        summary_writer.add_summary(summary, timesteps_so_far)
    return pi
    def update(self):
        if self.normrew:         # 规约奖励, 根据 MPI 从其余线程获取的信息
            rffs = np.array([self.rff.update(rew) for rew in self.rollout.buf_rews.T])
            rffs_mean, rffs_std, rffs_count = mpi_moments(rffs.ravel())
            self.rff_rms.update_from_moments(rffs_mean, rffs_std ** 2, rffs_count)
            rews = self.rollout.buf_rews / np.sqrt(self.rff_rms.var)
        else:
            rews = np.copy(self.rollout.buf_rews)

        # 调用本类的函数, 根据奖励序列 rews 计算 advantage function
        self.calculate_advantages(rews=rews, use_news=self.use_news, gamma=self.gamma, lam=self.lam)

        # 记录一些统计量进行输出
        info = dict(
            advmean=self.buf_advs.mean(),
            advstd=self.buf_advs.std(),
            retmean=self.buf_rets.mean(),
            retstd=self.buf_rets.std(),
            vpredmean=self.rollout.buf_vpreds.mean(),
            vpredstd=self.rollout.buf_vpreds.std(),
            ev=explained_variance(self.rollout.buf_vpreds.ravel(), self.buf_rets.ravel()),
            rew_mean=np.mean(self.rollout.buf_rews),
            rew_mean_norm=np.mean(rews),
            recent_best_ext_ret=self.rollout.current_max
        )
        if self.rollout.best_ext_ret is not None:
            info['best_ext_ret'] = self.rollout.best_ext_ret

        # normalize advantages. 对计算得到的 advantage 由 mean 和 std 进行规约.
        if self.normadv:
            m, s = get_mean_and_std(self.buf_advs)
            self.buf_advs = (self.buf_advs - m) / (s + 1e-7)
        envsperbatch = (self.nenvs * self.nsegs_per_env) // self.nminibatches
        envsperbatch = max(1, envsperbatch)
        envinds = np.arange(self.nenvs * self.nsegs_per_env)

        def resh(x):
            if self.nsegs_per_env == 1:
                return x
            sh = x.shape
            return x.reshape((sh[0] * self.nsegs_per_env, self.nsteps_per_seg) + sh[2:])

        # 将本类中定义的 placeholder 与 rollout 类中收集的样本numpy 对应起来, 准备作为 feed-dict
        ph_buf = [
            (self.stochpol.ph_ac, resh(self.rollout.buf_acs)),
            (self.ph_rews, resh(self.rollout.buf_rews)),
            (self.ph_oldvpred, resh(self.rollout.buf_vpreds)),
            (self.ph_oldnlp, resh(self.rollout.buf_nlps)),
            (self.stochpol.ph_ob, resh(self.rollout.buf_obs)),   # 以上是rollout在于环境交互中记录的numpy
            (self.ph_ret, resh(self.buf_rets)),                  # 根据 rollout 记录计算得到的 return
            (self.ph_adv, resh(self.buf_advs)),                  # 根据 rollout 记录计算得到的 advantage.
        ]
        ph_buf.extend([
            (self.dynamics.last_ob,
             self.rollout.buf_obs_last.reshape([self.nenvs * self.nsegs_per_env, 1, *self.ob_space.shape]))
        ])
        mblossvals = []          # 记录训练中的损失

        # 训练 Agent 损失
        for _ in range(self.nepochs):
            np.random.shuffle(envinds)
            for start in range(0, self.nenvs * self.nsegs_per_env, envsperbatch):
                end = start + envsperbatch
                mbenvinds = envinds[start:end]
                fd = {ph: buf[mbenvinds] for (ph, buf) in ph_buf}     # 构造 feed_dict
                fd.update({self.ph_lr: self.lr, self.ph_cliprange: self.cliprange})
                mblossvals.append(getsess().run(self._losses + (self._train,), fd)[:-1])    # 计算损失, 同时进行更新

        # add bai.  单独再次训练 DVAE
        for tmp in range(self.nepochs_dvae):
            print("额外训练dvae. ", tmp)
            np.random.shuffle(envinds)
            for start in range(0, self.nenvs * self.nsegs_per_env, envsperbatch):     # 循环8次
                end = start + envsperbatch
                mbenvinds = envinds[start:end]
                fd = {ph: buf[mbenvinds] for (ph, buf) in ph_buf}                       # 构造 feed_dict
                fd.update({self.ph_lr: self.lr, self.ph_cliprange: self.cliprange})
                d_loss, _ = getsess().run([self.dynamics_loss, self._train_dvae], fd)   # 计算dvae损失, 同时进行更新
                print(d_loss, end=", ")
            print("\n")

        mblossvals = [mblossvals[0]]
        info.update(zip(['opt_' + ln for ln in self.loss_names], np.mean([mblossvals[0]], axis=0)))
        info["rank"] = MPI.COMM_WORLD.Get_rank()
        self.n_updates += 1
        info["n_updates"] = self.n_updates
        info.update({dn: (np.mean(dvs) if len(dvs) > 0 else 0) for (dn, dvs) in self.rollout.statlists.items()})
        info.update(self.rollout.stats)
        if "states_visited" in info:
            info.pop("states_visited")
        tnow = time.time()
        info["ups"] = 1. / (tnow - self.t_last_update)
        info["total_secs"] = tnow - self.t_start
        info['tps'] = MPI.COMM_WORLD.Get_size() * self.rollout.nsteps * self.nenvs / (tnow - self.t_last_update)
        self.t_last_update = tnow

        return info
    def update(self):
        if self.normrew:
            rffs = np.array(
                [self.rff.update(rew) for rew in self.rollout.buf_rews.T])
            rffs_mean, rffs_std, rffs_count = mpi_moments(rffs.ravel())
            self.rff_rms.update_from_moments(rffs_mean, rffs_std**2,
                                             rffs_count)
            rews = self.rollout.buf_rews / np.sqrt(self.rff_rms.var)
        else:
            rews = np.copy(self.rollout.buf_rews)
        self.calculate_advantages(rews=rews,
                                  use_news=self.use_news,
                                  gamma=self.gamma,
                                  lam=self.lam)

        info = dict(advmean=self.buf_advs.mean(),
                    advstd=self.buf_advs.std(),
                    retmean=self.buf_rets.mean(),
                    retstd=self.buf_rets.std(),
                    vpredmean=self.rollout.buf_vpreds.mean(),
                    vpredstd=self.rollout.buf_vpreds.std(),
                    ev=explained_variance(self.rollout.buf_vpreds.ravel(),
                                          self.buf_rets.ravel()),
                    rew_mean=np.mean(self.rollout.buf_rews),
                    recent_best_ext_ret=self.rollout.current_max)
        if self.rollout.best_ext_ret is not None:
            info['best_ext_ret'] = self.rollout.best_ext_ret

        # normalize advantages
        if self.normadv:
            m, s = get_mean_and_std(self.buf_advs)
            self.buf_advs = (self.buf_advs - m) / (s + 1e-7)
        envsperbatch = (self.nenvs * self.nsegs_per_env) // self.nminibatches
        envsperbatch = max(1, envsperbatch)
        envinds = np.arange(self.nenvs * self.nsegs_per_env)

        def resh(x):
            if self.nsegs_per_env == 1:
                return x
            sh = x.shape
            return x.reshape((sh[0] * self.nsegs_per_env,
                              self.nsteps_per_seg) + sh[2:])

        ph_buf = [
            (self.trainpol.ph_ac, resh(self.rollout.buf_acs)),
            (self.ph_rews, resh(self.rollout.buf_rews)),
            (self.ph_oldvpred, resh(self.rollout.buf_vpreds)),
            (self.ph_oldnlp, resh(self.rollout.buf_nlps)),
            (self.trainpol.ph_ob, resh(self.rollout.buf_obs)),
            (self.ph_ret, resh(self.buf_rets)),
            (self.ph_adv, resh(self.buf_advs)),
        ]
        ph_buf.extend([(self.train_dynamics.last_ob,
                        self.rollout.buf_obs_last.reshape([
                            self.nenvs * self.nsegs_per_env, 1,
                            *self.ob_space.shape
                        ]))])
        ph_buf.extend([
            (self.trainpol.states_ph,
             resh(self.rollout.buf_states_first)),  # rnn inputs
            (self.trainpol.masks_ph, resh(self.rollout.buf_news))
        ])
        if 'err' in self.policy_mode:
            ph_buf.extend([(self.trainpol.pred_error,
                            resh(self.rollout.buf_errs))])  # New
        if 'ac' in self.policy_mode:
            ph_buf.extend([(self.trainpol.ph_ac, resh(self.rollout.buf_acs)),
                           (self.trainpol.ph_ac_first,
                            resh(self.rollout.buf_acs_first))])
        if 'pred' in self.policy_mode:
            ph_buf.extend([(self.trainpol.obs_pred,
                            resh(self.rollout.buf_obpreds))])

        # with open(os.getcwd() + "/record_instruction.txt", 'r') as rec_inst:
        #     rec_n = []
        #     rec_all_n = []
        #     while True:
        #         line = rec_inst.readline()
        #         if not line: break
        #         args = line.split()
        #         rec_n.append(int(args[0]))
        #         if len(args) > 1:
        #             rec_all_n.append(int(args[0]))
        #     if self.n_updates in rec_n and MPI.COMM_WORLD.Get_rank() == 0:
        #         print("Enter!")
        #         with open(self.logdir + '/full_log' + str(self.n_updates) + '.pk', 'wb') as full_log:
        #             import pickle
        #             debug_data = {"buf_obs" : self.rollout.buf_obs,
        #                           "buf_obs_last" : self.rollout.buf_obs_last,
        #                           "buf_acs" : self.rollout.buf_acs,
        #                           "buf_acs_first" : self.rollout.buf_acs_first,
        #                           "buf_news" : self.rollout.buf_news,
        #                           "buf_news_last" : self.rollout.buf_new_last,
        #                           "buf_rews" : self.rollout.buf_rews,
        #                           "buf_ext_rews" : self.rollout.buf_ext_rews}
        #             if self.n_updates in rec_all_n:
        #                 debug_data.update({"buf_err": self.rollout.buf_errs,
        #                                     "buf_err_last": self.rollout.buf_errs_last,
        #                                     "buf_obpreds": self.rollout.buf_obpreds,
        #                                     "buf_obpreds_last": self.rollout.buf_obpreds_last,
        #                                     "buf_vpreds": self.rollout.buf_vpreds,
        #                                     "buf_vpred_last": self.rollout.buf_vpred_last,
        #                                     "buf_states": self.rollout.buf_states,
        #                                     "buf_states_first": self.rollout.buf_states_first,
        #                                     "buf_nlps": self.rollout.buf_nlps,})
        #             pickle.dump(debug_data, full_log)

        mblossvals = []

        for _ in range(self.nepochs):
            np.random.shuffle(envinds)
            for start in range(0, self.nenvs * self.nsegs_per_env,
                               envsperbatch):
                end = start + envsperbatch
                mbenvinds = envinds[start:end]
                fd = {ph: buf[mbenvinds] for (ph, buf) in ph_buf}
                fd.update({
                    self.ph_lr: self.lr,
                    self.ph_cliprange: self.cliprange
                })
                mblossvals.append(getsess().run(self._losses + (self._train, ),
                                                fd)[:-1])

        mblossvals = [mblossvals[0]]
        info.update(
            zip(['opt_' + ln for ln in self.loss_names],
                np.mean([mblossvals[0]], axis=0)))
        info["rank"] = MPI.COMM_WORLD.Get_rank()
        self.n_updates += 1
        info["n_updates"] = self.n_updates
        info.update({
            dn: (np.mean(dvs) if len(dvs) > 0 else 0)
            for (dn, dvs) in self.rollout.statlists.items()
        })
        info.update(self.rollout.stats)
        if "states_visited" in info:
            info.pop("states_visited")
        tnow = time.time()
        info["ups"] = 1. / (tnow - self.t_last_update)
        info["total_secs"] = tnow - self.t_start
        info['tps'] = MPI.COMM_WORLD.Get_size(
        ) * self.rollout.nsteps * self.nenvs / (tnow - self.t_last_update)
        self.t_last_update = tnow

        # New
        if 'err' in self.policy_mode:
            info["error"] = np.sqrt(np.power(self.rollout.buf_errs, 2).mean())

        if self.n_updates % self.tboard_period == 0 and MPI.COMM_WORLD.Get_rank(
        ) == 0:
            if self.full_tensorboard_log:
                summary = getsess().run(self.merged_summary_op, fd)  # New
                self.summary_writer.add_summary(
                    summary, self.rollout.stats["tcount"])  # New
            for k, v in info.items():
                summary = tf.Summary(value=[
                    tf.Summary.Value(tag=k, simple_value=v),
                ])
                self.summary_writer.add_summary(summary,
                                                self.rollout.stats["tcount"])

        return info
    def update(self):
        if self.normrew:
            rffs = np.array(
                [self.rff.update(rew) for rew in self.rollout.buf_rews.T])
            rffs_mean, rffs_std, rffs_count = mpi_moments(rffs.ravel())
            self.rff_rms.update_from_moments(rffs_mean, rffs_std**2,
                                             rffs_count)
            rews = self.rollout.buf_rews / np.sqrt(self.rff_rms.var)
        else:
            rews = np.copy(self.rollout.buf_rews)
        self.calculate_advantages(rews=rews,
                                  use_news=self.use_news,
                                  gamma=self.gamma,
                                  lam=self.lam)

        info = dict(
            advmean=self.buf_advs.mean(),
            advstd=self.buf_advs.std(),
            retmean=self.buf_rets.mean(),
            retstd=self.buf_rets.std(),
            vpredmean=self.rollout.buf_vpreds.mean(),
            vpredstd=self.rollout.buf_vpreds.std(),
            ev=explained_variance(self.rollout.buf_vpreds.ravel(),
                                  self.buf_rets.ravel()),
            rew_mean=np.mean(self.rollout.buf_rews),
            recent_best_ext_ret=self.rollout.current_max,
        )
        if self.rollout.best_ext_ret is not None:
            info["best_ext_ret"] = self.rollout.best_ext_ret

        # normalize advantages
        if self.normadv:
            m, s = get_mean_and_std(self.buf_advs)
            self.buf_advs = (self.buf_advs - m) / (s + 1e-7)
        envsperbatch = (self.nenvs * self.nsegs_per_env) // self.nminibatches
        envsperbatch = max(1, envsperbatch)
        envinds = np.arange(self.nenvs * self.nsegs_per_env)

        def resh(x):
            if self.nsegs_per_env == 1:
                return x
            sh = x.shape
            return x.reshape((sh[0] * self.nsegs_per_env,
                              self.nsteps_per_seg) + sh[2:])

        ph_buf = [
            (self.stochpol.ph_ac, resh(self.rollout.buf_acs)),
            (self.ph_rews, resh(self.rollout.buf_rews)),
            (self.ph_oldvpred, resh(self.rollout.buf_vpreds)),
            (self.ph_oldnlp, resh(self.rollout.buf_nlps)),
            (self.stochpol.ph_ob, resh(self.rollout.buf_obs)),
            (self.ph_ret, resh(self.buf_rets)),
            (self.ph_adv, resh(self.buf_advs)),
        ]
        ph_buf.extend([(
            self.dynamics.last_ob,
            self.rollout.buf_obs_last.reshape(
                [self.nenvs * self.nsegs_per_env, 1, *self.ob_space.shape]),
        )])
        mblossvals = []

        for _ in range(self.nepochs):
            np.random.shuffle(envinds)
            for start in range(0, self.nenvs * self.nsegs_per_env,
                               envsperbatch):
                end = start + envsperbatch
                mbenvinds = envinds[start:end]
                fd = {ph: buf[mbenvinds] for (ph, buf) in ph_buf}
                fd.update({
                    self.ph_lr: self.lr,
                    self.ph_cliprange: self.cliprange
                })
                mblossvals.append(getsess().run(self._losses + (self._train, ),
                                                fd)[:-1])

        mblossvals = [mblossvals[0]]
        info.update(
            zip(
                ["opt_" + ln for ln in self.loss_names],
                np.mean([mblossvals[0]], axis=0),
            ))
        info["rank"] = MPI.COMM_WORLD.Get_rank()
        self.n_updates += 1
        info["n_updates"] = self.n_updates
        info.update({
            dn: (np.mean(dvs) if len(dvs) > 0 else 0)
            for (dn, dvs) in self.rollout.statlists.items()
        })
        info.update(self.rollout.stats)
        if "states_visited" in info:
            info.pop("states_visited")
        tnow = time.time()
        info["ups"] = 1.0 / (tnow - self.t_last_update)
        info["total_secs"] = tnow - self.t_start
        info["tps"] = (MPI.COMM_WORLD.Get_size() * self.rollout.nsteps *
                       self.nenvs / (tnow - self.t_last_update))
        self.t_last_update = tnow

        return info
Exemple #28
0
def learn(
    # =========== modified part begins =========== #
    env_id,
    seed,
    # ============ modified part ends ============ #
    policy_func,
    *,
    timesteps_per_actorbatch,  # timesteps per actor per update
    clip_param,
    entcoeff,  # clipping parameter epsilon, entropy coeff
    optim_epochs,
    optim_stepsize,
    optim_batchsize,  # optimization hypers
    gamma,
    lam,  # advantage estimation
    max_timesteps=0,
    max_episodes=0,
    max_iters=0,
    max_seconds=0,  # time constraint
    callback=None,  # you can do anything in the callback, since it takes locals(), globals()
    adam_epsilon=1e-5,
    schedule='constant'  # annealing for stepsize parameters (epsilon and adam)
):

    # ================================== modification 1 ================================== #
    """
    input:  replace "env" (env class) with "env_id" (string)
            add "seed" (int)
        reason: to enable env.make() during training
        modification detail: add following lines into learn()
            env = gym.make(env_id)
            env = bench.Monitor(env, logger.get_dir())
            env.seed(seed)
            env.close() # added at the end of learn()
    """
    import roboschool, gym
    from baselines import bench
    env = gym.make(env_id)
    env = bench.Monitor(env, logger.get_dir())
    env.seed(seed)
    # ================================== modification 1 ================================== #

    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space

    # policy_func is the initialization of NN
    # NN structure:
    #   state -> (num_hid_layers) fully-connected layers with (hid_size) units -> (action, predicted value)
    #       num_hid_layers, hid_size: set in the file calls "learn"
    pi = policy_func("pi", ob_space,
                     ac_space)  # Construct network for new policy
    oldpi = policy_func("oldpi", ob_space, ac_space)  # Network for old policy

    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    # placeholder for "ob"
    # created in mlppolicy.py
    ob = U.get_placeholder_cached(name="ob")
    # placeholder for "ac"
    # in common/distribution.py
    ac = pi.pdtype.sample_placeholder([None])

    # KL divergence and Entropy, defined in common/distribution.py
    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = U.mean(kloldnew)
    meanent = U.mean(ent)

    # pol_entpen: Entropy Bounus encourages exploration
    # entcoeff: entropy coefficient, defined in PPO page 5, Equ. (9)
    pol_entpen = (-entcoeff) * meanent

    # probability ration, defined in PPO page 3
    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold

    # Surrogate Goal
    # defined in PPO page 3, Equ (7)
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = U.clip(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg  #
    pol_surr = -U.mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)

    # Value Function Loss: square error loss for ||v_pred - v_target||
    vf_loss = U.mean(tf.square(pi.vpred - ret))

    # Total_loss = L^CLIP - Value Function Loss + Entropy Bounus
    # defined in PPO page 5, Equ. (9)
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])
    # adam optimizer?
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    # oldpi = pi
    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])

    # Why we need this line?
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    adam.sync()

    # ================================== modification 2 ================================== #
    for _ in range(1):
        # reinitialize env
        env = gym.make(env_id)
        env = bench.Monitor(env, logger.get_dir())
        # ================================== modification 2 ================================== #

        # Prepare for rollouts
        # ----------------------------------------
        seg_gen = traj_segment_generator(pi,
                                         env,
                                         timesteps_per_actorbatch,
                                         stochastic=True)

        episodes_so_far = 0
        timesteps_so_far = 0
        iters_so_far = 0
        tstart = time.time()
        lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
        rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

        assert sum([
            max_iters > 0, max_timesteps > 0, max_episodes > 0, max_seconds > 0
        ]) == 1, "Only one time constraint permitted"

        while True:
            if callback: callback(locals(), globals())
            if max_timesteps and timesteps_so_far >= max_timesteps:
                break
            elif max_episodes and episodes_so_far >= max_episodes:
                break
            elif max_iters and iters_so_far >= max_iters:
                break
            elif max_seconds and time.time() - tstart >= max_seconds:
                break

            # annealing for stepsize parameters (epsilon and adam)
            if schedule == 'constant':
                cur_lrmult = 1.0
            elif schedule == 'linear':
                cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps,
                                 0)
            else:
                raise NotImplementedError

            logger.log("********** Iteration %i ************" % iters_so_far)

            seg = seg_gen.__next__()
            add_vtarg_and_adv(seg, gamma, lam)

            # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
            ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
                "tdlamret"]
            vpredbefore = seg[
                "vpred"]  # predicted value function before udpate
            atarg = (atarg - atarg.mean()) / atarg.std(
            )  # standardized advantage function estimate
            d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                        shuffle=not pi.recurrent)
            optim_batchsize = optim_batchsize or ob.shape[0]

            if hasattr(pi, "ob_rms"):
                pi.ob_rms.update(ob)  # update running mean/std for policy

            # oldpi = pi
            # set old parameter values to new parameter values
            assign_old_eq_new()
            logger.log("Optimizing...")
            logger.log(fmt_row(13, loss_names))
            # Here we do a bunch of optimization epochs over the data
            for _ in range(optim_epochs):
                losses = [
                ]  # list of tuples, each of which gives the loss for a minibatch
                for batch in d.iterate_once(optim_batchsize):
                    *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                                batch["atarg"], batch["vtarg"],
                                                cur_lrmult)
                    adam.update(g, optim_stepsize * cur_lrmult)
                    losses.append(newlosses)
                logger.log(fmt_row(13, np.mean(losses, axis=0)))

            logger.log("Evaluating losses...")
            losses = []
            for batch in d.iterate_once(optim_batchsize):
                newlosses = compute_losses(batch["ob"], batch["ac"],
                                           batch["atarg"], batch["vtarg"],
                                           cur_lrmult)
                losses.append(newlosses)
            meanlosses, _, _ = mpi_moments(losses, axis=0)
            logger.log(fmt_row(13, meanlosses))
            for (lossval, name) in zipsame(meanlosses, loss_names):
                logger.record_tabular("loss_" + name, lossval)
            logger.record_tabular("ev_tdlam_before",
                                  explained_variance(vpredbefore, tdlamret))
            lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
            listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
            lens, rews = map(flatten_lists, zip(*listoflrpairs))
            lenbuffer.extend(lens)
            rewbuffer.extend(rews)
            logger.record_tabular("EpLenMean", np.mean(lenbuffer))
            logger.record_tabular("EpRewMean", np.mean(rewbuffer))
            logger.record_tabular("EpThisIter", len(lens))
            episodes_so_far += len(lens)
            timesteps_so_far += sum(lens)
            iters_so_far += 1
            logger.record_tabular("EpisodesSoFar", episodes_so_far)
            logger.record_tabular("TimestepsSoFar", timesteps_so_far)
            logger.record_tabular("TimeElapsed", time.time() - tstart)
            if MPI.COMM_WORLD.Get_rank() == 0:
                logger.dump_tabular()

    # ================================== modification 1 ================================== #
    env.close()
Exemple #29
0
def learn(
    env,
    policy_func,
    *,
    timesteps_per_batch,  # timesteps per actor per update
    clip_param,
    entcoeff,  # clipping parameter epsilon, entropy coeff
    optim_epochs,
    optim_stepsize,
    optim_batchsize,  # optimization hypers
    gamma,
    lam,  # advantage estimation
    max_timesteps=0,
    max_episodes=0,
    max_iters=0,
    max_seconds=0,  # time constraint
    callback=None,  # you can do anything in the callback, since it takes locals(), globals()
    adam_epsilon=1e-5,
    schedule='constant'  # annealing for stepsize parameters (epsilon and adam)
):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_func("pi", ob_space,
                     ac_space)  # Construct network for new policy
    oldpi = policy_func("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = U.mean(kloldnew)
    meanent = U.mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = U.clip(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg  #
    pol_surr = -U.mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = U.mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    adam.sync()

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_batch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    rollouts_time = 0
    optimization_time = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)
        a = time.time()

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        b = time.time()
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        grad_time = 0.0
        allreduce_time = 0.0
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                aa = time.time()
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
                bb = time.time()
                adam.update(g, optim_stepsize * cur_lrmult)
                cc = time.time()
                grad_time += bb - aa
                allreduce_time += cc - bb
                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        logger.record_tabular("GradTime", grad_time)
        logger.record_tabular("AllReduceTime", allreduce_time)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        c = time.time()
        rollouts_time += (b - a)
        optimization_time += (c - b)
        logger.record_tabular("RolloutsTime", rollouts_time)
        logger.record_tabular("OptimizationTime", optimization_time)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()
Exemple #30
0
def mpi_average(value):
    if value == []:
        value = [0.]
    if not isinstance(value, list):
        value = [value]
    return mpi_moments(np.array(value))[0]
Exemple #31
0
def mpi_average(value):
    if not isinstance(value, list):
        value = [value]
    if not any(value):
        value = [0.]
    return mpi_moments(np.array(value))[0]
Exemple #32
0
def learn(
    env,
    policy_func,
    reward_giver,
    semi_dataset,
    rank,
    pretrained_weight,
    *,
    g_step,
    d_step,
    entcoeff,
    save_per_iter,
    ckpt_dir,
    log_dir,
    timesteps_per_batch,
    task_name,
    gamma,
    lam,
    max_kl,
    cg_iters,
    cg_damping=1e-2,
    vf_stepsize=3e-4,
    d_stepsize=3e-4,
    vf_iters=3,
    max_timesteps=0,
    max_episodes=0,
    max_iters=0,
    vf_batchsize=128,
    callback=None,
    freeze_g=False,
    freeze_d=False,
    pretrained_il=None,
    pretrained_semi=None,
    semi_loss=False,
    expert_reward_threshold=None,  # filter experts based on reward
    expert_label=get_semi_prefix(),
    sparse_reward=False  # filter experts based on success flag (sparse reward)
):

    semi_loss = semi_loss and semi_dataset is not None
    l2_w = 0.1

    nworkers = MPI.COMM_WORLD.Get_size()
    rank = MPI.COMM_WORLD.Get_rank()
    np.set_printoptions(precision=3)

    if rank == 0:
        writer = U.file_writer(log_dir)

        # print all the hyperparameters in the log...
        log_dict = {
            # "expert trajectories": expert_dataset.num_traj,
            "expert model": pretrained_semi,
            "algo": "trpo",
            "threads": nworkers,
            "timesteps_per_batch": timesteps_per_batch,
            "timesteps_per_thread": -(-timesteps_per_batch // nworkers),
            "entcoeff": entcoeff,
            "vf_iters": vf_iters,
            "vf_batchsize": vf_batchsize,
            "vf_stepsize": vf_stepsize,
            "d_stepsize": d_stepsize,
            "g_step": g_step,
            "d_step": d_step,
            "max_kl": max_kl,
            "gamma": gamma,
            "lam": lam,
        }

        if semi_dataset is not None:
            log_dict["semi trajectories"] = semi_dataset.num_traj
        if hasattr(semi_dataset, 'info'):
            log_dict["semi_dataset_info"] = semi_dataset.info
        if expert_reward_threshold is not None:
            log_dict["expert reward threshold"] = expert_reward_threshold
        log_dict["sparse reward"] = sparse_reward

        # print them all together for csv
        logger.log(",".join([str(elem) for elem in log_dict]))
        logger.log(",".join([str(elem) for elem in log_dict.values()]))

        # also print them separately for easy reading:
        for elem in log_dict:
            logger.log(str(elem) + ": " + str(log_dict[elem]))

    # divide the timesteps to the threads
    timesteps_per_batch = -(-timesteps_per_batch // nworkers
                            )  # get ceil of division

    # Setup losses and stuff
    # ----------------------------------------
    ob_space = OrderedDict([(label, env[label].observation_space)
                            for label in env])

    if semi_dataset and get_semi_prefix() in env:  # semi ob space is different
        semi_obs_space = semi_ob_space(env[get_semi_prefix()],
                                       semi_size=semi_dataset.semi_size)
        ob_space[get_semi_prefix()] = semi_obs_space
    else:
        print("no semi dataset")
        # raise RuntimeError

    vf_stepsize = {label: vf_stepsize for label in env}

    ac_space = {label: env[label].action_space for label in ob_space}
    pi = {
        label: policy_func("pi",
                           ob_space=ob_space[label],
                           ac_space=ac_space[label],
                           prefix=label)
        for label in ob_space
    }
    oldpi = {
        label: policy_func("oldpi",
                           ob_space=ob_space[label],
                           ac_space=ac_space[label],
                           prefix=label)
        for label in ob_space
    }
    atarg = {
        label: tf.placeholder(dtype=tf.float32, shape=[None])
        for label in ob_space
    }  # Target advantage function (if applicable)
    ret = {
        label: tf.placeholder(dtype=tf.float32, shape=[None])
        for label in ob_space
    }  # Empirical return

    ob = {
        label: U.get_placeholder_cached(name=label + "ob")
        for label in ob_space
    }
    ac = {
        label: pi[label].pdtype.sample_placeholder([None])
        for label in ob_space
    }

    kloldnew = {label: oldpi[label].pd.kl(pi[label].pd) for label in ob_space}
    ent = {label: pi[label].pd.entropy() for label in ob_space}
    meankl = {label: tf.reduce_mean(kloldnew[label]) for label in ob_space}
    meanent = {label: tf.reduce_mean(ent[label]) for label in ob_space}
    entbonus = {label: entcoeff * meanent[label] for label in ob_space}

    vferr = {
        label: tf.reduce_mean(tf.square(pi[label].vpred - ret[label]))
        for label in ob_space
    }

    ratio = {
        label:
        tf.exp(pi[label].pd.logp(ac[label]) - oldpi[label].pd.logp(ac[label]))
        for label in ob_space
    }  # advantage * pnew / pold
    surrgain = {
        label: tf.reduce_mean(ratio[label] * atarg[label])
        for label in ob_space
    }

    optimgain = {
        label: surrgain[label] + entbonus[label]
        for label in ob_space
    }
    losses = {
        label: [
            optimgain[label], meankl[label], entbonus[label], surrgain[label],
            meanent[label]
        ]
        for label in ob_space
    }
    loss_names = {
        label: [
            label + name for name in
            ["optimgain", "meankl", "entloss", "surrgain", "entropy"]
        ]
        for label in ob_space
    }

    vf_losses = {label: [vferr[label]] for label in ob_space}
    vf_loss_names = {label: [label + "vf_loss"] for label in ob_space}

    dist = {label: meankl[label] for label in ob_space}

    all_var_list = {
        label: pi[label].get_trainable_variables()
        for label in ob_space
    }
    var_list = {
        label: [
            v for v in all_var_list[label]
            if "pol" in v.name or "logstd" in v.name
        ]
        for label in ob_space
    }
    vf_var_list = {
        label: [v for v in all_var_list[label] if "vf" in v.name]
        for label in ob_space
    }
    for label in ob_space:
        assert len(var_list[label]) == len(vf_var_list[label]) + 1

    get_flat = {label: U.GetFlat(var_list[label]) for label in ob_space}
    set_from_flat = {
        label: U.SetFromFlat(var_list[label])
        for label in ob_space
    }
    klgrads = {
        label: tf.gradients(dist[label], var_list[label])
        for label in ob_space
    }
    flat_tangent = {
        label: tf.placeholder(dtype=tf.float32,
                              shape=[None],
                              name=label + "flat_tan")
        for label in ob_space
    }
    fvp = {}
    for label in ob_space:
        shapes = [var.get_shape().as_list() for var in var_list[label]]
        start = 0
        tangents = []
        for shape in shapes:
            sz = U.intprod(shape)
            tangents.append(
                tf.reshape(flat_tangent[label][start:start + sz], shape))
            start += sz
        gvp = tf.add_n([
            tf.reduce_sum(g * tangent)
            for (g, tangent) in zipsame(klgrads[label], tangents)
        ])  # pylint: disable=E1111
        fvp[label] = U.flatgrad(gvp, var_list[label])

    assign_old_eq_new = {
        label:
        U.function([], [],
                   updates=[
                       tf.assign(oldv, newv)
                       for (oldv,
                            newv) in zipsame(oldpi[label].get_variables(),
                                             pi[label].get_variables())
                   ])
        for label in ob_space
    }
    compute_losses = {
        label: U.function([ob[label], ac[label], atarg[label]], losses[label])
        for label in ob_space
    }

    compute_vf_losses = {
        label: U.function([ob[label], ac[label], atarg[label], ret[label]],
                          losses[label] + vf_losses[label])
        for label in ob_space
    }

    compute_lossandgrad = {
        label: U.function([ob[label], ac[label], atarg[label]], losses[label] +
                          [U.flatgrad(optimgain[label], var_list[label])])
        for label in ob_space
    }
    compute_fvp = {
        label:
        U.function([flat_tangent[label], ob[label], ac[label], atarg[label]],
                   fvp[label])
        for label in ob_space
    }

    compute_vflossandgrad = {
        label: U.function([ob[label], ret[label]], vf_losses[label] +
                          [U.flatgrad(vferr[label], vf_var_list[label])])
        for label in ob_space
    }

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(
                colorize("done in %.3f seconds" % (time.time() - tstart),
                         color='magenta'))
        else:
            yield

    def allmean(x):
        assert isinstance(x, np.ndarray)
        out = np.empty_like(x)
        MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
        out /= nworkers
        return out

    episodes_so_far = {label: 0 for label in ob_space}
    timesteps_so_far = {label: 0 for label in ob_space}
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = {label: deque(maxlen=40)
                 for label in ob_space}  # rolling buffer for episode lengths
    rewbuffer = {label: deque(maxlen=40)
                 for label in ob_space}  # rolling buffer for episode rewards
    true_rewbuffer = {label: deque(maxlen=40) for label in ob_space}
    success_buffer = {label: deque(maxlen=40) for label in ob_space}
    # L2 only for semi network
    l2_rewbuffer = deque(
        maxlen=40) if semi_loss and semi_dataset is not None else None
    total_rewbuffer = deque(
        maxlen=40) if semi_loss and semi_dataset is not None else None

    not_update = 1 if not freeze_d else 0  # do not update G before D the first time
    loaded = False
    # if provide pretrained weight
    if not U.load_checkpoint_variables(pretrained_weight,
                                       include_no_prefix_vars=True):
        # if no general checkpoint available, check sub-checkpoints for both networks
        if U.load_checkpoint_variables(pretrained_il,
                                       prefix=get_il_prefix(),
                                       include_no_prefix_vars=False):
            if rank == 0:
                logger.log("loaded checkpoint variables from " + pretrained_il)
            loaded = True
        elif expert_label == get_il_prefix():
            logger.log("ERROR no available cat_dauggi expert model in ",
                       pretrained_il)
            exit(1)

        if U.load_checkpoint_variables(pretrained_semi,
                                       prefix=get_semi_prefix(),
                                       include_no_prefix_vars=False):
            if rank == 0:
                logger.log("loaded checkpoint variables from " +
                           pretrained_semi)
            loaded = True
        elif expert_label == get_semi_prefix():
            if rank == 0:
                logger.log("ERROR no available semi expert model in ",
                           pretrained_semi)
            exit(1)
    else:
        loaded = True
        if rank == 0:
            logger.log("loaded checkpoint variables from " + pretrained_weight)

    if loaded:
        not_update = 0 if any(
            [x.op.name.find("adversary") != -1
             for x in U.ALREADY_INITIALIZED]) else 1
        if pretrained_weight and pretrained_weight.rfind("iter_") and \
                pretrained_weight[pretrained_weight.rfind("iter_") + len("iter_"):].isdigit():
            curr_iter = int(
                pretrained_weight[pretrained_weight.rfind("iter_") +
                                  len("iter_"):]) + 1

            if rank == 0:
                print("loaded checkpoint at iteration: " + str(curr_iter))
            iters_so_far = curr_iter
            for label in timesteps_so_far:
                timesteps_so_far[label] = iters_so_far * timesteps_per_batch

    d_adam = MpiAdam(reward_giver.get_trainable_variables())
    vfadam = {label: MpiAdam(vf_var_list[label]) for label in ob_space}

    U.initialize()
    d_adam.sync()

    for label in ob_space:
        th_init = get_flat[label]()
        MPI.COMM_WORLD.Bcast(th_init, root=0)
        set_from_flat[label](th_init)
        vfadam[label].sync()
        if rank == 0:
            print(label + "Init param sum", th_init.sum(), flush=True)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = {
        label: traj_segment_generator(
            pi[label],
            env[label],
            reward_giver,
            timesteps_per_batch,
            stochastic=True,
            semi_dataset=semi_dataset if label == get_semi_prefix() else None,
            semi_loss=semi_loss,
            reward_threshold=expert_reward_threshold
            if label == expert_label else None,
            sparse_reward=sparse_reward if label == expert_label else False)
        for label in ob_space
    }

    g_losses = {}

    assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1

    g_loss_stats = {
        label: stats(loss_names[label] + vf_loss_names[label])
        for label in ob_space if label != expert_label
    }
    d_loss_stats = stats(reward_giver.loss_name)
    ep_names = ["True_rewards", "Rewards", "Episode_length", "Success"]

    ep_stats = {label: None for label in ob_space}
    # cat_dauggi network stats
    if get_il_prefix() in ep_stats:
        ep_stats[get_il_prefix()] = stats([name for name in ep_names])

    # semi network stats
    if get_semi_prefix() in ep_stats:
        if semi_loss and semi_dataset is not None:
            ep_names.append("L2_loss")
            ep_names.append("total_rewards")
        ep_stats[get_semi_prefix()] = stats(
            [get_semi_prefix() + name for name in ep_names])

    if rank == 0:
        start_time = time.time()
        ch_count = 0
        env_type = env[expert_label].env.env.__class__.__name__

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and any(
            [timesteps_so_far[label] >= max_timesteps for label in ob_space]):
            break
        elif max_episodes and any(
            [episodes_so_far[label] >= max_episodes for label in ob_space]):
            break
        elif max_iters and iters_so_far >= max_iters:
            break

        # Save model
        if rank == 0 and iters_so_far % save_per_iter == 0 and ckpt_dir is not None:
            fname = os.path.join(ckpt_dir, task_name)
            if env_type.find("Pendulum") != -1 or save_per_iter != 1:
                fname = os.path.join(ckpt_dir, 'iter_' + str(iters_so_far),
                                     'iter_' + str(iters_so_far))
            os.makedirs(os.path.dirname(fname), exist_ok=True)
            saver = tf.train.Saver()
            saver.save(tf.get_default_session(), fname, write_meta_graph=False)

        if rank == 0 and time.time(
        ) - start_time >= 3600 * ch_count:  # save a different checkpoint every hour
            fname = os.path.join(ckpt_dir, 'hour' + str(ch_count).zfill(3))
            fname = os.path.join(fname, 'iter_' + str(iters_so_far))
            os.makedirs(os.path.dirname(fname), exist_ok=True)
            saver = tf.train.Saver()
            saver.save(tf.get_default_session(), fname, write_meta_graph=False)
            ch_count += 1

        logger.log("********** Iteration %i ************" % iters_so_far)

        def fisher_func_builder(label):
            def fisher_vector_product(p):
                return allmean(compute_fvp[label](p, *
                                                  fvpargs)) + cg_damping * p

            return fisher_vector_product

        # ------------------ Update G ------------------
        d = {label: None for label in ob_space}
        segs = {label: None for label in ob_space}
        logger.log("Optimizing Policy...")
        for curr_step in range(g_step):
            for label in ob_space:

                if curr_step and label == expert_label:  # get expert trajectories only for one g_step which is same as d_step
                    continue

                logger.log("Optimizing Policy " + label + "...")
                with timed("sampling"):
                    segs[label] = seg = seg_gen[label].__next__()

                seg["rew"] = seg["rew"] - seg["l2_loss"] * l2_w

                add_vtarg_and_adv(seg, gamma, lam)
                ob, ac, atarg, tdlamret, full_ob = seg["ob"], seg["ac"], seg[
                    "adv"], seg["tdlamret"], seg["full_ob"]
                vpredbefore = seg[
                    "vpred"]  # predicted value function before udpate
                atarg = (atarg - atarg.mean()) / atarg.std(
                )  # standardized advantage function estimate
                d[label] = Dataset(dict(ob=ob,
                                        ac=ac,
                                        atarg=atarg,
                                        vtarg=tdlamret),
                                   shuffle=True)

                if not_update or label == expert_label:
                    continue  # stop G from updating

                if hasattr(pi[label], "ob_rms"):
                    pi[label].ob_rms.update(
                        full_ob)  # update running mean/std for policy

                args = seg["full_ob"], seg["ac"], atarg
                fvpargs = [arr[::5] for arr in args]

                assign_old_eq_new[label](
                )  # set old parameter values to new parameter values
                with timed("computegrad"):
                    *lossbefore, g = compute_lossandgrad[label](*args)
                lossbefore = allmean(np.array(lossbefore))
                g = allmean(g)
                if np.allclose(g, 0):
                    logger.log("Got zero gradient. not updating")
                else:
                    with timed("cg"):
                        stepdir = cg(fisher_func_builder(label),
                                     g,
                                     cg_iters=cg_iters,
                                     verbose=rank == 0)
                    assert np.isfinite(stepdir).all()
                    shs = .5 * stepdir.dot(fisher_func_builder(label)(stepdir))
                    lm = np.sqrt(shs / max_kl)
                    # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
                    fullstep = stepdir / lm
                    expectedimprove = g.dot(fullstep)
                    surrbefore = lossbefore[0]
                    stepsize = 1.0
                    thbefore = get_flat[label]()
                    for _ in range(10):
                        thnew = thbefore + fullstep * stepsize
                        set_from_flat[label](thnew)
                        meanlosses = surr, kl, *_ = allmean(
                            np.array(compute_losses[label](*args)))
                        if rank == 0:
                            print("Generator entropy " + str(meanlosses[4]) +
                                  ", loss " + str(meanlosses[2]))
                        improve = surr - surrbefore
                        logger.log("Expected: %.3f Actual: %.3f" %
                                   (expectedimprove, improve))
                        if not np.isfinite(meanlosses).all():
                            logger.log(
                                "Got non-finite value of losses -- bad!")
                        elif kl > max_kl * 1.5:
                            logger.log(
                                "violated KL constraint. shrinking step.")
                        elif improve < 0:
                            logger.log(
                                "surrogate didn't improve. shrinking step.")
                        else:
                            logger.log("Stepsize OK!")
                            break
                        stepsize *= .5
                    else:
                        logger.log("couldn't compute a good step")
                        set_from_flat[label](thbefore)
                    if nworkers > 1 and iters_so_far % 20 == 0:
                        paramsums = MPI.COMM_WORLD.allgather(
                            (thnew.sum(),
                             vfadam[label].getflat().sum()))  # list of tuples
                        assert all(
                            np.allclose(ps, paramsums[0])
                            for ps in paramsums[1:])

            expert_dataset = d[expert_label]

            if not_update:
                break

            for label in ob_space:
                if label == expert_label:
                    continue

                with timed("vf"):
                    logger.log(fmt_row(13, vf_loss_names[label]))
                    for _ in range(vf_iters):
                        vf_b_losses = []
                        for batch in d[label].iterate_once(vf_batchsize):
                            mbob = batch["ob"]
                            mbret = batch["vtarg"]
                            *newlosses, g = compute_vflossandgrad[label](mbob,
                                                                         mbret)
                            g = allmean(g)
                            newlosses = allmean(np.array(newlosses))

                            vfadam[label].update(g, vf_stepsize[label])
                            vf_b_losses.append(newlosses)
                        logger.log(fmt_row(13, np.mean(vf_b_losses, axis=0)))

                    logger.log("Evaluating losses...")
                    losses = []
                    for batch in d[label].iterate_once(vf_batchsize):
                        newlosses = compute_vf_losses[label](batch["ob"],
                                                             batch["ac"],
                                                             batch["atarg"],
                                                             batch["vtarg"])
                        losses.append(newlosses)
                    g_losses[label], _, _ = mpi_moments(losses, axis=0)

                #########################
                for ob_batch, ac_batch, full_ob_batch in dataset.iterbatches(
                    (segs[label]["ob"], segs[label]["ac"],
                     segs[label]["full_ob"]),
                        include_final_partial_batch=False,
                        batch_size=len(ob)):
                    expert_batch = expert_dataset.next_batch(len(ob))

                    ob_expert, ac_expert = expert_batch["ob"], expert_batch[
                        "ac"]

                    exp_rew = 0
                    exp_rews = None
                    for obs, acs in zip(ob_expert, ac_expert):
                        curr_rew = reward_giver.get_reward(obs, acs)[0][0] \
                                   if not hasattr(reward_giver, '_labels') else \
                                   reward_giver.get_reward(obs, acs, label)
                        if isinstance(curr_rew, tuple):
                            curr_rew, curr_rews = curr_rew
                            exp_rews = 1 - np.exp(
                                -curr_rews
                            ) if exp_rews is None else exp_rews + 1 - np.exp(
                                -curr_rews)
                        exp_rew += 1 - np.exp(-curr_rew)
                    mean_exp_rew = exp_rew / len(ob_expert)
                    mean_exp_rews = exp_rews / len(
                        ob_expert) if exp_rews is not None else None

                    gen_rew = 0
                    gen_rews = None
                    for obs, acs, full_obs in zip(ob_batch, ac_batch,
                                                  full_ob_batch):
                        curr_rew = reward_giver.get_reward(obs, acs)[0][0] \
                                   if not hasattr(reward_giver, '_labels') else \
                                   reward_giver.get_reward(obs, acs, label)
                        if isinstance(curr_rew, tuple):
                            curr_rew, curr_rews = curr_rew
                            gen_rews = 1 - np.exp(
                                -curr_rews
                            ) if gen_rews is None else gen_rews + 1 - np.exp(
                                -curr_rews)
                        gen_rew += 1 - np.exp(-curr_rew)
                    mean_gen_rew = gen_rew / len(ob_batch)
                    mean_gen_rews = gen_rews / len(
                        ob_batch) if gen_rews is not None else None
                    if rank == 0:
                        msg = "Network " + label + \
                            " Generator step " + str(curr_step) + ": Dicriminator reward of expert traj " \
                            + str(mean_exp_rew) + " vs gen traj " + str(mean_gen_rew)
                        if mean_exp_rews is not None and mean_gen_rews is not None:
                            msg += "\nDiscriminator multi rewards of expert " + str(mean_exp_rews) + " vs gen " \
                                    + str(mean_gen_rews)
                        logger.log(msg)
                #########################

        if not not_update:
            for label in g_losses:
                for (lossname,
                     lossval) in zip(loss_names[label] + vf_loss_names[label],
                                     g_losses[label]):
                    logger.record_tabular(lossname, lossval)
                logger.record_tabular(
                    label + "ev_tdlam_before",
                    explained_variance(segs[label]["vpred"],
                                       segs[label]["tdlamret"]))

        # ------------------ Update D ------------------
        if not freeze_d:
            logger.log("Optimizing Discriminator...")
            batch_size = len(list(segs.values())[0]['ob']) // d_step
            expert_dataset = d[expert_label]
            batch_gen = {
                label: dataset.iterbatches(
                    (segs[label]["ob"], segs[label]["ac"]),
                    include_final_partial_batch=False,
                    batch_size=batch_size)
                for label in segs if label != expert_label
            }

            d_losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for step in range(d_step):
                g_ob = {}
                g_ac = {}
                for label in batch_gen:  # get batches for different gens
                    g_ob[label], g_ac[label] = batch_gen[label].__next__()

                expert_batch = expert_dataset.next_batch(batch_size)

                ob_expert, ac_expert = expert_batch["ob"], expert_batch["ac"]

                for label in g_ob:
                    #########################
                    exp_rew = 0
                    exp_rews = None
                    for obs, acs in zip(ob_expert, ac_expert):
                        curr_rew = reward_giver.get_reward(obs, acs)[0][0] \
                            if not hasattr(reward_giver, '_labels') else \
                            reward_giver.get_reward(obs, acs, label)
                        if isinstance(curr_rew, tuple):
                            curr_rew, curr_rews = curr_rew
                            exp_rews = 1 - np.exp(
                                -curr_rews
                            ) if exp_rews is None else exp_rews + 1 - np.exp(
                                -curr_rews)
                        exp_rew += 1 - np.exp(-curr_rew)
                    mean_exp_rew = exp_rew / len(ob_expert)
                    mean_exp_rews = exp_rews / len(
                        ob_expert) if exp_rews is not None else None

                    gen_rew = 0
                    gen_rews = None
                    for obs, acs in zip(g_ob[label], g_ac[label]):
                        curr_rew = reward_giver.get_reward(obs, acs)[0][0] \
                            if not hasattr(reward_giver, '_labels') else \
                            reward_giver.get_reward(obs, acs, label)
                        if isinstance(curr_rew, tuple):
                            curr_rew, curr_rews = curr_rew
                            gen_rews = 1 - np.exp(
                                -curr_rews
                            ) if gen_rews is None else gen_rews + 1 - np.exp(
                                -curr_rews)
                        gen_rew += 1 - np.exp(-curr_rew)
                    mean_gen_rew = gen_rew / len(g_ob[label])
                    mean_gen_rews = gen_rews / len(
                        g_ob[label]) if gen_rews is not None else None
                    if rank == 0:
                        msg = "Dicriminator reward of expert traj " + str(mean_exp_rew) + " vs " + label + \
                            "gen traj " + str(mean_gen_rew)
                        if mean_exp_rews is not None and mean_gen_rews is not None:
                            msg += "\nDiscriminator multi expert rewards " + str(mean_exp_rews) + " vs " + label + \
                                   "gen " + str(mean_gen_rews)
                        logger.log(msg)
                        #########################

                # update running mean/std for reward_giver
                if hasattr(reward_giver, "obs_rms"):
                    reward_giver.obs_rms.update(
                        np.concatenate(list(g_ob.values()) + [ob_expert], 0))
                *newlosses, g = reward_giver.lossandgrad(
                    *(list(g_ob.values()) + list(g_ac.values()) + [ob_expert] +
                      [ac_expert]))
                d_adam.update(allmean(g), d_stepsize)
                d_losses.append(newlosses)
                logger.log(fmt_row(13, reward_giver.loss_name))
                logger.log(fmt_row(13, np.mean(d_losses, axis=0)))

        for label in ob_space:
            lrlocal = (segs[label]["ep_lens"], segs[label]["ep_rets"],
                       segs[label]["ep_true_rets"], segs[label]["ep_success"],
                       segs[label]["ep_semi_loss"])  # local values

            listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
            lens, rews, true_rets, success, semi_losses = map(
                flatten_lists, zip(*listoflrpairs))

            # success
            success = [
                float(elem) for elem in success
                if isinstance(elem, (int, float, bool))
            ]  # remove potential None types
            if not success:
                success = [-1]  # set success to -1 if env has no success flag
            success_buffer[label].extend(success)

            true_rewbuffer[label].extend(true_rets)
            lenbuffer[label].extend(lens)
            rewbuffer[label].extend(rews)

            if semi_loss and semi_dataset is not None and label == get_semi_prefix(
            ):
                semi_losses = [elem * l2_w for elem in semi_losses]
                total_rewards = rews
                total_rewards = [
                    re_elem - l2_elem
                    for re_elem, l2_elem in zip(total_rewards, semi_losses)
                ]
                l2_rewbuffer.extend(semi_losses)
                total_rewbuffer.extend(total_rewards)

            logger.record_tabular(label + "EpLenMean",
                                  np.mean(lenbuffer[label]))
            logger.record_tabular(label + "EpRewMean",
                                  np.mean(rewbuffer[label]))
            logger.record_tabular(label + "EpTrueRewMean",
                                  np.mean(true_rewbuffer[label]))
            logger.record_tabular(label + "EpSuccess",
                                  np.mean(success_buffer[label]))

            if semi_loss and semi_dataset is not None and label == get_semi_prefix(
            ):
                logger.record_tabular(label + "EpSemiLoss",
                                      np.mean(l2_rewbuffer))
                logger.record_tabular(label + "EpTotalLoss",
                                      np.mean(total_rewbuffer))
            logger.record_tabular(label + "EpThisIter", len(lens))
            episodes_so_far[label] += len(lens)
            timesteps_so_far[label] += sum(lens)

            logger.record_tabular(label + "EpisodesSoFar",
                                  episodes_so_far[label])
            logger.record_tabular(label + "TimestepsSoFar",
                                  timesteps_so_far[label])
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        iters_so_far += 1
        logger.record_tabular("ItersSoFar", iters_so_far)

        if rank == 0:
            logger.dump_tabular()
            if not not_update:
                for label in g_loss_stats:
                    g_loss_stats[label].add_all_summary(
                        writer, g_losses[label], iters_so_far)
            if not freeze_d:
                d_loss_stats.add_all_summary(writer, np.mean(d_losses, axis=0),
                                             iters_so_far)

            for label in ob_space:
                # default buffers
                ep_buffers = [
                    np.mean(true_rewbuffer[label]),
                    np.mean(rewbuffer[label]),
                    np.mean(lenbuffer[label]),
                    np.mean(success_buffer[label])
                ]

                if semi_loss and semi_dataset is not None and label == get_semi_prefix(
                ):
                    ep_buffers.append(np.mean(l2_rewbuffer))
                    ep_buffers.append(np.mean(total_rewbuffer))

                ep_stats[label].add_all_summary(writer, ep_buffers,
                                                iters_so_far)

        if not_update and not freeze_g:
            not_update -= 1