示例#1
0
文件: train.py 项目: luckti/minigo
def train(*tf_records: "Records to train on"):
    """Train on examples."""
    tf.logging.set_verbosity(tf.logging.INFO)
    estimator = dual_net.get_estimator()

    effective_batch_size = FLAGS.train_batch_size
    if FLAGS.use_tpu:
        effective_batch_size *= FLAGS.num_tpu_cores

    if FLAGS.use_bt:
        games = bigtable_input.GameQueue(
            FLAGS.cbt_project, FLAGS.cbt_instance, FLAGS.cbt_table)
        games_nr = bigtable_input.GameQueue(
            FLAGS.cbt_project, FLAGS.cbt_instance, FLAGS.cbt_table + '-nr')
    if FLAGS.use_tpu:
        if FLAGS.use_bt:
            def _input_fn(params):
                return preprocessing.get_tpu_bt_input_tensors(
                    games,
                    games_nr,
                    params['batch_size'],
                    number_of_games=FLAGS.window_size,
                    random_rotation=True)
        else:
            def _input_fn(params):
                return preprocessing.get_tpu_input_tensors(
                    params['batch_size'],
                    tf_records,
                    random_rotation=True)
        # Hooks are broken with TPUestimator at the moment.
        hooks = []
    else:
        def _input_fn():
            return preprocessing.get_input_tensors(
                FLAGS.train_batch_size,
                tf_records,
                filter_amount=1.0,
                shuffle_buffer_size=FLAGS.shuffle_buffer_size,
                random_rotation=True)

        hooks = [UpdateRatioSessionHook(FLAGS.work_dir),
                 EchoStepCounterHook(output_dir=FLAGS.work_dir)]

    steps = FLAGS.steps_to_train
    logging.info("Training, steps = %s, batch = %s -> %s examples",
                 steps or '?', effective_batch_size,
                 (steps * effective_batch_size) if steps else '?')
    estimator.train(_input_fn, steps=steps, hooks=hooks)

    if FLAGS.use_bt:
        bigtable_input.set_fresh_watermark(games, FLAGS.window_size)
示例#2
0
文件: train.py 项目: zhiwuya/minigo
def train(*tf_records: "Records to train on"):
    """Train on examples."""
    tf.logging.set_verbosity(tf.logging.INFO)
    estimator = dual_net.get_estimator()

    effective_batch_size = FLAGS.train_batch_size
    if FLAGS.use_tpu:
        effective_batch_size *= FLAGS.num_tpu_cores

    if FLAGS.use_tpu:
        if FLAGS.use_bt:
            def _input_fn(params):
                games = bigtable_input.GameQueue(
                    FLAGS.cbt_project, FLAGS.cbt_instance, FLAGS.cbt_table)
                games_nr = bigtable_input.GameQueue(
                    FLAGS.cbt_project, FLAGS.cbt_instance, FLAGS.cbt_table + '-nr')
                return preprocessing.get_tpu_bt_input_tensors(
                    games,
                    games_nr,
                    params['batch_size'],
                    params['input_layout'],
                    number_of_games=FLAGS.window_size,
                    random_rotation=True)
        else:
            def _input_fn(params):
                return preprocessing.get_tpu_input_tensors(
                    params['batch_size'],
                    params['input_layout'],
                    tf_records,
                    filter_amount=FLAGS.filter_amount,
                    shuffle_examples=FLAGS.shuffle_examples,
                    shuffle_buffer_size=FLAGS.shuffle_buffer_size,
                    random_rotation=True)
        # Hooks are broken with TPUestimator at the moment.
        hooks = []
    else:
        def _input_fn():
            return preprocessing.get_input_tensors(
                FLAGS.train_batch_size,
                FLAGS.input_layout,
                tf_records,
                filter_amount=FLAGS.filter_amount,
                shuffle_examples=FLAGS.shuffle_examples,
                shuffle_buffer_size=FLAGS.shuffle_buffer_size,
                random_rotation=True)

        hooks = [UpdateRatioSessionHook(FLAGS.work_dir),
                 EchoStepCounterHook(output_dir=FLAGS.work_dir)]

    steps = FLAGS.steps_to_train
    if not steps and FLAGS.num_examples:
        batch_size = FLAGS.train_batch_size
        if FLAGS.use_tpu:
            batch_size *= FLAGS.num_tpu_cores
        steps = math.floor(FLAGS.num_examples / batch_size)

    logging.info("Training, steps = %s, batch = %s -> %s examples",
                 steps or '?', effective_batch_size,
                 (steps * effective_batch_size) if steps else '?')

    if FLAGS.use_bt:
        games = bigtable_input.GameQueue(
            FLAGS.cbt_project, FLAGS.cbt_instance, FLAGS.cbt_table)
        if not games.read_wait_cell():
            games.require_fresh_games(20000)
        latest_game = games.latest_game_number
        index_from = max(latest_game, games.read_wait_cell())
        print("== Last game before training:", latest_game, flush=True)
        print("== Wait cell:", games.read_wait_cell(), flush=True)

    try:
        estimator.train(_input_fn, steps=steps, hooks=hooks)
        if FLAGS.use_bt:
            bigtable_input.set_fresh_watermark(games, index_from,
                                               FLAGS.window_size)
    except:
        if FLAGS.use_bt:
            games.require_fresh_games(0)
        raise
