예제 #1
0
    def get_optimizer(self, stage: str, model: nn.Module) -> _Optimizer:
        optimizer_params = \
            self.stages_config[stage].get("optimizer_params", {})

        layerwise_params = \
            optimizer_params.pop("layerwise_params", OrderedDict())
        no_bias_weight_decay = \
            optimizer_params.pop("no_bias_weight_decay", True)

        # linear scaling rule from https://arxiv.org/pdf/1706.02677.pdf
        lr_scaling_params = optimizer_params.pop("lr_linear_scaling", None)
        if lr_scaling_params:
            data_params = dict(self.stages_config[stage]["data_params"])
            batch_size = data_params.get("batch_size")
            per_gpu_scaling = data_params.get("per_gpu_scaling", False)
            distributed_rank = self.distributed_params.get("rank", -1)
            distributed = distributed_rank > -1
            if per_gpu_scaling and not distributed:
                num_gpus = max(1, torch.cuda.device_count())
                batch_size *= num_gpus

            base_lr = lr_scaling_params.get("lr")
            base_batch_size = lr_scaling_params.get("base_batch_size", 256)
            lr_scaling = batch_size / base_batch_size
            optimizer_params["lr"] = base_lr * lr_scaling  # scale default lr
        else:
            lr_scaling = 1.0

        model_params = process_model_params(model, layerwise_params,
                                            no_bias_weight_decay, lr_scaling)

        optimizer = self._get_optimizer(model_params=model_params,
                                        **optimizer_params)

        return optimizer
예제 #2
0
def create_optimizer(args, model, filter_bias_and_bn=True):
    opt_lower = args.opt.lower()
    weight_decay = args.weight_decay

    layerwise_params = {
        "encoder*": dict(lr=args.lr * args.ev, weight_decay=args.weight_decay)
    }
    parameters = process_model_params(model, layerwise_params=layerwise_params)

    opt_args = dict(lr=args.lr, weight_decay=weight_decay)
    if hasattr(args, 'opt_eps') and args.opt_eps is not None:
        opt_args['eps'] = args.opt_eps
    if hasattr(args, 'opt_betas') and args.opt_betas is not None:
        opt_args['betas'] = args.opt_betas

    opt_split = opt_lower.split('_')
    opt_lower = opt_split[-1]
    if opt_lower == 'sgd' or opt_lower == 'nesterov':
        opt_args.pop('eps', None)
        optimizer = SGD(parameters,
                        momentum=args.momentum,
                        nesterov=True,
                        **opt_args)
    elif opt_lower == 'momentum':
        opt_args.pop('eps', None)
        optimizer = SGD(parameters,
                        momentum=args.momentum,
                        nesterov=False,
                        **opt_args)
    elif opt_lower == 'adam':
        optimizer = Adam(parameters, **opt_args)
    elif opt_lower == 'adamw':
        optimizer = AdamW(parameters, **opt_args)
    elif opt_lower == 'radam':
        optimizer = RAdam(parameters, **opt_args)
    elif opt_lower == 'radam':
        optimizer = RAdam(parameters, **opt_args)
    else:
        assert False and "Invalid optimizer"
        raise ValueError

    if len(opt_split) > 1:
        if opt_split[0] == 'lookahead':
            _logger.info('Using lookahead')
            optimizer = Lookahead(optimizer)

    return optimizer
