Ejemplo n.º 1
0
def run_graph(graph, tf_records):

    data_graph = tf.Graph()
    with data_graph.as_default():
        features, labels = preprocessing.get_input_tensors(
            FLAGS.batch_size,
            FLAGS.input_layout,
            tf_records,
            shuffle_buffer_size=100000000,
            random_rotation=FLAGS.random_rotation,
            make_one_shot=True,
            use_bf16=False)

    infer_graph = tf.Graph()
    with infer_graph.as_default():
        tf.import_graph_def(graph, name='')

    input_tensor = dual_net.get_input_tensor(infer_graph)
    output_tensor = dual_net.get_output_tensor(infer_graph)

    config = tf.compat.v1.ConfigProto()
    data_sess = tf.compat.v1.Session(graph=data_graph, config=config)
    infer_sess = tf.compat.v1.Session(graph=infer_graph, config=config)

    elapsed = 0
    for it in range(FLAGS.num_steps):
        features_np = data_sess.run(features)
        start_time = time.time()
        infer_sess.run(output_tensor, feed_dict={input_tensor: features_np})
        elapsed += time.time() - start_time
Ejemplo n.º 2
0
 def input_fn():
     return preprocessing.get_input_tensors(
         FLAGS.train_batch_size,
         tf_records,
         filter_amount=1.0,
         shuffle_buffer_size=FLAGS.shuffle_buffer_size,
         random_rotation=True)
Ejemplo n.º 3
0
 def train(self, tf_records, init_from=None, logdir=None, num_steps=None):
     if num_steps is None:
         num_steps = EXAMPLES_PER_GENERATION // TRAIN_BATCH_SIZE
     with self.sess.graph.as_default():
         input_tensors = preprocessing.get_input_tensors(
             TRAIN_BATCH_SIZE, tf_records)
         output_tensors = dual_net(input_tensors,
                                   TRAIN_BATCH_SIZE,
                                   train_mode=True,
                                   **self.hparams)
         train_tensors = train_ops(input_tensors, output_tensors,
                                   **self.hparams)
         weight_tensors = logging_ops()
         self.initialize_weights(init_from)
         if logdir is not None:
             training_stats = StatisticsCollector()
             logger = tf.summary.FileWriter(logdir, self.sess.graph)
         for i in tqdm(range(num_steps)):
             try:
                 tensor_values = self.sess.run(train_tensors)
             except tf.errors.OutOfRangeError:
                 break
             if logdir is not None:
                 training_stats.report(tensor_values['policy_cost'],
                                       tensor_values['value_cost'],
                                       tensor_values['l2_cost'],
                                       tensor_values['combined_cost'])
                 if i % 100 == 0 and logdir is not None:
                     accuracy_summaries = training_stats.collect()
                     weight_summaries = self.sess.run(weight_tensors)
                     global_step = tensor_values['global_step']
                     logger.add_summary(accuracy_summaries, global_step)
                     logger.add_summary(weight_summaries, global_step)
         self.save_weights()
def run_graph(graph, tf_records):

  data_graph = tf.Graph()
  with data_graph.as_default():
    features, labels = preprocessing.get_input_tensors(
              FLAGS.batch_size,
              tf_records,
              shuffle_buffer_size=100000000,
              random_rotation=FLAGS.random_rotation, seed=2,
              dist_train=False, make_one_shot=True)

  infer_graph = tf.Graph()
  with infer_graph.as_default():
    tf.import_graph_def(graph, name='')

  input_tensor = dual_net.get_input_tensor(infer_graph)
  output_tensor = dual_net.get_output_tensor(infer_graph)

  config = tf.ConfigProto(
                intra_op_parallelism_threads=FLAGS.num_intra_threads,
                inter_op_parallelism_threads=FLAGS.num_inter_threads)
  data_sess = tf.Session(graph=data_graph, config=config)
  infer_sess = tf.Session(graph=infer_graph, config=config)

  elapsed = 0
  #with tf.contrib.tfprof.ProfileContext('/home/letiank/skx-8180/train_dir/minigo', trace_steps=range(70, 80), dump_steps=[110]):
  for it in range(FLAGS.num_steps):
    features_np = data_sess.run(features)
    start_time = time.time()
    infer_sess.run(output_tensor, feed_dict={input_tensor: features_np})
    elapsed += time.time() - start_time
