Exemplo n.º 1
0
Arquivo: a3c.py Projeto: CN-TU/remy
def create_training_thread(training, delay_delta):
    # print("Oh yeah, creating thraining thread!!!")
    logging.info(" ".join(map(str, ("training", training))))
    global global_t, global_thread_index, wall_t, sess
    if len(idle_threads) == 0:
        logging.info(" ".join(
            map(str, ("Creating new thread", global_thread_index))))
        if cooperative:
            created_thread = A3CTrainingThread(
                global_thread_index, global_network, initial_learning_rate,
                learning_rate_input, grad_applier, MAX_TIME_STEP, device,
                training, cooperative, delay_delta)
        else:
            created_thread = A3CTrainingThread(
                global_thread_index, global_network[-global_thread_index],
                initial_learning_rate, learning_rate_input, grad_applier,
                MAX_TIME_STEP, device, training, cooperative, delay_delta)
        training_threads[global_thread_index] = created_thread
        return_index = global_thread_index
        global_thread_index += 1
        initialize_uninitialized(sess)
    else:
        return_index = idle_threads.pop()
        logging.info(" ".join(map(str, ("Recycling thread", return_index))))
        created_thread = training_threads[return_index]

    # set start time
    # start_time = time.time() - wall_t
    # created_thread.set_start_time(start_time)
    created_thread.episode_count = 0

    created_thread.time_differences = []
    created_thread.windows = []
    created_thread.states = []
    created_thread.actions = []
    created_thread.ticknos = []
    created_thread.rewards = []
    created_thread.values = []
    created_thread.estimated_values = []
    created_thread.start_lstm_states = []
    created_thread.variable_snapshots = []
    created_thread.local_t = 0
    created_thread.episode_reward_throughput = 0
    created_thread.episode_reward_delay = 0
    created_thread.episode_reward_sent = 0
    sess.run(created_thread.sync)
    created_thread.reset_state_and_reinitialize(sess)

    return return_index
Exemplo n.º 2
0
 def config(self):
     initial_learning_rate = FLAGS.init_lr
     learning_rate_input = tf.placeholder("float")
     grad_applier = RMSPropApplier(learning_rate=learning_rate_input,
                                   decay=FLAGS.rmsp_alpha,
                                   momentum=0.0,
                                   epsilon=FLAGS.rmsp_epsilon,
                                   clip_norm=FLAGS.grad_norm_clip,
                                   device=constants.device)
     #Network will creat PNN
     self.global_network = Network(
         name="core_{}".format(constants.task_name))
     IPython.embed()
     print("CREATING AGENTS")
     for i in range(FLAGS.threads):
         training_thread = A3CTrainingThread(i,
                                             self.global_network,
                                             initial_learning_rate,
                                             learning_rate_input,
                                             grad_applier,
                                             FLAGS.global_t_max,
                                             device=constants.device,
                                             sess=self.sess,
                                             name="agent_{}_{}".format(
                                                 constants.task_name, i))
         self.training_threads.append(training_thread)
     init = tf.global_variables_initializer()
     self.sess.run(init)
Exemplo n.º 3
0
    def make_thread_obj(index):

        env = make_unity_env(i, unity_baseport=unity_baseport)

        # horrible naming here do to lack of full encapsulation of thread operations in the so-named thread class
        # TODO: when refactoring, move thread management and loop state in to the class
        obj = A3CTrainingThread(i,
                                global_network,
                                args.initial_learning_rate,
                                learning_rate_input,
                                grad_applier,
                                args.max_time_step,
                                device=device,
                                environment=env)  # gym.make(args.gym_env))

        obj.port = env.port  # slight hack: cache port value with the thread
        # TODO: include unity baseport
        # print "created training net %d" % i
        # training_thread_objs.append(obj)
        training_thread_objs[index] = obj
Exemplo n.º 4
0
    def start(self):
        self.global_t = 0

        if not os.path.exists(os.path.join(FLAGS.model_dir, "images")):
            os.makedirs(os.path.join(FLAGS.model_dir, "images"))

        constants.device = "/cpu:0"
        if FLAGS.use_gpu:
            constants.device = "/gpu:0"

        initial_learning_rate = FLAGS.init_lr

        self.stop_requested = False

        # prepare session
        self.sess = tf.Session(config=tf.ConfigProto(
            log_device_placement=False, allow_soft_placement=True))

        self.global_network = Network(name="core_%s" % constants.task_name)

        self.training_threads = []

        learning_rate_input = tf.placeholder("float")

        grad_applier = RMSPropApplier(learning_rate=learning_rate_input,
                                      decay=FLAGS.rmsp_alpha,
                                      momentum=0.0,
                                      epsilon=FLAGS.rmsp_epsilon,
                                      clip_norm=FLAGS.grad_norm_clip,
                                      device=constants.device)

        print("CREATING AGENTS")

        for i in range(FLAGS.threads):
            training_thread = A3CTrainingThread(i,
                                                self.global_network,
                                                initial_learning_rate,
                                                learning_rate_input,
                                                grad_applier,
                                                FLAGS.global_t_max,
                                                device=constants.device,
                                                sess=self.sess,
                                                name="agent_%s_%i" %
                                                (constants.task_name, i))
            self.training_threads.append(training_thread)

        init = tf.initialize_all_variables()
        self.sess.run(init)

        # summary for tensorboard
        self.score_input = tf.placeholder(tf.int32)
        tf.scalar_summary("score", self.score_input)

        self.summary_op = tf.merge_all_summaries()
        self.summary_writer = tf.train.SummaryWriter(FLAGS.model_dir,
                                                     self.sess.graph_def)

        # init or load checkpoint with saver
        self.saver = tf.train.Saver()

        checkpoint = tf.train.get_checkpoint_state(FLAGS.model_dir)
        if checkpoint and checkpoint.model_checkpoint_path:
            self.saver.restore(
                self.sess, checkpoint.model_checkpoint_path
                if FLAGS.checkpoint is None else FLAGS.checkpoint)
            print("checkpoint loaded:", checkpoint.model_checkpoint_path)
            tokens = checkpoint.model_checkpoint_path.split("-")
            # set global step
            self.global_t = int(tokens[-1])
            print(">>> global step set: ", self.global_t)
            # set wall time
            self.wall_t_fname = FLAGS.model_dir + '/' + 'wall_t.' + str(
                self.global_t)
            with open(self.wall_t_fname, 'r') as f:
                self.wall_t = float(f.read())
        else:
            print("Could not find old checkpoint")
            # set wall time
            self.wall_t = 0.0

        if FLAGS.transfer_model is not None:
            self.global_network.load(self.sess, FLAGS.transfer_model)

        train_threads = []
        for i in range(FLAGS.threads):
            train_threads.append(
                threading.Thread(target=self.train_function, args=(i, )))

        signal.signal(signal.SIGINT, self.signal_handler)

        # set start time
        self.start_time = time.time() - self.wall_t

        for t in train_threads:
            t.start()

        print('Press Ctrl+C to stop')
        signal.pause()

        print('Now saving data. Please wait')

        for t in train_threads:
            t.join()

        self.save()
Exemplo n.º 5
0
learning_rate_input = tf.placeholder("float")

grad_applier = RMSPropApplier(learning_rate=learning_rate_input,
                              decay=RMSP_ALPHA,
                              momentum=0.0,
                              epsilon=RMSP_EPSILON,
                              clip_norm=GRAD_NORM_CLIP,
                              device=device)

training_threads = []
for i in range(PARALLEL_SIZE):
    training_thread = A3CTrainingThread(i,
                                        global_network,
                                        1.0,
                                        learning_rate_input,
                                        grad_applier,
                                        8000000,
                                        device=device)
    training_threads.append(training_thread)

sess = tf.Session()
init = tf.initialize_all_variables()
sess.run(init)

saver = tf.train.Saver()
checkpoint = tf.train.get_checkpoint_state(CHECKPOINT_DIR)
if checkpoint and checkpoint.model_checkpoint_path:
    saver.restore(sess, checkpoint.model_checkpoint_path)
    print "checkpoint loaded:", checkpoint.model_checkpoint_path
else:
Exemplo n.º 6
0
def run_a3c_test(args):
    """Run A3C testing."""
    GYM_ENV_NAME = args.gym_env.replace('-', '_')

    if args.use_gpu:
        assert args.cuda_devices != ''
        os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_devices
    else:
        os.environ['CUDA_VISIBLE_DEVICES'] = ''
    import tensorflow as tf

    if not os.path.exists('results/a3c'):
        os.makedirs('results/a3c')

    if args.folder is not None:
        folder = args.folder
    else:
        folder = 'results/a3c/{}'.format(GYM_ENV_NAME)
        end_str = ''

        if args.use_mnih_2015:
            end_str += '_mnih2015'
        if args.use_lstm:
            end_str += '_lstm'
        if args.unclipped_reward:
            end_str += '_rawreward'
        elif args.log_scale_reward:
            end_str += '_logreward'
        if args.transformed_bellman:
            end_str += '_transformedbell'

        if args.use_transfer:
            end_str += '_transfer'
            if args.not_transfer_conv2:
                end_str += '_noconv2'
            elif args.not_transfer_conv3 and args.use_mnih_2015:
                end_str += '_noconv3'
            elif args.not_transfer_fc1:
                end_str += '_nofc1'
            elif args.not_transfer_fc2:
                end_str += '_nofc2'
        if args.finetune_upper_layers_only:
            end_str += '_tune_upperlayers'
        if args.train_with_demo_num_steps > 0 \
           or args.train_with_demo_num_epochs > 0:
            end_str += '_pretrain_ina3c'
        if args.use_demo_threads:
            end_str += '_demothreads'

        if args.load_pretrained_model:
            if args.use_pretrained_model_as_advice:
                end_str += '_modelasadvice'
            if args.use_pretrained_model_as_reward_shaping:
                end_str += '_modelasshaping'

        if args.padding == 'SAME':
            end_str += '_same'

        folder += end_str

    folder = pathlib.Path(folder)

    demo_memory_cam = None
    demo_cam_human = False
    if args.load_demo_cam:
        if args.demo_memory_folder is not None:
            demo_memory_folder = args.demo_memory_folder
        else:
            demo_memory_folder = 'collected_demo/{}'.format(GYM_ENV_NAME)

        demo_memory_folder = pathlib.Path(demo_memory_folder)

        if args.demo_cam_id is not None:
            demo_cam_human = True
            demo_cam, _, total_rewards_cam, _ = load_memory(
                name=None,
                demo_memory_folder=demo_memory_folder,
                demo_ids=args.demo_cam_id,
                imgs_normalized=False)

            demo_cam = demo_cam[int(args.demo_cam_id)]
            logger.info("loaded demo {} for testing CAM".format(
                args.demo_cam_id))

        else:
            demo_cam_folder = pathlib.Path(args.demo_cam_folder)
            demo_cam = ReplayMemory()
            demo_cam.load(name='test_cam', folder=demo_cam_folder)
            logger.info("loaded demo {} for testing CAM".format(
                str(demo_cam_folder / 'test_cam')))

        demo_memory_cam = np.zeros(
            (len(demo_cam),
             demo_cam.height,
             demo_cam.width,
             demo_cam.phi_length),
            dtype=np.float32)

        for i in range(len(demo_cam)):
            s0, _, _, _, _, _, t1, _ = demo_cam[i]
            demo_memory_cam[i] = np.copy(s0)

        del demo_cam

    device = "/cpu:0"
    gpu_options = None
    if args.use_gpu:
        device = "/gpu:"+os.environ["CUDA_VISIBLE_DEVICES"]
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=args.gpu_fraction)

    initial_learning_rate = args.initial_learn_rate
    logger.info('Initial Learning Rate={}'.format(initial_learning_rate))
    time.sleep(2)

    global_t = 0
    stop_requested = False

    game_state = GameState(env_id=args.gym_env)
    action_size = game_state.env.action_space.n

    config = tf.ConfigProto(
        gpu_options=gpu_options,
        log_device_placement=False,
        allow_soft_placement=True)

    input_shape = (84, 84, 4) if args.padding == 'VALID' else (88, 88, 4)
    if args.use_lstm:
        GameACLSTMNetwork.use_mnih_2015 = args.use_mnih_2015
        global_network = GameACLSTMNetwork(action_size, -1, device)
    else:
        GameACFFNetwork.use_mnih_2015 = args.use_mnih_2015
        global_network = GameACFFNetwork(
            action_size, -1, device, padding=args.padding,
            in_shape=input_shape)

    learning_rate_input = tf.placeholder(tf.float32, shape=(), name="opt_lr")

    grad_applier = tf.train.RMSPropOptimizer(
        learning_rate=learning_rate_input,
        decay=args.rmsp_alpha,
        epsilon=args.rmsp_epsilon)

    A3CTrainingThread.log_interval = args.log_interval
    A3CTrainingThread.performance_log_interval = args.performance_log_interval
    A3CTrainingThread.local_t_max = args.local_t_max
    A3CTrainingThread.demo_t_max = args.demo_t_max
    A3CTrainingThread.use_lstm = args.use_lstm
    A3CTrainingThread.action_size = action_size
    A3CTrainingThread.entropy_beta = args.entropy_beta
    A3CTrainingThread.demo_entropy_beta = args.demo_entropy_beta
    A3CTrainingThread.gamma = args.gamma
    A3CTrainingThread.use_mnih_2015 = args.use_mnih_2015
    A3CTrainingThread.env_id = args.gym_env
    A3CTrainingThread.finetune_upper_layers_only = \
        args.finetune_upper_layers_only
    A3CTrainingThread.transformed_bellman = args.transformed_bellman
    A3CTrainingThread.clip_norm = args.grad_norm_clip
    A3CTrainingThread.use_grad_cam = args.use_grad_cam

    if args.unclipped_reward:
        A3CTrainingThread.reward_type = "RAW"
    elif args.log_scale_reward:
        A3CTrainingThread.reward_type = "LOG"
    else:
        A3CTrainingThread.reward_type = "CLIP"

    if args.use_lstm:
        local_network = GameACLSTMNetwork(action_size, 0, device)
    else:
        local_network = GameACFFNetwork(
            action_size, 0, device, padding=args.padding,
            in_shape=input_shape)

    testing_thread = A3CTrainingThread(
        0, global_network, local_network, initial_learning_rate,
        learning_rate_input,
        grad_applier, 0,
        device=device)

    # prepare session
    sess = tf.Session(config=config)

    if args.use_transfer:
        if args.transfer_folder is not None:
            transfer_folder = args.transfer_folder
        else:
            transfer_folder = 'results/pretrain_models/{}'.format(GYM_ENV_NAME)
            end_str = ''

            if args.use_mnih_2015:
                end_str += '_mnih2015'
            end_str += '_l2beta1E-04_batchprop'  # TODO: make this an argument
            transfer_folder += end_str

        transfer_folder = pathlib.Path(transfer_folder)
        transfer_folder /= 'transfer_model'

        if args.not_transfer_conv2:
            transfer_var_list = [
                global_network.W_conv1,
                global_network.b_conv1,
                ]

        elif (args.not_transfer_conv3 and args.use_mnih_2015):
            transfer_var_list = [
                global_network.W_conv1,
                global_network.b_conv1,
                global_network.W_conv2,
                global_network.b_conv2,
                ]

        elif args.not_transfer_fc1:
            transfer_var_list = [
                global_network.W_conv1,
                global_network.b_conv1,
                global_network.W_conv2,
                global_network.b_conv2,
                ]

            if args.use_mnih_2015:
                transfer_var_list += [
                    global_network.W_conv3,
                    global_network.b_conv3,
                    ]

        elif args.not_transfer_fc2:
            transfer_var_list = [
                global_network.W_conv1,
                global_network.b_conv1,
                global_network.W_conv2,
                global_network.b_conv2,
                global_network.W_fc1,
                global_network.b_fc1,
                ]

            if args.use_mnih_2015:
                transfer_var_list += [
                    global_network.W_conv3,
                    global_network.b_conv3,
                    ]

        else:
            transfer_var_list = [
                global_network.W_conv1,
                global_network.b_conv1,
                global_network.W_conv2,
                global_network.b_conv2,
                global_network.W_fc1,
                global_network.b_fc1,
                global_network.W_fc2,
                global_network.b_fc2,
                ]

            if args.use_mnih_2015:
                transfer_var_list += [
                    global_network.W_conv3,
                    global_network.b_conv3,
                    ]

        global_network.load_transfer_model(
            sess, folder=transfer_folder,
            not_transfer_fc2=args.not_transfer_fc2,
            not_transfer_fc1=args.not_transfer_fc1,
            not_transfer_conv3=(args.not_transfer_conv3
                                and args.use_mnih_2015),
            not_transfer_conv2=args.not_transfer_conv2,
            var_list=transfer_var_list,
            )

    def initialize_uninitialized(sess):
        global_vars = tf.global_variables()
        is_not_initialized = sess.run(
            [tf.is_variable_initialized(var) for var in global_vars])
        not_initialized_vars = [
            v for (v, f) in zip(global_vars, is_not_initialized) if not f]

        if len(not_initialized_vars):
            sess.run(tf.variables_initializer(not_initialized_vars))

    if args.use_transfer:
        initialize_uninitialized(sess)
    else:
        sess.run(tf.global_variables_initializer())

    # init or load checkpoint with saver
    root_saver = tf.train.Saver(max_to_keep=1)
    checkpoint = tf.train.get_checkpoint_state(str(folder))
    if checkpoint and checkpoint.model_checkpoint_path:
        root_saver.restore(sess, checkpoint.model_checkpoint_path)
        logger.info("checkpoint loaded:{}".format(
            checkpoint.model_checkpoint_path))
        tokens = checkpoint.model_checkpoint_path.split("-")
        # set global step
        global_t = int(tokens[-1])
        logger.info(">>> global step set: {}".format(global_t))
    else:
        logger.warning("Could not find old checkpoint")

    def test_function():
        nonlocal global_t

        if args.use_transfer:
            from_folder = str(transfer_folder).split('/')[-2]
        else:
            from_folder = str(folder).split('/')[-1]

        from_folder = pathlib.Path(from_folder)
        save_folder = 'results/test_model/a3c' / from_folder
        prepare_dir(str(save_folder), empty=False)
        prepare_dir(str(save_folder / 'frames'), empty=False)

        # Evaluate model before training
        if not stop_requested:
            testing_thread.testing_model(
                sess, args.eval_max_steps, global_t, save_folder,
                demo_memory_cam=demo_memory_cam, demo_cam_human=demo_cam_human)

    def signal_handler(signal, frame):
        nonlocal stop_requested
        logger.info('You pressed Ctrl+C!')
        stop_requested = True

        if stop_requested and global_t == 0:
            sys.exit(1)

    test_thread = threading.Thread(target=test_function, args=())

    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)

    test_thread.start()

    print('Press Ctrl+C to stop')

    test_thread.join()

    sess.close()