예제 #3
0
def smart_way():
    args = parse_arguments()
    SEED = args.seed
    ROOT = Path(args.dataset)

    img_paths, targets = retrieve_dataset(ROOT)

    train_transforms = compose(
        [resize_transforms(),
         hard_transforms(),
         post_transforms()])
    valid_transforms = compose([pre_transforms(), post_transforms()])
    loaders = get_loaders(
        img_paths=img_paths,
        targets=targets,
        random_state=SEED,
        batch_size=8,
        train_transforms_fn=train_transforms,
        valid_transforms_fn=valid_transforms,
    )

    logdir = './table_recognition/nn/regression/logs6/'

    model = torch.load(
        f'./table_recognition/nn/segmentation/logs/resnet18_PSPNet/save/best_model.pth'
    )
    model: RegressionFromSegmentation = RegressionFromSegmentation(model)
    model.to(utils.get_device())

    learning_rate = 0.001
    encoder_learning_rate = 0.0005

    layerwise_params = {
        "encoder*": dict(lr=encoder_learning_rate, weight_decay=0.00003)
    }
    model_params = utils.process_model_params(
        model, layerwise_params=layerwise_params)
    base_optimizer = RAdam(model_params, lr=learning_rate, weight_decay=0.0003)
    optimizer = Lookahead(base_optimizer)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     factor=0.25,
                                                     patience=2)

    device = utils.get_device()

    runner = CustomRunner2(device=device)
    runner.train(model=model,
                 optimizer=optimizer,
                 scheduler=scheduler,
                 loaders=loaders,
                 logdir=logdir,
                 num_epochs=1000,
                 verbose=True,
                 load_best_on_end=True,
                 main_metric='loss')

    best_model_save_dir = os.path.join(logdir, 'save')
    os.makedirs(best_model_save_dir, exist_ok=True)
    torch.save(model, os.path.join(
        best_model_save_dir,
        'best_model.pth'))  # save best model (by valid loss)
    batch = next(iter(loaders["valid"]))
    try:
        runner.trace(
            model=model, batch=batch, logdir=logdir,
            fp16=False)  # optimized version (not all models can be traced)
    except Exception:
        pass
예제 #4
0
criterion = {
    "dice": DiceLoss(),
    "iou": IoULoss(),
    "bce": nn.BCEWithLogitsLoss()
}

learning_rate = 0.001
encoder_learning_rate = 0.0005

# Since we use a pre-trained encoder, we will reduce the learning rate on it.
layerwise_params = {
    "encoder*": dict(lr=encoder_learning_rate, weight_decay=0.00003)
}

# This function removes weight_decay for biases and applies our layerwise_params
model_params = utils.process_model_params(model,
                                          layerwise_params=layerwise_params)

# Catalyst has new SOTA optimizers out of box
base_optimizer = RAdam(model_params, lr=learning_rate, weight_decay=0.0003)
optimizer = Lookahead(base_optimizer)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                 factor=0.25,
                                                 patience=2)

num_epochs = 3
logdir = "./logs/segmentation"

