コード例 #1
0
def get_config():
    M = Model()

    dataflow = data_io
    from tensorpack.callbacks.base import Callback

    class CBSyncWeight(Callback):
        def _before_run(self, ctx):
            if self.local_step % 10 == 0:
                return [M._sync_op_pred]

    import functools
    from tensorpack.train.config import TrainConfig
    from tensorpack.callbacks.saver import ModelSaver
    from tensorpack.callbacks.graph import RunOp
    from tensorpack.callbacks.param import ScheduledHyperParamSetter, HumanHyperParamSetter, HyperParamSetterWithFunc
    from tensorpack.tfutils import sesscreate
    from tensorpack.tfutils.common import get_default_sess_config
    import tensorpack.tfutils.symbolic_functions as symbf

    sigma_beta_steering = symbf.get_scalar_var('actor/sigma_beta_steering',
                                               0.3,
                                               summary=True,
                                               trainable=False)
    sigma_beta_accel = symbf.get_scalar_var('actor/sigma_beta_accel',
                                            0.3,
                                            summary=True,
                                            trainable=False)

    return TrainConfig(
        model=M,
        data=dataflow,
        callbacks=[
            ModelSaver(),
            HyperParamSetterWithFunc(
                'learning_rate/actor',
                functools.partial(M._calc_learning_rate, 'actor')),
            HyperParamSetterWithFunc(
                'learning_rate/critic',
                functools.partial(M._calc_learning_rate, 'critic')),

            # ScheduledHyperParamSetter('learning_rate', [(20, 0.0003), (120, 0.0001)]),
            ScheduledHyperParamSetter('entropy_beta', [(80, 0.005)]),
            # HumanHyperParamSetter('learning_rate'),
            # HumanHyperParamSetter('entropy_beta'),
            ScheduledHyperParamSetter('actor/sigma_beta_accel', [(1, 0.2),
                                                                 (2, 0.01)]),
            ScheduledHyperParamSetter('actor/sigma_beta_steering',
                                      [(1, 0.1), (2, 0.01)]),
            CBSyncWeight(),
            data_io,
            # PeriodicTrigger(Evaluator(
            #     EVAL_EPISODE, ['state'], ['policy'], get_player),
            #     every_k_epochs=3),
        ] + evaluators,
        session_creator=sesscreate.NewSessionCreator(
            config=get_default_sess_config(0.5)),
        steps_per_epoch=STEPS_PER_EPOCH,
        max_epoch=1000,
    )
コード例 #2
0
ファイル: DQN-gym-Music.py プロジェクト: pkumusic/tensorpack
def get_config():
    logger.auto_set_dir()
    M = Model()
    lr = tf.Variable(0.001, trainable=False, name='learning_rate')
    tf.scalar_summary('learning_rate', lr)

    dataset_train = ExpReplay()

    return TrainConfig(
        dataset=dataset_train,  # A dataflow object for training
        optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
        callbacks=Callbacks([
            StatPrinter(),
            ModelSaver(),
            ScheduledHyperParamSetter('learning_rate',
                                      [(80, 0.0003),
                                       (120, 0.0001)])  # No interpolation
            # TODO: Some other parameters
        ]),
        session_config=get_default_sess_config(
            0.6
        ),  # Tensorflow default session config consume too much resources.
        model=M,
        step_per_epoch=STEP_PER_EPOCH,
    )
