Пример #1
0
def evaluate(config, restore_path):
    if restore_path is None:
        restore_file = executor.search_restore_filename(
            environment.CHECKPOINTS_DIR)
        restore_path = os.path.join(environment.CHECKPOINTS_DIR, restore_file)

    if not os.path.exists("{}.index".format(restore_path)):
        raise Exception("restore file {} dont exists.".format(restore_path))

    print("restore_path:", restore_path)

    DatasetClass = config.DATASET_CLASS
    ModelClass = config.NETWORK_CLASS
    network_kwargs = {key.lower(): val for key, val in config.NETWORK.items()}

    if "test" in DatasetClass.available_subsets:
        subset = "test"
    else:
        subset = "validation"

    validation_dataset = setup_dataset(config, subset, seed=0)

    graph = tf.Graph()
    with graph.as_default():

        if ModelClass.__module__.startswith("lmnet.networks.object_detection"):
            model = ModelClass(
                classes=validation_dataset.classes,
                num_max_boxes=validation_dataset.num_max_boxes,
                is_debug=config.IS_DEBUG,
                **network_kwargs,
            )

        else:
            model = ModelClass(
                classes=validation_dataset.classes,
                is_debug=config.IS_DEBUG,
                **network_kwargs,
            )

        global_step = tf.Variable(0, name="global_step", trainable=False)
        is_training = tf.constant(False, name="is_training")

        images_placeholder, labels_placeholder = model.placeholders()

        output = model.inference(images_placeholder, is_training)

        metrics_ops_dict, metrics_update_op = model.metrics(
            output, labels_placeholder)
        model.summary(output, labels_placeholder)

        summary_op = tf.summary.merge_all()

        metrics_summary_op, metrics_placeholders = executor.prepare_metrics(
            metrics_ops_dict)

        init_op = tf.global_variables_initializer()
        reset_metrics_op = tf.local_variables_initializer()
        saver = tf.train.Saver(max_to_keep=None)

    session_config = None  # tf.ConfigProto(log_device_placement=True)
    sess = tf.Session(graph=graph, config=session_config)
    sess.run([init_op, reset_metrics_op])

    validation_writer = tf.summary.FileWriter(environment.TENSORBOARD_DIR +
                                              "/evaluate")

    saver.restore(sess, restore_path)

    last_step = sess.run(global_step)

    # init metrics values
    test_step_size = int(
        math.ceil(validation_dataset.num_per_epoch / config.BATCH_SIZE))
    print("test_step_size", test_step_size)

    for test_step in range(test_step_size):
        print("test_step", test_step)

        images, labels = validation_dataset.feed()
        feed_dict = {
            images_placeholder: images,
            labels_placeholder: labels,
        }

        # Summarize at only last step.
        if test_step == test_step_size - 1:
            summary, _ = sess.run([summary_op, metrics_update_op],
                                  feed_dict=feed_dict)
            validation_writer.add_summary(summary, last_step)
        else:
            sess.run([metrics_update_op], feed_dict=feed_dict)

    metrics_values = sess.run(list(metrics_ops_dict.values()))
    metrics_feed_dict = {
        placeholder: value
        for placeholder, value in zip(metrics_placeholders, metrics_values)
    }
    metrics_summary, = sess.run(
        [metrics_summary_op],
        feed_dict=metrics_feed_dict,
    )
    validation_writer.add_summary(metrics_summary, last_step)
