def get_criterions(
    criterions,
    criterions_stride2=None,
    criterions_stride4=None,
    criterions_stride8=None,
    criterions_stride16=None,
    ignore_index=None,
) -> Tuple[List[Callback], Dict]:
    criterions_dict = {}
    losses = []
    callbacks = []

    # Create main losses
    for loss_name, loss_weight in criterions:
        criterion_callback = CriterionCallback(
            prefix=f"{OUTPUT_MASK_KEY}/" + loss_name,
            input_key=INPUT_MASK_KEY if loss_name != "wbce" else [INPUT_MASK_KEY, INPUT_MASK_WEIGHT_KEY],
            output_key=OUTPUT_MASK_KEY,
            criterion_key=f"{OUTPUT_MASK_KEY}/" + loss_name,
            multiplier=float(loss_weight),
        )

        criterions_dict[criterion_callback.criterion_key] = get_loss(loss_name, ignore_index=ignore_index)
        callbacks.append(criterion_callback)
        losses.append(criterion_callback.prefix)
        print("Using loss", loss_name, loss_weight)

    # Additional supervision losses
    for supervision_losses, supervision_output in zip(
        [criterions_stride2, criterions_stride4, criterions_stride8, criterions_stride16],
        [OUTPUT_MASK_2_KEY, OUTPUT_MASK_4_KEY, OUTPUT_MASK_8_KEY, OUTPUT_MASK_16_KEY],
    ):
        if supervision_losses is not None:
            for loss_name, loss_weight in supervision_losses:
                prefix = f"{supervision_output}/" + loss_name
                criterion_callback = CriterionCallback(
                    prefix=prefix,
                    input_key=INPUT_MASK_KEY if loss_name != "wbce" else [INPUT_MASK_KEY, INPUT_MASK_WEIGHT_KEY],
                    output_key=supervision_output,
                    criterion_key=prefix,
                    multiplier=float(loss_weight),
                )

                criterions_dict[criterion_callback.criterion_key] = ResizeTargetToPrediction2d(
                    get_loss(loss_name, ignore_index=ignore_index)
                )
                callbacks.append(criterion_callback)
                losses.append(criterion_callback.prefix)
                print("Using loss", loss_name, loss_weight)

    callbacks.append(MetricAggregationCallback(prefix="loss", metrics=losses, mode="sum"))
    return callbacks, criterions_dict
Exemple #2
0
def test_fn_ends_with_pass_on_callbacks():
    """@TODO: Docs. Contribution is welcome."""
    def test_fn_ends_with_pass_on_callback(
        callback,
        events,
    ):
        for event in events["covered"]:
            fn_name = f"on_{event}"
            assert (utils.fn_ends_with_pass(
                getattr(callback.__class__, fn_name)) is False)
        for event in events["non-covered"]:
            fn_name = f"on_{event}"
            assert (utils.fn_ends_with_pass(
                getattr(callback.__class__, fn_name)) is True)

    # Callback test
    from catalyst.dl import Callback

    callback = Callback(order=1)
    start_events = [
        "stage_start",
        "epoch_start",
        "batch_start",
        "loader_start",
    ]
    end_events = [
        "stage_end",
        "epoch_end",
        "batch_end",
        "loader_end",
        "exception",
    ]
    events = {"covered": [], "non-covered": [*start_events, *end_events]}
    test_fn_ends_with_pass_on_callback(callback=callback, events=events)

    # CriterionCallback test
    from catalyst.dl import CriterionCallback

    callback = CriterionCallback()
    covered_events = ["stage_start", "batch_end"]
    non_covered_start_events = ["epoch_start", "batch_start", "loader_start"]
    non_covered_end_events = [
        "stage_end",
        "epoch_end",
        "loader_end",
        "exception",
    ]
    events = {
        "covered": [*covered_events],
        "non-covered": [*non_covered_start_events, *non_covered_end_events],
    }
    test_fn_ends_with_pass_on_callback(callback=callback, events=events)
Exemple #3
0
def get_criterion_callback(loss_name,
                           input_key,
                           output_key,
                           prefix=None,
                           loss_weight=1.0,
                           ignore_index=UNLABELED_SAMPLE):
    criterions_dict = {
        f"{prefix}/{loss_name}": get_loss(loss_name, ignore_index=ignore_index)
    }
    if prefix is None:
        prefix = input_key

    criterion_callback = CriterionCallback(
        prefix=f"{prefix}/{loss_name}",
        input_key=input_key,
        output_key=output_key,
        criterion_key=f"{prefix}/{loss_name}",
        multiplier=float(loss_weight),
    )

    return criterions_dict, criterion_callback, criterion_callback.prefix
def create_callbacks(args, criterion_names):
    callbacks = [
        IoUMetricsCallback(mode=args.dice_mode,
                           input_key=args.input_target_key,
                           class_names=args.class_names.split(',')
                           if args.class_names else None),
        CheckpointCallback(save_n_best=args.save_n_best),
        EarlyStoppingCallback(
            patience=args.patience,
            metric=args.eval_metric,
            minimize=True if args.eval_metric == 'loss' else False)
    ]
    metrics_weights = {}
    for cn in criterion_names:
        callbacks.append(
            CriterionCallback(input_key=args.input_target_key,
                              prefix=f"loss_{cn}",
                              criterion_key=cn))
        metrics_weights[f'loss_{cn}'] = 1.0
    callbacks.append(
        MetricAggregationCallback(prefix="loss",
                                  mode="weighted_sum",
                                  metrics=metrics_weights))
    return callbacks
Exemple #5
0
mixup_alpha = get_dict_value_or_default(config,
                                        key='mixup_alpha',
                                        default_value=0.3)

if mixup:
    callbacks.extend([
        MixupCallback(crit_key='h1',
                      input_key='h1_targets',
                      output_key='h1_logits',
                      alpha=mixup_alpha,
                      on_train_only=False),
    ])
else:
    callbacks.extend([
        CriterionCallback(input_key="h1_targets",
                          output_key="h1_logits",
                          prefix="loss_h1",
                          criterion_key="h1"),
        CriterionCallback(input_key="h2_targets",
                          output_key="h2_logits",
                          prefix="loss_h2",
                          criterion_key="h2"),
        CriterionCallback(input_key="h3_targets",
                          output_key="h3_logits",
                          prefix="loss_h3",
                          criterion_key="h3"),
        crit_agg,
    ])

