コード例 #1
0
def test_model():
    """Evaluates a trained model."""
    # Setup training/testing environment
    setup_env()
    # Construct the model
    model = setup_model()

    # Load model weights
    if cfg.TEST.WEIGHTS:
        checkpoint.load_checkpoint(cfg.TEST.WEIGHTS, model)
        logger.info("Loaded model weights from: {}".format(cfg.TEST.WEIGHTS))
    elif checkpoint.has_checkpoint():
        last_checkpoint = checkpoint.get_last_checkpoint()
        checkpoint.load_checkpoint(last_checkpoint, model)
        logger.info("Loaded checkpoint from: {}".format(last_checkpoint))
    else:
        print("ERROR: NO checkpoint! ")
        os._exit()
    # Create data loaders and meters
    test_loader = loader.construct_test_loader()
    test_meter = meters.TestMeter(len(test_loader))
    # Evaluate the model
    if cfg.TASK == 'psd' or cfg.TASK == 'fix':
        result,ce_error=test_epoch_semi(test_loader, model, test_meter, 0)
    else:
        result,ce_error=test_epoch_semi(test_loader, model, test_meter, 0)
    with open(cfg.OUT_DIR+'/result.txt','w') as f:
        f.write(str(result["top1_err"])+'\n')
        f.write(str(ce_error[0])+'\n')      
        f.write(str(ce_error[1])+'\n')  
    print(result["top1_err"],ce_error)
コード例 #2
0
def test_model():
    """Evaluates the model."""

    # Build the model (before the loaders to speed up debugging)
    model = model_builder.build_model()
    log_model_info(model)

    # Compute precise time
    if cfg.PREC_TIME.ENABLED:
        logger.info("Computing precise time...")
        loss_fun = losses.get_loss_fun()
        bu.compute_precise_time(model, loss_fun)
        nu.reset_bn_stats(model)

    # Load model weights
    cu.load_checkpoint(cfg.TEST.WEIGHTS, model)
    logger.info("Loaded model weights from: {}".format(cfg.TEST.WEIGHTS))

    # Create data loaders
    test_loader = loader.construct_test_loader()

    # Create meters
    test_meter = TestMeter(len(test_loader))

    # Evaluate the model
    test_epoch(test_loader, model, test_meter, 0)
コード例 #3
0
def train_model():
    """Trains the model."""

    # Build the model (before the loaders to speed up debugging)
    model = model_builder.build_model()
    log_model_info(model)

    # Define the loss function
    loss_fun = losses.get_loss_fun()
    # Construct the optimizer
    optimizer = optim.construct_optimizer(model)

    # Load checkpoint or initial weights
    start_epoch = 0
    if cfg.TRAIN.AUTO_RESUME and cu.has_checkpoint():
        last_checkpoint = cu.get_last_checkpoint()
        checkpoint_epoch = cu.load_checkpoint(last_checkpoint, model,
                                              optimizer)
        logger.info("Loaded checkpoint from: {}".format(last_checkpoint))
        start_epoch = checkpoint_epoch + 1
    elif cfg.TRAIN.WEIGHTS:
        cu.load_checkpoint(cfg.TRAIN.WEIGHTS, model)
        logger.info("Loaded initial weights from: {}".format(
            cfg.TRAIN.WEIGHTS))

    # Compute precise time
    if start_epoch == 0 and cfg.PREC_TIME.ENABLED:
        logger.info("Computing precise time...")
        bu.compute_precise_time(model, loss_fun)
        nu.reset_bn_stats(model)

    # Create data loaders
    train_loader = loader.construct_train_loader()
    test_loader = loader.construct_test_loader()

    # Create meters
    train_meter = TrainMeter(len(train_loader))
    test_meter = TestMeter(len(test_loader))

    # Perform the training loop
    logger.info("Start epoch: {}".format(start_epoch + 1))

    for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
        # Train for one epoch
        train_epoch(train_loader, model, loss_fun, optimizer, train_meter,
                    cur_epoch)
        # Compute precise BN stats
        if cfg.BN.USE_PRECISE_STATS:
            nu.compute_precise_bn_stats(model, train_loader)
        # Save a checkpoint
        if cu.is_checkpoint_epoch(cur_epoch):
            checkpoint_file = cu.save_checkpoint(model, optimizer, cur_epoch)
            logger.info("Wrote checkpoint to: {}".format(checkpoint_file))
        # Evaluate the model
        if is_eval_epoch(cur_epoch):
            test_epoch(test_loader, model, test_meter, cur_epoch)
コード例 #4
0
def time_model():
    """Times model and data loader."""
    # Setup training/testing environment
    setup_env()
    # Construct the model and loss_fun
    model = setup_model()
    loss_fun = builders.build_loss_fun().cuda()
    # Create data loaders
    train_loader = loader.construct_train_loader()
    test_loader = loader.construct_test_loader()
    # Compute model and loader timings
    benchmark.compute_time_full(model, loss_fun, train_loader, test_loader)
