Exemple #1
0
 def _get_seed(self):
     if self.seeds is not None:
         seed = self.seeds[self.trajectory_index % len(self.seeds)]
     else:
         seed = self._seeder()[0]
     set_global_seed(seed)
     return seed
Exemple #2
0
    def _run_stage(self, stage: str) -> None:
        """
        Inner method to run stage on Runner,
        with stage callbacks events.

        Args:
            stage: stage name of interest,
                like "pretrain" / "train" / "finetune" / etc

        """
        self._prepare_for_stage(stage)

        self._run_event("on_stage_start")
        while self.epoch < self.num_epochs + 1:
            set_global_seed(self.experiment.initial_seed + self.global_epoch +
                            1)
            self._run_event("on_epoch_start")
            self._run_epoch(stage=stage, epoch=self.epoch)
            self._run_event("on_epoch_end")

            if self.need_early_stop:
                self.need_early_stop = False
                break

            self.global_epoch += 1
            self.epoch += 1
        self._run_event("on_stage_end")
Exemple #3
0
    def predict_loader(
        self,
        *,
        loader: DataLoader,
        model: Model = None,
        resume: str = None,
        fp16: Union[Dict, bool] = None,
        initial_seed: int = 42,
    ) -> Generator:
        """
        Runs model inference on PyTorch Dataloader and returns
        python generator with model predictions from `runner.predict_batch`.
        Cleans up the experiment info to avoid possible collisions.
        Sets `is_train_loader` and `is_valid_loader` to `False` while
        keeping `is_infer_loader` as True. Moves model to evaluation mode.

        Args:
            loader: loader to predict
            model: model to use for prediction
            resume: path to checkpoint to resume
            fp16 (Union[Dict, bool]): fp16 usage flag
            initial_seed: seed to use before prediction

        Yields:
            bathes with model predictions
        """
        if isinstance(fp16, bool) and fp16:
            fp16 = {"opt_level": "O1"}

        if model is not None:
            self.model = model
        assert self.model is not None

        if resume is not None:
            checkpoint = load_checkpoint(resume)
            unpack_checkpoint(checkpoint, model=self.model)

        self.experiment = None
        set_global_seed(initial_seed)
        (model, _, _, _, device) = process_components(  # noqa: WPS122
            model=self.model,
            distributed_params=fp16,
            device=self.device,
        )
        self._prepare_inner_state(
            stage="infer",
            model=model,
            device=device,
            is_train_loader=False,
            is_valid_loader=False,
            is_infer_loader=True,
        )
        maybe_recursive_call(self.model, "train", mode=False)

        set_global_seed(initial_seed)
        for batch in loader:
            yield self.predict_batch(batch)
Exemple #4
0
    def _run_epoch(self, stage: str, epoch: int) -> None:
        """
        Inner method to run epoch on Runner,
        with epoch callbacks events.

        Args:
            stage: stage name of interest,
                like "pretrain" / "train" / "finetune" / etc
            epoch: epoch index
        """
        self._prepare_for_epoch(stage=stage, epoch=epoch)
        assert self.loaders is not None

        for loader_name, loader in self.loaders.items():
            if len(loader) == 0:
                raise RunnerException(
                    f"DataLoader with name {loader_name} is empty.")

        self.is_infer_stage = self.stage_name.startswith("infer")
        if not self.is_infer_stage:
            assert self.valid_loader in self.loaders.keys(), (
                f"'{self.valid_loader}' "
                f"should be in provided loaders: {list(self.loaders.keys())}")
        else:
            assert not any(
                x.startswith(SETTINGS.loader_train_prefix)
                for x in self.loaders.keys()
            ), "for inference no train loader should be passed"

        for loader_name, loader in self.loaders.items():
            self.loader_name = loader_name
            self.loader_len = len(loader)
            self.is_train_loader = loader_name.startswith(
                SETTINGS.loader_train_prefix)
            self.is_valid_loader = loader_name.startswith(
                SETTINGS.loader_valid_prefix)
            self.is_infer_loader = loader_name.startswith(
                SETTINGS.loader_infer_prefix)
            maybe_recursive_call(
                self.model,
                "train",
                mode=self.is_train_loader,
            )

            if (isinstance(loader.sampler, DistributedSampler)
                    and not self.is_infer_stage):
                loader.sampler.set_epoch(self.epoch)

            set_global_seed(self.experiment.initial_seed + self.global_epoch +
                            1)
            self._run_event("on_loader_start")
            with torch.set_grad_enabled(self.is_train_loader):
                self._run_loader(loader)
            self._run_event("on_loader_end")
