コード例 #1
0
    def __init__(
        self,
        metric_key: str,
        optimizer_key: str = None,
        accumulation_steps: int = 1,
        grad_clip_fn: Union[str, Callable] = None,
        grad_clip_params: Dict = None,
    ):
        """Init."""
        super().__init__()
        self.metric_key = metric_key
        self.optimizer_key = optimizer_key
        self.optimizer = None
        self.criterion = None

        if isinstance(grad_clip_fn, str):
            self.grad_clip_fn = REGISTRY.get(grad_clip_fn)
        else:
            self.grad_clip_fn = grad_clip_fn
        if grad_clip_params is not None:
            self.grad_clip_fn = partial(self.grad_clip_fn, **grad_clip_params)

        self.accumulation_steps: int = accumulation_steps
        self._accumulation_counter: int = 0

        if self.optimizer_key is not None:
            self._prefix = f"{self.optimizer_key}"
            self._prefix_lr = f"lr/{self._prefix}"
            self._prefix_momentum = f"momentum/{self._prefix}"
            self._prefix_gradient = f"gradient/{self._prefix}"
        else:
            self._prefix_lr = "lr"
            self._prefix_momentum = "momentum"
            self._prefix_gradient = "gradient"
コード例 #2
0
ファイル: config.py プロジェクト: ricklentz/catalyst
 def _get_optimizer_from_params(self, model: RunnerModel, stage: str,
                                **params) -> RunnerOptimizer:
     # @TODO 1: refactor; this method is too long
     params = deepcopy(params)
     # learning rate linear scaling
     lr_scaling_params = params.pop("lr_linear_scaling", None)
     if lr_scaling_params:
         loaders_params = dict(self._stage_config[stage]["loaders"])
         lr, lr_scaling = do_lr_linear_scaling(
             lr_scaling_params=lr_scaling_params,
             batch_size=loaders_params.get("batch_size", 1),
             per_gpu_scaling=loaders_params.get("per_gpu_scaling", False),
         )
         params["lr"] = lr
     else:
         lr_scaling = 1.0
     # getting layer-wise parameters
     layerwise_params = params.pop("layerwise_params", OrderedDict())
     no_bias_weight_decay = params.pop("no_bias_weight_decay", True)
     # getting model parameters
     model_key = params.pop("_model", None)
     model_params = get_model_parameters(
         models=model,
         models_keys=model_key,
         layerwise_params=layerwise_params,
         no_bias_weight_decay=no_bias_weight_decay,
         lr_scaling=lr_scaling,
     )
     # instantiate optimizer
     optimizer = REGISTRY.get_from_params(**params, params=model_params)
     return optimizer
コード例 #3
0
ファイル: config.py プロジェクト: DimaOrekhov/catalyst
    def get_callbacks(self, stage: str) -> "OrderedDict[str, Callback]":
        """Returns the callbacks for a given stage."""
        callbacks_params = get_by_keys(self._stage_config,
                                       stage,
                                       "callbacks",
                                       default={})
        callbacks = OrderedDict(REGISTRY.get_from_params(**callbacks_params))

        is_callback_exists = lambda callback_fn: any(
            callback_isinstance(x, callback_fn) for x in callbacks.values())
        if self._verbose and not is_callback_exists(TqdmCallback):
            callbacks["_verbose"] = TqdmCallback()
        if self._timeit and not is_callback_exists(TimerCallback):
            callbacks["_timer"] = TimerCallback()
        if self._check and not is_callback_exists(CheckRunCallback):
            callbacks["_check"] = CheckRunCallback()
        if self._overfit and not is_callback_exists(BatchOverfitCallback):
            callbacks["_overfit"] = BatchOverfitCallback()

        if self._logdir is not None and not is_callback_exists(
                ICheckpointCallback):
            callbacks["_checkpoint"] = CheckpointCallback(logdir=os.path.join(
                self._logdir, "checkpoints"), )

        return callbacks
