예제 #1
0
def main():
    args = parse_args()

    os.environ["MXNET_CUDNN_AUTOTUNE_DEFAULT"] = "0"
    assert (args.batch_size == 1)

    _, log_file_exist = initialize_logging(
        logging_dir_path=args.save_dir,
        logging_file_name=args.logging_file_name,
        script_args=args,
        log_packages=args.log_packages,
        log_pip_packages=args.log_pip_packages)

    ds_metainfo = get_dataset_metainfo(dataset_name=args.dataset)
    ds_metainfo.update(args=args)

    ctx, batch_size = prepare_mx_context(num_gpus=args.num_gpus,
                                         batch_size=args.batch_size)

    net = prepare_model(model_name=args.model,
                        use_pretrained=args.use_pretrained,
                        pretrained_model_file_path=args.resume.strip(),
                        dtype=args.dtype,
                        net_extra_kwargs=ds_metainfo.net_extra_kwargs,
                        load_ignore_extra=False,
                        classes=args.num_classes,
                        in_channels=args.in_channels,
                        do_hybridize=False,
                        ctx=ctx)

    test_data = get_val_data_source(ds_metainfo=ds_metainfo,
                                    batch_size=args.batch_size,
                                    num_workers=args.num_workers)

    calc_detector_repeatability(test_data=test_data, net=net, ctx=ctx)
예제 #2
0
def main():
    args = parse_args()

    _, log_file_exist = initialize_logging(
        logging_dir_path=args.save_dir,
        logging_file_name=args.logging_file_name,
        script_args=args,
        log_packages=args.log_packages,
        log_pip_packages=args.log_pip_packages)

    ctx, batch_size = prepare_mx_context(num_gpus=args.num_gpus,
                                         batch_size=args.batch_size)

    net = prepare_model(model_name=args.model,
                        use_pretrained=args.use_pretrained,
                        pretrained_model_file_path=args.resume.strip(),
                        dtype=args.dtype,
                        tune_layers="",
                        ctx=ctx)
    input_image_size = net.in_size if hasattr(
        net, 'in_size') else (args.input_size, args.input_size)

    if args.use_rec:
        train_data, val_data, batch_fn = get_data_rec(
            rec_train=args.rec_train,
            rec_train_idx=args.rec_train_idx,
            rec_val=args.rec_val,
            rec_val_idx=args.rec_val_idx,
            batch_size=batch_size,
            num_workers=args.num_workers,
            input_image_size=input_image_size,
            resize_inv_factor=args.resize_inv_factor)
    else:
        train_data, val_data, batch_fn = get_data_loader(
            data_dir=args.data_dir,
            batch_size=batch_size,
            num_workers=args.num_workers,
            input_image_size=input_image_size,
            resize_inv_factor=args.resize_inv_factor)

    assert (args.use_pretrained or args.resume.strip())
    test(
        net=net,
        val_data=val_data,
        batch_fn=batch_fn,
        use_rec=args.use_rec,
        dtype=args.dtype,
        ctx=ctx,
        # calc_weight_count=(not log_file_exist),
        calc_weight_count=True,
        extended_log=True)
예제 #3
0
def main():
    os.environ["MXNET_CUDNN_AUTOTUNE_DEFAULT"] = "0"
    args = parse_args()

    _, log_file_exist = initialize_logging(
        logging_dir_path=args.save_dir,
        logging_file_name=args.logging_file_name,
        script_args=args,
        log_packages=args.log_packages,
        log_pip_packages=args.log_pip_packages)

    ctx, batch_size = prepare_mx_context(
        num_gpus=args.num_gpus,
        batch_size=1)

    net = prepare_model(
        model_name=args.model,
        use_pretrained=args.use_pretrained,
        pretrained_model_file_path=args.resume.strip(),
        dtype=args.dtype,
        net_extra_kwargs={"aux": False, "fixed_size": False},
        load_ignore_extra=True,
        classes=args.num_classes,
        in_channels=args.in_channels,
        do_hybridize=False,
        ctx=ctx)
    input_image_size = net.in_size if hasattr(net, 'in_size') else (480, 480)

    test_data = get_test_data_source(
        dataset_name=args.dataset,
        dataset_dir=args.data_dir,
        batch_size=batch_size,
        num_workers=args.num_workers)

    assert (args.use_pretrained or args.resume.strip() or args.calc_flops_only)
    test(
        net=net,
        test_data=test_data,
        data_source_needs_reset=False,
        dtype=args.dtype,
        ctx=ctx,
        input_image_size=input_image_size,
        in_channels=args.in_channels,
        classes=args.num_classes,
        # calc_weight_count=(not log_file_exist),
        calc_weight_count=True,
        calc_flops=args.calc_flops,
        calc_flops_only=args.calc_flops_only,
        extended_log=True,
        dataset_metainfo=get_metainfo(args.dataset))
