예제 #1
0
def run(args):
    job_name = args.job_name
    task_index = args.task_index
    # sys.stderr.write('Starting job %s task %d\n' % (job_name, task_index))

    ps_hosts = args.ps_hosts.split(',')
    worker_hosts = args.worker_hosts.split(',')

    cluster = tf.train.ClusterSpec({'ps': ps_hosts, 'worker': worker_hosts})
    server = tf.train.Server(cluster, job_name=job_name, task_index=task_index)

    if job_name == 'ps':
        server.join()
    elif job_name == 'worker':
        env = create_env(task_index)

        ## 回头这里可以把GPU的参数传进去
        learner = A3C(
            cluster=cluster,
            server=server,  # create session的时候要用
            task_index=task_index,
            env=env,
            dagger=args.dagger)

        try:
            learner.run()
        except KeyboardInterrupt:
            pass
        finally:
            learner.cleanup()
            if args.driver is not None:
                shutdown_from_driver(args.driver)
예제 #2
0
def run(args, server):
    env = atari_environment.AtariEnvironment(args.game)
    trainer = A3C(env, args.task)

    # Variable names that start with "local" are not saved in checkpoints.
    variables_to_save = [
        v for v in tf.global_variables() if not v.name.startswith("local")
    ]
    init_op = tf.variables_initializer(variables_to_save)
    init_all_op = tf.global_variables_initializer()
    saver = FastSaver(variables_to_save)

    def init_fn(ses):
        logger.info("Initializing all parameters.")
        ses.run(init_all_op)

    config = tf.ConfigProto(device_filters=[
        "/job:ps", "/job:worker/task:{}/cpu:0".format(args.task)
    ])
    logdir = os.path.join(args.log_dir, 'train')
    summary_writer = tf.summary.FileWriter(logdir + "_%d" % args.task)
    logger.info("Events directory: %s_%s", logdir, args.task)

    sv = tf.train.Supervisor(
        is_chief=(args.task == 0),
        logdir=logdir,
        saver=saver,
        summary_op=None,
        init_op=init_op,
        init_fn=init_fn,
        summary_writer=summary_writer,
        ready_op=tf.report_uninitialized_variables(variables_to_save),
        global_step=trainer.global_step,
        save_model_secs=30,
        save_summaries_secs=30)

    num_global_steps = 100000000

    logger.info(
        "Starting session. If this hangs, we're mostly likely waiting to connect to the parameter server. "
        +
        "One common cause is that the parameter server DNS name isn't resolving yet, or is misspecified."
    )
    with sv.managed_session(server.target,
                            config=config) as sess, sess.as_default():
        trainer.start(sess, summary_writer)
        global_step = sess.run(trainer.global_step)
        logger.info("Starting training at step=%d", global_step)
        while not sv.should_stop() and (not num_global_steps
                                        or global_step < num_global_steps):
            trainer.process(sess)
            global_step = sess.run(trainer.global_step)

    # Ask for all the services to stop.
    sv.stop()
    logger.info('reached %s steps. worker stopped.', global_step)
def run(args):
    env = create_env(args.env_id)
    trainer = A3C(env, None, args.visualise, args.intrinsic_type, args.bptt)

    # Variable names that start with "local" are not saved in checkpoints.
    variables_to_save = [v for v in tf.global_variables() if not v.name.startswith("local")]
    init_op = tf.variables_initializer(variables_to_save)
    init_all_op = tf.global_variables_initializer()
    saver = FastSaver(variables_to_save)

    var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
    logger.info('Trainable vars:')
    for v in var_list:
        logger.info('  %s %s', v.name, v.get_shape())

    def init_fn(ses):
        logger.info("Initializing all parameters.")
        ses.run(init_all_op)

    logdir = os.path.join(args.log_dir, 'train')
    summary_writer = tf.summary.FileWriter(logdir)
    logger.info("Events directory: %s", logdir)

    sv = tf.train.Supervisor(is_chief=True,
                             logdir=logdir,
                             saver=saver,
                             summary_op=None,
                             init_op=init_op,
                             init_fn=init_fn,
                             summary_writer=summary_writer,
                             ready_op=tf.report_uninitialized_variables(variables_to_save),
                             global_step=None,
                             save_model_secs=0,
                             save_summaries_secs=0)

    video_dir = os.path.join(args.log_dir, 'test_videos_' + args.intrinsic_type)
    if not os.path.exists(video_dir):
        os.makedirs(video_dir)
    video_filename = video_dir + "/%s_%02d_%d.gif"
    print("Video saved at %s" % video_dir)

    with sv.managed_session() as sess, sess.as_default():
        trainer.start(sess, summary_writer)
        rewards = []
        lengths = []
        for i in range(10):
            frames, reward, length = trainer.evaluate(sess)
            rewards.append(reward)
            lengths.append(length)
            imageio.mimsave(video_filename % (args.env_id, i, reward), frames, fps=30)

        print('Evaluation: avg. reward %.2f    avg.length %.2f' %
              (sum(rewards) / 10.0, sum(lengths) / 10.0))

    # Ask for all the services to stop.
    sv.stop()
예제 #4
0
def run(args):
    env = create_env(args.env_id,
                     client_id=str(args.task),
                     remotes=args.remotes)
    trainer = A3C(env, args.task, args.visualise, args.num_workers,
                  args.worker_id, args.verbose_lvl)

    # Variable names that start with "local" are not saved in checkpoints.
    if use_tf12_api:
        variables_to_save = [
            v for v in tf.global_variables() if not v.name.startswith("local")
        ]
        init_op = tf.variables_initializer(variables_to_save)
        init_all_op = tf.global_variables_initializer()
    else:
        variables_to_save = [
            v for v in tf.all_variables() if not v.name.startswith("local")
        ]
        init_op = tf.initialize_variables(variables_to_save)
        init_all_op = tf.initialize_all_variables()
    print variables_to_save
    saver = FastSaver(variables_to_save)

    var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                 tf.get_variable_scope().name)
    logger.info('Trainable vars:')
    for v in var_list:
        logger.info('  %s %s', v.name, v.get_shape())

    def init_fn(ses):
        logger.info("Initializing all parameters.")
        ses.run(init_all_op)

    logdir = os.path.join(args.log_dir, 'train')

    if use_tf12_api:
        summary_writer = tf.summary.FileWriter(logdir + "_%d" % args.task)
    else:
        summary_writer = tf.train.SummaryWriter(logdir + "_%d" % args.task)

    logger.info("Events directory: %s_%s", logdir, args.task)

    logger.info(
        "Starting session. If this hangs, we're mostly likely waiting to connect to the parameter server. "
        +
        "One common cause is that the parameter server DNS name isn't resolving yet, or is misspecified."
    )
    with tf.Session() as sess:
        init = tf.global_variables_initializer()
        sess.run(init)

        trainer.start_listen_thread()
        trainer.sync_initial_weights(sess, var_list)
        trainer.start(sess, summary_writer)
        while True:
            trainer.process(sess)