callbacks.extend([
    score_callback,
    EarlyStoppingCallback(metric='weight_recall',
                           max_lr=0.0016,
                           steps_per_epoch=1,
                           epochs=num_epochs)
    # scheduler = OneCycleLRWithWarmup(
    #     optimizer,
    #     num_steps=num_epochs,
    #     lr_range=(0.0016, 0.0000001),
    #     init_lr = learning_rate,
    #     warmup_steps=15
    # )
    loaders = get_loaders(preprocessing_fn, batch_size=8)

    callbacks = [
        # Each criterion is calculated separately.
        CriterionCallback(input_key="mask",
                          prefix="loss_dice",
                          criterion_key="dice"),
        CriterionCallback(input_key="mask",
                          prefix="loss_iou",
                          criterion_key="iou"),
        CriterionCallback(input_key="mask",
                          prefix="loss_bce",
                          criterion_key="bce"),
        ClasswiseIouCallback(input_key="mask",
                             prefix='clswise_iou',
                             classes=CLASSES.keys()),

        # And only then we aggregate everything into one loss.
        MetricAggregationCallback(
            prefix="loss",
            mode="weighted_sum",  # can be "sum", "weighted_sum" or "mean"
def main():
    parser = argparse.ArgumentParser()

    ###########################################################################################
    # Distributed-training related stuff
    parser.add_argument("--local_rank", type=int, default=0)
    ###########################################################################################

    parser.add_argument("-acc",
                        "--accumulation-steps",
                        type=int,
                        default=1,
                        help="Number of batches to process")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("-v", "--verbose", action="store_true")
    parser.add_argument("--fast", action="store_true")
    parser.add_argument(
        "-dd",
        "--data-dir",
        type=str,
        help="Data directory for INRIA sattelite dataset",
        default=os.environ.get("INRIA_DATA_DIR"),
    )
    parser.add_argument("-dd-xview2",
                        "--data-dir-xview2",
                        type=str,
                        required=False,
                        help="Data directory for external xView2 dataset")
    parser.add_argument("-m",
                        "--model",
                        type=str,
                        default="resnet34_fpncat128",
                        help="")
    parser.add_argument("-b",
                        "--batch-size",
                        type=int,
                        default=8,
                        help="Batch Size during training, e.g. -b 64")
    parser.add_argument("-e",
                        "--epochs",
                        type=int,
                        default=100,
                        help="Epoch to run")
    # parser.add_argument('-es', '--early-stopping', type=int, default=None, help='Maximum number of epochs without improvement')
    # parser.add_argument('-fe', '--freeze-encoder', type=int, default=0, help='Freeze encoder parameters for N epochs')
    # parser.add_argument('-ft', '--fine-tune', action='store_true')
    parser.add_argument("-lr",
                        "--learning-rate",
                        type=float,
                        default=1e-3,
                        help="Initial learning rate")
    parser.add_argument("-l",
                        "--criterion",
                        type=str,
                        required=True,
                        action="append",
                        nargs="+",
                        help="Criterion")
    parser.add_argument("-o",
                        "--optimizer",
                        default="RAdam",
                        help="Name of the optimizer")
    parser.add_argument(
        "-c",
        "--checkpoint",
        type=str,
        default=None,
        help="Checkpoint filename to use as initial model weights")
    parser.add_argument("-w",
                        "--workers",
                        default=8,
                        type=int,
                        help="Num workers")
    parser.add_argument("-a",
                        "--augmentations",
                        default="hard",
                        type=str,
                        help="")
    parser.add_argument("-tm",
                        "--train-mode",
                        default="random",
                        type=str,
                        help="")
    parser.add_argument("--run-mode", default="fit_predict", type=str, help="")
    parser.add_argument("--transfer", default=None, type=str, help="")
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--size", default=512, type=int)
    parser.add_argument("-s",
                        "--scheduler",
                        default="multistep",
                        type=str,
                        help="")
    parser.add_argument("-x", "--experiment", default=None, type=str, help="")
    parser.add_argument("-d",
                        "--dropout",
                        default=None,
                        type=float,
                        help="Dropout before head layer")
    parser.add_argument("--opl", action="store_true")
    parser.add_argument(
        "--warmup",
        default=0,
        type=int,
        help="Number of warmup epochs with reduced LR on encoder parameters")
    parser.add_argument("-wd",
                        "--weight-decay",
                        default=0,
                        type=float,
                        help="L2 weight decay")
    parser.add_argument("--show", action="store_true")
    parser.add_argument("--dsv", action="store_true")

    args = parser.parse_args()

    args.is_master = args.local_rank == 0
    args.distributed = False
    fp16 = args.fp16

    if "WORLD_SIZE" in os.environ:
        args.distributed = int(os.environ["WORLD_SIZE"]) > 1
        args.world_size = int(os.environ["WORLD_SIZE"])
        # args.world_size = torch.distributed.get_world_size()

        print("Initializing init_process_group", args.local_rank)

        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl")
        print("Initialized init_process_group", args.local_rank)

    is_master = args.is_master | (not args.distributed)

    distributed_params = {}
    if args.distributed:
        distributed_params = {"rank": args.local_rank, "syncbn": True}

    if args.fp16:
        distributed_params["apex"] = True
        distributed_params["opt_level"] = "O1"

    set_manual_seed(args.seed + args.local_rank)
    catalyst.utils.set_global_seed(args.seed + args.local_rank)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    data_dir = args.data_dir
    if data_dir is None:
        raise ValueError("--data-dir must be set")

    num_workers = args.workers
    num_epochs = args.epochs
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    model_name = args.model
    optimizer_name = args.optimizer
    image_size = args.size, args.size
    fast = args.fast
    augmentations = args.augmentations
    train_mode = args.train_mode
    scheduler_name = args.scheduler
    experiment = args.experiment
    dropout = args.dropout
    online_pseudolabeling = args.opl
    criterions = args.criterion
    verbose = args.verbose
    show = args.show
    use_dsv = args.dsv
    accumulation_steps = args.accumulation_steps
    weight_decay = args.weight_decay
    extra_data_xview2 = args.data_dir_xview2

    run_train = num_epochs > 0
    need_weight_mask = any(c[0] == "wbce" for c in criterions)

    custom_model_kwargs = {}
    if dropout is not None:
        custom_model_kwargs["dropout"] = float(dropout)

    model: nn.Module = get_model(model_name, **custom_model_kwargs).cuda()

    if args.transfer:
        transfer_checkpoint = fs.auto_file(args.transfer)
        print("Transfering weights from model checkpoint", transfer_checkpoint)
        checkpoint = load_checkpoint(transfer_checkpoint)
        pretrained_dict = checkpoint["model_state_dict"]

        transfer_weights(model, pretrained_dict)

    if args.checkpoint:
        checkpoint = load_checkpoint(fs.auto_file(args.checkpoint))
        unpack_checkpoint(checkpoint, model=model)

        print("Loaded model weights from:", args.checkpoint)
        report_checkpoint(checkpoint)

    main_metric = "optimized_jaccard"

    current_time = datetime.now().strftime("%y%m%d_%H_%M")
    checkpoint_prefix = f"{current_time}_{args.model}"

    if fp16:
        checkpoint_prefix += "_fp16"

    if fast:
        checkpoint_prefix += "_fast"

    if online_pseudolabeling:
        checkpoint_prefix += "_opl"

    if extra_data_xview2:
        checkpoint_prefix += "_with_xview2"

    if experiment is not None:
        checkpoint_prefix = experiment

    if args.distributed:
        checkpoint_prefix += f"_local_rank_{args.local_rank}"

    log_dir = os.path.join("runs", checkpoint_prefix)
    os.makedirs(log_dir, exist_ok=False)

    config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json")
    with open(config_fname, "w") as f:
        train_session_args = vars(args)
        f.write(json.dumps(train_session_args, indent=2))

    default_callbacks = [
        PixelAccuracyCallback(input_key=INPUT_MASK_KEY,
                              output_key=OUTPUT_MASK_KEY),
        # JaccardMetricPerImage(input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_KEY, prefix="jaccard"),
        JaccardMetricPerImageWithOptimalThreshold(input_key=INPUT_MASK_KEY,
                                                  output_key=OUTPUT_MASK_KEY,
                                                  prefix="optimized_jaccard")
    ]

    if is_master:
        default_callbacks += [
            BestMetricCheckpointCallback(target_metric="optimized_jaccard",
                                         target_metric_minimize=False),
            HyperParametersCallback(
                hparam_dict={
                    "model": model_name,
                    "scheduler": scheduler_name,
                    "optimizer": optimizer_name,
                    "augmentations": augmentations,
                    "size": args.size,
                    "weight_decay": weight_decay,
                    "epochs": num_epochs,
                    "dropout": None if dropout is None else float(dropout)
                }),
        ]

    if show and is_master:
        visualize_inria_predictions = partial(
            draw_inria_predictions,
            image_key=INPUT_IMAGE_KEY,
            image_id_key=INPUT_IMAGE_ID_KEY,
            targets_key=INPUT_MASK_KEY,
            outputs_key=OUTPUT_MASK_KEY,
            max_images=16,
        )
        default_callbacks += [
            ShowPolarBatchesCallback(visualize_inria_predictions,
                                     metric="accuracy",
                                     minimize=False),
            ShowPolarBatchesCallback(visualize_inria_predictions,
                                     metric="loss",
                                     minimize=True),
        ]

    train_ds, valid_ds, train_sampler = get_datasets(
        data_dir=data_dir,
        image_size=image_size,
        augmentation=augmentations,
        train_mode=train_mode,
        buildings_only=(train_mode == "tiles"),
        fast=fast,
        need_weight_mask=need_weight_mask,
    )

    if extra_data_xview2 is not None:
        extra_train_ds, _ = get_xview2_extra_dataset(
            extra_data_xview2,
            image_size=image_size,
            augmentation=augmentations,
            fast=fast,
            need_weight_mask=need_weight_mask,
        )

        weights = compute_sample_weight("balanced", [0] * len(train_ds) +
                                        [1] * len(extra_train_ds))
        train_sampler = WeightedRandomSampler(weights,
                                              train_sampler.num_samples * 2)

        train_ds = train_ds + extra_train_ds
        print("Using extra data from xView2 with", len(extra_train_ds),
              "samples")

    if run_train:
        loaders = collections.OrderedDict()
        callbacks = default_callbacks.copy()
        criterions_dict = {}
        losses = []

        ignore_index = None
        if online_pseudolabeling:
            ignore_index = UNLABELED_SAMPLE
            unlabeled_label = get_pseudolabeling_dataset(data_dir,
                                                         include_masks=False,
                                                         augmentation=None,
                                                         image_size=image_size)

            unlabeled_train = get_pseudolabeling_dataset(
                data_dir,
                include_masks=True,
                augmentation=augmentations,
                image_size=image_size)

            loaders["label"] = DataLoader(unlabeled_label,
                                          batch_size=batch_size // 2,
                                          num_workers=num_workers,
                                          pin_memory=True)

            if train_sampler is not None:
                num_samples = 2 * train_sampler.num_samples
            else:
                num_samples = 2 * len(train_ds)
            weights = compute_sample_weight("balanced", [0] * len(train_ds) +
                                            [1] * len(unlabeled_label))

            train_sampler = WeightedRandomSampler(weights,
                                                  num_samples,
                                                  replacement=True)
            train_ds = train_ds + unlabeled_train

            callbacks += [
                BCEOnlinePseudolabelingCallback2d(
                    unlabeled_train,
                    pseudolabel_loader="label",
                    prob_threshold=0.7,
                    output_key=OUTPUT_MASK_KEY,
                    unlabeled_class=UNLABELED_SAMPLE,
                    label_frequency=5,
                )
            ]

            print("Using online pseudolabeling with ", len(unlabeled_label),
                  "samples")

        valid_sampler = None
        if args.distributed:
            if train_sampler is not None:
                train_sampler = DistributedSamplerWrapper(train_sampler,
                                                          args.world_size,
                                                          args.local_rank,
                                                          shuffle=True)
            else:
                train_sampler = DistributedSampler(train_ds,
                                                   args.world_size,
                                                   args.local_rank,
                                                   shuffle=True)
            valid_sampler = DistributedSampler(valid_ds,
                                               args.world_size,
                                               args.local_rank,
                                               shuffle=False)

        loaders["train"] = DataLoader(
            train_ds,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=train_sampler is None,
            sampler=train_sampler,
        )

        loaders["valid"] = DataLoader(valid_ds,
                                      batch_size=batch_size,
                                      num_workers=num_workers,
                                      pin_memory=True,
                                      sampler=valid_sampler)

        # Create losses
        for loss_name, loss_weight in criterions:
            criterion_callback = CriterionCallback(
                prefix=f"{INPUT_MASK_KEY}/" + loss_name,
                input_key=INPUT_MASK_KEY if loss_name != "wbce" else
                [INPUT_MASK_KEY, INPUT_MASK_WEIGHT_KEY],
                output_key=OUTPUT_MASK_KEY,
                criterion_key=loss_name,
                multiplier=float(loss_weight),
            )

            criterions_dict[loss_name] = get_loss(loss_name,
                                                  ignore_index=ignore_index)
            callbacks.append(criterion_callback)
            losses.append(criterion_callback.prefix)
            print("Using loss", loss_name, loss_weight)

        if use_dsv:
            print("Using DSV")
            criterions = "dsv"
            dsv_loss_name = "soft_bce"

            criterions_dict[criterions] = AdaptiveMaskLoss2d(
                get_loss(dsv_loss_name, ignore_index=ignore_index))

            for i, dsv_input in enumerate([
                    OUTPUT_MASK_4_KEY, OUTPUT_MASK_8_KEY, OUTPUT_MASK_16_KEY,
                    OUTPUT_MASK_32_KEY
            ]):
                criterion_callback = CriterionCallback(
                    prefix=f"{dsv_input}/" + dsv_loss_name,
                    input_key=INPUT_MASK_KEY,
                    output_key=dsv_input,
                    criterion_key=criterions,
                    multiplier=1.0,
                )
                callbacks.append(criterion_callback)
                losses.append(criterion_callback.prefix)

        if isinstance(model, SupervisedHGSegmentationModel):
            print("Using Hourglass DSV")
            dsv_loss_name = "kl"

            criterions_dict["dsv"] = get_loss(dsv_loss_name,
                                              ignore_index=ignore_index)
            num_supervision_inputs = model.encoder.num_blocks - 1
            dsv_outputs = [
                OUTPUT_MASK_4_KEY + "_after_hg_" + str(i)
                for i in range(num_supervision_inputs)
            ]

            for i, dsv_input in enumerate(dsv_outputs):
                criterion_callback = CriterionCallback(
                    prefix="supervision/" + dsv_input,
                    input_key=INPUT_MASK_KEY,
                    output_key=dsv_input,
                    criterion_key="dsv",
                    multiplier=(i + 1) / num_supervision_inputs,
                )
                callbacks.append(criterion_callback)
                losses.append(criterion_callback.prefix)

        callbacks += [
            MetricAggregationCallback(prefix="loss",
                                      metrics=losses,
                                      mode="sum"),
            OptimizerCallback(accumulation_steps=accumulation_steps,
                              decouple_weight_decay=False),
        ]

        optimizer = get_optimizer(optimizer_name,
                                  get_optimizable_parameters(model),
                                  learning_rate,
                                  weight_decay=weight_decay)
        scheduler = get_scheduler(scheduler_name,
                                  optimizer,
                                  lr=learning_rate,
                                  num_epochs=num_epochs,
                                  batches_in_epoch=len(loaders["train"]))
        if isinstance(scheduler, (CyclicLR, OneCycleLRWithWarmup)):
            callbacks += [SchedulerCallback(mode="batch")]

        print("Train session    :", checkpoint_prefix)
        print("  FP16 mode      :", fp16)
        print("  Fast mode      :", args.fast)
        print("  Train mode     :", train_mode)
        print("  Epochs         :", num_epochs)
        print("  Workers        :", num_workers)
        print("  Data dir       :", data_dir)
        print("  Log dir        :", log_dir)
        print("  Augmentations  :", augmentations)
        print("  Train size     :", "batches", len(loaders["train"]),
              "dataset", len(train_ds))
        print("  Valid size     :", "batches", len(loaders["valid"]),
              "dataset", len(valid_ds))
        print("Model            :", model_name)
        print("  Parameters     :", count_parameters(model))
        print("  Image size     :", image_size)
        print("Optimizer        :", optimizer_name)
        print("  Learning rate  :", learning_rate)
        print("  Batch size     :", batch_size)
        print("  Criterion      :", criterions)
        print("  Use weight mask:", need_weight_mask)
        if args.distributed:
            print("Distributed")
            print("  World size     :", args.world_size)
            print("  Local rank     :", args.local_rank)
            print("  Is master      :", args.is_master)

        # model training
        runner = SupervisedRunner(input_key=INPUT_IMAGE_KEY,
                                  output_key=None,
                                  device="cuda")
        runner.train(
            fp16=distributed_params,
            model=model,
            criterion=criterions_dict,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "main"),
            num_epochs=num_epochs,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=False,
            checkpoint_data={"cmd_args": vars(args)},
        )

        # Training is finished. Let's run predictions using best checkpoint weights
        best_checkpoint = os.path.join(log_dir, "main", "checkpoints",
                                       "best.pth")

        model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}.pth")
        clean_checkpoint(best_checkpoint, model_checkpoint)

        unpack_checkpoint(torch.load(model_checkpoint), model=model)

        mask = predict(model,
                       read_inria_image("sample_color.jpg"),
                       image_size=image_size,
                       batch_size=args.batch_size)
        mask = ((mask > 0) * 255).astype(np.uint8)
        name = os.path.join(log_dir, "sample_color.jpg")
        cv2.imwrite(name, mask)

        del optimizer, loaders
Exemple #8
0
runner = SupervisedRunner(input_key=["seg_features"],
                          output_key=["cls_logits", "seg_logits"])


def calc_metric(pred, gt, *args, **kwargs):
    pred = torch.sigmoid(pred).detach().cpu().numpy()
    gt = gt.detach().cpu().numpy().astype(np.uint8)
    try:
        return [roc_auc_score(gt.reshape(-1), pred.reshape(-1))]
    except:
        return [0]


callbacks = [
    CriterionCallback(input_key="cls_targets",
                      output_key="cls_logits",
                      prefix="loss_cls",
                      criterion_key="cls"),
    CriterionCallback(input_key="seg_targets",
                      output_key="seg_logits",
                      prefix="loss_seg",
                      criterion_key="seg"),
    CriterionAggregatorCallback(
        prefix="loss",
        loss_keys=["loss_cls", "loss_seg"],
        loss_aggregate_fn="sum"  # or "mean"
    ),
    MultiMetricCallback(metric_fn=calc_metric,
                        prefix='rocauc',
                        input_key="cls_targets",
                        output_key="cls_logits",
                        list_args=['_']),
                          input_target_key=None)

optimizer = RAdam(model.parameters(), lr=args.lr, weight_decay=0.001)

scheduler = ReduceLROnPlateau(optimizer=optimizer,
                              factor=0.75,
                              patience=3,
                              mode="max")

criterion = {'label_loss': nn.CrossEntropyLoss()}

callbacks = [
    CriterionCallback(
        input_key="label",
        output_key="logit_label",
        prefix="label_loss",
        criterion_key="label_loss",
        multiplier=1.0,
    ),
    MetricAggregationCallback(
        prefix="loss",
        metrics=[
            "label_loss",
        ],
    ),
    WeightedAUC(input_key="label", output_key="logit_label")
]

runner.train(model=model,
             criterion=criterion,
             optimizer=optimizer,
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-acc",
                        "--accumulation-steps",
                        type=int,
                        default=1,
                        help="Number of batches to process")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("-v", "--verbose", action="store_true")
    parser.add_argument("--fast", action="store_true")
    parser.add_argument("-dd",
                        "--data-dir",
                        type=str,
                        required=True,
                        help="Data directory for INRIA sattelite dataset")
    parser.add_argument("-dd-xview2",
                        "--data-dir-xview2",
                        type=str,
                        required=False,
                        help="Data directory for external xView2 dataset")
    parser.add_argument("-m",
                        "--model",
                        type=str,
                        default="resnet34_fpncat128",
                        help="")
    parser.add_argument("-b",
                        "--batch-size",
                        type=int,
                        default=8,
                        help="Batch Size during training, e.g. -b 64")
    parser.add_argument("-e",
                        "--epochs",
                        type=int,
                        default=100,
                        help="Epoch to run")
    # parser.add_argument('-es', '--early-stopping', type=int, default=None, help='Maximum number of epochs without improvement')
    # parser.add_argument('-fe', '--freeze-encoder', type=int, default=0, help='Freeze encoder parameters for N epochs')
    # parser.add_argument('-ft', '--fine-tune', action='store_true')
    parser.add_argument("-lr",
                        "--learning-rate",
                        type=float,
                        default=1e-3,
                        help="Initial learning rate")
    parser.add_argument("-l",
                        "--criterion",
                        type=str,
                        required=True,
                        action="append",
                        nargs="+",
                        help="Criterion")
    parser.add_argument("-o",
                        "--optimizer",
                        default="RAdam",
                        help="Name of the optimizer")
    parser.add_argument(
        "-c",
        "--checkpoint",
        type=str,
        default=None,
        help="Checkpoint filename to use as initial model weights")
    parser.add_argument("-w",
                        "--workers",
                        default=8,
                        type=int,
                        help="Num workers")
    parser.add_argument("-a",
                        "--augmentations",
                        default="hard",
                        type=str,
                        help="")
    parser.add_argument("-tm",
                        "--train-mode",
                        default="random",
                        type=str,
                        help="")
    parser.add_argument("--run-mode", default="fit_predict", type=str, help="")
    parser.add_argument("--transfer", default=None, type=str, help="")
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--size", default=512, type=int)
    parser.add_argument("-s",
                        "--scheduler",
                        default="multistep",
                        type=str,
                        help="")
    parser.add_argument("-x", "--experiment", default=None, type=str, help="")
    parser.add_argument("-d",
                        "--dropout",
                        default=0.0,
                        type=float,
                        help="Dropout before head layer")
    parser.add_argument("--opl", action="store_true")
    parser.add_argument(
        "--warmup",
        default=0,
        type=int,
        help="Number of warmup epochs with reduced LR on encoder parameters")
    parser.add_argument("-wd",
                        "--weight-decay",
                        default=0,
                        type=float,
                        help="L2 weight decay")
    parser.add_argument("--show", action="store_true")
    parser.add_argument("--dsv", action="store_true")

    args = parser.parse_args()
    set_manual_seed(args.seed)

    data_dir = args.data_dir
    num_workers = args.workers
    num_epochs = args.epochs
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    model_name = args.model
    optimizer_name = args.optimizer
    image_size = args.size, args.size
    fast = args.fast
    augmentations = args.augmentations
    train_mode = args.train_mode
    fp16 = args.fp16
    scheduler_name = args.scheduler
    experiment = args.experiment
    dropout = args.dropout
    online_pseudolabeling = args.opl
    criterions = args.criterion
    verbose = args.verbose
    warmup = args.warmup
    show = args.show
    use_dsv = args.dsv
    accumulation_steps = args.accumulation_steps
    weight_decay = args.weight_decay
    extra_data_xview2 = args.data_dir_xview2

    run_train = num_epochs > 0
    need_weight_mask = any(c[0] == "wbce" for c in criterions)

    model: nn.Module = get_model(model_name, dropout=dropout).cuda()

    if args.transfer:
        transfer_checkpoint = fs.auto_file(args.transfer)
        print("Transfering weights from model checkpoint", transfer_checkpoint)
        checkpoint = load_checkpoint(transfer_checkpoint)
        pretrained_dict = checkpoint["model_state_dict"]

        transfer_weights(model, pretrained_dict)

    if args.checkpoint:
        checkpoint = load_checkpoint(fs.auto_file(args.checkpoint))
        unpack_checkpoint(checkpoint, model=model)

        print("Loaded model weights from:", args.checkpoint)
        report_checkpoint(checkpoint)

    runner = SupervisedRunner(input_key=INPUT_IMAGE_KEY,
                              output_key=None,
                              device="cuda")
    main_metric = "optimized_jaccard"
    cmd_args = vars(args)

    current_time = datetime.now().strftime("%b%d_%H_%M")
    checkpoint_prefix = f"{current_time}_{args.model}"

    if fp16:
        checkpoint_prefix += "_fp16"

    if fast:
        checkpoint_prefix += "_fast"

    if online_pseudolabeling:
        checkpoint_prefix += "_opl"

    if extra_data_xview2:
        checkpoint_prefix += "_with_xview2"

    if experiment is not None:
        checkpoint_prefix = experiment

    log_dir = os.path.join("runs", checkpoint_prefix)
    os.makedirs(log_dir, exist_ok=False)

    config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json")
    with open(config_fname, "w") as f:
        train_session_args = vars(args)
        f.write(json.dumps(train_session_args, indent=2))

    default_callbacks = [
        PixelAccuracyCallback(input_key=INPUT_MASK_KEY,
                              output_key=OUTPUT_MASK_KEY),
        JaccardMetricPerImage(input_key=INPUT_MASK_KEY,
                              output_key=OUTPUT_MASK_KEY,
                              prefix="jaccard"),
        OptimalThreshold(input_key=INPUT_MASK_KEY,
                         output_key=OUTPUT_MASK_KEY,
                         prefix="optimized_jaccard"),
        # OutputDistributionCallback(output_key=OUTPUT_MASK_KEY, activation=torch.sigmoid),
    ]

    if show:
        visualize_inria_predictions = partial(
            draw_inria_predictions,
            image_key=INPUT_IMAGE_KEY,
            image_id_key=INPUT_IMAGE_ID_KEY,
            targets_key=INPUT_MASK_KEY,
            outputs_key=OUTPUT_MASK_KEY,
        )
        default_callbacks += [
            ShowPolarBatchesCallback(visualize_inria_predictions,
                                     metric="accuracy",
                                     minimize=False)
        ]

    train_ds, valid_ds, train_sampler = get_datasets(
        data_dir=data_dir,
        image_size=image_size,
        augmentation=augmentations,
        train_mode=train_mode,
        fast=fast,
        need_weight_mask=need_weight_mask,
    )

    if extra_data_xview2 is not None:
        extra_train_ds, _ = get_xview2_extra_dataset(
            extra_data_xview2,
            image_size=image_size,
            augmentation=augmentations,
            fast=fast,
            need_weight_mask=need_weight_mask,
        )

        weights = compute_sample_weight("balanced", [0] * len(train_ds) +
                                        [1] * len(extra_train_ds))
        train_sampler = WeightedRandomSampler(weights,
                                              train_sampler.num_samples * 2)

        train_ds = train_ds + extra_train_ds
        print("Using extra data from xView2 with", len(extra_train_ds),
              "samples")

    # Pretrain/warmup
    if warmup:
        callbacks = default_callbacks.copy()
        criterions_dict = {}
        losses = []
        ignore_index = None

        for loss_name, loss_weight in criterions:
            criterion_callback = CriterionCallback(
                prefix="seg_loss/" + loss_name,
                input_key=INPUT_MASK_KEY if loss_name != "wbce" else
                [INPUT_MASK_KEY, INPUT_MASK_WEIGHT_KEY],
                output_key=OUTPUT_MASK_KEY,
                criterion_key=loss_name,
                multiplier=float(loss_weight),
            )

            criterions_dict[loss_name] = get_loss(loss_name,
                                                  ignore_index=ignore_index)
            callbacks.append(criterion_callback)
            losses.append(criterion_callback.prefix)
            print("Using loss", loss_name, loss_weight)

        callbacks += [
            CriterionAggregatorCallback(prefix="loss", loss_keys=losses),
            OptimizerCallback(accumulation_steps=accumulation_steps,
                              decouple_weight_decay=False),
        ]

        parameters = get_lr_decay_parameters(model.named_parameters(),
                                             learning_rate, {"encoder": 0.1})
        optimizer = get_optimizer("RAdam",
                                  parameters,
                                  learning_rate=learning_rate * 0.1)

        loaders = collections.OrderedDict()
        loaders["train"] = DataLoader(
            train_ds,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=train_sampler is None,
            sampler=train_sampler,
        )

        loaders["valid"] = DataLoader(valid_ds,
                                      batch_size=batch_size,
                                      num_workers=num_workers,
                                      pin_memory=True,
                                      shuffle=False,
                                      drop_last=False)

        runner.train(
            fp16=fp16,
            model=model,
            criterion=criterions_dict,
            optimizer=optimizer,
            scheduler=None,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "warmup"),
            num_epochs=warmup,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=False,
            checkpoint_data={"cmd_args": cmd_args},
        )

        del optimizer, loaders

        best_checkpoint = os.path.join(log_dir, "warmup", "checkpoints",
                                       "best.pth")
        model_checkpoint = os.path.join(log_dir, "warmup", "checkpoints",
                                        f"{checkpoint_prefix}_warmup.pth")
        clean_checkpoint(best_checkpoint, model_checkpoint)

        torch.cuda.empty_cache()
        gc.collect()

    if run_train:
        loaders = collections.OrderedDict()
        callbacks = default_callbacks.copy()
        criterions_dict = {}
        losses = []

        ignore_index = None
        if online_pseudolabeling:
            ignore_index = UNLABELED_SAMPLE
            unlabeled_label = get_pseudolabeling_dataset(data_dir,
                                                         include_masks=False,
                                                         augmentation=None,
                                                         image_size=image_size)

            unlabeled_train = get_pseudolabeling_dataset(
                data_dir,
                include_masks=True,
                augmentation=augmentations,
                image_size=image_size)

            loaders["label"] = DataLoader(unlabeled_label,
                                          batch_size=batch_size // 2,
                                          num_workers=num_workers,
                                          pin_memory=True)

            if train_sampler is not None:
                num_samples = 2 * train_sampler.num_samples
            else:
                num_samples = 2 * len(train_ds)
            weights = compute_sample_weight("balanced", [0] * len(train_ds) +
                                            [1] * len(unlabeled_label))

            train_sampler = WeightedRandomSampler(weights,
                                                  num_samples,
                                                  replacement=True)
            train_ds = train_ds + unlabeled_train

            callbacks += [
                BCEOnlinePseudolabelingCallback2d(
                    unlabeled_train,
                    pseudolabel_loader="label",
                    prob_threshold=0.7,
                    output_key=OUTPUT_MASK_KEY,
                    unlabeled_class=UNLABELED_SAMPLE,
                    label_frequency=5,
                )
            ]

            print("Using online pseudolabeling with ", len(unlabeled_label),
                  "samples")

        loaders["train"] = DataLoader(
            train_ds,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=train_sampler is None,
            sampler=train_sampler,
        )

        loaders["valid"] = DataLoader(valid_ds,
                                      batch_size=batch_size,
                                      num_workers=num_workers,
                                      pin_memory=True)

        # Create losses
        for loss_name, loss_weight in criterions:
            criterion_callback = CriterionCallback(
                prefix="seg_loss/" + loss_name,
                input_key=INPUT_MASK_KEY if loss_name != "wbce" else
                [INPUT_MASK_KEY, INPUT_MASK_WEIGHT_KEY],
                output_key=OUTPUT_MASK_KEY,
                criterion_key=loss_name,
                multiplier=float(loss_weight),
            )

            criterions_dict[loss_name] = get_loss(loss_name,
                                                  ignore_index=ignore_index)
            callbacks.append(criterion_callback)
            losses.append(criterion_callback.prefix)
            print("Using loss", loss_name, loss_weight)

        if use_dsv:
            print("Using DSV")
            criterions = "dsv"
            dsv_loss_name = "soft_bce"

            criterions_dict[criterions] = AdaptiveMaskLoss2d(
                get_loss(dsv_loss_name, ignore_index=ignore_index))

            for i, dsv_input in enumerate([
                    OUTPUT_MASK_4_KEY, OUTPUT_MASK_8_KEY, OUTPUT_MASK_16_KEY,
                    OUTPUT_MASK_32_KEY
            ]):
                criterion_callback = CriterionCallback(
                    prefix="seg_loss_dsv/" + dsv_input,
                    input_key=OUTPUT_MASK_KEY,
                    output_key=dsv_input,
                    criterion_key=criterions,
                    multiplier=1.0,
                )
                callbacks.append(criterion_callback)
                losses.append(criterion_callback.prefix)

        callbacks += [
            CriterionAggregatorCallback(prefix="loss", loss_keys=losses),
            OptimizerCallback(accumulation_steps=accumulation_steps,
                              decouple_weight_decay=False),
        ]

        optimizer = get_optimizer(optimizer_name,
                                  get_optimizable_parameters(model),
                                  learning_rate,
                                  weight_decay=weight_decay)
        scheduler = get_scheduler(scheduler_name,
                                  optimizer,
                                  lr=learning_rate,
                                  num_epochs=num_epochs,
                                  batches_in_epoch=len(loaders["train"]))
        if isinstance(scheduler, (CyclicLR, OneCycleLRWithWarmup)):
            callbacks += [SchedulerCallback(mode="batch")]

        print("Train session    :", checkpoint_prefix)
        print("\tFP16 mode      :", fp16)
        print("\tFast mode      :", args.fast)
        print("\tTrain mode     :", train_mode)
        print("\tEpochs         :", num_epochs)
        print("\tWorkers        :", num_workers)
        print("\tData dir       :", data_dir)
        print("\tLog dir        :", log_dir)
        print("\tAugmentations  :", augmentations)
        print("\tTrain size     :", len(loaders["train"]), len(train_ds))
        print("\tValid size     :", len(loaders["valid"]), len(valid_ds))
        print("Model            :", model_name)
        print("\tParameters     :", count_parameters(model))
        print("\tImage size     :", image_size)
        print("Optimizer        :", optimizer_name)
        print("\tLearning rate  :", learning_rate)
        print("\tBatch size     :", batch_size)
        print("\tCriterion      :", criterions)
        print("\tUse weight mask:", need_weight_mask)

        # model training
        runner.train(
            fp16=fp16,
            model=model,
            criterion=criterions_dict,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "main"),
            num_epochs=num_epochs,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=False,
            checkpoint_data={"cmd_args": vars(args)},
        )

        # Training is finished. Let's run predictions using best checkpoint weights
        best_checkpoint = os.path.join(log_dir, "main", "checkpoints",
                                       "best.pth")

        model_checkpoint = os.path.join(log_dir, "main", "checkpoints",
                                        f"{checkpoint_prefix}.pth")
        clean_checkpoint(best_checkpoint, model_checkpoint)

        unpack_checkpoint(torch.load(model_checkpoint), model=model)

        mask = predict(model,
                       read_inria_image("sample_color.jpg"),
                       image_size=image_size,
                       batch_size=args.batch_size)
        mask = ((mask > 0) * 255).astype(np.uint8)
        name = os.path.join(log_dir, "sample_color.jpg")
        cv2.imwrite(name, mask)

        del optimizer, loaders