예제 #4
0
def main():
    args = parse_args()

    _, log_file_exist = initialize_logging(
        logging_dir_path=args.save_dir,
        logging_file_name=args.logging_file_name,
        script_args=args,
        log_packages=args.log_packages,
        log_pip_packages=args.log_pip_packages)

    ctx, batch_size = prepare_mx_context(num_gpus=args.num_gpus,
                                         batch_size=args.batch_size)

    net = prepare_model(model_name=args.model,
                        use_pretrained=args.use_pretrained,
                        pretrained_model_file_path=args.resume.strip(),
                        dtype=args.dtype,
                        classes=args.num_classes,
                        in_channels=args.in_channels,
                        do_hybridize=(not args.calc_flops),
                        ctx=ctx)

    assert (hasattr(net, "in_size"))
    input_image_size = net.in_size

    ds_metainfo = get_dataset_metainfo(dataset_name=args.dataset)
    val_data = get_val_data_source(dataset_metainfo=ds_metainfo,
                                   dataset_dir=args.data_dir,
                                   batch_size=batch_size,
                                   num_workers=args.num_workers,
                                   input_image_size=input_image_size,
                                   resize_inv_factor=args.resize_inv_factor)
    batch_fn = get_batch_fn(use_imgrec=ds_metainfo.use_imgrec)

    assert (args.use_pretrained or args.resume.strip() or args.calc_flops_only)
    test(
        net=net,
        val_data=val_data,
        batch_fn=batch_fn,
        data_source_needs_reset=ds_metainfo.use_imgrec,
        val_metric=get_composite_metric(ds_metainfo.val_metric_names),
        dtype=args.dtype,
        ctx=ctx,
        input_image_size=input_image_size,
        in_channels=args.in_channels,
        # calc_weight_count=(not log_file_exist),
        calc_weight_count=True,
        calc_flops=args.calc_flops,
        calc_flops_only=args.calc_flops_only,
        extended_log=True)
예제 #5
0
def main():
    args = parse_args()

    os.environ["MXNET_CUDNN_AUTOTUNE_DEFAULT"] = "0"
    assert (args.batch_size == 1)

    _, log_file_exist = initialize_logging(
        logging_dir_path=args.save_dir,
        logging_file_name=args.logging_file_name,
        script_args=args,
        log_packages=args.log_packages,
        log_pip_packages=args.log_pip_packages)

    ds_metainfo = get_dataset_metainfo(dataset_name=args.dataset)
    ds_metainfo.update(args=args)

    ctx, batch_size = prepare_mx_context(
        num_gpus=args.num_gpus,
        batch_size=args.batch_size)

    net = prepare_model(
        model_name=args.model,
        use_pretrained=args.use_pretrained,
        pretrained_model_file_path=args.resume.strip(),
        dtype=args.dtype,
        net_extra_kwargs=None,
        load_ignore_extra=False,
        classes=args.num_classes,
        in_channels=args.in_channels,
        do_hybridize=False,
        ctx=ctx)

    test_data = get_val_data_source(
        ds_metainfo=ds_metainfo,
        batch_size=args.batch_size,
        num_workers=args.num_workers)

    tic = time.time()
    for batch in test_data:
        data_src_list, data_dst_list, labels_list = batch_fn(batch, ctx)
        outputs_src_list = [net(X) for X in data_src_list]
        assert (outputs_src_list is not None)
        pass
    logging.info("Time cost: {:.4f} sec".format(
        time.time() - tic))
예제 #6
0
def main():
    args = parse_args()

    _, log_file_exist = initialize_logging(
        logging_dir_path=args.save_dir,
        logging_file_name=args.logging_file_name,
        script_args=args,
        log_packages=args.log_packages,
        log_pip_packages=args.log_pip_packages)

    ctx, batch_size = prepare_mx_context(
        num_gpus=args.num_gpus,
        batch_size=args.batch_size)

    net = prepare_model(
        model_name=args.model,
        use_pretrained=args.use_pretrained,
        pretrained_model_file_path=args.resume.strip(),
        dtype=args.dtype,
        classes=args.num_classes,
        in_channels=args.in_channels,
        do_hybridize=(not args.calc_flops),
        ctx=ctx)
    input_image_size = net.in_size if hasattr(net, 'in_size') else (32, 32)

    val_data = get_val_data_source(
        dataset_name=args.dataset,
        dataset_dir=args.data_dir,
        batch_size=batch_size,
        num_workers=args.num_workers)

    assert (args.use_pretrained or args.resume.strip() or args.calc_flops_only)
    test(
        net=net,
        val_data=val_data,
        data_source_needs_reset=False,
        dtype=args.dtype,
        ctx=ctx,
        input_image_size=input_image_size,
        in_channels=args.in_channels,
        # calc_weight_count=(not log_file_exist),
        calc_weight_count=True,
        calc_flops=args.calc_flops,
        calc_flops_only=args.calc_flops_only,
        extended_log=True)
