예제 #1
0
def parse_args():
    """
    Parse python script parameters (common part).

    Returns
    -------
    ArgumentParser
        Resulted args.
    """
    parser = argparse.ArgumentParser(
        description="Evaluate a model for image classification (Gluon/FDV)",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        "--dataset",
        type=str,
        default="FDV1",
        help="dataset name. option are FDV1, FDV2")
    parser.add_argument(
        "--work-dir",
        type=str,
        default=os.path.join("..", "facedetver_data"),
        help="path to working directory only for dataset root path preset")

    args, _ = parser.parse_known_args()
    dataset_metainfo = get_dataset_metainfo(dataset_name=args.dataset)
    dataset_metainfo.add_dataset_parser_arguments(
        parser=parser,
        work_dir_path=args.work_dir)

    add_eval_parser_arguments(parser)

    args = parser.parse_args()
    return args
예제 #2
0
def parse_args():
    """
    Parse python script parameters (common part).

    Returns:
    -------
    ArgumentParser
        Resulted args.
    """
    parser = argparse.ArgumentParser(
        description="Train a model for image classification (Gluon)",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        "--dataset",
        type=str,
        default="ImageNet1K_rec",
        help=
        "dataset name. options are ImageNet1K, ImageNet1K_rec, CUB200_2011, CIFAR10, CIFAR100, SVHN"
    )
    parser.add_argument(
        "--work-dir",
        type=str,
        default=os.path.join("..", "imgclsmob_data"),
        help="path to working directory only for dataset root path preset")

    args, _ = parser.parse_known_args()
    dataset_metainfo = get_dataset_metainfo(dataset_name=args.dataset)
    dataset_metainfo.add_dataset_parser_arguments(parser=parser,
                                                  work_dir_path=args.work_dir)

    add_train_cls_parser_arguments(parser)

    args = parser.parse_args()
    return args
예제 #3
0
def parse_args():
    parser = argparse.ArgumentParser(
        description="Evaluate a model for image matching (Gluon/HPatches)",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        "--dataset",
        type=str,
        default="HPatches",
        help="dataset name")
    parser.add_argument(
        "--work-dir",
        type=str,
        default=os.path.join("..", "imgclsmob_data"),
        help="path to working directory only for dataset root path preset")

    args, _ = parser.parse_known_args()
    dataset_metainfo = get_dataset_metainfo(dataset_name=args.dataset)
    dataset_metainfo.add_dataset_parser_arguments(
        parser=parser,
        work_dir_path=args.work_dir)

    add_eval_parser_arguments(parser)

    args = parser.parse_args()
    return args
예제 #4
0
def parse_args():
    parser = argparse.ArgumentParser(
        description="Evaluate a model for image classification/segmentation (Gluon)",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        "--dataset",
        type=str,
        default="ImageNet1K_rec",
        help="dataset name. options are ImageNet1K, ImageNet1K_rec, CUB200_2011, CIFAR10, CIFAR100, SVHN, VOC2012, "
             "ADE20K, Cityscapes, COCO")
    parser.add_argument(
        "--work-dir",
        type=str,
        default=os.path.join("..", "imgclsmob_data"),
        help="path to working directory only for dataset root path preset")

    args, _ = parser.parse_known_args()
    dataset_metainfo = get_dataset_metainfo(dataset_name=args.dataset)
    dataset_metainfo.add_dataset_parser_arguments(
        parser=parser,
        work_dir_path=args.work_dir)

    add_eval_parser_arguments(parser)

    args = parser.parse_args()
    return args
