示例#1
0
def launch(pa, pg_resume=None, render=False, repre='image', end='no_new_job'):

    # env = environment.Env(pa, render=render, repre=repre, end=end)
    # print(repre)
    env = environment.Env(pa, repre=repre)
    test_env = environment.Env(pa, repre=repre)

    pg_learner = pg_network.PGLearner(pa)
    testing = Testing.TestingTrainParameter(pa, test_env, pg_learner)
    if pg_resume is not None:
        net_handle = open(pg_resume, 'rb')
        net_params = cPickle.load(net_handle)
        pg_learner.set_net_params(net_params)

    # ----------------------------
    print("Preparing for data...")
    # ----------------------------

    ref_discount_rews, ref_slow_down = slow_down_cdf.launch(pa,
                                                            pg_resume=None,
                                                            render=False,
                                                            plot=False,
                                                            repre=repre,
                                                            end=end)

    mean_rew_lr_curve = []
    max_rew_lr_curve = []
    slow_down_lr_curve = []

    timer_start = time.time()

    for iteration in xrange(pa.num_epochs):

        all_ob = []
        all_action = []
        all_adv = []
        all_eprews = []
        all_eplens = []
        all_slowdown = []
        all_entropy = []

        # go through all examples
        for ex in xrange(pa.num_ex):

            # Collect trajectories until we get timesteps_per_batch total timesteps
            trajs = []

            for i in xrange(pa.num_seq_per_batch):
                traj = get_traj(pg_learner, env, pa.episode_max_length)
                trajs.append(traj)

            # roll to next example
            # env.seq_no = (env.seq_no + 1) % env.pa.num_ex

            all_ob.append(concatenate_all_ob(trajs, pa))

            # Compute discounted sums of rewards
            rets = [discount(traj["reward"], pa.discount) for traj in trajs]
            maxlen = max(len(ret) for ret in rets)
            padded_rets = [
                np.concatenate([ret, np.zeros(maxlen - len(ret))])
                for ret in rets
            ]

            # Compute time-dependent baseline
            baseline = np.mean(padded_rets, axis=0)

            # Compute advantage function
            advs = [ret - baseline[:len(ret)] for ret in rets]
            all_action.append(
                np.concatenate([traj["action"] for traj in trajs]))
            all_adv.append(np.concatenate(advs))

            all_eprews.append(
                np.array([
                    discount(traj["reward"], pa.discount)[0] for traj in trajs
                ]))  # episode total rewards
            all_eplens.append(np.array([len(traj["reward"])
                                        for traj in trajs]))  # episode lengths

            # All Job Stat
            enter_time, finish_time, job_len = process_all_info(trajs)
            finished_idx = (finish_time >= 0)
            all_slowdown.append(
                (finish_time[finished_idx] - enter_time[finished_idx]) /
                job_len[finished_idx])

            # Action prob entropy
            all_entropy.append(np.concatenate([traj["entropy"]]))

        all_ob = concatenate_all_ob_across_examples(all_ob, pa)
        all_action = np.concatenate(all_action)
        all_adv = np.concatenate(all_adv)

        # Do policy gradient update step
        loss = pg_learner.train(all_ob, all_action, all_adv)
        eprews = np.concatenate(all_eprews)  # episode total rewards
        eplens = np.concatenate(all_eplens)  # episode lengths

        all_slowdown = np.concatenate(all_slowdown)

        all_entropy = np.concatenate(all_entropy)

        timer_end = time.time()

        print "-----------------"
        print "Iteration: \t %i" % iteration
        print "NumTrajs: \t %i" % len(eprews)
        print "NumTimesteps: \t %i" % np.sum(eplens)
        print "Loss:     \t %s" % loss
        print "MaxRew: \t %s" % np.average([np.max(rew) for rew in all_eprews])
        print "MeanRew: \t %s +- %s" % (eprews.mean(), eprews.std())
        # print "MeanSlowdown: \t %s" % np.mean(all_slowdown)
        print "MeanLen: \t %s +- %s" % (eplens.mean(), eplens.std())
        # print "MeanEntropy \t %s" % (np.mean(all_entropy))
        print "Elapsed time\t %s" % (timer_end - timer_start), "seconds"
        if iteration % 10 == 0:
            testing.start()
        print "-----------------"

        timer_start = time.time()

        max_rew_lr_curve.append(np.average([np.max(rew)
                                            for rew in all_eprews]))
        mean_rew_lr_curve.append(eprews.mean())
        slow_down_lr_curve.append(np.mean(all_slowdown))

        if iteration % pa.output_freq == 0:
            param_file = open(
                pa.output_filename + '_' + str(iteration) + '.pkl', 'wb')
            cPickle.dump(pg_learner.get_params(), param_file, -1)
            param_file.close()