예제 #7
0
def main():
    """
    Main body of script.
    """
    args = parse_args()
    args.seed = init_rand(seed=args.seed)

    _, log_file_exist = initialize_logging(
        logging_dir_path=args.save_dir,
        logging_file_name=args.logging_file_name,
        script_args=args,
        log_packages=args.log_packages,
        log_pip_packages=args.log_pip_packages)

    ctx, batch_size = prepare_mx_context(num_gpus=args.num_gpus,
                                         batch_size=args.batch_size)

    ds_metainfo = get_dataset_metainfo(dataset_name=args.dataset)
    ds_metainfo.update(args=args)

    use_teacher = (args.teacher_models
                   is not None) and (args.teacher_models.strip() != "")

    net = prepare_model(
        model_name=args.model,
        use_pretrained=args.use_pretrained,
        pretrained_model_file_path=args.resume.strip(),
        dtype=args.dtype,
        net_extra_kwargs=ds_metainfo.train_net_extra_kwargs,
        tune_layers=args.tune_layers,
        classes=args.num_classes,
        in_channels=args.in_channels,
        do_hybridize=(not args.not_hybridize),
        initializer=get_initializer(initializer_name=args.initializer),
        ctx=ctx)
    assert (hasattr(net, "classes"))
    num_classes = net.classes

    teacher_net = None
    discrim_net = None
    discrim_loss_func = None
    if use_teacher:
        teacher_nets = []
        for teacher_model in args.teacher_models.split(","):
            teacher_net = prepare_model(
                model_name=teacher_model.strip(),
                use_pretrained=True,
                pretrained_model_file_path="",
                dtype=args.dtype,
                net_extra_kwargs=ds_metainfo.train_net_extra_kwargs,
                do_hybridize=(not args.not_hybridize),
                ctx=ctx)
            assert (teacher_net.classes == net.classes)
            assert (teacher_net.in_size == net.in_size)
            teacher_nets.append(teacher_net)
        if len(teacher_nets) > 0:
            teacher_net = Concurrent(stack=True,
                                     prefix="",
                                     branches=teacher_nets)
            for k, v in teacher_net.collect_params().items():
                v.grad_req = "null"
            if not args.not_discriminator:
                discrim_net = MealDiscriminator()
                discrim_net.cast(args.dtype)
                if not args.not_hybridize:
                    discrim_net.hybridize(static_alloc=True, static_shape=True)
                discrim_net.initialize(mx.init.MSRAPrelu(), ctx=ctx)
                for k, v in discrim_net.collect_params().items():
                    v.lr_mult = args.dlr_factor
                discrim_loss_func = MealAdvLoss()

    train_data = get_train_data_source(ds_metainfo=ds_metainfo,
                                       batch_size=batch_size,
                                       num_workers=args.num_workers)
    val_data = get_val_data_source(ds_metainfo=ds_metainfo,
                                   batch_size=batch_size,
                                   num_workers=args.num_workers)
    batch_fn = get_batch_fn(ds_metainfo=ds_metainfo)

    num_training_samples = len(
        train_data._dataset
    ) if not ds_metainfo.use_imgrec else ds_metainfo.num_training_samples
    trainer, lr_scheduler = prepare_trainer(
        net=net,
        optimizer_name=args.optimizer_name,
        wd=args.wd,
        momentum=args.momentum,
        lr_mode=args.lr_mode,
        lr=args.lr,
        lr_decay_period=args.lr_decay_period,
        lr_decay_epoch=args.lr_decay_epoch,
        lr_decay=args.lr_decay,
        target_lr=args.target_lr,
        poly_power=args.poly_power,
        warmup_epochs=args.warmup_epochs,
        warmup_lr=args.warmup_lr,
        warmup_mode=args.warmup_mode,
        batch_size=batch_size,
        num_epochs=args.num_epochs,
        num_training_samples=num_training_samples,
        dtype=args.dtype,
        gamma_wd_mult=args.gamma_wd_mult,
        beta_wd_mult=args.beta_wd_mult,
        bias_wd_mult=args.bias_wd_mult,
        state_file_path=args.resume_state)

    if args.save_dir and args.save_interval:
        param_names = ds_metainfo.val_metric_capts + ds_metainfo.train_metric_capts + [
            "Train.Loss", "LR"
        ]
        lp_saver = TrainLogParamSaver(
            checkpoint_file_name_prefix="{}_{}".format(ds_metainfo.short_label,
                                                       args.model),
            last_checkpoint_file_name_suffix="last",
            best_checkpoint_file_name_suffix=None,
            last_checkpoint_dir_path=args.save_dir,
            best_checkpoint_dir_path=None,
            last_checkpoint_file_count=2,
            best_checkpoint_file_count=2,
            checkpoint_file_save_callback=save_params,
            checkpoint_file_exts=(".params", ".states"),
            save_interval=args.save_interval,
            num_epochs=args.num_epochs,
            param_names=param_names,
            acc_ind=ds_metainfo.saver_acc_ind,
            # bigger=[True],
            # mask=None,
            score_log_file_path=os.path.join(args.save_dir, "score.log"),
            score_log_attempt_value=args.attempt,
            best_map_log_file_path=os.path.join(args.save_dir, "best_map.log"))
    else:
        lp_saver = None

    val_metric = get_composite_metric(ds_metainfo.val_metric_names,
                                      ds_metainfo.val_metric_extra_kwargs)
    train_metric = get_composite_metric(ds_metainfo.train_metric_names,
                                        ds_metainfo.train_metric_extra_kwargs)
    loss_metrics = [LossValue(name="loss"), LossValue(name="dloss")]

    loss_kwargs = {
        "sparse_label": (not (args.mixup or args.label_smoothing)
                         and not (use_teacher and (teacher_net is not None)))
    }
    if ds_metainfo.loss_extra_kwargs is not None:
        loss_kwargs.update(ds_metainfo.loss_extra_kwargs)
    loss_func = get_loss(ds_metainfo.loss_name, loss_kwargs)

    train_net(batch_size=batch_size,
              num_epochs=args.num_epochs,
              start_epoch1=args.start_epoch,
              train_data=train_data,
              val_data=val_data,
              batch_fn=batch_fn,
              data_source_needs_reset=ds_metainfo.use_imgrec,
              dtype=args.dtype,
              net=net,
              teacher_net=teacher_net,
              discrim_net=discrim_net,
              trainer=trainer,
              lr_scheduler=lr_scheduler,
              lp_saver=lp_saver,
              log_interval=args.log_interval,
              mixup=args.mixup,
              mixup_epoch_tail=args.mixup_epoch_tail,
              label_smoothing=args.label_smoothing,
              num_classes=num_classes,
              grad_clip_value=args.grad_clip,
              batch_size_scale=args.batch_size_scale,
              val_metric=val_metric,
              train_metric=train_metric,
              loss_metrics=loss_metrics,
              loss_func=loss_func,
              discrim_loss_func=discrim_loss_func,
              ctx=ctx)