device = utils.get_device()
print(f"device: {device}")
예제 #5
0
def main(cfg: DictConfig):

    cwd = Path(get_original_cwd())

    # overwrite config if continue training from checkpoint
    resume_cfg = None
    if "resume" in cfg:
        cfg_path = cwd / cfg.resume / ".hydra/config.yaml"
        print(f"Continue from: {cfg.resume}")
        # Overwrite everything except device
        # TODO config merger (perhaps continue training with the same optimizer but other lrs?)
        resume_cfg = OmegaConf.load(cfg_path)
        cfg.model = resume_cfg.model
        if cfg.train.num_epochs == 0:
            cfg.data.scale_factor = resume_cfg.data.scale_factor
        OmegaConf.save(cfg, ".hydra/config.yaml")

    print(OmegaConf.to_yaml(cfg))

    device = set_device_id(cfg.device)
    set_seed(cfg.seed, device=device)

    # Augmentations
    if cfg.data.aug == "auto":
        transforms = albu.load(cwd / "autoalbument/autoconfig.json")
    else:
        transforms = D.get_training_augmentations()

    if OmegaConf.is_missing(cfg.model, "convert_bottleneck"):
        cfg.model.convert_bottleneck = (0, 0, 0)

    # Model
    print(f"Setup model {cfg.model.arch} {cfg.model.encoder_name} "
          f"convert_bn={cfg.model.convert_bn} "
          f"convert_bottleneck={cfg.model.convert_bottleneck} ")
    model = get_segmentation_model(
        arch=cfg.model.arch,
        encoder_name=cfg.model.encoder_name,
        encoder_weights=cfg.model.encoder_weights,
        classes=1,
        convert_bn=cfg.model.convert_bn,
        convert_bottleneck=cfg.model.convert_bottleneck,
        # decoder_attention_type="scse",  # TODO to config
    )
    model = model.to(device)
    model.train()
    print(model)

    # Optimization
    # Reduce LR for pretrained encoder
    layerwise_params = {
        "encoder*":
        dict(lr=cfg.optim.lr_encoder, weight_decay=cfg.optim.wd_encoder)
    }
    model_params = cutils.process_model_params(
        model, layerwise_params=layerwise_params)

    # Select optimizer
    optimizer = get_optimizer(
        name=cfg.optim.name,
        model_params=model_params,
        lr=cfg.optim.lr,
        wd=cfg.optim.wd,
        lookahead=cfg.optim.lookahead,
    )

    criterion = {
        "dice": DiceLoss(),
        # "dice": SoftDiceLoss(mode="binary", smooth=1e-7),
        "iou": IoULoss(),
        "bce": nn.BCEWithLogitsLoss(),
        "lovasz": LovaszLossBinary(),
        "focal_tversky": FocalTverskyLoss(eps=1e-7, alpha=0.7, gamma=0.75),
    }

    # Load states if resuming training
    if "resume" in cfg:
        checkpoint_path = (cwd / cfg.resume / cfg.train.logdir /
                           "checkpoints/best_full.pth")
        if checkpoint_path.exists():
            print(f"\nLoading checkpoint {str(checkpoint_path)}")
            checkpoint = cutils.load_checkpoint(checkpoint_path)
            cutils.unpack_checkpoint(
                checkpoint=checkpoint,
                model=model,
                optimizer=optimizer
                if resume_cfg.optim.name == cfg.optim.name else None,
                criterion=criterion,
            )
        else:
            raise ValueError("Nothing to resume, checkpoint missing")

    # We could only want to validate resume, in this case skip training routine
    best_th = 0.5

    stats = None
    if cfg.data.stats:
        print(f"Use statistics from file: {cfg.data.stats}")
        stats = cwd / cfg.data.stats

    if cfg.train.num_epochs is not None:
        callbacks = [
            # Each criterion is calculated separately.
            CriterionCallback(input_key="mask",
                              prefix="loss_dice",
                              criterion_key="dice"),
            CriterionCallback(input_key="mask",
                              prefix="loss_iou",
                              criterion_key="iou"),
            CriterionCallback(input_key="mask",
                              prefix="loss_bce",
                              criterion_key="bce"),
            CriterionCallback(input_key="mask",
                              prefix="loss_lovasz",
                              criterion_key="lovasz"),
            CriterionCallback(
                input_key="mask",
                prefix="loss_focal_tversky",
                criterion_key="focal_tversky",
            ),
            # And only then we aggregate everything into one loss.
            MetricAggregationCallback(
                prefix="loss",
                mode="weighted_sum",  # can be "sum", "weighted_sum" or "mean"
                # because we want weighted sum, we need to add scale for each loss
                metrics={
                    "loss_dice": cfg.loss.dice,
                    "loss_iou": cfg.loss.iou,
                    "loss_bce": cfg.loss.bce,
                    "loss_lovasz": cfg.loss.lovasz,
                    "loss_focal_tversky": cfg.loss.focal_tversky,
                },
            ),
            # metrics
            DiceCallback(input_key="mask"),
            IouCallback(input_key="mask"),
            # gradient accumulation
            OptimizerCallback(accumulation_steps=cfg.optim.accumulate),
            # early stopping
            SchedulerCallback(reduced_metric="loss_dice",
                              mode=cfg.scheduler.mode),
            EarlyStoppingCallback(**cfg.scheduler.early_stopping,
                                  minimize=False),
            # TODO WandbLogger works poorly with multistage right now
            WandbLogger(project=cfg.project, config=dict(cfg)),
            # CheckpointCallback(save_n_best=cfg.checkpoint.save_n_best),
        ]

        # Training
        runner = SupervisedRunner(device=device,
                                  input_key="image",
                                  input_target_key="mask")

        # TODO Scheduler does not work now, every stage restarts from base lr
        scheduler_warm_restart = optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[1, 2],
            gamma=10,
        )

        for i, (size, num_epochs) in enumerate(
                zip(cfg.data.sizes, cfg.train.num_epochs)):
            scale = size / 1024
            print(
                f"Training stage {i}, scale {scale}, size {size}, epochs {num_epochs}"
            )

            # Datasets
            (
                train_ds,
                valid_ds,
                train_images,
                val_images,
            ) = D.get_train_valid_datasets_from_path(
                # path=(cwd / cfg.data.path),
                path=(cwd / f"data/hubmap-{size}x{size}/"),
                train_ids=cfg.data.train_ids,
                valid_ids=cfg.data.valid_ids,
                seed=cfg.seed,
                valid_split=cfg.data.valid_split,
                mean=cfg.data.mean,
                std=cfg.data.std,
                transforms=transforms,
                stats=stats,
            )

            train_bs = int(cfg.loader.train_bs / (scale**2))
            valid_bs = int(cfg.loader.valid_bs / (scale**2))
            print(
                f"train: {len(train_ds)}; bs {train_bs}",
                f"valid: {len(valid_ds)}, bs {valid_bs}",
            )

            # Data loaders
            data_loaders = D.get_data_loaders(
                train_ds=train_ds,
                valid_ds=valid_ds,
                train_bs=train_bs,
                valid_bs=valid_bs,
                num_workers=cfg.loader.num_workers,
            )

            # Select scheduler
            scheduler = get_scheduler(
                name=cfg.scheduler.type,
                optimizer=optimizer,
                num_epochs=num_epochs * (len(data_loaders["train"]) if
                                         cfg.scheduler.mode == "batch" else 1),
                eta_min=scheduler_warm_restart.get_last_lr()[0] /
                cfg.scheduler.eta_min_factor,
                plateau=cfg.scheduler.plateau,
            )

            runner.train(
                model=model,
                criterion=criterion,
                optimizer=optimizer,
                scheduler=scheduler,
                callbacks=callbacks,
                logdir=cfg.train.logdir,
                loaders=data_loaders,
                num_epochs=num_epochs,
                verbose=True,
                main_metric=cfg.train.main_metric,
                load_best_on_end=True,
                minimize_metric=False,
                check=cfg.check,
                fp16=dict(amp=cfg.amp),
            )

            # Set new initial LR for optimizer after restart
            scheduler_warm_restart.step()
            print(
                f"New LR for warm restart {scheduler_warm_restart.get_last_lr()[0]}"
            )

            # Find optimal threshold for dice score
            model.eval()
            best_th, dices = find_dice_threshold(model, data_loaders["valid"])
            print("Best dice threshold", best_th, np.max(dices[1]))
            np.save(f"dices_{size}.npy", dices)
    else:
        print("Validation only")
        # Datasets
        size = cfg.data.sizes[-1]
        train_ds, valid_ds = D.get_train_valid_datasets_from_path(
            # path=(cwd / cfg.data.path),
            path=(cwd / f"data/hubmap-{size}x{size}/"),
            train_ids=cfg.data.train_ids,
            valid_ids=cfg.data.valid_ids,
            seed=cfg.seed,
            valid_split=cfg.data.valid_split,
            mean=cfg.data.mean,
            std=cfg.data.std,
            transforms=transforms,
            stats=stats,
        )

        train_bs = int(cfg.loader.train_bs / (cfg.data.scale_factor**2))
        valid_bs = int(cfg.loader.valid_bs / (cfg.data.scale_factor**2))
        print(
            f"train: {len(train_ds)}; bs {train_bs}",
            f"valid: {len(valid_ds)}, bs {valid_bs}",
        )

        # Data loaders
        data_loaders = D.get_data_loaders(
            train_ds=train_ds,
            valid_ds=valid_ds,
            train_bs=train_bs,
            valid_bs=valid_bs,
            num_workers=cfg.loader.num_workers,
        )

        # Find optimal threshold for dice score
        model.eval()
        best_th, dices = find_dice_threshold(model, data_loaders["valid"])
        print("Best dice threshold", best_th, np.max(dices[1]))
        np.save(f"dices_val.npy", dices)

    #
    # # Load best checkpoint
    # checkpoint_path = Path(cfg.train.logdir) / "checkpoints/best.pth"
    # if checkpoint_path.exists():
    #     print(f"\nLoading checkpoint {str(checkpoint_path)}")
    #     state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))[
    #         "model_state_dict"
    #     ]
    #     model.load_state_dict(state_dict)
    #     del state_dict
    # model = model.to(device)
    # Load config for updating with threshold and metric
    # (otherwise loading do not work)
    cfg = OmegaConf.load(".hydra/config.yaml")
    cfg.threshold = float(best_th)

    # Evaluate on full-size image if valid_ids is non-empty
    df_train = pd.read_csv(cwd / "data/train.csv")
    df_train = {
        r["id"]: r["encoding"]
        for r in df_train.to_dict(orient="record")
    }
    dices = []
    unique_ids = sorted(
        set(
            str(p).split("/")[-1].split("_")[0]
            for p in (cwd / cfg.data.path / "train").iterdir()))
    size = cfg.data.sizes[-1]
    scale = size / 1024
    for image_id in cfg.data.valid_ids:
        image_name = unique_ids[image_id]
        print(f"\nValidate for {image_name}")

        rle_pred, shape = inference_one(
            image_path=(cwd / f"data/train/{image_name}.tiff"),
            target_path=Path("."),
            cfg=cfg,
            model=model,
            scale_factor=scale,
            tile_size=cfg.data.tile_size,
            tile_step=cfg.data.tile_step,
            threshold=best_th,
            save_raw=False,
            tta_mode=None,
            weight="pyramid",
            device=device,
            filter_crops="tissue",
            stats=stats,
        )

        print("Predict", shape)
        pred = rle_decode(rle_pred["predicted"], shape)
        mask = rle_decode(df_train[image_name], shape)
        assert pred.shape == mask.shape, f"pred {pred.shape}, mask {mask.shape}"
        assert pred.shape == shape, f"pred {pred.shape}, expected {shape}"

        dices.append(
            dice(
                torch.from_numpy(pred).type(torch.uint8),
                torch.from_numpy(mask).type(torch.uint8),
                threshold=None,
                activation="none",
            ))
    print("Full image dice:", np.mean(dices))
    OmegaConf.save(cfg, ".hydra/config.yaml")
    return