Exemplo n.º 7
0
import matplotlib.pyplot as plt
import random

from game_state import GameState
from game_ac_network import GameACNetwork
from a3c_training_thread import A3CTrainingThread
from constants import ACTION_SIZE

PARALLEL_SIZE = 8
CHECKPOINT_DIR = 'checkpoints'

global_network = GameACNetwork(ACTION_SIZE)

training_threads = []
for i in range(PARALLEL_SIZE):
    training_thread = A3CTrainingThread(i, global_network, 1.0, 8000000)
    training_threads.append(training_thread)

sess = tf.Session()
init = tf.initialize_all_variables()
sess.run(init)

saver = tf.train.Saver()
checkpoint = tf.train.get_checkpoint_state(CHECKPOINT_DIR)
if checkpoint and checkpoint.model_checkpoint_path:
    saver.restore(sess, checkpoint.model_checkpoint_path)
    print "checkpoint loaded:", checkpoint.model_checkpoint_path
else:
    print "Could not find old checkpoint"

W_conv1 = sess.run(global_network.W_conv1)
Exemplo n.º 8
0
learning_rate_input = tf.placeholder("float")

policy_applier = RMSPropApplier(learning_rate=learning_rate_input,
                                decay=0.99,
                                momentum=0.0,
                                epsilon=RMSP_EPSILON)

value_applier = RMSPropApplier(learning_rate=learning_rate_input,
                               decay=0.99,
                               momentum=0.0,
                               epsilon=RMSP_EPSILON)

training_threads = []
for i in range(PARALLEL_SIZE):
    training_thread = A3CTrainingThread(i, global_network, 1.0,
                                        learning_rate_input, policy_applier,
                                        value_applier, 8000000)
    training_threads.append(training_thread)

sess = tf.Session()
init = tf.initialize_all_variables()
sess.run(init)

saver = tf.train.Saver()
checkpoint = tf.train.get_checkpoint_state(CHECKPOINT_DIR)
if checkpoint and checkpoint.model_checkpoint_path:
    saver.restore(sess, checkpoint.model_checkpoint_path)
    print "checkpoint loaded:", checkpoint.model_checkpoint_path
else:
    print "Could not find old checkpoint"
Exemplo n.º 9
0
learning_rate_input = tf.placeholder("float")

policy_applier = RMSPropApplier(learning_rate=learning_rate_input,
                                decay=0.99,
                                momentum=0.0,
                                epsilon=RMSP_EPSILON)

value_applier = RMSPropApplier(learning_rate=learning_rate_input,
                               decay=0.99,
                               momentum=0.0,
                               epsilon=RMSP_EPSILON)

for i in range(PARALLEL_SIZE):
    training_thread = A3CTrainingThread(i, global_network,
                                        initial_learning_rate,
                                        learning_rate_input, policy_applier,
                                        value_applier, MAX_TIME_STEP)
    training_threads.append(training_thread)

# prepare session
sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))

init = tf.initialize_all_variables()
sess.run(init)

# summary for tensorboard
summary_op = tf.merge_all_summaries()
summary_writer = tf.train.SummaryWriter(LOG_FILE, sess.graph_def)