예제 #8
0
파일: eval_gl.py 프로젝트: siddie/imgclsmob
def test_model(args):
    """
    Main test routine.

    Parameters:
    ----------
    args : ArgumentParser
        Main script arguments.

    Returns:
    -------
    float
        Main accuracy value.
    """
    ds_metainfo = get_dataset_metainfo(dataset_name=args.dataset)
    ds_metainfo.update(args=args)
    assert (ds_metainfo.ml_type !=
            "imgseg") or (args.data_subset != "test") or (args.batch_size == 1)
    assert (ds_metainfo.ml_type != "imgseg") or args.disable_cudnn_autotune

    ctx, batch_size = prepare_mx_context(num_gpus=args.num_gpus,
                                         batch_size=args.batch_size)

    net = prepare_model(
        model_name=args.model,
        use_pretrained=args.use_pretrained,
        pretrained_model_file_path=args.resume.strip(),
        dtype=args.dtype,
        net_extra_kwargs=ds_metainfo.test_net_extra_kwargs,
        load_ignore_extra=ds_metainfo.load_ignore_extra,
        classes=(args.num_classes if ds_metainfo.ml_type != "hpe" else None),
        in_channels=args.in_channels,
        do_hybridize=(ds_metainfo.allow_hybridize and (not args.calc_flops)),
        ctx=ctx)
    assert (hasattr(net, "in_size"))
    input_image_size = net.in_size

    get_test_data_source_class = get_val_data_source if args.data_subset == "val" else get_test_data_source
    test_data = get_test_data_source_class(ds_metainfo=ds_metainfo,
                                           batch_size=args.batch_size,
                                           num_workers=args.num_workers)
    batch_fn = get_batch_fn(ds_metainfo=ds_metainfo)
    if args.data_subset == "val":
        test_metric = get_composite_metric(
            metric_names=ds_metainfo.val_metric_names,
            metric_extra_kwargs=ds_metainfo.val_metric_extra_kwargs)
    else:
        test_metric = get_composite_metric(
            metric_names=ds_metainfo.test_metric_names,
            metric_extra_kwargs=ds_metainfo.test_metric_extra_kwargs)

    if not args.not_show_progress:
        from tqdm import tqdm
        test_data = tqdm(test_data)

    assert (args.use_pretrained or args.resume.strip() or args.calc_flops_only)
    acc_values = calc_model_accuracy(
        net=net,
        test_data=test_data,
        batch_fn=batch_fn,
        data_source_needs_reset=ds_metainfo.use_imgrec,
        metric=test_metric,
        dtype=args.dtype,
        ctx=ctx,
        input_image_size=input_image_size,
        in_channels=args.in_channels,
        # calc_weight_count=(not log_file_exist),
        calc_weight_count=True,
        calc_flops=args.calc_flops,
        calc_flops_only=args.calc_flops_only,
        extended_log=True)
    return acc_values[
        ds_metainfo.saver_acc_ind] if len(acc_values) > 0 else None