Exemple #11
0
def main(cfg: DictConfig):

    cwd = Path(get_original_cwd())

    # overwrite config if continue training from checkpoint
    resume_cfg = None
    if "resume" in cfg:
        cfg_path = cwd / cfg.resume / ".hydra/config.yaml"
        print(f"Continue from: {cfg.resume}")
        # Overwrite everything except device
        # TODO config merger (perhaps continue training with the same optimizer but other lrs?)
        resume_cfg = OmegaConf.load(cfg_path)
        cfg.model = resume_cfg.model
        if cfg.train.num_epochs == 0:
            cfg.data.scale_factor = resume_cfg.data.scale_factor
        OmegaConf.save(cfg, ".hydra/config.yaml")

    print(OmegaConf.to_yaml(cfg))

    device = set_device_id(cfg.device)
    set_seed(cfg.seed, device=device)

    # Augmentations
    if cfg.data.aug == "auto":
        transforms = albu.load(cwd / "autoalbument/autoconfig.json")
    else:
        transforms = D.get_training_augmentations()

    if OmegaConf.is_missing(cfg.model, "convert_bottleneck"):
        cfg.model.convert_bottleneck = (0, 0, 0)

    # Model
    print(f"Setup model {cfg.model.arch} {cfg.model.encoder_name} "
          f"convert_bn={cfg.model.convert_bn} "
          f"convert_bottleneck={cfg.model.convert_bottleneck} ")
    model = get_segmentation_model(
        arch=cfg.model.arch,
        encoder_name=cfg.model.encoder_name,
        encoder_weights=cfg.model.encoder_weights,
        classes=1,
        convert_bn=cfg.model.convert_bn,
        convert_bottleneck=cfg.model.convert_bottleneck,
        # decoder_attention_type="scse",  # TODO to config
    )
    model = model.to(device)
    model.train()
    print(model)

    # Optimization
    # Reduce LR for pretrained encoder
    layerwise_params = {
        "encoder*":
        dict(lr=cfg.optim.lr_encoder, weight_decay=cfg.optim.wd_encoder)
    }
    model_params = cutils.process_model_params(
        model, layerwise_params=layerwise_params)

    # Select optimizer
    optimizer = get_optimizer(
        name=cfg.optim.name,
        model_params=model_params,
        lr=cfg.optim.lr,
        wd=cfg.optim.wd,
        lookahead=cfg.optim.lookahead,
    )

    criterion = {
        "dice": DiceLoss(),
        # "dice": SoftDiceLoss(mode="binary", smooth=1e-7),
        "iou": IoULoss(),
        "bce": nn.BCEWithLogitsLoss(),
        "lovasz": LovaszLossBinary(),
        "focal_tversky": FocalTverskyLoss(eps=1e-7, alpha=0.7, gamma=0.75),
    }

    # Load states if resuming training
    if "resume" in cfg:
        checkpoint_path = (cwd / cfg.resume / cfg.train.logdir /
                           "checkpoints/best_full.pth")
        if checkpoint_path.exists():
            print(f"\nLoading checkpoint {str(checkpoint_path)}")
            checkpoint = cutils.load_checkpoint(checkpoint_path)
            cutils.unpack_checkpoint(
                checkpoint=checkpoint,
                model=model,
                optimizer=optimizer
                if resume_cfg.optim.name == cfg.optim.name else None,
                criterion=criterion,
            )
        else:
            raise ValueError("Nothing to resume, checkpoint missing")

    # We could only want to validate resume, in this case skip training routine
    best_th = 0.5

    stats = None
    if cfg.data.stats:
        print(f"Use statistics from file: {cfg.data.stats}")
        stats = cwd / cfg.data.stats

    if cfg.train.num_epochs is not None:
        callbacks = [
            # Each criterion is calculated separately.
            CriterionCallback(input_key="mask",
                              prefix="loss_dice",
                              criterion_key="dice"),
            CriterionCallback(input_key="mask",
                              prefix="loss_iou",
                              criterion_key="iou"),
            CriterionCallback(input_key="mask",
                              prefix="loss_bce",
                              criterion_key="bce"),
            CriterionCallback(input_key="mask",
                              prefix="loss_lovasz",
                              criterion_key="lovasz"),
            CriterionCallback(
                input_key="mask",
                prefix="loss_focal_tversky",
                criterion_key="focal_tversky",
            ),
            # And only then we aggregate everything into one loss.
            MetricAggregationCallback(
                prefix="loss",
                mode="weighted_sum",  # can be "sum", "weighted_sum" or "mean"
                # because we want weighted sum, we need to add scale for each loss
                metrics={
                    "loss_dice": cfg.loss.dice,
                    "loss_iou": cfg.loss.iou,
                    "loss_bce": cfg.loss.bce,
                    "loss_lovasz": cfg.loss.lovasz,
                    "loss_focal_tversky": cfg.loss.focal_tversky,
                },
            ),
            # metrics
            DiceCallback(input_key="mask"),
            IouCallback(input_key="mask"),
            # gradient accumulation
            OptimizerCallback(accumulation_steps=cfg.optim.accumulate),
            # early stopping
            SchedulerCallback(reduced_metric="loss_dice",
                              mode=cfg.scheduler.mode),
            EarlyStoppingCallback(**cfg.scheduler.early_stopping,
                                  minimize=False),
            # TODO WandbLogger works poorly with multistage right now
            WandbLogger(project=cfg.project, config=dict(cfg)),
            # CheckpointCallback(save_n_best=cfg.checkpoint.save_n_best),
        ]

        # Training
        runner = SupervisedRunner(device=device,
                                  input_key="image",
                                  input_target_key="mask")

        # TODO Scheduler does not work now, every stage restarts from base lr
        scheduler_warm_restart = optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[1, 2],
            gamma=10,
        )

        for i, (size, num_epochs) in enumerate(
                zip(cfg.data.sizes, cfg.train.num_epochs)):
            scale = size / 1024
            print(
                f"Training stage {i}, scale {scale}, size {size}, epochs {num_epochs}"
            )

            # Datasets
            (
                train_ds,
                valid_ds,
                train_images,
                val_images,
            ) = D.get_train_valid_datasets_from_path(
                # path=(cwd / cfg.data.path),
                path=(cwd / f"data/hubmap-{size}x{size}/"),
                train_ids=cfg.data.train_ids,
                valid_ids=cfg.data.valid_ids,
                seed=cfg.seed,
                valid_split=cfg.data.valid_split,
                mean=cfg.data.mean,
                std=cfg.data.std,
                transforms=transforms,
                stats=stats,
            )

            train_bs = int(cfg.loader.train_bs / (scale**2))
            valid_bs = int(cfg.loader.valid_bs / (scale**2))
            print(
                f"train: {len(train_ds)}; bs {train_bs}",
                f"valid: {len(valid_ds)}, bs {valid_bs}",
            )

            # Data loaders
            data_loaders = D.get_data_loaders(
                train_ds=train_ds,
                valid_ds=valid_ds,
                train_bs=train_bs,
                valid_bs=valid_bs,
                num_workers=cfg.loader.num_workers,
            )

            # Select scheduler
            scheduler = get_scheduler(
                name=cfg.scheduler.type,
                optimizer=optimizer,
                num_epochs=num_epochs * (len(data_loaders["train"]) if
                                         cfg.scheduler.mode == "batch" else 1),
                eta_min=scheduler_warm_restart.get_last_lr()[0] /
                cfg.scheduler.eta_min_factor,
                plateau=cfg.scheduler.plateau,
            )

            runner.train(
                model=model,
                criterion=criterion,
                optimizer=optimizer,
                scheduler=scheduler,
                callbacks=callbacks,
                logdir=cfg.train.logdir,
                loaders=data_loaders,
                num_epochs=num_epochs,
                verbose=True,
                main_metric=cfg.train.main_metric,
                load_best_on_end=True,
                minimize_metric=False,
                check=cfg.check,
                fp16=dict(amp=cfg.amp),
            )

            # Set new initial LR for optimizer after restart
            scheduler_warm_restart.step()
            print(
                f"New LR for warm restart {scheduler_warm_restart.get_last_lr()[0]}"
            )

            # Find optimal threshold for dice score
            model.eval()
            best_th, dices = find_dice_threshold(model, data_loaders["valid"])
            print("Best dice threshold", best_th, np.max(dices[1]))
            np.save(f"dices_{size}.npy", dices)
    else:
        print("Validation only")
        # Datasets
        size = cfg.data.sizes[-1]
        train_ds, valid_ds = D.get_train_valid_datasets_from_path(
            # path=(cwd / cfg.data.path),
            path=(cwd / f"data/hubmap-{size}x{size}/"),
            train_ids=cfg.data.train_ids,
            valid_ids=cfg.data.valid_ids,
            seed=cfg.seed,
            valid_split=cfg.data.valid_split,
            mean=cfg.data.mean,
            std=cfg.data.std,
            transforms=transforms,
            stats=stats,
        )

        train_bs = int(cfg.loader.train_bs / (cfg.data.scale_factor**2))
        valid_bs = int(cfg.loader.valid_bs / (cfg.data.scale_factor**2))
        print(
            f"train: {len(train_ds)}; bs {train_bs}",
            f"valid: {len(valid_ds)}, bs {valid_bs}",
        )

        # Data loaders
        data_loaders = D.get_data_loaders(
            train_ds=train_ds,
            valid_ds=valid_ds,
            train_bs=train_bs,
            valid_bs=valid_bs,
            num_workers=cfg.loader.num_workers,
        )

        # Find optimal threshold for dice score
        model.eval()
        best_th, dices = find_dice_threshold(model, data_loaders["valid"])
        print("Best dice threshold", best_th, np.max(dices[1]))
        np.save(f"dices_val.npy", dices)

    #
    # # Load best checkpoint
    # checkpoint_path = Path(cfg.train.logdir) / "checkpoints/best.pth"
    # if checkpoint_path.exists():
    #     print(f"\nLoading checkpoint {str(checkpoint_path)}")
    #     state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))[
    #         "model_state_dict"
    #     ]
    #     model.load_state_dict(state_dict)
    #     del state_dict
    # model = model.to(device)
    # Load config for updating with threshold and metric
    # (otherwise loading do not work)
    cfg = OmegaConf.load(".hydra/config.yaml")
    cfg.threshold = float(best_th)

    # Evaluate on full-size image if valid_ids is non-empty
    df_train = pd.read_csv(cwd / "data/train.csv")
    df_train = {
        r["id"]: r["encoding"]
        for r in df_train.to_dict(orient="record")
    }
    dices = []
    unique_ids = sorted(
        set(
            str(p).split("/")[-1].split("_")[0]
            for p in (cwd / cfg.data.path / "train").iterdir()))
    size = cfg.data.sizes[-1]
    scale = size / 1024
    for image_id in cfg.data.valid_ids:
        image_name = unique_ids[image_id]
        print(f"\nValidate for {image_name}")

        rle_pred, shape = inference_one(
            image_path=(cwd / f"data/train/{image_name}.tiff"),
            target_path=Path("."),
            cfg=cfg,
            model=model,
            scale_factor=scale,
            tile_size=cfg.data.tile_size,
            tile_step=cfg.data.tile_step,
            threshold=best_th,
            save_raw=False,
            tta_mode=None,
            weight="pyramid",
            device=device,
            filter_crops="tissue",
            stats=stats,
        )

        print("Predict", shape)
        pred = rle_decode(rle_pred["predicted"], shape)
        mask = rle_decode(df_train[image_name], shape)
        assert pred.shape == mask.shape, f"pred {pred.shape}, mask {mask.shape}"
        assert pred.shape == shape, f"pred {pred.shape}, expected {shape}"

        dices.append(
            dice(
                torch.from_numpy(pred).type(torch.uint8),
                torch.from_numpy(mask).type(torch.uint8),
                threshold=None,
                activation="none",
            ))
    print("Full image dice:", np.mean(dices))
    OmegaConf.save(cfg, ".hydra/config.yaml")
    return