# init or load checkpoint with saver
saver = tf.train.Saver()
Exemplo n.º 10
0
def run_a3c(args):
    """Run A3C experiment."""
    GYM_ENV_NAME = args.gym_env.replace('-', '_')
    GAME_NAME = args.gym_env.replace('NoFrameskip-v4','')

    # setup folder name and path to folder
    folder = pathlib.Path(setup_folder(args, GYM_ENV_NAME))

    # setup GPU (if applicable)
    import tensorflow as tf
    gpu_options = setup_gpu(tf, args.use_gpu, args.gpu_fraction)

    ######################################################
    # setup default device
    device = "/cpu:0"

    global_t = 0
    rewards = {'train': {}, 'eval': {}}
    best_model_reward = -(sys.maxsize)
    if args.load_pretrained_model:
        class_rewards = {'class_eval': {}}

    # setup logging info for analysis, see Section 4.2 of the paper
    sil_dict = {
                # count number of SIL updates
                "sil_ctr":{},
                # total number of butter D sampled during SIL
                "sil_a3c_sampled":{},
                # total number of buffer D samples (i.e., generated by A3C workers) used during SIL (i.e., passed max op)
                "sil_a3c_used":{},
                # the return of used samples for buffer D
                "sil_a3c_used_return":{},
                # total number of buffer R sampled during SIL
                "sil_rollout_sampled":{},
                # total number of buffer R samples (i.e., generated by refresher worker) used during SIL (i.e., passed max op)
                "sil_rollout_used":{},
                # the return of used samples for buffer R
                "sil_rollout_used_return":{},
                # number of old samples still used (even after refreshing)
                "sil_old_used":{}
                }
    sil_ctr, sil_a3c_sampled, sil_a3c_used, sil_a3c_used_return = 0, 0, 0, 0
    sil_rollout_sampled, sil_rollout_used, sil_rollout_used_return = 0, 0, 0
    sil_old_used = 0


    rollout_dict = {
                    # total number of rollout performed
                    "rollout_ctr": {},
                    # total number of successful rollout (i.e., Gnew > G)
                    "rollout_added_ctr":{},
                    # the return of Gnew
                    "rollout_new_return":{},
                    # the return of G
                    "rollout_old_return":{}
                    }
    rollout_ctr, rollout_added_ctr = 0, 0
    rollout_new_return = 0 # this records the total, avg = this / rollout_added_ctr
    rollout_old_return = 0 # this records the total, avg = this / rollout_added_ctr

    # setup file names
    reward_fname = folder / '{}-a3c-rewards.pkl'.format(GYM_ENV_NAME)
    sil_fname = folder / '{}-a3c-dict-sil.pkl'.format(GYM_ENV_NAME)
    rollout_fname = folder / '{}-a3c-dict-rollout.pkl'.format(GYM_ENV_NAME)
    if args.load_pretrained_model:
        class_reward_fname = folder / '{}-class-rewards.pkl'.format(GYM_ENV_NAME)

    sharedmem_fname = folder / '{}-sharedmem.pkl'.format(GYM_ENV_NAME)
    sharedmem_params_fname = folder / '{}-sharedmem-params.pkl'.format(GYM_ENV_NAME)
    sharedmem_trees_fname = folder / '{}-sharedmem-trees.pkl'.format(GYM_ENV_NAME)

    rolloutmem_fname = folder / '{}-rolloutmem.pkl'.format(GYM_ENV_NAME)
    rolloutmem_params_fname = folder / '{}-rolloutmem-params.pkl'.format(GYM_ENV_NAME)
    rolloutmem_trees_fname = folder / '{}-rolloutmem-trees.pkl'.format(GYM_ENV_NAME)

    # for removing older ckpt, save mem space
    prev_ckpt_t = -1

    stop_req = False

    game_state = GameState(env_id=args.gym_env)
    action_size = game_state.env.action_space.n
    game_state.close()
    del game_state.env
    del game_state

    input_shape = (args.input_shape, args.input_shape, 4)
    #######################################################
    # setup global A3C
    GameACFFNetwork.use_mnih_2015 = args.use_mnih_2015
    global_network = GameACFFNetwork(
        action_size, -1, device, padding=args.padding,
        in_shape=input_shape)
    logger.info('A3C Initial Learning Rate={}'.format(args.initial_learn_rate))

    # setup pretrained model
    global_pretrained_model = None
    local_pretrained_model = None
    pretrain_graph = None

    # if use pretrained model to refresh
    # then must load pretrained model
    # otherwise, don't load model
    if args.use_lider and args.nstep_bc > 0:
        assert args.load_pretrained_model, "refreshing with other policies, must load a pre-trained model (TA or BC)"
    else:
        assert not args.load_pretrained_model, "refreshing with the current policy, don't load pre-trained models"

    if args.load_pretrained_model:
        pretrain_graph, global_pretrained_model = setup_pretrained_model(tf,
            args, action_size, input_shape,
            device="/gpu:0" if args.use_gpu else device)
        assert global_pretrained_model is not None
        assert pretrain_graph is not None

    time.sleep(2.0)

    # setup experience memory
    shared_memory = None # => this is BufferD
    rollout_buffer = None # => this is BufferR
    if args.use_sil:
        shared_memory = SILReplayMemory(
            action_size, max_len=args.memory_length, gamma=args.gamma,
            clip=False if args.unclipped_reward else True,
            height=input_shape[0], width=input_shape[1],
            phi_length=input_shape[2], priority=args.priority_memory,
            reward_constant=args.reward_constant)

        if args.use_lider and not args.onebuffer:
            rollout_buffer = SILReplayMemory(
                action_size, max_len=args.memory_length, gamma=args.gamma,
                clip=False if args.unclipped_reward else True,
                height=input_shape[0], width=input_shape[1],
                phi_length=input_shape[2], priority=args.priority_memory,
                reward_constant=args.reward_constant)

        # log memory information
        shared_memory.log()
        if args.use_lider and not args.onebuffer:
            rollout_buffer.log()

    ############## Setup Thread Workers BEGIN ################
    # 17 total number of threads for all experiments
    assert args.parallel_size ==17, "use 17 workers for all experiments"

    startIndex = 0
    all_workers = []

    # a3c and sil learning rate and optimizer
    learning_rate_input = tf.placeholder(tf.float32, shape=(), name="opt_lr")
    grad_applier = tf.train.RMSPropOptimizer(
        learning_rate=learning_rate_input,
        decay=args.rmsp_alpha,
        epsilon=args.rmsp_epsilon)

    setup_common_worker(CommonWorker, args, action_size)

    # setup SIL worker
    sil_worker = None
    if args.use_sil:
        _device = "/gpu:0" if args.use_gpu else device

        sil_network = GameACFFNetwork(
            action_size, startIndex, device=_device,
            padding=args.padding, in_shape=input_shape)

        sil_worker = SILTrainingThread(startIndex, global_network, sil_network,
            args.initial_learn_rate,
            learning_rate_input,
            grad_applier, device=_device,
            batch_size=args.batch_size,
            use_rollout=args.use_lider,
            one_buffer=args.onebuffer,
            sampleR=args.sampleR)

        all_workers.append(sil_worker)
        startIndex += 1

    # setup refresh worker
    refresh_worker = None
    if args.use_lider:
        _device = "/gpu:0" if args.use_gpu else device

        refresh_network = GameACFFNetwork(
            action_size, startIndex, device=_device,
            padding=args.padding, in_shape=input_shape)

        refresh_local_pretrained_model = None
        # if refreshing with other polies
        if args.nstep_bc > 0:
            refresh_local_pretrained_model = PretrainedModelNetwork(
                pretrain_graph, action_size, startIndex,
                padding=args.padding,
                in_shape=input_shape, sae=False,
                tied_weights=False,
                use_denoising=False,
                noise_factor=0.3,
                loss_function='mse',
                use_slv=False, device=_device)

        refresh_worker = RefreshThread(
            thread_index=startIndex, action_size=action_size, env_id=args.gym_env,
            global_a3c=global_network, local_a3c=refresh_network,
            update_in_rollout=args.update_in_rollout, nstep_bc=args.nstep_bc,
            global_pretrained_model=global_pretrained_model,
            local_pretrained_model=refresh_local_pretrained_model,
            transformed_bellman = args.transformed_bellman,
            device=_device,
            entropy_beta=args.entropy_beta, clip_norm=args.grad_norm_clip,
            grad_applier=grad_applier,
            initial_learn_rate=args.initial_learn_rate,
            learning_rate_input=learning_rate_input)

        all_workers.append(refresh_worker)
        startIndex += 1

    # setup a3c workers
    setup_a3c_worker(A3CTrainingThread, args, startIndex)
    for i in range(startIndex, args.parallel_size):
        local_network = GameACFFNetwork(
            action_size, i, device="/cpu:0",
            padding=args.padding,
            in_shape=input_shape)

        a3c_worker = A3CTrainingThread(
            i, global_network, local_network,
            args.initial_learn_rate, learning_rate_input, grad_applier,
            device="/cpu:0", no_op_max=30)

        all_workers.append(a3c_worker)
    ############## Setup Thread Workers END ################

    # setup config for tensorflow
    config = tf.ConfigProto(
        gpu_options=gpu_options,
        log_device_placement=False,
        allow_soft_placement=True)

    # prepare sessions
    sess = tf.Session(config=config)
    pretrain_sess = None
    if global_pretrained_model:
        pretrain_sess = tf.Session(config=config, graph=pretrain_graph)

    # initial pretrained model
    if pretrain_sess:
        assert args.pretrained_model_folder is not None
        global_pretrained_model.load(
            pretrain_sess,
            args.pretrained_model_folder)

    sess.run(tf.global_variables_initializer())
    if global_pretrained_model:
        initialize_uninitialized(tf, pretrain_sess,
                                 global_pretrained_model)
    if local_pretrained_model:
        initialize_uninitialized(tf, pretrain_sess,
                                 local_pretrained_model)

    # summary writer for tensorboard
    summ_file = args.save_to+'log/a3c/{}/'.format(GYM_ENV_NAME) + str(folder)[58:] # str(folder)[12:]
    summary_writer = tf.summary.FileWriter(summ_file, sess.graph)

    # init or load checkpoint with saver
    root_saver = tf.train.Saver(max_to_keep=1)
    saver = tf.train.Saver(max_to_keep=1)
    best_saver = tf.train.Saver(max_to_keep=1)

    checkpoint = tf.train.get_checkpoint_state(str(folder)+'/model_checkpoints')
    if checkpoint and checkpoint.model_checkpoint_path:
        root_saver.restore(sess, checkpoint.model_checkpoint_path)
        logger.info("checkpoint loaded:{}".format(
            checkpoint.model_checkpoint_path))
        tokens = checkpoint.model_checkpoint_path.split("-")
        # set global step
        global_t = int(tokens[-1])
        logger.info(">>> global step set: {}".format(global_t))

        tmp_t = (global_t // args.eval_freq) * args.eval_freq
        logger.info(">>> tmp_t: {}".format(tmp_t))

        # set wall time
        wall_t = 0.

        # set up reward files
        best_reward_file = folder / 'model_best/best_model_reward'
        with best_reward_file.open('r') as f:
            best_model_reward = float(f.read())

        # restore rewards
        rewards = restore_dict(reward_fname, global_t)
        logger.info(">>> restored: rewards")

        # restore loggings
        sil_dict = restore_dict(sil_fname, global_t)
        sil_ctr = sil_dict['sil_ctr'][tmp_t]
        sil_a3c_sampled = sil_dict['sil_a3c_sampled'][tmp_t]
        sil_a3c_used = sil_dict['sil_a3c_used'][tmp_t]
        sil_a3c_used_return = sil_dict['sil_a3c_used_return'][tmp_t]
        sil_rollout_sampled = sil_dict['sil_rollout_sampled'][tmp_t]
        sil_rollout_used = sil_dict['sil_rollout_used'][tmp_t]
        sil_rollout_used_return = sil_dict['sil_rollout_used_return'][tmp_t]
        sil_old_used = sil_dict['sil_old_used'][tmp_t]
        logger.info(">>> restored: sil_dict")

        rollout_dict = restore_dict(rollout_fname, global_t)
        rollout_ctr = rollout_dict['rollout_ctr'][tmp_t]
        rollout_added_ctr = rollout_dict['rollout_added_ctr'][tmp_t]
        rollout_new_return = rollout_dict['rollout_new_return'][tmp_t]
        rollout_old_return = rollout_dict['rollout_old_return'][tmp_t]
        logger.info(">>> restored: rollout_dict")

        if args.load_pretrained_model:
            class_reward_file = folder / '{}-class-rewards.pkl'.format(GYM_ENV_NAME)
            class_rewards = restore_dict(class_reward_file, global_t)

        # restore replay buffers (if saved)
        if args.checkpoint_buffer:
            # restore buffer D
            if args.use_sil and args.priority_memory:
                shared_memory = restore_buffer(sharedmem_fname, shared_memory, global_t)
                shared_memory = restore_buffer_trees(sharedmem_trees_fname, shared_memory, global_t)
                shared_memory = restore_buffer_params(sharedmem_params_fname, shared_memory, global_t)
                logger.info(">>> restored: shared_memory (Buffer D)")
                shared_memory.log()
                # restore buffer R
                if args.use_lider and not args.onebuffer:
                    rollout_buffer = restore_buffer(rolloutmem_fname, rollout_buffer, global_t)
                    rollout_buffer = restore_buffer_trees(rolloutmem_trees_fname, rollout_buffer, global_t)
                    rollout_buffer = restore_buffer_params(rolloutmem_params_fname, rollout_buffer, global_t)
                    logger.info(">>> restored: rollout_buffer (Buffer R)")
                    rollout_buffer.log()

        # if all restores okay, remove old ckpt to save storage space
        prev_ckpt_t = global_t

    else:
        logger.warning("Could not find old checkpoint")
        wall_t = 0.0
        prepare_dir(folder, empty=True)
        prepare_dir(folder / 'model_checkpoints', empty=True)
        prepare_dir(folder / 'model_best', empty=True)
        prepare_dir(folder / 'frames', empty=True)

    lock = threading.Lock()

    # next saving global_t
    def next_t(current_t, freq):
        return np.ceil((current_t + 0.00001) / freq) * freq

    next_global_t = next_t(global_t, args.eval_freq)
    next_save_t = next_t(
        global_t, args.eval_freq*args.checkpoint_freq)

    step_t = 0

    def train_function(parallel_idx, th_ctr, ep_queue, net_updates):
        nonlocal global_t, step_t, rewards, class_rewards, lock, \
                 next_save_t, next_global_t, prev_ckpt_t
        nonlocal shared_memory, rollout_buffer
        nonlocal sil_dict, sil_ctr, sil_a3c_sampled, sil_a3c_used, sil_a3c_used_return, \
                 sil_rollout_sampled, sil_rollout_used, sil_rollout_used_return, \
                 sil_old_used
        nonlocal rollout_dict, rollout_ctr, rollout_added_ctr, \
                 rollout_new_return, rollout_old_return

        parallel_worker = all_workers[parallel_idx]
        parallel_worker.set_summary_writer(summary_writer)

        with lock:
            # Evaluate model before training
            if not stop_req and global_t == 0 and step_t == 0:
                rewards['eval'][step_t] = parallel_worker.testing(
                    sess, args.eval_max_steps, global_t, folder,
                    worker=all_workers[-1])

                # testing pretrained TA or BC in game
                if args.load_pretrained_model:
                    assert pretrain_sess is not None
                    assert global_pretrained_model is not None
                    class_rewards['class_eval'][step_t] = \
                        parallel_worker.test_loaded_classifier(global_t=global_t,
                                                    max_eps=50, # testing 50 episodes
                                                    sess=pretrain_sess,
                                                    worker=all_workers[-1],
                                                    model=global_pretrained_model)
                    # log pretrained model performance
                    class_eval_file = pathlib.Path(args.pretrained_model_folder[:21]+\
                        str(GAME_NAME)+"/"+str(GAME_NAME)+'-model-eval.txt')
                    class_std = np.std(class_rewards['class_eval'][step_t][-1])
                    class_mean = np.mean(class_rewards['class_eval'][step_t][-1])
                    with class_eval_file.open('w') as f:
                        f.write("class_mean: \n" + str(class_mean) + "\n")
                        f.write("class_std: \n" + str(class_std) + "\n")
                        f.write("class_rewards: \n" + str(class_rewards['class_eval'][step_t][-1]) + "\n")

                checkpt_file = folder / 'model_checkpoints'
                checkpt_file /= '{}_checkpoint'.format(GYM_ENV_NAME)
                saver.save(sess, str(checkpt_file), global_step=global_t)
                save_best_model(rewards['eval'][global_t][0])

                # saving worker info to dicts for analysis
                sil_dict['sil_ctr'][step_t] = sil_ctr
                sil_dict['sil_a3c_sampled'][step_t] = sil_a3c_sampled
                sil_dict['sil_a3c_used'][step_t] = sil_a3c_used
                sil_dict['sil_a3c_used_return'][step_t] = sil_a3c_used_return
                sil_dict['sil_rollout_sampled'][step_t] = sil_rollout_sampled
                sil_dict['sil_rollout_used'][step_t] = sil_rollout_used
                sil_dict['sil_rollout_used_return'][step_t] = sil_rollout_used_return
                sil_dict['sil_old_used'][step_t] = sil_old_used

                rollout_dict['rollout_ctr'][step_t] = rollout_ctr
                rollout_dict['rollout_added_ctr'][step_t] = rollout_added_ctr
                rollout_dict['rollout_new_return'][step_t] = rollout_new_return
                rollout_dict['rollout_old_return'][step_t] = rollout_old_return

                # dump pickle
                dump_pickle([rewards, sil_dict, rollout_dict],
                            [reward_fname, sil_fname, rollout_fname],
                            global_t)
                if args.load_pretrained_model:
                    dump_pickle([class_rewards], [class_reward_fname], global_t)

                logger.info('Dump pickle at step {}'.format(global_t))

                # save replay buffer (only works under priority mem)
                if args.checkpoint_buffer:
                    if shared_memory is not None and args.priority_memory:
                        params = [shared_memory.buff._next_idx, shared_memory.buff._max_priority]
                        trees = [shared_memory.buff._it_sum._value,
                                 shared_memory.buff._it_min._value]
                        dump_pickle([shared_memory.buff._storage, params, trees],
                                    [sharedmem_fname, sharedmem_params_fname, sharedmem_trees_fname],
                                    global_t)
                        logger.info('Saving shared_memory')

                    if rollout_buffer is not None and args.priority_memory:
                        params = [rollout_buffer.buff._next_idx, rollout_buffer.buff._max_priority]
                        trees = [rollout_buffer.buff._it_sum._value,
                                 rollout_buffer.buff._it_min._value]
                        dump_pickle([rollout_buffer.buff._storage, params, trees],
                                    [rolloutmem_fname, rolloutmem_params_fname, rolloutmem_trees_fname],
                                    global_t)
                        logger.info('Saving rollout_buffer')

                prev_ckpt_t = global_t

                step_t = 1

        # set start_time
        start_time = time.time() - wall_t
        parallel_worker.set_start_time(start_time)

        if parallel_worker.is_sil_thread:
            sil_interval = 0  # bigger number => slower SIL updates
            m_repeat = 4
            min_mem = args.batch_size * m_repeat
            sil_train_flag = len(shared_memory) >= min_mem

        while True:
            if stop_req:
                return

            if global_t >= (args.max_time_step * args.max_time_step_fraction):
                return

            if parallel_worker.is_sil_thread:
                # before sil starts, init local count
                local_sil_ctr = 0
                local_sil_a3c_sampled, local_sil_a3c_used, local_sil_a3c_used_return = 0, 0, 0
                local_sil_rollout_sampled, local_sil_rollout_used, local_sil_rollout_used_return = 0, 0, 0
                local_sil_old_used = 0

                if net_updates.qsize() >= sil_interval \
                   and len(shared_memory) >= min_mem:
                    sil_train_flag = True

                if sil_train_flag:
                    sil_train_flag = False

                    th_ctr.get()

                    train_out = parallel_worker.sil_train(
                        sess, global_t, shared_memory, m_repeat,
                        rollout_buffer=rollout_buffer)

                    local_sil_ctr, local_sil_a3c_sampled, local_sil_a3c_used, \
                       local_sil_a3c_used_return, \
                       local_sil_rollout_sampled, local_sil_rollout_used, \
                       local_sil_rollout_used_return, \
                       local_sil_old_used = train_out

                    th_ctr.put(1)

                    with net_updates.mutex:
                        net_updates.queue.clear()

                    if args.use_lider:
                        parallel_worker.record_sil(sil_ctr=sil_ctr,
                                              total_used=(sil_a3c_used + sil_rollout_used),
                                              num_a3c_used=sil_a3c_used,
                                              a3c_used_return=sil_a3c_used_return/(sil_a3c_used+1),#add one in case divide by zero
                                              rollout_used=sil_rollout_used,
                                              rollout_used_return=sil_rollout_used_return/(sil_rollout_used+1),
                                              old_used=sil_old_used,
                                              global_t=global_t)

                        if sil_ctr % 200 == 0 and sil_ctr > 0:
                            rollout_buffsize = 0
                            if not args.onebuffer:
                                rollout_buffsize = len(rollout_buffer)
                            log_data = (sil_ctr, len(shared_memory),
                                        rollout_buffsize,
                                        sil_a3c_used+sil_rollout_used,
                                        args.batch_size*sil_ctr,
                                        sil_a3c_used,
                                        sil_a3c_used_return/(sil_a3c_used+1),
                                        sil_rollout_used,
                                        sil_rollout_used_return/(sil_rollout_used+1),
                                        sil_old_used)
                            logger.info("SIL: sil_ctr={0:}"
                                        " sil_memory_size={1:}"
                                        " rollout_buffer_size={2:}"
                                        " total_sample_used={3:}/{4:}"
                                        " a3c_used={5:}"
                                        " a3c_used_return_avg={6:.2f}"
                                        " rollout_used={7:}"
                                        " rollout_used_return_avg={8:.2f}"
                                        " old_used={9:}".format(*log_data))
                    else:
                        parallel_worker.record_sil(sil_ctr=sil_ctr,
                                                   total_used=(sil_a3c_used + sil_rollout_used),
                                                   num_a3c_used=sil_a3c_used,
                                                   rollout_used=sil_rollout_used,
                                                   global_t=global_t)
                        if sil_ctr % 200 == 0 and sil_ctr > 0:
                            log_data = (sil_ctr, sil_a3c_used+sil_rollout_used,
                                        args.batch_size*sil_ctr,
                                        sil_a3c_used,
                                        len(shared_memory))
                            logger.info("SIL: sil_ctr={0:}"
                                        " total_sample_used={1:}/{2:}"
                                        " a3c_used={3:}"
                                        " sil_memory_size={4:}".format(*log_data))

                # Adding episodes to SIL memory is centralize to ensure
                # sampling and updating of priorities does not become a problem
                # since we add new episodes to SIL at once and during
                # SIL training it is guaranteed that SIL memory is untouched.
                max = args.parallel_size
                while not ep_queue.empty():
                    data = ep_queue.get()
                    parallel_worker.episode.set_data(*data)
                    shared_memory.extend(parallel_worker.episode)
                    parallel_worker.episode.reset()
                    max -= 1
                    if max <= 0: # This ensures that SIL has a chance to train
                        break

                diff_global_t = 0

                # centralized rollout counting
                local_rollout_ctr, local_rollout_added_ctr = 0, 0
                local_rollout_new_return, local_rollout_old_return = 0, 0

            elif parallel_worker.is_refresh_thread:
                # before refresh starts, init local count
                diff_global_t = 0
                local_rollout_ctr, local_rollout_added_ctr = 0, 0
                local_rollout_new_return, local_rollout_old_return = 0, 0

                if len(shared_memory) >= 1:
                    th_ctr.get()
                    # randomly sample a state from buffer D
                    sample = shared_memory.sample_one_random()
                    # after sample, flip refreshed to True
                    # TODO: fix this so that only *succesful* refresh is flipped to True
                    # currently counting *all* refresh as True
                    assert sample[-1] == True

                    train_out = parallel_worker.rollout(sess, folder, pretrain_sess,
                                                        global_t, sample,
                                                        args.addall,
                                                        args.max_ep_step,
                                                        args.nstep_bc,
                                                        args.update_in_rollout)

                    diff_global_t, episode_end, part_end, local_rollout_ctr, \
                        local_rollout_added_ctr, add, local_rollout_new_return, \
                        local_rollout_old_return = train_out

                    th_ctr.put(1)

                    if rollout_ctr % 20 == 0 and rollout_ctr > 0:
                        log_msg = "ROLLOUT: rollout_ctr={} added_rollout_ct={} worker={}".format(
                        rollout_ctr, rollout_added_ctr, parallel_worker.thread_idx)
                        logger.info(log_msg)
                        logger.info("ROLLOUT Gnew: {}, G: {}".format(local_rollout_new_return,
                                                                     local_rollout_old_return))

                    # should always part_end, i.e., end of episode
                    # and only add if new return is better (if not LiDER-AddAll)
                    if part_end and add:
                        if not args.onebuffer:
                            # directly put into Buffer R
                            rollout_buffer.extend(parallel_worker.episode)
                        else:
                            # Buffer D add sample is centralized when OneBuffer
                            ep_queue.put(parallel_worker.episode.get_data())

                    parallel_worker.episode.reset()

                # centralized SIL counting
                local_sil_ctr = 0
                local_sil_a3c_sampled, local_sil_a3c_used, local_sil_a3c_used_return = 0, 0, 0
                local_sil_rollout_sampled, local_sil_rollout_used, local_sil_rollout_used_return = 0, 0, 0
                local_sil_old_used = 0

            # a3c training thread worker
            else:
                th_ctr.get()

                train_out = parallel_worker.train(sess, global_t, rewards)
                diff_global_t, episode_end, part_end = train_out

                th_ctr.put(1)

                if args.use_sil:
                    net_updates.put(1)
                    if part_end:
                        ep_queue.put(parallel_worker.episode.get_data())
                        parallel_worker.episode.reset()

                # centralized SIL counting
                local_sil_ctr = 0
                local_sil_a3c_sampled, local_sil_a3c_used, local_sil_a3c_used_return = 0, 0, 0
                local_sil_rollout_sampled, local_sil_rollout_used, local_sil_rollout_used_return = 0, 0, 0
                local_sil_old_used = 0
                # centralized rollout counting
                local_rollout_ctr, local_rollout_added_ctr = 0, 0
                local_rollout_new_return, local_rollout_old_return = 0, 0

            # ensure only one thread is updating global_t at a time
            with lock:
                global_t += diff_global_t

                # centralize increasing count for SIL and Rollout
                sil_ctr += local_sil_ctr
                sil_a3c_sampled += local_sil_a3c_sampled
                sil_a3c_used += local_sil_a3c_used
                sil_a3c_used_return += local_sil_a3c_used_return
                sil_rollout_sampled += local_sil_rollout_sampled
                sil_rollout_used += local_sil_rollout_used
                sil_rollout_used_return += local_sil_rollout_used_return
                sil_old_used += local_sil_old_used

                rollout_ctr += local_rollout_ctr
                rollout_added_ctr += local_rollout_added_ctr
                rollout_new_return += local_rollout_new_return
                rollout_old_return += local_rollout_old_return

                # if during a thread's update, global_t has reached a evaluation interval
                if global_t > next_global_t:
                    next_global_t = next_t(global_t, args.eval_freq)
                    step_t = int(next_global_t - args.eval_freq)

                    # wait for all threads to be done before testing
                    while not stop_req and th_ctr.qsize() < len(all_workers):
                        time.sleep(0.001)

                    step_t = int(next_global_t - args.eval_freq)

                    # Evaluate for 125,000 steps
                    rewards['eval'][step_t] = parallel_worker.testing(
                        sess, args.eval_max_steps, step_t, folder,
                        worker=all_workers[-1])
                    save_best_model(rewards['eval'][step_t][0])
                    last_reward = rewards['eval'][step_t][0]

                    # saving worker info to dicts
                    # SIL
                    sil_dict['sil_ctr'][step_t] = sil_ctr
                    sil_dict['sil_a3c_sampled'][step_t] = sil_a3c_sampled
                    sil_dict['sil_a3c_used'][step_t] = sil_a3c_used
                    sil_dict['sil_a3c_used_return'][step_t] = sil_a3c_used_return
                    sil_dict['sil_rollout_sampled'][step_t] = sil_rollout_sampled
                    sil_dict['sil_rollout_used'][step_t] = sil_rollout_used
                    sil_dict['sil_rollout_used_return'][step_t] = sil_rollout_used_return
                    sil_dict['sil_old_used'][step_t] = sil_old_used
                    # ROLLOUT
                    rollout_dict['rollout_ctr'][step_t] = rollout_ctr
                    rollout_dict['rollout_added_ctr'][step_t] = rollout_added_ctr
                    rollout_dict['rollout_new_return'][step_t] = rollout_new_return
                    rollout_dict['rollout_old_return'][step_t] = rollout_old_return

                    # save ckpt after done with eval
                    if global_t > next_save_t:
                        next_save_t = next_t(global_t, args.eval_freq*args.checkpoint_freq)

                        # dump pickle
                        dump_pickle([rewards, sil_dict, rollout_dict],
                                    [reward_fname, sil_fname, rollout_fname],
                                    global_t)
                        if args.load_pretrained_model:
                            dump_pickle([class_rewards], [class_reward_fname], global_t)
                        logger.info('Dump pickle at step {}'.format(global_t))

                        # save replay buffer (only works for priority mem for now)
                        if args.checkpoint_buffer:
                            if shared_memory is not None and args.priority_memory:
                                params = [shared_memory.buff._next_idx, shared_memory.buff._max_priority]
                                trees = [shared_memory.buff._it_sum._value,
                                         shared_memory.buff._it_min._value]
                                dump_pickle([shared_memory.buff._storage, params, trees],
                                            [sharedmem_fname, sharedmem_params_fname, sharedmem_trees_fname],
                                            global_t)
                                logger.info('Saved shared_memory')

                            if rollout_buffer is not None and args.priority_memory:
                                params = [rollout_buffer.buff._next_idx, rollout_buffer.buff._max_priority]
                                trees = [rollout_buffer.buff._it_sum._value,
                                         rollout_buffer.buff._it_min._value]
                                dump_pickle([rollout_buffer.buff._storage, params, trees],
                                            [rolloutmem_fname, rolloutmem_params_fname, rolloutmem_trees_fname],
                                            global_t)
                                logger.info('Saved rollout_buffer')

                        # save a3c after saving buffer -- in case saving buffer OOM
                        # so that at least we can revert back to the previous ckpt
                        checkpt_file = folder / 'model_checkpoints'
                        checkpt_file /= '{}_checkpoint'.format(GYM_ENV_NAME)
                        saver.save(sess, str(checkpt_file), global_step=global_t,
                                   write_meta_graph=False)
                        logger.info('Saved model ckpt')

                        # if everything saves okay, clean up previous ckpt to save space
                        remove_pickle([reward_fname, sil_fname, rollout_fname],
                                      prev_ckpt_t)
                        if args.load_pretrained_model:
                            remove_pickle([class_reward_fname], prev_ckpt_t)

                        remove_pickle([sharedmem_fname, sharedmem_params_fname,
                                       sharedmem_trees_fname],
                                      prev_ckpt_t)
                        if rollout_buffer is not None and args.priority_memory:
                            remove_pickle([rolloutmem_fname, rolloutmem_params_fname,
                                           rolloutmem_trees_fname],
                                          prev_ckpt_t)

                        logger.info('Removed ckpt from step {}'.format(prev_ckpt_t))

                        prev_ckpt_t = global_t


    def signal_handler(signal, frame):
        nonlocal stop_req
        logger.info('You pressed Ctrl+C!')
        stop_req = True

        if stop_req and global_t == 0:
            sys.exit(1)

    def save_best_model(test_reward):
        nonlocal best_model_reward
        if test_reward > best_model_reward:
            best_model_reward = test_reward
            best_reward_file = folder / 'model_best/best_model_reward'

            with best_reward_file.open('w') as f:
                f.write(str(best_model_reward))

            best_checkpt_file = folder / 'model_best'
            best_checkpt_file /= '{}_checkpoint'.format(GYM_ENV_NAME)
            best_saver.save(sess, str(best_checkpt_file))


    train_threads = []
    th_ctr = Queue()
    for i in range(args.parallel_size):
        th_ctr.put(1)

    episodes_queue = None
    net_updates = None
    if args.use_sil:
        episodes_queue = Queue()
        net_updates = Queue()

    for i in range(args.parallel_size):
        worker_thread = Thread(
            target=train_function,
            args=(i, th_ctr, episodes_queue, net_updates,))
        train_threads.append(worker_thread)

    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)

    # set start time
    start_time = time.time() - wall_t

    for t in train_threads:
        t.start()

    print('Press Ctrl+C to stop')

    for t in train_threads:
        t.join()

    logger.info('Now saving data. Please wait')

    # write wall time
    wall_t = time.time() - start_time
    wall_t_fname = folder / 'wall_t.{}'.format(global_t)
    with wall_t_fname.open('w') as f:
        f.write(str(wall_t))

    # save final model
    checkpoint_file = str(folder / '{}_checkpoint_a3c'.format(GYM_ENV_NAME))
    root_saver.save(sess, checkpoint_file, global_step=global_t)

    dump_final_pickle([rewards, sil_dict, rollout_dict],
                      [reward_fname, sil_fname, rollout_fname])

    logger.info('Data saved!')

    # if everything saves okay & is done training (not because of pressed Ctrl+C),
    # clean up previous ckpt to save space
    if global_t >= (args.max_time_step * args.max_time_step_fraction):
        remove_pickle([reward_fname, sil_fname, rollout_fname],
                      prev_ckpt_t)
        if args.load_pretrained_model:
            remove_pickle([class_reward_fname], prev_ckpt_t)

        remove_pickle([sharedmem_fname, sharedmem_params_fname, sharedmem_trees_fname],
                      prev_ckpt_t)
        if rollout_buffer is not None and args.priority_memory:
            remove_pickle([rolloutmem_fname, rolloutmem_params_fname, rolloutmem_trees_fname],
                          prev_ckpt_t)

        logger.info('Done training, removed ckpt from step {}'.format(prev_ckpt_t))


    sess.close()
    if pretrain_sess:
        pretrain_sess.close()