예제 #9
0
def main():
    """
    Main body of script.
    """
    args = parse_args()

    _, log_file_exist = initialize_logging(
        logging_dir_path=args.save_dir,
        logging_file_name=args.logging_file_name,
        script_args=args,
        log_packages=args.log_packages,
        log_pip_packages=args.log_pip_packages)

    ds_metainfo = get_dataset_metainfo(dataset_name=args.dataset)
    ds_metainfo.update(args=args)

    ctx, batch_size = prepare_mx_context(
        num_gpus=args.num_gpus,
        batch_size=args.batch_size)

    net = prepare_model(
        model_name=args.model,
        use_pretrained=args.use_pretrained,
        pretrained_model_file_path=args.resume.strip(),
        dtype=args.dtype,
        net_extra_kwargs=ds_metainfo.net_extra_kwargs,
        load_ignore_extra=ds_metainfo.load_ignore_extra,
        classes=args.num_classes,
        in_channels=args.in_channels,
        do_hybridize=(ds_metainfo.allow_hybridize and (not args.calc_flops)),
        ctx=ctx)
    assert (hasattr(net, "in_size"))
    input_image_size = net.in_size

    if args.data_subset == "val":
        get_test_data_source_class = get_val_data_source
        test_metric = get_composite_metric(
            metric_names=ds_metainfo.val_metric_names,
            metric_extra_kwargs=ds_metainfo.val_metric_extra_kwargs)
    else:
        get_test_data_source_class = get_test_data_source
        test_metric = get_composite_metric(
            metric_names=ds_metainfo.test_metric_names,
            metric_extra_kwargs=ds_metainfo.test_metric_extra_kwargs)
    test_data = get_test_data_source_class(
        ds_metainfo=ds_metainfo,
        batch_size=args.batch_size,
        num_workers=args.num_workers)
    batch_fn = get_batch_fn(use_imgrec=ds_metainfo.use_imgrec)

    if not args.not_show_progress:
        test_data = tqdm(test_data)

    assert (args.use_pretrained or args.resume.strip() or args.calc_flops_only)
    test(
        net=net,
        test_data=test_data,
        batch_fn=batch_fn,
        data_source_needs_reset=ds_metainfo.use_imgrec,
        metric=test_metric,
        dtype=args.dtype,
        ctx=ctx,
        input_image_size=input_image_size,
        in_channels=args.in_channels,
        calc_weight_count=True,
        calc_flops=args.calc_flops,
        calc_flops_only=args.calc_flops_only,
        extended_log=True,
        show_bad_samples=args.show_bad_samples)
