示例#1
0
def _force_make_distributed_loader(loader: DataLoader) -> DataLoader:
    """
    Transfers loader to distributed mode. Experimental feature.

    Args:
        loader (DataLoader): pytorch dataloder

    Returns:
        DataLoader: pytorch dataloder with distributed sampler.
    """
    sampler = (DistributedSampler(dataset=loader.dataset) if getattr(
        loader, "sampler", None) is not None else DistributedSamplerWrapper(
            sampler=loader.sampler))
    loader = DataLoader(
        dataset=copy(loader.dataset),
        batch_size=loader.batch_size,
        # shuffle=loader.shuffle,
        sampler=sampler,
        # batch_sampler=loader.batch_sampler,
        num_workers=loader.num_workers,
        # collate_fn=loader.collate_fn,
        pin_memory=loader.pin_memory,
        drop_last=loader.drop_last,
    )
    return loader
示例#2
0
def get_loaders_from_params(
    batch_size: int = 1,
    num_workers: int = 0,
    drop_last: bool = False,
    per_gpu_scaling: bool = False,
    loaders_params: Dict[str, Any] = None,
    samplers_params: Dict[str, Any] = None,
    initial_seed: int = 42,
    get_datasets_fn: Callable = None,
    **data_params,
) -> "OrderedDict[str, DataLoader]":
    """
    Creates pytorch dataloaders from datasets and additional parameters.

    Args:
        batch_size (int): ``batch_size`` parameter
            from ``torch.utils.data.DataLoader``
        num_workers (int): ``num_workers`` parameter
            from ``torch.utils.data.DataLoader``
        drop_last (bool): ``drop_last`` parameter
            from ``torch.utils.data.DataLoader``
        per_gpu_scaling (bool): boolean flag,
            if ``True``, uses ``batch_size=batch_size*num_available_gpus``
        loaders_params (Dict[str, Any]): additional loaders parameters
        samplers_params (Dict[str, Any]): additional sampler parameters
        initial_seed (int): initial seed for ``torch.utils.data.DataLoader``
            workers
        get_datasets_fn(Callable): callable function to get dictionary with
            ``torch.utils.data.Datasets``
        **data_params: additional data parameters
            or dictionary with ``torch.utils.data.Datasets`` to use for
            pytorch dataloaders creation

    Returns:
        OrderedDict[str, DataLoader]: dictionary with
            ``torch.utils.data.DataLoader``

    Raises:
        NotImplementedError: if datasource is out of `Dataset` or dict
        ValueError: if batch_sampler option is mutually
            exclusive with distributed
    """
    default_batch_size = batch_size
    default_num_workers = num_workers
    loaders_params = loaders_params or {}
    assert isinstance(loaders_params,
                      dict), (f"`loaders_params` should be a Dict. "
                              f"Got: {loaders_params}")
    samplers_params = samplers_params or {}
    assert isinstance(
        samplers_params,
        dict), f"`samplers_params` should be a Dict. Got: {samplers_params}"

    distributed_rank = get_rank()
    distributed = distributed_rank > -1

    if get_datasets_fn is not None:
        datasets = get_datasets_fn(**data_params)
    else:
        datasets = dict(**data_params)

    loaders = OrderedDict()
    for name, datasource in datasets.items():  # noqa: WPS426
        assert isinstance(
            datasource,
            (Dataset, dict
             )), f"{datasource} should be Dataset or Dict. Got: {datasource}"

        loader_params = loaders_params.pop(name, {})
        assert isinstance(loader_params,
                          dict), f"{loader_params} should be Dict"

        sampler_params = samplers_params.pop(name, None)
        if sampler_params is None:
            if isinstance(datasource, dict) and "sampler" in datasource:
                sampler = datasource.pop("sampler", None)
            else:
                sampler = None
        else:
            sampler = SAMPLER.get_from_params(**sampler_params)
            if isinstance(datasource, dict) and "sampler" in datasource:
                datasource.pop("sampler", None)

        batch_size = loader_params.pop("batch_size", default_batch_size)
        num_workers = loader_params.pop("num_workers", default_num_workers)

        if per_gpu_scaling and not distributed:
            num_gpus = max(1, torch.cuda.device_count())
            batch_size *= num_gpus
            num_workers *= num_gpus

        loader_params = {
            "batch_size": batch_size,
            "num_workers": num_workers,
            "pin_memory": torch.cuda.is_available(),
            "drop_last": drop_last,
            **loader_params,
        }

        if isinstance(datasource, Dataset):
            loader_params["dataset"] = datasource
        elif isinstance(datasource, dict):
            assert (
                "dataset"
                in datasource), "You need to specify dataset for dataloader"
            loader_params = merge_dicts(datasource, loader_params)
        else:
            raise NotImplementedError

        if distributed:
            if sampler is not None:
                if not isinstance(sampler, DistributedSampler):
                    sampler = DistributedSamplerWrapper(sampler=sampler)
            else:
                sampler = DistributedSampler(dataset=loader_params["dataset"])

        loader_params["shuffle"] = name.startswith("train") and sampler is None

        loader_params["sampler"] = sampler

        if "batch_sampler" in loader_params:
            if distributed:
                raise ValueError("batch_sampler option is mutually "
                                 "exclusive with distributed")

            for k in ("batch_size", "shuffle", "sampler", "drop_last"):
                loader_params.pop(k, None)

        if "worker_init_fn" not in loader_params:
            loader_params["worker_init_fn"] = lambda x: set_global_seed(
                initial_seed + x)

        loaders[name] = DataLoader(**loader_params)

    return loaders