Exemple #12
0
def train_segmentation_model(
        model: torch.nn.Module,
        logdir: str,
        num_epochs: int,
        loaders: Dict[str, DataLoader]
):
    criterion = {
        "dice": DiceLoss(),
        "iou": IoULoss(),
        "bce": nn.BCEWithLogitsLoss()
    }

    learning_rate = 0.001
    encoder_learning_rate = 0.0005

    layerwise_params = {"encoder*": dict(lr=encoder_learning_rate, weight_decay=0.00003)}
    model_params = utils.process_model_params(model, layerwise_params=layerwise_params)
    base_optimizer = RAdam(model_params, lr=learning_rate, weight_decay=0.0003)
    optimizer = Lookahead(base_optimizer)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.25, patience=2)

    device = utils.get_device()
    runner = SupervisedRunner(device=device, input_key='image', input_target_key='mask')

    callbacks = [
        CriterionCallback(
            input_key="mask",
            prefix="loss_dice",
            criterion_key="dice"
        ),
        CriterionCallback(
            input_key="mask",
            prefix="loss_iou",
            criterion_key="iou"
        ),
        CriterionCallback(
            input_key="mask",
            prefix="loss_bce",
            criterion_key="bce"
        ),

        MetricAggregationCallback(
            prefix="loss",
            mode="weighted_sum",
            metrics={"loss_dice": 1.0, "loss_iou": 1.0, "loss_bce": 0.8},
        ),

        # metrics
        DiceCallback(input_key='mask'),
        IouCallback(input_key='mask'),
    ]

    runner.train(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        loaders=loaders,
        callbacks=callbacks,
        logdir=logdir,
        num_epochs=num_epochs,
        main_metric="iou",
        minimize_metric=False,
        verbose=True,
        load_best_on_end=True,
    )
    best_model_save_dir = os.path.join(logdir, 'save')
    os.makedirs(best_model_save_dir, True)
    torch.save(model, os.path.join(best_model_save_dir, 'best_model.pth'))   # save best model (by valid loss)
    batch = next(iter(loaders["valid"]))
    try:
        runner.trace(model=model, batch=batch, logdir=logdir, fp16=False)  # optimized version (not all models can be traced)
    except Exception:
        pass