コード例 #4
0
ファイル: sys.py プロジェクト: Podidiving/catalyst
def get_config_runner(expdir: Path, config: Dict):
    """
    Imports and creates ConfigRunner instance.

    Args:
        expdir: experiment directory path
        config: dictionary with experiment Config

    Returns:
        ConfigRunner instance
    """
    config_copy = copy.deepcopy(config)

    if expdir is not None:
        dir_module = import_module(expdir)  # noqa: F841
        # runner_fn = getattr(dir_module, "Runner", None)

    runner_params = config_copy.get("runner", {})
    runner_from_config = runner_params.pop("_target_", None)
    assert runner_from_config is not None, "You should specify the ConfigRunner."
    runner_fn = REGISTRY.get(runner_from_config)
    # assert any(
    #     x is None for x in (runner_fn, runner_from_config)
    # ), "Runner is set both in code and config."
    # if runner_fn is None and runner_from_config is not None:
    #     runner_fn = REGISTRY.get(runner_from_config)

    runner = runner_fn(config=config_copy, **runner_params)

    return runner
コード例 #5
0
ファイル: config.py プロジェクト: DimaOrekhov/catalyst
 def get_criterion(self, stage: str) -> RunnerCriterion:
     """Returns the criterion for a given stage."""
     criterion_params = get_by_keys(self._stage_config,
                                    stage,
                                    "criterion",
                                    default={})
     criterion = REGISTRY.get_from_params(**criterion_params)
     return criterion or None
コード例 #6
0
    def __init__(self, in_features, activation_fn="Sigmoid"):
        """@TODO: Docs. Contribution is welcome."""
        super().__init__()

        activation_fn = REGISTRY.get_if_str(activation_fn)
        self.attn = nn.Sequential(
            nn.Conv2d(in_features, 1, kernel_size=1, stride=1, padding=0, bias=False),
            activation_fn(),
        )
コード例 #7
0
ファイル: config.py プロジェクト: ricklentz/catalyst
 def _get_callback_from_params(**params):
     params = deepcopy(params)
     wrapper_params = params.pop("_wrapper", None)
     callback = REGISTRY.get_from_params(**params)
     if wrapper_params is not None:
         wrapper_params["base_callback"] = callback
         callback = ConfigRunner._get_callback_from_params(
             **wrapper_params)  # noqa: WPS437
     return callback
コード例 #8
0
    def __init__(self,
                 vocab,
                 cleaners=[],
                 g2p=None,
                 words_separator="\t",
                 batch_size=1):
        """Processor initialization

        Parameters
        ----------
        vocab : List[str]
            List of all tokens, thats will be used after text processing.
            Use phonemes list if you want use g2p, or graphemes (alphabet characters) othervise
        cleaners : Union[List[Callable], List[dict]], optional
            List of cleaners callable objects, or their config dicts.
        g2p : Union[Callable, dict], optional
            g2p callable object or their config config dict.
        words_separator : str, optional
            Token thats will be separate words, by default "\t"
        batch_size : int, optional
            Batch size for data processing, by default 1
        """
        self.vocab = vocab
        self.words_separator = words_separator
        self.batch_size = batch_size

        self.token2id = {}

        # zero token id for padding
        for i, token in enumerate(self.vocab, 1):
            self.token2id[token] = i

        self.cleaners = []

        for cleaner in cleaners:
            if isinstance(cleaner, dict):
                cleaner = REGISTRY.get_from_params(**cleaner)

            self.cleaners.append(cleaner)

        if isinstance(g2p, dict):
            g2p = REGISTRY.get_from_params(**g2p)

        self.g2p = g2p