Exemple #5
0
def main_worker(args, unknown_args):
    """Runs main worker thread from model training."""
    args, config = parse_args_uargs(args, unknown_args)
    set_global_seed(args.seed)
    prepare_cudnn(args.deterministic, args.benchmark)

    config.setdefault("distributed_params", {})["apex"] = args.apex
    config.setdefault("distributed_params", {})["amp"] = args.amp

    experiment, runner, config = prepare_config_api_components(
        expdir=Path(args.expdir), config=config
    )

    if experiment.logdir is not None and get_rank() <= 0:
        dump_environment(config, experiment.logdir, args.configs)
        dump_code(args.expdir, experiment.logdir)

    runner.run_experiment(experiment)
Exemple #6
0
 def _prepare_seed(self):
     seed = self._seeder()[0]
     set_global_seed(seed)
Exemple #7
0
def get_loaders_from_params(
    batch_size: int = 1,
    num_workers: int = 0,
    drop_last: bool = False,
    per_gpu_scaling: bool = False,
    loaders_params: Dict[str, Any] = None,
    samplers_params: Dict[str, Any] = None,
    initial_seed: int = 42,
    get_datasets_fn: Callable = None,
    **data_params,
) -> "OrderedDict[str, DataLoader]":
    """
    Creates pytorch dataloaders from datasets and additional parameters.

    Args:
        batch_size: ``batch_size`` parameter
            from ``torch.utils.data.DataLoader``
        num_workers: ``num_workers`` parameter
            from ``torch.utils.data.DataLoader``
        drop_last: ``drop_last`` parameter
            from ``torch.utils.data.DataLoader``
        per_gpu_scaling: boolean flag,
            if ``True``, uses ``batch_size=batch_size*num_available_gpus``
        loaders_params (Dict[str, Any]): additional loaders parameters
        samplers_params (Dict[str, Any]): additional sampler parameters
        initial_seed: initial seed for ``torch.utils.data.DataLoader``
            workers
        get_datasets_fn(Callable): callable function to get dictionary with
            ``torch.utils.data.Datasets``
        **data_params: additional data parameters
            or dictionary with ``torch.utils.data.Datasets`` to use for
            pytorch dataloaders creation

    Returns:
        OrderedDict[str, DataLoader]: dictionary with
            ``torch.utils.data.DataLoader``

    Raises:
        NotImplementedError: if datasource is out of `Dataset` or dict
        ValueError: if batch_sampler option is mutually
            exclusive with distributed
    """
    from catalyst.data.sampler import DistributedSamplerWrapper

    default_batch_size = batch_size
    default_num_workers = num_workers
    loaders_params = loaders_params or {}
    assert isinstance(
        loaders_params,
        dict), f"`loaders_params` should be a Dict. " f"Got: {loaders_params}"
    samplers_params = samplers_params or {}
    assert isinstance(
        samplers_params,
        dict), f"`samplers_params` should be a Dict. Got: {samplers_params}"

    distributed_rank = get_rank()
    distributed = distributed_rank > -1

    if get_datasets_fn is not None:
        datasets = get_datasets_fn(**data_params)
    else:
        datasets = dict(**data_params)

    loaders = OrderedDict()
    for name, datasource in datasets.items():  # noqa: WPS426
        assert isinstance(
            datasource,
            (Dataset, dict
             )), f"{datasource} should be Dataset or Dict. Got: {datasource}"

        loader_params = loaders_params.pop(name, {})
        assert isinstance(loader_params,
                          dict), f"{loader_params} should be Dict"

        sampler_params = samplers_params.pop(name, None)
        if sampler_params is None:
            if isinstance(datasource, dict) and "sampler" in datasource:
                sampler = datasource.pop("sampler", None)
            else:
                sampler = None
        else:
            sampler = SAMPLER.get_from_params(**sampler_params)
            if isinstance(datasource, dict) and "sampler" in datasource:
                datasource.pop("sampler", None)

        batch_size = loader_params.pop("batch_size", default_batch_size)
        num_workers = loader_params.pop("num_workers", default_num_workers)

        if per_gpu_scaling and not distributed:
            num_gpus = max(1, torch.cuda.device_count())
            batch_size *= num_gpus
            num_workers *= num_gpus

        loader_params = {
            "batch_size": batch_size,
            "num_workers": num_workers,
            "pin_memory": torch.cuda.is_available(),
            "drop_last": drop_last,
            **loader_params,
        }

        if isinstance(datasource, Dataset):
            loader_params["dataset"] = datasource
        elif isinstance(datasource, dict):
            assert "dataset" in datasource, "You need to specify dataset for dataloader"
            loader_params = merge_dicts(datasource, loader_params)
        else:
            raise NotImplementedError

        if distributed:
            if sampler is not None:
                if not isinstance(sampler, DistributedSampler):
                    sampler = DistributedSamplerWrapper(sampler=sampler)
            else:
                sampler = DistributedSampler(dataset=loader_params["dataset"])

        loader_params["shuffle"] = name.startswith("train") and sampler is None

        loader_params["sampler"] = sampler

        if "batch_sampler" in loader_params:
            if distributed:
                raise ValueError("batch_sampler option is mutually "
                                 "exclusive with distributed")

            for k in ("batch_size", "shuffle", "sampler", "drop_last"):
                loader_params.pop(k, None)

        if "worker_init_fn" not in loader_params:
            loader_params["worker_init_fn"] = lambda x: set_global_seed(
                initial_seed + x)

        loaders[name] = DataLoader(**loader_params)

    return loaders