Ejemplo n.º 5
0
def train(working_dir, tf_records, generation_num, **hparams):
    assert generation_num > 0, "Model 0 is random weights"
    estimator = get_estimator(working_dir, **hparams)
    max_steps = generation_num * EXAMPLES_PER_GENERATION // TRAIN_BATCH_SIZE
    input_fn = lambda: preprocessing.get_input_tensors(
        TRAIN_BATCH_SIZE, tf_records)
    update_ratio_hook = UpdateRatioSessionHook(working_dir)
    estimator.train(input_fn, hooks=[update_ratio_hook], max_steps=max_steps)
Ejemplo n.º 6
0
def validate(working_dir, tf_records, checkpoint_name=None, **hparams):
    estimator = get_estimator(working_dir, **hparams)
    if checkpoint_name is None:
        checkpoint_name = estimator.latest_checkpoint()
    input_fn = lambda: preprocessing.get_input_tensors(
        TRAIN_BATCH_SIZE, tf_records, shuffle_buffer_size=1000,
        filter_amount=0.05)
    estimator.evaluate(input_fn, steps=1000)
Ejemplo n.º 7
0
 def _input_fn():
     return preprocessing.get_input_tensors(
         effective_batch_size,
         tf_records,
         filter_amount=FLAGS.filter_amount,
         shuffle_buffer_size=FLAGS.shuffle_buffer_size,
         random_rotation=True, seed=FLAGS.training_seed,
         dist_train=FLAGS.dist_train)
Ejemplo n.º 8
0
def train(working_dir, tf_records, generation_num, **hparams):
    assert generation_num > 0, "Model 0 is random weights"
    estimator = get_estimator(working_dir, **hparams)
    max_steps = generation_num * EXAMPLES_PER_GENERATION // TRAIN_BATCH_SIZE
    input_fn = lambda: preprocessing.get_input_tensors(TRAIN_BATCH_SIZE,
                                                       tf_records)
    update_ratio_hook = UpdateRatioSessionHook(working_dir)
    estimator.train(input_fn, hooks=[update_ratio_hook], max_steps=max_steps)
Ejemplo n.º 9
0
 def _input_fn():
     return preprocessing.get_input_tensors(
         FLAGS.train_batch_size,
         FLAGS.input_layout,
         tf_records,
         filter_amount=FLAGS.filter_amount,
         shuffle_examples=FLAGS.shuffle_examples,
         shuffle_buffer_size=FLAGS.shuffle_buffer_size,
         random_rotation=True)
Ejemplo n.º 10
0
def validate(estimator_dir, tf_records, checkpoint_path=None, **kwargs):
    model = get_estimator(estimator_dir, **kwargs)
    if checkpoint_path is None:
        checkpoint_path = model.latest_checkpoint()
    model.evaluate(input_fn=lambda: preprocessing.get_input_tensors(
        list_tf_records=tf_records,
        buffer_size=GLOBAL_PARAMETER_STORE.VALIDATION_BUFFER_SIZE),
                   steps=GLOBAL_PARAMETER_STORE.VALIDATION_NUMBER_OF_STEPS,
                   checkpoint_path=checkpoint_path)
Ejemplo n.º 11
0
 def extract_data(self, tf_record, filter_amount=1, random_rotation=False):
     pos_tensor, label_tensors = preprocessing.get_input_tensors(
         1, [tf_record],
         num_repeats=1,
         shuffle_records=False,
         shuffle_examples=False,
         filter_amount=filter_amount,
         random_rotation=random_rotation)
     return self.get_data_tensors(pos_tensor, label_tensors)
