示例#1
0
 def __init__(self, iterations, train_steps=-1):
     tf.logging.info("TrainRunner: constructor")
     self.feature_structure = None
     self.loss = None
     self.infeed_queue = []
     self.enqueue_ops = []
     self.dataset_initializer = []
     self.iterations = iterations
     self.sess = None
     self.input_sess = None
     self.infeed_thread = None
     if train_steps < 0:
         train_steps = None
     if train_steps is not None:
         if train_steps % iterations != 0:
             train_steps = iterations * int(
                 math.ceil(train_steps / iterations))
     self.train_steps = train_steps
     self.input_graph = tf.Graph()
     with tf.Graph().as_default() as self.init_graph:
         self.tpu_init = tpu.initialize_system()
         self.tpu_shutdown = tpu.shutdown_system()
     #self.cluster_resolver = tflex.TPUClusterResolver(
     self.cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
         FLAGS.tpu or FLAGS.master,
         zone=FLAGS.tpu_zone,
         project=FLAGS.gcp_project)
     self.config = tf.ConfigProto(
         operation_timeout_in_ms=600 * 60 * 1000,
         graph_options=tf.GraphOptions(
             rewrite_options=rewriter_config_pb2.RewriterConfig(
                 disable_meta_optimizer=True)),
         isolate_session_state=True)
     cluster_spec = self.cluster_resolver.cluster_spec()
     if cluster_spec:
         self.config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
     self.master = self.cluster_resolver.get_master()
     self.init_sess = tf.Session(self.master,
                                 graph=self.init_graph,
                                 config=self.config)
     tf.logging.info("TrainRunner: initializing TPU session...")
     if not bool(int(os.environ.get('TPU_NO_INIT', '0'))):
         tflex.run(self.init_sess, self.tpu_init)
     tf.logging.info("TrainRunner: initializing TPU session (done)")
     self.devices = self.init_sess.list_devices()
     self.cores = sorted(
         [x.name for x in self.devices if ':TPU:' in x.name])
     self.num_cores = len(self.cores)
     self.tpu_cores_per_host = 8
     assert self.num_cores % self.tpu_cores_per_host == 0
     self.num_hosts = self.num_cores // self.tpu_cores_per_host
     print(self.config.cluster_def)
     print('cores: %d hosts: %d ip: %s' %
           (self.num_cores, self.num_hosts, self.master))
示例#2
0
 def infeed_thread_fn():
     """Build and infeed session.run calls in a background thread."""
     i = 1
     while i < FLAGS.num_cores // FLAGS.tpu_cores_per_host:
         self.build_enqueue_ops(input_fn, params, i)
         i += 1
     # Build infeed sesssion
     self.input_sess = tf.Session(self.cluster_resolver.get_master(),
                                  graph=self.input_graph,
                                  config=self.config)
     self.input_sess.run(self.dataset_initializer)
     tf.logging.info('Ensure infeed data has fully uploaded')
     tflex.flush(self.input_sess)
     tf.logging.info('Run infeed session.run calls')
     tflex.run(self.input_sess, [self.enqueue_ops])
示例#3
0
 def __init__(self, iterations, train_steps):
     tf.logging.info("TrainRunner: constructor")
     self.feature_structure = {}
     self.loss = None
     self.infeed_queue = []
     self.enqueue_ops = []
     self.dataset_initializer = []
     self.iterations = iterations
     self.sess = None
     self.input_sess = None
     self.infeed_thread = None
     if train_steps % iterations != 0:
         train_steps = iterations * int(math.ceil(train_steps / iterations))
     self.train_steps = train_steps
     self.input_graph = tf.Graph()
     with tf.Graph().as_default() as self.init_graph:
         self.tpu_init = tpu.initialize_system()
         self.tpu_shutdown = tpu.shutdown_system()
     #self.cluster_resolver = tflex.TPUClusterResolver(
     self.cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
         FLAGS.tpu or FLAGS.master,
         zone=FLAGS.tpu_zone,
         project=FLAGS.gcp_project)
     self.config = tf.ConfigProto(
         operation_timeout_in_ms=600 * 60 * 1000,
         graph_options=tf.GraphOptions(
             rewrite_options=rewriter_config_pb2.RewriterConfig(
                 disable_meta_optimizer=True)),
         isolate_session_state=True)
     cluster_spec = self.cluster_resolver.cluster_spec()
     if cluster_spec:
         self.config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
     self.init_sess = tf.Session(self.cluster_resolver.get_master(),
                                 graph=self.init_graph,
                                 config=self.config)
     tf.logging.info("TrainRunner: initializing TPU session...")
     if not bool(int(os.environ.get('TPU_NO_INIT', '0'))):
         tflex.run(self.init_sess, self.tpu_init)
     tf.logging.info("TrainRunner: initializing TPU session (done)")