Exemple #8
0
 def _get_seed(self):
     seed = self._seeder()[0]
     if self.seeds is not None:
         seed = self.seeds[seed]
     set_global_seed(seed)
     return seed
Exemple #9
0
    def _prepare_for_stage(self, stage: str):
        """Inner method to prepare `Runner` for the specified stage.

        Sets `Experiment` initial seed.
        Prepares experiment components with `self._get_experiment_components`.
        Prepares callbacks with `self._get_experiment_callbacks`.
        Prepares inner state with `self._prepare_inner_state`
        Additionally sets `Experiment` datasources for specified stage.

        Args:
            stage: stage name of interest,
                like "pretrain" / "train" / "finetune" / etc
        """
        set_global_seed(self.experiment.initial_seed)
        loaders = self.experiment.get_loaders(stage=stage)
        loaders = validate_loaders(loaders)
        self.loaders = loaders

        set_global_seed(self.experiment.initial_seed)
        (
            model,
            criterion,
            optimizer,
            scheduler,
            device,
        ) = self._get_experiment_components(experiment=self.experiment,
                                            stage=stage,
                                            device=self.device)

        set_global_seed(self.experiment.initial_seed)
        callbacks = self._get_experiment_callbacks(experiment=self.experiment,
                                                   stage=stage)

        migrating_params = dict(**self.experiment.get_stage_params(stage))
        migrate_from_previous_stage = migrating_params.get(
            "migrate_from_previous_stage", True)
        if (migrate_from_previous_stage
                and getattr(self, "callbacks", None) is not None):
            for key, value in self.callbacks.items():
                if value.scope == CallbackScope.experiment:
                    callbacks[key] = value

        callbacks = sort_callbacks_by_order(callbacks)

        if migrate_from_previous_stage:
            migrating_params.update({
                "global_epoch":
                getattr(self, "global_epoch", 1),
                "global_batch_step":
                getattr(self, "global_batch_step", 0),
                "global_sample_step":
                getattr(self, "global_sample_step", 0),
                "resume":
                getattr(self, "resume", None),
            })

        self._prepare_inner_state(
            stage=stage,
            model=model,
            device=device,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders=getattr(self, "loaders", None),
            **migrating_params,
        )
Exemple #10
0
                   get_validation_augmentation)
import argparse
import pandas as pd
from catalyst.dl.runner import SupervisedRunner
from catalyst.dl import utils
from catalyst.dl.callbacks import (EarlyStoppingCallback, CriterionCallback,
                                   OptimizerCallback, DiceCallback,
                                   CheckpointCallback)
# import torch
from torch import optim
from torch.utils.data import DataLoader
from dataloader import CloudDataset
import segmentation_models_pytorch as smp
from catalyst.utils.seed import set_global_seed
from catalyst.utils.torch import prepare_cudnn
set_global_seed(2019)
prepare_cudnn(deterministic=True)

