def eval_tpu(params): """Eval routines.""" set_tpu_info(params) tf.reset_default_graph() model = get_model_builder(params) if 'eval_image_size' in params: image_size = max(params.image_size, params.eval_image_size) else: image_size = params.image_size eval_fn, eval_summary_writer = prepare_eval( params=params, model=model, eval_logdir=f'eval_{image_size}_all', while_training=False) global_step = tf.train.get_or_create_global_step() # saver saver = common_utils.get_saver(max_to_keep=None, restore_ema=True) ckpt_dir = os.path.join(params.output_dir, 'ckpt_last') # best checkpoint best_acc = -1. best_acc_ckpt_name = None best_acc_path = os.path.join(params.output_dir, 'ckpt_best') best_acc_file = os.path.join(best_acc_path, 'best_acc') if not gfile.IsDirectory(best_acc_path): gfile.MakeDirs(best_acc_path) # actually run tpu_init = tf.tpu.initialize_system() tpu_shutdown = tf.tpu.shutdown_system() with common_utils.get_session(params) as sess: logging.info('Initialize TPU system') sess.run(tpu_init) all_checkpoints = tf.train.get_checkpoint_state( ckpt_dir).all_model_checkpoint_paths logging.info('Start eval') for ckpt_name in all_checkpoints: saver.restore(sess, ckpt_name) step = sess.run(global_step) _, curr_acc = eval_fn(sess, step) if best_acc < curr_acc: best_acc = curr_acc best_acc_ckpt_name = ckpt_name logging.info('Shut down TPU system.') sess.run(tpu_shutdown) eval_summary_writer.close() with gfile.GFile(best_acc_file, 'w') as fout: fout.write(f'{best_acc:<6.4f}') saver.restore(sess, best_acc_ckpt_name) best_ckpt_path = saver.save(sess, save_path=os.path.join( best_acc_path, 'ckpt'), write_meta_graph=False, write_state=False) logging.info(f'Saved best_ckpt `{best_ckpt_path}`')
def train_tpu(params, should_eval=False): """Training routines.""" set_tpu_info(params) train_graph = tf.Graph() # train_op infeed_ops, infeed_graphs = data_utils.build_train_infeeds(params) with train_graph.as_default(): model = get_model_builder(params) if 'mpl' in params.dataset_name.lower(): train_class = training_utils.MPL() elif 'uda' in params.dataset_name.lower(): train_class = training_utils.UDA() else: train_class = training_utils.Supervised() @tpu_function.on_device_training_loop def train_loop(): """Docs.""" def _cond(step): return tf.less(step, tf.cast(params.save_every, step.dtype)) def _body(step): run_op = train_class.step_fn(params, model) with tf.control_dependencies([run_op]): return step + 1 loop_inps = [tf.cast(0, tf.int32)] loop_outs = tf.while_loop(_cond, _body, loop_inps, parallel_iterations=1, name='train') train_op = loop_outs.op return train_op train_op = tf.tpu.shard(computation=train_loop, num_shards=params.num_replicas, device_assignment=params.device_assignment) global_step = tf.train.get_or_create_global_step() reset_global_step = tf.assign(global_step, 0) num_params = common_utils.count_params() logging.info(f'Model has {num_params} params') if should_eval: eval_fn, eval_summary_writer = prepare_eval( params=params, model=model, eval_logdir=f'eval_{params.image_size}', while_training=True) best_acc = -1. tf.io.write_graph(train_graph, params.output_dir, 'train.pbtxt', as_text=True) # outfeed_dequeue_op outfeed_signature = train_class.outfeed_signature() outfeed_ops, outfeed_graph = common_utils.get_outfeed_ops( params, outfeed_signature) # saver max_to_keep = 1 if should_eval else None saver = common_utils.get_saver(max_to_keep=max_to_keep) ckpt_dir = os.path.join(params.output_dir, 'ckpt_last') async_checkpoint = common_utils.AsyncCheckpoint( saver, ckpt_dir, max_to_keep=max_to_keep) if params.load_checkpoint is not None: reader = tf.train.NewCheckpointReader(params.load_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() var_list = {} global_variables = tf.global_variables() for v in global_variables: v_name = common_utils.strip_var_name(v.name) if v_name in var_to_shape_map: var_list[v_name] = v else: logging.info( f'NOT FOUND: v.name={v.name:<100} v_name={v_name}') logging.info( f'Found {len(var_list)} variables (out of {len(global_variables)})' f' in {params.load_checkpoint}') restore_saver = tf.train.Saver(var_list=var_list) else: restore_saver = saver # actually run tpu_init = tf.tpu.initialize_system() var_init = tf.global_variables_initializer() tpu_shutdown = tf.tpu.shutdown_system() with common_utils.get_session(params) as sess: logging.info('Initialize TPU system') sess.run(tpu_init) run_options = tf.RunOptions( timeout_in_ms=1000 * 60 * 60 * 10, # 10 hours report_tensor_allocations_upon_oom=True) sess.run(var_init, options=run_options) latest_checkpoint = common_utils.get_latest_checkpoint(ckpt_dir) if latest_checkpoint is not None: logging.info(f'Initialize vars from `{latest_checkpoint}`') saver.restore(sess, latest_checkpoint) elif params.load_checkpoint is not None: logging.info( f'Initialize vars from `{params.load_checkpoint}`') restore_saver.restore(sess, params.load_checkpoint) if params.load_checkpoint_and_restart_global_step: logging.info('Reset global_step to 0') sess.run(reset_global_step) else: logging.info('Initialize vars from scratch') infeed_thread = common_utils.InfeedThread( params=params, infeed_ops=infeed_ops, infeed_graphs=infeed_graphs, name='train_infeed') outfeed_thread = common_utils.OutfeedThread( params, outfeed_ops, outfeed_graph, outfeed_signature) outfeed_thread.start() logging.info('Start training') while True: step = sess.run(global_step) if step >= params.num_train_steps: break infeed_thread.start() sess.run(train_op, options=run_options) step = sess.run(global_step) async_checkpoint.save(sess, step) infeed_thread.join() if should_eval: weak_result, acc = eval_fn(sess, step) if weak_result and not params.running_local_dev: logging.info('Weak results. Stop training') break if best_acc < acc: best_acc = acc logging.info( 'Wait for [infeed,outfeed,eval,checkpoint]_thread to stop') async_checkpoint.join() infeed_thread.stop() outfeed_thread.join() if should_eval: eval_summary_writer.close() with gfile.GFile(os.path.join(params.output_dir, 'acc'), 'w') as fout: fout.write(f'{best_acc:<.10f}') logging.info('Shut down TPU system.') sess.run(tpu_shutdown)