コード例 #9
0
    def __init__(
        self,
        arch: str = "resnet18",
        pretrained: bool = True,
        frozen: bool = True,
        pooling: str = None,
        pooling_kwargs: dict = None,
        cut_layers: int = 2,
        state_dict: Union[dict, str, Path] = None,
    ):
        """
        Args:
            arch: Name for resnet. Have to be one of
                resnet18, resnet34, resnet50, resnet101, resnet152
            pretrained: If True, returns a model pre-trained on ImageNet
            frozen: If frozen, sets requires_grad to False
            pooling: pooling
            pooling_kwargs: params for pooling
            state_dict (Union[dict, str, Path]): Path to ``torch.Model``
                or a dict containing parameters and persistent buffers.
        """
        super().__init__()

        resnet = torchvision.models.__dict__[arch](pretrained=pretrained)
        if state_dict is not None:
            if isinstance(state_dict, (Path, str)):
                state_dict = torch.load(str(state_dict))
            resnet.load_state_dict(state_dict)

        modules = list(resnet.children())[:-cut_layers]  # delete last layers

        if frozen:
            for module in modules:
                utils.set_requires_grad(module, requires_grad=False)

        if pooling is not None:
            pooling_kwargs = pooling_kwargs or {}
            pooling_layer_fn = REGISTRY.get(pooling)
            pooling_layer = (pooling_layer_fn(
                in_features=resnet.fc.in_features, **pooling_kwargs) if "attn"
                             in pooling.lower() else pooling_layer_fn(
                                 **pooling_kwargs))
            modules += [pooling_layer]

            if hasattr(pooling_layer, "out_features"):
                out_features = pooling_layer.out_features(
                    in_features=resnet.fc.in_features)
            else:
                out_features = None
        else:
            out_features = resnet.fc.in_features

        modules += [Flatten()]
        self.out_features = out_features

        self.encoder = nn.Sequential(*modules)
コード例 #10
0
ファイル: lookahead.py プロジェクト: Podidiving/catalyst
    def get_from_params(cls,
                        params: Dict,
                        base_optimizer_params: Dict = None,
                        **kwargs) -> "Lookahead":
        """@TODO: Docs. Contribution is welcome."""
        from catalyst.registry import REGISTRY

        base_optimizer = REGISTRY.get_from_params(params=params,
                                                  **base_optimizer_params)
        optimizer = cls(optimizer=base_optimizer, **kwargs)
        return optimizer
コード例 #11
0
ファイル: config.py プロジェクト: DimaOrekhov/catalyst
 def get_engine(self) -> IEngine:
     """Returns the engine for the run."""
     engine_params = self._config.get("engine", None)
     if engine_params is not None:
         engine = REGISTRY.get_from_params(**engine_params)
     else:
         engine = get_available_engine(fp16=self._fp16,
                                       ddp=self._ddp,
                                       amp=self._amp,
                                       apex=self._apex)
     return engine
コード例 #12
0
ファイル: processor.py プロジェクト: alxmamaev/ultimate_tts
    def __init__(self,
                 batch_size=1,
                 mel_extractor=None,
                 speaker_embedding_extractor=None,
                 prossody_extractor=None,
                 wav_max_value=32768):
        self.batch_size = batch_size
        self.wav_max_value = wav_max_value

        if isinstance(mel_extractor, dict):
            mel_extractor = REGISTRY.get_from_params(**mel_extractor)
        self.mel_extractor = mel_extractor

        if isinstance(speaker_embedding_extractor, dict):
            speaker_embedding_extractor = REGISTRY.get_from_params(
                **speaker_embedding_extractor)
        self.speaker_embedding_extractor = speaker_embedding_extractor

        if isinstance(prossody_extractor, dict):
            prossody_extractor = REGISTRY.get_from_params(**prossody_extractor)
        self.prossody_extractor = prossody_extractor
コード例 #13
0
ファイル: config.py プロジェクト: ricklentz/catalyst
    def _get_criterion_from_params(**params) -> RunnerCriterion:
        params = deepcopy(params)
        key_value_flag = params.pop("_key_value", False)

        if key_value_flag:
            criterion = {
                key: ConfigRunner._get_criterion_from_params(
                    **key_params)  # noqa: WPS437
                for key, key_params in params.items()
            }
        else:
            criterion = REGISTRY.get_from_params(**params)
        return criterion
コード例 #14
0
ファイル: config.py プロジェクト: DimaOrekhov/catalyst
    def get_datasets(self, stage: str) -> "OrderedDict[str, Dataset]":
        """
        Returns datasets for a given stage.

        Args:
            stage: stage name

        Returns:
            Dict: datasets objects
        """
        datasets_params = self._stage_config[stage]["loaders"]["datasets"]
        datasets = REGISTRY.get_from_params(**datasets_params)
        return OrderedDict(datasets)