示例#3
0
    def get_loaders(
        self,
        stage: str,
        epoch: int = None,
    ) -> "OrderedDict[str, DataLoader]":
        """Returns the loaders for a given stage."""
        data_params = dict(self.stages_config[stage]["data_params"])

        default_batch_size = data_params.pop("batch_size", 1)
        default_num_workers = data_params.pop("num_workers")
        drop_last = data_params.pop("drop_last", False)
        per_gpu_scaling = data_params.pop("per_gpu_scaling", False)
        distributed_rank = utils.get_rank()
        distributed = distributed_rank > -1

        datasets = self.get_datasets(stage=stage, **data_params)

        overridden_loaders_params = data_params.pop("loaders_params", {})
        assert isinstance(
            overridden_loaders_params,
            dict), (f"`overridden_loaders_params` should be a Dict. "
                    f"Got: {overridden_loaders_params}")

        samplers_params = data_params.pop("samplers_params", {})
        assert isinstance(
            samplers_params, dict
        ), f"`samplers_params` should be a Dict. Got: {samplers_params}"

        loaders = OrderedDict()
        for name, ds_ in datasets.items():
            assert isinstance(
                ds_, (Dataset, dict)), f"{ds_} should be Dataset or Dict"

            overridden_loader_params = overridden_loaders_params.pop(name, {})
            assert isinstance(
                overridden_loader_params,
                dict), f"{overridden_loader_params} should be Dict"

            sampler_params = samplers_params.pop(name, None)
            if sampler_params is None:
                if isinstance(ds_, dict) and "sampler" in ds_:
                    sampler = ds_.pop("sampler", None)
                else:
                    sampler = None
            else:
                sampler = SAMPLERS.get_from_params(**sampler_params)
                if isinstance(ds_, dict) and "sampler" in ds_:
                    ds_.pop("sampler", None)

            batch_size = overridden_loader_params.pop("batch_size",
                                                      default_batch_size)
            num_workers = overridden_loader_params.pop("num_workers",
                                                       default_num_workers)

            if per_gpu_scaling and not distributed:
                num_gpus = max(1, torch.cuda.device_count())
                batch_size *= num_gpus
                num_workers *= num_gpus

            loader_params = {
                "batch_size": batch_size,
                "num_workers": num_workers,
                "pin_memory": torch.cuda.is_available(),
                "drop_last": drop_last,
                **overridden_loader_params,
            }

            if isinstance(ds_, Dataset):
                loader_params["dataset"] = ds_
            elif isinstance(ds_, dict):
                assert ("dataset"
                        in ds_), "You need to specify dataset for dataloader"
                loader_params = utils.merge_dicts(ds_, loader_params)
            else:
                raise NotImplementedError

            if distributed:
                if sampler is not None:
                    if not isinstance(sampler, DistributedSampler):
                        sampler = DistributedSamplerWrapper(sampler=sampler)
                else:
                    sampler = DistributedSampler(
                        dataset=loader_params["dataset"])

            loader_params["shuffle"] = (name.startswith("train")
                                        and sampler is None)

            loader_params["sampler"] = sampler

            if "batch_sampler" in loader_params:
                if distributed:
                    raise ValueError("batch_sampler option is mutually "
                                     "exclusive with distributed")

                for k in ("batch_size", "shuffle", "sampler", "drop_last"):
                    loader_params.pop(k, None)

            if "worker_init_fn" not in loader_params:
                loader_params[
                    "worker_init_fn"] = lambda x: utils.set_global_seed(
                        self.initial_seed + x)

            loaders[name] = DataLoader(**loader_params)

        return loaders