Ejemplo n.º 12
0
    def train(self,
              tf_records,
              init_from=None,
              num_steps=None,
              logging_freq=100,
              verbosity=1):
        logdir = os.path.join(self.logdir,
                              'train') if self.logdir is not None else None

        def should_log(i):
            return logdir is not None and i % logging_freq == 0

        if num_steps is None:
            num_steps = EXAMPLES_PER_GENERATION // TRAIN_BATCH_SIZE
        with self.sess.graph.as_default():
            input_tensors = preprocessing.get_input_tensors(
                TRAIN_BATCH_SIZE, tf_records)
            output_tensors = dual_net(input_tensors,
                                      TRAIN_BATCH_SIZE,
                                      train_mode=True,
                                      **self.hparams)
            train_tensors = train_ops(input_tensors, output_tensors,
                                      **self.hparams)
            weight_summary_op = logging_ops()
            weight_tensors = tf.trainable_variables()
            self.initialize_weights(init_from)
            if logdir is not None:
                training_stats = StatisticsCollector()
                logger = tf.summary.FileWriter(logdir, self.sess.graph)
            for i in tqdm(range(num_steps)):
                if should_log(i):
                    before_weights = self.sess.run(weight_tensors)
                try:
                    tensor_values = self.sess.run(train_tensors)
                except tf.errors.OutOfRangeError:
                    break

                if verbosity > 1 and i % logging_freq == 0:
                    print(tensor_values)
                if logdir is not None:
                    training_stats.report({
                        k: tensor_values[k]
                        for k in ('policy_cost', 'value_cost', 'l2_cost',
                                  'combined_cost')
                    })
                if should_log(i):
                    after_weights = self.sess.run(weight_tensors)
                    weight_update_summaries = compute_update_ratio(
                        weight_tensors, before_weights, after_weights)
                    accuracy_summaries = training_stats.collect()
                    weight_summaries = self.sess.run(weight_summary_op)
                    global_step = tensor_values['global_step']
                    logger.add_summary(weight_update_summaries, global_step)
                    logger.add_summary(accuracy_summaries, global_step)
                    logger.add_summary(weight_summaries, global_step)
            self.save_weights()
Ejemplo n.º 13
0
def validate(working_dir, tf_records, checkpoint_name=None, **hparams):
    estimator = get_estimator(working_dir, **hparams)
    if checkpoint_name is None:
        checkpoint_name = estimator.latest_checkpoint()
    input_fn = lambda: preprocessing.get_input_tensors(TRAIN_BATCH_SIZE,
                                                       tf_records,
                                                       shuffle_buffer_size=
                                                       1000,
                                                       filter_amount=0.05)
    estimator.evaluate(input_fn, steps=1000)
Ejemplo n.º 14
0
 def _input_fn():
     return preprocessing.get_input_tensors(
         effective_batch_size,
         FLAGS.input_layout,
         tf_records,
         filter_amount=FLAGS.filter_amount,
         shuffle_examples=FLAGS.shuffle_examples,
         shuffle_buffer_size=FLAGS.shuffle_buffer_size,
         random_rotation=True,
         seed=FLAGS.training_seed,
         dist_train=FLAGS.dist_train,
         use_bf16=FLAGS.use_bfloat16)
Ejemplo n.º 15
0
    def validate(self,
                 tf_records,
                 batch_size=128,
                 init_from=None,
                 num_steps=1000):
        """Compute only the error terms for a set of tf_records, ideally a
        holdout set, and report them to an 'test' subdirectory of the logs.
        """
        cost_tensor_names = [
            'policy_cost', 'value_cost', 'l2_cost', 'combined_cost'
        ]
        if self.logdir is None:
            print("Error, trainer not initialized with a logdir.",
                  file=sys.stderr)
            return

        logdir = os.path.join(self.logdir, 'test')

        with self.sess.graph.as_default():
            input_tensors = preprocessing.get_input_tensors(
                batch_size,
                tf_records,
                shuffle_buffer_size=1000,
                filter_amount=0.05)
            output_tensors = dual_net(input_tensors,
                                      TRAIN_BATCH_SIZE,
                                      train_mode=False,
                                      **self.hparams)
            train_tensors = train_ops(input_tensors, output_tensors,
                                      **self.hparams)

            # just run our cost tensors
            validate_tensors = {k: train_tensors[k] for k in cost_tensor_names}
            self.initialize_weights(init_from)
            training_stats = StatisticsCollector()
            logger = tf.summary.FileWriter(logdir, None)  # No graph needed.

            for i in tqdm(range(num_steps)):
                try:
                    tensor_values = self.sess.run(validate_tensors)
                except tf.errors.OutOfRangeError:
                    break
                training_stats.report(tensor_values)

            accuracy_summaries = training_stats.collect()
            global_step = self.sess.run(train_tensors['global_step'])
            logger.add_summary(accuracy_summaries, global_step)
            logger.flush()
            print(accuracy_summaries)