Exemplo n.º 11
0
learning_rate_input = tf.placeholder("float")

grad_applier = RMSPropApplier(learning_rate=learning_rate_input,
                              decay=RMSP_ALPHA,
                              momentum=0.0,
                              epsilon=RMSP_EPSILON,
                              clip_norm=GRAD_NORM_CLIP,
                              device=device)

for i in range(args.threads):
    training_thread = A3CTrainingThread(
        i,
        global_network,
        args.initial_learning_rate,
        learning_rate_input,
        grad_applier,
        args.max_time_step,
        device=device,
        environment=make_env(i + 1))  #gym.make(args.gym_env))
    training_threads.append(training_thread)

#os._exit(0)

############################################################################
# summary for tensorboard

from summaries import setup_summaries

############################################################################
# prepare session
Exemplo n.º 12
0
                              device=device)
# 设置同步mutex
mutex = threading.Lock()
#设置同步的condition
condition = threading.Condition()

# 设置共享变量,用来线程间通信,存放刚完成前一道工序的job,即用来通知后一道工序的机器来启动该工序为等待状态
arrived_jobs = list()
for i in range(MACHINE_SIZE):
    arrived_jobs.append(list())
terminal_count = [0]
for i in range(PARALLEL_SIZE):
    training_thread = A3CTrainingThread(i,
                                        initial_learning_rate,
                                        learning_rate_input,
                                        grad_applier,
                                        MAX_TIME_EPISODE,
                                        device=device,
                                        arrived_jobs=arrived_jobs,
                                        condition=condition)
    training_threads.append(training_thread)