Exemple #13
0
         "logit_grapheme",
     ),
     input_key=(
         "grapheme_root",
         "vowel_diacritic",
         "consonant_diacritic",
         "grapheme",
     ),
     mixuponly=False,
     alpha=0.5,
     resolution=IMG_SIZE,
 ),
 CriterionCallback(
     input_key="grapheme_root",
     output_key="logit_grapheme_root",
     prefix="grapheme_root_loss",
     criterion_key="grapheme_root_loss",
     multiplier=2.0,
 ),
 CriterionCallback(
     input_key="vowel_diacritic",
     output_key="logit_vowel_diacritic",
     prefix="vowel_diacritic_loss",
     criterion_key="vowel_diacritic_loss",
     multiplier=1.0,
 ),
 CriterionCallback(
     input_key="consonant_diacritic",
     output_key="logit_consonant_diacritic",
     prefix="consonant_diacritic_loss",
     criterion_key="consonant_diacritic_loss",
Exemple #14
0
def get_cls_callbacks(loss_name,
                      num_classes,
                      num_epochs,
                      class_names,
                      tsa=None,
                      uda=None,
                      show=False):
    if len(loss_name) == 1:
        loss_name, multiplier = loss_name[0], 1.0
    elif len(loss_name) == 2:
        loss_name, multiplier = loss_name[0], float(loss_name[1])
    else:
        raise ValueError(loss_name)

    criterions = {'cls': get_loss(loss_name, ignore_index=UNLABELED_CLASS)}
    output_key = 'logits'

    if tsa:
        crit_callback = TSACriterionCallback(prefix='cls/tsa_loss',
                                             loss_key='cls',
                                             output_key=output_key,
                                             criterion_key='cls',
                                             multiplier=multiplier,
                                             num_classes=num_classes,
                                             num_epochs=num_epochs)
    else:
        crit_callback = CriterionCallback(prefix='cls/loss',
                                          loss_key='cls',
                                          output_key=output_key,
                                          criterion_key='cls',
                                          multiplier=multiplier)

    callbacks = [
        crit_callback,
        CappaScoreCallback(prefix='cls/kappa',
                           output_key=output_key,
                           ignore_index=UNLABELED_CLASS,
                           class_names=class_names),
        # Metrics
        CustomAccuracyCallback(output_key=output_key,
                               prefix='cls/accuracy',
                               ignore_index=UNLABELED_CLASS),
        # F1 scores
        FScoreCallback(prefix='cls/f1_macro',
                       beta=1,
                       average='macro',
                       output_key=output_key,
                       ignore_index=UNLABELED_CLASS),
        FScoreCallback(prefix='cls/f1_micro',
                       beta=2,
                       average='micro',
                       output_key=output_key,
                       ignore_index=UNLABELED_CLASS),
        # F2 scores
        FScoreCallback(prefix='cls/f2_macro',
                       beta=2,
                       average='macro',
                       output_key=output_key,
                       ignore_index=UNLABELED_CLASS),
        FScoreCallback(prefix='cls/f2_micro',
                       beta=2,
                       average='micro',
                       output_key=output_key,
                       ignore_index=UNLABELED_CLASS)
    ]

    if uda:
        callbacks += [
            UDACriterionCallback(prefix='cls/uda',
                                 output_key=output_key,
                                 unsupervised_label=UNLABELED_CLASS)
        ]
    else:
        callbacks += [
            ConfusionMatrixCallback(prefix='cls/confusion',
                                    output_key=output_key,
                                    class_names=class_names),
            NegativeMiningCallback(ignore_index=UNLABELED_CLASS),
        ]

    if show:
        visualization_fn = partial(draw_classification_predictions,
                                   class_names=class_names)
        callbacks += [
            ShowPolarBatchesCallback(visualization_fn,
                                     metric='cls/accuracy',
                                     minimize=False)
        ]
    return callbacks, criterions
Exemple #15
0
def get_reg_callbacks(loss_name,
                      class_names,
                      prefix='reg',
                      output_key='regression',
                      uda=None,
                      show=False):
    if len(loss_name) == 1:
        loss_name, multiplier = loss_name[0], 1.0
    elif len(loss_name) == 2:
        loss_name, multiplier = loss_name[0], float(loss_name[1])
    else:
        raise ValueError(loss_name)

    criterions = {prefix: get_loss(loss_name, ignore_index=UNLABELED_CLASS)}
    callbacks = [
        # Loss
        CriterionCallback(prefix=f'{prefix}/loss',
                          loss_key=prefix,
                          output_key=output_key,
                          criterion_key=prefix,
                          multiplier=multiplier),
        # Metrics
        RMSEMetric(prefix=f'{prefix}/rmse', output_key=output_key),
        CappaScoreCallback(prefix=f'{prefix}/kappa',
                           output_key=output_key,
                           ignore_index=UNLABELED_CLASS,
                           class_names=class_names,
                           optimize_thresholds=False,
                           from_regression=True),
        CustomAccuracyCallback(prefix=f'{prefix}/accuracy',
                               output_key=output_key,
                               from_regression=True,
                               ignore_index=UNLABELED_CLASS),
        ConfusionMatrixCallbackFromRegression(prefix=f'{prefix}/confusion',
                                              output_key=output_key,
                                              class_names=class_names,
                                              ignore_index=UNLABELED_CLASS),
        # F1 scores
        FScoreCallback(prefix=f'{prefix}/f1_macro',
                       beta=1,
                       average='macro',
                       output_key=output_key,
                       from_regression=True,
                       ignore_index=UNLABELED_CLASS),
        FScoreCallback(prefix=f'{prefix}/f1_micro',
                       beta=2,
                       average='micro',
                       output_key=output_key,
                       from_regression=True,
                       ignore_index=UNLABELED_CLASS),
        # F2 scores
        FScoreCallback(prefix=f'{prefix}/f2_macro',
                       beta=2,
                       average='macro',
                       output_key=output_key,
                       from_regression=True,
                       ignore_index=UNLABELED_CLASS),
        FScoreCallback(prefix=f'{prefix}/f2_micro',
                       beta=2,
                       average='micro',
                       output_key=output_key,
                       from_regression=True,
                       ignore_index=UNLABELED_CLASS)
    ]

    if uda:
        callbacks += [
            UDARegressionCriterionCallback(prefix=f'{prefix}/uda',
                                           output_key=output_key,
                                           unsupervised_label=UNLABELED_CLASS)
        ]

    if show:
        visualization_fn = partial(draw_regression_predictions,
                                   outputs_key=output_key,
                                   class_names=class_names,
                                   unsupervised_label=UNLABELED_CLASS)
        callbacks += [
            ShowPolarBatchesCallback(visualization_fn,
                                     metric=f'{prefix}/accuracy',
                                     minimize=False)
        ]

    return callbacks, criterions