Esempio n. 1
0
def get_config():
    logger.auto_set_dir()

    M = Model()
    expreplay = ExpReplay(predictor_io_names=(['state'], ['Qvalue']),
                          player=get_player(train=True),
                          batch_size=BATCH_SIZE,
                          memory_size=MEMORY_SIZE,
                          init_memory_size=INIT_MEMORY_SIZE,
                          exploration=INIT_EXPLORATION,
                          end_exploration=END_EXPLORATION,
                          exploration_epoch_anneal=EXPLORATION_EPOCH_ANNEAL,
                          update_frequency=4,
                          reward_clip=(-1, 1),
                          history_len=FRAME_HISTORY)

    return TrainConfig(
        dataflow=expreplay,
        callbacks=[
            ModelSaver(),
            ScheduledHyperParamSetter('learning_rate',
                                      [(150, 4e-4), (250, 1e-4), (350, 5e-5)]),
            RunOp(lambda: M.update_target_param()),
            expreplay,
            StartProcOrThread(expreplay.get_simulator_thread()),
            PeriodicCallback(Evaluator(EVAL_EPISODE, ['state'], ['Qvalue']),
                             3),
            # HumanHyperParamSetter('learning_rate', 'hyper.txt'),
            # HumanHyperParamSetter(ObjAttrParam(expreplay, 'exploration'), 'hyper.txt'),
        ],
        model=M,
        steps_per_epoch=STEP_PER_EPOCH,
        # run the simulator on a separate GPU if available
        predict_tower=[1] if get_nr_gpu() > 1 else [0],
    )
Esempio n. 2
0
def get_config(args, env_conf):
    expreplay = ExpReplay(predictor_io_names=(['state'], ['Qvalue']),
                          player=get_player(env_conf, train=True),
                          state_shape=tuple(env_conf["observation_shape"]),
                          batch_size=args.batch_size,
                          memory_size=MEMORY_SIZE,
                          init_memory_size=INIT_MEMORY_SIZE,
                          init_exploration=1.0,
                          update_frequency=UPDATE_FREQ,
                          history_len=1)

    return AutoResumeTrainConfig(
        data=QueueInput(expreplay),
        model=Model(args, env_conf),
        callbacks=[
            ModelSaver(),
            PeriodicTrigger(RunOp(DQNModel.update_target_param, verbose=True),
                            every_k_steps=TARGET_NET_UPDATE
                            ),  # update target network every 10k steps
            expreplay,
            ScheduledHyperParamSetter('learning_rate', LEARNING_RATE),
            ScheduledHyperParamSetter(
                ObjAttrParam(expreplay, 'exploration'),
                EXPLORATION,  # 1->0.1 in the first million steps
                interp='linear'),
            PeriodicTrigger(LogVisualizeEpisode(['state'], ['Qvalue'],
                                                get_player_test(env_conf)),
                            every_k_epochs=args.log_period),
            HumanHyperParamSetter('learning_rate'),
        ],
        steps_per_epoch=STEPS_PER_EPOCH,
        max_epoch=800,
    )
Esempio n. 3
0
def get_config(
    files_list,
    input_names=["state_1", "state_2"],
    output_names=["Qvalue_1", "Qvalue_2"],
    agents=2,
):
    """This is only used during training."""
    expreplay = ExpReplay(
        predictor_io_names=(input_names, output_names),
        player=get_player(task="train", files_list=files_list, agents=agents),
        state_shape=IMAGE_SIZE,
        batch_size=BATCH_SIZE,
        memory_size=MEMORY_SIZE,
        init_memory_size=INIT_MEMORY_SIZE,
        init_exploration=1.0,
        update_frequency=UPDATE_FREQ,
        history_len=FRAME_HISTORY,
        agents=agents,
    )

    return TrainConfig(
        # dataflow=expreplay,
        data=QueueInput(expreplay),
        model=Model(agents=agents),
        callbacks=[
            ModelSaver(),
            PeriodicTrigger(
                RunOp(DQNModel.update_target_param, verbose=True),
                # update target network every 10k steps
                every_k_steps=10000 // UPDATE_FREQ,
            ),
            expreplay,
            ScheduledHyperParamSetter("learning_rate", [(60, 4e-4),
                                                        (100, 2e-4)]),
            ScheduledHyperParamSetter(
                ObjAttrParam(expreplay, "exploration"),
                # 1->0.1 in the first million steps
                [(0, 1), (10, 0.1), (320, 0.01)],
                interp="linear",
            ),
            PeriodicTrigger(
                Evaluator(
                    nr_eval=EVAL_EPISODE,
                    input_names=input_names,
                    output_names=output_names,
                    files_list=files_list,
                    get_player_fn=get_player,
                    agents=agents,
                ),
                every_k_epochs=EPOCHS_PER_EVAL,
            ),
            HumanHyperParamSetter("learning_rate"),
        ],
        steps_per_epoch=STEPS_PER_EPOCH,
        max_epoch=1000,
    )