예제 #5
0
    def start(self):
        ''' Main execution. 
            Instantiate workers, Accept client connections, 
            distribute computation requests among workers and route computed results back to clients. '''

        self.dnn = A3C()
        self.optimizers = [
            Optimizer(self.dnn) for i in range(cfg.optimizer_num)
        ]

        # Front facing socket to accept client connections.
        socket_front = self.zmq_context.socket(zmq.ROUTER)
        #socket_front.bind('tcp://127.0.0.1:5001')
        socket_front.bind('ipc:///tmp/deeplearning.zmq')

        # Backend socket to distribute work.
        socket_back = self.zmq_context.socket(zmq.DEALER)
        #socket_back = self.zmq_context.socket(zmq.REP)
        socket_back.bind('inproc://backend')

        workers = []

        # Start three workers.
        for i in range(0, cfg.agent_max_num - 1):
            worker = Worker(self.zmq_context, i, self.dnn)
            worker.start()
            workers.append(worker)

        self.run_optimizer()

        # Use built in queue device to distribute requests among workers.
        # What queue device does internally is,
        #   1. Read a client's socket ID and request.
        #   2. Send socket ID and request to a worker.
        #   3. Read a client's socket ID and result from a worker.
        #   4. Route result back to the client using socket ID.
        try:
            zmq.device(zmq.QUEUE, socket_front, socket_back)
        except KeyboardInterrupt:
            print 'interrupted!'

        # process termination
        self.stop_optimizer()

        for worker in workers:
            worker.stop()

        socket_front.close()
        socket_back.close()
        self.zmq_context.term()

        for worker in workers:
            worker.join()

        print "\nServer(Message Queue type) stopped!!!!\n"
예제 #6
0
def train():
    def trainEndedEvaluator(episodeList):
        if len(episodeList) < 10:
            return False
        actualList = episodeList[-10:]
        sumRew = 0
        solved = 0
        minP = 100
        maxP = -100
        velocity = 0.0
        changes = 0
        for episode in actualList:
            lastAction = None
            for item in episode:
                if lastAction != item.action:
                    lastAction = item.action
                    changes += 1
                sumRew += item.reward
                position = item.next_state[0]
                velocity += item.next_state[1]
                minP = min(minP, position)
                maxP = max(maxP, position)
                if position >= 0.5:
                    solved += 1
        avg = sumRew / len(actualList)
        avgChanges = changes / len(actualList)
        avgSpeed = velocity / (len(actualList) * 200)
        print("Avg Changes", avgChanges, " Average reward  after ",
              commons.totalEpisodes, " episodes is ", avg, "Solved = ", solved,
              " best range (", minP, ", ", maxP, ") , avg speed ", avgSpeed)
        if solved / 10.0 > 0.9:
            return True

    def remapReward(episodeData):
        x0 = -0.5
        totalReward = 0
        minPos = +5
        maxPos = -5
        for eps in episodeData:
            position = eps.next_state[0]
            minPos = min(minPos, position)
            maxPos = max(maxPos, position)
        delta_inf = x0 - minPos
        delta_sup = maxPos - x0
        rew = abs(delta_sup) - abs(delta_inf)
        for eps in episodeData:
            eps.reward = rew / len(episodeData)
        totalReward = rew
        return totalReward

    a3c = A3C('MountainCar-v0', numThreads=10)
    a3c.train(stopFunction=trainEndedEvaluator,
              remapRewardFunction=remapReward,
              epsilon=1.0)
예제 #7
0
def run(args, server):
    env = create_env(args.env_id, client_id=str(args.task), remotes=args.remotes, num_trials=args.num_trials)

    trainer = A3C(env, args.task, args.visualise, args.meta, args.remotes, args.num_trials)

    # log, checkpoints et tensorboard

    # (Original Comment) Variable names that start with "local" are not saved in checkpoints.
    if use_tf12_api:
        variables_to_save = [v for v in tf.global_variables() if not v.name.startswith("local")]
        init_op = tf.variables_initializer(variables_to_save)
        init_all_op = tf.global_variables_initializer()
    else:
        variables_to_save = [v for v in tf.all_variables() if not v.name.startswith("local")]
        init_op = tf.initialize_variables(variables_to_save)
        init_all_op = tf.initialize_all_variables()
    saver = FastSaver(variables_to_save)

    var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
    logger.info('Trainable vars:')
    for v in var_list:
        logger.info('  %s %s', v.name, v.get_shape())

    def init_fn(ses):
        logger.info("Initializing all parameters.")
        ses.run(init_all_op)

    config = tf.ConfigProto(device_filters=["/job:ps", "/job:worker/task:{}/cpu:0".format(args.task)])
    logdir = os.path.join(args.log_dir, 'train')

    if use_tf12_api:
        summary_writer = tf.summary.FileWriter(logdir + "_%d" % args.task)
    else:
        summary_writer = tf.train.SummaryWriter(logdir + "_%d" % args.task)

    logger.info("Events directory: %s_%s", logdir, args.task)
    # The tf.train.Supervisor provides a set of services that helps implement a robust training process. *(4)
    sv = tf.train.Supervisor(is_chief=(args.task == 0),
                             logdir=logdir,
                             saver=saver,
                             summary_op=None,
                             init_op=init_op,
                             init_fn=init_fn,
                             summary_writer=summary_writer,
                             ready_op=tf.report_uninitialized_variables(variables_to_save),
                             global_step=trainer.global_step,
                             save_model_secs=30,
                             save_summaries_secs=30)

    if args.test: # testing phase
        run_test(trainer, sv, config, summary_writer, server)
    else: # training phase
        run_train(trainer, sv, config, summary_writer, server)
예제 #8
0
 def creat_agent(process_idx):
     env = build_env(args.type,
                     args,
                     max_episode_length=args.max_episode_length)
     return A3C(model,
                opt,
                env,
                args.t_max,
                0.99,
                beta=args.beta,
                process_idx=process_idx,
                phi=dqn_phi)
예제 #9
0
def play(args, server):
    env = create_env(args.env_id,
                     client_id=str(args.task),
                     remotes=args.remotes)
    trainer = A3C(env, args.task, args.visualise)
    result = []
    """
    implement your code here 
    Condition:
        The purpose of this function is for testing
        The number of episodes is 20
        you have to return the mean value of rewards of 20 episodes
    """

    return np.mean(result)
예제 #10
0
def eval_tic_tac_toe(value_weight,
                     num_epoch_rounds=1,
                     games=10**4,
                     rollouts=10**5,
                     advantage_lambda=0.98):
    """
  Returns the average reward over 10k games after 100k rollouts
  
  Parameters
  ----------
  value_weight: float

  Returns
  ------- 
  avg_rewards
  """
    env = TicTacToeEnvironment()
    model_dir = "/tmp/tictactoe"
    try:
        shutil.rmtree(model_dir)
    except:
        pass

    avg_rewards = []
    for j in range(num_epoch_rounds):
        print("Epoch round: %d" % j)
        a3c_engine = A3C(env,
                         entropy_weight=0.01,
                         value_weight=value_weight,
                         model_dir=model_dir,
                         advantage_lambda=advantage_lambda)
        try:
            a3c_engine.restore()
        except:
            print("unable to restore")
            pass
        a3c_engine.fit(rollouts)
        rewards = []
        for i in range(games):
            env.reset()
            reward = -float('inf')
            while not env.terminated:
                action = a3c_engine.select_action(env.state)
                reward = env.step(action)
            rewards.append(reward)
        print("Mean reward at round %d is %f" % (j + 1, np.mean(rewards)))
        avg_rewards.append({(j + 1) * rollouts: np.mean(rewards)})
    return avg_rewards