示例#4
0
def main(unused_argv):
    logging.info("Gin config: %s\nGin bindings: %s", FLAGS.gin_config,
                 FLAGS.gin_bindings)
    gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings)
    global params
    #FLAGS.iterations_per_loop = 100
    #params = {'batch_size': FLAGS.train_batch_size}
    #params = {'batch_size': 128, 'use_tpu': True, 'precision': 'float32'}
    # with open(FLAGS.params) as f:
    #   params = json.load(f)
    params = options()
    params['use_tpu'] = getval('use_tpu', True)
    params['batch_per_core'] = getval('batch_per_core', 1)
    params['iterations'] = getval('iterations', 20)
    params['batch_size'] = FLAGS.num_cores * params['batch_per_core']
    params['opt_name'] = getval('opt_name', 'adam')
    params['beta1'] = getval('beta1', 0.9)
    params['beta2'] = getval('beta2', 0.999)
    params['epsilon'] = getval('epsilon', 1e-9)
    params['lr'] = getval('lr', 0.00025)
    FLAGS.train_batch_size = params['batch_size']
    FLAGS.iterations_per_loop = params['iterations']
    FLAGS.train_steps = getval('train_steps', int(2e6))
    params['precision'] = getval('precision', 'float32')
    params['model'] = getval('model', 'GPT2')
    assert params['model'] in ['GPT2', 'GPT2Rev']

    graph = tf.Graph()
    with graph.as_default():
        master = FLAGS.tpu or FLAGS.master or getval('TPU_NAME', 'unknown')
        cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            master, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
        config = tf.ConfigProto(
            operation_timeout_in_ms=600 * 60 * 1000,
            # graph_options=tf.GraphOptions(
            #     rewrite_options=rewriter_config_pb2.RewriterConfig(
            #         disable_meta_optimizer=True,
            #         ),
            #     ),
            isolate_session_state=True)
        cluster_spec = cluster_resolver.cluster_spec()
        if cluster_spec:
            config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
        sess = tf.InteractiveSession(cluster_resolver.get_master(),
                                     graph=graph,
                                     config=config)
        devices = sess.list_devices()
        cores = sorted([x.name for x in devices if ':TPU:' in x.name])
        num_cores = len(cores)
        assert num_cores % 8 == 0
        num_hosts = num_cores // 8
        print(config.cluster_def)
        print('cores: %d hosts: %d ip: %s' % (num_cores, num_hosts, master))
        tf.logging.info("TrainRunner: initializing TPU session...")
        if not bool(int(os.environ.get('TPU_NO_INIT', '0'))):
            tflex.run(sess, tf.tpu.initialize_system())
        tf.logging.info("TrainRunner: initializing TPU session (done)")
        gan = BigGAN.GAN()
        pp(tf.trainable_variables())
        import pdb
        pdb.set_trace()

        # seed = 0
        # dataset = ImageNet.make_dataset(FLAGS.dataset or "gs://dota-euw4a/datasets/danbooru2019-s/danbooru2019-s-0*", 0, 1, seed=seed)
        # it = iterate_dataset(dataset)

        # def go():
        #   zz = next(it)
        #   images = [zz['image']]
        #   labels = [zz['label']]

        #   #import IPython
        #   print('label', labels[0])
        #   #print(labels[0] - 1, imagenet_label_names[labels[0] - 1])
        #   print(images[0].shape)
        #   print('embedding', zz['parsed']['image/class/embedding'].values.shape)
        #   print('filename', zz['parsed']['image/filename'])
        #   print('hash', zz['parsed']['image/hash'])
        #   op = tf.io.encode_jpeg(images[0])
        #   with open('test.png', 'wb') as f:
        #     f.write(sess.run(op))
        # go()

        import pdb
        pdb.set_trace()

        dataset = dataset
