def __init__(self, iterations, train_steps=-1): tf.logging.info("TrainRunner: constructor") self.feature_structure = None self.loss = None self.infeed_queue = [] self.enqueue_ops = [] self.dataset_initializer = [] self.iterations = iterations self.sess = None self.input_sess = None self.infeed_thread = None if train_steps < 0: train_steps = None if train_steps is not None: if train_steps % iterations != 0: train_steps = iterations * int( math.ceil(train_steps / iterations)) self.train_steps = train_steps self.input_graph = tf.Graph() with tf.Graph().as_default() as self.init_graph: self.tpu_init = tpu.initialize_system() self.tpu_shutdown = tpu.shutdown_system() #self.cluster_resolver = tflex.TPUClusterResolver( self.cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( FLAGS.tpu or FLAGS.master, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) self.config = tf.ConfigProto( operation_timeout_in_ms=600 * 60 * 1000, graph_options=tf.GraphOptions( rewrite_options=rewriter_config_pb2.RewriterConfig( disable_meta_optimizer=True)), isolate_session_state=True) cluster_spec = self.cluster_resolver.cluster_spec() if cluster_spec: self.config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) self.master = self.cluster_resolver.get_master() self.init_sess = tf.Session(self.master, graph=self.init_graph, config=self.config) tf.logging.info("TrainRunner: initializing TPU session...") if not bool(int(os.environ.get('TPU_NO_INIT', '0'))): tflex.run(self.init_sess, self.tpu_init) tf.logging.info("TrainRunner: initializing TPU session (done)") self.devices = self.init_sess.list_devices() self.cores = sorted( [x.name for x in self.devices if ':TPU:' in x.name]) self.num_cores = len(self.cores) self.tpu_cores_per_host = 8 assert self.num_cores % self.tpu_cores_per_host == 0 self.num_hosts = self.num_cores // self.tpu_cores_per_host print(self.config.cluster_def) print('cores: %d hosts: %d ip: %s' % (self.num_cores, self.num_hosts, self.master))
def infeed_thread_fn(): """Build and infeed session.run calls in a background thread.""" i = 1 while i < FLAGS.num_cores // FLAGS.tpu_cores_per_host: self.build_enqueue_ops(input_fn, params, i) i += 1 # Build infeed sesssion self.input_sess = tf.Session(self.cluster_resolver.get_master(), graph=self.input_graph, config=self.config) self.input_sess.run(self.dataset_initializer) tf.logging.info('Ensure infeed data has fully uploaded') tflex.flush(self.input_sess) tf.logging.info('Run infeed session.run calls') tflex.run(self.input_sess, [self.enqueue_ops])
def __init__(self, iterations, train_steps): tf.logging.info("TrainRunner: constructor") self.feature_structure = {} self.loss = None self.infeed_queue = [] self.enqueue_ops = [] self.dataset_initializer = [] self.iterations = iterations self.sess = None self.input_sess = None self.infeed_thread = None if train_steps % iterations != 0: train_steps = iterations * int(math.ceil(train_steps / iterations)) self.train_steps = train_steps self.input_graph = tf.Graph() with tf.Graph().as_default() as self.init_graph: self.tpu_init = tpu.initialize_system() self.tpu_shutdown = tpu.shutdown_system() #self.cluster_resolver = tflex.TPUClusterResolver( self.cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( FLAGS.tpu or FLAGS.master, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) self.config = tf.ConfigProto( operation_timeout_in_ms=600 * 60 * 1000, graph_options=tf.GraphOptions( rewrite_options=rewriter_config_pb2.RewriterConfig( disable_meta_optimizer=True)), isolate_session_state=True) cluster_spec = self.cluster_resolver.cluster_spec() if cluster_spec: self.config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) self.init_sess = tf.Session(self.cluster_resolver.get_master(), graph=self.init_graph, config=self.config) tf.logging.info("TrainRunner: initializing TPU session...") if not bool(int(os.environ.get('TPU_NO_INIT', '0'))): tflex.run(self.init_sess, self.tpu_init) tf.logging.info("TrainRunner: initializing TPU session (done)")
def main(unused_argv): logging.info("Gin config: %s\nGin bindings: %s", FLAGS.gin_config, FLAGS.gin_bindings) gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings) global params #FLAGS.iterations_per_loop = 100 #params = {'batch_size': FLAGS.train_batch_size} #params = {'batch_size': 128, 'use_tpu': True, 'precision': 'float32'} # with open(FLAGS.params) as f: # params = json.load(f) params = options() params['use_tpu'] = getval('use_tpu', True) params['batch_per_core'] = getval('batch_per_core', 1) params['iterations'] = getval('iterations', 20) params['batch_size'] = FLAGS.num_cores * params['batch_per_core'] params['opt_name'] = getval('opt_name', 'adam') params['beta1'] = getval('beta1', 0.9) params['beta2'] = getval('beta2', 0.999) params['epsilon'] = getval('epsilon', 1e-9) params['lr'] = getval('lr', 0.00025) FLAGS.train_batch_size = params['batch_size'] FLAGS.iterations_per_loop = params['iterations'] FLAGS.train_steps = getval('train_steps', int(2e6)) params['precision'] = getval('precision', 'float32') params['model'] = getval('model', 'GPT2') assert params['model'] in ['GPT2', 'GPT2Rev'] graph = tf.Graph() with graph.as_default(): master = FLAGS.tpu or FLAGS.master or getval('TPU_NAME', 'unknown') cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( master, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) config = tf.ConfigProto( operation_timeout_in_ms=600 * 60 * 1000, # graph_options=tf.GraphOptions( # rewrite_options=rewriter_config_pb2.RewriterConfig( # disable_meta_optimizer=True, # ), # ), isolate_session_state=True) cluster_spec = cluster_resolver.cluster_spec() if cluster_spec: config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) sess = tf.InteractiveSession(cluster_resolver.get_master(), graph=graph, config=config) devices = sess.list_devices() cores = sorted([x.name for x in devices if ':TPU:' in x.name]) num_cores = len(cores) assert num_cores % 8 == 0 num_hosts = num_cores // 8 print(config.cluster_def) print('cores: %d hosts: %d ip: %s' % (num_cores, num_hosts, master)) tf.logging.info("TrainRunner: initializing TPU session...") if not bool(int(os.environ.get('TPU_NO_INIT', '0'))): tflex.run(sess, tf.tpu.initialize_system()) tf.logging.info("TrainRunner: initializing TPU session (done)") gan = BigGAN.GAN() pp(tf.trainable_variables()) import pdb pdb.set_trace() # seed = 0 # dataset = ImageNet.make_dataset(FLAGS.dataset or "gs://dota-euw4a/datasets/danbooru2019-s/danbooru2019-s-0*", 0, 1, seed=seed) # it = iterate_dataset(dataset) # def go(): # zz = next(it) # images = [zz['image']] # labels = [zz['label']] # #import IPython # print('label', labels[0]) # #print(labels[0] - 1, imagenet_label_names[labels[0] - 1]) # print(images[0].shape) # print('embedding', zz['parsed']['image/class/embedding'].values.shape) # print('filename', zz['parsed']['image/filename']) # print('hash', zz['parsed']['image/hash']) # op = tf.io.encode_jpeg(images[0]) # with open('test.png', 'wb') as f: # f.write(sess.run(op)) # go() import pdb pdb.set_trace() dataset = dataset
def shutdown(self): tf.logging.info("TrainRunner: shutting down...") if not bool(int(os.environ.get('TPU_NO_INIT', '0'))): tflex.run(self.init_sess, self.tpu_shutdown) tf.logging.info("TrainRunner: shutting down (done)")
def train(self, num_threads=1, output_summaries=True): """Run the Train steps on the TPU device. Args: num_threads: number of outstanding checkpointing threads """ if output_summaries and self.model_dir is not None: output_dir = os.path.join(self.model_dir, "eval") tf.gfile.MakeDirs(output_dir) # Summary writer writes out eval metrics. summary_writer = tf.compat.v1.summary.FileWriter(output_dir) else: summary_writer = None def checkpoint_thread_fn(saver, sess, force=False): step = self.cur_step if self.model_dir is None: tf.logging.info( 'step %d: model_dir is None; not saving checkpoint %s-%d', step, 'model.ckpt', step) return if not force: if train_flags.options().get('no_save'): tf.logging.info( 'step %d: options.no_save is set; not saving checkpoint %s-%d', step, 'model.ckpt', step) return path = self.model_dir + "/model.ckpt" tf.logging.info('step %d: Saving checkpoint %s-%d...', step, path, step) now = time.time() saver.save(sess, path, write_meta_graph=False, global_step=step) elapsed = time.time() - now tf.logging.info('step %d: Saved checkpoint %s-%d in %.2fs', step, path, step, elapsed) @tflex.register_command def save(): checkpoint_thread_fn(self.saver, self.sess, force=True) thread_id = 0 checkpoint_threads = [] need_final_checkpoint = False tf.logging.info("TrainRunner: step %d", self.cur_step) #tflex.run(sess, self.global_step.initializer, dict([(self.global_step.initializer.inputs[1], self.cur_step)])) for i in range(num_threads): checkpoint_threads.append(None) end_step = None if self.train_steps is None else (self.cur_step + self.train_steps) while True if end_step is None else (self.cur_step < end_step): tflex.check_commands() if tflex.should_quit(): tf.logging.info("TrainRunner: quitting") break start = time.time() tf.logging.info("TrainRunner: start next %d steps", self.iterations) self.cur_step += self.iterations self.infeed_thread_fn() loss = tflex.run(self.sess, [self.loss]) thread = checkpoint_threads[thread_id] if checkpoint_threads[thread_id] is not None and checkpoint_threads[ thread_id].is_alive(): tf.logging.info( "TrainRunner: checkpoint thread still active; skipping") need_final_checkpoint = True else: tf.logging.info("TrainRunner: starting checkpoint thread...") if checkpoint_threads[thread_id] is not None: checkpoint_threads[thread_id].join() checkpoint_threads[thread_id] = threading.Thread( target=checkpoint_thread_fn, args=(self.saver, self.sess), daemon=True) checkpoint_threads[thread_id].start() need_final_checkpoint = False thread_id += 1 if thread_id >= num_threads: thread_id = 0 end = time.time() tf.logging.info("TrainRunner: fetching global_step...") gs = tflex.run(self.sess, self.global_step) step_sec = end - start gs_sec = self.iterations / step_sec ex_sec = self.iterations * self.train_batch_size / (end - start) # Write out summary to tensorboard. if output_summaries: tf.logging.info("TrainRunner: writing summaries...") with tf.Graph().as_default(): eval_results = { 'loss': loss, 'iterations_per_step': self.iterations, 'seconds_per_step': step_sec, 'global_step_per_second': gs_sec, 'examples_per_second': ex_sec, 'train_batch_size_per_core': self.train_batch_size // self.num_cores, 'num_cores': self.num_cores, } for metric in eval_results: values = eval_results[metric] if not isinstance(values, list): values = [values] for i, value in enumerate(values): tag = '{}_{:02d}'.format(metric, i) if i > 0 else metric step = self.cur_step - len(values) + i + 1 summaries = [] summaries.append( tf.Summary.Value(tag=tag, simple_value=value)) tf_summary = tf.Summary(value=list(summaries)) if summary_writer is not None: summary_writer.add_summary(tf_summary, step) tf.logging.info("TrainRunner: flushing summaries (%d)...", self.cur_step) def thunk(cur_step): if summary_writer is not None: summary_writer.flush() tf.logging.info( "TrainRunner: flushing summaries (%d) (done)", cur_step) tflex.parallelize([self.cur_step], thunk) tf.logging.info( "TrainRunner: step={} global={} end={} loss={} step_time={:.2f}sec examples/sec={:.7f} global_step/sec={:.7f}" .format(self.cur_step, gs, end_step, loss, step_sec, ex_sec, gs_sec)) if need_final_checkpoint: tf.logging.info("TrainRunner: starting final checkpoint thread...") checkpoint_threads.append(None) i = len(checkpoint_threads) - 1 checkpoint_threads[i] = threading.Thread( target=checkpoint_thread_fn, args=(self.saver, self.sess), daemon=True) checkpoint_threads[i].start() tf.logging.info("TrainRunner: waiting for infeed thread...") self.infeed_thread.join() tf.logging.info("TrainRunner: waiting for checkpoint threads...") for i in range(num_threads): if checkpoint_threads[i] is not None: checkpoint_threads[i].join() checkpoint_threads[i] = None tf.logging.info("TrainRunner: waiting for checkpoint threads (done)") if output_summaries: tf.logging.info("TrainRunner: closing summary writer...") if summary_writer is not None: summary_writer.close() tf.logging.info("TrainRunner: closing summary writer (done)")
def infeed_thread_fn(): """Build and infeed session.run calls in a background thread.""" tf.logging.info('Run infeed session.run calls') tflex.run(self.input_sess, [self.enqueue_ops]) tf.logging.info('infeed session.run finished')
def initialize(self, input_fn, model_fn, params): """Build graphs for the TPU device and the input pipelines. Args: input_fn: Dataset input graph generation function model_fn: Model definition function params: Parameters to input and model functions """ tf.logging.info("TrainRunner: initialize method") with tf.device(self.device_for_host()): self.global_step = tflex.get_or_create_global_step() def infeed_thread_fn(): """Build and infeed session.run calls in a background thread.""" tf.logging.info('Run infeed session.run calls') tflex.run(self.input_sess, [self.enqueue_ops]) tf.logging.info('infeed session.run finished') for i in tqdm.trange(self.num_hosts): self.build_enqueue_ops(input_fn, params, i) # Build infeed sesssion self.input_sess = tf.Session(self.master, graph=self.input_graph, config=self.config) self.input_sess.run(self.dataset_initializer) tf.logging.info('Ensure infeed data has fully uploaded') tflex.flush(self.input_sess) def get_tpu_step(mparams): """Get the TPU graph generation function.""" def tpu_pre(loss): """Generate the TPU graph.""" del loss values = self.infeed_queue[0].generate_dequeue_op(tpu_device=0) unflattened_inputs = data_nest.pack_sequence_as( self.feature_structure, values) if "features" in unflattened_inputs and "labels" in unflattened_inputs: features = unflattened_inputs["features"] labels = unflattened_inputs["labels"] else: features = unflattened_inputs labels = None estimator_spec = model_fn(features, labels, tf.estimator.ModeKeys.TRAIN, mparams) return estimator_spec def tpu_make(estimator_spec): loss, train_op = estimator_spec.loss, estimator_spec.train_op with tf.device(device_for_tpu_core()): with tf.control_dependencies([train_op]): return tf.identity(loss, name="tpu_loss_op") def tpu_step(loss): estimator_spec = tpu_pre(loss) return tpu_make(estimator_spec) return tpu_pre, tpu_make, tpu_step tpu_pre, tpu_make, tpu_step = get_tpu_step(params) if False: with tf.Graph().as_default() as self.tpu_graph: params['use_tpu'] = False self.tpu_global_step = tflex.get_or_create_global_step() self.tpu_spec = tpu_pre(_INITIAL_LOSS) self.tpu_sess = tf.Session(self.master, config=self.config, graph=self.graph) import pdb pdb.set_trace() self.tpu_op = tpu_make(self.tpu_spec) params['use_tpu'] = True @tpu_function.on_device_training_loop def tpu_loop(): return tpu.repeat(self.iterations, tpu_step, [_INITIAL_LOSS]) #return tpu_step(_INITIAL_LOSS) (self.loss, ) = tpu.shard( tpu_loop, inputs=[], num_shards=self.num_cores, outputs_from_all_shards=False, ) if FLAGS.restore_trainable_variables: self.var_list = tf.trainable_variables() else: self.var_list = tf.global_variables() self.model_dir = FLAGS.model_dir or train_flags.options().get( 'model_dir') self.saver = None if self.model_dir is not None: self.saver = tf.train.Saver(var_list=self.var_list, keep_checkpoint_every_n_hours=0.5) # Why do this? # graph_io.write_graph(tf.Graph().as_graph_def(add_shapes=True), self.model_dir, "graph.pbtxt") # Build tpu train model session and initialize graph self.sess = tf.Session(self.master, config=self.config) self.initializer = [ tf.local_variables_initializer(), tf.global_variables_initializer() ] self.extra_initializers = tf.get_collection( 'tftorch_initializers' ) # a bit of a hack, but it gets the job done tflex.run(self.sess, self.initializer) tflex.run(self.sess, self.extra_initializers) if FLAGS.restore_dir is not None: ckpt = tf.train.latest_checkpoint(FLAGS.restore_dir) if ckpt is None: #raise ValueError("restore_dir has no latest_checkpoint: %s" % repr(FLAGS.restore_dir)) pass else: step = tflex.checkpoint_step(ckpt) or 0 saver = tf.train.Saver(var_list=self.var_list, restore_sequentially=True) for x in self.var_list: tf.logging.info('\t%s', repr(x)) tf.logging.info('Restoring %s step %d', ckpt, step) saver.restore(self.sess, ckpt) tf.logging.info('Setting step %d', step) self.global_step.load(step, self.sess) tf.logging.info('Restoring %s step %d (done)', ckpt, step) self.cur_step = tflex.run(self.sess, self.global_step) # Complete infeed graph generation and session.run calls def null_fn(): pass self.infeed_thread = threading.Thread(target=null_fn, daemon=True) self.infeed_thread.start() self.infeed_thread_fn = infeed_thread_fn
def initialize(self, input_fn, model_fn, params): """Build graphs for the TPU device and the input pipelines. Args: input_fn: Dataset input graph generation function model_fn: Model definition function params: Parameters to input and model functions """ tf.logging.info("TrainRunner: initialize method") with tf.device(self.device_for_host()): self.global_step = tflex.get_or_create_global_step() def infeed_thread_fn(): """Build and infeed session.run calls in a background thread.""" i = 1 while i < FLAGS.num_cores // FLAGS.tpu_cores_per_host: self.build_enqueue_ops(input_fn, params, i) i += 1 # Build infeed sesssion self.input_sess = tf.Session(self.cluster_resolver.get_master(), graph=self.input_graph, config=self.config) self.input_sess.run(self.dataset_initializer) tf.logging.info('Ensure infeed data has fully uploaded') tflex.flush(self.input_sess) tf.logging.info('Run infeed session.run calls') tflex.run(self.input_sess, [self.enqueue_ops]) self.build_enqueue_ops(input_fn, params, 0) def get_tpu_step(mparams): """Get the TPU graph generation function.""" def tpu_pre(loss): """Generate the TPU graph.""" del loss values = self.infeed_queue[0].generate_dequeue_op(tpu_device=0) unflattened_inputs = data_nest.pack_sequence_as( self.feature_structure, values) features = unflattened_inputs["features"] labels = unflattened_inputs["labels"] estimator_spec = model_fn(features, labels, tf.estimator.ModeKeys.TRAIN, mparams) return estimator_spec def tpu_make(estimator_spec): loss, train_op = estimator_spec.loss, estimator_spec.train_op with tf.device(device_for_tpu_core()): with tf.control_dependencies([train_op]): return tf.identity(loss, name="tpu_loss_op") def tpu_step(loss): estimator_spec = tpu_pre(loss) return tpu_make(estimator_spec) return tpu_pre, tpu_make, tpu_step tpu_pre, tpu_make, tpu_step = get_tpu_step(params) if False: with tf.Graph().as_default() as self.tpu_graph: params['use_tpu'] = False self.tpu_global_step = tflex.get_or_create_global_step() self.tpu_spec = tpu_pre(_INITIAL_LOSS) self.tpu_sess = tf.Session(self.cluster_resolver.get_master(), config=self.config, graph=self.graph) import pdb pdb.set_trace() self.tpu_op = tpu_make(self.tpu_spec) params['use_tpu'] = True @tpu_function.on_device_training_loop def tpu_loop(): return tpu.repeat(self.iterations, tpu_step, [_INITIAL_LOSS]) #return tpu_step(_INITIAL_LOSS) (self.loss, ) = tpu.shard( tpu_loop, inputs=[], num_shards=FLAGS.num_cores, outputs_from_all_shards=False, ) initializer = tf.global_variables_initializer() self.saver = tf.train.Saver(keep_checkpoint_every_n_hours=0.5) graph_io.write_graph(tf.Graph().as_graph_def(add_shapes=True), FLAGS.model_dir, "graph.pbtxt") # Build tpu train model session and initialize graph self.sess = tf.Session(self.cluster_resolver.get_master(), config=self.config) tflex.run(self.sess, initializer) if FLAGS.restore_dir is not None: ckpt = tf.train.latest_checkpoint(FLAGS.restore_dir) if ckpt is not None: if FLAGS.restore_trainable_variables: var_list = tf.trainable_variables() if params['n_ctx'] != 1024: var_list = [ x for x in var_list if '/wpe' not in x.name ] else: var_list = tf.global_variables() saver = tf.train.Saver(var_list=var_list, restore_sequentially=True) tf.logging.info('Restoring %s', ckpt) for x in var_list: tf.logging.info('\t%s', repr(x)) saver.restore(self.sess, ckpt) tf.logging.info('Restoring %s (done)', ckpt) self.cur_step = tflex.run(self.sess, self.global_step) # Complete infeed graph generation and session.run calls self.infeed_thread = threading.Thread(target=infeed_thread_fn, daemon=True) self.infeed_thread.start()