コード例 #3
0
ファイル: train.py プロジェクト: qq456cvb/UKPGAN
def main(cfg):
    print(cfg)
    
    tf.reset_default_graph()
    
    logger.set_logger_dir('tflogs', action='d')

    copyfile(hydra.utils.to_absolute_path('model.py'), 'model.py')
    copyfile(hydra.utils.to_absolute_path('dataflow.py'), 'dataflow.py')
    
    if cfg.cat_name == 'smpl':
        train_df = SMPLDataFlow(cfg, True, 1000)
        val_df = VisSMPLDataFlow(cfg, True, 1000, port=1080)
    else:
        train_df = ShapeNetDataFlow(cfg, cfg.data.train_txt, True)
        val_df = VisDataFlow(cfg, cfg.data.val_txt, False, port=1080)
    
    config = TrainConfig(
        model=Model(cfg),
        dataflow=BatchData(PrefetchData(train_df, cpu_count() // 2, cpu_count() // 2), cfg.batch_size),
        callbacks=[
            ModelSaver(),
            SimpleMovingAverage(['recon_loss', 'GAN/loss_d', 'GAN/loss_g', 'GAN/gp_loss', 'symmetry_loss'], 100),
            PeriodicTrigger(val_df, every_k_steps=30)
        ],
        monitors=tensorpack.train.DEFAULT_MONITORS() + [ScalarPrinter(enable_step=True, enable_epoch=False)],
        max_epoch=10
    )
    launch_train_with_config(config, SimpleTrainer())
コード例 #4
0
def get_config():
    #logger.auto_set_dir()
    #logger.set_logger_dir(os.path.join('train_log', LOG_DIR))
    logger.set_logger_dir(LOG_DIR)
    M = Model()
    #TODO: For count-based model, remove epsilon greedy exploration
    if PC_METHOD:
        global INIT_EXPLORATION, END_EXPLORATION, EXPLORATION_EPOCH_ANNEAL
        INIT_EXPLORATION = 1
        END_EXPLORATION = 0.01
        EXPLORATION_EPOCH_ANNEAL = 0.33
        logger.info("remove epsilon greedy")
    dataset_train = ExpReplay(
        predictor_io_names=(['state'], ['Qvalue']),
        player=get_player(train=True),
        batch_size=BATCH_SIZE,
        memory_size=MEMORY_SIZE,
        init_memory_size=INIT_MEMORY_SIZE,
        exploration=INIT_EXPLORATION,
        end_exploration=END_EXPLORATION,
        exploration_epoch_anneal=EXPLORATION_EPOCH_ANNEAL,
        update_frequency=4,
        reward_clip=(-1, 1),
        history_len=FRAME_HISTORY)

    lr = tf.Variable(0.001, trainable=False, name='learning_rate')
    tf.scalar_summary('learning_rate', lr)

    return TrainConfig(
        dataset=dataset_train,
        optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
        callbacks=Callbacks([
            StatPrinter(),
            PeriodicCallback(ModelSaver(), 5),
            ScheduledHyperParamSetter('learning_rate',
                                      [(150, 4e-4), (250, 1e-4), (350, 5e-5)]),
            RunOp(lambda: M.update_target_param()),
            dataset_train,
            PeriodicCallback(Evaluator(EVAL_EPISODE, ['state'], ['Qvalue']),
                             5),
            #HumanHyperParamSetter('learning_rate', 'hyper.txt'),
            #HumanHyperParamSetter(ObjAttrParam(dataset_train, 'exploration'), 'hyper.txt'),
        ]),
        # save memory for multiprocess evaluator
        session_config=get_default_sess_config(0.6),
        model=M,
        step_per_epoch=STEP_PER_EPOCH,
    )
コード例 #5
0
def get_config():
    logger.set_logger_dir(LOG_DIR)
    M = Model()

    name_base = str(uuid.uuid1())[:6]
    PIPE_DIR = os.environ.get('TENSORPACK_PIPEDIR', '.').rstrip('/')
    namec2s = 'ipc://{}/sim-c2s-{}'.format(PIPE_DIR, name_base)
    names2c = 'ipc://{}/sim-s2c-{}'.format(PIPE_DIR, name_base)
    procs = [
        MySimulatorWorker(k, namec2s, names2c) for k in range(SIMULATOR_PROC)
    ]
    ensure_proc_terminate(procs)
    start_proc_mask_signal(procs)

    master = MySimulatorMaster(namec2s, names2c, M)
    dataflow = BatchData(DataFromQueue(master.queue), BATCH_SIZE)

    lr = tf.Variable(0.001, trainable=False, name='learning_rate')
    tf.scalar_summary('learning_rate', lr)

    return TrainConfig(
        dataset=dataflow,
        optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
        callbacks=Callbacks([
            StatPrinter(),
            PeriodicCallback(ModelSaver(), 5),
            ScheduledHyperParamSetter('learning_rate', [(80, 0.0003),
                                                        (120, 0.0001)]),
            ScheduledHyperParamSetter('entropy_beta', [(80, 0.005)]),
            ScheduledHyperParamSetter('explore_factor', [(80, 2), (100, 3),
                                                         (120, 4), (140, 5)]),
            HumanHyperParamSetter('learning_rate'),
            HumanHyperParamSetter('entropy_beta'),
            HumanHyperParamSetter('explore_factor'),
            master,
            PeriodicCallback(
                Evaluator(EVAL_EPISODE, ['state'], ['logits'],
                          policy_dist=POLICY_DIST), 5),
        ]),
        extra_threads_procs=[master],
        session_config=get_default_sess_config(0.5),
        model=M,
        step_per_epoch=STEP_PER_EPOCH,
        max_epoch=1000,
    )
コード例 #6
0
def get_config(args=None,
               is_chief=True,
               task_index=0,
               chief_worker_hostname="",
               n_workers=1):
    logger.set_logger_dir(args.train_log_path +
                          datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_' +
                          str(task_index))

    # function to split model parameters between multiple parameter servers
    ps_strategy = tf.contrib.training.GreedyLoadBalancingStrategy(
        len(cluster['ps']), tf.contrib.training.byte_size_load_fn)
    device_function = tf.train.replica_device_setter(
        worker_device='/job:worker/task:{}/cpu:0'.format(task_index),
        cluster=cluster_spec,
        ps_strategy=ps_strategy)

    M = Model(device_function)

    name_base = str(uuid.uuid1()).replace('-', '')[:16]
    PIPE_DIR = os.environ.get('TENSORPACK_PIPEDIR', '.').rstrip('/')
    namec2s = 'ipc://{}/sim-c2s-{}'.format(PIPE_DIR, name_base)
    names2c = 'ipc://{}/sim-s2c-{}'.format(PIPE_DIR, name_base)
    procs = [
        MySimulatorWorker(k, namec2s, names2c)
        for k in range(args.simulator_procs)
    ]
    ensure_proc_terminate(procs)
    start_proc_mask_signal(procs)

    neptune_client = neptune_mp_server.Client(
        server_host=chief_worker_hostname, server_port=args.port)

    master = MySimulatorMaster(task_index,
                               neptune_client,
                               namec2s,
                               names2c,
                               M,
                               dummy=args.dummy,
                               predictor_threads=args.nr_predict_towers,
                               predict_batch_size=args.predict_batch_size,
                               do_train=args.do_train)

    # here's the data passed to the repeated data source
    dataflow = BatchData(DataFromQueue(master.queue), BATCH_SIZE)

    with tf.device(device_function):
        with tf.variable_scope(tf.get_variable_scope(), reuse=None):
            lr = tf.Variable(args.learning_rate,
                             trainable=False,
                             name='learning_rate')
    tf.summary.scalar('learning_rate', lr)

    intra_op_par = args.intra_op_par
    inter_op_par = args.inter_op_par

    session_config = get_default_sess_config(0.5)
    print("{} {}".format(intra_op_par, type(intra_op_par)))
    if intra_op_par is not None:
        session_config.intra_op_parallelism_threads = intra_op_par

    if inter_op_par is not None:
        session_config.inter_op_parallelism_threads = inter_op_par

    session_config.log_device_placement = False
    extra_arg = {
        'dummy_predictor': args.dummy_predictor,
        'intra_op_par': intra_op_par,
        'inter_op_par': inter_op_par,
        'max_steps': args.max_steps,
        'device_count': {
            'CPU': args.cpu_device_count
        },
        'threads_to_trace': args.threads_to_trace,
        'dummy': args.dummy,
        'cpu': args.cpu,
        'queue_size': args.queue_size,
        #'worker_host' : "grpc://localhost:{}".format(cluster['worker'][my_task_index].split(':')[1]),
        'worker_host': server.target,
        'is_chief': is_chief,
        'device_function': device_function,
        'n_workers': n_workers,
        'use_sync_opt': args.use_sync_opt,
        'port': args.port,
        'batch_size': BATCH_SIZE,
        'debug_charts': args.debug_charts,
        'adam_debug': args.adam_debug,
        'task_index': task_index,
        'lr': lr,
        'schedule_hyper': args.schedule_hyper,
        'experiment_dir': args.experiment_dir
    }

    print("\n\n worker host: {} \n\n".format(extra_arg['worker_host']))

    with tf.device(device_function):
        if args.optimizer == 'adam':
            optimizer = tf.train.AdamOptimizer(lr,
                                               epsilon=args.epsilon,
                                               beta1=args.beta1,
                                               beta2=args.beta2)
            if args.adam_debug:
                optimizer = MyAdamOptimizer(lr,
                                            epsilon=args.epsilon,
                                            beta1=args.beta1,
                                            beta2=args.beta2)
        elif args.optimizer == 'gd':
            optimizer = tf.train.GradientDescentOptimizer(lr)
        elif args.optimizer == 'adagrad':
            optimizer = tf.train.AdagradOptimizer(lr)
        elif args.optimizer == 'adadelta':
            optimizer = tf.train.AdadeltaOptimizer(lr, epsilon=1e-3)
        elif args.optimizer == 'momentum':
            optimizer = tf.train.MomentumOptimizer(lr, momentum=0.9)
        elif args.optimizer == 'rms':
            optimizer = tf.train.RMSPropOptimizer(lr)

        # wrap in SyncReplicasOptimizer
        if args.use_sync_opt == 1:
            if not args.adam_debug:
                optimizer = tf.train.SyncReplicasOptimizer(
                    optimizer,
                    replicas_to_aggregate=args.num_grad,
                    total_num_replicas=n_workers)
            else:
                optimizer = MySyncReplicasOptimizer(
                    optimizer,
                    replicas_to_aggregate=args.num_grad,
                    total_num_replicas=n_workers)
            extra_arg['hooks'] = optimizer.make_session_run_hook(is_chief)

    callbacks = [
        StatPrinter(), master,
        DebugLogCallback(neptune_client,
                         worker_id=task_index,
                         nr_send=args.send_debug_every,
                         debug_charts=args.debug_charts,
                         adam_debug=args.adam_debug,
                         schedule_hyper=args.schedule_hyper)
    ]

    if args.debug_charts:
        callbacks.append(
            HeartPulseCallback('heart_pulse_{}.log'.format(
                os.environ['SLURMD_NODENAME'])))

    if args.early_stopping is not None:
        args.early_stopping = float(args.early_stopping)

        if my_task_index == 1 and not args.eval_node:
            # only one worker does evaluation
            callbacks.append(
                PeriodicCallback(
                    Evaluator(EVAL_EPISODE, ['state'], ['logits'],
                              neptune_client,
                              worker_id=task_index,
                              solved_score=args.early_stopping), 2))
    elif my_task_index == 1 and not args.eval_node:
        # only 1 worker does evaluation
        callbacks.append(
            PeriodicCallback(
                Evaluator(EVAL_EPISODE, ['state'], ['logits'],
                          neptune_client,
                          worker_id=task_index), 2))

    if args.save_every != 0:
        callbacks.append(
            PeriodicPerStepCallback(
                ModelSaver(var_collections=M.vars_for_save,
                           models_dir=args.models_dir), args.save_every))

    if args.schedule_hyper and my_task_index == 2:
        callbacks.append(
            HyperParameterScheduler('learning_rate', [(20, 0.0005),
                                                      (60, 0.0001)]))
        callbacks.append(
            HyperParameterScheduler('entropy_beta', [(40, 0.005),
                                                     (80, 0.001)]))

    return TrainConfig(dataset=dataflow,
                       optimizer=optimizer,
                       callbacks=Callbacks(callbacks),
                       extra_threads_procs=[master],
                       session_config=session_config,
                       model=M,
                       step_per_epoch=STEP_PER_EPOCH,
                       max_epoch=args.max_epoch,
                       extra_arg=extra_arg)
コード例 #7
0
ファイル: train-atari.py プロジェクト: deepsense-ai/BA3C-CPU
def get_config(args=None):
    logger.set_logger_dir(args.train_log_path)
    #logger.auto_set_dir()
    M = Model()

    name_base = str(uuid.uuid1())[:6]
    PIPE_DIR = os.environ.get('TENSORPACK_PIPEDIR', '.').rstrip('/')
    namec2s = 'ipc://{}/sim-c2s-{}'.format(PIPE_DIR, name_base)
    names2c = 'ipc://{}/sim-s2c-{}'.format(PIPE_DIR, name_base)
    procs = [MySimulatorWorker(k, namec2s, names2c) for k in range(args.simulator_procs)]
    ensure_proc_terminate(procs)
    start_proc_mask_signal(procs)

    master = MySimulatorMaster(namec2s, names2c, M, dummy=args.dummy,
                               predictor_threads=args.nr_predict_towers, predict_batch_size=args.predict_batch_size,
                               do_train=args.do_train)
    
    #here's the data passed to the repeated data source
    dataflow = BatchData(DataFromQueue(master.queue), BATCH_SIZE)
    dataflow = DelayingDataSource(dataflow, args.data_source_delay)

    lr = tf.Variable(args.learning_rate, trainable=False, name='learning_rate')
    tf.scalar_summary('learning_rate', lr)

    intra_op_par = args.intra_op_par
    inter_op_par = args.inter_op_par

    session_config = get_default_sess_config(0.5)
    if intra_op_par is not None:
        session_config.intra_op_parallelism_threads = intra_op_par

    if inter_op_par is not None:
        session_config.inter_op_parallelism_threads = inter_op_par

    session_config.log_device_placement = False
    extra_arg = {
        'dummy_predictor': args.dummy_predictor,
        'intra_op_par': intra_op_par,
        'inter_op_par': inter_op_par,
        'max_steps': args.max_steps,
        'device_count': {'CPU': args.cpu_device_count},
        'threads_to_trace': args.threads_to_trace,
        'dummy': args.dummy,
        'cpu' : args.cpu,
        'queue_size' : args.queue_size
    }

    return TrainConfig(
        dataset=dataflow,
        optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
        callbacks=Callbacks([
            StatPrinter(), ModelSaver(),

            ScheduledHyperParamSetter('learning_rate', [(80, 0.0003), (120, 0.0001)]),
            ScheduledHyperParamSetter('entropy_beta', [(80, 0.005)]),
            ScheduledHyperParamSetter('explore_factor',
                [(80, 2), (100, 3), (120, 4), (140, 5)]),

            HumanHyperParamSetter('learning_rate'),
            HumanHyperParamSetter('entropy_beta'),
            HumanHyperParamSetter('explore_factor'),
            master,
            PeriodicCallback(Evaluator(EVAL_EPISODE, ['state'], ['logits']), args.epochs_for_evaluation),
        ]),
        extra_threads_procs=[master],
        session_config=session_config,
        model=M,
        step_per_epoch=STEP_PER_EPOCH,
        max_epoch=args.max_epoch,
        extra_arg=extra_arg
    )
コード例 #8
0
def get_config(ctx):
    """ We use addiional id to make it possible to run multiple instances of the same code
    We use the neputne id for an easy reference.
    piotr.milos@codilime
    """
    global HISTORY_LOGS, EXPERIMENT_ID  #Ugly hack, make it better at some point, may be ;)
    id = ctx.job.id
    EXPERIMENT_ID = hash(id)

    import montezuma_env

    ctx.job.register_action(
        "Set starting point procssor:", lambda str: set_motezuma_env_options(
            str, montezuma_env.STARTING_POINT_SELECTOR))
    ctx.job.register_action(
        "Set rewards:",
        lambda str: set_motezuma_env_options(str, montezuma_env.REWARDS_FILE))

    logger.auto_set_dir(suffix=id)

    # (self, parameters, number_of_actions, input_shape)

    M = EXPERIMENT_MODEL

    name_base = str(uuid.uuid1())[:6]
    PIPE_DIR = os.environ.get('TENSORPACK_PIPEDIR_{}'.format(id),
                              '.').rstrip('/')
    namec2s = 'ipc://{}/sim-c2s-{}-{}'.format(PIPE_DIR, name_base, id)
    names2c = 'ipc://{}/sim-s2c-{}-{}'.format(PIPE_DIR, name_base, id)
    procs = [
        MySimulatorWorker(k, namec2s, names2c) for k in range(SIMULATOR_PROC)
    ]
    ensure_proc_terminate(procs)
    start_proc_mask_signal(procs)

    master = MySimulatorMaster(namec2s, names2c, M)
    dataflow = BatchData(DataFromQueue(master.queue), BATCH_SIZE)

    # My stuff - PM
    neptuneLogger = NeptuneLogger.get_instance()
    lr = tf.Variable(0.001, trainable=False, name='learning_rate')
    tf.scalar_summary('learning_rate', lr)
    num_epochs = get_atribute(ctx, "num_epochs", 100)

    rewards_str = get_atribute(ctx, "rewards", "5 1 -200")
    with open(montezuma_env.REWARDS_FILE, "w") as file:
        file.write(rewards_str)

    if hasattr(ctx.params, "learning_rate_schedule"):
        schedule_str = str(ctx.params.learning_rate_schedule)
    else:  #Default value inhereted from tensorpack
        schedule_str = "[[80, 0.0003], [120, 0.0001]]"
    logger.info("Setting learing rate schedule:{}".format(schedule_str))
    learning_rate_scheduler = ScheduledHyperParamSetter(
        'learning_rate', json.loads(schedule_str))

    if hasattr(ctx.params, "entropy_beta_schedule"):
        schedule_str = str(ctx.params.entropy_beta_schedule)
    else:  #Default value inhereted from tensorpack
        schedule_str = "[[80, 0.0003], [120, 0.0001]]"
    logger.info("Setting entropy beta schedule:{}".format(schedule_str))
    entropy_beta_scheduler = ScheduledHyperParamSetter(
        'entropy_beta', json.loads(schedule_str))

    if hasattr(ctx.params, "explore_factor_schedule"):
        schedule_str = str(ctx.params.explore_factor_schedule)
    else:  #Default value inhereted from tensorpack
        schedule_str = "[[80, 2], [100, 3], [120, 4], [140, 5]]"
    logger.info("Setting explore factor schedule:{}".format(schedule_str))
    explore_factor_scheduler = ScheduledHyperParamSetter(
        'explore_factor', json.loads(schedule_str))

    return TrainConfig(
        dataset=dataflow,
        optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
        callbacks=Callbacks([
            StatPrinter(),
            ModelSaver(),
            learning_rate_scheduler,
            entropy_beta_scheduler,
            explore_factor_scheduler,
            HumanHyperParamSetter('learning_rate'),
            HumanHyperParamSetter('entropy_beta'),
            HumanHyperParamSetter('explore_factor'),
            NeputneHyperParamSetter('learning_rate', ctx),
            NeputneHyperParamSetter('entropy_beta', ctx),
            NeputneHyperParamSetter('explore_factor', ctx),
            master,
            StartProcOrThread(master),
            PeriodicCallback(
                Evaluator(EVAL_EPISODE, ['state'], ['logits'], neptuneLogger,
                          HISTORY_LOGS), 1),
            neptuneLogger,
        ]),
        session_config=get_default_sess_config(0.5),
        model=M,
        step_per_epoch=STEP_PER_EPOCH,
        max_epoch=num_epochs,
    )
コード例 #9
0
ファイル: main.py プロジェクト: waxz/ppo_torcs
def get_config():
    M = Model()

    name_base = str(uuid.uuid1())[:6]
    PIPE_DIR = os.environ.get('TENSORPACK_PIPEDIR',
                              '/tmp/.ipcpipe').rstrip('/')
    if not os.path.exists(PIPE_DIR): os.makedirs(PIPE_DIR)
    else: os.system('rm -f {}/sim-*'.format(PIPE_DIR))
    namec2s = 'ipc://{}/sim-c2s-{}'.format(PIPE_DIR, name_base)
    names2c = 'ipc://{}/sim-s2c-{}'.format(PIPE_DIR, name_base)
    # AgentTorcs * SIMULATOR_PROC, AgentReplay * SIMULATOR_PROC
    procs = [
        MySimulatorWorker(k, namec2s, names2c)
        for k in range(SIMULATOR_PROC * 2)
    ]
    ensure_proc_terminate(procs)
    start_proc_mask_signal(procs)

    master = MySimulatorMaster(namec2s, names2c, M)
    dataflow = BatchData(DataFromQueue(master.queue), BATCH_SIZE)

    class CBSyncWeight(Callback):
        def _after_run(self, ctx, _):
            if self.local_step > 1 and self.local_step % SIMULATOR_PROC == 0:
                # print("before step ",self.local_step)
                return [M._td_sync_op]

        def _before_run(self, ctx):

            if self.local_step % 10 == 0:
                return [M._sync_op, M._td_sync_op]
            if self.local_step % SIMULATOR_PROC == 0 and 0:
                return [M._td_sync_op]

    import functools
    return TrainConfig(
        model=M,
        dataflow=dataflow,
        callbacks=[
            ModelSaver(),
            HyperParamSetterWithFunc(
                'learning_rate/actor',
                functools.partial(M._calc_learning_rate, 'actor')),
            HyperParamSetterWithFunc(
                'learning_rate/critic',
                functools.partial(M._calc_learning_rate, 'critic')),

            # ScheduledHyperParamSetter('learning_rate', [(20, 0.0003), (120, 0.0001)]),
            ScheduledHyperParamSetter('entropy_beta', [(80, 0.005)]),
            # HumanHyperParamSetter('learning_rate'),
            # HumanHyperParamSetter('entropy_beta'),
            # ScheduledHyperParamSetter('actor/sigma_beta_accel', [(1, 0.2), (2, 0.01), (3, 1e-3), (4, 1e-4)]),
            # ScheduledHyperParamSetter('actor/sigma_beta_steering', [(1, 0.1), (2, 0.01), (3, 1e-3), (4, 1e-4)]),
            master,
            StartProcOrThread(master),
            CBSyncWeight(),
            # CBTDSyncWeight()
            # PeriodicTrigger(Evaluator(
            #     EVAL_EPISODE, ['state'], ['policy'], get_player),
            #     every_k_epochs=3),
        ],
        session_creator=sesscreate.NewSessionCreator(
            config=get_default_sess_config(0.5)),
        steps_per_epoch=STEPS_PER_EPOCH,
        max_epoch=1000,
    )