示例#5
0
 def shutdown(self):
     tf.logging.info("TrainRunner: shutting down...")
     if not bool(int(os.environ.get('TPU_NO_INIT', '0'))):
         tflex.run(self.init_sess, self.tpu_shutdown)
         tf.logging.info("TrainRunner: shutting down (done)")
示例#6
0
    def train(self, num_threads=1, output_summaries=True):
        """Run the Train steps on the TPU device.

    Args:
      num_threads: number of outstanding checkpointing threads

    """
        if output_summaries and self.model_dir is not None:
            output_dir = os.path.join(self.model_dir, "eval")
            tf.gfile.MakeDirs(output_dir)
            # Summary writer writes out eval metrics.
            summary_writer = tf.compat.v1.summary.FileWriter(output_dir)
        else:
            summary_writer = None

        def checkpoint_thread_fn(saver, sess, force=False):
            step = self.cur_step
            if self.model_dir is None:
                tf.logging.info(
                    'step %d: model_dir is None; not saving checkpoint %s-%d',
                    step, 'model.ckpt', step)
                return
            if not force:
                if train_flags.options().get('no_save'):
                    tf.logging.info(
                        'step %d: options.no_save is set; not saving checkpoint %s-%d',
                        step, 'model.ckpt', step)
                    return
            path = self.model_dir + "/model.ckpt"
            tf.logging.info('step %d: Saving checkpoint %s-%d...', step, path,
                            step)
            now = time.time()
            saver.save(sess, path, write_meta_graph=False, global_step=step)
            elapsed = time.time() - now
            tf.logging.info('step %d: Saved checkpoint %s-%d in %.2fs', step,
                            path, step, elapsed)

        @tflex.register_command
        def save():
            checkpoint_thread_fn(self.saver, self.sess, force=True)

        thread_id = 0
        checkpoint_threads = []
        need_final_checkpoint = False
        tf.logging.info("TrainRunner: step %d", self.cur_step)
        #tflex.run(sess, self.global_step.initializer, dict([(self.global_step.initializer.inputs[1], self.cur_step)]))
        for i in range(num_threads):
            checkpoint_threads.append(None)
        end_step = None if self.train_steps is None else (self.cur_step +
                                                          self.train_steps)
        while True if end_step is None else (self.cur_step < end_step):
            tflex.check_commands()
            if tflex.should_quit():
                tf.logging.info("TrainRunner: quitting")
                break
            start = time.time()
            tf.logging.info("TrainRunner: start next %d steps",
                            self.iterations)
            self.cur_step += self.iterations
            self.infeed_thread_fn()
            loss = tflex.run(self.sess, [self.loss])
            thread = checkpoint_threads[thread_id]
            if checkpoint_threads[thread_id] is not None and checkpoint_threads[
                    thread_id].is_alive():
                tf.logging.info(
                    "TrainRunner: checkpoint thread still active; skipping")
                need_final_checkpoint = True
            else:
                tf.logging.info("TrainRunner: starting checkpoint thread...")
                if checkpoint_threads[thread_id] is not None:
                    checkpoint_threads[thread_id].join()
                checkpoint_threads[thread_id] = threading.Thread(
                    target=checkpoint_thread_fn,
                    args=(self.saver, self.sess),
                    daemon=True)
                checkpoint_threads[thread_id].start()
                need_final_checkpoint = False
            thread_id += 1
            if thread_id >= num_threads:
                thread_id = 0
            end = time.time()
            tf.logging.info("TrainRunner: fetching global_step...")
            gs = tflex.run(self.sess, self.global_step)
            step_sec = end - start
            gs_sec = self.iterations / step_sec
            ex_sec = self.iterations * self.train_batch_size / (end - start)
            # Write out summary to tensorboard.
            if output_summaries:
                tf.logging.info("TrainRunner: writing summaries...")
                with tf.Graph().as_default():
                    eval_results = {
                        'loss':
                        loss,
                        'iterations_per_step':
                        self.iterations,
                        'seconds_per_step':
                        step_sec,
                        'global_step_per_second':
                        gs_sec,
                        'examples_per_second':
                        ex_sec,
                        'train_batch_size_per_core':
                        self.train_batch_size // self.num_cores,
                        'num_cores':
                        self.num_cores,
                    }
                    for metric in eval_results:
                        values = eval_results[metric]
                        if not isinstance(values, list):
                            values = [values]
                        for i, value in enumerate(values):
                            tag = '{}_{:02d}'.format(metric,
                                                     i) if i > 0 else metric
                            step = self.cur_step - len(values) + i + 1
                            summaries = []
                            summaries.append(
                                tf.Summary.Value(tag=tag, simple_value=value))
                            tf_summary = tf.Summary(value=list(summaries))
                            if summary_writer is not None:
                                summary_writer.add_summary(tf_summary, step)
                    tf.logging.info("TrainRunner: flushing summaries (%d)...",
                                    self.cur_step)

                    def thunk(cur_step):
                        if summary_writer is not None:
                            summary_writer.flush()
                        tf.logging.info(
                            "TrainRunner: flushing summaries (%d) (done)",
                            cur_step)

                    tflex.parallelize([self.cur_step], thunk)
            tf.logging.info(
                "TrainRunner: step={} global={} end={} loss={} step_time={:.2f}sec examples/sec={:.7f} global_step/sec={:.7f}"
                .format(self.cur_step, gs, end_step, loss, step_sec, ex_sec,
                        gs_sec))
        if need_final_checkpoint:
            tf.logging.info("TrainRunner: starting final checkpoint thread...")
            checkpoint_threads.append(None)
            i = len(checkpoint_threads) - 1
            checkpoint_threads[i] = threading.Thread(
                target=checkpoint_thread_fn,
                args=(self.saver, self.sess),
                daemon=True)
            checkpoint_threads[i].start()
        tf.logging.info("TrainRunner: waiting for infeed thread...")
        self.infeed_thread.join()
        tf.logging.info("TrainRunner: waiting for checkpoint threads...")
        for i in range(num_threads):
            if checkpoint_threads[i] is not None:
                checkpoint_threads[i].join()
                checkpoint_threads[i] = None
        tf.logging.info("TrainRunner: waiting for checkpoint threads (done)")
        if output_summaries:
            tf.logging.info("TrainRunner: closing summary writer...")
            if summary_writer is not None:
                summary_writer.close()
            tf.logging.info("TrainRunner: closing summary writer (done)")
