예제 #1
0
def parse_args():
    """
    Create python script parameters (common part).

    Returns:
    -------
    ArgumentParser
        Resulted args.
    """
    parser = argparse.ArgumentParser(
        description=
        "Evaluate a model for image classification/segmentation (TensorFlow 2.0)",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        "--dataset",
        type=str,
        default="ImageNet1K",
        help=
        "dataset name. options are ImageNet1K, ImageNet1K_rec, CUB200_2011, CIFAR10, CIFAR100, SVHN, VOC2012, "
        "ADE20K, Cityscapes, COCO")
    parser.add_argument(
        "--work-dir",
        type=str,
        default=os.path.join("..", "imgclsmob_data"),
        help="path to working directory only for dataset root path preset")

    args, _ = parser.parse_known_args()
    dataset_metainfo = get_dataset_metainfo(dataset_name=args.dataset)
    dataset_metainfo.add_dataset_parser_arguments(parser=parser,
                                                  work_dir_path=args.work_dir)

    add_eval_parser_arguments(parser)

    args = parser.parse_args()
    return args
예제 #2
0
def parse_args():
    """
    Parse python script parameters (common part).

    Returns
    -------
    ArgumentParser
        Resulted args.
    """
    parser = argparse.ArgumentParser(
        description=
        "Train a model for image classification/segmentation (TensorFlow 2.0)",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--dataset",
                        type=str,
                        default="ImageNet1K",
                        help="dataset name. options are ImageNet1K, CIFAR10")
    parser.add_argument(
        "--work-dir",
        type=str,
        default=os.path.join("..", "imgclsmob_data"),
        help="path to working directory only for dataset root path preset")

    args, _ = parser.parse_known_args()
    dataset_metainfo = get_dataset_metainfo(dataset_name=args.dataset)
    dataset_metainfo.add_dataset_parser_arguments(parser=parser,
                                                  work_dir_path=args.work_dir)

    add_train_cls_parser_arguments(parser)

    args = parser.parse_args()
    return args
예제 #3
0
def test_model(args, use_cuda, data_format):
    """
    Main test routine.

    Parameters:
    ----------
    args : ArgumentParser
        Main script arguments.
    use_cuda : bool
        Whether to use CUDA.
    data_format : str
        The ordering of the dimensions in tensors.

    Returns:
    -------
    float
        Main accuracy value.
    """
    ds_metainfo = get_dataset_metainfo(dataset_name=args.dataset)
    ds_metainfo.update(args=args)
    assert (ds_metainfo.ml_type != "imgseg") or (args.batch_size == 1)
    assert (ds_metainfo.ml_type != "imgseg") or args.disable_cudnn_autotune

    batch_size = args.batch_size
    net = prepare_model(model_name=args.model,
                        use_pretrained=args.use_pretrained,
                        pretrained_model_file_path=args.resume.strip(),
                        net_extra_kwargs=ds_metainfo.test_net_extra_kwargs,
                        load_ignore_extra=ds_metainfo.load_ignore_extra,
                        batch_size=batch_size,
                        use_cuda=use_cuda)
    assert (hasattr(net, "in_size"))

    if not args.calc_flops_only:
        tic = time.time()

        get_test_data_source_class = get_val_data_source if args.data_subset == "val" else get_test_data_source
        test_data, total_img_count = get_test_data_source_class(
            ds_metainfo=ds_metainfo,
            batch_size=args.batch_size,
            data_format=data_format)
        if args.data_subset == "val":
            test_metric = get_composite_metric(
                metric_names=ds_metainfo.val_metric_names,
                metric_extra_kwargs=ds_metainfo.val_metric_extra_kwargs)
        else:
            test_metric = get_composite_metric(
                metric_names=ds_metainfo.test_metric_names,
                metric_extra_kwargs=ds_metainfo.test_metric_extra_kwargs)

        if args.show_progress:
            from tqdm import tqdm
            test_data = tqdm(test_data)

        processed_img_count = 0
        for test_images, test_labels in test_data:
            predictions = net(test_images)
            test_metric.update(test_labels, predictions)
            processed_img_count += len(test_images)
            if processed_img_count >= total_img_count:
                break

        accuracy_msg = report_accuracy(metric=test_metric, extended_log=True)
        logging.info("Test: {}".format(accuracy_msg))
        logging.info("Time cost: {:.4f} sec".format(time.time() - tic))
        acc_values = test_metric.get()[1]
        acc_values = acc_values if type(acc_values) == list else [acc_values]
    else:
        acc_values = []

    return acc_values
def main():
    """
    Main body of script.
    """
    args = parse_args()
    args.seed = init_rand(seed=args.seed)

    _, log_file_exist = initialize_logging(
        logging_dir_path=args.save_dir,
        logging_file_name=args.logging_file_name,
        script_args=args,
        log_packages=args.log_packages,
        log_pip_packages=args.log_pip_packages)

    data_format = "channels_last"
    tf.keras.backend.set_image_data_format(data_format)

    model = args.model
    net = get_model(model, data_format=data_format)

    loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
    optimizer = tf.keras.optimizers.Adam()
    train_loss = tf.keras.metrics.Mean(name="train_loss")
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
        name="train_accuracy")
    test_loss = tf.keras.metrics.Mean(name="test_loss")
    test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
        name="test_accuracy")

    @tf.function
    def train_step(images, labels):
        with tf.GradientTape() as tape:
            predictions = net(images)
            loss = loss_object(labels, predictions)
        gradients = tape.gradient(loss, net.trainable_variables)
        optimizer.apply_gradients(zip(gradients, net.trainable_variables))
        train_loss(loss)
        train_accuracy(labels, predictions)

    @tf.function
    def test_step(images, labels):
        predictions = net(images)
        t_loss = loss_object(labels, predictions)
        test_loss(t_loss)
        test_accuracy(labels, predictions)

    ds_metainfo = get_dataset_metainfo(dataset_name=args.dataset)
    ds_metainfo.update(args=args)
    assert (ds_metainfo.ml_type != "imgseg") or (args.batch_size == 1)
    assert (ds_metainfo.ml_type != "imgseg") or args.disable_cudnn_autotune

    batch_size = args.batch_size

    train_data, train_img_count = get_train_data_source(
        ds_metainfo=ds_metainfo,
        batch_size=batch_size,
        data_format=data_format)
    val_data, val_img_count = get_val_data_source(ds_metainfo=ds_metainfo,
                                                  batch_size=batch_size,
                                                  data_format=data_format)

    num_epochs = args.num_epochs
    for epoch in range(num_epochs):
        for images, labels in train_data:
            train_step(images, labels)
            # break

        for test_images, test_labels in val_data:
            test_step(test_images, test_labels)
            # break

        template = "Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}"
        logging.info(
            template.format(epoch + 1, train_loss.result(),
                            train_accuracy.result() * 100, test_loss.result(),
                            test_accuracy.result() * 100))

        train_loss.reset_states()
        train_accuracy.reset_states()
        test_loss.reset_states()
        test_accuracy.reset_states()