Esempio n. 1
0
def train_tpu(params, should_eval=False):
    """Training routines."""
    set_tpu_info(params)
    train_graph = tf.Graph()

    # train_op
    infeed_ops, infeed_graphs = data_utils.build_train_infeeds(params)

    with train_graph.as_default():
        model = get_model_builder(params)
        if 'mpl' in params.dataset_name.lower():
            train_class = training_utils.MPL()
        elif 'uda' in params.dataset_name.lower():
            train_class = training_utils.UDA()
        else:
            train_class = training_utils.Supervised()

        @tpu_function.on_device_training_loop
        def train_loop():
            """Docs."""
            def _cond(step):
                return tf.less(step, tf.cast(params.save_every, step.dtype))

            def _body(step):
                run_op = train_class.step_fn(params, model)
                with tf.control_dependencies([run_op]):
                    return step + 1

            loop_inps = [tf.cast(0, tf.int32)]
            loop_outs = tf.while_loop(_cond,
                                      _body,
                                      loop_inps,
                                      parallel_iterations=1,
                                      name='train')
            train_op = loop_outs.op
            return train_op

        train_op = tf.tpu.shard(computation=train_loop,
                                num_shards=params.num_replicas,
                                device_assignment=params.device_assignment)
        global_step = tf.train.get_or_create_global_step()
        reset_global_step = tf.assign(global_step, 0)

        num_params = common_utils.count_params()
        logging.info(f'Model has {num_params} params')

        if should_eval:
            eval_fn, eval_summary_writer = prepare_eval(
                params=params,
                model=model,
                eval_logdir=f'eval_{params.image_size}',
                while_training=True)
            best_acc = -1.

        tf.io.write_graph(train_graph,
                          params.output_dir,
                          'train.pbtxt',
                          as_text=True)

        # outfeed_dequeue_op
        outfeed_signature = train_class.outfeed_signature()
        outfeed_ops, outfeed_graph = common_utils.get_outfeed_ops(
            params, outfeed_signature)

        # saver
        max_to_keep = 1 if should_eval else None
        saver = common_utils.get_saver(max_to_keep=max_to_keep)
        ckpt_dir = os.path.join(params.output_dir, 'ckpt_last')
        async_checkpoint = common_utils.AsyncCheckpoint(
            saver, ckpt_dir, max_to_keep=max_to_keep)

        if params.load_checkpoint is not None:
            reader = tf.train.NewCheckpointReader(params.load_checkpoint)
            var_to_shape_map = reader.get_variable_to_shape_map()
            var_list = {}
            global_variables = tf.global_variables()
            for v in global_variables:
                v_name = common_utils.strip_var_name(v.name)
                if v_name in var_to_shape_map:
                    var_list[v_name] = v
                else:
                    logging.info(
                        f'NOT FOUND: v.name={v.name:<100} v_name={v_name}')
            logging.info(
                f'Found {len(var_list)} variables (out of {len(global_variables)})'
                f' in {params.load_checkpoint}')
            restore_saver = tf.train.Saver(var_list=var_list)
        else:
            restore_saver = saver

        # actually run
        tpu_init = tf.tpu.initialize_system()
        var_init = tf.global_variables_initializer()
        tpu_shutdown = tf.tpu.shutdown_system()
        with common_utils.get_session(params) as sess:
            logging.info('Initialize TPU system')
            sess.run(tpu_init)

            run_options = tf.RunOptions(
                timeout_in_ms=1000 * 60 * 60 * 10,  # 10 hours
                report_tensor_allocations_upon_oom=True)
            sess.run(var_init, options=run_options)

            latest_checkpoint = common_utils.get_latest_checkpoint(ckpt_dir)
            if latest_checkpoint is not None:
                logging.info(f'Initialize vars from `{latest_checkpoint}`')
                saver.restore(sess, latest_checkpoint)
            elif params.load_checkpoint is not None:
                logging.info(
                    f'Initialize vars from `{params.load_checkpoint}`')
                restore_saver.restore(sess, params.load_checkpoint)
                if params.load_checkpoint_and_restart_global_step:
                    logging.info('Reset global_step to 0')
                    sess.run(reset_global_step)
            else:
                logging.info('Initialize vars from scratch')

            infeed_thread = common_utils.InfeedThread(
                params=params,
                infeed_ops=infeed_ops,
                infeed_graphs=infeed_graphs,
                name='train_infeed')
            outfeed_thread = common_utils.OutfeedThread(
                params, outfeed_ops, outfeed_graph, outfeed_signature)
            outfeed_thread.start()

            logging.info('Start training')
            while True:
                step = sess.run(global_step)
                if step >= params.num_train_steps:
                    break

                infeed_thread.start()
                sess.run(train_op, options=run_options)
                step = sess.run(global_step)
                async_checkpoint.save(sess, step)
                infeed_thread.join()

                if should_eval:
                    weak_result, acc = eval_fn(sess, step)
                    if weak_result and not params.running_local_dev:
                        logging.info('Weak results. Stop training')
                        break
                    if best_acc < acc:
                        best_acc = acc

            logging.info(
                'Wait for [infeed,outfeed,eval,checkpoint]_thread to stop')
            async_checkpoint.join()
            infeed_thread.stop()
            outfeed_thread.join()
            if should_eval:
                eval_summary_writer.close()
                with gfile.GFile(os.path.join(params.output_dir, 'acc'),
                                 'w') as fout:
                    fout.write(f'{best_acc:<.10f}')

            logging.info('Shut down TPU system.')
            sess.run(tpu_shutdown)
