Exemple #1
0
def eval_tpu(params):
    """Eval routines."""
    set_tpu_info(params)
    tf.reset_default_graph()

    model = get_model_builder(params)
    if 'eval_image_size' in params:
        image_size = max(params.image_size, params.eval_image_size)
    else:
        image_size = params.image_size
    eval_fn, eval_summary_writer = prepare_eval(
        params=params,
        model=model,
        eval_logdir=f'eval_{image_size}_all',
        while_training=False)
    global_step = tf.train.get_or_create_global_step()

    # saver
    saver = common_utils.get_saver(max_to_keep=None, restore_ema=True)
    ckpt_dir = os.path.join(params.output_dir, 'ckpt_last')

    # best checkpoint
    best_acc = -1.
    best_acc_ckpt_name = None
    best_acc_path = os.path.join(params.output_dir, 'ckpt_best')
    best_acc_file = os.path.join(best_acc_path, 'best_acc')
    if not gfile.IsDirectory(best_acc_path):
        gfile.MakeDirs(best_acc_path)

    # actually run
    tpu_init = tf.tpu.initialize_system()
    tpu_shutdown = tf.tpu.shutdown_system()
    with common_utils.get_session(params) as sess:
        logging.info('Initialize TPU system')
        sess.run(tpu_init)

        all_checkpoints = tf.train.get_checkpoint_state(
            ckpt_dir).all_model_checkpoint_paths
        logging.info('Start eval')
        for ckpt_name in all_checkpoints:
            saver.restore(sess, ckpt_name)
            step = sess.run(global_step)
            _, curr_acc = eval_fn(sess, step)
            if best_acc < curr_acc:
                best_acc = curr_acc
                best_acc_ckpt_name = ckpt_name

        logging.info('Shut down TPU system.')
        sess.run(tpu_shutdown)
        eval_summary_writer.close()

        with gfile.GFile(best_acc_file, 'w') as fout:
            fout.write(f'{best_acc:<6.4f}')
        saver.restore(sess, best_acc_ckpt_name)
        best_ckpt_path = saver.save(sess,
                                    save_path=os.path.join(
                                        best_acc_path, 'ckpt'),
                                    write_meta_graph=False,
                                    write_state=False)
        logging.info(f'Saved best_ckpt `{best_ckpt_path}`')