Ejemplo n.º 16
0
 def extract_data(self, tf_record, filter_amount=1):
     tf_example_tensor = preprocessing.get_input_tensors(
         1, [tf_record], num_repeats=1, shuffle_records=False,
         shuffle_examples=False, filter_amount=filter_amount)
     recovered_data = []
     with tf.Session() as sess:
         while True:
             try:
                 values = sess.run(tf_example_tensor)
                 recovered_data.append((
                     values['pos_tensor'],
                     values['pi_tensor'],
                     values['value_tensor']))
             except tf.errors.OutOfRangeError:
                 break
     return recovered_data
Ejemplo n.º 17
0
def train(estimator_dir, tf_records, model_version, **kwargs):
    """
    Main training function for the PolicyValueNetwork
    Args:
        estimator_dir (str): Path to the estimator directory
        tf_records (list): A list of TFRecords from which we parse the training examples
        model_version (int): The version of the model
    """
    model = get_estimator(estimator_dir, **kwargs)
    logger.info("Training model version: {}".format(model_version))
    max_steps = model_version * GLOBAL_PARAMETER_STORE.EXAMPLES_PER_GENERATION // \
                GLOBAL_PARAMETER_STORE.TRAIN_BATCH_SIZE
    model.train(input_fn=lambda: preprocessing.get_input_tensors(
        list_tf_records=tf_records),
                max_steps=max_steps)
    logger.info("Trained model version: {}".format(model_version))
Ejemplo n.º 18
0
 def extract_data(self, tf_record, filter_amount=1):
     pos_tensor, label_tensors = preprocessing.get_input_tensors(
         1, [tf_record], num_repeats=1, shuffle_records=False,
         shuffle_examples=False, filter_amount=filter_amount)
     recovered_data = []
     with tf.Session() as sess:
         while True:
             try:
                 pos_value, label_values = sess.run([pos_tensor, label_tensors])
                 recovered_data.append((
                     pos_value,
                     label_values['pi_tensor'],
                     label_values['value_tensor']))
             except tf.errors.OutOfRangeError:
                 break
     return recovered_data
Ejemplo n.º 19
0
def train(working_dir, tf_records, generation_num, **hparams):
    assert generation_num > 0, "Model 0 is random weights"
    hparams = get_default_hyperparams(**hparams)
    model = Model(hparams).cuda()

    loader = preprocessing.get_input_tensors(TRAIN_BATCH_SIZE, tf_records)

    # boundaries = [int(1e6), int(2e6)]
    # values = [1e-2, 1e-3, 1e-4]

    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=1.5e-6,
        momentum=hparams['momentum'],
        weight_decay=hparams['l2_strength'],
    )

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

    now = datetime.datetime.now()
    model_name = now.strftime("%Y-%m-%d %H:%M:%S").split(" ")
    model_name = "-".join(model_name)+".model"
    combined_cost = None
    for epoch in range(100):
        for step, (features, pi, outcome) in enumerate(loader):
            features = features.permute(0, 3, 1, 2)
            features = Variable(features.float())
            pi = Variable(pi.float())
            outcome = Variable(outcome)

            policy_output, value_output, logits = model(features)

            loss = nn.CrossEntropyLoss()
            pi = torch.max(pi, 1)[1]
            policy_cost = torch.mean(loss(logits.float().cuda(), pi.long().cuda()))
            value_cost = torch.mean((value_output.float().cuda() - outcome.float().cuda())**2)

            combined_cost = policy_cost + value_cost
            policy_entropy = -torch.mean(torch.sum(policy_output * torch.log(policy_output), dim=0))

            optimizer.zero_grad()
            combined_cost.backward()
            scheduler.step()

        print("epoch: %s | loss: %s" % (epoch, combined_cost.data[0]))
        torch.save(model.state_dict(), os.path.join(working_dir, model_name))
    return model_name