예제 #5
0
def main():
    args = parse_args()

    os.environ["MXNET_CUDNN_AUTOTUNE_DEFAULT"] = "0"
    assert (args.batch_size == 1)

    _, log_file_exist = initialize_logging(
        logging_dir_path=args.save_dir,
        logging_file_name=args.logging_file_name,
        script_args=args,
        log_packages=args.log_packages,
        log_pip_packages=args.log_pip_packages)

    ds_metainfo = get_dataset_metainfo(dataset_name=args.dataset)
    ds_metainfo.update(args=args)

    ctx, batch_size = prepare_mx_context(num_gpus=args.num_gpus,
                                         batch_size=args.batch_size)

    net = prepare_model(model_name=args.model,
                        use_pretrained=args.use_pretrained,
                        pretrained_model_file_path=args.resume.strip(),
                        dtype=args.dtype,
                        net_extra_kwargs=ds_metainfo.net_extra_kwargs,
                        load_ignore_extra=False,
                        classes=args.num_classes,
                        in_channels=args.in_channels,
                        do_hybridize=False,
                        ctx=ctx)

    test_data = get_val_data_source(ds_metainfo=ds_metainfo,
                                    batch_size=args.batch_size,
                                    num_workers=args.num_workers)

    calc_detector_repeatability(test_data=test_data, net=net, ctx=ctx)
예제 #6
0
def main():
    args = parse_args()

    os.environ["MXNET_CUDNN_AUTOTUNE_DEFAULT"] = "0"
    assert (args.batch_size == 1)

    _, log_file_exist = initialize_logging(
        logging_dir_path=args.save_dir,
        logging_file_name=args.logging_file_name,
        script_args=args,
        log_packages=args.log_packages,
        log_pip_packages=args.log_pip_packages)

    ds_metainfo = get_dataset_metainfo(dataset_name=args.dataset)
    ds_metainfo.update(args=args)

    ctx, batch_size = prepare_mx_context(
        num_gpus=args.num_gpus,
        batch_size=args.batch_size)

    net = prepare_model(
        model_name=args.model,
        use_pretrained=args.use_pretrained,
        pretrained_model_file_path=args.resume.strip(),
        dtype=args.dtype,
        net_extra_kwargs=None,
        load_ignore_extra=False,
        classes=args.num_classes,
        in_channels=args.in_channels,
        do_hybridize=False,
        ctx=ctx)

    test_data = get_val_data_source(
        ds_metainfo=ds_metainfo,
        batch_size=args.batch_size,
        num_workers=args.num_workers)

    tic = time.time()
    for batch in test_data:
        data_src_list, data_dst_list, labels_list = batch_fn(batch, ctx)
        outputs_src_list = [net(X) for X in data_src_list]
        assert (outputs_src_list is not None)
        pass
    logging.info("Time cost: {:.4f} sec".format(
        time.time() - tic))