def main():
    parser = argparse.ArgumentParser()

    ###########################################################################################
    # Distributed-training related stuff
    parser.add_argument("--local_rank", type=int, default=0)
    ###########################################################################################

    parser.add_argument("-acc", "--accumulation-steps", type=int, default=1, help="Number of batches to process")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("-v", "--verbose", action="store_true")
    parser.add_argument("--fast", action="store_true")
    parser.add_argument(
        "-dd",
        "--data-dir",
        type=str,
        help="Data directory for INRIA sattelite dataset",
        default=os.environ.get("INRIA_DATA_DIR"),
    )
    parser.add_argument(
        "-dd-xview2", "--data-dir-xview2", type=str, required=False, help="Data directory for external xView2 dataset"
    )
    parser.add_argument("-m", "--model", type=str, default="b6_unet32_s2", help="")
    parser.add_argument("-b", "--batch-size", type=int, default=8, help="Batch Size during training, e.g. -b 64")
    parser.add_argument("-e", "--epochs", type=int, default=100, help="Epoch to run")
    # parser.add_argument('-es', '--early-stopping', type=int, default=None, help='Maximum number of epochs without improvement')
    # parser.add_argument('-fe', '--freeze-encoder', type=int, default=0, help='Freeze encoder parameters for N epochs')
    # parser.add_argument('-ft', '--fine-tune', action='store_true')
    parser.add_argument("-lr", "--learning-rate", type=float, default=1e-3, help="Initial learning rate")
    parser.add_argument("-l", "--criterion", type=str, required=True, action="append", nargs="+", help="Criterion")
    parser.add_argument(
        "-l2",
        "--criterion2",
        type=str,
        required=False,
        action="append",
        nargs="+",
        help="Criterion for stride 2 mask",
    )
    parser.add_argument(
        "-l4",
        "--criterion4",
        type=str,
        required=False,
        action="append",
        nargs="+",
        help="Criterion for stride 4 mask",
    )
    parser.add_argument(
        "-l8",
        "--criterion8",
        type=str,
        required=False,
        action="append",
        nargs="+",
        help="Criterion for stride 8 mask",
    )
    parser.add_argument(
        "-l16",
        "--criterion16",
        type=str,
        required=False,
        action="append",
        nargs="+",
        help="Criterion for stride 16 mask",
    )

    parser.add_argument("-o", "--optimizer", default="RAdam", help="Name of the optimizer")
    parser.add_argument(
        "-c", "--checkpoint", type=str, default=None, help="Checkpoint filename to use as initial model weights"
    )
    parser.add_argument("-w", "--workers", default=8, type=int, help="Num workers")
    parser.add_argument("-a", "--augmentations", default="hard", type=str, help="")
    parser.add_argument("-tm", "--train-mode", default="random", type=str, help="")
    parser.add_argument("--run-mode", default="fit_predict", type=str, help="")
    parser.add_argument("--transfer", default=None, type=str, help="")
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--size", default=512, type=int)
    parser.add_argument("-s", "--scheduler", default="multistep", type=str, help="")
    parser.add_argument("-x", "--experiment", default=None, type=str, help="")
    parser.add_argument("-d", "--dropout", default=None, type=float, help="Dropout before head layer")
    parser.add_argument("--opl", action="store_true")
    parser.add_argument(
        "--warmup", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters"
    )
    parser.add_argument("-wd", "--weight-decay", default=0, type=float, help="L2 weight decay")
    parser.add_argument("--show", action="store_true")
    parser.add_argument("--dsv", action="store_true")

    args = parser.parse_args()

    args.is_master = args.local_rank == 0
    args.distributed = False
    fp16 = args.fp16

    if "WORLD_SIZE" in os.environ:
        args.distributed = int(os.environ["WORLD_SIZE"]) > 1
        args.world_size = int(os.environ["WORLD_SIZE"])
        # args.world_size = torch.distributed.get_world_size()

        print("Initializing init_process_group", args.local_rank)

        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl")
        print("Initialized init_process_group", args.local_rank)

    is_master = args.is_master | (not args.distributed)

    if args.distributed:
        distributed_params = {"rank": args.local_rank, "syncbn": True}
        if args.fp16:
            distributed_params["amp"] = True
    else:
        if args.fp16:
            distributed_params = {}
            distributed_params["amp"] = True
        else:
            distributed_params = False

    set_manual_seed(args.seed + args.local_rank)
    catalyst.utils.set_global_seed(args.seed + args.local_rank)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

    data_dir = args.data_dir
    if data_dir is None:
        raise ValueError("--data-dir must be set")

    num_workers = args.workers
    num_epochs = args.epochs
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    model_name = args.model
    optimizer_name = args.optimizer
    image_size = args.size, args.size
    fast = args.fast
    augmentations = args.augmentations
    train_mode = args.train_mode
    scheduler_name = args.scheduler
    experiment = args.experiment
    dropout = args.dropout
    online_pseudolabeling = args.opl
    criterions = args.criterion
    criterions2 = args.criterion2
    criterions4 = args.criterion4
    criterions8 = args.criterion8
    criterions16 = args.criterion16

    verbose = args.verbose
    show = args.show
    accumulation_steps = args.accumulation_steps
    weight_decay = args.weight_decay
    extra_data_xview2 = args.data_dir_xview2

    run_train = num_epochs > 0
    need_weight_mask = any(c[0] == "wbce" for c in criterions)

    custom_model_kwargs = {"full_size_mask": False}
    if dropout is not None:
        custom_model_kwargs["dropout"] = float(dropout)

    if any([criterions2, criterions4, criterions8, criterions16]):
        custom_model_kwargs["need_supervision_masks"] = True
        print("Enabling supervision masks")

    model: nn.Module = get_model(model_name, num_classes=16, **custom_model_kwargs).cuda()

    if args.transfer:
        transfer_checkpoint = fs.auto_file(args.transfer)
        print("Transfering weights from model checkpoint", transfer_checkpoint)
        checkpoint = load_checkpoint(transfer_checkpoint)
        pretrained_dict = checkpoint["model_state_dict"]

        transfer_weights(model, pretrained_dict)

    if args.checkpoint:
        checkpoint = load_checkpoint(fs.auto_file(args.checkpoint))
        unpack_checkpoint(checkpoint, model=model)

        print("Loaded model weights from:", args.checkpoint)
        report_checkpoint(checkpoint)

    main_metric = "jaccard"

    current_time = datetime.now().strftime("%y%m%d_%H_%M")
    checkpoint_prefix = f"{current_time}_{args.model}"

    if fp16:
        checkpoint_prefix += "_fp16"

    if fast:
        checkpoint_prefix += "_fast"

    if online_pseudolabeling:
        checkpoint_prefix += "_opl"

    if extra_data_xview2:
        checkpoint_prefix += "_with_xview2"

    if experiment is not None:
        checkpoint_prefix = experiment

    default_callbacks = [
        JaccardMetricPerImage(
            input_key=INPUT_MASK_KEY,
            output_key=OUTPUT_MASK_KEY,
            prefix="jaccard",
            inputs_to_labels=depth2mask,
            outputs_to_labels=decode_depth_mask,
        ),
    ]

    if is_master:

        default_callbacks += [
            BestMetricCheckpointCallback(target_metric="jaccard", target_metric_minimize=False),
            HyperParametersCallback(
                hparam_dict={
                    "model": model_name,
                    "scheduler": scheduler_name,
                    "optimizer": optimizer_name,
                    "augmentations": augmentations,
                    "size": args.size,
                    "weight_decay": weight_decay,
                    "epochs": num_epochs,
                    "dropout": None if dropout is None else float(dropout),
                }
            ),
        ]

        if show:
            visualize_inria_predictions = partial(
                draw_inria_predictions,
                image_key=INPUT_IMAGE_KEY,
                image_id_key=INPUT_IMAGE_ID_KEY,
                targets_key=INPUT_MASK_KEY,
                outputs_key=OUTPUT_MASK_KEY,
                inputs_to_labels=depth2mask,
                outputs_to_labels=decode_depth_mask,
                max_images=16,
            )
            default_callbacks += [
                ShowPolarBatchesCallback(visualize_inria_predictions, metric="accuracy", minimize=False),
                ShowPolarBatchesCallback(visualize_inria_predictions, metric="loss", minimize=True),
            ]

    train_ds, valid_ds, train_sampler = get_datasets(
        data_dir=data_dir,
        image_size=image_size,
        augmentation=augmentations,
        train_mode=train_mode,
        buildings_only=(train_mode == "tiles"),
        fast=fast,
        need_weight_mask=need_weight_mask,
        make_mask_target_fn=mask_to_ce_target,
    )

    if extra_data_xview2 is not None:
        extra_train_ds, _ = get_xview2_extra_dataset(
            extra_data_xview2,
            image_size=image_size,
            augmentation=augmentations,
            fast=fast,
            need_weight_mask=need_weight_mask,
        )

        weights = compute_sample_weight("balanced", [0] * len(train_ds) + [1] * len(extra_train_ds))
        train_sampler = WeightedRandomSampler(weights, train_sampler.num_samples * 2)

        train_ds = train_ds + extra_train_ds
        print("Using extra data from xView2 with", len(extra_train_ds), "samples")

    if run_train:
        loaders = collections.OrderedDict()
        callbacks = default_callbacks.copy()
        criterions_dict = {}
        losses = []

        ignore_index = None
        if online_pseudolabeling:
            ignore_index = UNLABELED_SAMPLE
            unlabeled_label = get_pseudolabeling_dataset(
                data_dir, include_masks=False, augmentation=None, image_size=image_size
            )

            unlabeled_train = get_pseudolabeling_dataset(
                data_dir, include_masks=True, augmentation=augmentations, image_size=image_size
            )

            if args.distributed:
                label_sampler = DistributedSampler(unlabeled_label, args.world_size, args.local_rank, shuffle=False)
            else:
                label_sampler = None

            loaders["infer"] = DataLoader(
                unlabeled_label,
                batch_size=batch_size // 2,
                num_workers=num_workers,
                pin_memory=True,
                sampler=label_sampler,
                drop_last=False,
            )

            if train_sampler is not None:
                num_samples = 2 * train_sampler.num_samples
            else:
                num_samples = 2 * len(train_ds)
            weights = compute_sample_weight("balanced", [0] * len(train_ds) + [1] * len(unlabeled_label))

            train_sampler = WeightedRandomSampler(weights, num_samples, replacement=True)
            train_ds = train_ds + unlabeled_train

            callbacks += [
                BCEOnlinePseudolabelingCallback2d(
                    unlabeled_train,
                    pseudolabel_loader="infer",
                    prob_threshold=0.7,
                    output_key=OUTPUT_MASK_KEY,
                    unlabeled_class=UNLABELED_SAMPLE,
                    label_frequency=5,
                )
            ]

            print("Using online pseudolabeling with ", len(unlabeled_label), "samples")

        valid_sampler = None
        if args.distributed:
            if train_sampler is not None:
                train_sampler = DistributedSamplerWrapper(
                    train_sampler, args.world_size, args.local_rank, shuffle=True
                )
            else:
                train_sampler = DistributedSampler(train_ds, args.world_size, args.local_rank, shuffle=True)
            valid_sampler = DistributedSampler(valid_ds, args.world_size, args.local_rank, shuffle=False)

        loaders["train"] = DataLoader(
            train_ds,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=train_sampler is None,
            sampler=train_sampler,
        )

        loaders["valid"] = DataLoader(
            valid_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True, sampler=valid_sampler
        )

        loss_callbacks, loss_criterions = get_criterions(
            criterions, criterions2, criterions4, criterions8, criterions16
        )
        callbacks += loss_callbacks

        optimizer = get_optimizer(
            optimizer_name, get_optimizable_parameters(model), learning_rate, weight_decay=weight_decay
        )
        scheduler = get_scheduler(
            scheduler_name, optimizer, lr=learning_rate, num_epochs=num_epochs, batches_in_epoch=len(loaders["train"])
        )
        if isinstance(scheduler, (CyclicLR, OneCycleLRWithWarmup)):
            callbacks += [SchedulerCallback(mode="batch")]

        log_dir = os.path.join("runs", checkpoint_prefix)

        if is_master:
            os.makedirs(log_dir, exist_ok=False)
            config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json")
            with open(config_fname, "w") as f:
                train_session_args = vars(args)
                f.write(json.dumps(train_session_args, indent=2))

            print("Train session    :", checkpoint_prefix)
            print("  FP16 mode      :", fp16)
            print("  Fast mode      :", args.fast)
            print("  Train mode     :", train_mode)
            print("  Epochs         :", num_epochs)
            print("  Workers        :", num_workers)
            print("  Data dir       :", data_dir)
            print("  Log dir        :", log_dir)
            print("  Augmentations  :", augmentations)
            print("  Train size     :", "batches", len(loaders["train"]), "dataset", len(train_ds))
            print("  Valid size     :", "batches", len(loaders["valid"]), "dataset", len(valid_ds))
            print("Model            :", model_name)
            print("  Parameters     :", count_parameters(model))
            print("  Image size     :", image_size)
            print("Optimizer        :", optimizer_name)
            print("  Learning rate  :", learning_rate)
            print("  Batch size     :", batch_size)
            print("  Criterion      :", criterions)
            print("  Use weight mask:", need_weight_mask)
            if args.distributed:
                print("Distributed")
                print("  World size     :", args.world_size)
                print("  Local rank     :", args.local_rank)
                print("  Is master      :", args.is_master)

        # model training
        runner = SupervisedRunner(input_key=INPUT_IMAGE_KEY, output_key=None, device="cuda")
        runner.train(
            fp16=distributed_params,
            model=model,
            criterion=loss_criterions,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "main"),
            num_epochs=num_epochs,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=False,
            checkpoint_data={"cmd_args": vars(args)},
        )

        # Training is finished. Let's run predictions using best checkpoint weights
        if is_master:
            best_checkpoint = os.path.join(log_dir, "main", "checkpoints", "best.pth")

            model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}.pth")
            clean_checkpoint(best_checkpoint, model_checkpoint)

            unpack_checkpoint(torch.load(model_checkpoint), model=model)

            mask = predict(
                model, read_inria_image("sample_color.jpg"), image_size=image_size, batch_size=args.batch_size
            )
            mask = ((mask > 0) * 255).astype(np.uint8)
            name = os.path.join(log_dir, "sample_color.jpg")
            cv2.imwrite(name, mask)