コード例 #15
0
ファイル: config.py プロジェクト: Podidiving/catalyst
    def _get_model_from_params(**params) -> RunnerModel:
        params = deepcopy(params)
        is_key_value = params.pop("_key_value", False)

        if is_key_value:
            model = {
                model_key: ConfigRunner._get_model_from_params(**model_params)
                for model_key, model_params in params.items()
            }
            model = nn.ModuleDict(model)
        else:
            model = REGISTRY.get_from_params(**params)
        return model
コード例 #16
0
    def _get_scheduler_from_params(*, optimizer: RunnerOptimizer,
                                   **params) -> RunnerScheduler:
        params = deepcopy(params)

        is_key_value = params.pop("_key_value", False)
        optimizer_key = params.pop("_optimizer", None)
        optimizer = optimizer[optimizer_key] if optimizer_key else optimizer

        if is_key_value:
            scheduler: Dict[str, Scheduler] = {}
            for key, scheduler_params in params.items():
                scheduler[key] = ConfigRunner._get_scheduler_from_params(
                    **scheduler_params, optimizer=optimizer)  # noqa: WPS437
        else:
            scheduler = REGISTRY.get_from_params(**params, optimizer=optimizer)
        return scheduler
コード例 #17
0
ファイル: config.py プロジェクト: DimaOrekhov/catalyst
    def get_samplers(self, stage: str) -> "OrderedDict[str, Sampler]":
        """
        Returns samplers for a given stage.

        Args:
            stage: stage name

        Returns:
            Dict of samplers
        """
        samplers_params = get_by_keys(self._stage_config,
                                      stage,
                                      "loaders",
                                      "samplers",
                                      default={})
        samplers = REGISTRY.get_from_params(**samplers_params)
        return OrderedDict(samplers)
コード例 #18
0
ファイル: config.py プロジェクト: Podidiving/catalyst
    def get_loggers(self) -> Dict[str, ILogger]:
        """Returns the loggers for the run."""
        loggers_params = self._config.get("loggers", {})
        loggers = REGISTRY.get_from_params(**loggers_params)

        is_logger_exists = lambda logger_fn: any(
            isinstance(x, logger_fn) for x in loggers.values())
        if not is_logger_exists(ConsoleLogger):
            loggers["_console"] = ConsoleLogger()
        if self._logdir is not None and not is_logger_exists(CSVLogger):
            loggers["_csv"] = CSVLogger(logdir=self._logdir,
                                        use_logdir_postfix=True)
        if self._logdir is not None and not is_logger_exists(
                TensorboardLogger):
            loggers["_tensorboard"] = TensorboardLogger(
                logdir=self._logdir, use_logdir_postfix=True)

        return loggers
コード例 #19
0
def main(args):
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)

    verbose = config["dataset_preprocessing_params"]["verbose"]
    ignore_processors = set(args.ignore_processors.split(
        ",")) if args.ignore_processors is not None else set()

    for processor_name, processing_params in config[
            "dataset_preprocessing_params"]["processors"].items():
        if processor_name in ignore_processors:
            print(f"Ignore {processor_name}")
            continue

        processor = REGISTRY.get_from_params(**config[processor_name])
        processor.process_files(processing_params["inputs"],
                                processing_params["outputs"],
                                verbose=verbose)
コード例 #20
0
ファイル: config.py プロジェクト: ricklentz/catalyst
    def get_loggers(self) -> Dict[str, ILogger]:
        """Returns the loggers for the run."""
        loggers_params = self._config.get("loggers", {})
        loggers = {
            key: REGISTRY.get_from_params(**params)
            for key, params in loggers_params.items()
        }

        is_logger_exists = lambda logger_fn: any(
            isinstance(x, logger_fn) for x in loggers.values())
        if not is_logger_exists(ConsoleLogger):
            loggers["_console"] = ConsoleLogger()
        if self._logdir is not None and not is_logger_exists(CSVLogger):
            loggers["_csv"] = CSVLogger(logdir=self._logdir)
        if self._logdir is not None and not is_logger_exists(
                TensorboardLogger):
            loggers["_tensorboard"] = TensorboardLogger(
                logdir=os.path.join(self._logdir, "tensorboard"))

        return loggers