Ejemplo n.º 20
0
    def train(self, tf_records, init_from=None, logdir=None, num_steps=None,
              logging_freq=100, verbosity=1):
        def should_log(i):
            return logdir is not None and i % logging_freq == 0
        if num_steps is None:
            num_steps = EXAMPLES_PER_GENERATION // TRAIN_BATCH_SIZE
        with self.sess.graph.as_default():
            input_tensors = preprocessing.get_input_tensors(
                TRAIN_BATCH_SIZE, tf_records)
            output_tensors = dual_net(input_tensors, TRAIN_BATCH_SIZE,
                                      train_mode=True, **self.hparams)
            train_tensors = train_ops(
                input_tensors, output_tensors, **self.hparams)
            weight_summary_op = logging_ops()
            weight_tensors = tf.trainable_variables()
            self.initialize_weights(init_from)
            if logdir is not None:
                training_stats = StatisticsCollector()
                logger = tf.summary.FileWriter(logdir, self.sess.graph)
            for i in tqdm(range(num_steps)):
                if should_log(i):
                    before_weights = self.sess.run(weight_tensors)
                try:
                    tensor_values = self.sess.run(train_tensors)
                except tf.errors.OutOfRangeError:
                    break

                if verbosity > 1 and i % logging_freq == 0:
                    print(tensor_values)
                if logdir is not None:
                    training_stats.report(
                        {k: tensor_values[k] for k in (
                            'policy_cost', 'value_cost', 'l2_cost',
                            'combined_cost')})
                if should_log(i):
                    after_weights = self.sess.run(weight_tensors)
                    weight_update_summaries = compute_update_ratio(
                        weight_tensors, before_weights, after_weights)
                    accuracy_summaries = training_stats.collect()
                    weight_summaries = self.sess.run(weight_summary_op)
                    global_step = tensor_values['global_step']
                    logger.add_summary(weight_update_summaries, global_step)
                    logger.add_summary(accuracy_summaries, global_step)
                    logger.add_summary(weight_summaries, global_step)
            self.save_weights()
Ejemplo n.º 21
0
 def extract_data(self, tf_record, filter_amount=1, random_rotation=False):
     pos_tensor, label_tensors = preprocessing.get_input_tensors(
         1, [tf_record], num_repeats=1, shuffle_records=False,
         shuffle_examples=False, filter_amount=filter_amount,
         random_rotation=random_rotation)
     recovered_data = []
     with tf.Session() as sess:
         while True:
             try:
                 pos_value, label_values = sess.run(
                     [pos_tensor, label_tensors])
                 recovered_data.append((
                     pos_value,
                     label_values['pi_tensor'],
                     label_values['value_tensor']))
             except tf.errors.OutOfRangeError:
                 break
     return recovered_data
Ejemplo n.º 22
0
 def input_fn():
     return preprocessing.get_input_tensors(params, params.batch_size,
                                            tf_records)
Ejemplo n.º 23
0
 def input_fn(): return preprocessing.get_input_tensors(
     TRAIN_BATCH_SIZE, tf_records, shuffle_buffer_size=1000,
     filter_amount=0.05)
 estimator.evaluate(input_fn, steps=1000)
Ejemplo n.º 24
0
 def input_fn(): return preprocessing.get_input_tensors(
     TRAIN_BATCH_SIZE, tf_records)
 update_ratio_hook = UpdateRatioSessionHook(working_dir)