예제 #7
0
def main():
    """
    Main body of script.
    """
    args = parse_args()
    args.seed = init_rand(seed=args.seed)

    _, log_file_exist = initialize_logging(
        logging_dir_path=args.save_dir,
        logging_file_name=args.logging_file_name,
        script_args=args,
        log_packages=args.log_packages,
        log_pip_packages=args.log_pip_packages)

    ctx, batch_size = prepare_mx_context(num_gpus=args.num_gpus,
                                         batch_size=args.batch_size)

    ds_metainfo = get_dataset_metainfo(dataset_name=args.dataset)
    ds_metainfo.update(args=args)

    use_teacher = (args.teacher_models
                   is not None) and (args.teacher_models.strip() != "")

    net = prepare_model(
        model_name=args.model,
        use_pretrained=args.use_pretrained,
        pretrained_model_file_path=args.resume.strip(),
        dtype=args.dtype,
        net_extra_kwargs=ds_metainfo.train_net_extra_kwargs,
        tune_layers=args.tune_layers,
        classes=args.num_classes,
        in_channels=args.in_channels,
        do_hybridize=(not args.not_hybridize),
        initializer=get_initializer(initializer_name=args.initializer),
        ctx=ctx)
    assert (hasattr(net, "classes"))
    num_classes = net.classes

    teacher_net = None
    discrim_net = None
    discrim_loss_func = None
    if use_teacher:
        teacher_nets = []
        for teacher_model in args.teacher_models.split(","):
            teacher_net = prepare_model(
                model_name=teacher_model.strip(),
                use_pretrained=True,
                pretrained_model_file_path="",
                dtype=args.dtype,
                net_extra_kwargs=ds_metainfo.train_net_extra_kwargs,
                do_hybridize=(not args.not_hybridize),
                ctx=ctx)
            assert (teacher_net.classes == net.classes)
            assert (teacher_net.in_size == net.in_size)
            teacher_nets.append(teacher_net)
        if len(teacher_nets) > 0:
            teacher_net = Concurrent(stack=True,
                                     prefix="",
                                     branches=teacher_nets)
            for k, v in teacher_net.collect_params().items():
                v.grad_req = "null"
            if not args.not_discriminator:
                discrim_net = MealDiscriminator()
                discrim_net.cast(args.dtype)
                if not args.not_hybridize:
                    discrim_net.hybridize(static_alloc=True, static_shape=True)
                discrim_net.initialize(mx.init.MSRAPrelu(), ctx=ctx)
                for k, v in discrim_net.collect_params().items():
                    v.lr_mult = args.dlr_factor
                discrim_loss_func = MealAdvLoss()

    train_data = get_train_data_source(ds_metainfo=ds_metainfo,
                                       batch_size=batch_size,
                                       num_workers=args.num_workers)
    val_data = get_val_data_source(ds_metainfo=ds_metainfo,
                                   batch_size=batch_size,
                                   num_workers=args.num_workers)
    batch_fn = get_batch_fn(ds_metainfo=ds_metainfo)

    num_training_samples = len(
        train_data._dataset
    ) if not ds_metainfo.use_imgrec else ds_metainfo.num_training_samples
    trainer, lr_scheduler = prepare_trainer(
        net=net,
        optimizer_name=args.optimizer_name,
        wd=args.wd,
        momentum=args.momentum,
        lr_mode=args.lr_mode,
        lr=args.lr,
        lr_decay_period=args.lr_decay_period,
        lr_decay_epoch=args.lr_decay_epoch,
        lr_decay=args.lr_decay,
        target_lr=args.target_lr,
        poly_power=args.poly_power,
        warmup_epochs=args.warmup_epochs,
        warmup_lr=args.warmup_lr,
        warmup_mode=args.warmup_mode,
        batch_size=batch_size,
        num_epochs=args.num_epochs,
        num_training_samples=num_training_samples,
        dtype=args.dtype,
        gamma_wd_mult=args.gamma_wd_mult,
        beta_wd_mult=args.beta_wd_mult,
        bias_wd_mult=args.bias_wd_mult,
        state_file_path=args.resume_state)

    if args.save_dir and args.save_interval:
        param_names = ds_metainfo.val_metric_capts + ds_metainfo.train_metric_capts + [
            "Train.Loss", "LR"
        ]
        lp_saver = TrainLogParamSaver(
            checkpoint_file_name_prefix="{}_{}".format(ds_metainfo.short_label,
                                                       args.model),
            last_checkpoint_file_name_suffix="last",
            best_checkpoint_file_name_suffix=None,
            last_checkpoint_dir_path=args.save_dir,
            best_checkpoint_dir_path=None,
            last_checkpoint_file_count=2,
            best_checkpoint_file_count=2,
            checkpoint_file_save_callback=save_params,
            checkpoint_file_exts=(".params", ".states"),
            save_interval=args.save_interval,
            num_epochs=args.num_epochs,
            param_names=param_names,
            acc_ind=ds_metainfo.saver_acc_ind,
            # bigger=[True],
            # mask=None,
            score_log_file_path=os.path.join(args.save_dir, "score.log"),
            score_log_attempt_value=args.attempt,
            best_map_log_file_path=os.path.join(args.save_dir, "best_map.log"))
    else:
        lp_saver = None

    val_metric = get_composite_metric(ds_metainfo.val_metric_names,
                                      ds_metainfo.val_metric_extra_kwargs)
    train_metric = get_composite_metric(ds_metainfo.train_metric_names,
                                        ds_metainfo.train_metric_extra_kwargs)
    loss_metrics = [LossValue(name="loss"), LossValue(name="dloss")]

    loss_kwargs = {
        "sparse_label": (not (args.mixup or args.label_smoothing)
                         and not (use_teacher and (teacher_net is not None)))
    }
    if ds_metainfo.loss_extra_kwargs is not None:
        loss_kwargs.update(ds_metainfo.loss_extra_kwargs)
    loss_func = get_loss(ds_metainfo.loss_name, loss_kwargs)

    train_net(batch_size=batch_size,
              num_epochs=args.num_epochs,
              start_epoch1=args.start_epoch,
              train_data=train_data,
              val_data=val_data,
              batch_fn=batch_fn,
              data_source_needs_reset=ds_metainfo.use_imgrec,
              dtype=args.dtype,
              net=net,
              teacher_net=teacher_net,
              discrim_net=discrim_net,
              trainer=trainer,
              lr_scheduler=lr_scheduler,
              lp_saver=lp_saver,
              log_interval=args.log_interval,
              mixup=args.mixup,
              mixup_epoch_tail=args.mixup_epoch_tail,
              label_smoothing=args.label_smoothing,
              num_classes=num_classes,
              grad_clip_value=args.grad_clip,
              batch_size_scale=args.batch_size_scale,
              val_metric=val_metric,
              train_metric=train_metric,
              loss_metrics=loss_metrics,
              loss_func=loss_func,
              discrim_loss_func=discrim_loss_func,
              ctx=ctx)