예제 #11
0
def train():
    def trainEndedEvaluator(episodeList):
        if len(episodeList) < 100:
            return False
        actualList = episodeList[-100:]
        solved = 0
        totalRew = 0
        for episode in actualList:
            solved = solved + 1 if len(episode[0]) >= 450 else solved
            totalRew += len(episode[0])
        avg = totalRew / len(actualList)
        print("Average reward  after ", commons.totalEpisodes, " episodes is ",
              avg, "Solved = ", solved)
        return solved / len(actualList) > 0.9

    a3c = A3C('CartPole-v1', numThreads=10)
    a3c.train(stopFunction=trainEndedEvaluator, epsilon=1.0)
def test(args, server):

    log_dir = os.path.join(args.log_dir, '{}/train'.format(args.env))
    game, parameter = new_environment(name=args.env, test=True)
    a3c = A3C(game,
              log_dir,
              parameter.get(),
              agent_index=args.task,
              callback=game.draw)

    config = tf.ConfigProto(device_filters=[
        "/job:ps", "/job:worker/task:{}/cpu:0".format(args.task)
    ])
    with tf.Session(target=server.target, config=config) as sess:
        saver = tf.train.Saver()
        a3c.load(sess, saver, model_name='best_a3c_model.ckpt')
        a3c.evaluate(sess, n_episode=10, saver=None, verbose=True)
def train(args, server):

    os.environ['OMP_NUM_THREADS'] = '1'
    set_random_seed(args.task * 17)
    log_dir = os.path.join(args.log_dir, '{}/train'.format(args.env))
    if not tf.gfile.Exists(log_dir):
        tf.gfile.MakeDirs(log_dir)

    game, parameter = new_environment(args.env)
    a3c = A3C(game,
              log_dir,
              parameter.get(),
              agent_index=args.task,
              callback=None)

    global_vars = [
        v for v in tf.global_variables() if not v.name.startswith("local")
    ]
    ready_op = tf.report_uninitialized_variables(global_vars)
    config = tf.ConfigProto(device_filters=[
        "/job:ps", "/job:worker/task:{}/cpu:0".format(args.task)
    ])

    with tf.Session(target=server.target, config=config) as sess:
        saver = tf.train.Saver()
        path = os.path.join(log_dir, 'log_%d' % args.task)
        writer = tf.summary.FileWriter(delete_dir(path), sess.graph_def)
        a3c.set_summary_writer(writer)

        if args.task == 0:
            sess.run(tf.global_variables_initializer())
        else:
            while len(sess.run(ready_op)) > 0:
                print("Waiting for task 0 initializing the global variables.")
                time.sleep(1)
        a3c.run(sess, saver)
예제 #14
0
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Tue Nov 28 09:52:10 2017