예제 #10
0
def main():
    args = parse_args()
    args.seed = init_rand(seed=args.seed)

    _, log_file_exist = initialize_logging(
        logging_dir_path=args.save_dir,
        logging_file_name=args.logging_file_name,
        script_args=args,
        log_packages=args.log_packages,
        log_pip_packages=args.log_pip_packages)

    ctx, batch_size = prepare_mx_context(num_gpus=args.num_gpus,
                                         batch_size=args.batch_size)

    net = prepare_model(model_name=args.model,
                        use_pretrained=args.use_pretrained,
                        pretrained_model_file_path=args.resume.strip(),
                        dtype=args.dtype,
                        tune_layers=args.tune_layers,
                        classes=args.num_classes,
                        in_channels=args.in_channels,
                        do_hybridize=(not args.not_hybridize),
                        ctx=ctx)

    assert (hasattr(net, 'classes'))
    num_classes = net.classes if hasattr(net, 'classes') else 10

    train_data = get_train_data_source(dataset_name=args.dataset,
                                       dataset_dir=args.data_dir,
                                       batch_size=batch_size,
                                       num_workers=args.num_workers)
    val_data = get_val_data_source(dataset_name=args.dataset,
                                   dataset_dir=args.data_dir,
                                   batch_size=batch_size,
                                   num_workers=args.num_workers)

    trainer, lr_scheduler = prepare_trainer(
        net=net,
        optimizer_name=args.optimizer_name,
        wd=args.wd,
        momentum=args.momentum,
        lr_mode=args.lr_mode,
        lr=args.lr,
        lr_decay_period=args.lr_decay_period,
        lr_decay_epoch=args.lr_decay_epoch,
        lr_decay=args.lr_decay,
        target_lr=args.target_lr,
        poly_power=args.poly_power,
        warmup_epochs=args.warmup_epochs,
        warmup_lr=args.warmup_lr,
        warmup_mode=args.warmup_mode,
        batch_size=batch_size,
        num_epochs=args.num_epochs,
        num_training_samples=get_num_training_samples(args.dataset),
        dtype=args.dtype,
        gamma_wd_mult=args.gamma_wd_mult,
        beta_wd_mult=args.beta_wd_mult,
        bias_wd_mult=args.bias_wd_mult,
        state_file_path=args.resume_state)

    if args.save_dir and args.save_interval:
        lp_saver = TrainLogParamSaver(
            checkpoint_file_name_prefix='{}_{}'.format(args.dataset.lower(),
                                                       args.model),
            last_checkpoint_file_name_suffix="last",
            best_checkpoint_file_name_suffix=None,
            last_checkpoint_dir_path=args.save_dir,
            best_checkpoint_dir_path=None,
            last_checkpoint_file_count=2,
            best_checkpoint_file_count=2,
            checkpoint_file_save_callback=save_params,
            checkpoint_file_exts=('.params', '.states'),
            save_interval=args.save_interval,
            num_epochs=args.num_epochs,
            param_names=['Val.Err', 'Train.Err', 'Train.Loss', 'LR'],
            acc_ind=0,
            # bigger=[True],
            # mask=None,
            score_log_file_path=os.path.join(args.save_dir, 'score.log'),
            score_log_attempt_value=args.attempt,
            best_map_log_file_path=os.path.join(args.save_dir, 'best_map.log'))
    else:
        lp_saver = None

    train_net(batch_size=batch_size,
              num_epochs=args.num_epochs,
              start_epoch1=args.start_epoch,
              train_data=train_data,
              val_data=val_data,
              data_source_needs_reset=False,
              dtype=args.dtype,
              net=net,
              trainer=trainer,
              lr_scheduler=lr_scheduler,
              lp_saver=lp_saver,
              log_interval=args.log_interval,
              mixup=args.mixup,
              mixup_epoch_tail=args.mixup_epoch_tail,
              label_smoothing=args.label_smoothing,
              num_classes=num_classes,
              grad_clip_value=args.grad_clip,
              batch_size_scale=args.batch_size_scale,
              ctx=ctx)
예제 #11
0
def main():
    """
    Main body of script.
    """
    args = parse_args()
    args.seed = init_rand(seed=args.seed)

    _, log_file_exist = initialize_logging(
        logging_dir_path=args.save_dir,
        logging_file_name=args.logging_file_name,
        script_args=args,
        log_packages=args.log_packages,
        log_pip_packages=args.log_pip_packages)

    ctx, batch_size = prepare_mx_context(num_gpus=args.num_gpus,
                                         batch_size=args.batch_size)

    net = prepare_model(
        model_name=args.model,
        use_pretrained=args.use_pretrained,
        pretrained_model_file_path=args.resume.strip(),
        dtype=args.dtype,
        tune_layers=args.tune_layers,
        classes=args.num_classes,
        in_channels=args.in_channels,
        do_hybridize=(not args.not_hybridize),
        initializer=get_initializer(initializer_name=args.initializer),
        ctx=ctx)
    assert (hasattr(net, "classes"))
    num_classes = net.classes

    ds_metainfo = get_dataset_metainfo(dataset_name=args.dataset)
    ds_metainfo.update(args=args)

    train_data = get_train_data_source(ds_metainfo=ds_metainfo,
                                       batch_size=batch_size,
                                       num_workers=args.num_workers)
    val_data = get_val_data_source(ds_metainfo=ds_metainfo,
                                   batch_size=batch_size,
                                   num_workers=args.num_workers)
    batch_fn = get_batch_fn(use_imgrec=ds_metainfo.use_imgrec)

    num_training_samples = len(
        train_data._dataset
    ) if not ds_metainfo.use_imgrec else ds_metainfo.num_training_samples
    trainer, lr_scheduler = prepare_trainer(
        net=net,
        optimizer_name=args.optimizer_name,
        wd=args.wd,
        momentum=args.momentum,
        lr_mode=args.lr_mode,
        lr=args.lr,
        lr_decay_period=args.lr_decay_period,
        lr_decay_epoch=args.lr_decay_epoch,
        lr_decay=args.lr_decay,
        target_lr=args.target_lr,
        poly_power=args.poly_power,
        warmup_epochs=args.warmup_epochs,
        warmup_lr=args.warmup_lr,
        warmup_mode=args.warmup_mode,
        batch_size=batch_size,
        num_epochs=args.num_epochs,
        num_training_samples=num_training_samples,
        dtype=args.dtype,
        gamma_wd_mult=args.gamma_wd_mult,
        beta_wd_mult=args.beta_wd_mult,
        bias_wd_mult=args.bias_wd_mult,
        state_file_path=args.resume_state)

    if args.save_dir and args.save_interval:
        param_names = ds_metainfo.val_metric_capts + ds_metainfo.train_metric_capts + [
            "Train.Loss", "LR"
        ]
        lp_saver = TrainLogParamSaver(
            checkpoint_file_name_prefix="{}_{}".format(ds_metainfo.short_label,
                                                       args.model),
            last_checkpoint_file_name_suffix="last",
            best_checkpoint_file_name_suffix=None,
            last_checkpoint_dir_path=args.save_dir,
            best_checkpoint_dir_path=None,
            last_checkpoint_file_count=2,
            best_checkpoint_file_count=2,
            checkpoint_file_save_callback=save_params,
            checkpoint_file_exts=(".params", ".states"),
            save_interval=args.save_interval,
            num_epochs=args.num_epochs,
            param_names=param_names,
            acc_ind=ds_metainfo.saver_acc_ind,
            # bigger=[True],
            # mask=None,
            score_log_file_path=os.path.join(args.save_dir, "score.log"),
            score_log_attempt_value=args.attempt,
            best_map_log_file_path=os.path.join(args.save_dir, "best_map.log"))
    else:
        lp_saver = None

    train_net(batch_size=batch_size,
              num_epochs=args.num_epochs,
              start_epoch1=args.start_epoch,
              train_data=train_data,
              val_data=val_data,
              batch_fn=batch_fn,
              data_source_needs_reset=ds_metainfo.use_imgrec,
              dtype=args.dtype,
              net=net,
              trainer=trainer,
              lr_scheduler=lr_scheduler,
              lp_saver=lp_saver,
              log_interval=args.log_interval,
              mixup=args.mixup,
              mixup_epoch_tail=args.mixup_epoch_tail,
              label_smoothing=args.label_smoothing,
              num_classes=num_classes,
              grad_clip_value=args.grad_clip,
              batch_size_scale=args.batch_size_scale,
              val_metric=get_composite_metric(
                  ds_metainfo.val_metric_names,
                  ds_metainfo.val_metric_extra_kwargs),
              train_metric=get_composite_metric(
                  ds_metainfo.train_metric_names,
                  ds_metainfo.train_metric_extra_kwargs),
              ctx=ctx)