예제 #8
0
파일: eval_gl.py 프로젝트: siddie/imgclsmob
def test_model(args):
    """
    Main test routine.

    Parameters:
    ----------
    args : ArgumentParser
        Main script arguments.

    Returns:
    -------
    float
        Main accuracy value.
    """
    ds_metainfo = get_dataset_metainfo(dataset_name=args.dataset)
    ds_metainfo.update(args=args)
    assert (ds_metainfo.ml_type !=
            "imgseg") or (args.data_subset != "test") or (args.batch_size == 1)
    assert (ds_metainfo.ml_type != "imgseg") or args.disable_cudnn_autotune

    ctx, batch_size = prepare_mx_context(num_gpus=args.num_gpus,
                                         batch_size=args.batch_size)

    net = prepare_model(
        model_name=args.model,
        use_pretrained=args.use_pretrained,
        pretrained_model_file_path=args.resume.strip(),
        dtype=args.dtype,
        net_extra_kwargs=ds_metainfo.test_net_extra_kwargs,
        load_ignore_extra=ds_metainfo.load_ignore_extra,
        classes=(args.num_classes if ds_metainfo.ml_type != "hpe" else None),
        in_channels=args.in_channels,
        do_hybridize=(ds_metainfo.allow_hybridize and (not args.calc_flops)),
        ctx=ctx)
    assert (hasattr(net, "in_size"))
    input_image_size = net.in_size

    get_test_data_source_class = get_val_data_source if args.data_subset == "val" else get_test_data_source
    test_data = get_test_data_source_class(ds_metainfo=ds_metainfo,
                                           batch_size=args.batch_size,
                                           num_workers=args.num_workers)
    batch_fn = get_batch_fn(ds_metainfo=ds_metainfo)
    if args.data_subset == "val":
        test_metric = get_composite_metric(
            metric_names=ds_metainfo.val_metric_names,
            metric_extra_kwargs=ds_metainfo.val_metric_extra_kwargs)
    else:
        test_metric = get_composite_metric(
            metric_names=ds_metainfo.test_metric_names,
            metric_extra_kwargs=ds_metainfo.test_metric_extra_kwargs)

    if not args.not_show_progress:
        from tqdm import tqdm
        test_data = tqdm(test_data)

    assert (args.use_pretrained or args.resume.strip() or args.calc_flops_only)
    acc_values = calc_model_accuracy(
        net=net,
        test_data=test_data,
        batch_fn=batch_fn,
        data_source_needs_reset=ds_metainfo.use_imgrec,
        metric=test_metric,
        dtype=args.dtype,
        ctx=ctx,
        input_image_size=input_image_size,
        in_channels=args.in_channels,
        # calc_weight_count=(not log_file_exist),
        calc_weight_count=True,
        calc_flops=args.calc_flops,
        calc_flops_only=args.calc_flops_only,
        extended_log=True)
    return acc_values[
        ds_metainfo.saver_acc_ind] if len(acc_values) > 0 else None