# prepare session
sess = tf.Session(config=tf.ConfigProto(log_device_placement=False,
                                        allow_soft_placement=True))

init = tf.global_variables_initializer()
sess.run(init)

# summary for tensorboard
score_input = tf.placeholder(tf.int32)
tf.summary.scalar("score", score_input)
Exemplo n.º 13
0

def get_game(thread_index):
    from game_state_env import GameStateGymEnv
    import gym

    env = gym.make('Catcher-v0')
    return GameStateGymEnv(env)


for i in range(PARALLEL_SIZE):
    training_thread = A3CTrainingThread(i,
                                        global_network,
                                        initial_learning_rate,
                                        learning_rate_input,
                                        grad_applier,
                                        MAX_TIME_STEP,
                                        device=device,
                                        game_function=get_game,
                                        local_network=make_network)
    training_threads.append(training_thread)

# prepare session
sess = tf.Session(config=tf.ConfigProto(log_device_placement=False,
                                        allow_soft_placement=True))

init = tf.initialize_all_variables()
sess.run(init)

# summary for tensorboard
score_input = tf.placeholder(tf.int32)
Exemplo n.º 14
0
def run_a3c(args):
    """
    python3 run_experiment.py --gym-env=PongNoFrameskip-v4 --parallel-size=16 --initial-learn-rate=7e-4 --use-lstm --use-mnih-2015

    python3 run_experiment.py --gym-env=PongNoFrameskip-v4 --parallel-size=16 --initial-learn-rate=7e-4 --use-lstm --use-mnih-2015 --use-transfer --not-transfer-fc2 --transfer-folder=<>

    python3 run_experiment.py --gym-env=PongNoFrameskip-v4 --parallel-size=16 --initial-learn-rate=7e-4 --use-lstm --use-mnih-2015 --use-transfer --not-transfer-fc2 --transfer-folder=<> --load-pretrained-model --onevsall-mtl --pretrained-model-folder=<> --use-pretrained-model-as-advice --use-pretrained-model-as-reward-shaping
    """
    from game_ac_network import GameACFFNetwork, GameACLSTMNetwork
    from a3c_training_thread import A3CTrainingThread
    if args.use_gpu:
        assert args.cuda_devices != ''
        os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_devices
    else:
        os.environ['CUDA_VISIBLE_DEVICES'] = ''
    import tensorflow as tf

    def log_uniform(lo, hi, rate):
        log_lo = math.log(lo)
        log_hi = math.log(hi)
        v = log_lo * (1 - rate) + log_hi * rate
        return math.exp(v)

    if not os.path.exists('results/a3c'):
        os.makedirs('results/a3c')

    if args.folder is not None:
        folder = 'results/a3c/{}_{}'.format(args.gym_env.replace('-', '_'),
                                            args.folder)
    else:
        folder = 'results/a3c/{}'.format(args.gym_env.replace('-', '_'))
        end_str = ''

        if args.use_mnih_2015:
            end_str += '_mnih2015'
        if args.use_lstm:
            end_str += '_lstm'
        if args.unclipped_reward:
            end_str += '_rawreward'
        elif args.log_scale_reward:
            end_str += '_logreward'
        if args.transformed_bellman:
            end_str += '_transformedbell'

        if args.use_transfer:
            end_str += '_transfer'
            if args.not_transfer_conv2:
                end_str += '_noconv2'
            elif args.not_transfer_conv3 and args.use_mnih_2015:
                end_str += '_noconv3'
            elif args.not_transfer_fc1:
                end_str += '_nofc1'
            elif args.not_transfer_fc2:
                end_str += '_nofc2'
        if args.finetune_upper_layers_only:
            end_str += '_tune_upperlayers'
        if args.train_with_demo_num_steps > 0 or args.train_with_demo_num_epochs > 0:
            end_str += '_pretrain_ina3c'
        if args.use_demo_threads:
            end_str += '_demothreads'

        if args.load_pretrained_model:
            if args.use_pretrained_model_as_advice:
                end_str += '_modelasadvice'
            if args.use_pretrained_model_as_reward_shaping:
                end_str += '_modelasshaping'
        folder += end_str

    if args.append_experiment_num is not None:
        folder += '_' + args.append_experiment_num

    if False:
        from common.util import LogFormatter
        fh = logging.FileHandler('{}/a3c.log'.format(folder), mode='w')
        fh.setLevel(logging.DEBUG)
        formatter = LogFormatter(
            '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        fh.setFormatter(formatter)
        logger.addHandler(fh)

    demo_memory = None
    num_demos = 0
    max_reward = 0.
    if args.load_memory or args.load_demo_cam:
        if args.demo_memory_folder is not None:
            demo_memory_folder = args.demo_memory_folder
        else:
            demo_memory_folder = 'collected_demo/{}'.format(
                args.gym_env.replace('-', '_'))

    if args.load_memory:
        # FIXME: use new load_memory function
        demo_memory, actions_ctr, max_reward = load_memory(
            args.gym_env, demo_memory_folder,
            imgs_normalized=True)  #, create_symmetry=True)
        action_freq = [
            actions_ctr[a] for a in range(demo_memory[0].num_actions)
        ]
        num_demos = len(demo_memory)

    demo_memory_cam = None
    if args.load_demo_cam:
        demo_cam, _, total_rewards_cam, _ = load_memory(
            name=None,
            demo_memory_folder=demo_memory_folder,
            demo_ids=args.demo_cam_id,
            imgs_normalized=False)

        demo_cam = demo_cam[int(args.demo_cam_id)]
        demo_memory_cam = np.zeros((len(demo_cam), demo_cam.height,
                                    demo_cam.width, demo_cam.phi_length),
                                   dtype=np.float32)
        for i in range(len(demo_cam)):
            s0 = (demo_cam[i])[0]
            demo_memory_cam[i] = np.copy(s0)
        del demo_cam
        logger.info("loaded demo {} for testing CAM".format(args.demo_cam_id))

    device = "/cpu:0"
    gpu_options = None
    if args.use_gpu:
        device = "/gpu:" + os.environ["CUDA_VISIBLE_DEVICES"]
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=args.gpu_fraction)

    initial_learning_rate = args.initial_learn_rate
    logger.info('Initial Learning Rate={}'.format(initial_learning_rate))
    time.sleep(2)

    global_t = 0
    pretrain_global_t = 0
    pretrain_epoch = 0
    rewards = {'train': {}, 'eval': {}}
    best_model_reward = -(sys.maxsize)

    stop_requested = False

    game_state = GameState(env_id=args.gym_env)
    action_size = game_state.env.action_space.n
    game_state.close()
    del game_state.env
    del game_state

    config = tf.ConfigProto(gpu_options=gpu_options,
                            log_device_placement=False,
                            allow_soft_placement=True)

    pretrained_model = None
    pretrained_model_sess = None
    if args.load_pretrained_model:
        if args.onevsall_mtl:
            from game_class_network import MTLBinaryClassNetwork as PretrainedModelNetwork
        elif args.onevsall_mtl_linear:
            from game_class_network import MTLMultivariateNetwork as PretrainedModelNetwork
        else:
            from game_class_network import MultiClassNetwork as PretrainedModelNetwork
            logger.error("Not supported yet!")
            assert False

        if args.pretrained_model_folder is not None:
            pretrained_model_folder = args.pretrained_model_folder
        else:
            pretrained_model_folder = '{}_classifier_use_mnih_onevsall_mtl'.format(
                args.gym_env.replace('-', '_'))
        PretrainedModelNetwork.use_mnih_2015 = args.use_mnih_2015
        pretrained_model = PretrainedModelNetwork(action_size, -1, device)
        pretrained_model_sess = tf.Session(config=config,
                                           graph=pretrained_model.graph)
        pretrained_model.load(
            pretrained_model_sess,
            '{}/{}_checkpoint'.format(pretrained_model_folder,
                                      args.gym_env.replace('-', '_')))

    if args.use_lstm:
        GameACLSTMNetwork.use_mnih_2015 = args.use_mnih_2015
        global_network = GameACLSTMNetwork(action_size, -1, device)
    else:
        GameACFFNetwork.use_mnih_2015 = args.use_mnih_2015
        global_network = GameACFFNetwork(action_size, -1, device)

    training_threads = []

    learning_rate_input = tf.placeholder(tf.float32, shape=(), name="opt_lr")

    grad_applier = tf.train.RMSPropOptimizer(learning_rate=learning_rate_input,
                                             decay=args.rmsp_alpha,
                                             epsilon=args.rmsp_epsilon)

    A3CTrainingThread.log_interval = args.log_interval
    A3CTrainingThread.performance_log_interval = args.performance_log_interval
    A3CTrainingThread.local_t_max = args.local_t_max
    A3CTrainingThread.demo_t_max = args.demo_t_max
    A3CTrainingThread.use_lstm = args.use_lstm
    A3CTrainingThread.action_size = action_size
    A3CTrainingThread.entropy_beta = args.entropy_beta
    A3CTrainingThread.demo_entropy_beta = args.demo_entropy_beta
    A3CTrainingThread.gamma = args.gamma
    A3CTrainingThread.use_mnih_2015 = args.use_mnih_2015
    A3CTrainingThread.env_id = args.gym_env
    A3CTrainingThread.finetune_upper_layers_only = args.finetune_upper_layers_only
    A3CTrainingThread.transformed_bellman = args.transformed_bellman
    A3CTrainingThread.clip_norm = args.grad_norm_clip
    A3CTrainingThread.use_grad_cam = args.use_grad_cam

    if args.unclipped_reward:
        A3CTrainingThread.reward_type = "RAW"
    elif args.log_scale_reward:
        A3CTrainingThread.reward_type = "LOG"
    else:
        A3CTrainingThread.reward_type = "CLIP"

    n_shapers = args.parallel_size  #int(args.parallel_size * .25)
    mod = args.parallel_size // n_shapers
    for i in range(args.parallel_size):
        is_reward_shape = False
        is_advice = False
        if i % mod == 0:
            is_reward_shape = args.use_pretrained_model_as_reward_shaping
            is_advice = args.use_pretrained_model_as_advice
        training_thread = A3CTrainingThread(
            i,
            global_network,
            initial_learning_rate,
            learning_rate_input,
            grad_applier,
            args.max_time_step,
            device=device,
            pretrained_model=pretrained_model,
            pretrained_model_sess=pretrained_model_sess,
            advice=is_advice,
            reward_shaping=is_reward_shape)
        training_threads.append(training_thread)

    # prepare session
    sess = tf.Session(config=config)

    if args.use_transfer:
        if args.transfer_folder is not None:
            transfer_folder = args.transfer_folder
        else:
            transfer_folder = 'results/pretrain_models/{}'.format(
                args.gym_env.replace('-', '_'))
            end_str = ''
            if args.use_mnih_2015:
                end_str += '_mnih2015'
            end_str += '_l2beta1E-04_batchprop'  #TODO: make this an argument
            transfer_folder += end_str

        transfer_folder += '/transfer_model'

        if args.not_transfer_conv2:
            transfer_var_list = [
                global_network.W_conv1, global_network.b_conv1
            ]
        elif (args.not_transfer_conv3 and args.use_mnih_2015):
            transfer_var_list = [
                global_network.W_conv1, global_network.b_conv1,
                global_network.W_conv2, global_network.b_conv2
            ]
        elif args.not_transfer_fc1:
            transfer_var_list = [
                global_network.W_conv1,
                global_network.b_conv1,
                global_network.W_conv2,
                global_network.b_conv2,
            ]
            if args.use_mnih_2015:
                transfer_var_list += [
                    global_network.W_conv3, global_network.b_conv3
                ]
        elif args.not_transfer_fc2:
            transfer_var_list = [
                global_network.W_conv1, global_network.b_conv1,
                global_network.W_conv2, global_network.b_conv2,
                global_network.W_fc1, global_network.b_fc1
            ]
            if args.use_mnih_2015:
                transfer_var_list += [
                    global_network.W_conv3, global_network.b_conv3
                ]
        else:
            transfer_var_list = [
                global_network.W_conv1, global_network.b_conv1,
                global_network.W_conv2, global_network.b_conv2,
                global_network.W_fc1, global_network.b_fc1,
                global_network.W_fc2, global_network.b_fc2
            ]
            if args.use_mnih_2015:
                transfer_var_list += [
                    global_network.W_conv3, global_network.b_conv3
                ]

        global_network.load_transfer_model(
            sess,
            folder=transfer_folder,
            not_transfer_fc2=args.not_transfer_fc2,
            not_transfer_fc1=args.not_transfer_fc1,
            not_transfer_conv3=(args.not_transfer_conv3
                                and args.use_mnih_2015),
            not_transfer_conv2=args.not_transfer_conv2,
            var_list=transfer_var_list)

    def initialize_uninitialized(sess):
        global_vars = tf.global_variables()
        is_not_initialized = sess.run(
            [tf.is_variable_initialized(var) for var in global_vars])
        not_initialized_vars = [
            v for (v, f) in zip(global_vars, is_not_initialized) if not f
        ]

        if len(not_initialized_vars):
            sess.run(tf.variables_initializer(not_initialized_vars))

    if args.use_transfer:
        initialize_uninitialized(sess)
    else:
        sess.run(tf.global_variables_initializer())

    # summary writer for tensorboard
    summary_op = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(
        'results/log/a3c/{}/'.format(args.gym_env.replace('-', '_')) +
        folder[12:], sess.graph)

    # init or load checkpoint with saver
    root_saver = tf.train.Saver(max_to_keep=1)
    saver = tf.train.Saver(max_to_keep=6)
    best_saver = tf.train.Saver(max_to_keep=1)
    checkpoint = tf.train.get_checkpoint_state(folder)
    if checkpoint and checkpoint.model_checkpoint_path:
        root_saver.restore(sess, checkpoint.model_checkpoint_path)
        logger.info("checkpoint loaded:{}".format(
            checkpoint.model_checkpoint_path))
        tokens = checkpoint.model_checkpoint_path.split("-")
        # set global step
        global_t = int(tokens[-1])
        logger.info(">>> global step set: {}".format(global_t))
        # set wall time
        wall_t_fname = folder + '/' + 'wall_t.' + str(global_t)
        with open(wall_t_fname, 'r') as f:
            wall_t = float(f.read())
        with open(folder + '/pretrain_global_t', 'r') as f:
            pretrain_global_t = int(f.read())
        with open(folder + '/model_best/best_model_reward',
                  'r') as f_best_model_reward:
            best_model_reward = float(f_best_model_reward.read())
        rewards = pickle.load(
            open(
                folder + '/' + args.gym_env.replace('-', '_') +
                '-a3c-rewards.pkl', 'rb'))
    else:
        logger.warning("Could not find old checkpoint")
        # set wall time
        wall_t = 0.0
        prepare_dir(folder, empty=True)
        prepare_dir(folder + '/model_checkpoints', empty=True)
        prepare_dir(folder + '/model_best', empty=True)
        prepare_dir(folder + '/frames', empty=True)

    lock = threading.Lock()
    test_lock = False
    if global_t == 0:
        test_lock = True

    last_temp_global_t = global_t
    ispretrain_markers = [False] * args.parallel_size
    num_demo_thread = 0
    ctr_demo_thread = 0

    def train_function(parallel_index):
        nonlocal global_t, pretrain_global_t, pretrain_epoch, \
            rewards, test_lock, lock, \
            last_temp_global_t, ispretrain_markers, num_demo_thread, \
            ctr_demo_thread
        training_thread = training_threads[parallel_index]

        training_thread.set_summary_writer(summary_writer)

        # set all threads as demo threads
        training_thread.is_demo_thread = args.load_memory and args.use_demo_threads
        if training_thread.is_demo_thread or args.train_with_demo_num_steps > 0 or args.train_with_demo_num_epochs:
            training_thread.pretrain_init(demo_memory)

        if global_t == 0 and (
                args.train_with_demo_num_steps > 0
                or args.train_with_demo_num_epochs > 0) and parallel_index < 2:
            ispretrain_markers[parallel_index] = True
            training_thread.replay_mem_reset()

            # Pretraining with demo memory
            logger.info("t_idx={} pretrain starting".format(parallel_index))
            while ispretrain_markers[parallel_index]:
                if stop_requested:
                    return
                if pretrain_global_t > args.train_with_demo_num_steps and pretrain_epoch > args.train_with_demo_num_epochs:
                    # At end of pretraining, reset state
                    training_thread.replay_mem_reset()
                    training_thread.episode_reward = 0
                    training_thread.local_t = 0
                    if args.use_lstm:
                        training_thread.local_network.reset_state()
                    ispretrain_markers[parallel_index] = False
                    logger.info(
                        "t_idx={} pretrain ended".format(parallel_index))
                    break

                diff_pretrain_global_t, _ = training_thread.demo_process(
                    sess, pretrain_global_t)
                for _ in range(diff_pretrain_global_t):
                    pretrain_global_t += 1
                    if pretrain_global_t % 10000 == 0:
                        logger.debug(
                            "pretrain_global_t={}".format(pretrain_global_t))

                pretrain_epoch += 1
                if pretrain_epoch % 1000 == 0:
                    logger.debug("pretrain_epoch={}".format(pretrain_epoch))

            # Waits for all threads to finish pretraining
            while not stop_requested and any(ispretrain_markers):
                time.sleep(0.01)

        # Evaluate model before training
        if not stop_requested and global_t == 0:
            with lock:
                if parallel_index == 0:
                    test_reward, test_steps, test_episodes = training_threads[
                        0].testing(sess,
                                   args.eval_max_steps,
                                   global_t,
                                   folder,
                                   demo_memory_cam=demo_memory_cam)
                    rewards['eval'][global_t] = (test_reward, test_steps,
                                                 test_episodes)
                    saver.save(
                        sess,
                        folder + '/model_checkpoints/' +
                        '{}_checkpoint'.format(args.gym_env.replace('-', '_')),
                        global_step=global_t)
                    save_best_model(test_reward)
                    test_lock = False
            # all threads wait until evaluation finishes
            while not stop_requested and test_lock:
                time.sleep(0.01)

        # set start_time
        start_time = time.time() - wall_t
        training_thread.set_start_time(start_time)
        episode_end = True
        use_demo_thread = False
        while True:
            if stop_requested:
                return
            if global_t >= (args.max_time_step * args.max_time_step_fraction):
                return

            if args.use_demo_threads and global_t < args.max_steps_threads_as_demo and episode_end and num_demo_thread < 16:
                #if num_demo_thread < 2:
                demo_rate = 1.0 * (args.max_steps_threads_as_demo -
                                   global_t) / args.max_steps_threads_as_demo
                if demo_rate < 0.0333:
                    demo_rate = 0.0333

                if np.random.random() <= demo_rate and num_demo_thread < 16:
                    ctr_demo_thread += 1
                    training_thread.replay_mem_reset(D_idx=ctr_demo_thread %
                                                     num_demos)
                    num_demo_thread += 1
                    logger.info(
                        "idx={} as demo thread started ({}/16) rate={}".format(
                            parallel_index, num_demo_thread, demo_rate))
                    use_demo_thread = True

            if use_demo_thread:
                diff_global_t, episode_end = training_thread.demo_process(
                    sess, global_t)
                if episode_end:
                    num_demo_thread -= 1
                    use_demo_thread = False
                    logger.info("idx={} demo thread concluded ({}/16)".format(
                        parallel_index, num_demo_thread))
            else:
                diff_global_t, episode_end = training_thread.process(
                    sess, global_t, rewards)

            for _ in range(diff_global_t):
                global_t += 1
                if global_t % args.eval_freq == 0:
                    temp_global_t = global_t
                    lock.acquire()
                    try:
                        # catch multiple threads getting in at the same time
                        if last_temp_global_t == temp_global_t:
                            logger.info("Threading race problem averted!")
                            continue
                        test_lock = True
                        test_reward, test_steps, n_episodes = training_thread.testing(
                            sess,
                            args.eval_max_steps,
                            temp_global_t,
                            folder,
                            demo_memory_cam=demo_memory_cam)
                        rewards['eval'][temp_global_t] = (test_reward,
                                                          test_steps,
                                                          n_episodes)
                        if temp_global_t % (
                            (args.max_time_step * args.max_time_step_fraction)
                                // 5) == 0:
                            saver.save(sess,
                                       folder + '/model_checkpoints/' +
                                       '{}_checkpoint'.format(
                                           args.gym_env.replace('-', '_')),
                                       global_step=temp_global_t,
                                       write_meta_graph=False)
                        if test_reward > best_model_reward:
                            save_best_model(test_reward)
                        test_lock = False
                        last_temp_global_t = temp_global_t
                    finally:
                        lock.release()
                if global_t % (
                    (args.max_time_step * args.max_time_step_fraction) //
                        5) == 0:
                    saver.save(
                        sess,
                        folder + '/model_checkpoints/' +
                        '{}_checkpoint'.format(args.gym_env.replace('-', '_')),
                        global_step=global_t,
                        write_meta_graph=False)
                # all threads wait until evaluation finishes
                while not stop_requested and test_lock:
                    time.sleep(0.01)

    def signal_handler(signal, frame):
        nonlocal stop_requested
        logger.info('You pressed Ctrl+C!')
        stop_requested = True

        if stop_requested and global_t == 0:
            sys.exit(1)

    def save_best_model(test_reward):
        nonlocal best_model_reward
        best_model_reward = test_reward
        with open(folder + '/model_best/best_model_reward',
                  'w') as f_best_model_reward:
            f_best_model_reward.write(str(best_model_reward))
        best_saver.save(
            sess, folder + '/model_best/' +
            '{}_checkpoint'.format(args.gym_env.replace('-', '_')))

    train_threads = []
    for i in range(args.parallel_size):
        train_threads.append(
            threading.Thread(target=train_function, args=(i, )))

    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)

    # set start time
    start_time = time.time() - wall_t

    for t in train_threads:
        t.start()

    print('Press Ctrl+C to stop')

    for t in train_threads:
        t.join()

    logger.info('Now saving data. Please wait')

    # write wall time
    wall_t = time.time() - start_time
    wall_t_fname = folder + '/' + 'wall_t.' + str(global_t)
    with open(wall_t_fname, 'w') as f:
        f.write(str(wall_t))
    with open(folder + '/pretrain_global_t', 'w') as f:
        f.write(str(pretrain_global_t))

    root_saver.save(
        sess,
        folder + '/{}_checkpoint_a3c'.format(args.gym_env.replace('-', '_')),
        global_step=global_t)

    pickle.dump(
        rewards,
        open(
            folder + '/' + args.gym_env.replace('-', '_') + '-a3c-rewards.pkl',
            'wb'), pickle.HIGHEST_PROTOCOL)
    logger.info('Data saved!')

    sess.close()