Exemple #2
0
def train_tpu(params, should_eval=False):
    """Training routines."""
    set_tpu_info(params)
    train_graph = tf.Graph()

    # train_op
    infeed_ops, infeed_graphs = data_utils.build_train_infeeds(params)

    with train_graph.as_default():
        model = get_model_builder(params)
        if 'mpl' in params.dataset_name.lower():
            train_class = training_utils.MPL()
        elif 'uda' in params.dataset_name.lower():
            train_class = training_utils.UDA()
        else:
            train_class = training_utils.Supervised()

        @tpu_function.on_device_training_loop
        def train_loop():
            """Docs."""
            def _cond(step):
                return tf.less(step, tf.cast(params.save_every, step.dtype))

            def _body(step):
                run_op = train_class.step_fn(params, model)
                with tf.control_dependencies([run_op]):
                    return step + 1

            loop_inps = [tf.cast(0, tf.int32)]
            loop_outs = tf.while_loop(_cond,
                                      _body,
                                      loop_inps,
                                      parallel_iterations=1,
                                      name='train')
            train_op = loop_outs.op
            return train_op

        train_op = tf.tpu.shard(computation=train_loop,
                                num_shards=params.num_replicas,
                                device_assignment=params.device_assignment)
        global_step = tf.train.get_or_create_global_step()
        reset_global_step = tf.assign(global_step, 0)

        num_params = common_utils.count_params()
        logging.info(f'Model has {num_params} params')

        if should_eval:
            eval_fn, eval_summary_writer = prepare_eval(
                params=params,
                model=model,
                eval_logdir=f'eval_{params.image_size}',
                while_training=True)
            best_acc = -1.

        tf.io.write_graph(train_graph,
                          params.output_dir,
                          'train.pbtxt',
                          as_text=True)

        # outfeed_dequeue_op
        outfeed_signature = train_class.outfeed_signature()
        outfeed_ops, outfeed_graph = common_utils.get_outfeed_ops(
            params, outfeed_signature)

        # saver
        max_to_keep = 1 if should_eval else None
        saver = common_utils.get_saver(max_to_keep=max_to_keep)
        ckpt_dir = os.path.join(params.output_dir, 'ckpt_last')
        async_checkpoint = common_utils.AsyncCheckpoint(
            saver, ckpt_dir, max_to_keep=max_to_keep)

        if params.load_checkpoint is not None:
            reader = tf.train.NewCheckpointReader(params.load_checkpoint)
            var_to_shape_map = reader.get_variable_to_shape_map()
            var_list = {}
            global_variables = tf.global_variables()
            for v in global_variables:
                v_name = common_utils.strip_var_name(v.name)
                if v_name in var_to_shape_map:
                    var_list[v_name] = v
                else:
                    logging.info(
                        f'NOT FOUND: v.name={v.name:<100} v_name={v_name}')
            logging.info(
                f'Found {len(var_list)} variables (out of {len(global_variables)})'
                f' in {params.load_checkpoint}')
            restore_saver = tf.train.Saver(var_list=var_list)
        else:
            restore_saver = saver

        # actually run
        tpu_init = tf.tpu.initialize_system()
        var_init = tf.global_variables_initializer()
        tpu_shutdown = tf.tpu.shutdown_system()
        with common_utils.get_session(params) as sess:
            logging.info('Initialize TPU system')
            sess.run(tpu_init)

            run_options = tf.RunOptions(
                timeout_in_ms=1000 * 60 * 60 * 10,  # 10 hours
                report_tensor_allocations_upon_oom=True)
            sess.run(var_init, options=run_options)

            latest_checkpoint = common_utils.get_latest_checkpoint(ckpt_dir)
            if latest_checkpoint is not None:
                logging.info(f'Initialize vars from `{latest_checkpoint}`')
                saver.restore(sess, latest_checkpoint)
            elif params.load_checkpoint is not None:
                logging.info(
                    f'Initialize vars from `{params.load_checkpoint}`')
                restore_saver.restore(sess, params.load_checkpoint)
                if params.load_checkpoint_and_restart_global_step:
                    logging.info('Reset global_step to 0')
                    sess.run(reset_global_step)
            else:
                logging.info('Initialize vars from scratch')

            infeed_thread = common_utils.InfeedThread(
                params=params,
                infeed_ops=infeed_ops,
                infeed_graphs=infeed_graphs,
                name='train_infeed')
            outfeed_thread = common_utils.OutfeedThread(
                params, outfeed_ops, outfeed_graph, outfeed_signature)
            outfeed_thread.start()

            logging.info('Start training')
            while True:
                step = sess.run(global_step)
                if step >= params.num_train_steps:
                    break

                infeed_thread.start()
                sess.run(train_op, options=run_options)
                step = sess.run(global_step)
                async_checkpoint.save(sess, step)
                infeed_thread.join()

                if should_eval:
                    weak_result, acc = eval_fn(sess, step)
                    if weak_result and not params.running_local_dev:
                        logging.info('Weak results. Stop training')
                        break
                    if best_acc < acc:
                        best_acc = acc

            logging.info(
                'Wait for [infeed,outfeed,eval,checkpoint]_thread to stop')
            async_checkpoint.join()
            infeed_thread.stop()
            outfeed_thread.join()
            if should_eval:
                eval_summary_writer.close()
                with gfile.GFile(os.path.join(params.output_dir, 'acc'),
                                 'w') as fout:
                    fout.write(f'{best_acc:<.10f}')

            logging.info('Shut down TPU system.')
            sess.run(tpu_shutdown)