예제 #9
0
def main():
    """
    Main body of script.
    """
    args = parse_args()

    _, log_file_exist = initialize_logging(
        logging_dir_path=args.save_dir,
        logging_file_name=args.logging_file_name,
        script_args=args,
        log_packages=args.log_packages,
        log_pip_packages=args.log_pip_packages)

    ds_metainfo = get_dataset_metainfo(dataset_name=args.dataset)
    ds_metainfo.update(args=args)

    ctx, batch_size = prepare_mx_context(
        num_gpus=args.num_gpus,
        batch_size=args.batch_size)

    net = prepare_model(
        model_name=args.model,
        use_pretrained=args.use_pretrained,
        pretrained_model_file_path=args.resume.strip(),
        dtype=args.dtype,
        net_extra_kwargs=ds_metainfo.net_extra_kwargs,
        load_ignore_extra=ds_metainfo.load_ignore_extra,
        classes=args.num_classes,
        in_channels=args.in_channels,
        do_hybridize=(ds_metainfo.allow_hybridize and (not args.calc_flops)),
        ctx=ctx)
    assert (hasattr(net, "in_size"))
    input_image_size = net.in_size

    if args.data_subset == "val":
        get_test_data_source_class = get_val_data_source
        test_metric = get_composite_metric(
            metric_names=ds_metainfo.val_metric_names,
            metric_extra_kwargs=ds_metainfo.val_metric_extra_kwargs)
    else:
        get_test_data_source_class = get_test_data_source
        test_metric = get_composite_metric(
            metric_names=ds_metainfo.test_metric_names,
            metric_extra_kwargs=ds_metainfo.test_metric_extra_kwargs)
    test_data = get_test_data_source_class(
        ds_metainfo=ds_metainfo,
        batch_size=args.batch_size,
        num_workers=args.num_workers)
    batch_fn = get_batch_fn(use_imgrec=ds_metainfo.use_imgrec)

    if not args.not_show_progress:
        test_data = tqdm(test_data)

    assert (args.use_pretrained or args.resume.strip() or args.calc_flops_only)
    test(
        net=net,
        test_data=test_data,
        batch_fn=batch_fn,
        data_source_needs_reset=ds_metainfo.use_imgrec,
        metric=test_metric,
        dtype=args.dtype,
        ctx=ctx,
        input_image_size=input_image_size,
        in_channels=args.in_channels,
        calc_weight_count=True,
        calc_flops=args.calc_flops,
        calc_flops_only=args.calc_flops_only,
        extended_log=True,
        show_bad_samples=args.show_bad_samples)
