Ejemplo n.º 1
0
def main():
    """[summary]
    """
    config = load_config()
    print('successfully loaded config:')
    print(config)

    # prepare dataloaders, flipping flags, and weights for averaging
    test_dataloaders, flags_hflip, flags_vflip, weights = [], [], [], []

    # default dataloader (w/o tta)
    test_dataloaders.append(get_test_dataloader(config))
    weights.append(1.0)
    flags_hflip.append(False)
    flags_vflip.append(False)

    # dataloaders w/ tta size jittering
    for tta_resize_wh, weight in zip(config.TTA.RESIZE,
                                     config.TTA.RESIZE_WEIGHTS):
        test_dataloaders.append(
            get_test_dataloader(config, tta_resize_wh=tta_resize_wh))
        weights.append(weight)
        flags_hflip.append(False)
        flags_vflip.append(False)

    # dataloader w/ tta horizontal flipping
    if config.TTA.HORIZONTAL_FLIP:
        test_dataloaders.append(get_test_dataloader(config, tta_hflip=True))
        weights.append(config.TTA.HORIZONTAL_FLIP_WEIGHT)
        flags_hflip.append(True)
        flags_vflip.append(False)

    # dataloader w/ tta vertical flipping
    if config.TTA.VERTICAL_FLIP:
        test_dataloaders.append(get_test_dataloader(config, tta_vflip=True))
        weights.append(config.TTA.VERTICAL_FLIP_WEIGHT)
        flags_hflip.append(False)
        flags_vflip.append(True)

    # normalize weights
    weights = np.array(weights)
    weights /= weights.sum()

    # prepare model to test
    model = get_model(config)
    model.eval()

    # prepare directory to output predictions
    exp_subdir = experiment_subdir(config.EXP_ID)
    pred_root = os.path.join(config.PREDICTION_ROOT, exp_subdir)
    os.makedirs(pred_root, exist_ok=False)

    test_width, test_height = config.TRANSFORM.TEST_SIZE

    # test loop
    for batches in tqdm(zip(*test_dataloaders),
                        total=len(test_dataloaders[0])):
        # prepare buffers for image file name and predicted array
        batch_size = len(batches[0]['image'])
        output_paths = [None] * batch_size
        orig_image_sizes = [None] * batch_size
        predictions_averaged = np.zeros(shape=[
            batch_size,
            len(config.INPUT.CLASSES), test_height, test_width
        ])

        for dataloader_idx, batch in enumerate(batches):
            images = batch['image'].to(config.MODEL.DEVICE)
            image_paths = batch['image_path']
            original_heights, original_widths, _ = batch['original_shape']

            predictions = model.module.predict(images)
            predictions = predictions.cpu().numpy()

            for batch_idx in range(len(predictions)):
                pred = predictions[batch_idx]
                path = image_paths[batch_idx]
                orig_h = original_heights[batch_idx].item()
                orig_w = original_widths[batch_idx].item()

                # resize (only when resize tta or input resizing is applied)
                _, pred_height, pred_width = pred.shape
                if (pred_width != test_width) or (pred_height != test_height):
                    pred = pred.transpose(1, 2, 0)  # CHW -> HWC
                    pred = cv2.resize(pred, dsize=(test_width, test_height))
                    pred = pred.transpose(2, 0, 1)  # HWC -> CHW

                # flip (only when flipping tta is applied)
                if flags_vflip[dataloader_idx]:
                    pred = pred[:, ::-1, :]
                if flags_hflip[dataloader_idx]:
                    pred = pred[:, :, ::-1]

                # store predictions into the buffer
                predictions_averaged[
                    batch_idx] += pred * weights[dataloader_idx]

                # prepare sub-directory under pred_root
                filename = os.path.basename(path)
                filename, _ = os.path.splitext(filename)
                filename = f'{filename}.png'
                aoi = get_aoi_from_path(path)
                out_dir = os.path.join(pred_root, aoi)
                os.makedirs(out_dir, exist_ok=True)

                # store output paths and original image sizes into the buffers
                output_path = os.path.join(out_dir, filename)
                orig_image_wh = (orig_w, orig_h)
                if dataloader_idx == 0:
                    output_paths[batch_idx] = output_path
                    orig_image_sizes[batch_idx] = orig_image_wh
                else:
                    assert output_paths[batch_idx] == output_path
                    assert orig_image_sizes[batch_idx] == orig_image_wh

        for output_path, orig_image_wh, pred_averaged in zip(
                output_paths, orig_image_sizes, predictions_averaged):
            # remove padded area
            pred_averaged = crop_center(pred_averaged, crop_wh=orig_image_wh)

            # dump to .png file
            dump_prediction_to_png(output_path, pred_averaged)
Ejemplo n.º 2
0
        # handle padded area and non-ROI area
        roi_mask = roi_masks[i]
        h, w = roi_mask.shape
        pred_refined = pred_refined[:, :h, :w]
        pred_refined[:, np.logical_not(roi_mask)] = 0

        # dump
        pred_filename = os.path.basename(pred_path)
        dump_prediction_to_png(os.path.join(out_dir, pred_filename),
                               pred_refined)