Exemplo n.º 15
0
training_threads = []

learning_rate_input = tf.placeholder("float")

grad_applier = RMSPropApplier(learning_rate=learning_rate_input,
                              decay=options.rmsp_alpha,
                              momentum=0.0,
                              epsilon=options.rmsp_epsilon,
                              clip_norm=options.grad_norm_clip,
                              device=device)

for i in range(options.parallel_size):
    training_thread = A3CTrainingThread(i,
                                        global_network,
                                        initial_learning_rate,
                                        learning_rate_input,
                                        grad_applier,
                                        options.max_time_step,
                                        device=device,
                                        options=options)
    training_threads.append(training_thread)

# prepare session
sess = tf.Session(config=tf.ConfigProto(log_device_placement=False,
                                        gpu_options={'allow_growth': True},
                                        allow_soft_placement=True))

init = tf.global_variables_initializer()
sess.run(init)

# summary for tensorboard
score_input = tf.placeholder(tf.int32)
Exemplo n.º 16
0
def train():
    #initial learning rate
    initial_learning_rate = log_uniform(INITIAL_ALPHA_LOW, INITIAL_ALPHA_HIGH,
                                        INITIAL_ALPHA_LOG_RATE)

    # parameter server and worker information
    ps_hosts = np.zeros(FLAGS.ps_hosts_num, dtype=object)
    worker_hosts = np.zeros(FLAGS.worker_hosts_num, dtype=object)
    port_num = FLAGS.st_port_num
    for i in range(FLAGS.ps_hosts_num):
        ps_hosts[i] = str(FLAGS.hostname) + ":" + str(port_num)
        port_num += 1
    for i in range(FLAGS.worker_hosts_num):
        worker_hosts[i] = str(FLAGS.hostname) + ":" + str(port_num)
        port_num += 1
    ps_hosts = list(ps_hosts)
    worker_hosts = list(worker_hosts)
    # Create a cluster from the parameter server and worker hosts.
    cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})

    # Create and start a server for the local task.
    server = tf.train.Server(cluster,
                             job_name=FLAGS.job_name,
                             task_index=FLAGS.task_index)

    if FLAGS.job_name == "ps":
        server.join()
    elif FLAGS.job_name == "worker":
        device = tf.train.replica_device_setter(
            worker_device="/job:worker/task:%d" % FLAGS.task_index,
            cluster=cluster)

        learning_rate_input = tf.placeholder("float")

        grad_applier = RMSPropApplier(learning_rate=learning_rate_input,
                                      decay=RMSP_ALPHA,
                                      momentum=0.0,
                                      epsilon=RMSP_EPSILON,
                                      clip_norm=GRAD_NORM_CLIP,
                                      device=device)

        tf.set_random_seed(1)
        #There are no global network
        training_thread = A3CTrainingThread(0,
                                            "",
                                            initial_learning_rate,
                                            learning_rate_input,
                                            grad_applier,
                                            MAX_TIME_STEP,
                                            device=device,
                                            FLAGS=FLAGS,
                                            task_index=FLAGS.task_index)

        # prepare session
        with tf.device(
                tf.train.replica_device_setter(
                    worker_device="/job:worker/task:%d" % FLAGS.task_index,
                    cluster=cluster)):
            # flag for task
            flag = tf.get_variable('flag', [],
                                   initializer=tf.constant_initializer(0),
                                   trainable=False)
            flag_ph = tf.placeholder(flag.dtype, shape=flag.get_shape())
            flag_ops = flag.assign(flag_ph)
            # global step
            global_step = tf.get_variable(
                'global_step', [],
                initializer=tf.constant_initializer(0),
                trainable=False)
            global_step_ph = tf.placeholder(global_step.dtype,
                                            shape=global_step.get_shape())
            global_step_ops = global_step.assign(global_step_ph)
            # score for tensorboard and score_set for genetic algorithm
            score = tf.get_variable('score', [],
                                    initializer=tf.constant_initializer(-21),
                                    trainable=False)
            score_ph = tf.placeholder(score.dtype, shape=score.get_shape())
            score_ops = score.assign(score_ph)
            score_set = np.zeros(FLAGS.worker_hosts_num, dtype=object)
            score_set_ph = np.zeros(FLAGS.worker_hosts_num, dtype=object)
            score_set_ops = np.zeros(FLAGS.worker_hosts_num, dtype=object)
            for i in range(FLAGS.worker_hosts_num):
                score_set[i] = tf.get_variable(
                    'score' + str(i), [],
                    initializer=tf.constant_initializer(-1000),
                    trainable=False)
                score_set_ph[i] = tf.placeholder(
                    score_set[i].dtype, shape=score_set[i].get_shape())
                score_set_ops[i] = score_set[i].assign(score_set_ph[i])
            # fixed path of earlier task
            fixed_path_tf = np.zeros((FLAGS.L, FLAGS.M), dtype=object)
            fixed_path_ph = np.zeros((FLAGS.L, FLAGS.M), dtype=object)
            fixed_path_ops = np.zeros((FLAGS.L, FLAGS.M), dtype=object)
            for i in range(FLAGS.L):
                for j in range(FLAGS.M):
                    fixed_path_tf[i, j] = tf.get_variable(
                        'fixed_path' + str(i) + "-" + str(j), [],
                        initializer=tf.constant_initializer(0),
                        trainable=False)
                    fixed_path_ph[i, j] = tf.placeholder(
                        fixed_path_tf[i, j].dtype,
                        shape=fixed_path_tf[i, j].get_shape())
                    fixed_path_ops[i, j] = fixed_path_tf[i, j].assign(
                        fixed_path_ph[i, j])
            # parameters on PathNet
            vars_ = training_thread.local_network.get_vars()
            vars_ph = np.zeros(len(vars_), dtype=object)
            vars_ops = np.zeros(len(vars_), dtype=object)
            for i in range(len(vars_)):
                vars_ph[i] = tf.placeholder(vars_[i].dtype,
                                            shape=vars_[i].get_shape())
                vars_ops[i] = vars_[i].assign(vars_ph[i])

            # initialization
            init_op = tf.global_variables_initializer()
            # summary for tensorboard
            tf.summary.scalar("score", score)
            summary_op = tf.summary.merge_all()
            saver = tf.train.Saver()

        sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0),
                                 global_step=global_step,
                                 logdir=FLAGS.log_dir,
                                 summary_op=summary_op,
                                 saver=saver,
                                 init_op=init_op)

        with sv.managed_session(server.target) as sess:
            if (FLAGS.task_index != (FLAGS.worker_hosts_num - 1)):
                for task in range(2):
                    while True:
                        if (sess.run([flag])[0] == (task + 1)):
                            break
                            time.sleep(2)
                    # Set fixed_path
                    fixed_path = np.zeros((FLAGS.L, FLAGS.M), dtype=float)
                    for i in range(FLAGS.L):
                        for j in range(FLAGS.M):
                            if (sess.run([fixed_path_tf[i, j]])[0] == 1):
                                fixed_path[i, j] = 1.0
                    training_thread.local_network.set_fixed_path(fixed_path)
                    # set start_time
                    wall_t = 0.0
                    start_time = time.time() - wall_t
                    training_thread.set_start_time(start_time)
                    while True:
                        if sess.run([global_step])[0] > (MAX_TIME_STEP *
                                                         (task + 1)):
                            break
                        diff_global_t = training_thread.process(
                            sess,
                            sess.run([global_step])[0], "", summary_op, "",
                            score_ph, score_ops, "", FLAGS,
                            score_set_ph[FLAGS.task_index],
                            score_set_ops[FLAGS.task_index],
                            score_set[FLAGS.task_index])
                        sess.run(
                            global_step_ops, {
                                global_step_ph:
                                sess.run([global_step])[0] + diff_global_t
                            })
            else:
                fixed_path = np.zeros((FLAGS.L, FLAGS.M), dtype=float)
                vars_backup = np.zeros(len(vars_), dtype=object)
                vars_backup = sess.run(vars_)
                winner_idx = 0
                for task in range(2):
                    # Generating randomly geopath
                    geopath_set = np.zeros(FLAGS.worker_hosts_num - 1,
                                           dtype=object)
                    for i in range(FLAGS.worker_hosts_num - 1):
                        geopath_set[i] = pathnet.get_geopath(
                            FLAGS.L, FLAGS.M, FLAGS.N)
                        tmp = np.zeros((FLAGS.L, FLAGS.M), dtype=float)
                        for j in range(FLAGS.L):
                            for k in range(FLAGS.M):
                                if ((geopath_set[i][j, k] == 1.0)
                                        or (fixed_path[j, k] == 1.0)):
                                    tmp[j, k] = 1.0
                        pathnet.geopath_insert(
                            sess, training_thread.local_network.
                            geopath_update_placeholders_set[i],
                            training_thread.local_network.
                            geopath_update_ops_set[i], tmp, FLAGS.L, FLAGS.M)
                    print("Geopath Setting Done")
                    sess.run(flag_ops, {flag_ph: (task + 1)})
                    print("=============Task" + str(task + 1) + "============")
                    score_subset = np.zeros(FLAGS.B, dtype=float)
                    score_set_print = np.zeros(FLAGS.worker_hosts_num,
                                               dtype=float)
                    rand_idx = range(FLAGS.worker_hosts_num - 1)
                    np.random.shuffle(rand_idx)
                    rand_idx = rand_idx[:FLAGS.B]
                    while True:
                        if sess.run([global_step])[0] > (MAX_TIME_STEP *
                                                         (task + 1)):
                            break
                        flag_sum = 0
                        for i in range(FLAGS.worker_hosts_num - 1):
                            score_set_print[i] = sess.run([score_set[i]])[0]
                        print(score_set_print)
                        for i in range(len(rand_idx)):
                            score_subset[i] = sess.run(
                                [score_set[rand_idx[i]]])[0]
                            if (score_subset[i] == -1000):
                                flag_sum = 1
                                break
                        if (flag_sum == 0):
                            winner_idx = rand_idx[np.argmax(score_subset)]
                            print(
                                str(sess.run([global_step])[0]) +
                                " Step Score: " +
                                str(sess.run([score_set[winner_idx]])[0]))
                            for i in rand_idx:
                                if (i != winner_idx):
                                    geopath_set[i] = np.copy(
                                        geopath_set[winner_idx])
                                    geopath_set[i] = pathnet.mutation(
                                        geopath_set[i], FLAGS.L, FLAGS.M,
                                        FLAGS.N)
                                    tmp = np.zeros((FLAGS.L, FLAGS.M),
                                                   dtype=float)
                                    for j in range(FLAGS.L):
                                        for k in range(FLAGS.M):
                                            if ((geopath_set[i][j, k] == 1.0)
                                                    or
                                                (fixed_path[j, k] == 1.0)):
                                                tmp[j, k] = 1.0
                                    pathnet.geopath_insert(
                                        sess, training_thread.local_network.
                                        geopath_update_placeholders_set[i],
                                        training_thread.local_network.
                                        geopath_update_ops_set[i], tmp,
                                        FLAGS.L, FLAGS.M)
                                    sess.run(score_set_ops[i],
                                             {score_set_ph[i]: -1000})
                            rand_idx = range(FLAGS.worker_hosts_num - 1)
                            np.random.shuffle(rand_idx)
                            rand_idx = rand_idx[:FLAGS.B]
                        else:
                            time.sleep(5)
                    # fixed_path setting
                    fixed_path = geopath_set[winner_idx]
                    for i in range(FLAGS.L):
                        for j in range(FLAGS.M):
                            if (fixed_path[i, j] == 1.0):
                                sess.run(fixed_path_ops[i, j],
                                         {fixed_path_ph[i, j]: 1})
                    training_thread.local_network.set_fixed_path(fixed_path)
                    # initialization of parameters except fixed_path
                    vars_idx = training_thread.local_network.get_vars_idx()
                    for i in range(len(vars_idx)):
                        if (vars_idx[i] == 1.0):
                            sess.run(vars_ops[i], {vars_ph[i]: vars_backup[i]})
        sv.stop()
        print("Done")