예제 #6
0
def train_segmentation_model(
        model: torch.nn.Module,
        logdir: str,
        num_epochs: int,
        loaders: Dict[str, DataLoader]
):
    criterion = {
        "dice": DiceLoss(),
        "iou": IoULoss(),
        "bce": nn.BCEWithLogitsLoss()
    }

    learning_rate = 0.001
    encoder_learning_rate = 0.0005

    layerwise_params = {"encoder*": dict(lr=encoder_learning_rate, weight_decay=0.00003)}
    model_params = utils.process_model_params(model, layerwise_params=layerwise_params)
    base_optimizer = RAdam(model_params, lr=learning_rate, weight_decay=0.0003)
    optimizer = Lookahead(base_optimizer)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.25, patience=2)

    device = utils.get_device()
    runner = SupervisedRunner(device=device, input_key='image', input_target_key='mask')

    callbacks = [
        CriterionCallback(
            input_key="mask",
            prefix="loss_dice",
            criterion_key="dice"
        ),
        CriterionCallback(
            input_key="mask",
            prefix="loss_iou",
            criterion_key="iou"
        ),
        CriterionCallback(
            input_key="mask",
            prefix="loss_bce",
            criterion_key="bce"
        ),

        MetricAggregationCallback(
            prefix="loss",
            mode="weighted_sum",
            metrics={"loss_dice": 1.0, "loss_iou": 1.0, "loss_bce": 0.8},
        ),

        # metrics
        DiceCallback(input_key='mask'),
        IouCallback(input_key='mask'),
    ]

    runner.train(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        loaders=loaders,
        callbacks=callbacks,
        logdir=logdir,
        num_epochs=num_epochs,
        main_metric="iou",
        minimize_metric=False,
        verbose=True,
        load_best_on_end=True,
    )
    best_model_save_dir = os.path.join(logdir, 'save')
    os.makedirs(best_model_save_dir, True)
    torch.save(model, os.path.join(best_model_save_dir, 'best_model.pth'))   # save best model (by valid loss)
    batch = next(iter(loaders["valid"]))
    try:
        runner.trace(model=model, batch=batch, logdir=logdir, fp16=False)  # optimized version (not all models can be traced)
    except Exception:
        pass
예제 #7
0
                    :shape_copy[3]] = state_dict[key][
                                      :shape_copy[0],
                                      :shape_copy[1],
                                      :shape_copy[2],
                                      :shape_copy[3]]
            else:
                model_state[key] = state_dict[key]

    model.load_state_dict(model_state, strict=False)

    model_wrapper = ModelWrapper(model)
    ##############################
    learning_rate = 0.001

    # This function removes weight_decay for biases and applies our layerwise_params
    model_params = utils.process_model_params(model_wrapper)

    # Catalyst has new SOTA optimizers out of box
    base_optimizer = RAdam(model_params, lr=learning_rate, weight_decay=0.0003)
    optimizer = Lookahead(base_optimizer, k=5, alpha=0.5)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.25, patience=2)


    #########################

    num_epochs = 50
    logdir = "logs_pretrain/segmentation"
    resume_path = f"{logdir}/checkpoints/last_full.pth"
    if not os.path.isfile(resume_path):
        resume_path = None