コード例 #21
0
    def __init__(self, transforms, batch_size=1):
        """AudioProcessor handles audios processing.
        This processor applies transforms in their order to the batch of audios.

        Parameters
        ----------
        transforms : Union[List[Callable], List[Dict]]
            List of callable transforms objects, or their config dicts.
        batch_size : int, optional
            Batch size for data processing, by default 1
        """

        self.transforms = []
        self.batch_size = batch_size

        for transform in transforms:
            if isinstance(transform, dict):
                transform = REGISTRY.get_from_params(**transform)

            self.transforms.append(transform)
コード例 #22
0
ファイル: backward.py プロジェクト: catalyst-team/catalyst
    def __init__(
        self,
        metric_key: str,
        grad_clip_fn: Union[str, Callable] = None,
        grad_clip_params: Dict = None,
        log_gradient: bool = False,
    ):
        """Init."""
        super().__init__()
        self.metric_key = metric_key

        if isinstance(grad_clip_fn, str):
            self.grad_clip_fn = REGISTRY.get(grad_clip_fn)
        else:
            self.grad_clip_fn = grad_clip_fn
        if grad_clip_params is not None:
            self.grad_clip_fn = partial(self.grad_clip_fn, **grad_clip_params)

        self._prefix_gradient = f"gradient/{metric_key}"
        self._log_gradient = log_gradient
コード例 #23
0
    def __init__(
        self,
        metric_key: str,
        model_key: str = None,
        optimizer_key: str = None,
        accumulation_steps: int = 1,
        grad_clip_fn: Union[str, Callable] = None,
        grad_clip_params: Dict = None,
    ):
        """Init."""
        super().__init__(order=CallbackOrder.optimizer, node=CallbackNode.all)
        self.metric_key = metric_key
        self.model_key = model_key
        self.optimizer_key = optimizer_key
        self.model = None
        self.optimizer = None
        self.criterion = None

        if isinstance(grad_clip_fn, str):
            self.grad_clip_fn = REGISTRY.get(grad_clip_fn)
        else:
            self.grad_clip_fn = grad_clip_fn

        self.accumulation_steps: int = accumulation_steps
        self._accumulation_counter: int = 0

        if self.model_key is not None or self.optimizer_key is not None:
            if self.model_key is not None and self.optimizer_key is not None:
                self._prefix = f"{self.model_key}_{self.optimizer_key}"
            elif self.model_key is not None:
                self._prefix = f"{self.model_key}"
            elif self.optimizer_key is not None:
                self._prefix = f"{self.optimizer_key}"
            self._prefix_lr = f"lr/{self._prefix}"
            self._prefix_momentum = f"momentum/{self._prefix}"
        else:
            self._prefix_lr = "lr"
            self._prefix_momentum = "momentum"

        if grad_clip_params is not None:
            self.grad_clip_fn = partial(self.grad_clip_fn, **grad_clip_params)
コード例 #24
0
    def __init__(
        self,
        transform: Sequence[Union[dict, nn.Module]],
        input_key: Union[str, int] = "image",
        output_key: Optional[Union[str, int]] = None,
    ) -> None:
        """Init."""
        super().__init__(order=CallbackOrder.Internal, node=CallbackNode.all)

        self.input_key = input_key
        self.output_key = output_key or self.input_key

        transforms: Sequence[nn.Module] = [
            item if isinstance(item, nn.Module) else REGISTRY.get_from_params(
                **item) for item in transform
        ]
        assert all(
            isinstance(t, nn.Module) for t in
            transforms), "`nn.Module` should be a base class for transforms"

        self.transform = nn.Sequential(*transforms)