@author: zqwu
"""

import os
os.chdir(os.environ['HOME'] + '/cs238/CS238')
from a3c import A3C, _Worker
from model_example import DuelNetwork
from sawyer import Sawyer

current_dir = os.getcwd()
urdfFile = os.path.join(current_dir,
                        "rethink/sawyer_description/urdf/sawyer_no_base.urdf")
env = Sawyer(urdfFile)
model = DuelNetwork()

alg = A3C(env, model)
예제 #15
0
def run(args, server):
    env = create_env(args.env_id,
                     client_id=str(args.task),
                     remotes=args.remotes)
    if args.teacher:
        teacher = model.LSTMPolicy(env.observation_space.shape,
                                   env.action_space.n,
                                   name="global")
        teacher_init_op = teacher.load_model_from_checkpoint(
            args.checkpoint_path)

        trainer = A3C(env,
                      args.task,
                      args.visualise,
                      teacher=teacher,
                      name="student")

    else:
        teacher = None
        trainer = A3C(env, args.task, args.visualise, teacher=teacher)

    # Variable names that start with "local" are not saved in checkpoints.
    if use_tf12_api:
        variables_to_save = trainer.global_var_list
        all_trainable_variables = [
            v for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            if trainer.scope in v.name
        ]
        init_op = tf.variables_initializer(variables_to_save)
        init_all_op = tf.variables_initializer(all_student_variables)

    else:

        variables_to_save = trainer.global_var_list
        init_op = tf.initialize_variables(variables_to_save)
        all_trainable_variables = [
            v for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            if trainer.scope in v.name
        ]
        init_all_op = tf.variables_initializer(all_student_variables)

    saver = FastSaver(variables_to_save)

    logger.info('Trainable vars:')

    for v in all_trainable_variables:
        logger.info('{} {}'.format(v.name, v.get_shape()))

    def init_fn(ses):
        logger.info("Initializing all parameters.")
        ses.run([init_all_op])

    def get_init_fn():
        if args.teacher:
            return tf.contrib.framework.assign_from_checkpoint_fn(
                args.checkpoint_path,
                teacher.var_list,
                ignore_missing_vars=True)
        else:
            return lambda sess: init_fn(sess)

    config = tf.ConfigProto(device_filters=[
        "/job:ps", "/job:worker/task:{}/cpu:0".format(args.task)
    ])
    logdir = os.path.join(args.log_dir, 'train')

    if use_tf12_api:
        summary_writer = tf.summary.FileWriter(logdir +
                                               "_{}".format(args.task))
    else:
        summary_writer = tf.train.SummaryWriter(logdir +
                                                "_'{}".format(args.task))

    logger.info("Events directory: {}_{}".format(logdir, args.task))

    sv = tf.train.Supervisor(
        is_chief=(args.task == 0),
        logdir=logdir,
        saver=saver,
        summary_op=None,
        init_op=init_op,
        init_fn=get_init_fn(),
        summary_writer=summary_writer,
        ready_op=tf.report_uninitialized_variables(variables_to_save),
        global_step=trainer.global_step,
        save_model_secs=30,
        save_summaries_secs=30)

    num_global_steps = 100000000

    logger.info(
        "Starting session. If this hangs, we're mostly likely waiting to connect to the parameter server. "
        +
        "One common cause is that the parameter server DNS name isn't resolving yet, or is misspecified."
    )
    with sv.managed_session(server.target,
                            config=config) as sess, sess.as_default():
        sess.run(trainer.sync)
        trainer.start(sess, summary_writer)
        global_step = sess.run(trainer.global_step)
        logger.info("Starting training at step={}".format(global_step))
        while not sv.should_stop() and (not num_global_steps
                                        or global_step < num_global_steps):
            trainer.process(sess)
            global_step = sess.run(trainer.global_step)

    # Ask for all the services to stop.
    sv.stop()
    logger.info('reached {} steps. worker stopped.'.format(global_step))
예제 #16
0
def run(args, server):
    # run for the workers
    # Create the game environment
    env = create_env(args.env_id,
                     client_id=str(args.task),
                     remotes=args.remotes)
    # Create a new agent : trainer
    trainer = A3C(env, args.task, args.visualise)

    # Variable names that start with "local" are not saved in checkpoints.
    variables_to_save = [
        v for v in tf.global_variables() if not v.name.startswith("local")
    ]
    init_op = tf.variables_initializer(variables_to_save)
    init_all_op = tf.global_variables_initializer()
    saver = FastSaver(variables_to_save)

    var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                 tf.get_variable_scope().name)
    logger.info('Trainable vars:')
    for v in var_list:
        logger.info('  %s %s', v.name, v.get_shape())

    def init_fn(ses):
        logger.info("Initializing all parameters.")
        ses.run(init_all_op)

    config = tf.ConfigProto(device_filters=[
        "/job:ps", "/job:worker/task:{}/cpu:0".format(args.task)
    ])
    logdir = os.path.join(args.log_dir, 'train')

    if use_tf12_api:
        summary_writer = tf.summary.FileWriter(logdir + "_%d" % args.task)
    else:
        summary_writer = tf.train.SummaryWriter(logdir + "_%d" % args.task)

    logger.info("Events directory: %s_%s", logdir, args.task)
    sv = tf.train.Supervisor(
        is_chief=(args.task == 0),
        logdir=logdir,
        saver=saver,
        summary_op=None,
        init_op=init_op,
        init_fn=init_fn,
        summary_writer=summary_writer,
        ready_op=tf.report_uninitialized_variables(variables_to_save),
        global_step=trainer.global_step,
        save_model_secs=30,
        save_summaries_secs=30)

    # Global steps !!!
    num_global_steps = 100000000

    logger.info(
        "Starting session. If this hangs, we're mostly likely waiting to connect to the parameter server. "
        +
        "One common cause is that the parameter server DNS name isn't resolving yet, or is misspecified."
    )
    with sv.managed_session(server.target,
                            config=config) as sess, sess.as_default():
        # synchronize the params from parameter server
        sess.run(trainer.sync)
        # agent / trainer begins to interact with the game environment
        trainer.start(sess, summary_writer)
        global_step = sess.run(trainer.global_step)
        logger.info("Starting training at step=%d", global_step)
        while not sv.should_stop() and (not num_global_steps
                                        or global_step < num_global_steps):
            # process the rollout information to update the parameters
            # here there is a while loop. So it is the main learning phase.
            # because of the yield syntax in the env_runner, they perform 20 steps of experiments once trainer.process() is called.
            trainer.process(sess)
            global_step = sess.run(trainer.global_step)

    # Ask for all the services to stop.
    sv.stop()
    logger.info('reached %s steps. worker stopped.', global_step)
def run(args, server):
    env_ids = str(args.env_id).split(",")
    tasks = len(env_ids)
    original_logdir = args.log_dir
    logdir = os.path.join(args.log_dir, 'train')

    env_task = args.task % tasks
    args.env_id = env_ids[env_task]
    env = create_env(args.env_id)
    ac_spaces = [create_env(env_id).action_space.n for env_id in env_ids]
    workers_per_task = args.num_workers / tasks
    trainer = A3C(env, args.task, env_task, tasks, ac_spaces, workers_per_task)

    trainable_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                           tf.get_variable_scope().name)

    def global_var(var):
        if not var.name.startswith("global"):
            return False

        if var.name.startswith("global/global_step"):
            return True

        if "/Adam" in var.name:
            return False

        for v in trainable_var_list:
            if v.name == var.name:
                return True

    local_var_list = [v for v in tf.global_variables() if not global_var(v)]
    global_var_list = [v for v in tf.global_variables() if global_var(v)]

    logger.info('Global vars:')
    for v in global_var_list:
        logger.info('  %s %s', v.name, v.get_shape())

    logger.info('Local vars:')
    for v in local_var_list:
        logger.info('  %s %s', v.name, v.get_shape())

    logger.info('Trainable vars:')
    for v in trainable_var_list:
        logger.info('  %s %s', v.name, v.get_shape())

    local_init_op = tf.variables_initializer(local_var_list)
    global_init_op = tf.variables_initializer(global_var_list)

    def init_sync_pairs():
        pairs = []
        for v in local_var_list:
            if v.name.startswith("local"):
                global_v_name = v.name.replace('local', 'global', 1)
                for global_v in global_var_list:
                    if global_v.name == global_v_name:
                        pairs.append((v, global_v))
                        break
        return pairs

    init_sync = tf.group(*[v1.assign(v2) for v1, v2 in init_sync_pairs()])

    saver = FastSaver(global_var_list, max_to_keep=3)
    saver_path = os.path.join(logdir, "model.ckpt")
    report_uninitialized_variables = tf.report_uninitialized_variables()

    def init_fn(ses):
        logger.info("Initializing all parameters.")
        logger.info("Uninizialied Variables after init_fn: %s",
                    ses.run(report_uninitialized_variables))

    config = tf.ConfigProto(device_filters=[
        "/job:ps", "/job:worker/task:{}/cpu:0".format(args.task)
    ])

    summary_dir = logdir + "_{}/worker_{}".format(
        args.task % tasks, int((args.task - (args.task % tasks)) / tasks))
    summary_writer = tf.summary.FileWriter(summary_dir, flush_secs=30)

    logger.info("Events directory: %s", summary_dir)

    sv = tf.train.Supervisor(is_chief=(args.task == 0),
                             logdir=logdir,
                             saver=saver,
                             summary_op=None,
                             init_op=global_init_op,
                             init_fn=init_fn,
                             local_init_op=local_init_op,
                             summary_writer=summary_writer,
                             ready_op=tf.report_uninitialized_variables(
                                 tf.global_variables()),
                             global_step=trainer.global_step,
                             save_model_secs=120,
                             save_summaries_secs=30,
                             recovery_wait_secs=5)

    num_global_steps = 100000  #20000000

    logger.info(
        "Starting session. If this hangs, we're mostly likely waiting to connect to the parameter server. "
        +
        "One common cause is that the parameter server DNS name isn't resolving yet, or is misspecified."
    )
    with sv.managed_session(server.target,
                            config=config) as sess, sess.as_default():
        uninitialized_variables = sess.run(report_uninitialized_variables)
        if len(uninitialized_variables) > 0:
            logger.info("Some variables are not initialized:\n{}").format(
                uninitialized_variables)
        assert len(uninitialized_variables) == 0

        sess.run(init_sync)

        if args.task < args.num_workers:
            trainer.start(sess, summary_writer)
            global_step = sess.run(trainer.global_step)
            logger.info("Starting training at step=%d", global_step)
            while not sv.should_stop() and (not num_global_steps
                                            or global_step < num_global_steps):
                trainer.process(sess)
                global_step = sess.run(trainer.global_step)
            if args.task == 0:
                saver.save(sess, saver_path, global_step)

            while not sv.should_stop():
                time.sleep(5)

    # Ask for all the services to stop.
    sv.stop()
    logger.info('reached %s steps. worker stopped.', global_step)
    if args.task == 0:
        with open(os.path.join(original_logdir, 'done.txt'), "w") as file:
            file.write(str(global_step))
예제 #18
0
import go_vncdriver
from a3c import A3C
from envs import create_env
import tensorflow as tf

env = create_env('flashgames.NeonRace-v0', client_id=0, remotes=1)
trainer = A3C(env, 0, True)

variables_to_save = [
    v for v in tf.global_variables() if not v.name.startswith("local")
]
print[v.name for v in variables_to_save]
예제 #19
0
파일: worker.py 프로젝트: wwxFromTju/DHP
def run(args, server):

    if config.mode in ['on_line']:
        '''f project and on_line mode is special, log_dir is sperate by game (g) and subject (s)'''
        logdir = os.path.join(
            args.log_dir,
            'train_g_' + str(args.env_id) + '_s_' + str(args.subject))
    elif config.mode in ['off_line', 'data_processor']:
        '''normal log_dir'''
        logdir = os.path.join(args.log_dir, 'train')
    '''any way, log_dir is separate by work (task)'''
    summary_writer = tf.summary.FileWriter(logdir + "_%d" % args.task)
    '''log final log_dir'''
    logger.info("Events directory: %s_%s", logdir, args.task)
    '''create env'''
    env = envs.PanoramicEnv(
        env_id=args.env_id,
        task=args.task,
        subject=args.subject,
        summary_writer=summary_writer,
    )
    '''create trainer'''
    trainer = A3C(env, args.env_id, args.task)
    '''Variable names that start with "local" are not saved in checkpoints.'''
    variables_to_save = [
        v for v in tf.global_variables() if not v.name.startswith("local")
    ]
    init_op = tf.variables_initializer(variables_to_save)
    init_all_op = tf.global_variables_initializer()
    saver = FastSaver(variables_to_save)

    def init_fn(ses):
        logger.info("Initializing all parameters.")
        ses.run(init_all_op)

    config_tf = tf.ConfigProto(device_filters=[
        "/job:ps", "/job:worker/task:{}/cpu:0".format(args.task)
    ])
    '''determine is_chief'''
    if config.mode in ['on_line']:
        '''on_line mode has one worker for each ps, so it is always the cheif'''
        is_chief = True
    elif config.mode in ['off_line']:
        '''off_line mode share model for all worker (videos)'''
        is_chief = (args.task == 0)

    if is_chief:
        print('>>>> this is task cheif, initialize variables')
        tf.Session(server.target, config=config_tf).run(init_all_op)
    else:
        print('>>>> this is not task cheif, wait for a while')
        time.sleep(10)

    sv = tf.train.Supervisor(
        is_chief=is_chief,
        logdir=logdir,
        saver=saver,
        summary_op=None,
        init_op=init_op,
        init_fn=init_fn,
        summary_writer=summary_writer,
        ready_op=tf.report_uninitialized_variables(variables_to_save),
        global_step=trainer.global_step,
        save_model_secs=30,
        save_summaries_secs=30)

    logger.info(
        "Starting session. If this hangs, we're mostly likely waiting to connect to the parameter server. "
        +
        "One common cause is that the parameter server DNS name isn't resolving yet, or is misspecified."
    )
    '''start run'''
    with sv.managed_session(server.target, config=config_tf) as sess:
        '''start trainer'''
        trainer.start(sess, summary_writer)
        '''log global_step so that we can see if the model is restored successfully'''
        global_step = sess.run(trainer.global_step)
        logger.info("Starting training at step=%d", global_step)
        '''keep runing'''
        not_reach_train_limit = True
        while (not sv.should_stop()) and not_reach_train_limit:
            trainer.process(sess)
            global_step = sess.run(trainer.global_step)
            if config.number_trained_steps > 0:
                not_reach_train_limit = (global_step <
                                         config.number_trained_steps)
    '''Ask for all the services to stop.'''
    sv.stop()
    logger.info('reached %s steps. worker stopped.', global_step)
예제 #20
0
파일: worker.py 프로젝트: zoujun123/cdrl
def run(args, server):
    # lkx: client and remote doesn't mater for non VNC and flash game
    # env = create_env(args.env_id, client_id=str(args.task), remotes=args.remotes)
    # trainer = A3C(env, args.task)

    target_task = 1  # int(args.target_task)
    env_names = args.env_id.split("_")
    envs = [
        create_env(env_name,
                   client_id=str(args.worker_id),
                   remotes=args.remotes) for env_name in env_names
    ]

    trainer = A3C(envs, int(args.worker_id), target_task)

    # Variable names that start with "local" are not saved in checkpoints.
    variables_to_save = [
        v for v in tf.all_variables() if not v.name.startswith("local")
    ]
    init_op = tf.initialize_variables(variables_to_save)
    init_all_op = tf.initialize_all_variables()
    saver = FastSaver(variables_to_save)

    variables_to_restore = [
        v for v in tf.all_variables()
        if v.name.startswith("global0") and "global_step" not in v.name
    ]  # Adam_2 and 3 cost by the distillation train op
    pre_train_saver = FastSaver(variables_to_restore)

    def init_fn(ses):
        logger.info("Initializing all parameters.")
        ses.run(init_all_op)
        pre_train_saver.restore(ses, "../model/model.ckpt-4986751")

    config = tf.ConfigProto(device_filters=[
        "/job:ps", "/job:worker/task:{}/cpu:0".format(args.worker_id)
    ])  # refer to worker id
    logdir = os.path.join(args.log_dir, 'train')
    summary_writer = tf.train.SummaryWriter(logdir + "_%d" % args.worker_id)
    logger.info("Events directory: %s_%s", logdir, args.worker_id)
    sv = tf.train.Supervisor(
        is_chief=(args.worker_id == 0),
        logdir=logdir,
        saver=saver,
        summary_op=None,
        init_op=
        init_op,  # Defaults to an Operation that initializes all variables
        init_fn=init_fn,  # Called after the optional init_op is called
        summary_writer=summary_writer,
        ready_op=tf.report_uninitialized_variables(
            variables_to_save),  # list the names of uninitialized variables.
        global_step=trainer.global_step[target_task],
        save_model_secs=30,
        save_summaries_secs=30)

    num_taskss = len(envs)

    num_global_steps = 20000000  #10000000

    logger.info(
        "Starting session. If this hangs, we're mostly likely waiting to connect to the parameter server. "
        +
        "One common cause is that the parameter server DNS name isn't resolving yet, or is misspecified."
    )
    with sv.managed_session(server.target,
                            config=config) as sess, sess.as_default():
        for ii in np.arange(num_taskss):
            sess.run(trainer.sync[ii])
        sess.run(trainer.sync_logits)
        trainer.start(sess, summary_writer)
        global_step = sess.run(trainer.global_step[target_task])
        logger.info("Starting training at step=%d", global_step)
        while not sv.should_stop() and (not num_global_steps
                                        or global_step < num_global_steps):
            # if global_step <= 1000000 and np.random.uniform(0, 1) > 0.5:   # todo annealing
            #     batch_aux = trainer.get_knowledge(sess)
            #     trainer.process(sess, batch_aux)
            trainer.process(sess)
            global_step = sess.run(trainer.global_step[target_task])

    # Ask for all the services to stop.
    sv.stop()
    logger.info('reached %s steps. worker stopped.', global_step)
예제 #21
0
def run(args, server):
    env = create_env(client_id=str(args.task), remotes=args.remotes)
    trainer = A3C(env, args.task, args.visualise)

    # 以 'local' 开头的变量(局部变量)不会被保存在 checkpoint 参数文件中
    variables_to_save = [
        v for v in tf.global_variables() if not v.name.startswith("local")
    ]
    init_op = tf.variables_initializer(variables_to_save)
    init_all_op = tf.global_variables_initializer()

    # 保存变量到参数文件中
    saver = FastSaver(variables_to_save)

    # 获取可被训练的变量
    var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                 tf.get_variable_scope().name)

    logger.info('可被训练的变量 :')
    for v in var_list:
        logger.info('  %s %s', v.name, v.get_shape())

    def init_fn(ses):
        logger.info("初始化所有参数。")
        ses.run(init_all_op)

    config = tf.ConfigProto(device_filters=[
        "/job:ps", "/job:worker/task:{}/cpu:0".format(args.task)
    ])

    logdir = os.path.join(args.log_dir, 'train')
    # 写入 TensorBoard 的日志文件
    summary_writer = tf.summary.FileWriter(logdir + "_%d" % args.task)

    logger.info("存储 TensorBoard 文件的目录: %s_%s", logdir, args.task)

    # 一个高层的 Wrapper(包装类)
    # 可以做 TensorBoard 日志文件的保存,参数文件的保存,等等操作
    sv = tf.train.Supervisor(
        is_chief=(args.task == 0),
        logdir=logdir,  # 存储参数文件的目录
        saver=saver,  # 存储参数文件所用的 Saver
        summary_op=None,
        init_op=init_op,
        init_fn=init_fn,
        summary_writer=summary_writer,  # 存储 TensorBoard 日志文件的 FileWriter
        ready_op=tf.report_uninitialized_variables(variables_to_save),
        global_step=trainer.global_step,
        save_model_secs=30,
        save_summaries_secs=30)

    # 总的可运行步数。可修改
    num_global_steps = 100000000

    logger.info("启动会话中...")
    with sv.managed_session(server.target,
                            config=config) as sess, sess.as_default():
        sess.run(trainer.sync)
        trainer.start(sess, summary_writer)
        global_step = sess.run(trainer.global_step)
        logger.info("在第 %d 步开始训练", global_step)
        while not sv.should_stop() and (not num_global_steps
                                        or global_step < num_global_steps):
            trainer.process(sess)
            global_step = sess.run(trainer.global_step)

    # 停止所有服务
    sv.stop()
    logger.info('已经 %s 步了. worker 被停止.', global_step)
예제 #22
0
def run(args, server, renderOnly=False):
    env = create_env(args.env_id, client_id=str(args.task), remotes=args.remotes, renderOnly=renderOnly)
    trainer = A3C(env, args.task, args.visualise, renderOnly=renderOnly)

    # Variable names that start with "local" are not saved in checkpoints.
    if use_tf12_api:
        variables_to_save = [v for v in tf.global_variables() if not v.name.startswith("local")]
        init_op = tf.variables_initializer(variables_to_save)
        init_all_op = tf.global_variables_initializer()
    else:
        variables_to_save = [v for v in tf.all_variables() if not v.name.startswith("local")]
        init_op = tf.initialize_variables(variables_to_save)
        init_all_op = tf.initialize_all_variables()
    saver = FastSaver(variables_to_save)


    var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
    logger.info('Trainable vars:')
    for v in var_list:
        logger.info('  %s %s', v.name, v.get_shape())

    def init_fn(ses):
        logger.info("Initializing all parameters.")
        ses.run(init_all_op)

    config = tf.ConfigProto(device_filters=["/job:ps", "/job:worker/task:{}/cpu:0".format(args.task)])
    logdir = os.path.join(args.log_dir, 'train')

    if use_tf12_api:
        summary_writer = tf.summary.FileWriter(logdir + "_%d" % args.task)
    else:
        summary_writer = tf.train.SummaryWriter(logdir + "_%d" % args.task)

    logger.info("Events directory: %s_%s", logdir, args.task)
    sv = tf.train.Supervisor(is_chief=(args.task == 0),
                             logdir=logdir,
                             saver=saver,
                             summary_op=None,
                             init_op=init_op,
                             init_fn=init_fn,
                             summary_writer=summary_writer,
                             ready_op=tf.report_uninitialized_variables(variables_to_save),
                             global_step=trainer.global_step,
                             save_model_secs=30,
                             save_summaries_secs=30)

    num_global_steps = 100000000

    logger.info(
        "Starting session. If this hangs, we're mostly likely waiting to connect to the parameter server. " +
        "One common cause is that the parameter server DNS name isn't resolving yet, or is misspecified.")
    with sv.managed_session(server.target, config=config) as sess, sess.as_default():
        sess.run(trainer.sync)
        trainer.start(sess, summary_writer)
        global_step = sess.run(trainer.global_step)
        logger.info("Starting training at step=%d", global_step)
        while not sv.should_stop() and (not num_global_steps or global_step < num_global_steps):
            #logger.info("About to process")
            trainer.process(sess)
            global_step = sess.run(trainer.global_step)

    # Ask for all the services to stop.
    sv.stop()
    logger.info('reached %s steps. worker stopped.', global_step)
예제 #23
0
def run(args, server):
    #environment and trainer
    env = create_env(args.env_id,
                     client_id=str(args.task),
                     remotes=args.remotes)
    trainer = A3C(env, args.task, args.visualise)

    # Variable names that start with "local" are not saved in checkpoints.
    # global_variables are shared between distributed machines
    if use_tf12_api:
        variables_to_save = [
            v for v in tf.global_variables() if not v.name.startswith("local")
        ]
        init_op = tf.variables_initializer(variables_to_save)
        init_all_op = tf.global_variables_initializer()
    # this will not be run since we are using a latest version of tensorflow
    else:
        variables_to_save = [
            v for v in tf.all_variables() if not v.name.startswith("local")
        ]
        init_op = tf.initialize_variables(variables_to_save)
        init_all_op = tf.initialize_all_variables()
    #saver for saving the parameters
    saver = FastSaver(variables_to_save)

    #get trainable
    var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                 tf.get_variable_scope().name)
    logger.info('Trainable vars:')
    for v in var_list:
        logger.info('  %s %s', v.name, v.get_shape())
    #--------------------------------------------------------------------------------------------------------------------------------
    #
    def init_fn(ses):
        logger.info("Initializing all parameters.")
        ses.run(init_all_op)

    #--------------------------------------------------------------------------------------------------------------------------------
    config = tf.ConfigProto(device_filters=[
        "/job:ps", "/job:worker/task:{}/cpu:0".format(args.task)
    ])
    logdir = os.path.join(args.log_dir, 'train')

    if use_tf12_api:
        summary_writer = tf.summary.FileWriter(logdir + "_%d" % args.task)
    else:
        summary_writer = tf.train.SummaryWriter(logdir + "_%d" % args.task)

    logger.info("Events directory: %s_%s", logdir, args.task)
    #
    # https://www.tensorflow.org/api_docs/python/tf/train/Supervisor
    sv = tf.train.Supervisor(
        is_chief=(args.task == 0),
        logdir=logdir,
        saver=saver,
        summary_op=None,
        init_op=init_op,
        init_fn=init_fn,
        summary_writer=summary_writer,
        ready_op=tf.report_uninitialized_variables(variables_to_save),
        global_step=trainer.global_step,
        save_model_secs=30,
        save_summaries_secs=30)

    #maximum amount of steps for the
    num_global_steps = 100000000

    logger.info(
        "Starting session. If this hangs, we're mostly likely waiting to connect to the parameter server. "
        +
        "One common cause is that the parameter server DNS name isn't resolving yet, or is misspecified."
    )
    with sv.managed_session(server.target,
                            config=config) as sess, sess.as_default():
        sess.run(trainer.sync)
        trainer.start(sess, summary_writer)
        global_step = sess.run(trainer.global_step)
        logger.info("Starting training at step=%d", global_step)

        # This is the training loop,
        while not sv.should_stop() and (not num_global_steps
                                        or global_step < num_global_steps):
            trainer.process(sess)
            global_step = sess.run(trainer.global_step)

    # Ask for all the services to stop.
    sv.stop()
    logger.info('reached %s steps. worker stopped.', global_step)
예제 #24
0
def run(args, server):
    env = new_env(args)
    if args.alg == 'A3C':
        trainer = A3C(env, args)
    elif args.alg == 'Q':
        trainer = Q(env, args)
    elif args.alg == 'VPN':
        env_off = new_env(args)
        env_off.verbose = 0
        env_off.reset()
        trainer = VPN(env, args, env_off=env_off)
    else:
        raise ValueError('Invalid algorithm: ' + args.alg)

    # Variable names that start with "local" are not saved in checkpoints.
    variables_to_save = [v for v in tf.global_variables() if \
                not v.name.startswith("global") and not v.name.startswith("local/target/")]
    global_variables = [
        v for v in tf.global_variables() if not v.name.startswith("local")
    ]

    init_op = tf.variables_initializer(global_variables)
    init_all_op = tf.global_variables_initializer()
    saver = FastSaver(variables_to_save, max_to_keep=0)

    var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                 tf.get_variable_scope().name)
    logger.info('Trainable vars:')
    for v in var_list:
        logger.info('  %s %s', v.name, v.get_shape())
    logger.info("Num parameters: %d", trainer.local_network.num_param)

    def init_fn(ses):
        logger.info("Initializing all parameters.")
        ses.run(init_all_op)

    device = 'gpu' if args.gpu > 0 else 'cpu'
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.15)
    config = tf.ConfigProto(device_filters=[
        "/job:ps", "/job:worker/task:{}/{}:0".format(args.task, device)
    ],
                            gpu_options=gpu_options,
                            allow_soft_placement=True)
    logdir = os.path.join(args.log, 'train')
    summary_writer = tf.summary.FileWriter(logdir + "_%d" % args.task)

    logger.info("Events directory: %s_%s", logdir, args.task)
    sv = tf.train.Supervisor(
        is_chief=(args.task == 0),
        logdir=logdir,
        saver=saver,
        summary_op=None,
        init_op=init_op,
        init_fn=init_fn,
        summary_writer=summary_writer,
        ready_op=tf.report_uninitialized_variables(global_variables),
        global_step=trainer.global_step,
        save_model_secs=0,
        save_summaries_secs=30)

    logger.info(
        "Starting session. If this hangs, we're mostly likely waiting to connect to the parameter server. "
        +
        "One common cause is that the parameter server DNS name isn't resolving yet, or is misspecified."
    )
    with sv.managed_session(server.target,
                            config=config) as sess, sess.as_default():
        sess.run(trainer.sync)
        trainer.start(sess, summary_writer)
        global_step = sess.run(trainer.global_step)
        epoch = -1
        logger.info("Starting training at step=%d", global_step)
        while not sv.should_stop() and (not args.max_step
                                        or global_step < args.max_step):
            if args.task == 0 and int(global_step / args.eval_freq) > epoch:
                epoch = int(global_step / args.eval_freq)
                filename = os.path.join(args.log, 'e%d' % (epoch))
                sv.saver.save(sess, filename)
                sv.saver.save(sess, os.path.join(args.log, 'latest'))
                print("Saved to: %s" % filename)
            trainer.process(sess)
            global_step = sess.run(trainer.global_step)

        if args.task == 0 and int(global_step / args.eval_freq) > epoch:
            epoch = int(global_step / args.eval_freq)
            filename = os.path.join(args.log, 'e%d' % (epoch))
            sv.saver.save(sess, filename)
            sv.saver.save(sess, os.path.join(args.log, 'latest'))
            print("Saved to: %s" % filename)
    # Ask for all the services to stop.
    sv.stop()
    logger.info('reached %s steps. worker stopped.', global_step)
예제 #25
0
	def creat_agent(port, process_idx):
		env = build_env(args, port, name='Train' + str(process_idx))
		return A3C(model, opt, env, args.t_max, 0.99, beta=args.beta,process_idx=process_idx)
예제 #26
0
파일: worker.py 프로젝트: aGiant/btgym
    def run(self):
        """
        Worker runtime body.
        """

        # Define cluster:
        cluster = tf.train.ClusterSpec(self.cluster_spec).as_cluster_def()

        # Start tf.server:
        if self.job_name in 'ps':
            server = tf.train.Server(
                cluster,
                job_name=self.job_name,
                task_index=self.task,
                config=tf.ConfigProto(device_filters=["/job:ps"]))
            self.log.debug('parameters_server started.')
            # Just block here:
            server.join()

        else:
            server = tf.train.Server(cluster,
                                     job_name='worker',
                                     task_index=self.task,
                                     config=tf.ConfigProto(
                                         intra_op_parallelism_threads=1,
                                         inter_op_parallelism_threads=2))
            self.log.debug('worker_{} tf.server started.'.format(self.task))

            self.log.debug('making environment.')
            if not self.test_mode:
                # Assume BTgym env. class:
                self.log.debug('worker_{} is data_master: {}'.format(
                    self.task, self.env_config['data_master']))
                try:
                    self.env = self.env_class(**self.env_config)

                except:
                    raise SystemExit(
                        ' Worker_{} failed to make BTgym environment'.format(
                            self.task))

            else:
                # Assume atari testing:
                try:
                    self.env = create_env(self.env_config['gym_id'])

                except:
                    raise SystemExit(
                        ' Worker_{} failed to make Atari Gym environment'.
                        format(self.task))

            self.log.debug('worker_{}:envronment ok.'.format(self.task))
            # Define trainer:
            trainer = A3C(env=self.env,
                          task=self.task,
                          model_class=self.model_class,
                          test_mode=self.test_mode,
                          **self.kwargs)

            self.log.debug('worker_{}:trainer ok.'.format(self.task))

            # Saver-related:
            variables_to_save = [
                v for v in tf.global_variables()
                if not v.name.startswith("local")
            ]
            init_op = tf.variables_initializer(variables_to_save)
            init_all_op = tf.global_variables_initializer()

            saver = FastSaver(variables_to_save)

            var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                         tf.get_variable_scope().name)

            #self.log.debug('worker-{}: trainable vars:'.format(self.task))
            #for v in var_list:
            #    self.log.debug('{}: {}'.format(v.name, v.get_shape()))

            def init_fn(ses):
                self.log.debug("Initializing all parameters.")
                ses.run(init_all_op)

            config = tf.ConfigProto(device_filters=[
                "/job:ps", "/job:worker/task:{}/cpu:0".format(self.task)
            ])
            logdir = os.path.join(self.log_dir, 'train')
            summary_dir = logdir + "_{}".format(self.task)

            summary_writer = tf.summary.FileWriter(summary_dir)

            sv = tf.train.Supervisor(
                is_chief=(self.task == 0),
                logdir=logdir,
                saver=saver,
                summary_op=None,
                init_op=init_op,
                init_fn=init_fn,
                #summary_writer=summary_writer,
                ready_op=tf.report_uninitialized_variables(variables_to_save),
                global_step=trainer.global_step,
                save_model_secs=300,
            )
            self.log.debug("connecting to the parameter server... ")

            with sv.managed_session(server.target,
                                    config=config) as sess, sess.as_default():
                sess.run(trainer.sync)
                trainer.start(sess, summary_writer)
                global_step = sess.run(trainer.global_step)
                self.log.info(
                    "worker_{}: starting training at step: {}".format(
                        self.task, global_step))
                while not sv.should_stop() and global_step < self.max_steps:
                    trainer.process(sess)
                    global_step = sess.run(trainer.global_step)

                # Ask for all the services to stop:
                self.env.close()
                sv.stop()
            self.log.info('worker_{}: reached {} steps, exiting.'.format(
                self.task, global_step))
예제 #27
0
def run_tester(args, server):
    env = new_env(args)
    # env.configure()
    env.reset()
    env.max_history = args.eval_num
    if args.alg == 'A3C':
        agent = A3C(env, args)
    elif args.alg == 'Q':
        agent = Q(env, args)
    elif args.alg == 'VPN':
        agent = VPN(env, args)
    else:
        raise ValueError('Invalid algorithm: ' + args.alg)

    device = 'gpu' if args.gpu > 0 else 'cpu'
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.15)
    config = tf.ConfigProto(device_filters=[
        "/job:ps", "/job:worker/task:{}/{}:0".format(args.task, device)
    ],
                            gpu_options=gpu_options,
                            allow_soft_placement=True)
    variables_to_save = [v for v in tf.global_variables() if \
                not v.name.startswith("global") and not v.name.startswith("local/target/")]
    global_variables = [
        v for v in tf.global_variables() if not v.name.startswith("local")
    ]

    init_op = tf.variables_initializer(global_variables)
    init_all_op = tf.global_variables_initializer()

    var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                 tf.get_variable_scope().name)
    logger.info('Trainable vars:')
    for v in var_list:
        logger.info('  %s %s', v.name, v.get_shape())
    logger.info("Num parameters: %d", agent.local_network.num_param)

    def init_fn(ses):
        logger.info("Initializing all parameters.")
        ses.run(init_all_op)

    saver = FastSaver(variables_to_save, max_to_keep=0)
    sv = tf.train.Supervisor(
        is_chief=False,
        global_step=agent.global_step,
        summary_op=None,
        init_op=init_op,
        init_fn=init_fn,
        ready_op=tf.report_uninitialized_variables(global_variables),
        saver=saver,
        save_model_secs=0,
        save_summaries_secs=0)

    best_reward = -10000
    with sv.managed_session(server.target,
                            config=config) as sess, sess.as_default():
        epoch = args.eval_epoch
        while args.eval_freq * epoch <= args.max_step:
            path = os.path.join(args.log, "e%d" % epoch)
            if not os.path.exists(path + ".index"):
                time.sleep(10)
                continue
            logger.info("Start evaluation (Epoch %d)", epoch)
            saver.restore(sess, path)
            np.random.seed(args.seed)
            reward = evaluate(env,
                              agent.local_network,
                              args.eval_num,
                              eps=args.eps_eval)

            logfile = open(os.path.join(args.log, "eval.csv"), "a")
            print("Epoch: %d, Reward: %.2f" % (epoch, reward))
            logfile.write("%d, %.3f\n" % (epoch, reward))
            logfile.close()
            if reward > best_reward:
                best_reward = reward
                sv.saver.save(sess, os.path.join(args.log, 'best'))
                print("Saved to: %s" % os.path.join(args.log, 'best'))

            epoch += 1

    logger.info('tester stopped.')
예제 #28
0
#!/usr/bin/python3
import docker
from a3c import A3C

if __name__ == '__main__':

    docker_client = docker.from_env()

    a3c = A3C(docker_client, 3105, '../models/a3c/', '../logs/a3c/')
    a3c.train(3)
예제 #29
0
def run(args, server):
    env = create_env(args.env_id, client_id=str(args.task), remotes=args.remotes, envWrap=args.envWrap, designHead=args.designHead,
                        noLifeReward=args.noLifeReward)
    trainer = A3C(env, args.task, args.visualise, args.unsup, args.envWrap, args.designHead, args.noReward)

    # logging
    if args.task == 0:
        with open(args.log_dir + '/log.txt', 'w') as fid:
            for key, val in constants.items():
                fid.write('%s: %s\n'%(str(key), str(val)))
            fid.write('designHead: %s\n'%args.designHead)
            fid.write('input observation: %s\n'%str(env.observation_space.shape))
            fid.write('env name: %s\n'%str(env.spec.id))
            fid.write('unsup method type: %s\n'%str(args.unsup))

    # Variable names that start with "local" are not saved in checkpoints.
    if use_tf12_api:
        variables_to_save = [v for v in tf.global_variables() if not v.name.startswith("local")]
        init_op = tf.variables_initializer(variables_to_save)
        init_all_op = tf.global_variables_initializer()
    else:
        variables_to_save = [v for v in tf.all_variables() if not v.name.startswith("local")]
        init_op = tf.initialize_variables(variables_to_save)
        init_all_op = tf.initialize_all_variables()
    saver = FastSaver(variables_to_save)
    if args.pretrain is not None:
        variables_to_restore = [v for v in tf.trainable_variables() if not v.name.startswith("local")]
        pretrain_saver = FastSaver(variables_to_restore)

    var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
    logger.info('Trainable vars:')
    for v in var_list:
        logger.info('  %s %s', v.name, v.get_shape())

    def init_fn(ses):
        logger.info("Initializing all parameters.")
        ses.run(init_all_op)
        if args.pretrain is not None:
            pretrain = tf.train.latest_checkpoint(args.pretrain)
            logger.info("==> Restoring from given pretrained checkpoint.")
            logger.info("    Pretraining address: %s", pretrain)
            pretrain_saver.restore(ses, pretrain)
            logger.info("==> Done restoring model! Restored %d variables.", len(variables_to_restore))

    config = tf.ConfigProto(device_filters=["/job:ps", "/job:worker/task:{}/cpu:0".format(args.task)])
    logdir = os.path.join(args.log_dir, 'train')

    if use_tf12_api:
        summary_writer = tf.summary.FileWriter(logdir + "_%d" % args.task)
    else:
        summary_writer = tf.train.SummaryWriter(logdir + "_%d" % args.task)

    logger.info("Events directory: %s_%s", logdir, args.task)
    sv = tf.train.Supervisor(is_chief=(args.task == 0),
                             logdir=logdir,
                             saver=saver,
                             summary_op=None,
                             init_op=init_op,
                             init_fn=init_fn,
                             summary_writer=summary_writer,
                             ready_op=tf.report_uninitialized_variables(variables_to_save),
                             global_step=trainer.global_step,
                             save_model_secs=30,
                             save_summaries_secs=30)

    num_global_steps = constants['MAX_GLOBAL_STEPS']

    logger.info(
        "Starting session. If this hangs, we're mostly likely waiting to connect to the parameter server. " +
        "One common cause is that the parameter server DNS name isn't resolving yet, or is misspecified.")
    with sv.managed_session(server.target, config=config) as sess, sess.as_default():
        # Workaround for FailedPreconditionError
        # see: https://github.com/openai/universe-starter-agent/issues/44 and 31
        sess.run(trainer.sync)

        trainer.start(sess, summary_writer)
        global_step = sess.run(trainer.global_step)
        logger.info("Starting training at gobal_step=%d", global_step)
        while not sv.should_stop() and (not num_global_steps or global_step < num_global_steps):
            trainer.process(sess)
            global_step = sess.run(trainer.global_step)

    # Ask for all the services to stop.
    sv.stop()
    logger.info('reached %s steps. worker stopped.', global_step)
예제 #30
0
def playMountain():
    a3c = A3C('MountainCar-v0')
    a3c.play()