예제 #12
0
def main():
    args = parse_args()
    args.seed = init_rand(seed=args.seed)

    _, log_file_exist = initialize_logging(
        logging_dir_path=args.save_dir,
        logging_file_name=args.logging_file_name,
        script_args=args,
        log_packages=args.log_packages,
        log_pip_packages=args.log_pip_packages)

    ctx, batch_size = prepare_mx_context(
        num_gpus=args.num_gpus,
        batch_size=args.batch_size)

    # if args.convert_to_mxnet:
    #     batch_size = 1

    net = prepare_model(
        model_name=args.model,
        use_pretrained=args.use_pretrained,
        pretrained_model_file_path=args.resume.strip(),
        dtype=args.dtype,
        tune_layers=args.tune_layers,
        ctx=ctx)
    num_classes = net.classes if hasattr(net, 'classes') else 1000
    input_image_size = net.in_size if hasattr(net, 'in_size') else (224, 224)

    if args.use_rec:
        train_data, val_data, batch_fn = get_data_rec(
            rec_train=args.rec_train,
            rec_train_idx=args.rec_train_idx,
            rec_val=args.rec_val,
            rec_val_idx=args.rec_val_idx,
            batch_size=batch_size,
            num_workers=args.num_workers,
            input_image_size=input_image_size)
    else:
        train_data, val_data, batch_fn = get_data_loader(
            data_dir=args.data_dir,
            batch_size=batch_size,
            num_workers=args.num_workers,
            input_image_size=input_image_size)

    # if args.convert_to_mxnet:
    #     assert args.save_dir and os.path.exists(args.save_dir)
    #     assert (args.use_pretrained or args.resume.strip())
    #     x = mx.nd.array(np.zeros((1, 3, 224, 224), np.float32), ctx)
    #     net.forward(x)
    #     export_checkpoint_file_path_prefix = os.path.join(args.save_dir, 'imagenet_{}'.format(args.model))
    #     net.export(export_checkpoint_file_path_prefix)
    #     logging.info('Convert model to MXNet format: {}'.format(export_checkpoint_file_path_prefix))

    num_training_samples = 1281167
    trainer, lr_scheduler = prepare_trainer(
        net=net,
        optimizer_name=args.optimizer_name,
        wd=args.wd,
        momentum=args.momentum,
        lr_mode=args.lr_mode,
        lr=args.lr,
        lr_decay_period=args.lr_decay_period,
        lr_decay_epoch=args.lr_decay_epoch,
        lr_decay=args.lr_decay,
        target_lr=args.target_lr,
        poly_power=args.poly_power,
        warmup_epochs=args.warmup_epochs,
        warmup_lr=args.warmup_lr,
        warmup_mode=args.warmup_mode,
        batch_size=batch_size,
        num_epochs=args.num_epochs,
        num_training_samples=num_training_samples,
        dtype=args.dtype,
        state_file_path=args.resume_state)

    if args.save_dir and args.save_interval:
        lp_saver = TrainLogParamSaver(
            checkpoint_file_name_prefix='imagenet_{}'.format(args.model),
            last_checkpoint_file_name_suffix="last",
            best_checkpoint_file_name_suffix=None,
            last_checkpoint_dir_path=args.save_dir,
            best_checkpoint_dir_path=None,
            last_checkpoint_file_count=2,
            best_checkpoint_file_count=2,
            checkpoint_file_save_callback=save_params,
            checkpoint_file_exts=('.params', '.states'),
            save_interval=args.save_interval,
            num_epochs=args.num_epochs,
            param_names=['Val.Top1', 'Train.Top1', 'Val.Top5', 'Train.Loss', 'LR'],
            acc_ind=2,
            # bigger=[True],
            # mask=None,
            score_log_file_path=os.path.join(args.save_dir, 'score.log'),
            score_log_attempt_value=args.attempt,
            best_map_log_file_path=os.path.join(args.save_dir, 'best_map.log'))
    else:
        lp_saver = None

    train_net(
        batch_size=batch_size,
        num_epochs=args.num_epochs,
        start_epoch1=args.start_epoch,
        train_data=train_data,
        val_data=val_data,
        batch_fn=batch_fn,
        use_rec=args.use_rec,
        dtype=args.dtype,
        net=net,
        trainer=trainer,
        lr_scheduler=lr_scheduler,
        lp_saver=lp_saver,
        log_interval=args.log_interval,
        mixup=args.mixup,
        mixup_epoch_tail=args.mixup_epoch_tail,
        num_classes=num_classes,
        ctx=ctx)