if __name__ == '__main__':
    t0 = timeit.default_timer()

    config = load_config()

    assert len(config.ENSEMBLE_EXP_IDS) >= 1

    subdir = ensemble_subdir(config.ENSEMBLE_EXP_IDS)
    input_root = os.path.join(config.ENSEMBLED_PREDICTION_ROOT, subdir)
    out_root = os.path.join(config.REFINED_PREDICTION_ROOT, subdir)
    aois = get_subdirs(input_root)

    n_thread = config.REFINEMENT_NUM_THREADS
    n_thread = n_thread if n_thread > 0 else mp.cpu_count()
    print(f'N_thread for multiprocessing: {n_thread}')

    print('preparing input args...')
    input_args = []
    for aoi in aois:
def main():
    """[summary]
    """

    config = load_config()
    print("successfully loaded config:")
    print(config)
    print("")

    assert config.SOLVER.EPOCHS > config.EVAL.EPOCH_TO_START_VAL

    # prepare directories to output log/weight files
    exp_subdir = experiment_subdir(config.EXP_ID)
    log_dir = os.path.join(config.LOG_ROOT, exp_subdir)
    weight_dir = os.path.join(config.WEIGHT_ROOT, exp_subdir)
    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(weight_dir, exist_ok=True)

    checkpoint_dir = os.path.join(config.CHECKPOINT_ROOT, exp_subdir)
    if config.SAVE_CHECKPOINTS:
        os.makedirs(checkpoint_dir, exist_ok=True)

    # prepare dataloaders
    train_dataloader = get_dataloader(config, is_train=True)
    val_dataloader = get_dataloader(config, is_train=False)

    # prepare model to train
    model = get_model(config)

    # prepare optimizer with lr scheduler
    optimizer = get_optimizer(config, model)
    lr_scheduler = get_lr_scheduler(config, optimizer)

    # prepare other states
    start_epoch = 0
    best_score = 0

    # load checkpoint if exists
    model, optimizer, lr_scheduler, start_epoch, best_score = load_latest_checkpoint(
        checkpoint_dir, model, optimizer, lr_scheduler, start_epoch,
        best_score)

    # prepare metrics and loss
    metrics = get_metrics(config)
    loss = get_loss(config)

    # prepare train/val epoch runners
    train_epoch = smp.utils.train.TrainEpoch(
        model,
        loss=loss,
        metrics=metrics,
        optimizer=optimizer,
        device=config.MODEL.DEVICE,
        verbose=True,
    )
    val_epoch = smp.utils.train.ValidEpoch(
        model,
        loss=loss,
        metrics=metrics,
        device=config.MODEL.DEVICE,
        verbose=True,
    )

    # prepare tensorboard
    tblogger = SummaryWriter(log_dir)

    if config.DUMP_GIT_INFO:
        # save git hash
        dump_git_info(os.path.join(log_dir, git_filename()))

    # dump config to a file
    with open(os.path.join(log_dir, config_filename()), "w") as f:
        f.write(str(config))

    # train loop
    metric_name = config.EVAL.MAIN_METRIC
    split_id = config.INPUT.TRAIN_VAL_SPLIT_ID

    for epoch in range(start_epoch, config.SOLVER.EPOCHS):
        lr = optimizer.param_groups[0]["lr"]
        print(f"\nEpoch: {epoch}, lr: {lr}")

        # run train for 1 epoch
        train_logs = train_epoch.run(train_dataloader)

        # log lr to tensorboard
        tblogger.add_scalar("lr", lr, epoch)
        # log train losses and scores
        for k, v in train_logs.items():
            tblogger.add_scalar(f"split_{split_id}/train/{k}", v, epoch)

        if (epoch >= config.EVAL.EPOCH_TO_START_VAL) and (
                epoch % config.EVAL.VAL_INTERVAL_EPOCH == 0):
            # run val for 1 epoch
            val_logs = val_epoch.run(val_dataloader)

            # log val losses and scores
            for k, v in val_logs.items():
                tblogger.add_scalar(f"split_{split_id}/val/{k}", v, epoch)

            # save model weight if score updated
            if best_score < val_logs[metric_name]:
                best_score = val_logs[metric_name]
                torch.save(model.state_dict(),
                           os.path.join(weight_dir, weight_best_filename()))
                print("Best val score updated!")
        else:
            if epoch < config.EVAL.EPOCH_TO_START_VAL:
                print(f"Skip val until epoch {config.EVAL.EPOCH_TO_START_VAL}")
            elif epoch % config.EVAL.VAL_INTERVAL_EPOCH != 0:
                print(
                    f"Skip val since val interval is set to {config.EVAL.VAL_INTERVAL_EPOCH}"
                )

        # update lr for the next epoch
        lr_scheduler.step()

        if config.SAVE_CHECKPOINTS:
            # save checkpoint every epoch
            save_checkpoint(
                os.path.join(checkpoint_dir, checkpoint_epoch_filename(epoch)),
                model,
                optimizer,
                lr_scheduler,
                epoch + 1,
                best_score,
            )
            save_checkpoint(
                os.path.join(checkpoint_dir, checkpoint_latest_filename()),
                model,
                optimizer,
                lr_scheduler,
                epoch + 1,
                best_score,
            )

    tblogger.close()