Esempio n. 4
0
def main():
    BATCH_SIZE = 64
    IMAGE_SIZE = (84, 75)
    FRAME_HISTORY = 4
    UPDATE_FREQ = 4  # the number of new state transitions per parameter update (per training step)
    MEMORY_SIZE = 10**6
    INIT_MEMORY_SIZE = MEMORY_SIZE // 20
    NUM_PARALLEL_PLAYERS = 1
    MIN_EPSILON = 0.1
    START_EPSILON = 1.0
    STOP_EPSILON_DECAY_AT = 250000

    adapter = TfAdapter(IMAGE_SIZE, FRAME_HISTORY)
    # adapter = TorchAdapter()
    summary_writer = tf.summary.create_file_writer(
        datetime.datetime.now().strftime('logs/%d-%m-%Y_%H-%M'))
    expreplay = ExpReplay(adapter.infer,
                          get_player=get_player,
                          num_parallel_players=NUM_PARALLEL_PLAYERS,
                          state_shape=IMAGE_SIZE,
                          batch_size=BATCH_SIZE,
                          memory_size=MEMORY_SIZE,
                          init_memory_size=INIT_MEMORY_SIZE,
                          update_frequency=UPDATE_FREQ,
                          history_len=FRAME_HISTORY,
                          state_dtype=np.float32)
    expreplay._before_train()
    for step_idx, batch in enumerate(expreplay):
        adapter.train_step(batch)
        if expreplay.exploration > MIN_EPSILON:
            expreplay.exploration -= (START_EPSILON -
                                      MIN_EPSILON) / STOP_EPSILON_DECAY_AT
        if step_idx > 0 and step_idx % 5000 == 0:
            adapter.update_target()
            mean, max = expreplay.runner.reset_stats()
            with summary_writer.as_default():
                tf.summary.scalar('expreplay/mean_score', mean, step_idx)
                tf.summary.scalar('expreplay/max_score', max, step_idx)
                summary_writer.flush()
