def test(model=None, config=None):
    epoch = 0
    torch.backends.cudnn.benchmark = True

    config = parse_test_config() if not config else config

    transform = A.Compose([
        A.Normalize(),
        M.MyToTensorV2(),
    ],
                          additional_targets={
                              'normal': 'normal',
                              'depth': 'depth',
                          })

    _, dataloader = create_dataloader(config.DATASET_ROOT,
                                      config.JSON_PATH,
                                      batch_size=config.BATCH_SIZE,
                                      transform=transform,
                                      workers=config.WORKERS,
                                      pin_memory=config.PIN_MEMORY,
                                      shuffle=config.SHUFFLE)

    if not model:
        model = Model()
        model = model.to(DEVICE)
        epoch, model = load_checkpoint(model, config.CHECKPOINT_FILE, DEVICE)

    loss_fn = LossFunction()
    metric_fn = MetricFunction(config.BATCH_SIZE)

    model.eval()
    run_test(model, dataloader, loss_fn, metric_fn)
    print_single_error(epoch, loss_fn.show(), metric_fn.show())
Beispiel #2
0
def train(config=None, config_test=None):
    torch.backends.cudnn.benchmark = True

    config = parse_train_config() if not config else config

    transform = A.Compose([
        M.MyRandomResizedCrop(width=config.IMAGE_SIZE,
                              height=config.IMAGE_SIZE),
        A.OneOf([
            A.MotionBlur(p=0.2),
            A.MedianBlur(blur_limit=3, p=0.1),
            A.Blur(blur_limit=3, p=0.1),
        ],
                p=0.2),
        A.OneOf([
            M.MyOpticalDistortion(p=0.3),
            M.MyGridDistortion(p=0.1),
        ],
                p=0.2),
        A.OneOf([
            A.IAASharpen(),
            A.IAAEmboss(),
            A.RandomBrightnessContrast(),
        ],
                p=0.3),
        A.Normalize(),
        M.MyToTensorV2(),
    ],
                          additional_targets={
                              'right_img': 'image',
                              'left_normal': 'normal',
                              'right_normal': 'normal',
                          })

    _, dataloader = create_dataloader(config.DATASET_ROOT,
                                      config.JSON_PATH,
                                      batch_size=config.BATCH_SIZE,
                                      transform=transform,
                                      workers=config.WORKERS,
                                      pin_memory=config.PIN_MEMORY,
                                      shuffle=config.SHUFFLE)

    model = Model()
    model.apply(init_weights)
    solver = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                     model.parameters()),
                              lr=config.LEARNING_RATE,
                              betas=config.BETAS,
                              eps=config.EPS,
                              weight_decay=config.WEIGHT_DECAY)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        solver, milestones=config.MILESTONES, gamma=config.GAMMA)
    model = model.to(DEVICE)

    loss_fn = LossFunction()

    epoch_idx = 0
    if config.CHECKPOINT_FILE and config.LOAD_MODEL:
        epoch_idx, model = load_checkpoint(model, config.CHECKPOINT_FILE,
                                           DEVICE)

    output_dir = os.path.join(
        config.OUT_PATH, re.sub("[^0-9a-zA-Z]+", "-",
                                dt.now().isoformat()))

    for epoch_idx in range(epoch_idx, config.NUM_EPOCHS):
        metric_fn = MetricFunction(config.BATCH_SIZE)

        model.train()
        train_one_epoch(model, dataloader, loss_fn, metric_fn, solver,
                        epoch_idx)
        print_single_error(epoch_idx, loss_fn.show(), metric_fn.show())
        lr_scheduler.step()

        if config.TEST:
            test(model, config_test)
        if config.SAVE_MODEL:
            save_checkpoint(epoch_idx, model, output_dir)

    if not config.TEST:
        test(model, config_test)
    if not config.SAVE_MODEL:
        save_checkpoint(epoch_idx, model, output_dir)