Exemplo n.º 17
0
def train():
    #initial learning rate
    pinitial_learning_rate = log_uniform(PINITIAL_ALPHA_LOW,
                                         PINITIAL_ALPHA_HIGH,
                                         INITIAL_ALPHA_LOG_RATE)
    vinitial_learning_rate = log_uniform(VINITIAL_ALPHA_LOW,
                                         VINITIAL_ALPHA_HIGH,
                                         INITIAL_ALPHA_LOG_RATE)

    # parameter server and worker information
    ps_hosts = np.zeros(FLAGS.ps_hosts_num, dtype=object)
    worker_hosts = np.zeros(FLAGS.worker_hosts_num, dtype=object)
    port_num = FLAGS.st_port_num
    for i in range(FLAGS.ps_hosts_num):
        ps_hosts[i] = str(FLAGS.hostname) + ":" + str(port_num)
        port_num += 1
    for i in range(FLAGS.worker_hosts_num):
        worker_hosts[i] = str(FLAGS.hostname) + ":" + str(port_num)
        port_num += 1
    ps_hosts = list(ps_hosts)
    worker_hosts = list(worker_hosts)
    # Create a cluster from the parameter server and worker hosts.
    cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})

    # Create and start a server for the local task.
    server = tf.train.Server(cluster,
                             job_name=FLAGS.job_name,
                             task_index=FLAGS.task_index)

    if FLAGS.job_name == "ps":
        server.join()
    elif FLAGS.job_name == "worker":
        device = tf.train.replica_device_setter(
            worker_device="/job:worker/task:%d" % FLAGS.task_index,
            cluster=cluster)

        plearning_rate_input = tf.placeholder("float")
        vlearning_rate_input = tf.placeholder("float")

        pgrad_applier = RMSPropApplier(learning_rate=plearning_rate_input,
                                       decay=RMSP_ALPHA,
                                       momentum=0.0,
                                       epsilon=RMSP_EPSILON,
                                       clip_norm=GRAD_NORM_CLIP,
                                       device=device)
        vgrad_applier = RMSPropApplier(learning_rate=vlearning_rate_input,
                                       decay=RMSP_ALPHA,
                                       momentum=0.0,
                                       epsilon=RMSP_EPSILON,
                                       clip_norm=GRAD_NORM_CLIP,
                                       device=device)

        tf.set_random_seed(1)
        #There are no global network
        training_thread = A3CTrainingThread(0,
                                            "",
                                            pinitial_learning_rate,
                                            plearning_rate_input,
                                            pgrad_applier,
                                            vinitial_learning_rate,
                                            vlearning_rate_input,
                                            vgrad_applier,
                                            MAX_TIME_STEP,
                                            device=device,
                                            task_index=FLAGS.task_index)

        # prepare session
        with tf.device(
                tf.train.replica_device_setter(
                    worker_device="/job:worker/task:%d" % FLAGS.task_index,
                    cluster=cluster)):
            global_step = tf.get_variable(
                'global_step', [],
                initializer=tf.constant_initializer(0),
                trainable=False)
            global_step_ph = tf.placeholder(global_step.dtype,
                                            shape=global_step.get_shape())
            global_step_ops = global_step.assign(global_step_ph)
            score = tf.get_variable('score', [],
                                    initializer=tf.constant_initializer(-21),
                                    trainable=False)
            score_ph = tf.placeholder(score.dtype, shape=score.get_shape())
            score_ops = score.assign(score_ph)
            init_op = tf.global_variables_initializer()
            # summary for tensorboard
            tf.summary.scalar("score", score)
            summary_op = tf.summary.merge_all()
            saver = tf.train.Saver()

        sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0),
                                 global_step=global_step,
                                 logdir=LOG_FILE,
                                 summary_op=summary_op,
                                 saver=saver,
                                 init_op=init_op)

        with sv.managed_session(server.target) as sess:
            # set start_time
            wall_t = 0.0
            start_time = time.time() - wall_t
            training_thread.set_start_time(start_time)
            local_t = 0
            while True:
                if sess.run([global_step])[0] > MAX_TIME_STEP:
                    break
                diff_global_t = training_thread.process(
                    sess,
                    sess.run([global_step])[0], "", summary_op, "", score_ph,
                    score_ops)
                sess.run(global_step_ops, {
                    global_step_ph:
                    sess.run([global_step])[0] + diff_global_t
                })
                local_t += diff_global_t

        sv.stop()
        print("Done")