示例#7
0
 def infeed_thread_fn():
     """Build and infeed session.run calls in a background thread."""
     tf.logging.info('Run infeed session.run calls')
     tflex.run(self.input_sess, [self.enqueue_ops])
     tf.logging.info('infeed session.run finished')
示例#8
0
    def initialize(self, input_fn, model_fn, params):
        """Build graphs for the TPU device and the input pipelines.

    Args:
      input_fn: Dataset input graph generation function
      model_fn: Model definition function
      params:  Parameters to input and model functions
    """

        tf.logging.info("TrainRunner: initialize method")

        with tf.device(self.device_for_host()):
            self.global_step = tflex.get_or_create_global_step()

        def infeed_thread_fn():
            """Build and infeed session.run calls in a background thread."""
            tf.logging.info('Run infeed session.run calls')
            tflex.run(self.input_sess, [self.enqueue_ops])
            tf.logging.info('infeed session.run finished')

        for i in tqdm.trange(self.num_hosts):
            self.build_enqueue_ops(input_fn, params, i)

        # Build infeed sesssion
        self.input_sess = tf.Session(self.master,
                                     graph=self.input_graph,
                                     config=self.config)
        self.input_sess.run(self.dataset_initializer)
        tf.logging.info('Ensure infeed data has fully uploaded')
        tflex.flush(self.input_sess)

        def get_tpu_step(mparams):
            """Get the TPU graph generation function."""
            def tpu_pre(loss):
                """Generate the TPU graph."""
                del loss
                values = self.infeed_queue[0].generate_dequeue_op(tpu_device=0)
                unflattened_inputs = data_nest.pack_sequence_as(
                    self.feature_structure, values)
                if "features" in unflattened_inputs and "labels" in unflattened_inputs:
                    features = unflattened_inputs["features"]
                    labels = unflattened_inputs["labels"]
                else:
                    features = unflattened_inputs
                    labels = None
                estimator_spec = model_fn(features, labels,
                                          tf.estimator.ModeKeys.TRAIN, mparams)
                return estimator_spec

            def tpu_make(estimator_spec):
                loss, train_op = estimator_spec.loss, estimator_spec.train_op
                with tf.device(device_for_tpu_core()):
                    with tf.control_dependencies([train_op]):
                        return tf.identity(loss, name="tpu_loss_op")

            def tpu_step(loss):
                estimator_spec = tpu_pre(loss)
                return tpu_make(estimator_spec)

            return tpu_pre, tpu_make, tpu_step

        tpu_pre, tpu_make, tpu_step = get_tpu_step(params)

        if False:
            with tf.Graph().as_default() as self.tpu_graph:
                params['use_tpu'] = False
                self.tpu_global_step = tflex.get_or_create_global_step()
                self.tpu_spec = tpu_pre(_INITIAL_LOSS)
                self.tpu_sess = tf.Session(self.master,
                                           config=self.config,
                                           graph=self.graph)
                import pdb
                pdb.set_trace()
                self.tpu_op = tpu_make(self.tpu_spec)
                params['use_tpu'] = True

        @tpu_function.on_device_training_loop
        def tpu_loop():
            return tpu.repeat(self.iterations, tpu_step, [_INITIAL_LOSS])
            #return tpu_step(_INITIAL_LOSS)

        (self.loss, ) = tpu.shard(
            tpu_loop,
            inputs=[],
            num_shards=self.num_cores,
            outputs_from_all_shards=False,
        )
        if FLAGS.restore_trainable_variables:
            self.var_list = tf.trainable_variables()
        else:
            self.var_list = tf.global_variables()
        self.model_dir = FLAGS.model_dir or train_flags.options().get(
            'model_dir')
        self.saver = None
        if self.model_dir is not None:
            self.saver = tf.train.Saver(var_list=self.var_list,
                                        keep_checkpoint_every_n_hours=0.5)
            # Why do this?
            # graph_io.write_graph(tf.Graph().as_graph_def(add_shapes=True), self.model_dir, "graph.pbtxt")

        # Build tpu train model session and initialize graph
        self.sess = tf.Session(self.master, config=self.config)
        self.initializer = [
            tf.local_variables_initializer(),
            tf.global_variables_initializer()
        ]
        self.extra_initializers = tf.get_collection(
            'tftorch_initializers'
        )  # a bit of a hack, but it gets the job done
        tflex.run(self.sess, self.initializer)
        tflex.run(self.sess, self.extra_initializers)

        if FLAGS.restore_dir is not None:
            ckpt = tf.train.latest_checkpoint(FLAGS.restore_dir)
            if ckpt is None:
                #raise ValueError("restore_dir has no latest_checkpoint: %s" % repr(FLAGS.restore_dir))
                pass
            else:
                step = tflex.checkpoint_step(ckpt) or 0
                saver = tf.train.Saver(var_list=self.var_list,
                                       restore_sequentially=True)
                for x in self.var_list:
                    tf.logging.info('\t%s', repr(x))
                tf.logging.info('Restoring %s step %d', ckpt, step)
                saver.restore(self.sess, ckpt)
                tf.logging.info('Setting step %d', step)
                self.global_step.load(step, self.sess)
                tf.logging.info('Restoring %s step %d (done)', ckpt, step)
        self.cur_step = tflex.run(self.sess, self.global_step)

        # Complete infeed graph generation and session.run calls
        def null_fn():
            pass

        self.infeed_thread = threading.Thread(target=null_fn, daemon=True)
        self.infeed_thread.start()
        self.infeed_thread_fn = infeed_thread_fn