Esempio n. 5
0
def get_config():
    """This is only used during training."""
    expreplay = ExpReplay(predictor_io_names=(['state'], ['Qvalue']),
                          player=get_player(directory=data_dir,
                                            task='train',
                                            files_list=train_data_fpaths),
                          state_shape=OBSERVATION_DIMS,
                          batch_size=BATCH_SIZE,
                          memory_size=MEMORY_SIZE,
                          init_memory_size=INIT_MEMORY_SIZE,
                          init_exploration=1.0,
                          update_frequency=UPDATE_FREQ,
                          frame_history_len=FRAME_HISTORY)

    return TrainConfig(
        # dataflow=expreplay,
        data=QueueInput(expreplay),
        model=Model(),
        callbacks=[  # TODO: periodically save videos
            ModelSaver(checkpoint_dir="model_checkpoints",
                       keep_checkpoint_every_n_hours=0.25,
                       max_to_keep=1000),  # TODO: og was just ModelSaver()
            PeriodicTrigger(
                RunOp(DQNModel.update_target_param, verbose=True),
                # update target network every 10k/freq steps
                every_k_steps=10000 // UPDATE_FREQ),
            # expreplay,
            ScheduledHyperParamSetter('learning_rate', [(60, 4e-4),
                                                        (100, 2e-4)]),
            ScheduledHyperParamSetter(
                ObjAttrParam(expreplay, 'exploration'),
                # 1->0.1 in the first 10M steps
                [(0, 1), (100, 0.1), (120, 0.01)],
                interp='linear'),
            PeriodicTrigger(  # runs exprelay._trigger()
                expreplay, every_k_steps=5000),
            PeriodicTrigger(
                # eval_model_multithread(pred, EVAL_EPISODE, get_player)
                Evaluator(nr_eval=EVAL_EPISODE,
                          input_names=['state'],
                          output_names=['Qvalue'],
                          directory=data_dir,
                          files_list=test_data_fpaths,
                          get_player_fn=get_player),
                every_k_steps=10000 // UPDATE_FREQ),
            HumanHyperParamSetter('learning_rate'),
        ],
        steps_per_epoch=STEPS_PER_EPOCH,
        max_epoch=NUM_EPOCHS,
    )
Esempio n. 6
0
def get_config():
    logger.auto_set_dir()

    M = Model()
    expreplay = ExpReplay(
        predictor_io_names=(['state'], ['Qvalue']),
        player=get_player(train=True),
        state_shape=IMAGE_SIZE,
        batch_size=BATCH_SIZE,
        memory_size=MEMORY_SIZE,
        init_memory_size=INIT_MEMORY_SIZE,
        exploration=INIT_EXPLORATION,
        end_exploration=END_EXPLORATION,
        exploration_epoch_anneal=EXPLORATION_EPOCH_ANNEAL,
        update_frequency=4,
        history_len=FRAME_HISTORY
    )

    def update_target_param():
        vars = tf.trainable_variables()
        ops = []
        G = tf.get_default_graph()
        for v in vars:
            target_name = v.op.name
            if target_name.startswith('target'):
                new_name = target_name.replace('target/', '')
                logger.info("{} <- {}".format(target_name, new_name))
                ops.append(v.assign(G.get_tensor_by_name(new_name + ':0')))
        return tf.group(*ops, name='update_target_network')

    return TrainConfig(
        dataflow=expreplay,
        callbacks=[
            ModelSaver(),
            ScheduledHyperParamSetter('learning_rate',
                                      [(150, 4e-4), (250, 1e-4), (350, 5e-5)]),
            RunOp(update_target_param),
            expreplay,
            PeriodicTrigger(Evaluator(
                EVAL_EPISODE, ['state'], ['Qvalue']), every_k_epochs=5),
            # HumanHyperParamSetter('learning_rate', 'hyper.txt'),
            # HumanHyperParamSetter(ObjAttrParam(expreplay, 'exploration'), 'hyper.txt'),
        ],
        model=M,
        steps_per_epoch=STEPS_PER_EPOCH,
        # run the simulator on a separate GPU if available
        predict_tower=[1] if get_nr_gpu() > 1 else [0],
    )
Esempio n. 7
0
def get_config(model):
    global args
    expreplay = ExpReplay(predictor_io_names=(['state'], ['Qvalue']),
                          get_player=lambda: get_player(train=True),
                          num_parallel_players=NUM_PARALLEL_PLAYERS,
                          state_shape=model.state_shape,
                          batch_size=BATCH_SIZE,
                          memory_size=MEMORY_SIZE,
                          init_memory_size=INIT_MEMORY_SIZE,
                          update_frequency=UPDATE_FREQ,
                          history_len=FRAME_HISTORY,
                          state_dtype=model.state_dtype.as_numpy_dtype)

    # Set to other values if you need a different initial exploration
    # (e.g., # if you're resuming a training half-way)
    # expreplay.exploration = 1.0

    return TrainConfig(
        data=QueueInput(expreplay),
        model=model,
        callbacks=[
            ModelSaver(),
            PeriodicTrigger(
                RunOp(DQNModel.update_target_param, verbose=True),
                every_k_steps=10000 //
                UPDATE_FREQ),  # update target network every 10k steps
            expreplay,
            ScheduledHyperParamSetter('learning_rate', [(0, 1e-3), (60, 4e-4),
                                                        (100, 2e-4),
                                                        (500, 5e-5)]),
            ScheduledHyperParamSetter(
                ObjAttrParam(expreplay, 'exploration'),
                [(0, 1), (10, 0.1),
                 (320, 0.01)],  # 1->0.1 in the first million steps
                interp='linear'),
            PeriodicTrigger(Evaluator(EVAL_EPISODE, ['state'], ['Qvalue'],
                                      get_player),
                            every_k_epochs=5 if 'pong' in args.env.lower() else
                            10),  # eval more frequently for easy games
            HumanHyperParamSetter('learning_rate'),
        ],
        steps_per_epoch=STEPS_PER_EPOCH,
        max_epoch=800,
    )
Esempio n. 8
0
def get_config():
    """This is only used during training."""
    expreplay = ExpReplay(predictor_io_names=(['state'], ['Qvalue']),
                          player=get_player(directory=data_dir,
                                            task='train',
                                            files_list=train_list),
                          state_shape=IMAGE_SIZE,
                          batch_size=BATCH_SIZE,
                          memory_size=MEMORY_SIZE,
                          init_memory_size=INIT_MEMORY_SIZE,
                          init_exploration=1.0,
                          update_frequency=UPDATE_FREQ,
                          history_len=FRAME_HISTORY)

    return TrainConfig(
        # dataflow=expreplay,
        data=QueueInput(expreplay),
        model=Model(),
        callbacks=[
            ModelSaver(),
            PeriodicTrigger(
                RunOp(DQNModel.update_target_param, verbose=True),
                # update target network every 10k steps
                every_k_steps=10000 // UPDATE_FREQ),
            expreplay,
            ScheduledHyperParamSetter('learning_rate', [(60, 4e-4),
                                                        (100, 2e-4)]),
            ScheduledHyperParamSetter(
                ObjAttrParam(expreplay, 'exploration'),
                # 1->0.1 in the first million steps
                [(0, 1), (10, 0.1), (320, 0.01)],
                interp='linear'),
            PeriodicTrigger(Evaluator(nr_eval=EVAL_EPISODE,
                                      input_names=['state'],
                                      output_names=['Qvalue'],
                                      directory=data_dir,
                                      files_list=test_list,
                                      get_player_fn=get_player),
                            every_k_epochs=EPOCHS_PER_EVAL),
            HumanHyperParamSetter('learning_rate'),
        ],
        steps_per_epoch=STEPS_PER_EPOCH,
        max_epoch=1000,
    )
Esempio n. 9
0
def get_config():
    expreplay = ExpReplay(
        predictor_io_names=(['state'], ['Qvalue']),
        player=get_player(train=True),
        state_shape=IMAGE_SIZE + (IMAGE_CHANNEL,),
        batch_size=BATCH_SIZE,
        memory_size=MEMORY_SIZE,
        init_memory_size=INIT_MEMORY_SIZE,
        init_exploration=1.0,
        update_frequency=UPDATE_FREQ,
        history_len=FRAME_HISTORY
    )

    return AutoResumeTrainConfig(
        data=QueueInput(expreplay),
        model=Model(),
        callbacks=[
            ModelSaver(),
            PeriodicTrigger(
                RunOp(DQNModel.update_target_param, verbose=True),
                every_k_steps=TARGET_NET_UPDATE),    # update target network every 10k steps
            expreplay,
            ScheduledHyperParamSetter('learning_rate',
                                      LEARNING_RATE),
            ScheduledHyperParamSetter(
                ObjAttrParam(expreplay, 'exploration'),
                EXPLORATION,   # 1->0.1 in the first million steps
                interp='linear'),
            # PeriodicTrigger(Evaluator(
            #     EVAL_EPISODE, ['state'], ['Qvalue'], get_player),
            #     every_k_epochs=10),

            PeriodicTrigger(LogVisualizeEpisode(
                ['state'], ['Qvalue'], get_player),
                every_k_epochs=EPISODE_LOG_PERIOD),

            HumanHyperParamSetter('learning_rate'),
        ],
        steps_per_epoch=STEPS_PER_EPOCH,
        max_epoch=800,
    )
Esempio n. 10
0
def get_config():
    M = Model()
    expreplay = ExpReplay(predictor_io_names=(['state'], ['Qvalue']),
                          player=get_player(train=True),
                          state_shape=IMAGE_SIZE,
                          batch_size=BATCH_SIZE,
                          memory_size=MEMORY_SIZE,
                          init_memory_size=INIT_MEMORY_SIZE,
                          init_exploration=1.0,
                          update_frequency=UPDATE_FREQ,
                          history_len=FRAME_HISTORY)

    return TrainConfig(
        dataflow=expreplay,
        callbacks=[
            ModelSaver(),
            PeriodicTrigger(
                RunOp(DQNModel.update_target_param, verbose=True),
                every_k_steps=10000 //
                UPDATE_FREQ),  # update target network every 10k steps
            expreplay,
            ScheduledHyperParamSetter('learning_rate', [(60, 4e-4),
                                                        (100, 2e-4)]),
            ScheduledHyperParamSetter(
                ObjAttrParam(expreplay, 'exploration'),
                [(0, 1), (10, 0.1),
                 (320, 0.01)],  # 1->0.1 in the first million steps
                interp='linear'),
            PeriodicTrigger(Evaluator(EVAL_EPISODE, ['state'], ['Qvalue'],
                                      get_player),
                            every_k_epochs=10),
            HumanHyperParamSetter('learning_rate'),
        ],
        model=M,
        steps_per_epoch=STEPS_PER_EPOCH,
        max_epoch=1000,
        # run the simulator on a separate GPU if available
        predict_tower=[1] if get_nr_gpu() > 1 else [0],
    )
Esempio n. 11
0
def get_config():
    expreplay = ExpReplay(
        predictor_io_names=(['state', 'history'], ['stage1/Qvalue']),
        predictor_refine_io_names=(['state_refine',
                                    'history_refine'], ['stage2/Qvalue']),
        env=get_player(test=False),
        state_shape=STATE_SHAPE,
        batch_size=BATCH_SIZE,
        memory_size=MEMORY_SIZE,
        init_memory_size=INIT_MEMORY_SIZE,
        init_exploration=1.,
        update_frequency=UPDATE_FREQ)

    # ds = FakeData([(2, 2, *STATE_SHAPE), [2], [2], [2], [2]], dtype=['float32', 'int64', 'float32', 'bool', 'bool'])
    # ds = PrefetchData(ds, nr_prefetch=6, nr_proc=2)
    return AutoResumeTrainConfig(
        data=QueueInput(expreplay),
        model=Model(),
        callbacks=[
            Evaluator(EVAL_EPISODE, ['state', 'history'], ['stage1/Qvalue'],
                      ['state_refine', 'history_refine'], ['stage2/Qvalue'],
                      partial(get_player, True)),
            ModelSaver(),
            PeriodicTrigger(RunOp(DQNModel.update_target_param, verbose=True),
                            every_k_steps=STEPS_PER_EPOCH //
                            10),  # update target network every 10k steps
            expreplay,
            # ScheduledHyperParamSetter('learning_rate',
            #                           [(60, 4e-4), (100, 2e-4)]),
            ScheduledHyperParamSetter(
                ObjAttrParam(expreplay, 'exploration'),
                [(0, 1.), (9, 0.1)],  # 1->0.1 in the first million steps
                interp='linear'),
            HumanHyperParamSetter('learning_rate'),
        ],
        # session_init=SaverRestore("/home/neil/PycharmProjects/RPN-RL-master/save/resnet_v2_50.ckpt"),
        steps_per_epoch=STEPS_PER_EPOCH,
        max_epoch=1000,
    )
Esempio n. 12
0
def get_config():
    expreplay = ExpReplay(
        predictor_io_names=(['state'], ['Qvalue']),
        player=get_player(train=True),
        state_shape=IMAGE_SHAPE3,
        batch_size=BATCH_SIZE,
        memory_size=MEMORY_SIZE,
        init_memory_size=INIT_MEMORY_SIZE,
        init_exploration=1.0,
        update_frequency=UPDATE_FREQ,
        history_len=FRAME_HISTORY
    )
    config = TrainConfig(
        model=Model(),
        dataflow=dataflow,
        callbacks=[
            ModelSaver(),
            ScheduledHyperParamSetter('learning_rate', [(20, 0.0003), (120, 0.0001)]),
            ScheduledHyperParamSetter('entropy_beta', [(80, 0.005)]),
            HumanHyperParamSetter('learning_rate'),
            HumanHyperParamSetter('entropy_beta'),
            master,
            StartProcOrThread(master),
            PeriodicTrigger(Evaluator(
                EVAL_EPISODE, ['state'], ['policy'], get_player),
                every_k_epochs=3),
            expreplay,
            PeriodicTrigger(LogVisualizeEpisode(
                ['state'], ['policy'], get_player),
                every_k_epochs=1),
        ],
        session_creator=sesscreate.NewSessionCreator(
            config=get_default_sess_config(0.5)),
        steps_per_epoch=STEPS_PER_EPOCH,
        session_init=get_model_loader(args.load) if args.load else None,
        max_epoch=1000,
    )
    return  config
Esempio n. 13
0
def train():
    dirname = os.path.join('train_log', 'A3C-{}'.format(ENV_NAME))
    logger.set_logger_dir(dirname)

    # assign GPUs for training & inference
    num_gpu = get_num_gpu()
    global PREDICTOR_THREAD
    if num_gpu > 0:
        if num_gpu > 1:
            # use half gpus for inference
            predict_tower = list(range(num_gpu))[-num_gpu // 2:]
        else:
            predict_tower = [0]
        PREDICTOR_THREAD = len(predict_tower) * PREDICTOR_THREAD_PER_GPU
        train_tower = list(range(num_gpu))[:-num_gpu // 2] or [0]
        logger.info("[Batch-A3C] Train on gpu {} and infer on gpu {}".format(
            ','.join(map(str, train_tower)), ','.join(map(str, predict_tower))))
    else:
        logger.warn("Without GPU this model will never learn! CPU is only useful for debug.")
        PREDICTOR_THREAD = 1
        predict_tower, train_tower = [0], [0]

    # setup simulator processes
    name_base = str(uuid.uuid1())[:6]
    prefix = '@' if sys.platform.startswith('linux') else ''
    namec2s = 'ipc://{}sim-c2s-{}'.format(prefix, name_base)
    names2c = 'ipc://{}sim-s2c-{}'.format(prefix, name_base)
    procs = [MySimulatorWorker(k, namec2s, names2c) for k in range(SIMULATOR_PROC)]
    ensure_proc_terminate(procs)
    start_proc_mask_signal(procs)

    master = MySimulatorMaster(namec2s, names2c, predict_tower)
    dataflow = BatchData(DataFromQueue(master.queue), BATCH_SIZE)
    # config = TrainConfig(
    #     model=Model(),
    #     dataflow=dataflow,
    #     callbacks=[
    #         ModelSaver(),
    #         ScheduledHyperParamSetter('learning_rate', [(20, 0.0003), (120, 0.0001)]),
    #         ScheduledHyperParamSetter('entropy_beta', [(80, 0.005)]),
    #         HumanHyperParamSetter('learning_rate'),
    #         HumanHyperParamSetter('entropy_beta'),
    #         master,
    #         StartProcOrThread(master),
    #         PeriodicTrigger(Evaluator(
    #             EVAL_EPISODE, ['state'], ['policy'], get_player),
    #             every_k_epochs=3),
    #         PeriodicTrigger(LogVisualizeEpisode(
    #             ['state'], ['policy'], get_player),
    #             every_k_epochs=1),
    #     ],
    #     session_creator=sesscreate.NewSessionCreator(
    #         config=get_default_sess_config(0.5)),
    #     steps_per_epoch=STEPS_PER_EPOCH,
    #     session_init=get_model_loader(args.load) if args.load else None,
    #     max_epoch=1000,
    # )
    # config = get_config()
    expreplay = ExpReplay(
        predictor_io_names=(['state'], ['policy']),
        player=get_player(train=True),
        state_shape=IMAGE_SHAPE3,
        batch_size=BATCH_SIZE,
        memory_size=MEMORY_SIZE,
        init_memory_size=INIT_MEMORY_SIZE,
        init_exploration=1.0,
        update_frequency=UPDATE_FREQ,
        history_len=FRAME_HISTORY
    )
    config = TrainConfig(
        model=Model(),
        dataflow=dataflow,
        callbacks=[
            ModelSaver(),
            ScheduledHyperParamSetter('learning_rate', [(20, 0.0003), (120, 0.0001)]),
            ScheduledHyperParamSetter('entropy_beta', [(80, 0.005)]),
            HumanHyperParamSetter('learning_rate'),
            HumanHyperParamSetter('entropy_beta'),
            master,
            StartProcOrThread(master),
            PeriodicTrigger(Evaluator(
                EVAL_EPISODE, ['state'], ['policy'], get_player),
                every_k_epochs=3),
            expreplay,
            ScheduledHyperParamSetter(
                ObjAttrParam(expreplay, 'exploration'),
                [(0, 1), (10, 0.9), (50, 0.1), (320, 0.01)],   # 1->0.1 in the first million steps
                interp='linear'),
            PeriodicTrigger(LogVisualizeEpisode(
                ['state'], ['policy'], get_player),
                every_k_epochs=1),
        ],
        session_creator=sesscreate.NewSessionCreator(
            config=get_default_sess_config(0.5)),
        steps_per_epoch=STEPS_PER_EPOCH,
        session_init=get_model_loader(args.load) if args.load else None,
        max_epoch=1000,
    )
    trainer = SimpleTrainer() if config.nr_tower == 1 else AsyncMultiGPUTrainer(train_tower)
    launch_train_with_config(config, trainer)