Exemplo n.º 18
0
training_threads = []

learning_rate_input = tf.placeholder("float")

grad_applier = RMSPropApplier(learning_rate=learning_rate_input,
                              decay=RMSP_ALPHA,
                              momentum=0.0,
                              epsilon=RMSP_EPSILON,
                              clip_norm=GRAD_NORM_CLIP,
                              device=device)

for i in range(PARALLEL_SIZE):
    training_thread = A3CTrainingThread(i, global_network,
                                        initial_learning_rate,
                                        learning_rate_input,
                                        grad_applier, MAX_TIME_STEP,
                                        device=device)
    training_threads.append(training_thread)

# prepare session
sess = tf.Session(config=tf.ConfigProto(log_device_placement=False,
                                        allow_soft_placement=True))

init = tf.global_variables_initializer()
sess.run(init)

# summary for tensorboard
score_input = tf.placeholder(tf.int32)
tf.summary.scalar("score", score_input)
Exemplo n.º 19
0
def train():
    #initial learning rate
    initial_learning_rate = log_uniform(INITIAL_ALPHA_LOW,
                                        INITIAL_ALPHA_HIGH,
                                        INITIAL_ALPHA_LOG_RATE)

    # parameter server and worker information
    ps_hosts = np.zeros(FLAGS.ps_hosts_num,dtype=object);
    worker_hosts = np.zeros(FLAGS.worker_hosts_num,dtype=object);
    port_num=FLAGS.st_port_num;
    for i in range(FLAGS.ps_hosts_num):
        ps_hosts[i]=str(FLAGS.hostname)+":"+str(port_num);
        port_num+=1;
    for i in range(FLAGS.worker_hosts_num):
        worker_hosts[i]=str(FLAGS.hostname)+":"+str(port_num);
        port_num+=1;
    ps_hosts=list(ps_hosts);
    worker_hosts=list(worker_hosts);
    # Create a cluster from the parameter server and worker hosts.
    cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})

    # Create and start a server for the local task.
    server = tf.train.Server(cluster,
                             job_name=FLAGS.job_name,
                             task_index=FLAGS.task_index)


    if FLAGS.job_name == "ps":
        server.join();
    elif FLAGS.job_name == "worker":
        # gpu_assignment = FLAGS.task_index % NUM_GPUS
        # print("Assigning worker #%d to GPU #%d" % (FLAGS.task_index, gpu_assignment))
        # device=tf.train.replica_device_setter(
        #             worker_device="/job:worker/task:%d/gpu:%d" % (FLAGS.task_index, gpu_assignment),
        #             cluster=cluster);

        device=tf.train.replica_device_setter(
              worker_device="/job:worker/task:%d" % FLAGS.task_index,
              cluster=cluster);



        learning_rate_input = tf.placeholder("float")

        grad_applier = RMSPropApplier(learning_rate = learning_rate_input,
                                                                    decay = RMSP_ALPHA,
                                                                    momentum = 0.0,
                                                                    epsilon = RMSP_EPSILON,
                                                                    clip_norm = GRAD_NORM_CLIP,
                                                                    device = device)

        tf.set_random_seed(1);
        #There are no global network

        #lock = multiprocessing.Lock()

        #wrapper = ToDiscrete('constant-7')
        #env = wrapper(gym.make('gym_doom/DoomBasic-v0'))
        #env.close()

        training_thread = A3CTrainingThread(0,"",0,initial_learning_rate,learning_rate_input,grad_applier,MAX_TIME_STEP,device=device,FLAGS=FLAGS,task_index=FLAGS.task_index)

        # prepare session
        with tf.device(device):
            # flag for task
            flag = tf.get_variable('flag',[],initializer=tf.constant_initializer(0),trainable=False);
            flag_ph=tf.placeholder(flag.dtype,shape=flag.get_shape());
            flag_ops=flag.assign(flag_ph);
            # global step
            global_step = tf.get_variable('global_step',[],initializer=tf.constant_initializer(0),trainable=False);
            global_step_ph=tf.placeholder(global_step.dtype,shape=global_step.get_shape());
            global_step_ops=global_step.assign(global_step_ph);
            # score for tensorboard and score_set for genetic algorithm
            score = tf.get_variable('score',[],initializer=tf.constant_initializer(-21),trainable=False);
            score_ph=tf.placeholder(score.dtype,shape=score.get_shape());
            score_ops=score.assign(score_ph);
            score_set=np.zeros(FLAGS.worker_hosts_num,dtype=object);
            score_set_ph=np.zeros(FLAGS.worker_hosts_num,dtype=object);
            score_set_ops=np.zeros(FLAGS.worker_hosts_num,dtype=object);
            for i in range(FLAGS.worker_hosts_num):
                score_set[i] = tf.get_variable('score'+str(i),[],initializer=tf.constant_initializer(-1000),trainable=False);
                score_set_ph[i]=tf.placeholder(score_set[i].dtype,shape=score_set[i].get_shape());
                score_set_ops[i]=score_set[i].assign(score_set_ph[i]);
            # fixed path of earlier task
            fixed_path_tf=np.zeros((FLAGS.L,FLAGS.M),dtype=object);
            fixed_path_ph=np.zeros((FLAGS.L,FLAGS.M),dtype=object);
            fixed_path_ops=np.zeros((FLAGS.L,FLAGS.M),dtype=object);
            for i in range(FLAGS.L):
                for j in range(FLAGS.M):
                    fixed_path_tf[i,j]=tf.get_variable('fixed_path'+str(i)+"-"+str(j),[],initializer=tf.constant_initializer(0),trainable=False);
                    fixed_path_ph[i,j]=tf.placeholder(fixed_path_tf[i,j].dtype,shape=fixed_path_tf[i,j].get_shape());
                    fixed_path_ops[i,j]=fixed_path_tf[i,j].assign(fixed_path_ph[i,j]);
            # parameters on PathNet
            vars_=training_thread.local_network.get_vars();
            vars_ph=np.zeros(len(vars_),dtype=object);
            vars_ops=np.zeros(len(vars_),dtype=object);
            for i in range(len(vars_)):
                vars_ph[i]=tf.placeholder(vars_[i].dtype,shape=vars_[i].get_shape());
                vars_ops[i]=vars_[i].assign(vars_ph[i]);
            # initialization
            init_op=tf.global_variables_initializer();
            # summary for tensorboard
            tf.summary.scalar("score", score);
            summary_op = tf.summary.merge_all()
            saver = tf.train.Saver();

        sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0),
                                 global_step=global_step,
                                 logdir=FLAGS.log_dir,
                                 summary_op=summary_op,
                                 saver=saver,
                                 init_op=init_op)
        try:
            os.mkdir("./data/graphs")
        except:
            pass

        # config = tf.ConfigProto(
        #         device_count = {'GPU': 0}
        #     )
        # config = tf.ConfigProto()
        # config.gpu_options.allow_growth = True
        # config.gpu_options.per_process_gpu_memory_fraction = 0.1

        with sv.managed_session(server.target) as sess:
            if(FLAGS.task_index!=(FLAGS.worker_hosts_num-1)):
                 for task in range(2):
                    training_thread.set_training_stage(task)

                    while sess.run([flag])[0] != (task+1):
                        time.sleep(2)

                    # Set fixed_path
                    fixed_path=np.zeros((FLAGS.L,FLAGS.M),dtype=float);
                    for i in range(FLAGS.L):
                        for j in range(FLAGS.M):
                            if(sess.run([fixed_path_tf[i,j]])[0]==1):
                                fixed_path[i,j]=1.0;
                    training_thread.local_network.set_fixed_path(fixed_path);
                    # set start_time
                    wall_t=0.0;
                    start_time = time.time() - wall_t
                    training_thread.set_start_time(start_time)
                    while True:
                        if sess.run([global_step])[0] > (MAX_TIME_STEP*(task+1)):
                            break
                        diff_global_t = training_thread.process(sess, sess.run([global_step])[0], "",
                                                                                                    summary_op, "",score_ph,score_ops,"",FLAGS,score_set_ph[FLAGS.task_index],score_set_ops[FLAGS.task_index])
                        sess.run(global_step_ops,{global_step_ph:sess.run([global_step])[0]+diff_global_t});
            else:
                fixed_path=np.zeros((FLAGS.L,FLAGS.M),dtype=float)
                vars_backup=np.zeros(len(vars_),dtype=object)
                vars_backup=sess.run(vars_)
                winner_idx=0

                vis = visualize.GraphVisualize([FLAGS.M] * FLAGS.L, True)


                for task in range(2):
                    # Generating randomly geopath
                    geopath_set=np.zeros(FLAGS.worker_hosts_num-1,dtype=object);
                    for i in range(FLAGS.worker_hosts_num-1):
                        geopath_set[i]=pathnet.get_geopath(FLAGS.L,FLAGS.M,FLAGS.N);
                        tmp=np.zeros((FLAGS.L,FLAGS.M),dtype=float);
                        for j in range(FLAGS.L):
                            for k in range(FLAGS.M):
                                if((geopath_set[i][j,k]==1.0)or(fixed_path[j,k]==1.0)):
                                    tmp[j,k]=1.0;
                        pathnet.geopath_insert(sess,training_thread.local_network.geopath_update_placeholders_set[i],training_thread.local_network.geopath_update_ops_set[i],tmp,FLAGS.L,FLAGS.M);
                    print("Geopath Setting Done");
                    sess.run(flag_ops,{flag_ph:(task+1)});
                    print("=============Task "+str(task+1)+"============");
                    score_subset=np.zeros(FLAGS.B,dtype=float);
                    score_set_print=np.zeros(FLAGS.worker_hosts_num,dtype=float);
                    rand_idx=np.arange(FLAGS.worker_hosts_num-1);
                    np.random.shuffle(rand_idx);
                    rand_idx=rand_idx[:FLAGS.B];
                    while sess.run([global_step])[0] <= (MAX_TIME_STEP*(task+1)):
                        # if (sess.run([global_step])[0]) % 1000 == 0:
                        #     print("Saving summary...")
                        #     tf.logging.info('Running Summary operation on the chief.')
                        #     summary_str = sess.run(summary_op)
                        #     sv.summary_computed(sess, summary_str)
                        #     tf.logging.info('Finished running Summary operation.')
                        #
                        #     # Determine the next time for running the summary.


                        decodePath = lambda p: [np.where(l==1.0)[0] for l in p]

                        flag_sum=0;
                        for i in range(FLAGS.worker_hosts_num-1):
                            score_set_print[i]=sess.run([score_set[i]])[0];
                        for i in range(len(rand_idx)):
                            score_subset[i]=sess.run([score_set[rand_idx[i]]])[0];
                            if(score_subset[i]==-1000):
                                flag_sum=1;
                                break;
                        if(flag_sum==0):
                            vispaths = [np.array(decodePath(p)) for p in geopath_set]
                            vis.show(vispaths, 'm')

                            winner_idx=rand_idx[np.argmax(score_subset)];
                            print(str(sess.run([global_step])[0])+" Step Score: "+str(sess.run([score_set[winner_idx]])[0]));
                            for i in rand_idx:
                                if(i!=winner_idx):
                                    geopath_set[i]=np.copy(geopath_set[winner_idx]);
                                    geopath_set[i]=pathnet.mutation(geopath_set[i],FLAGS.L,FLAGS.M,FLAGS.N);
                                    tmp=np.zeros((FLAGS.L,FLAGS.M),dtype=float);
                                    for j in range(FLAGS.L):
                                        for k in range(FLAGS.M):
                                            if((geopath_set[i][j,k]==1.0)or(fixed_path[j,k]==1.0)):
                                                tmp[j,k]=1.0;
                                    pathnet.geopath_insert(sess,training_thread.local_network.geopath_update_placeholders_set[i],training_thread.local_network.geopath_update_ops_set[i],tmp,FLAGS.L,FLAGS.M);
                                sess.run(score_set_ops[i],{score_set_ph[i]:-1000})
                            rand_idx=np.arange(FLAGS.worker_hosts_num-1)
                            np.random.shuffle(rand_idx)
                            rand_idx=rand_idx[:FLAGS.B]
                        else:
                            time.sleep(2);
                    # fixed_path setting
                    fixed_path=geopath_set[winner_idx]

                    vis.set_fixed(decodePath(fixed_path), 'r' if task == 0 else 'g')
                    vis.show(vispaths, 'm')
                    print('fix')
                    for i in range(FLAGS.L):
                        for j in range(FLAGS.M):
                            if(fixed_path[i,j]==1.0):
                                sess.run(fixed_path_ops[i,j],{fixed_path_ph[i,j]:1});
                    training_thread.local_network.set_fixed_path(fixed_path);

                    # backup fixed vars
                    # FIXED_VARS_BACKUP = training_thread.local_network.get_fixed_vars();
                    # FIXED_VARS_IDX_BACKUP = training_thread.local_network.get_fixed_vars_idx();

                    # initialization of parameters except fixed_path
                    vars_idx=training_thread.local_network.get_vars_idx();
                    for i in range(len(vars_idx)):
                        if(vars_idx[i]==1.0):
                            sess.run(vars_ops[i],{vars_ph[i]:vars_backup[i]});

                vis.waitForButtonPress()
        sv.stop();