示例#9
0
    def initialize(self, input_fn, model_fn, params):
        """Build graphs for the TPU device and the input pipelines.

    Args:
      input_fn: Dataset input graph generation function
      model_fn: Model definition function
      params:  Parameters to input and model functions
    """

        tf.logging.info("TrainRunner: initialize method")

        with tf.device(self.device_for_host()):
            self.global_step = tflex.get_or_create_global_step()

        def infeed_thread_fn():
            """Build and infeed session.run calls in a background thread."""
            i = 1
            while i < FLAGS.num_cores // FLAGS.tpu_cores_per_host:
                self.build_enqueue_ops(input_fn, params, i)
                i += 1
            # Build infeed sesssion
            self.input_sess = tf.Session(self.cluster_resolver.get_master(),
                                         graph=self.input_graph,
                                         config=self.config)
            self.input_sess.run(self.dataset_initializer)
            tf.logging.info('Ensure infeed data has fully uploaded')
            tflex.flush(self.input_sess)
            tf.logging.info('Run infeed session.run calls')
            tflex.run(self.input_sess, [self.enqueue_ops])

        self.build_enqueue_ops(input_fn, params, 0)

        def get_tpu_step(mparams):
            """Get the TPU graph generation function."""
            def tpu_pre(loss):
                """Generate the TPU graph."""
                del loss
                values = self.infeed_queue[0].generate_dequeue_op(tpu_device=0)
                unflattened_inputs = data_nest.pack_sequence_as(
                    self.feature_structure, values)
                features = unflattened_inputs["features"]
                labels = unflattened_inputs["labels"]
                estimator_spec = model_fn(features, labels,
                                          tf.estimator.ModeKeys.TRAIN, mparams)
                return estimator_spec

            def tpu_make(estimator_spec):
                loss, train_op = estimator_spec.loss, estimator_spec.train_op
                with tf.device(device_for_tpu_core()):
                    with tf.control_dependencies([train_op]):
                        return tf.identity(loss, name="tpu_loss_op")

            def tpu_step(loss):
                estimator_spec = tpu_pre(loss)
                return tpu_make(estimator_spec)

            return tpu_pre, tpu_make, tpu_step

        tpu_pre, tpu_make, tpu_step = get_tpu_step(params)

        if False:
            with tf.Graph().as_default() as self.tpu_graph:
                params['use_tpu'] = False
                self.tpu_global_step = tflex.get_or_create_global_step()
                self.tpu_spec = tpu_pre(_INITIAL_LOSS)
                self.tpu_sess = tf.Session(self.cluster_resolver.get_master(),
                                           config=self.config,
                                           graph=self.graph)
                import pdb
                pdb.set_trace()
                self.tpu_op = tpu_make(self.tpu_spec)
                params['use_tpu'] = True

        @tpu_function.on_device_training_loop
        def tpu_loop():
            return tpu.repeat(self.iterations, tpu_step, [_INITIAL_LOSS])
            #return tpu_step(_INITIAL_LOSS)

        (self.loss, ) = tpu.shard(
            tpu_loop,
            inputs=[],
            num_shards=FLAGS.num_cores,
            outputs_from_all_shards=False,
        )
        initializer = tf.global_variables_initializer()
        self.saver = tf.train.Saver(keep_checkpoint_every_n_hours=0.5)
        graph_io.write_graph(tf.Graph().as_graph_def(add_shapes=True),
                             FLAGS.model_dir, "graph.pbtxt")

        # Build tpu train model session and initialize graph
        self.sess = tf.Session(self.cluster_resolver.get_master(),
                               config=self.config)
        tflex.run(self.sess, initializer)

        if FLAGS.restore_dir is not None:
            ckpt = tf.train.latest_checkpoint(FLAGS.restore_dir)
            if ckpt is not None:
                if FLAGS.restore_trainable_variables:
                    var_list = tf.trainable_variables()
                    if params['n_ctx'] != 1024:
                        var_list = [
                            x for x in var_list if '/wpe' not in x.name
                        ]
                else:
                    var_list = tf.global_variables()
                saver = tf.train.Saver(var_list=var_list,
                                       restore_sequentially=True)
                tf.logging.info('Restoring %s', ckpt)
                for x in var_list:
                    tf.logging.info('\t%s', repr(x))
                saver.restore(self.sess, ckpt)
                tf.logging.info('Restoring %s (done)', ckpt)
        self.cur_step = tflex.run(self.sess, self.global_step)

        # Complete infeed graph generation and session.run calls
        self.infeed_thread = threading.Thread(target=infeed_thread_fn,
                                              daemon=True)
        self.infeed_thread.start()