示例#3
0
def train(*tf_records: "Records to train on"):
    """Train on examples."""

    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
    estimator = dual_net.get_estimator(FLAGS.num_intra_threads,
                                       FLAGS.num_inter_threads)

    if FLAGS.dist_train:
        effective_batch_size = int(FLAGS.train_batch_size / hvd.size())
        global_batch_size = effective_batch_size * hvd.size()
        mllogger = mllog.get_mllogger()
        mllogger.event(key=mllog.constants.GLOBAL_BATCH_SIZE,
                       value=global_batch_size)
    else:
        effective_batch_size = FLAGS.train_batch_size
        global_batch_size = FLAGS.train_batch_size

    logging.info("Real global batch size = {}, local batch size = {}.".format(
        global_batch_size, effective_batch_size))

    if FLAGS.use_tpu:
        effective_batch_size *= FLAGS.num_tpu_cores

    if FLAGS.use_tpu:
        if FLAGS.use_bt:

            def _input_fn(params):
                games = bigtable_input.GameQueue(FLAGS.cbt_project,
                                                 FLAGS.cbt_instance,
                                                 FLAGS.cbt_table)
                games_nr = bigtable_input.GameQueue(FLAGS.cbt_project,
                                                    FLAGS.cbt_instance,
                                                    FLAGS.cbt_table + '-nr')
                return preprocessing.get_tpu_bt_input_tensors(
                    games,
                    games_nr,
                    params['batch_size'],
                    params['input_layout'],
                    number_of_games=FLAGS.window_size,
                    random_rotation=True)
        else:

            def _input_fn(params):
                return preprocessing.get_tpu_input_tensors(
                    params['batch_size'],
                    params['input_layout'],
                    tf_records,
                    filter_amount=FLAGS.filter_amount,
                    shuffle_examples=FLAGS.shuffle_examples,
                    shuffle_buffer_size=FLAGS.shuffle_buffer_size,
                    random_rotation=True)

        # Hooks are broken with TPUestimator at the moment.
        hooks = []
    else:

        def _input_fn():
            return preprocessing.get_input_tensors(
                effective_batch_size,
                FLAGS.input_layout,
                tf_records,
                filter_amount=FLAGS.filter_amount,
                shuffle_examples=FLAGS.shuffle_examples,
                shuffle_buffer_size=FLAGS.shuffle_buffer_size,
                random_rotation=True,
                seed=FLAGS.training_seed,
                dist_train=FLAGS.dist_train,
                use_bf16=FLAGS.use_bfloat16)

        hooks = [
            UpdateRatioSessionHook(FLAGS.work_dir),
            EchoStepCounterHook(output_dir=FLAGS.work_dir)
        ]
        if FLAGS.dist_train:
            hooks.append(hvd.BroadcastGlobalVariablesHook(0))

    steps = FLAGS.steps_to_train
    if not steps and FLAGS.num_examples:
        batch_size = effective_batch_size
        if FLAGS.use_tpu:
            batch_size *= FLAGS.num_tpu_cores
        steps = math.floor(FLAGS.num_examples / batch_size)

    logging.info("Training, steps = %s, batch = %s -> %s examples", steps
                 or '?', effective_batch_size,
                 (steps * effective_batch_size) if steps else '?')

    if FLAGS.use_bt:
        games = bigtable_input.GameQueue(FLAGS.cbt_project, FLAGS.cbt_instance,
                                         FLAGS.cbt_table)
        if not games.read_wait_cell():
            games.require_fresh_games(20000)
        latest_game = games.latest_game_number
        index_from = max(latest_game, games.read_wait_cell())
        print("== Last game before training:", latest_game, flush=True)
        print("== Wait cell:", games.read_wait_cell(), flush=True)

    try:
        estimator.train(_input_fn, steps=steps, hooks=hooks)
        if FLAGS.use_bt:
            bigtable_input.set_fresh_watermark(games, index_from,
                                               FLAGS.window_size)
    except:
        if FLAGS.use_bt:
            games.require_fresh_games(0)
        raise