コード例 #25
0
    def __init__(
        self,
        transform: Union[Callable, str],
        scope: str,
        input_key: Union[List[str], str] = None,
        output_key: Union[List[str], str] = None,
        transform_kwargs: Dict[str, Any] = None,
    ):
        """
        Preprocess your batch with specified function.

        Args:
            transform (Callable, str): Function to apply.
                If string will get function from registry.
            scope (str): ``"on_batch_end"`` (post-processing model output) or
                ``"on_batch_start"`` (pre-processing model input).
            input_key (Union[List[str], str], optional): Keys in batch dict to apply function.
                Defaults to ``None``.
            output_key (Union[List[str], str], optional): Keys for output.
                If None then will apply function inplace to ``keys_to_apply``.
                Defaults to ``None``.
            transform_kwargs (Dict[str, Any]): Kwargs for transform.

        Raises:
            TypeError: When keys is not str or a list.
                When ``scope`` is not in ``["on_batch_end", "on_batch_start"]``.
        """
        super().__init__(order=CallbackOrder.Internal)
        if isinstance(transform, str):
            transform = REGISTRY.get(transform)
        if transform_kwargs is not None:
            transform = partial(transform, **transform_kwargs)
        if input_key is not None:
            if not isinstance(input_key, (list, str)):
                raise TypeError("input key should be str or a list of str.")
            elif isinstance(input_key, str):
                input_key = [input_key]
            self._handle_batch = self._handle_value
        else:
            self._handle_batch = self._handle_key_value

        output_key = output_key or input_key
        if output_key is not None:
            if input_key is None:
                raise TypeError("You should define input_key in "
                                "case if output_key is not None")
            if not isinstance(output_key, (list, str)):
                raise TypeError("output key should be str or a list of str.")
            if isinstance(output_key, str):
                output_key = [output_key]
                transform = _tuple_wrapper(transform)

        if isinstance(scope,
                      str) and scope in ["on_batch_end", "on_batch_start"]:
            self.scope = scope
        else:
            raise TypeError(
                'Expected scope to be on of the ["on_batch_end", "on_batch_start"]'
            )
        self.input_key = input_key
        self.output_key = output_key
        self.transform = transform
コード例 #26
0
from .tts import models
from .tts.layers import losses
from . import callbacks
from . import runners
from .utils import text, audio, features, forced_aligner, durations_extractor
from .dataset import text_mel_collate_fn
from catalyst.registry import REGISTRY

# -- Register text processing modules --
REGISTRY.add_from_module(text.processor)
REGISTRY.add_from_module(text.cleaners)
REGISTRY.add_from_module(text.g2p)
REGISTRY.add_from_module(text.normalizers)

# -- Register audio processing modules --
REGISTRY.add_from_module(audio.processor)
REGISTRY.add_from_module(audio.transforms)

# -- Register feature extraction modules --
REGISTRY.add_from_module(features.processor)
REGISTRY.add_from_module(features.mel)
REGISTRY.add_from_module(features.xvectors)
REGISTRY.add_from_module(features.prosody)

# -- Register forced alignment extraction module --
REGISTRY.add_from_module(forced_aligner.processor)

# -- Register durations extraction module --
REGISTRY.add_from_module(durations_extractor.processor)