Ejemplo n.º 25
0
 def input_fn():
     return preprocessing.get_input_tensors(TRAIN_BATCH_SIZE,
                                            tf_records,
                                            shuffle_buffer_size=1000,
                                            filter_amount=0.05)
Ejemplo n.º 26
0
 def input_fn():
     return preprocessing.get_input_tensors(TRAIN_BATCH_SIZE, tf_records)
Ejemplo n.º 27
0
 def input_fn(): return preprocessing.get_input_tensors(
     TRAIN_BATCH_SIZE, tf_records)
 update_ratio_hook = UpdateRatioSessionHook(working_dir)
Ejemplo n.º 28
0
 def input_fn():
   return preprocessing.get_input_tensors(
       params, params.batch_size, tf_records)
Ejemplo n.º 29
0
 def input_fn(): return preprocessing.get_input_tensors(
     TRAIN_BATCH_SIZE, tf_records, shuffle_buffer_size=1000,
     filter_amount=0.05)
 estimator.evaluate(input_fn, steps=1000)
Ejemplo n.º 30
0
 def input_fn():
     return preprocessing.get_input_tensors(FLAGS.train_batch_size,
                                            tf_records,
                                            filter_amount=0.05,
                                            shuffle_buffer_size=20000)
Ejemplo n.º 31
0
 def input_fn():
     return preprocessing.get_input_tensors(TRAIN_BATCH_SIZE,
                                            tf_records,
                                            filter_amount=1.0)
Ejemplo n.º 32
0
 def _input_fn():
     return preprocessing.get_input_tensors(
         FLAGS.train_batch_size, tf_records, filter_amount=0.05,
         shuffle_examples=False)
Ejemplo n.º 33
0
 def input_fn():
     return preprocessing.get_input_tensors(params,
                                            params.batch_size,
                                            tf_records,
                                            filter_amount=0.05)
Ejemplo n.º 34
0
 def input_fn():
   return preprocessing.get_input_tensors(
       params, params.batch_size, tf_records, filter_amount=0.05)
Ejemplo n.º 35
0
def main(unused_argv):
    in_path = FLAGS.in_path
    out_path = FLAGS.out_path

    assert tf.gfile.Exists(in_path)
    # TODO(amj): Why does ensure_dir_exists skip gs paths?
    #tf.gfile.MakeDirs(os.path.dirname(out_path))
    #assert tf.gfile.Exists(os.path.dirname(out_path))

    policy_err = []
    value_err = []

    print()
    with tf.python_io.TFRecordWriter(out_path, OPTS) as writer:
        ds_iter = preprocessing.get_input_tensors(FLAGS.batch_size, [in_path],
                                                  shuffle_examples=False,
                                                  random_rotation=False,
                                                  filter_amount=1.0)

        with tf.Session() as sess:
            features, labels = ds_iter
            p_in = labels['pi_tensor']
            v_in = labels['value_tensor']

            p_out, v_out, logits = dual_net.model_inference_fn(
                features, False, FLAGS.flag_values_dict())
            tf.train.Saver().restore(sess, FLAGS.model)

            # TODO(seth): Add policy entropy.

            p_err = tf.nn.softmax_cross_entropy_with_logits_v2(
                logits=logits, labels=tf.stop_gradient(p_in))
            v_err = tf.square(v_out - v_in)

            for _ in tqdm(itertools.count(1)):
                try:
                    # Undo cast in batch_parse_tf_example.
                    x_in = tf.cast(features, tf.int8)

                    x, pi, val, pi_err, val_err = sess.run(
                        [x_in, p_out, v_out, p_err, v_err])

                    for i, (x_i, pi_i, val_i) in enumerate(zip(x, pi, val)):
                        # NOTE: The teacher's policy has much higher entropy
                        # Than the Self-play policy labels which are mostly 0
                        # expect that resulting file is 3-5x larger.

                        r = preprocessing.make_tf_example(x_i, pi_i, val_i)
                        serialized = r.SerializeToString()
                        writer.write(serialized)

                    policy_err.extend(pi_err)
                    value_err.extend(val_err)

                except tf.errors.OutOfRangeError:
                    print()
                    print("Breaking OutOfRangeError")
                    break

    print("Counts", len(policy_err), len(value_err))
    test()

    plt.subplot(121)
    n, bins, patches = plt.hist(policy_err, 40)
    plt.title('Policy Error histogram')

    plt.subplot(122)
    n, bins, patches = plt.hist(value_err, 40)
    plt.title('Value Error')

    plt.show()