Esempio n. 2
0
def prepare_eval(params, model, eval_logdir, while_training=False):
    """Docs."""

    (eval_infeed_ops, eval_infeed_graphs,
     eval_size) = data_utils.build_eval_infeeds(params)
    eval_infeed_thread = common_utils.InfeedThread(
        params=params,
        infeed_ops=eval_infeed_ops,
        infeed_graphs=eval_infeed_graphs,
        name='eval_infeed')
    num_eval_steps = eval_size // params.eval_batch_size

    logging.info(f'Each eval will run for {num_eval_steps} steps')

    def eval_loop():
        """Docs."""
        def _cond(step, *args):  # pylint: disable=unused-argument
            return tf.less(step, tf.cast(num_eval_steps, step.dtype))

        def _body(step, *args):
            outs = training_utils.eval_step_fn(params, model)
            new_args = [step + 1
                        ] + [a.write(step, o) for a, o in zip(args, outs)]
            return tuple(new_args)

        batch_size = batch_size = params.eval_batch_size // params.num_replicas
        num_classes = params.num_classes
        logits = tf.TensorArray(dtype=tf.float32,
                                size=num_eval_steps,
                                element_shape=[batch_size, num_classes])
        labels = tf.TensorArray(dtype=tf.int32,
                                size=num_eval_steps,
                                element_shape=[batch_size, 1])
        masks = tf.TensorArray(dtype=tf.float32,
                               size=num_eval_steps,
                               element_shape=[batch_size])
        loop_inps = [0, logits, labels, masks]
        loop_outs = tf.while_loop(_cond,
                                  _body,
                                  loop_inps,
                                  parallel_iterations=1,
                                  name='eval')
        return [o.concat() for o in loop_outs[1:]]

    if while_training:
        with tf.variable_scope('ema', reuse=True):
            eval_op = tf.tpu.shard(computation=eval_loop,
                                   num_shards=params.num_replicas,
                                   device_assignment=params.device_assignment)
    else:
        eval_op = tf.tpu.shard(computation=eval_loop,
                               num_shards=params.num_replicas,
                               device_assignment=params.device_assignment)

    eval_logdir = os.path.join(params.output_dir, 'logs', eval_logdir)
    if gfile.IsDirectory(eval_logdir):
        gfile.DeleteRecursively(eval_logdir)
        gfile.MakeDirs(eval_logdir, mode=0o777)
    summary_writer = tf.summary.FileWriter(eval_logdir)

    def eval_fn(sess, step):
        """Docs."""
        eval_infeed_thread.start()
        logits, labels, mask = sess.run(eval_op)

        num_examples = np.sum(mask)
        sorted_indices = np.argsort(logits, axis=-1)

        def _top_k(k):
            in_top_k = np.any(sorted_indices[:, -k:] == labels, axis=-1)
            total = np.sum(in_top_k.astype(np.float32) * mask)
            return total / num_examples

        top_1, top_5 = _top_k(k=1), _top_k(k=5)
        tb_step = step // 1000 if params.task_mode == 'eval_forever' else step
        summary_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag='eval/top_1', simple_value=top_1),
                tf.Summary.Value(tag='eval/top_5', simple_value=top_5),
            ]), tb_step)
        summary_writer.flush()
        log_string = ' '.join([
            f'step={step:<8d}',
            f'total={int(num_examples):<6d}',
            f'top_1={top_1:<8.6f}',
            f'top_5={top_5:<8.6f}',
        ])
        logging.info(log_string)

        weak_result = False
        eval_infeed_thread.join()

        return weak_result, top_1

    return eval_fn, summary_writer