예제 #13
0
def main():
    args = parse_args()

    if args.disable_cudnn_autotune:
        os.environ["MXNET_CUDNN_AUTOTUNE_DEFAULT"] = "0"

    _, log_file_exist = initialize_logging(
        logging_dir_path=args.save_dir,
        logging_file_name=args.logging_file_name,
        script_args=args,
        log_packages=args.log_packages,
        log_pip_packages=args.log_pip_packages)

    ds_metainfo = get_dataset_metainfo(dataset_name=args.dataset)
    ds_metainfo.update(args=args)
    assert (ds_metainfo.ml_type != "imgseg") or (args.batch_size == 1)
    assert (ds_metainfo.ml_type != "imgseg") or args.disable_cudnn_autotune

    ctx, batch_size = prepare_mx_context(
        num_gpus=args.num_gpus,
        batch_size=args.batch_size)

    net = prepare_model(
        model_name=args.model,
        use_pretrained=args.use_pretrained,
        pretrained_model_file_path=args.resume.strip(),
        dtype=args.dtype,
        net_extra_kwargs=ds_metainfo.net_extra_kwargs,
        load_ignore_extra=ds_metainfo.load_ignore_extra,
        classes=args.num_classes,
        in_channels=args.in_channels,
        do_hybridize=(ds_metainfo.allow_hybridize and (not args.calc_flops)),
        ctx=ctx)
    assert (hasattr(net, "in_size"))
    input_image_size = net.in_size

    if args.data_subset == "val":
        test_data = get_val_data_source(
            ds_metainfo=ds_metainfo,
            batch_size=batch_size,
            num_workers=args.num_workers)
        test_metric = get_composite_metric(
            metric_names=ds_metainfo.val_metric_names,
            metric_extra_kwargs=ds_metainfo.val_metric_extra_kwargs)
    else:
        test_data = get_test_data_source(
            ds_metainfo=ds_metainfo,
            batch_size=batch_size,
            num_workers=args.num_workers)
        test_metric = get_composite_metric(
            metric_names=ds_metainfo.test_metric_names,
            metric_extra_kwargs=ds_metainfo.test_metric_extra_kwargs)
    batch_fn = get_batch_fn(use_imgrec=ds_metainfo.use_imgrec)

    if args.show_progress:
        from tqdm import tqdm
        test_data = tqdm(test_data)

    assert (args.use_pretrained or args.resume.strip() or args.calc_flops_only)
    test(
        net=net,
        test_data=test_data,
        batch_fn=batch_fn,
        data_source_needs_reset=ds_metainfo.use_imgrec,
        metric=test_metric,
        dtype=args.dtype,
        ctx=ctx,
        input_image_size=input_image_size,
        in_channels=args.in_channels,
        # calc_weight_count=(not log_file_exist),
        calc_weight_count=True,
        calc_flops=args.calc_flops,
        calc_flops_only=args.calc_flops_only,
        extended_log=True)