예제 #10
0
def main():
    """
    Main body of script.
    """
    args = parse_args()
    args.seed = init_rand(seed=args.seed)

    _, log_file_exist = initialize_logging(
        logging_dir_path=args.save_dir,
        logging_file_name=args.logging_file_name,
        script_args=args,
        log_packages=args.log_packages,
        log_pip_packages=args.log_pip_packages)

    ctx, batch_size = prepare_mx_context(num_gpus=args.num_gpus,
                                         batch_size=args.batch_size)

    net = prepare_model(
        model_name=args.model,
        use_pretrained=args.use_pretrained,
        pretrained_model_file_path=args.resume.strip(),
        dtype=args.dtype,
        tune_layers=args.tune_layers,
        classes=args.num_classes,
        in_channels=args.in_channels,
        do_hybridize=(not args.not_hybridize),
        initializer=get_initializer(initializer_name=args.initializer),
        ctx=ctx)
    assert (hasattr(net, "classes"))
    num_classes = net.classes

    ds_metainfo = get_dataset_metainfo(dataset_name=args.dataset)
    ds_metainfo.update(args=args)

    train_data = get_train_data_source(ds_metainfo=ds_metainfo,
                                       batch_size=batch_size,
                                       num_workers=args.num_workers)
    val_data = get_val_data_source(ds_metainfo=ds_metainfo,
                                   batch_size=batch_size,
                                   num_workers=args.num_workers)
    batch_fn = get_batch_fn(use_imgrec=ds_metainfo.use_imgrec)

    num_training_samples = len(
        train_data._dataset
    ) if not ds_metainfo.use_imgrec else ds_metainfo.num_training_samples
    trainer, lr_scheduler = prepare_trainer(
        net=net,
        optimizer_name=args.optimizer_name,
        wd=args.wd,
        momentum=args.momentum,
        lr_mode=args.lr_mode,
        lr=args.lr,
        lr_decay_period=args.lr_decay_period,
        lr_decay_epoch=args.lr_decay_epoch,
        lr_decay=args.lr_decay,
        target_lr=args.target_lr,
        poly_power=args.poly_power,
        warmup_epochs=args.warmup_epochs,
        warmup_lr=args.warmup_lr,
        warmup_mode=args.warmup_mode,
        batch_size=batch_size,
        num_epochs=args.num_epochs,
        num_training_samples=num_training_samples,
        dtype=args.dtype,
        gamma_wd_mult=args.gamma_wd_mult,
        beta_wd_mult=args.beta_wd_mult,
        bias_wd_mult=args.bias_wd_mult,
        state_file_path=args.resume_state)

    if args.save_dir and args.save_interval:
        param_names = ds_metainfo.val_metric_capts + ds_metainfo.train_metric_capts + [
            "Train.Loss", "LR"
        ]
        lp_saver = TrainLogParamSaver(
            checkpoint_file_name_prefix="{}_{}".format(ds_metainfo.short_label,
                                                       args.model),
            last_checkpoint_file_name_suffix="last",
            best_checkpoint_file_name_suffix=None,
            last_checkpoint_dir_path=args.save_dir,
            best_checkpoint_dir_path=None,
            last_checkpoint_file_count=2,
            best_checkpoint_file_count=2,
            checkpoint_file_save_callback=save_params,
            checkpoint_file_exts=(".params", ".states"),
            save_interval=args.save_interval,
            num_epochs=args.num_epochs,
            param_names=param_names,
            acc_ind=ds_metainfo.saver_acc_ind,
            # bigger=[True],
            # mask=None,
            score_log_file_path=os.path.join(args.save_dir, "score.log"),
            score_log_attempt_value=args.attempt,
            best_map_log_file_path=os.path.join(args.save_dir, "best_map.log"))
    else:
        lp_saver = None

    train_net(batch_size=batch_size,
              num_epochs=args.num_epochs,
              start_epoch1=args.start_epoch,
              train_data=train_data,
              val_data=val_data,
              batch_fn=batch_fn,
              data_source_needs_reset=ds_metainfo.use_imgrec,
              dtype=args.dtype,
              net=net,
              trainer=trainer,
              lr_scheduler=lr_scheduler,
              lp_saver=lp_saver,
              log_interval=args.log_interval,
              mixup=args.mixup,
              mixup_epoch_tail=args.mixup_epoch_tail,
              label_smoothing=args.label_smoothing,
              num_classes=num_classes,
              grad_clip_value=args.grad_clip,
              batch_size_scale=args.batch_size_scale,
              val_metric=get_composite_metric(
                  ds_metainfo.val_metric_names,
                  ds_metainfo.val_metric_extra_kwargs),
              train_metric=get_composite_metric(
                  ds_metainfo.train_metric_names,
                  ds_metainfo.train_metric_extra_kwargs),
              ctx=ctx)
