示例#1
0
def main():
    args = parser.parse_args()
    enc = encoder.get_encoder(args.model_name)
    hparams = model.default_hparams()
    hparams.res_dropout = args.dropout
    hparams.attn_dropout = args.dropout
    epsilon = -1e10
    if args.dtype == 'float32':
        hparams.dtype = tf.float32
    elif args.dtype == 'float16':
        hparams.dtype = tf.float16
        epsilon = -65500
    elif args.dtype == 'bfloat16':
        hparams.dtype = tf.bfloat16
        epsilon = -65500
    else:
        print('Unknown dtype', args.dtype)
    if args.float16:
        hparams.dtype = tf.bfloat16
        epsilon = -65500

    with open(os.path.join('models', args.model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))
    if args.n_ctx >= 0:
        hparams.n_ctx=args.n_ctx
    if args.n_embd >= 0:
        hparams.n_embd=args.n_embd
    if args.n_head >= 0:
        hparams.n_head=args.n_head
    if args.n_layer >= 0:
        hparams.n_layer=args.n_layer

    if args.sample_length < 0:
        args.sample_length = hparams.n_ctx - 1
    if args.sample_length > hparams.n_ctx:
        raise ValueError(
            "Can't get samples longer than window size: %s" % hparams.n_ctx)
    if args.sample_ctx < 0:
      args.sample_ctx = hparams.n_ctx

    if args.model_name == '345M':
        args.memory_saving_gradients = True
        if args.optimizer == 'adam':
            args.only_train_transformer_layers = True

    config = tf.ConfigProto()
    if args.allow_growth:
        config.gpu_options.allow_growth = True
    if args.disable_layout_optimizer:
        config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF
    with tflex.Session(config=config, init_tpu=args.init_tpu) as sess:
        context = tf.placeholder(tf.int32, [args.batch_size, None])
        context_in = randomize(context, hparams, args.noise)
        output = model.model(hparams=hparams, X=context_in)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        if args.val_every > 0:
            val_context = tf.placeholder(tf.int32, [args.val_batch_size, None])
            val_output = model.model(hparams=hparams, X=val_context)
            val_loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=val_context[:, 1:], logits=val_output['logits'][:, :-1]))
            val_loss_summary = tf.summary.scalar('val_loss', val_loss)


        tf_sample = sample.sample_sequence(
            hparams=hparams,
            length=args.sample_length,
            context=context,
            batch_size=args.batch_size,
            temperature=1.0,
            top_k=args.top_k,
            top_p=args.top_p,
            epsilon=epsilon)

        all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
        train_vars = [v for v in all_vars if '/h' in v.name] if args.only_train_transformer_layers else all_vars

        parameter_count = sum([np.prod(v.shape.as_list()) for v in train_vars])
        print("This model is using %d parameters (%.2fM)" % (parameter_count, parameter_count/(1024.0*1024.0)))

        with tf.variable_scope(tf.get_variable_scope().name, reuse=tf.AUTO_REUSE):
            global_step = tflex.get_variable('global_step') or tf.get_variable('global_step', shape=(), dtype=tf.int32, trainable=False)
            current_step = args.learning_rate_initial_step
            global_step.load(current_step, session=sess)
            if args.learning_rate_cos:
                lr = tflex_sgdr.sgdr_decay_with_warmup(args.learning_rate, global_step,
                    warmup_steps=args.learning_rate_warmup, initial_period_steps=args.learning_rate_period, learning_rate_min=args.learning_rate_min)
            else:
                lr = tflex.get_variable('learn_rate') or tf.get_variable('learn_rate', shape=(), dtype=tf.float32, trainable=False)
                lr.load(args.learning_rate, session=sess)

        def update_lr(rate=None, step=None):
          if not args.learning_rate_cos:
            if step is None:
              step = global_step.eval(session=sess)
            if rate is None:
              rate = args.learning_rate
            if callable(rate):
              rate = rate(step)
            lr.load(rate, session=sess)
          return lr.eval(session=sess)

        @tflex.register_command
        def set_learning_rate():
          print("Current learn rate: %0.8f" % update_lr())
          print("New learn rate?")
          rate = input('')
          if not rate:
            print("Empty input; not changing anything.")
          else:
            try:
              rate = float(rate)
            except:
              print("Invalid input; must be a float")
          print("Setting learn rate to %0.8f" % rate)
          args.learning_rate = rate

        if args.optimizer == 'adam':
            opt = tf.train.AdamOptimizer(learning_rate=lr)
        elif args.optimizer == 'sgd':
            opt = tf.train.GradientDescentOptimizer(learning_rate=lr)
        elif args.optimizer == 'ada':
            import tensor2tensor.utils.optimize
            from tensor2tensor.utils import hparam
            import tensor2tensor.models.research
            from tensor2tensor.utils import registry
            ada_hparams = registry.hparams('afx_mimic_adam')
            ada_hparams.optimizer_adafactor_beta1 = 0.0
            ada_hparams.optimizer_adafactor_factored = True
            opt = tensor2tensor.utils.optimize.adafactor(learning_rate=lr, hparams=ada_hparams)
        else:
            exit('Bad optimizer:', args.optimizer)

        #if tpu_addr:
        #    # https://pulsejet.github.io/blog/posts/tpu-without-estimator/
        #    from tensorflow.contrib.tpu.python.tpu import tpu_function
        #    tpu_function.get_tpu_context().set_number_of_shards(8)
        #    opt = tf.contrib.tpu.CrossShardOptimizer(opt)

        if args.accumulate_gradients > 1:
            if args.memory_saving_gradients:
                exit("Memory saving gradients are not implemented for gradient accumulation yet.")
            opt = AccumulatingOptimizer(
                opt=opt,
                var_list=train_vars)
            opt_reset = opt.reset()
            opt_compute = opt.compute_gradients(loss)
            opt_apply = opt.apply_gradients()
            summary_loss = tf.summary.scalar('loss', opt_apply)
        else:
            if args.memory_saving_gradients:
                opt_grads = memory_saving_gradients.gradients(loss, train_vars)
            else:
                opt_grads = tf.gradients(loss, train_vars)
            opt_grads = list(zip(opt_grads, train_vars))
            opt_apply = opt.apply_gradients(opt_grads)
            summary_loss = tf.summary.scalar('loss', loss)

        summary_lr = tf.summary.scalar('learning_rate', lr)
        summaries = tf.summary.merge([summary_lr, summary_loss])

        summary_log = tf.summary.FileWriter(
            os.path.join(CHECKPOINT_DIR, args.run_name))

        if args.save_graph:
            summary_log.add_graph(tf.get_default_graph())

        saver = tflex.Saver(
            var_list=all_vars,
            max_to_keep=args.max_to_keep,
            keep_checkpoint_every_n_hours=100000,
            reshape=args.truncate_weights)
        sess.run(tf.global_variables_initializer())

        if args.restore_from == 'latest':
            ckpt = tflex.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, args.run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tflex.latest_checkpoint(
                    os.path.join('models', args.model_name))
        elif args.restore_from == 'fresh':
            ckpt = tflex.latest_checkpoint(
                os.path.join('models', args.model_name))
        else:
            ckpt = tflex.latest_checkpoint(args.restore_from)
        print('Loading snapshot %s...' % ckpt)
        t0 = time.time()
        if not args.fresh_model:
            saver.restore(sess, ckpt)
        t1 = time.time()
        print('Loaded in %f seconds' % (t1 - t0))

        def make_sampler(dataset, enc, seed, combine):
          if os.path.isdir(dataset) or dataset.endswith('.npz'):
            chunks = load_dataset(enc, dataset, combine)
            data_sampler = Sampler(chunks, seed=seed)
            print('dataset has', data_sampler.total_size, 'tokens', len(chunks), 'chunks')
          else:
            data_sampler = TextSampler(dataset, enc, seed=seed)
          return data_sampler

        print('Loading dataset...')
        seed = None if args.seed < 0 else args.seed
        data_sampler = make_sampler(dataset=args.dataset, enc=enc, seed=seed, combine=args.combine)
        if args.val_every > 0:
            # Sample from validation set once with fixed seed to make
            # it deterministic during training as well as across runs.
            val_dataset = args.val_dataset if args.val_dataset else args.dataset
            val_data_sampler = make_sampler(dataset=val_dataset, enc=enc, seed=1, combine=args.combine)
            val_batches = [[val_data_sampler.sample(hparams.n_ctx) for _ in range(args.val_batch_size)]
                           for _ in range(args.val_batch_count)]

        print('Training...')
        counter = 1
        counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter')
        if os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1

        @tflex.register_command
        def get_tarfile_name(checkpoint_folder):
            """Converts a folder path into a filename for a .tar archive"""
            tarfile_name = checkpoint_folder.replace(os.path.sep, '_') + '.tar'

            return tarfile_name


        def copy_checkpoint_to_gdrive(run_name='run1', copy_folder=False):
            """Copies the checkpoint folder to a mounted Google Drive."""
            #is_mounted()

            checkpoint_folder = os.path.join('checkpoint', run_name)

            if copy_folder:
                shutil.copytree(checkpoint_folder, "/content/drive/My Drive/" + checkpoint_folder)
            else:
                file_path = get_tarfile_name(checkpoint_folder)

                # Reference: https://stackoverflow.com/a/17081026
                with tarfile.open(file_path, 'w') as tar:
                    tar.add(checkpoint_folder)

                shutil.copyfile(file_path, "/content/drive/My Drive/" + file_path)

        @tflex.register_command
        def save():
            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, args.run_name,
                             'model-{}').format(counter))
            t0 = time.time()
            saver.save(
                sess,
                os.path.join(CHECKPOINT_DIR, args.run_name, 'model'),
                global_step=counter)
            t1 = time.time()
            print('Saved in %f seconds' % (t1 - t0))
            with open(counter_path, 'w') as fp:
                fp.write(str(counter) + '\n')
            #copy_checkpoint_to_gdrive()

        @tflex.register_command
        def generate_samples():
            print('Generating samples...')
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < args.sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: args.batch_size * [context_tokens]})
                for i in range(min(args.sample_num - index, args.batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    print(text)
                    all_text.append(text)
                    index += 1
            maketree(os.path.join(SAMPLE_DIR, args.run_name))
            with open(
                    os.path.join(SAMPLE_DIR, args.run_name,
                                 'samples-{}').format(counter), 'w') as fp:
                fp.write('\n'.join(all_text))

        @tflex.register_command
        def validation():
            if args.val_every <= 0:
              return
            print('Calculating validation loss...')
            losses = []
            for batch in tqdm.tqdm(val_batches):
                losses.append(sess.run(val_loss, feed_dict={val_context: batch}))
            v_val_loss = np.mean(losses)
            v_summary = sess.run(val_loss_summary, feed_dict={val_loss: v_val_loss})
            summary_log.add_summary(v_summary, counter)
            summary_log.flush()
            print(
                '{stamp} [{counter} | {time:2.4f}] validation loss = {loss:2.4f}'
                .format(
                    stamp=timestamp(),
                    counter=counter,
                    time=time.time() - start_time,
                    loss=v_val_loss))

        start_time = time.time()

        def elapsed():
            return time.time() - start_time

        def say(msg):
            print('{stamp} [{counter} | {time:2.4f}] {msg}'.format(counter=counter, time=elapsed(), msg=msg, stamp=timestamp()))

        def sample_batch():
            #return [data_sampler.sample(args.sample_ctx) for _ in range(args.batch_size)]
            #say('Sampling batch...')
            r = []
            times = []
            for _ in range(args.batch_size):
                start = time.time()
                sample = data_sampler.sample(args.sample_ctx)
                end = time.time()
                elapsed = (end - start)
                r += [sample]
                times += [elapsed]
            total = sum(times)
            avg = total / len(times)
            #say('Sampled %d batches in %.4f seconds (avg per batch: %.4f)' % (args.batch_size, total, avg))
            return r

        prev_time = time.time()
        avg_loss = (0.0, 0.0)

        if args.debug_before_training:
            import pdb
            pdb.set_trace()

        last_saved_time = elapsed()
        while True:
            try:
                now = elapsed()
                if args.save_time > 0 and (((now - last_saved_time) / 60.0) >= args.save_time):
                    save()
                    last_saved_time = now
                elif args.save_every > 0 and (counter % args.save_every == 0):
                    save()
                if counter % args.sample_every == 0:
                    generate_samples()
                if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1):
                    validation()

                v_rate = update_lr()

                if args.accumulate_gradients > 1:
                    #say('Running opt_reset...')
                    sess.run(opt_reset)
                    for _ in range(args.accumulate_gradients):
                        batch = sample_batch()
                        say('Running opt_compute...')
                        sess.run(opt_compute, feed_dict={context: batch})
                    say('Running opt_apply...')
                    (v_loss, v_summary) = sess.run((opt_apply, summaries))
                else:
                    batch = sample_batch()
                    say('Running opt_apply...')
                    (_, v_loss, v_summary) = sess.run(
                        (opt_apply, loss, summaries),
                        feed_dict={context: batch})

                if args.float16:
                    v_loss = tf.to_float(v_loss).eval()

                summary_log.add_summary(v_summary, counter)
                summary_log.flush()

                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                now = time.time()
                print('{stamp} [{counter} | {time:2.4f} | {delta:2.2f}s | {ops:2.6f}tokens/s] loss={loss:2.4f} avg={avg:2.4f} rate={rate:0.7f} step={step}'
                    .format(
                        stamp=timestamp(),
                        counter=counter,
                        time=now - start_time,
                        delta=now - prev_time,
                        ops=args.sample_ctx * args.batch_size / (now - prev_time),
                        rate=v_rate,
                        loss=v_loss,
                        avg=avg_loss[0] / avg_loss[1],
                        step=current_step,
                        ))

                counter += 1
                current_step += 1
                global_step.load(current_step, session=sess)

                tflex.check_commands_with_args(
                    session=sess,
                    stamp=timestamp(),
                    counter=counter,
                    time=now - start_time,
                    delta=now - prev_time,
                    ops=args.batch_size / (now - prev_time),
                    rate=v_rate,
                    loss=v_loss,
                    avg=avg_loss[0] / avg_loss[1],
                    avg_loss=avg_loss,
                    step=current_step,
                    train_vars=train_vars,
                    all_vars=all_vars,
                    args=args,
                    data_sampler=data_sampler,
                    ckpt=ckpt,
                    saver=saver,
                    )
                if tflex.should_quit():
                  break

                prev_time = now
                if args.debug_print_all_vars:
                    print('all variables:')
                    print('name/shape/parameter_count')
                    param_count = 0
                    for x in tf.all_variables():
                        shape = x.shape.as_list()
                        count = np.prod(shape)
                        print(x.name, shape, count)
                        param_count += count
                    print('Total parameters:', param_count)
                    args.debug_print_all_vars = False

                if args.debug_print_trainable_vars:
                    print('trainable variables:')
                    print('name/shape/parameter_count')
                    param_count = 0
                    for x in tf.trainable_variables():
                        shape = x.shape.as_list()
                        count = np.prod(shape)
                        print(x.name, shape, count)
                        param_count += count
                    print('Total parameters:', param_count)
                    args.debug_print_trainable_vars = False
            except KeyboardInterrupt:
                print('interrupted')
                if args.save_on_ctrlc:
                    save()
                if args.debug_on_ctrlc:
                    import pdb
                    pdb.set_trace()
                else:
                    break
示例#2
0
def interact_model(model_name='117M',
                   restore_from=None,
                   seed=None,
                   nsamples=1,
                   step=1,
                   length=64,
                   prompt="\n",
                   clear=None,
                   maxlen=-1,
                   temperature=1,
                   top_k=0,
                   top_p=0,
                   penalize=0):
    """
    Interactively run the model
    :model_name=117M : String, which model to use
    :seed=None : Integer seed for random number generators, fix seed to reproduce
     results
    :nsamples=1 : Number of samples to return total
    :step=1 : Number of tokens to generate at a time
    :length=64 : Window size; use 1024 for maximum size per sample
    :prompt="\\n" : Prompt to start with. The default of "" prompts with an <|endoftext|> token.
    :clear=None : If this string is encountered, clear the context window.
    :maxlen=-1 : if this many tokens are generated without
     encountering --clear, then print it and clear the context window.
    :temperature=1 : Float value controlling randomness in boltzmann
     distribution. Lower temperature results in less random completions. As the
     temperature approaches zero, the model will become deterministic and
     repetitive. Higher temperature results in more random completions.
    :top_k=0 : Integer value controlling diversity. 1 means only 1 word is
     considered for each step (token), resulting in deterministic completions,
     while 40 means 40 words are considered at each step. 0 (default) is a
     special setting meaning no restrictions. 40 generally is a good value.
    :top_p=0.0 : Float value controlling diversity. Implements nucleus sampling,
     overriding top_k if set to a value > 0. A good setting is 0.9.
    :penalize=0.0 : Float value controlling "used" penalty. Implements repetition
     reduction (similar to CTRL) if set to a value > 0. A decent setting might be 0.85
     with temperature 0.3 and top_k 40.
    """
    batch_size = 1
    assert nsamples % batch_size == 0

    enc = encoder.get_encoder(model_name)
    hparams = model.default_hparams()
    with open(os.path.join('models', model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if length > hparams.n_ctx:
        raise ValueError("Length can't be largeer than n_ctx: %s" %
                         hparams.n_ctx)
    if step > length:
        raise ValueError("Can't get samples longer than length: %s" % length)

    with tflex.Session(graph=tf.Graph()) as sess:
        context = tf.placeholder(tf.int32, [batch_size, None])
        np.random.seed(seed)
        tf.set_random_seed(seed)
        output = sample.sample_sequence(hparams=hparams,
                                        length=step,
                                        context=context,
                                        batch_size=batch_size,
                                        temperature=temperature,
                                        top_k=top_k,
                                        top_p=top_p,
                                        penalize=penalize)

        saver = tflex.Saver(reshape=True)
        if restore_from is None:
            restore_from = os.path.join('models', model_name)
        ckpt = tflex.latest_checkpoint(restore_from)
        saver.restore(sess, ckpt)

        while True:
            tflex.check_commands()
            if tflex.should_quit():
                break
            try:
                with open(prompt) as f:
                    tflex.raw_text = f.read()
                if tflex.raw_text.endswith('\n'):
                    tflex.raw_text = tflex.raw_text[:-1]
                if tflex.raw_text.endswith('\r'):
                    tflex.raw_text = tflex.raw_text[:-1]
            except:
                tflex.raw_text = prompt
            tflex.raw_text = tflex.raw_text.replace('\\n', '\n')
            tflex.raw_text = tflex.raw_text.replace('\\t', '\t')
            #print(repr(tflex.raw_text))
            tflex.context_tokens = enc.encode(
                tflex.raw_text) if len(tflex.raw_text) > 0 else [50256]
            while len(tflex.context_tokens) > length - step - 1:
                tflex.context_tokens = tflex.context_tokens[1:]
            tflex.prompt_tokens = tflex.context_tokens[:]
            tflex.first = True
            tflex.backlog = []
            tflex.backlog_count = 0
            tflex.context_text = ""
            tflex.context_count = 0
            while True:
                for tokens in generate_result(
                        context_tokens=tflex.context_tokens,
                        enc=enc,
                        output=output,
                        context=context,
                        nsamples=1,
                        batch_size=batch_size,
                        sess=sess):
                    tflex.tokens = tokens
                    if tflex.first:
                        #clear_output(wait=True)
                        sys.stdout.write(enc.decode(tflex.context_tokens))
                        sys.stdout.flush()
                        tflex.first = False
                    tflex.backlog.extend(tflex.tokens)
                    tflex.backlog_count += 1
                    if is_ascii(enc.decode([tflex.backlog[-1]
                                            ])) or tflex.backlog_count > 16:
                        text = enc.decode(tflex.backlog)
                        result = text
                        if clear is not None:
                            result, *rest = text.split(clear)
                        sys.stdout.write(result)
                        sys.stdout.flush()
                        tflex.context_text += text
                        tflex.context_count += len(tflex.backlog)

                        def reset_context():
                            tflex.context_text = ""
                            tflex.context_count = 0
                            tflex.context_tokens = []
                            tflex.first = True
                            tflex.tokens = tflex.prompt_tokens[:]

                        tflex.reset_context = reset_context
                        if maxlen > 0 and tflex.context_count > maxlen or clear is not None and clear in tflex.context_text:
                            tflex.reset_context()
                        tflex.backlog = []
                        tflex.backlog_count = 0
                    tflex.check_commands()
                    tflex.context_tokens.extend(tflex.tokens)
                    while len(tflex.context_tokens) > length - step - 1:
                        tflex.context_tokens = tflex.context_tokens[1:]
示例#3
0
    def train(self, num_threads=1, output_summaries=True):
        """Run the Train steps on the TPU device.

    Args:
      num_threads: number of outstanding checkpointing threads

    """
        if output_summaries and self.model_dir is not None:
            output_dir = os.path.join(self.model_dir, "eval")
            tf.gfile.MakeDirs(output_dir)
            # Summary writer writes out eval metrics.
            summary_writer = tf.compat.v1.summary.FileWriter(output_dir)
        else:
            summary_writer = None

        def checkpoint_thread_fn(saver, sess, force=False):
            step = self.cur_step
            if self.model_dir is None:
                tf.logging.info(
                    'step %d: model_dir is None; not saving checkpoint %s-%d',
                    step, 'model.ckpt', step)
                return
            if not force:
                if train_flags.options().get('no_save'):
                    tf.logging.info(
                        'step %d: options.no_save is set; not saving checkpoint %s-%d',
                        step, 'model.ckpt', step)
                    return
            path = self.model_dir + "/model.ckpt"
            tf.logging.info('step %d: Saving checkpoint %s-%d...', step, path,
                            step)
            now = time.time()
            saver.save(sess, path, write_meta_graph=False, global_step=step)
            elapsed = time.time() - now
            tf.logging.info('step %d: Saved checkpoint %s-%d in %.2fs', step,
                            path, step, elapsed)

        @tflex.register_command
        def save():
            checkpoint_thread_fn(self.saver, self.sess, force=True)

        thread_id = 0
        checkpoint_threads = []
        need_final_checkpoint = False
        tf.logging.info("TrainRunner: step %d", self.cur_step)
        #tflex.run(sess, self.global_step.initializer, dict([(self.global_step.initializer.inputs[1], self.cur_step)]))
        for i in range(num_threads):
            checkpoint_threads.append(None)
        end_step = None if self.train_steps is None else (self.cur_step +
                                                          self.train_steps)
        while True if end_step is None else (self.cur_step < end_step):
            tflex.check_commands()
            if tflex.should_quit():
                tf.logging.info("TrainRunner: quitting")
                break
            start = time.time()
            tf.logging.info("TrainRunner: start next %d steps",
                            self.iterations)
            self.cur_step += self.iterations
            self.infeed_thread_fn()
            loss = tflex.run(self.sess, [self.loss])
            thread = checkpoint_threads[thread_id]
            if checkpoint_threads[thread_id] is not None and checkpoint_threads[
                    thread_id].is_alive():
                tf.logging.info(
                    "TrainRunner: checkpoint thread still active; skipping")
                need_final_checkpoint = True
            else:
                tf.logging.info("TrainRunner: starting checkpoint thread...")
                if checkpoint_threads[thread_id] is not None:
                    checkpoint_threads[thread_id].join()
                checkpoint_threads[thread_id] = threading.Thread(
                    target=checkpoint_thread_fn,
                    args=(self.saver, self.sess),
                    daemon=True)
                checkpoint_threads[thread_id].start()
                need_final_checkpoint = False
            thread_id += 1
            if thread_id >= num_threads:
                thread_id = 0
            end = time.time()
            tf.logging.info("TrainRunner: fetching global_step...")
            gs = tflex.run(self.sess, self.global_step)
            step_sec = end - start
            gs_sec = self.iterations / step_sec
            ex_sec = self.iterations * self.train_batch_size / (end - start)
            # Write out summary to tensorboard.
            if output_summaries:
                tf.logging.info("TrainRunner: writing summaries...")
                with tf.Graph().as_default():
                    eval_results = {
                        'loss':
                        loss,
                        'iterations_per_step':
                        self.iterations,
                        'seconds_per_step':
                        step_sec,
                        'global_step_per_second':
                        gs_sec,
                        'examples_per_second':
                        ex_sec,
                        'train_batch_size_per_core':
                        self.train_batch_size // self.num_cores,
                        'num_cores':
                        self.num_cores,
                    }
                    for metric in eval_results:
                        values = eval_results[metric]
                        if not isinstance(values, list):
                            values = [values]
                        for i, value in enumerate(values):
                            tag = '{}_{:02d}'.format(metric,
                                                     i) if i > 0 else metric
                            step = self.cur_step - len(values) + i + 1
                            summaries = []
                            summaries.append(
                                tf.Summary.Value(tag=tag, simple_value=value))
                            tf_summary = tf.Summary(value=list(summaries))
                            if summary_writer is not None:
                                summary_writer.add_summary(tf_summary, step)
                    tf.logging.info("TrainRunner: flushing summaries (%d)...",
                                    self.cur_step)

                    def thunk(cur_step):
                        if summary_writer is not None:
                            summary_writer.flush()
                        tf.logging.info(
                            "TrainRunner: flushing summaries (%d) (done)",
                            cur_step)

                    tflex.parallelize([self.cur_step], thunk)
            tf.logging.info(
                "TrainRunner: step={} global={} end={} loss={} step_time={:.2f}sec examples/sec={:.7f} global_step/sec={:.7f}"
                .format(self.cur_step, gs, end_step, loss, step_sec, ex_sec,
                        gs_sec))
        if need_final_checkpoint:
            tf.logging.info("TrainRunner: starting final checkpoint thread...")
            checkpoint_threads.append(None)
            i = len(checkpoint_threads) - 1
            checkpoint_threads[i] = threading.Thread(
                target=checkpoint_thread_fn,
                args=(self.saver, self.sess),
                daemon=True)
            checkpoint_threads[i].start()
        tf.logging.info("TrainRunner: waiting for infeed thread...")
        self.infeed_thread.join()
        tf.logging.info("TrainRunner: waiting for checkpoint threads...")
        for i in range(num_threads):
            if checkpoint_threads[i] is not None:
                checkpoint_threads[i].join()
                checkpoint_threads[i] = None
        tf.logging.info("TrainRunner: waiting for checkpoint threads (done)")
        if output_summaries:
            tf.logging.info("TrainRunner: closing summary writer...")
            if summary_writer is not None:
                summary_writer.close()
            tf.logging.info("TrainRunner: closing summary writer (done)")