Пример #2
0
def start_training(config):
    if config.IS_DISTRIBUTION:
        import horovod.tensorflow as hvd
        # initialize Horovod.
        hvd.init()
        num_worker = hvd.size()
        rank = hvd.rank()
        # verify that MPI multi-threading is supported.
        assert hvd.mpi_threads_supported()
        # make sure MPI is not re-initialized.
        import mpi4py.rc
        mpi4py.rc.initialize = False
        # import mpi4py
        from mpi4py import MPI
        comm = MPI.COMM_WORLD
        # check size and rank are syncronized
        assert num_worker == comm.Get_size()
        assert rank == comm.Get_rank()
    else:
        num_worker = 1
        rank = 0

    ModelClass = config.NETWORK_CLASS
    network_kwargs = dict(
        (key.lower(), val) for key, val in config.NETWORK.items())
    if "train_validation_saving_size".upper() in config.DATASET.keys():
        use_train_validation_saving = config.DATASET.TRAIN_VALIDATION_SAVING_SIZE > 0
    else:
        use_train_validation_saving = False

    if use_train_validation_saving:
        top_train_validation_saving_set_accuracy = 0

    train_dataset = setup_dataset(config, "train", rank)
    print("train dataset num:", train_dataset.num_per_epoch)

    if use_train_validation_saving:
        train_validation_saving_dataset = setup_dataset(
            config, "train_validation_saving", rank)
        print("train_validation_saving dataset num:",
              train_validation_saving_dataset.num_per_epoch)

    validation_dataset = setup_dataset(config, "validation", rank)
    print("validation dataset num:", validation_dataset.num_per_epoch)

    graph = tf.Graph()
    with graph.as_default():
        if ModelClass.__module__.startswith("lmnet.networks.object_detection"):
            model = ModelClass(
                classes=train_dataset.classes,
                num_max_boxes=train_dataset.num_max_boxes,
                is_debug=config.IS_DEBUG,
                **network_kwargs,
            )
        elif ModelClass.__module__.startswith("lmnet.networks.segmentation"):
            model = ModelClass(
                classes=train_dataset.classes,
                label_colors=train_dataset.label_colors,
                is_debug=config.IS_DEBUG,
                **network_kwargs,
            )
        else:
            model = ModelClass(
                classes=train_dataset.classes,
                is_debug=config.IS_DEBUG,
                **network_kwargs,
            )

        global_step = tf.Variable(0, name="global_step", trainable=False)
        is_training_placeholder = tf.placeholder(
            tf.bool, name="is_training_placeholder")

        images_placeholder, labels_placeholder = model.placeholderes()

        output = model.inference(images_placeholder, is_training_placeholder)
        if ModelClass.__module__.startswith("lmnet.networks.object_detection"):
            loss = model.loss(output, labels_placeholder,
                              is_training_placeholder)
        else:
            loss = model.loss(output, labels_placeholder)
        opt = model.optimizer(global_step)
        if config.IS_DISTRIBUTION:
            # add Horovod Distributed Optimizer
            opt = hvd.DistributedOptimizer(opt)
        train_op = model.train(loss, opt, global_step)
        metrics_ops_dict, metrics_update_op = model.metrics(
            output, labels_placeholder)
        # TODO(wakisaka): Deal with many networks.
        model.summary(output, labels_placeholder)

        summary_op = tf.summary.merge_all()

        metrics_summary_op, metrics_placeholders = executor.prepare_metrics(
            metrics_ops_dict)

        init_op = tf.global_variables_initializer()
        reset_metrics_op = tf.local_variables_initializer()
        if config.IS_DISTRIBUTION:
            # add Horovod broadcasting variables from rank 0 to all
            bcast_global_variables_op = hvd.broadcast_global_variables(0)

        if use_train_validation_saving:
            saver = tf.train.Saver(max_to_keep=1)
        else:
            saver = tf.train.Saver(max_to_keep=None)

        if config.IS_PRETRAIN:
            all_vars = tf.global_variables()
            pretrain_var_list = [
                var for var in all_vars
                if var.name.startswith(tuple(config.PRETRAIN_VARS))
            ]
            print("pretrain_vars", [var.name for var in pretrain_var_list])
            pretrain_saver = tf.train.Saver(pretrain_var_list,
                                            name="pretrain_saver")

    if config.IS_DISTRIBUTION:
        # For distributed training
        session_config = tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True, visible_device_list=str(hvd.local_rank())))
    else:
        # TODO(wakisaka): For debug.
        # session_config = tf.ConfigProto(
        #     gpu_options=tf.GPUOptions(
        #         allow_growth=True,
        #         per_process_gpu_memory_fraction=0.1
        #     )
        # )
        session_config = tf.ConfigProto(
        )  # tf.ConfigProto(log_device_placement=True)
    # TODO(wakisaka): XLA JIT
    # session_config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1

    sess = tf.Session(graph=graph, config=session_config)
    sess.run([init_op, reset_metrics_op])

    if rank == 0:
        train_writer = tf.summary.FileWriter(
            environment.TENSORBOARD_DIR + "/train", sess.graph)
        if use_train_validation_saving:
            train_val_saving_writer = tf.summary.FileWriter(
                environment.TENSORBOARD_DIR + "/train_validation_saving")
        val_writer = tf.summary.FileWriter(environment.TENSORBOARD_DIR +
                                           "/validation")

        if config.IS_PRETRAIN:
            print("------- Load pretrain data ----------")
            pretrain_saver.restore(
                sess, os.path.join(config.PRETRAIN_DIR, config.PRETRAIN_FILE))
            sess.run(tf.assign(global_step, 0))

        last_step = 0

        # for recovery
        ckpt = tf.train.get_checkpoint_state(environment.CHECKPOINTS_DIR)
        if ckpt and ckpt.model_checkpoint_path:
            print("--------- Restore last checkpoint -------------")
            saver.restore(sess, ckpt.model_checkpoint_path)
            # saver.recover_last_checkpoints(ckpt.model_checkpoint_path)
            last_step = sess.run(global_step)
            # TODO(wakisaka): tensorflow v1.3 remain previous event log in tensorboard.
            # https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/python/training/supervisor.py#L1072
            train_writer.add_session_log(SessionLog(status=SessionLog.START),
                                         global_step=last_step + 1)
            val_writer.add_session_log(SessionLog(status=SessionLog.START),
                                       global_step=last_step + 1)
            print("recovered. last step", last_step)

    if config.IS_DISTRIBUTION:
        # broadcast variables from rank 0 to all other processes
        sess.run(bcast_global_variables_op)
        # calculate step per epoch for each nodes
        train_num_per_epoch = train_dataset.num_per_epoch
        num_per_nodes = (train_num_per_epoch + num_worker - 1) // num_worker
        step_per_epoch = num_per_nodes // config.BATCH_SIZE
        begin_index = (train_num_per_epoch * rank) // num_worker
        end_index = begin_index + num_per_nodes

    last_step = sess.run(global_step)

    # Calculate max steps. The priority of config.MAX_EPOCHS is higher than config.MAX_STEPS.
    if "MAX_EPOCHS" in config:
        max_steps = int(train_dataset.num_per_epoch / config.BATCH_SIZE *
                        config.MAX_EPOCHS)
    else:
        max_steps = config.MAX_STEPS
    print("max_steps: {}".format(max_steps))

    for step in range(last_step, max_steps):
        print("step", step)

        if config.IS_DISTRIBUTION:
            # scatter dataset
            if step % step_per_epoch == 0:
                indices = train_dataset.get_shuffle_index(
                ) if rank == 0 else None
                # broadcast shuffled indices
                indices = comm.bcast(indices, 0)
                feed_indices = indices[begin_index:end_index]
                # update each dataset by splited indices
                train_dataset.update_dataset(feed_indices)

        images, labels = train_dataset.feed()

        feed_dict = {
            is_training_placeholder: True,
            images_placeholder: images,
            labels_placeholder: labels,
        }

        if step * ((step + 1) % config.SUMMARISE_STEPS) == 0 and rank == 0:
            # Runtime statistics for develop.
            # run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            # run_metadata = tf.RunMetadata()

            sess.run(reset_metrics_op)
            _, summary, _ = sess.run(
                [train_op, summary_op, metrics_update_op],
                feed_dict=feed_dict,
                # options=run_options,
                # run_metadata=run_metadata,
            )
            # train_writer.add_run_metadata(run_metadata, "step: {}".format(step + 1))
            train_writer.add_summary(summary, step + 1)

            metrics_values = sess.run(list(metrics_ops_dict.values()))
            metrics_feed_dict = {
                placeholder: value
                for placeholder, value in zip(metrics_placeholders,
                                              metrics_values)
            }

            metrics_summary, = sess.run(
                [metrics_summary_op],
                feed_dict=metrics_feed_dict,
            )
            train_writer.add_summary(metrics_summary, step + 1)
        else:
            sess.run([train_op], feed_dict=feed_dict)

        to_be_saved = step == 0 or (
            step + 1) == max_steps or (step + 1) % config.SAVE_STEPS == 0

        if to_be_saved and rank == 0:
            if use_train_validation_saving:

                sess.run(reset_metrics_op)
                train_validation_saving_step_size = int(
                    math.ceil(train_validation_saving_dataset.num_per_epoch /
                              config.BATCH_SIZE))
                print("train_validation_saving_step_size",
                      train_validation_saving_step_size)

                current_train_validation_saving_set_accuracy = 0

                for train_validation_saving_step in range(
                        train_validation_saving_step_size):
                    print("train_validation_saving_step",
                          train_validation_saving_step)

                    images, labels = train_validation_saving_dataset.feed()
                    feed_dict = {
                        is_training_placeholder: False,
                        images_placeholder: images,
                        labels_placeholder: labels,
                    }

                    if train_validation_saving_step % config.SUMMARISE_STEPS == 0:
                        summary, _ = sess.run([summary_op, metrics_update_op],
                                              feed_dict=feed_dict)
                        train_val_saving_writer.add_summary(summary, step + 1)
                    else:
                        sess.run([metrics_update_op], feed_dict=feed_dict)

                metrics_values = sess.run(list(metrics_ops_dict.values()))
                metrics_feed_dict = {
                    placeholder: value
                    for placeholder, value in zip(metrics_placeholders,
                                                  metrics_values)
                }
                metrics_summary, = sess.run(
                    [metrics_summary_op],
                    feed_dict=metrics_feed_dict,
                )
                train_val_saving_writer.add_summary(metrics_summary, step + 1)

                current_train_validation_saving_set_accuracy = sess.run(
                    metrics_ops_dict["accuracy"])

                if current_train_validation_saving_set_accuracy > top_train_validation_saving_set_accuracy:
                    top_train_validation_saving_set_accuracy = current_train_validation_saving_set_accuracy
                    print("New top train_validation_saving accuracy is: ",
                          top_train_validation_saving_set_accuracy)

                    _save_checkpoint(saver, sess, global_step, step)

            else:
                _save_checkpoint(saver, sess, global_step, step)

            if step == 0:
                # check create pb on only first step.
                minimal_graph = tf.graph_util.convert_variables_to_constants(
                    sess,
                    sess.graph.as_graph_def(add_shapes=True),
                    ["output"],
                )
                pb_name = "minimal_graph_with_shape_{}.pb".format(step + 1)
                pbtxt_name = "minimal_graph_with_shape_{}.pbtxt".format(step +
                                                                        1)
                tf.train.write_graph(minimal_graph,
                                     environment.CHECKPOINTS_DIR,
                                     pb_name,
                                     as_text=False)
                tf.train.write_graph(minimal_graph,
                                     environment.CHECKPOINTS_DIR,
                                     pbtxt_name,
                                     as_text=True)

        if step == 0 or (step + 1) % config.TEST_STEPS == 0:
            # init metrics values
            sess.run(reset_metrics_op)
            test_step_size = int(
                math.ceil(validation_dataset.num_per_epoch /
                          config.BATCH_SIZE))
            print("test_step_size", test_step_size)

            for test_step in range(test_step_size):
                print("test_step", test_step)

                images, labels = validation_dataset.feed()
                feed_dict = {
                    is_training_placeholder: False,
                    images_placeholder: images,
                    labels_placeholder: labels,
                }

                if test_step % config.SUMMARISE_STEPS == 0:
                    summary, _ = sess.run([summary_op, metrics_update_op],
                                          feed_dict=feed_dict)
                    if rank == 0:
                        val_writer.add_summary(summary, step + 1)
                else:
                    sess.run([metrics_update_op], feed_dict=feed_dict)

            metrics_values = sess.run(list(metrics_ops_dict.values()))
            metrics_feed_dict = {
                placeholder: value
                for placeholder, value in zip(metrics_placeholders,
                                              metrics_values)
            }
            metrics_summary, = sess.run(
                [metrics_summary_op],
                feed_dict=metrics_feed_dict,
            )
            if rank == 0:
                val_writer.add_summary(metrics_summary, step + 1)

    # training loop end.
    print("reach max step")