예제 #11
0
def main():
    args = parse_args()

    if args.disable_cudnn_autotune:
        os.environ["MXNET_CUDNN_AUTOTUNE_DEFAULT"] = "0"

    _, log_file_exist = initialize_logging(
        logging_dir_path=args.save_dir,
        logging_file_name=args.logging_file_name,
        script_args=args,
        log_packages=args.log_packages,
        log_pip_packages=args.log_pip_packages)

    ds_metainfo = get_dataset_metainfo(dataset_name=args.dataset)
    ds_metainfo.update(args=args)
    assert (ds_metainfo.ml_type != "imgseg") or (args.batch_size == 1)
    assert (ds_metainfo.ml_type != "imgseg") or args.disable_cudnn_autotune

    ctx, batch_size = prepare_mx_context(
        num_gpus=args.num_gpus,
        batch_size=args.batch_size)

    net = prepare_model(
        model_name=args.model,
        use_pretrained=args.use_pretrained,
        pretrained_model_file_path=args.resume.strip(),
        dtype=args.dtype,
        net_extra_kwargs=ds_metainfo.net_extra_kwargs,
        load_ignore_extra=ds_metainfo.load_ignore_extra,
        classes=args.num_classes,
        in_channels=args.in_channels,
        do_hybridize=(ds_metainfo.allow_hybridize and (not args.calc_flops)),
        ctx=ctx)
    assert (hasattr(net, "in_size"))
    input_image_size = net.in_size

    if args.data_subset == "val":
        test_data = get_val_data_source(
            ds_metainfo=ds_metainfo,
            batch_size=batch_size,
            num_workers=args.num_workers)
        test_metric = get_composite_metric(
            metric_names=ds_metainfo.val_metric_names,
            metric_extra_kwargs=ds_metainfo.val_metric_extra_kwargs)
    else:
        test_data = get_test_data_source(
            ds_metainfo=ds_metainfo,
            batch_size=batch_size,
            num_workers=args.num_workers)
        test_metric = get_composite_metric(
            metric_names=ds_metainfo.test_metric_names,
            metric_extra_kwargs=ds_metainfo.test_metric_extra_kwargs)
    batch_fn = get_batch_fn(use_imgrec=ds_metainfo.use_imgrec)

    if args.show_progress:
        from tqdm import tqdm
        test_data = tqdm(test_data)

    assert (args.use_pretrained or args.resume.strip() or args.calc_flops_only)
    test(
        net=net,
        test_data=test_data,
        batch_fn=batch_fn,
        data_source_needs_reset=ds_metainfo.use_imgrec,
        metric=test_metric,
        dtype=args.dtype,
        ctx=ctx,
        input_image_size=input_image_size,
        in_channels=args.in_channels,
        # calc_weight_count=(not log_file_exist),
        calc_weight_count=True,
        calc_flops=args.calc_flops,
        calc_flops_only=args.calc_flops_only,
        extended_log=True)