示例#4
0
def train(*tf_records: "Records to train on"):
    """Train on examples."""
    tf.logging.set_verbosity(tf.logging.INFO)
    estimator = dual_net.get_estimator()

    effective_batch_size = FLAGS.train_batch_size
    if FLAGS.use_tpu:
        effective_batch_size *= FLAGS.num_tpu_cores
    elif FLAGS.use_ipu:
        effective_batch_size *= FLAGS.num_ipu_cores

    if FLAGS.use_tpu:
        if FLAGS.use_bt:

            def _input_fn(params):
                games = bigtable_input.GameQueue(FLAGS.cbt_project,
                                                 FLAGS.cbt_instance,
                                                 FLAGS.cbt_table)
                games_nr = bigtable_input.GameQueue(FLAGS.cbt_project,
                                                    FLAGS.cbt_instance,
                                                    FLAGS.cbt_table + '-nr')
                return preprocessing.get_tpu_bt_input_tensors(
                    games,
                    games_nr,
                    params['batch_size'],
                    number_of_games=FLAGS.window_size,
                    random_rotation=True)
        else:

            def _input_fn(params):
                return preprocessing.get_tpu_input_tensors(
                    params['batch_size'], tf_records, random_rotation=True)

        # Hooks are broken with TPUestimator at the moment.
        hooks = []
    elif FLAGS.use_ipu:

        def _input_fn():
            return preprocessing.get_ipu_input_tensors(
                FLAGS.train_batch_size,
                tf_records,
                filter_amount=FLAGS.filter_amount,
                shuffle_buffer_size=FLAGS.shuffle_buffer_size,
                shuffle_examples=False,
                random_rotation=False)

        hooks = []
    else:

        def _input_fn():
            return preprocessing.get_input_tensors(
                FLAGS.train_batch_size,
                tf_records,
                filter_amount=FLAGS.filter_amount,
                shuffle_buffer_size=FLAGS.shuffle_buffer_size,
                random_rotation=True)

        hooks = [
            UpdateRatioSessionHook(FLAGS.work_dir),
            EchoStepCounterHook(output_dir=FLAGS.work_dir)
        ]

    try:
        if FLAGS.PROFILING:
            ph = ProfilerHook()
            hooks = [ph]
    except:
        pass

    steps = FLAGS.steps_to_train

    # step correction due to smaller batch size
    if FLAGS.use_ipu:
        steps = steps * 4096 // effective_batch_size

    logging.info("Training, steps = %s, batch = %s -> %s examples", steps
                 or '?', effective_batch_size,
                 (steps * effective_batch_size) if steps else '?')

    if FLAGS.use_bt:
        games = bigtable_input.GameQueue(FLAGS.cbt_project, FLAGS.cbt_instance,
                                         FLAGS.cbt_table)
        if not games.read_wait_cell():
            games.require_fresh_games(20000)
        latest_game = games.latest_game_number
        index_from = max(latest_game, games.read_wait_cell())
        print("== Last game before training:", latest_game, flush=True)
        print("== Wait cell:", games.read_wait_cell(), flush=True)

    if DATA_BENCHMARK:
        benchmark_op = dataset_benchmark(
            dataset=_input_fn(),
            number_of_epochs=80,
            elements_per_epochs=10000,
            print_stats=True,
            # apply_options=False
        )

        import json
        print("Benchmarking data pipeline:")
        with tf.Session() as sess:
            json_string = sess.run(benchmark_op)
            json_object = json.loads(json_string[0])
        print(json_object)
        if not INFEED_BENCHMARK:
            raise NotImplementedError("Data benchmark ended.")
        else:
            print("Data benchmark ended.")

    if INFEED_BENCHMARK:
        benchmark_op = infeed_benchmark(
            infeed_queue=ipu_infeed_queue.IPUInfeedQueue(_input_fn(),
                                                         feed_name="infeed"),
            number_of_epochs=80,
            elements_per_epochs=10000,
            print_stats=True,
            # apply_options=False
        )

        import json
        print("Benchmarking data pipeline:")
        with tf.Session() as sess:
            json_string = sess.run(benchmark_op)
            json_object = json.loads(json_string[0])
        print(json_object)
        raise NotImplementedError("Infeed benchmark ended.")

    try:
        estimator.train(_input_fn, steps=steps, hooks=hooks)
        if FLAGS.use_bt:
            bigtable_input.set_fresh_watermark(games, index_from,
                                               FLAGS.window_size)
    except:
        if FLAGS.use_bt:
            games.require_fresh_games(0)
        raise

    return estimator