Ejemplo n.º 36
0
def init_train(rank, tcomm, model_dir):
    """Train on examples and export the updated model weights."""
    # init hvd
    logging.info('hvd init at rank %d', rank)
    hvd.init(tcomm)

    #
    FLAGS.export_path = model_dir

    if rank == 0:
        logging.info('[ Train flags ] freeze              = %d', FLAGS.freeze)
        logging.info('[ Train flags ] window_size         = %d',
                     FLAGS.window_size)
        logging.info('[ Train flags ] use_trt             = %d', FLAGS.use_trt)
        logging.info('[ Train flags ] trt_max_batch_size  = %d',
                     FLAGS.trt_max_batch_size)
        logging.info('[ Train flags ] trt_precision       = %s',
                     FLAGS.trt_precision)
        logging.info('[ Train flags ] shuffle_buffer_size = %d',
                     FLAGS.shuffle_buffer_size)
        logging.info('[ Train flags ] shuffle_examples    = %d',
                     FLAGS.shuffle_examples)
        logging.info('[ Train flags ] export path         = %s',
                     FLAGS.export_path)
        logging.info('[ Train flags ] num_gpus_train      = %d', hvd.size())

        # From dual_net.py
        logging.info('[ d_net flags ] work_dir            = %s',
                     FLAGS.work_dir)
        logging.info('[ d_net flags ] train_batch_size    = %d',
                     FLAGS.train_batch_size)
        logging.info('[ d_net flags ] lr_rates            = %s',
                     FLAGS.lr_rates)
        logging.info('[ d_net flags ] lr_boundaries       = %s',
                     FLAGS.lr_boundaries)
        logging.info('[ d_net flags ] l2_strength         = %s',
                     FLAGS.l2_strength)
        logging.info('[ d_net flags ] conv_width          = %d',
                     FLAGS.conv_width)
        logging.info('[ d_net flags ] fc_width            = %d',
                     FLAGS.fc_width)
        logging.info('[ d_net flags ] trunk_layers        = %d',
                     FLAGS.trunk_layers)
        logging.info('[ d_net flags ] value_cost_weight   = %s',
                     FLAGS.value_cost_weight)
        logging.info('[ d_net flags ] summary_steps       = %d',
                     FLAGS.summary_steps)
        logging.info('[ d_net flags ] bool_features       = %d',
                     FLAGS.bool_features)
        logging.info('[ d_net flags ] input_features      = %s',
                     FLAGS.input_features)
        logging.info('[ d_net flags ] input_layout        = %s',
                     FLAGS.input_layout)

    # Training
    tf_records_ph = tf.placeholder(tf.string)
    data_iter = preprocessing.get_input_tensors(
        FLAGS.train_batch_size // hvd.size(),
        FLAGS.input_layout,
        tf_records_ph,
        filter_amount=FLAGS.filter_amount,
        shuffle_examples=FLAGS.shuffle_examples,
        shuffle_buffer_size=FLAGS.shuffle_buffer_size,
        random_rotation=True)

    features, labels = data_iter.get_next()
    train_op = dual_net.model_fn(features, labels, tf.estimator.ModeKeys.TRAIN,
                                 FLAGS.flag_values_dict(), True)
    sess = dual_net._get_session()

    # restore all from a checkpoint
    tf.train.Saver().restore(sess,
                             os.path.join(FLAGS.work_dir, 'model.ckpt-5672'))

    return TrainState(sess, train_op, data_iter, tf_records_ph)