コード例 #5
0
def train_model():
    """Trains the model."""
    # Setup training/testing environment
    setup_env()
    # Construct the model, loss_fun, and optimizer
    model = setup_model()
    loss_fun = builders.build_loss_fun().cuda()
    optimizer = optim.construct_optimizer(model)
    # Load checkpoint or initial weights
    start_epoch = 0
    if cfg.TRAIN.AUTO_RESUME and checkpoint.has_checkpoint():
        last_checkpoint = checkpoint.get_last_checkpoint()
        checkpoint_epoch = checkpoint.load_checkpoint(last_checkpoint, model,
                                                      optimizer)
        logger.info("Loaded checkpoint from: {}".format(last_checkpoint))
        start_epoch = checkpoint_epoch + 1
    elif cfg.TRAIN.WEIGHTS:
        checkpoint.load_checkpoint(cfg.TRAIN.WEIGHTS, model)
        logger.info("Loaded initial weights from: {}".format(
            cfg.TRAIN.WEIGHTS))
    # Create data loaders and meters
    train_loader = loader.construct_train_loader()
    test_loader = loader.construct_test_loader()
    train_meter = meters.TrainMeter(len(train_loader))
    test_meter = meters.TestMeter(len(test_loader))
    # Compute model and loader timings
    if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0:
        benchmark.compute_time_full(model, loss_fun, train_loader, test_loader)
    # Perform the training loop
    logger.info("Start epoch: {}".format(start_epoch + 1))
    for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
        if hasattr(cfg, 'search_epoch'):
            if cur_epoch >= cfg.search_epoch:
                break
        # Train for one epoch
        train_epoch(train_loader, model, loss_fun, optimizer, train_meter,
                    cur_epoch)
        # Compute precise BN stats
        if cfg.BN.USE_PRECISE_STATS:
            net.compute_precise_bn_stats(model, train_loader)
        # Save a checkpoint
        if (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0:
            checkpoint_file = checkpoint.save_checkpoint(
                model, optimizer, cur_epoch)
            logger.info("Wrote checkpoint to: {}".format(checkpoint_file))
        # Evaluate the model
        next_epoch = cur_epoch + 1
        if next_epoch % cfg.TRAIN.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH:
            stats = test_epoch(test_loader, model, test_meter, cur_epoch)
            nni.report_intermediate_result(stats['top1_err'])
    nni.report_final_result(test_meter.min_top1_err)
コード例 #6
0
def train_model():
    """Trains the model."""
    # Setup training/testing environment
    setup_env()
    # Construct the model, ema, loss_fun, and optimizer
    model = setup_model()
    ema = deepcopy(model)
    loss_fun = builders.build_loss_fun().cuda()
    optimizer = optim.construct_optimizer(model)
    # Load checkpoint or initial weights
    start_epoch = 0
    if cfg.TRAIN.AUTO_RESUME and cp.has_checkpoint():
        file = cp.get_last_checkpoint()
        epoch = cp.load_checkpoint(file, model, ema, optimizer)[0]
        logger.info("Loaded checkpoint from: {}".format(file))
        start_epoch = epoch + 1
    elif cfg.TRAIN.WEIGHTS:
        train_weights = get_weights_file(cfg.TRAIN.WEIGHTS)
        cp.load_checkpoint(train_weights, model, ema)
        logger.info("Loaded initial weights from: {}".format(train_weights))
    # Create data loaders and meters
    train_loader = data_loader.construct_train_loader()
    test_loader = data_loader.construct_test_loader()
    train_meter = meters.TrainMeter(len(train_loader))
    test_meter = meters.TestMeter(len(test_loader))
    ema_meter = meters.TestMeter(len(test_loader), "test_ema")
    # Create a GradScaler for mixed precision training
    scaler = amp.GradScaler(enabled=cfg.TRAIN.MIXED_PRECISION)
    # Compute model and loader timings
    if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0:
        benchmark.compute_time_full(model, loss_fun, train_loader, test_loader)
    # Perform the training loop
    logger.info("Start epoch: {}".format(start_epoch + 1))
    for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
        # Train for one epoch
        params = (train_loader, model, ema, loss_fun, optimizer, scaler,
                  train_meter)
        train_epoch(*params, cur_epoch)
        # Compute precise BN stats
        if cfg.BN.USE_PRECISE_STATS:
            net.compute_precise_bn_stats(model, train_loader)
            net.compute_precise_bn_stats(ema, train_loader)
        # Evaluate the model
        test_epoch(test_loader, model, test_meter, cur_epoch)
        test_epoch(test_loader, ema, ema_meter, cur_epoch)
        test_err = test_meter.get_epoch_stats(cur_epoch)["top1_err"]
        ema_err = ema_meter.get_epoch_stats(cur_epoch)["top1_err"]
        # Save a checkpoint
        file = cp.save_checkpoint(model, ema, optimizer, cur_epoch, test_err,
                                  ema_err)
        logger.info("Wrote checkpoint to: {}".format(file))
コード例 #7
0
def train_kd_model():
    """Trains the model."""
    # Setup training/testing environment
    setup_env()
    # Construct the model, loss_fun, and optimizer
    model = setup_model()
    loss_fun = builders.build_loss_fun().cuda()
    optimizer = optim.construct_optimizer(model)
    # Load checkpoint or initial weights
    start_epoch = 0
    if cfg.TRAIN.AUTO_RESUME and cp.has_checkpoint():
        file = cp.get_last_checkpoint()
        epoch = cp.load_checkpoint(file, model, optimizer)
        logger.info("Loaded checkpoint from: {}".format(file))
        start_epoch = epoch + 1
    elif cfg.TRAIN.WEIGHTS:
        cp.load_checkpoint(cfg.TRAIN.WEIGHTS, model, strict=False)
        logger.info("Loaded initial weights from: {}".format(
            cfg.TRAIN.WEIGHTS))
    # Create data loaders and meters
    train_loader = data_loader.construct_train_loader()
    test_loader = data_loader.construct_test_loader()
    train_meter = meters.TrainMeter(len(train_loader))
    test_meter = meters.TestMeter(len(test_loader))
    # Create a GradScaler for mixed precision training
    scaler = amp.GradScaler(enabled=cfg.TRAIN.MIXED_PRECISION)
    # Compute model and loader timings
    if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0:
        benchmark.compute_time_full(model, loss_fun, train_loader, test_loader)
    # Perform the training loop
    logger.info("Start epoch: {}".format(start_epoch + 1))
    best_err = np.inf
    # Create the teacher model
    teacher = setup_teacher_model()
    for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
        # Train for one epoch
        params = (train_loader, model, loss_fun, optimizer, scaler,
                  train_meter, teacher)
        train_kd_epoch(*params, cur_epoch)
        # Compute precise BN stats
        if cfg.BN.USE_PRECISE_STATS:
            net.compute_precise_bn_stats(model, train_loader)
        # Evaluate the model
        test_epoch(test_loader, model, test_meter, cur_epoch)
        # Check if checkpoint is best so far (note: should checkpoint meters as well)
        stats = test_meter.get_epoch_stats(cur_epoch)
        best = stats["top1_err"] <= best_err
        best_err = min(stats["top1_err"], best_err)
        # Save a checkpoint
        file = cp.save_checkpoint(model, optimizer, cur_epoch, best)
        logger.info("Wrote checkpoint to: {}".format(file))
コード例 #8
0
def test_model():
    """Evaluates a trained model."""
    # Setup training/testing environment
    setup_env()
    # Construct the model
    model = setup_model()
    # Load model weights
    checkpoint.load_checkpoint(cfg.TEST.WEIGHTS, model)
    logger.info("Loaded model weights from: {}".format(cfg.TEST.WEIGHTS))
    # Create data loaders and meters
    test_loader = loader.construct_test_loader()
    test_meter = meters.TestMeter(len(test_loader))
    # Evaluate the model
    test_epoch(test_loader, model, test_meter, 0)
コード例 #9
0
def train_model():
    """Trains the model."""
    # Setup training/testing environment
    setup_env()
    # Construct the model, loss_fun, and optimizer
    model = setup_model()
    loss_fun = builders.build_loss_fun().cuda()
    optimizer = optim.construct_optimizer(model)
    # Load checkpoint or initial weights
    start_epoch = 0
    if cfg.TRAIN.AUTO_RESUME and checkpoint.has_checkpoint():
        last_checkpoint = checkpoint.get_last_checkpoint()
        checkpoint_epoch = checkpoint.load_checkpoint(last_checkpoint, model,
                                                      optimizer)
        logger.info("Loaded checkpoint from: {}".format(last_checkpoint))
        start_epoch = checkpoint_epoch + 1
    elif cfg.TRAIN.WEIGHTS:
        checkpoint.load_checkpoint(cfg.TRAIN.WEIGHTS, model)
        logger.info("Loaded initial weights from: {}".format(
            cfg.TRAIN.WEIGHTS))
    # Compute precise time
    if start_epoch == 0 and cfg.PREC_TIME.ENABLED:
        logger.info("Computing precise time...")
        prec_time = net.compute_precise_time(model, loss_fun)
        logger.info(logging.dump_json_stats(prec_time))
        net.reset_bn_stats(model)
    # Create data loaders and meters
    train_loader = loader.construct_train_loader()
    test_loader = loader.construct_test_loader()
    train_meter = meters.TrainMeter(len(train_loader))
    test_meter = meters.TestMeter(len(test_loader))
    # Perform the training loop
    logger.info("Start epoch: {}".format(start_epoch + 1))
    for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
        # Train for one epoch
        train_epoch(train_loader, model, loss_fun, optimizer, train_meter,
                    cur_epoch)
        # Compute precise BN stats
        if cfg.BN.USE_PRECISE_STATS:
            net.compute_precise_bn_stats(model, train_loader)
        # Save a checkpoint
        if (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0:
            checkpoint_file = checkpoint.save_checkpoint(
                model, optimizer, cur_epoch)
            logger.info("Wrote checkpoint to: {}".format(checkpoint_file))
        # Evaluate the model
        next_epoch = cur_epoch + 1
        if next_epoch % cfg.TRAIN.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH:
            test_epoch(test_loader, model, test_meter, cur_epoch)
コード例 #10
0
def test_model():
    """Evaluates the model."""

    # Build the model (before the loaders to speed up debugging)
    model = model_builder.build_model()
    log_model_info(model)

    # Load model weights
    cu.load_checkpoint(cfg.TEST.WEIGHTS, model)
    logger.info("Loaded model weights from: {}".format(cfg.TEST.WEIGHTS))

    # Create data loaders
    test_loader = loader.construct_test_loader()

    # Create meters
    test_meter = TestMeter(len(test_loader))

    # Evaluate the model
    test_epoch(test_loader, model, test_meter, 0)
コード例 #11
0
ファイル: test_net.py プロジェクト: acabadw22/pycls
def test_model():
    """Evaluates the model."""

    # Setup logging
    logging.setup_logging()
    # Show the config
    logger.info("Config:\n{}".format(cfg))

    # Fix the RNG seeds (see RNG comment in core/config.py for discussion)
    np.random.seed(cfg.RNG_SEED)
    torch.manual_seed(cfg.RNG_SEED)
    # Configure the CUDNN backend
    torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK

    # Build the model (before the loaders to speed up debugging)
    model = builders.build_model()
    logger.info("Model:\n{}".format(model))
    logger.info(logging.dump_json_stats(net.complexity(model)))

    # Compute precise time
    if cfg.PREC_TIME.ENABLED:
        logger.info("Computing precise time...")
        loss_fun = builders.build_loss_fun()
        prec_time = net.compute_precise_time(model, loss_fun)
        logger.info(logging.dump_json_stats(prec_time))
        net.reset_bn_stats(model)

    # Load model weights
    checkpoint.load_checkpoint(cfg.TEST.WEIGHTS, model)
    logger.info("Loaded model weights from: {}".format(cfg.TEST.WEIGHTS))

    # Create data loaders
    test_loader = loader.construct_test_loader()

    # Create meters
    test_meter = meters.TestMeter(len(test_loader))

    # Evaluate the model
    test_epoch(test_loader, model, test_meter, 0)
コード例 #12
0
def test_ftta_model(corruptions, levels):
    """Use feed back to fine-tune some part of the model. (with all kind of corruptions)"""
    all_results = []
    for corruption_level in levels:
        lvl_results = []
        for corruption_type in corruptions:
            cfg.TRAIN.CORRUPTION = corruption_type
            cfg.TRAIN.LEVEL = corruption_level
            cfg.TEST.CORRUPTION = corruption_type
            cfg.TEST.LEVEL = corruption_level

            # Setup training/testing environment
            setup_env()
            # Construct the model, loss_fun, and optimizer
            model = setup_model()
            loss_fun = builders.build_loss_fun().cuda()
            optimizer = optim.construct_optimizer(model)
            # Load checkpoint or initial weights
            start_epoch = 0
            checkpoint.load_checkpoint(cfg.TRAIN.WEIGHTS,
                                       model,
                                       strict=cfg.TRAIN.LOAD_STRICT)
            logger.info("Loaded initial weights from: {}".format(
                cfg.TRAIN.WEIGHTS))
            # Create data loaders and meters
            train_loader = loader.construct_train_loader()
            test_loader = loader.construct_test_loader()
            train_meter = meters.TrainMeter(len(train_loader))
            test_meter = meters.TestMeter(len(test_loader))
            # Compute model and loader timings
            if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0:
                benchmark.compute_time_full(model, loss_fun, train_loader,
                                            test_loader)

            # Perform the training loop
            logger.info("Start epoch: {}".format(start_epoch + 1))
            for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
                if cfg.TRAIN.ADAPTATION != 'test_only':
                    if cfg.TRAIN.ADAPTATION == 'update_bn':
                        bn_update(model, train_loader)
                    elif cfg.TRAIN.ADAPTATION == 'min_entropy':
                        # Train for one epoch
                        train_epoch(train_loader, model, loss_fun, optimizer,
                                    train_meter, cur_epoch)
                        bn_update(model, train_loader)

                    # Save a checkpoint
                    if (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0:
                        checkpoint_file = checkpoint.save_checkpoint(
                            model, optimizer, cur_epoch)
                        logger.info(
                            "Wrote checkpoint to: {}".format(checkpoint_file))

                # Evaluate the model
                next_epoch = cur_epoch + 1
                if next_epoch % cfg.TRAIN.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH:
                    top1 = test_epoch(test_loader, model, test_meter,
                                      cur_epoch)
            lvl_results.append(top1)
        all_results.append(lvl_results)

    for lvl_idx in range(len(all_results)):
        logger.info("corruption level: {}".format(levels[lvl_idx]))
        logger.info("corruption types: {}".format(corruptions))
        logger.info(all_results[lvl_idx])

    # show_parameters(model)

    return all_results
コード例 #13
0
def train_model():
    """Trains the model."""
    # Setup training/testing environment
    setup_env()
    # Construct the model, loss_fun, and optimizer
    model = setup_model()
    loss_fun = builders.build_loss_fun().cuda()
    if "search" in cfg.MODEL.TYPE:
        params_w = [v for k, v in model.named_parameters() if "alphas" not in k]
        params_a = [v for k, v in model.named_parameters() if "alphas" in k]
        optimizer_w = torch.optim.SGD(
            params=params_w,
            lr=cfg.OPTIM.BASE_LR,
            momentum=cfg.OPTIM.MOMENTUM,
            weight_decay=cfg.OPTIM.WEIGHT_DECAY,
            dampening=cfg.OPTIM.DAMPENING,
            nesterov=cfg.OPTIM.NESTEROV
        )
        if cfg.OPTIM.ARCH_OPTIM == "adam":
            optimizer_a = torch.optim.Adam(
                params=params_a,
                lr=cfg.OPTIM.ARCH_BASE_LR,
                betas=(0.5, 0.999),
                weight_decay=cfg.OPTIM.ARCH_WEIGHT_DECAY
            )
        elif cfg.OPTIM.ARCH_OPTIM == "sgd":
            optimizer_a = torch.optim.SGD(
                params=params_a,
                lr=cfg.OPTIM.ARCH_BASE_LR,
                momentum=cfg.OPTIM.MOMENTUM,
                weight_decay=cfg.OPTIM.ARCH_WEIGHT_DECAY,
                dampening=cfg.OPTIM.DAMPENING,
                nesterov=cfg.OPTIM.NESTEROV
            )
        optimizer = [optimizer_w, optimizer_a]
    else:
        optimizer = optim.construct_optimizer(model)
    # Load checkpoint or initial weights
    start_epoch = 0
    if cfg.TRAIN.AUTO_RESUME and checkpoint.has_checkpoint():
        last_checkpoint = checkpoint.get_last_checkpoint()
        checkpoint_epoch = checkpoint.load_checkpoint(last_checkpoint, model, optimizer)
        logger.info("Loaded checkpoint from: {}".format(last_checkpoint))
        start_epoch = checkpoint_epoch + 1
    elif cfg.TRAIN.WEIGHTS:
        checkpoint.load_checkpoint(cfg.TRAIN.WEIGHTS, model)
        logger.info("Loaded initial weights from: {}".format(cfg.TRAIN.WEIGHTS))
    # Create data loaders and meters
    if cfg.TRAIN.PORTION < 1:
        if "search" in cfg.MODEL.TYPE:
            train_loader = [loader._construct_loader(
                dataset_name=cfg.TRAIN.DATASET,
                split=cfg.TRAIN.SPLIT,
                batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
                shuffle=True,
                drop_last=True,
                portion=cfg.TRAIN.PORTION,
                side="l"
            ),
            loader._construct_loader(
                dataset_name=cfg.TRAIN.DATASET,
                split=cfg.TRAIN.SPLIT,
                batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
                shuffle=True,
                drop_last=True,
                portion=cfg.TRAIN.PORTION,
                side="r"
            )]
        else:
            train_loader = loader._construct_loader(
                dataset_name=cfg.TRAIN.DATASET,
                split=cfg.TRAIN.SPLIT,
                batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
                shuffle=True,
                drop_last=True,
                portion=cfg.TRAIN.PORTION,
                side="l"
            )
        test_loader = loader._construct_loader(
            dataset_name=cfg.TRAIN.DATASET,
            split=cfg.TRAIN.SPLIT,
            batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
            shuffle=False,
            drop_last=False,
            portion=cfg.TRAIN.PORTION,
            side="r"
        )
    else:
        train_loader = loader.construct_train_loader()
        test_loader = loader.construct_test_loader()
    train_meter_type = meters.TrainMeterIoU if cfg.TASK == "seg" else meters.TrainMeter
    test_meter_type = meters.TestMeterIoU if cfg.TASK == "seg" else meters.TestMeter
    l = train_loader[0] if isinstance(train_loader, list) else train_loader
    train_meter = train_meter_type(len(l))
    test_meter = test_meter_type(len(test_loader))
    # Compute model and loader timings
    if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0:
        l = train_loader[0] if isinstance(train_loader, list) else train_loader
        benchmark.compute_time_full(model, loss_fun, l, test_loader)
    # Perform the training loop
    logger.info("Start epoch: {}".format(start_epoch + 1))
    for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
        # Train for one epoch
        f = search_epoch if "search" in cfg.MODEL.TYPE else train_epoch
        f(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch)
        # Compute precise BN stats
        if cfg.BN.USE_PRECISE_STATS:
            net.compute_precise_bn_stats(model, train_loader)
        # Save a checkpoint
        if (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0:
            checkpoint_file = checkpoint.save_checkpoint(model, optimizer, cur_epoch)
            logger.info("Wrote checkpoint to: {}".format(checkpoint_file))
        # Evaluate the model
        next_epoch = cur_epoch + 1
        if next_epoch % cfg.TRAIN.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH:
            test_epoch(test_loader, model, test_meter, cur_epoch)
コード例 #14
0
def test():
    """Evaluates a trained model."""
    # Setup training/testing environment
    setup_env()
    # Construct the model
    model = setup_model()
    # Load model weights
    checkpoint.load_checkpoint(cfg.TEST.WEIGHTS, model)
    logger.info("Loaded model weights from: {}".format(cfg.TEST.WEIGHTS))
    # Create data loaders
    test_loader = loader.construct_test_loader()
    dataset = test_loader.dataset
    # Enable eval mode
    logs = []
    model.eval()
    for inputs, labels in test_loader:
        # Transfer the data to the current GPU device
        inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
        # Compute the predictions
        preds = model(inputs)
        if cfg.SOFTMAX:
            preds = F.softmax(preds, dim=1)
        else:
            preds = torch.sigmoid(preds)
        # Abnormal dataset format support
        if cfg.TRAIN.DATASET == "abnormal":
            labels = labels.argmax(dim=1)
        # (batch_size, classes) -> (classes, batch_size)
        for label, tail in zip(labels.tolist(), preds.tolist()):
            logs.append([label, tail[0], tail])

    imgs = [v["im_path"] for v in dataset._imdb]
    class_ids = dataset._class_ids
    assert len(imgs) == len(logs)

    lines = []
    outputs = []
    lines.append(":".join(class_ids))
    lines.append("{}".format(len(imgs)))
    lines.append("im_path,label,score,score_1_n")

    for im_path, (label, score, tail) in zip(imgs, logs):
        tail = ",".join(["{:.3f}".format(v) for v in tail])
        lines.append("{},{},{},{}".format(im_path, label, score, tail))
        outputs.append([im_path, class_ids[label], score])

    task_name = time.strftime("%m%d%H%M%S")
    os.makedirs(os.path.join(cfg.OUT_DIR, task_name))

    temp_file = "{}/threshold.png".format(task_name)
    temp_file = os.path.join(cfg.OUT_DIR, temp_file)
    score_thr = search_thr(logs, s1_thr=2, s2_thr=70, out_file=temp_file)

    temp_file = "{}/results.csv".format(task_name)
    temp_file = os.path.join(cfg.OUT_DIR, temp_file)
    with open(temp_file, "w") as f:
        f.write("\n".join(lines))
        print(temp_file)

    temp_file = "{}/results.pkl".format(task_name)
    temp_file = os.path.join(cfg.OUT_DIR, temp_file)
    with open(temp_file, "wb") as f:
        pickle.dump(outputs, f)
        print(temp_file)

    hardmini(outputs, class_ids, task_name, score_thr)
    return outputs
コード例 #15
0
ファイル: trainer.py プロジェクト: zhengxiawu/pytorch_cls
def train_model():
    """Trains the model."""
    # Setup training/testing environment
    setup_env()
    # Construct the model, loss_fun, and optimizer
    model = setup_model()
    loss_fun = builders.build_loss_fun().cuda()
    optimizer = optim.construct_optimizer(model)
    # Load checkpoint or initial weights
    start_epoch = 0
    if cfg.TRAIN.AUTO_RESUME and checkpoint.has_checkpoint():
        last_checkpoint = checkpoint.get_last_checkpoint()
        checkpoint_epoch = checkpoint.load_checkpoint(last_checkpoint, model,
                                                      optimizer)
        logger.info("Loaded checkpoint from: {}".format(last_checkpoint))
        start_epoch = checkpoint_epoch + 1
    elif cfg.TRAIN.WEIGHTS:
        checkpoint.load_checkpoint(cfg.TRAIN.WEIGHTS, model)
        logger.info("Loaded initial weights from: {}".format(
            cfg.TRAIN.WEIGHTS))
    # Create data loaders and meters
    if cfg.TEST.DATASET == 'imagenet_dataset' or cfg.TRAIN.DATASET == 'imagenet_dataset':
        dataset = loader.construct_train_loader()
        train_loader = dataset.train_loader
        test_loader = dataset.val_loader
    else:
        dataset = None
        train_loader = loader.construct_train_loader()
        test_loader = loader.construct_test_loader()
    train_meter = meters.TrainMeter(len(train_loader))
    test_meter = meters.TestMeter(len(test_loader))
    # Compute model and loader timings
    if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0:
        benchmark.compute_time_full(model, loss_fun, train_loader, test_loader)
    # Perform the training loop
    logger.info("Start epoch: {}".format(start_epoch + 1))
    for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
        # Train for one epoch
        train_epoch(train_loader, model, loss_fun, optimizer, train_meter,
                    cur_epoch)
        # Compute precise BN stats
        if cfg.BN.USE_PRECISE_STATS:
            net.compute_precise_bn_stats(model, train_loader)
        # Save a checkpoint
        if (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0:
            checkpoint_file = checkpoint.save_checkpoint(
                model, optimizer, cur_epoch)
            logger.info("Wrote checkpoint to: {}".format(checkpoint_file))
        # Evaluate the model
        next_epoch = cur_epoch + 1
        if next_epoch % cfg.TRAIN.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH:
            logger.info("Start testing")
            test_epoch(test_loader, model, test_meter, cur_epoch)
        if dataset is not None:
            logger.info("Reset the dataset")
            train_loader._dali_iterator.reset()
            test_loader._dali_iterator.reset()
            # clear memory
            if torch.cuda.is_available():
                torch.cuda.synchronize()
                torch.cuda.empty_cache(
                )  # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637
            gc.collect()
コード例 #16
0
def train_model():
    """Trains the model."""

    # Setup logging
    logging.setup_logging()
    # Show the config
    logger.info("Config:\n{}".format(cfg))

    # Fix the RNG seeds (see RNG comment in core/config.py for discussion)
    np.random.seed(cfg.RNG_SEED)
    torch.manual_seed(cfg.RNG_SEED)
    # Configure the CUDNN backend
    torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK

    # Build the model (before the loaders to speed up debugging)
    model = builders.build_model()
    logger.info("Model:\n{}".format(model))
    logger.info(logging.dump_json_stats(net.complexity(model)))

    # Define the loss function
    loss_fun = builders.build_loss_fun()
    # Construct the optimizer
    optimizer = optim.construct_optimizer(model)

    # Load checkpoint or initial weights
    start_epoch = 0
    if cfg.TRAIN.AUTO_RESUME and checkpoint.has_checkpoint():
        last_checkpoint = checkpoint.get_last_checkpoint()
        checkpoint_epoch = checkpoint.load_checkpoint(last_checkpoint, model,
                                                      optimizer)
        logger.info("Loaded checkpoint from: {}".format(last_checkpoint))
        start_epoch = checkpoint_epoch + 1
    elif cfg.TRAIN.WEIGHTS:
        checkpoint.load_checkpoint(cfg.TRAIN.WEIGHTS, model)
        logger.info("Loaded initial weights from: {}".format(
            cfg.TRAIN.WEIGHTS))

    # Compute precise time
    if start_epoch == 0 and cfg.PREC_TIME.ENABLED:
        logger.info("Computing precise time...")
        prec_time = net.compute_precise_time(model, loss_fun)
        logger.info(logging.dump_json_stats(prec_time))
        net.reset_bn_stats(model)

    # Create data loaders
    train_loader = loader.construct_train_loader()
    test_loader = loader.construct_test_loader()

    # Create meters
    train_meter = meters.TrainMeter(len(train_loader))
    test_meter = meters.TestMeter(len(test_loader))

    # Perform the training loop
    logger.info("Start epoch: {}".format(start_epoch + 1))

    for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
        # Train for one epoch
        train_epoch(train_loader, model, loss_fun, optimizer, train_meter,
                    cur_epoch)
        # Compute precise BN stats
        if cfg.BN.USE_PRECISE_STATS:
            net.compute_precise_bn_stats(model, train_loader)
        # Save a checkpoint
        if checkpoint.is_checkpoint_epoch(cur_epoch):
            checkpoint_file = checkpoint.save_checkpoint(
                model, optimizer, cur_epoch)
            logger.info("Wrote checkpoint to: {}".format(checkpoint_file))
        # Evaluate the model
        if is_eval_epoch(cur_epoch):
            test_epoch(test_loader, model, test_meter, cur_epoch)
コード例 #17
0
def train_model(writer_train=None, writer_eval=None, is_master=False):
    """Trains the model."""
    # Fit flops/params
    if cfg.TRAIN.AUTO_MATCH and cfg.RGRAPH.SEED_TRAIN == cfg.RGRAPH.SEED_TRAIN_START:
        mode = 'flops'  # flops or params
        if cfg.TRAIN.DATASET == 'cifar10':
            pre_repeat = 15
            if cfg.MODEL.TYPE == 'resnet':  # ResNet20
                stats_baseline = 40813184
            elif cfg.MODEL.TYPE == 'mlpnet':  # 5-layer MLP. cfg.MODEL.LAYERS exclude stem and head layers
                if cfg.MODEL.LAYERS == 3:
                    if cfg.RGRAPH.DIM_LIST[0] == 256:
                        stats_baseline = 985600
                    elif cfg.RGRAPH.DIM_LIST[0] == 512:
                        stats_baseline = 2364416
                    elif cfg.RGRAPH.DIM_LIST[0] == 1024:
                        stats_baseline = 6301696
            elif cfg.MODEL.TYPE == 'cnn':
                if cfg.MODEL.LAYERS == 3:
                    if cfg.RGRAPH.DIM_LIST[0] == 512:
                        stats_baseline = 806884352
                    elif cfg.RGRAPH.DIM_LIST[0] == 16:
                        stats_baseline = 1216672
                elif cfg.MODEL.LAYERS == 6:
                    if '64d' in cfg.OUT_DIR:
                        stats_baseline = 48957952
                    elif '16d' in cfg.OUT_DIR:
                        stats_baseline = 3392128
        elif cfg.TRAIN.DATASET == 'imagenet':
            pre_repeat = 9
            if cfg.MODEL.TYPE == 'resnet':
                if 'basic' in cfg.RESNET.TRANS_FUN:  # ResNet34
                    stats_baseline = 3663761408
                elif 'sep' in cfg.RESNET.TRANS_FUN:  # ResNet34-sep
                    stats_baseline = 553614592
                elif 'bottleneck' in cfg.RESNET.TRANS_FUN:  # ResNet50
                    stats_baseline = 4089184256
            elif cfg.MODEL.TYPE == 'efficientnet':  # EfficientNet
                stats_baseline = 385824092
            elif cfg.MODEL.TYPE == 'cnn':  # CNN
                if cfg.MODEL.LAYERS == 6:
                    if '64d' in cfg.OUT_DIR:
                        stats_baseline = 166438912
        cfg.defrost()
        stats = model_builder.build_model_stats(mode)
        if stats != stats_baseline:
            # 1st round: set first stage dim
            for i in range(pre_repeat):
                scale = round(math.sqrt(stats_baseline / stats), 2)
                first = cfg.RGRAPH.DIM_LIST[0]
                ratio_list = [dim / first for dim in cfg.RGRAPH.DIM_LIST]
                first = int(round(first * scale))
                cfg.RGRAPH.DIM_LIST = [
                    int(round(first * ratio)) for ratio in ratio_list
                ]
                stats = model_builder.build_model_stats(mode)
            flag_init = 1 if stats < stats_baseline else -1
            step = 1
            while True:
                first = cfg.RGRAPH.DIM_LIST[0]
                ratio_list = [dim / first for dim in cfg.RGRAPH.DIM_LIST]
                first += flag_init * step
                cfg.RGRAPH.DIM_LIST = [
                    int(round(first * ratio)) for ratio in ratio_list
                ]
                stats = model_builder.build_model_stats(mode)
                flag = 1 if stats < stats_baseline else -1
                if stats == stats_baseline:
                    break
                if flag != flag_init:
                    if cfg.RGRAPH.UPPER == False:  # make sure the stats is SMALLER than baseline
                        if flag < 0:
                            first = cfg.RGRAPH.DIM_LIST[0]
                            ratio_list = [
                                dim / first for dim in cfg.RGRAPH.DIM_LIST
                            ]
                            first -= flag_init * step
                            cfg.RGRAPH.DIM_LIST = [
                                int(round(first * ratio))
                                for ratio in ratio_list
                            ]
                        break
                    else:
                        if flag > 0:
                            first = cfg.RGRAPH.DIM_LIST[0]
                            ratio_list = [
                                dim / first for dim in cfg.RGRAPH.DIM_LIST
                            ]
                            first -= flag_init * step
                            cfg.RGRAPH.DIM_LIST = [
                                int(round(first * ratio))
                                for ratio in ratio_list
                            ]
                        break
            # 2nd round: set other stage dim
            first = cfg.RGRAPH.DIM_LIST[0]
            ratio_list = [
                int(round(dim / first)) for dim in cfg.RGRAPH.DIM_LIST
            ]
            stats = model_builder.build_model_stats(mode)
            flag_init = 1 if stats < stats_baseline else -1
            if 'share' not in cfg.RESNET.TRANS_FUN:
                for i in range(1, len(cfg.RGRAPH.DIM_LIST)):
                    for j in range(ratio_list[i]):
                        cfg.RGRAPH.DIM_LIST[i] += flag_init
                        stats = model_builder.build_model_stats(mode)
                        flag = 1 if stats < stats_baseline else -1
                        if flag_init != flag:
                            cfg.RGRAPH.DIM_LIST[i] -= flag_init
                            break
        stats = model_builder.build_model_stats(mode)
        print('FINAL', cfg.RGRAPH.GROUP_NUM, cfg.RGRAPH.DIM_LIST, stats,
              stats_baseline, stats < stats_baseline)
    # Build the model (before the loaders to ease debugging)
    model = model_builder.build_model()
    params, flops = log_model_info(model, writer_eval)

    # Define the loss function
    loss_fun = losses.get_loss_fun()
    # Construct the optimizer
    optimizer = optim.construct_optimizer(model)

    # Load a checkpoint if applicable
    start_epoch = 0
    if cfg.TRAIN.AUTO_RESUME and cu.has_checkpoint():
        last_checkpoint = cu.get_checkpoint_last()
        checkpoint_epoch = cu.load_checkpoint(last_checkpoint, model,
                                              optimizer)
        logger.info('Loaded checkpoint from: {}'.format(last_checkpoint))
        if checkpoint_epoch == cfg.OPTIM.MAX_EPOCH:
            exit()
            start_epoch = checkpoint_epoch
        else:
            start_epoch = checkpoint_epoch + 1

    # Create data loaders
    train_loader = loader.construct_train_loader()
    test_loader = loader.construct_test_loader()

    # Create meters
    train_meter = TrainMeter(len(train_loader))
    test_meter = TestMeter(len(test_loader))

    if cfg.ONLINE_FLOPS:
        model_dummy = model_builder.build_model()

        IMAGE_SIZE = 224
        n_flops, n_params = mu.measure_model(model_dummy, IMAGE_SIZE,
                                             IMAGE_SIZE)

        logger.info('FLOPs: %.2fM, Params: %.2fM' %
                    (n_flops / 1e6, n_params / 1e6))

        del (model_dummy)

    # Perform the training loop
    logger.info('Start epoch: {}'.format(start_epoch + 1))

    # do eval at initialization
    eval_epoch(test_loader,
               model,
               test_meter,
               -1,
               writer_eval,
               params,
               flops,
               is_master=is_master)

    if start_epoch == cfg.OPTIM.MAX_EPOCH:
        cur_epoch = start_epoch - 1
        eval_epoch(test_loader,
                   model,
                   test_meter,
                   cur_epoch,
                   writer_eval,
                   params,
                   flops,
                   is_master=is_master)
    else:
        for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
            # Train for one epoch
            train_epoch(train_loader,
                        model,
                        loss_fun,
                        optimizer,
                        train_meter,
                        cur_epoch,
                        writer_train,
                        is_master=is_master)
            # Compute precise BN stats
            if cfg.BN.USE_PRECISE_STATS:
                nu.compute_precise_bn_stats(model, train_loader)
            # Save a checkpoint
            if cu.is_checkpoint_epoch(cur_epoch):
                checkpoint_file = cu.save_checkpoint(model, optimizer,
                                                     cur_epoch)
                logger.info('Wrote checkpoint to: {}'.format(checkpoint_file))
            # Evaluate the model
            if is_eval_epoch(cur_epoch):
                eval_epoch(test_loader,
                           model,
                           test_meter,
                           cur_epoch,
                           writer_eval,
                           params,
                           flops,
                           is_master=is_master)