parser = argparse.ArgumentParser("PyTorch Segmentation Pipeline")
args = parser.add_argument('-E', '--epochs', default=1, type=int)
args = parser.add_argument('-F', '--fold', default=1, type=int)
args = parser.add_argument('-C', '--checkpoint', default=False, type=bool)
args = parser.add_argument('-M', '--model', default='AlbuNet', type=str)
args = parser.add_argument('-A', '--encoder', default='resnet18', type=str)
args = parser.add_argument('-P', '--pretrained', default=True, type=bool)
args = parser.add_argument('--lr', default=1e-4, type=float)
args = parser.add_argument('--lr_e', default=1e-4, type=float)
args = parser.add_argument('--lr_d', default=1e-4, type=float)
args = parser.add_argument('--bs', default=4, type=int)
args = parser.add_argument('--size', default=320, type=int)
args = parser.add_argument('--dice-weight', default=0.5, type=float)
Exemple #11
0
def main() -> None:
    config = load_config(CONFIG_FILE)
    train_config = config["train"]

    num_epochs = config.get("num epochs", 2)
    random_state = config.get("random state", 2019)
    num_workers = config.get("num workers", 6)
    batch_size = config["batch size"]

    train_dataset = get_dataset(**config["train"])
    valid_dataset = get_dataset(**config["validation"])

    data_loaders = OrderedDict()
    data_loaders["train"] = DataLoader(train_dataset,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       num_workers=num_workers)
    data_loaders["valid"] = DataLoader(valid_dataset,
                                       batch_size=batch_size,
                                       shuffle=False,
                                       num_workers=num_workers)

    set_global_seed(random_state)

    model = get_model(**config["model"])

    if CHECKPOINT != "" and os.path.exists(CHECKPOINT):
        checkpoint_state = torch.load(CHECKPOINT)["model_state_dict"]
        model.load_state_dict(checkpoint_state)
        print(f"Using {CHECKPOINT} checkpoint", flush=True)

    model = model.to(DEVICE)

    model_optimizer = get_optimizer(model.parameters(), **config["optimizer"])

    loss_function = get_loss(**config["loss"])
    metric = config.get("metric", "loss")
    is_metric_minimization = config.get("minimize metric", True)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        model_optimizer,
        mode="min" if is_metric_minimization else "max",
        patience=3,
        factor=0.2,
        verbose=True,
    )

    runner = SupervisedRunner(device=DEVICE)
    runner.train(
        model=model,
        criterion=loss_function,
        optimizer=model_optimizer,
        loaders=data_loaders,
        logdir=LOGDIR,
        callbacks=[
            cbks.DiceCallback(),
            cbks.IouCallback(),
            # PositiveAndNegativeDiceMetricCallback(),
            # ChannelviseDiceMetricCallback(),
            # MulticlassDiceMetricCallback(
            #     class_names=zip(range(4), list('0123')),
            #     avg_classes=list('0123')
            # ),
            cbks.CriterionCallback(),
            cbks.OptimizerCallback(
                accumulation_steps=4),  # accumulate gradients of 4 batches
            CheckpointCallback(save_n_best=3),
        ],
        scheduler=scheduler,
        verbose=True,
        minimize_metric=is_metric_minimization,
        num_epochs=num_epochs,
        main_metric=metric,
    )
Exemple #12
0
def main_worker(args, unknown_args):
    """Runs main worker thread from model training."""
    args, config = parse_args_uargs(args, unknown_args)
    set_global_seed(args.seed)
    prepare_cudnn(args.deterministic, args.benchmark)

    config.setdefault("distributed_params", {})["apex"] = args.apex
    config.setdefault("distributed_params", {})["amp"] = args.amp
    expdir = Path(args.expdir)

    # optuna objective
    def objective(trial: optuna.trial):
        trial, trial_config = _process_trial_config(trial, config.copy())
        experiment, runner, trial_config = prepare_config_api_components(
            expdir=expdir, config=trial_config
        )
        # @TODO: here we need better solution.
        experiment._trial = trial  # noqa: WPS437

        if experiment.logdir is not None and get_rank() <= 0:
            dump_environment(trial_config, experiment.logdir, args.configs)
            dump_code(args.expdir, experiment.logdir)

        runner.run_experiment(experiment)

        return runner.best_valid_metrics[runner.main_metric]

    # optuna direction
    direction = (
        "minimize"
        if config.get("stages", {})
        .get("stage_params", {})
        .get("minimize_metric", True)
        else "maximize"
    )

    # optuna study
    study_params = config.pop("study_params", {})

    # optuna sampler
    sampler_params = study_params.pop("sampler_params", {})
    optuna_sampler_type = sampler_params.pop("sampler", None)
    optuna_sampler = (
        optuna.samplers.__dict__[optuna_sampler_type](**sampler_params)
        if optuna_sampler_type is not None
        else None
    )

    # optuna pruner
    pruner_params = study_params.pop("pruner_params", {})
    optuna_pruner_type = pruner_params.pop("pruner", None)
    optuna_pruner = (
        optuna.pruners.__dict__[optuna_pruner_type](**pruner_params)
        if optuna_pruner_type is not None
        else None
    )

    study = optuna.create_study(
        direction=direction,
        storage=args.storage or study_params.pop("storage", None),
        study_name=args.study_name or study_params.pop("study_name", None),
        sampler=optuna_sampler,
        pruner=optuna_pruner,
    )
    study.optimize(
        objective,
        n_trials=args.n_trials,
        timeout=args.timeout,
        n_jobs=args.n_jobs or 1,
        gc_after_trial=args.gc_after_trial,
        show_progress_bar=args.show_progress_bar,
    )