예제 #12
0
def parse_args():
    parser = argparse.ArgumentParser(
        description='Train a model for image classification (Gluon/CIFAR)',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--dataset',
                        type=str,
                        default="CIFAR10",
                        help='dataset name. options are CIFAR10 and CIFAR100')
    parser.add_argument(
        "--work-dir",
        type=str,
        default=os.path.join("..", "imgclsmob_data"),
        help="path to working directory only for dataset root path preset")

    args, _ = parser.parse_known_args()
    dataset_metainfo = get_dataset_metainfo(dataset_name=args.dataset)
    dataset_metainfo.add_dataset_parser_arguments(parser=parser,
                                                  work_dir_path=args.work_dir)

    parser.add_argument(
        '--model',
        type=str,
        required=True,
        help='type of model to use. see model_provider for options.')
    parser.add_argument('--use-pretrained',
                        action='store_true',
                        help='enable using pretrained model from gluon.')
    parser.add_argument('--dtype',
                        type=str,
                        default='float32',
                        help='data type for training')
    parser.add_argument('--not-hybridize',
                        action='store_true',
                        help='do not hybridize model')
    parser.add_argument(
        '--resume',
        type=str,
        default='',
        help='resume from previously saved parameters if not None')
    parser.add_argument(
        '--resume-state',
        type=str,
        default='',
        help='resume from previously saved optimizer state if not None')

    parser.add_argument('--num-gpus',
                        type=int,
                        default=0,
                        help='number of gpus to use.')
    parser.add_argument('-j',
                        '--num-data-workers',
                        dest='num_workers',
                        default=4,
                        type=int,
                        help='number of preprocessing workers')

    parser.add_argument('--batch-size',
                        type=int,
                        default=128,
                        help='training batch size per device (CPU/GPU).')
    parser.add_argument('--batch-size-scale',
                        type=int,
                        default=1,
                        help='manual batch-size increasing factor.')
    parser.add_argument('--num-epochs',
                        type=int,
                        default=200,
                        help='number of training epochs.')
    parser.add_argument(
        '--start-epoch',
        type=int,
        default=1,
        help='starting epoch for resuming, default is 1 for new training')
    parser.add_argument('--attempt',
                        type=int,
                        default=1,
                        help='current number of training')

    parser.add_argument('--optimizer-name',
                        type=str,
                        default='nag',
                        help='optimizer name')
    parser.add_argument('--lr', type=float, default=0.1, help='learning rate')
    parser.add_argument(
        '--lr-mode',
        type=str,
        default='cosine',
        help='learning rate scheduler mode. options are step, poly and cosine')
    parser.add_argument('--lr-decay',
                        type=float,
                        default=0.1,
                        help='decay rate of learning rate')
    parser.add_argument(
        '--lr-decay-period',
        type=int,
        default=0,
        help=
        'interval for periodic learning rate decays. default is 0 to disable.')
    parser.add_argument('--lr-decay-epoch',
                        type=str,
                        default='40,60',
                        help='epoches at which learning rate decays')
    parser.add_argument('--target-lr',
                        type=float,
                        default=1e-8,
                        help='ending learning rate')
    parser.add_argument('--poly-power',
                        type=float,
                        default=2,
                        help='power value for poly LR scheduler')
    parser.add_argument('--warmup-epochs',
                        type=int,
                        default=0,
                        help='number of warmup epochs.')
    parser.add_argument('--warmup-lr',
                        type=float,
                        default=1e-8,
                        help='starting warmup learning rate')
    parser.add_argument(
        '--warmup-mode',
        type=str,
        default='linear',
        help=
        'learning rate scheduler warmup mode. options are linear, poly and constant'
    )
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        help='momentum value for optimizer')
    parser.add_argument('--wd',
                        type=float,
                        default=0.0001,
                        help='weight decay rate')
    parser.add_argument('--gamma-wd-mult',
                        type=float,
                        default=1.0,
                        help='weight decay multiplier for batchnorm gamma')
    parser.add_argument('--beta-wd-mult',
                        type=float,
                        default=1.0,
                        help='weight decay multiplier for batchnorm beta')
    parser.add_argument('--bias-wd-mult',
                        type=float,
                        default=1.0,
                        help='weight decay multiplier for bias')
    parser.add_argument('--grad-clip',
                        type=float,
                        default=None,
                        help='max_norm for gradient clipping')
    parser.add_argument('--label-smoothing',
                        action='store_true',
                        help='use label smoothing')

    parser.add_argument('--mixup',
                        action='store_true',
                        help='use mixup strategy')
    parser.add_argument(
        '--mixup-epoch-tail',
        type=int,
        default=20,
        help='number of epochs without mixup at the end of training')

    parser.add_argument('--log-interval',
                        type=int,
                        default=200,
                        help='number of batches to wait before logging.')
    parser.add_argument(
        '--save-interval',
        type=int,
        default=4,
        help='saving parameters epoch interval, best model will always be saved'
    )
    parser.add_argument('--save-dir',
                        type=str,
                        default='',
                        help='directory of saved models and log-files')
    parser.add_argument('--logging-file-name',
                        type=str,
                        default='train.log',
                        help='filename of training log')

    parser.add_argument('--seed',
                        type=int,
                        default=-1,
                        help='Random seed to be fixed')
    parser.add_argument('--log-packages',
                        type=str,
                        default='mxnet',
                        help='list of python packages for logging')
    parser.add_argument('--log-pip-packages',
                        type=str,
                        default='mxnet-cu100',
                        help='list of pip packages for logging')

    parser.add_argument('--tune-layers',
                        type=str,
                        default='',
                        help='Regexp for selecting layers for fine tuning')
    args = parser.parse_args()
    return args