def train_tpu(params, should_eval=False): """Training routines.""" set_tpu_info(params) train_graph = tf.Graph() # train_op infeed_ops, infeed_graphs = data_utils.build_train_infeeds(params) with train_graph.as_default(): model = get_model_builder(params) if 'mpl' in params.dataset_name.lower(): train_class = training_utils.MPL() elif 'uda' in params.dataset_name.lower(): train_class = training_utils.UDA() else: train_class = training_utils.Supervised() @tpu_function.on_device_training_loop def train_loop(): """Docs.""" def _cond(step): return tf.less(step, tf.cast(params.save_every, step.dtype)) def _body(step): run_op = train_class.step_fn(params, model) with tf.control_dependencies([run_op]): return step + 1 loop_inps = [tf.cast(0, tf.int32)] loop_outs = tf.while_loop(_cond, _body, loop_inps, parallel_iterations=1, name='train') train_op = loop_outs.op return train_op train_op = tf.tpu.shard(computation=train_loop, num_shards=params.num_replicas, device_assignment=params.device_assignment) global_step = tf.train.get_or_create_global_step() reset_global_step = tf.assign(global_step, 0) num_params = common_utils.count_params() logging.info(f'Model has {num_params} params') if should_eval: eval_fn, eval_summary_writer = prepare_eval( params=params, model=model, eval_logdir=f'eval_{params.image_size}', while_training=True) best_acc = -1. tf.io.write_graph(train_graph, params.output_dir, 'train.pbtxt', as_text=True) # outfeed_dequeue_op outfeed_signature = train_class.outfeed_signature() outfeed_ops, outfeed_graph = common_utils.get_outfeed_ops( params, outfeed_signature) # saver max_to_keep = 1 if should_eval else None saver = common_utils.get_saver(max_to_keep=max_to_keep) ckpt_dir = os.path.join(params.output_dir, 'ckpt_last') async_checkpoint = common_utils.AsyncCheckpoint( saver, ckpt_dir, max_to_keep=max_to_keep) if params.load_checkpoint is not None: reader = tf.train.NewCheckpointReader(params.load_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() var_list = {} global_variables = tf.global_variables() for v in global_variables: v_name = common_utils.strip_var_name(v.name) if v_name in var_to_shape_map: var_list[v_name] = v else: logging.info( f'NOT FOUND: v.name={v.name:<100} v_name={v_name}') logging.info( f'Found {len(var_list)} variables (out of {len(global_variables)})' f' in {params.load_checkpoint}') restore_saver = tf.train.Saver(var_list=var_list) else: restore_saver = saver # actually run tpu_init = tf.tpu.initialize_system() var_init = tf.global_variables_initializer() tpu_shutdown = tf.tpu.shutdown_system() with common_utils.get_session(params) as sess: logging.info('Initialize TPU system') sess.run(tpu_init) run_options = tf.RunOptions( timeout_in_ms=1000 * 60 * 60 * 10, # 10 hours report_tensor_allocations_upon_oom=True) sess.run(var_init, options=run_options) latest_checkpoint = common_utils.get_latest_checkpoint(ckpt_dir) if latest_checkpoint is not None: logging.info(f'Initialize vars from `{latest_checkpoint}`') saver.restore(sess, latest_checkpoint) elif params.load_checkpoint is not None: logging.info( f'Initialize vars from `{params.load_checkpoint}`') restore_saver.restore(sess, params.load_checkpoint) if params.load_checkpoint_and_restart_global_step: logging.info('Reset global_step to 0') sess.run(reset_global_step) else: logging.info('Initialize vars from scratch') infeed_thread = common_utils.InfeedThread( params=params, infeed_ops=infeed_ops, infeed_graphs=infeed_graphs, name='train_infeed') outfeed_thread = common_utils.OutfeedThread( params, outfeed_ops, outfeed_graph, outfeed_signature) outfeed_thread.start() logging.info('Start training') while True: step = sess.run(global_step) if step >= params.num_train_steps: break infeed_thread.start() sess.run(train_op, options=run_options) step = sess.run(global_step) async_checkpoint.save(sess, step) infeed_thread.join() if should_eval: weak_result, acc = eval_fn(sess, step) if weak_result and not params.running_local_dev: logging.info('Weak results. Stop training') break if best_acc < acc: best_acc = acc logging.info( 'Wait for [infeed,outfeed,eval,checkpoint]_thread to stop') async_checkpoint.join() infeed_thread.stop() outfeed_thread.join() if should_eval: eval_summary_writer.close() with gfile.GFile(os.path.join(params.output_dir, 'acc'), 'w') as fout: fout.write(f'{best_acc:<.10f}') logging.info('Shut down TPU system.') sess.run(tpu_shutdown)
def prepare_eval(params, model, eval_logdir, while_training=False): """Docs.""" (eval_infeed_ops, eval_infeed_graphs, eval_size) = data_utils.build_eval_infeeds(params) eval_infeed_thread = common_utils.InfeedThread( params=params, infeed_ops=eval_infeed_ops, infeed_graphs=eval_infeed_graphs, name='eval_infeed') num_eval_steps = eval_size // params.eval_batch_size logging.info(f'Each eval will run for {num_eval_steps} steps') def eval_loop(): """Docs.""" def _cond(step, *args): # pylint: disable=unused-argument return tf.less(step, tf.cast(num_eval_steps, step.dtype)) def _body(step, *args): outs = training_utils.eval_step_fn(params, model) new_args = [step + 1 ] + [a.write(step, o) for a, o in zip(args, outs)] return tuple(new_args) batch_size = batch_size = params.eval_batch_size // params.num_replicas num_classes = params.num_classes logits = tf.TensorArray(dtype=tf.float32, size=num_eval_steps, element_shape=[batch_size, num_classes]) labels = tf.TensorArray(dtype=tf.int32, size=num_eval_steps, element_shape=[batch_size, 1]) masks = tf.TensorArray(dtype=tf.float32, size=num_eval_steps, element_shape=[batch_size]) loop_inps = [0, logits, labels, masks] loop_outs = tf.while_loop(_cond, _body, loop_inps, parallel_iterations=1, name='eval') return [o.concat() for o in loop_outs[1:]] if while_training: with tf.variable_scope('ema', reuse=True): eval_op = tf.tpu.shard(computation=eval_loop, num_shards=params.num_replicas, device_assignment=params.device_assignment) else: eval_op = tf.tpu.shard(computation=eval_loop, num_shards=params.num_replicas, device_assignment=params.device_assignment) eval_logdir = os.path.join(params.output_dir, 'logs', eval_logdir) if gfile.IsDirectory(eval_logdir): gfile.DeleteRecursively(eval_logdir) gfile.MakeDirs(eval_logdir, mode=0o777) summary_writer = tf.summary.FileWriter(eval_logdir) def eval_fn(sess, step): """Docs.""" eval_infeed_thread.start() logits, labels, mask = sess.run(eval_op) num_examples = np.sum(mask) sorted_indices = np.argsort(logits, axis=-1) def _top_k(k): in_top_k = np.any(sorted_indices[:, -k:] == labels, axis=-1) total = np.sum(in_top_k.astype(np.float32) * mask) return total / num_examples top_1, top_5 = _top_k(k=1), _top_k(k=5) tb_step = step // 1000 if params.task_mode == 'eval_forever' else step summary_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag='eval/top_1', simple_value=top_1), tf.Summary.Value(tag='eval/top_5', simple_value=top_5), ]), tb_step) summary_writer.flush() log_string = ' '.join([ f'step={step:<8d}', f'total={int(num_examples):<6d}', f'top_1={top_1:<8.6f}', f'top_5={top_5:<8.6f}', ]) logging.info(log_string) weak_result = False eval_infeed_thread.join() return weak_result, top_1 return eval_fn, summary_writer