# -- Register data modules --
コード例 #27
0
def get_loaders_from_params(
    batch_size: int = 1,
    num_workers: int = 0,
    drop_last: bool = False,
    per_gpu_scaling: bool = False,
    loaders_params: Dict[str, Any] = None,
    samplers_params: Dict[str, Any] = None,
    initial_seed: int = 42,
    datasets_fn: Callable = None,
    **data_params,
) -> "OrderedDict[str, DataLoader]":
    """
    Creates pytorch dataloaders from datasets and additional parameters.

    Args:
        batch_size: ``batch_size`` parameter
            from ``torch.utils.data.DataLoader``
        num_workers: ``num_workers`` parameter
            from ``torch.utils.data.DataLoader``
        drop_last: ``drop_last`` parameter
            from ``torch.utils.data.DataLoader``
        per_gpu_scaling: boolean flag,
            if ``True``, scales batch_size in proportion to the number of GPUs
        loaders_params (Dict[str, Any]): additional loaders parameters
        samplers_params (Dict[str, Any]): additional sampler parameters
        initial_seed: initial seed for ``torch.utils.data.DataLoader``
            workers
        datasets_fn(Callable): callable function to get dictionary with
            ``torch.utils.data.Datasets``
        **data_params: additional data parameters
            or dictionary with ``torch.utils.data.Datasets`` to use for
            pytorch dataloaders creation

    Returns:
        OrderedDict[str, DataLoader]: dictionary with
            ``torch.utils.data.DataLoader``

    Raises:
        NotImplementedError: if datasource is out of `Dataset` or dict
        ValueError: if batch_sampler option is mutually
            exclusive with distributed
    """
    from catalyst.data.sampler import DistributedSamplerWrapper

    default_batch_size = batch_size
    default_num_workers = num_workers
    loaders_params = loaders_params or {}
    assert isinstance(loaders_params,
                      dict), (f"`loaders_params` should be a Dict. "
                              f"Got: {loaders_params}")
    samplers_params = samplers_params or {}
    assert isinstance(
        samplers_params,
        dict), f"`samplers_params` should be a Dict. Got: {samplers_params}"

    distributed_rank = get_rank()
    distributed = distributed_rank > -1

    if datasets_fn is not None:
        datasets = datasets_fn(**data_params)
    else:
        datasets = dict(**data_params)

    loaders = OrderedDict()
    for name, datasource in datasets.items():  # noqa: WPS426
        assert isinstance(
            datasource,
            (Dataset, dict
             )), f"{datasource} should be Dataset or Dict. Got: {datasource}"

        loader_params = loaders_params.pop(name, {})
        assert isinstance(loader_params,
                          dict), f"{loader_params} should be Dict"

        sampler_params = samplers_params.pop(name, None)
        if sampler_params is None:
            if isinstance(datasource, dict) and "sampler" in datasource:
                sampler = datasource.pop("sampler", None)
            else:
                sampler = None
        else:
            sampler = REGISTRY.get_from_params(**sampler_params)
            if isinstance(datasource, dict) and "sampler" in datasource:
                datasource.pop("sampler", None)

        batch_size = loader_params.pop("batch_size", default_batch_size)
        num_workers = loader_params.pop("num_workers", default_num_workers)

        if per_gpu_scaling and not distributed:
            num_gpus = max(1, torch.cuda.device_count())
            batch_size *= num_gpus
            num_workers *= num_gpus
        elif not per_gpu_scaling and distributed:
            world_size = get_distributed_params().pop("world_size", 1)
            if batch_size % world_size == 0:
                batch_size = int(batch_size / world_size)
            else:
                raise ValueError(
                    "For this distributed mode with per_gpu_scaling = False "
                    "you need to have batch_size divisible by number of GPUs")

        loader_params = {
            "batch_size": batch_size,
            "num_workers": num_workers,
            "pin_memory": torch.cuda.is_available(),
            "drop_last": drop_last,
            **loader_params,
        }

        if isinstance(datasource, Dataset):
            loader_params["dataset"] = datasource
        elif isinstance(datasource, dict):
            assert "dataset" in datasource, "You need to specify dataset for dataloader"
            loader_params = merge_dicts(datasource, loader_params)
        else:
            raise NotImplementedError

        if distributed:
            if sampler is not None:
                if not isinstance(sampler, DistributedSampler):
                    sampler = DistributedSamplerWrapper(sampler=sampler)
            else:
                sampler = DistributedSampler(dataset=loader_params["dataset"])

        loader_params["shuffle"] = name.startswith("train") and sampler is None

        loader_params["sampler"] = sampler

        if "batch_sampler" in loader_params:
            if distributed:
                raise ValueError("batch_sampler option is mutually "
                                 "exclusive with distributed")

            for k in ("batch_size", "shuffle", "sampler", "drop_last"):
                loader_params.pop(k, None)

        if "worker_init_fn" not in loader_params:
            loader_params["worker_init_fn"] = partial(
                _worker_init_fn, initial_seed=initial_seed)

        loaders[name] = DataLoader(**loader_params)

    return loaders
コード例 #28
0
ファイル: tts_runner.py プロジェクト: alxmamaev/ultimate_tts
    def get_collate_fn(self):
        data_params = self._config["data_params"]
        collate_fn = REGISTRY.get(data_params["collate_fn"])

        return collate_fn
コード例 #29
0
 def get_engine(self) -> IEngine:
     """Returns the engine for the run."""
     engine_params = self._config.get("engine")
     engine = REGISTRY.get_from_params(**engine_params)
     return engine
コード例 #30
0
ファイル: config.py プロジェクト: DimaOrekhov/catalyst
 def _get_loaders_from_params(
         self, **params) -> "Optional[OrderedDict[str, DataLoader]]":
     """Creates dataloaders from ``**params`` parameters."""
     loaders = dict(REGISTRY.get_from_params(**params))
     return loaders if all(
         isinstance(dl, DataLoader) for dl in loaders.values()) else None