def on_batch_end(self, runner: "IRunner") -> None: """On batch end event Args: runner: current runner """ # Drop the cache when we exit to a nesting level # that's outside any instance of autocast. if torch.autocast_decrement_nesting() == 0: torch.clear_autocast_cache() torch.set_autocast_enabled(self.prev_autocast_state) if not runner.is_train_loader: return loss = runner.batch_metrics[self.metric_key] self._accumulation_counter += 1 need_gradient_step = ( self._accumulation_counter % self.accumulation_steps == 0 ) self.scaler.scale(loss).backward() if need_gradient_step: self.grad_step( optimizer=self._optimizer, grad_clip_fn=self.grad_clip_fn, ) if not self.use_fast_zero_grad: maybe_recursive_call(self._optimizer, "zero_grad") else: maybe_recursive_call(self._optimizer, zero_grad) self._accumulation_counter = 0
def on_loader_start(self, runner: "IRunner"): """Event handler for loader start. Args: runner: IRunner instance. Raises: RunnerException: if current DataLoader is empty. """ assert self.loader is not None self.loader_len = len(self.loader) if self.loader_len == 0: raise RunnerException( f"DataLoader with name {self.loader_key} is empty.") self.loader_batch_size = (self.loader.batch_sampler.batch_size if self.loader.batch_sampler is not None else self.loader.batch_size) self.loader_sample_step = 0 self.is_train_loader = self.loader_key.startswith( SETTINGS.loader_train_prefix) self.is_valid_loader = self.loader_key.startswith( SETTINGS.loader_valid_prefix) self.is_infer_loader = self.loader_key.startswith( SETTINGS.loader_infer_prefix) maybe_recursive_call(self.model, "train", mode=self.is_train_loader) if isinstance(self.loader.sampler, DistributedSampler): self.loader.sampler.set_epoch(self.epoch) set_global_seed(self.experiment.initial_seed + self.global_epoch + 1)
def unpack_checkpoint(checkpoint, model=None, criterion=None, optimizer=None, scheduler=None): """@TODO: Docs. Contribution is welcome.""" if model is not None: model = get_nn_from_ddp_module(model) maybe_recursive_call( model, "load_state_dict", recursive_args=checkpoint["model_state_dict"], ) for dict2load, name2load in zip( [criterion, optimizer, scheduler], ["criterion", "optimizer", "scheduler"], ): if dict2load is None: continue if isinstance(dict2load, dict): for key, value in dict2load.items(): if value is not None: state_dict2load = f"{name2load}_{key}_state_dict" value.load_state_dict(checkpoint[state_dict2load]) else: name2load = f"{name2load}_state_dict" dict2load.load_state_dict(checkpoint[name2load])
def predict_loader( self, *, loader: DataLoader, model: Model = None, resume: str = None, fp16: Union[Dict, bool] = None, initial_seed: int = 42, ) -> Generator: """ Runs model inference on PyTorch Dataloader and returns python generator with model predictions from `runner.predict_batch`. Cleans up the experiment info to avoid possible collisions. Sets `is_train_loader` and `is_valid_loader` to `False` while keeping `is_infer_loader` as True. Moves model to evaluation mode. Args: loader: loader to predict model: model to use for prediction resume: path to checkpoint to resume fp16 (Union[Dict, bool]): fp16 usage flag initial_seed: seed to use before prediction Yields: bathes with model predictions """ if isinstance(fp16, bool) and fp16: fp16 = {"opt_level": "O1"} if model is not None: self.model = model assert self.model is not None if resume is not None: checkpoint = load_checkpoint(resume) unpack_checkpoint(checkpoint, model=self.model) self.experiment = None set_global_seed(initial_seed) (model, _, _, _, device) = process_components( # noqa: WPS122 model=self.model, distributed_params=fp16, device=self.device, ) self._prepare_inner_state( stage="infer", model=model, device=device, is_train_loader=False, is_valid_loader=False, is_infer_loader=True, ) maybe_recursive_call(self.model, "train", mode=False) set_global_seed(initial_seed) for batch in loader: yield self.predict_batch(batch)
def pack_checkpoint( model: nn.Module = None, criterion: nn.Module = None, optimizer=None, scheduler=None, **kwargs, ): """ Packs ``model``, ``criterion``, ``optimizer``, ``scheduler`` and some extra info ``**kwargs`` to torch-based checkpoint. Args: model: torch model criterion: torch criterion optimizer: torch optimizer scheduler: torch scheduler **kwargs: some extra info to pack Returns: torch-based checkpoint with ``model_state_dict``, ``criterion_state_dict``, ``optimizer_state_dict``, ``scheduler_state_dict`` keys. """ checkpoint = kwargs if isinstance(model, dict): for key, value in model.items(): model_module = get_nn_from_ddp_module(value) checkpoint[f"model_{key}_state_dict"] = maybe_recursive_call( model_module, "state_dict") else: model_module = get_nn_from_ddp_module(model) checkpoint["model_state_dict"] = maybe_recursive_call( model_module, "state_dict") for dict2save, name2save in zip( [criterion, optimizer, scheduler], ["criterion", "optimizer", "scheduler"], ): if dict2save is None: continue # @TODO refactor with maybe_recursive_call (?) if isinstance(dict2save, dict): for key, value in dict2save.items(): if value is not None: state_dict2save = name2save + "_" + str(key) # checkpoint[name2save_] = value state_dict2save = state_dict2save + "_state_dict" checkpoint[state_dict2save] = value.state_dict() else: # checkpoint[name2save] = dict2save name2save = name2save + "_state_dict" checkpoint[name2save] = dict2save.state_dict() return checkpoint
def _run_epoch(self, stage: str, epoch: int) -> None: """ Inner method to run epoch on Runner, with epoch callbacks events. Args: stage: stage name of interest, like "pretrain" / "train" / "finetune" / etc epoch: epoch index """ self._prepare_for_epoch(stage=stage, epoch=epoch) assert self.loaders is not None for loader_name, loader in self.loaders.items(): if len(loader) == 0: raise RunnerException( f"DataLoader with name {loader_name} is empty.") self.is_infer_stage = self.stage_name.startswith("infer") if not self.is_infer_stage: assert self.valid_loader in self.loaders.keys(), ( f"'{self.valid_loader}' " f"should be in provided loaders: {list(self.loaders.keys())}") else: assert not any( x.startswith(SETTINGS.loader_train_prefix) for x in self.loaders.keys() ), "for inference no train loader should be passed" for loader_name, loader in self.loaders.items(): self.loader_name = loader_name self.loader_len = len(loader) self.is_train_loader = loader_name.startswith( SETTINGS.loader_train_prefix) self.is_valid_loader = loader_name.startswith( SETTINGS.loader_valid_prefix) self.is_infer_loader = loader_name.startswith( SETTINGS.loader_infer_prefix) maybe_recursive_call( self.model, "train", mode=self.is_train_loader, ) if (isinstance(loader.sampler, DistributedSampler) and not self.is_infer_stage): loader.sampler.set_epoch(self.epoch) set_global_seed(self.experiment.initial_seed + self.global_epoch + 1) self._run_event("on_loader_start") with torch.set_grad_enabled(self.is_train_loader): self._run_loader(loader) self._run_event("on_loader_end")
def predict_loader( self, *, loader: DataLoader, model: TorchModel = None, engine: Union["Engine", str] = None, seed: int = 42, # extra info resume: str = None, # engine extra params, cpu: bool = False, fp16: bool = False, ) -> Generator: """ Runs model inference on PyTorch DataLoader and returns python generator with model predictions from `runner.predict_batch`. Args: loader: loader to predict model: model to use for prediction engine: engine to use for prediction seed: random seed to use before prediction resume: path to checkpoint for model cpu: boolean flag to force CPU usage fp16: boolean flag to use half-precision Yields: bathes with model predictions .. note:: Please follow the `minimal examples`_ sections for use cases. .. _`minimal examples`: https://github.com/catalyst-team/catalyst#minimal-examples # noqa: E501, W505 """ self.engine = engine or get_available_engine(cpu=cpu, fp16=fp16) if model is not None: self.model = model assert self.model is not None if resume is not None: self.engine.wait_for_everyone() unwrapped_model = self.engine.unwrap_model(self.model) unwrapped_model.load_state_dict(load_checkpoint(resume)) self.model = self.engine.prepare(self.model) maybe_recursive_call(self.model, "train", mode=False) loader = self.engine.prepare(loader) set_global_seed(seed) for batch in loader: yield self.predict_batch(batch)
def on_batch_end(self, runner: "IRunner") -> None: """On batch end event Args: runner: current runner """ if self.use_amp: # Drop the cache when we exit to a nesting level # that's outside any instance of autocast. if torch.autocast_decrement_nesting() == 0: torch.clear_autocast_cache() torch.set_autocast_enabled(self.prev_autocast_state) if not runner.is_train_loader: return loss = runner.batch_metrics[self.metric_key] self._accumulation_counter += 1 need_gradient_step = (self._accumulation_counter % self.accumulation_steps == 0) # @TODO: speedup with re-definition ``on_stage_start`` if self.use_apex: from apex import amp # Need to set ``delay_unscale`` # according to # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations delay_unscale = not need_gradient_step with amp.scale_loss(loss, self._optimizer, delay_unscale=delay_unscale) as scaled_loss: scaled_loss.backward() elif self.use_amp: self.scaler.scale(loss).backward() else: loss.backward() if need_gradient_step: self.grad_step( optimizer=self._optimizer, grad_clip_fn=self.grad_clip_fn, ) if not self.use_fast_zero_grad: maybe_recursive_call(self._optimizer, "zero_grad") else: maybe_recursive_call(self._optimizer, zero_grad) self._accumulation_counter = 0
def predict_loader( self, *, loader: DataLoader, model: Model = None, engine: Union["IEngine", str] = None, seed: int = 42, # engine extra params, fp16: bool = False, amp: bool = False, apex: bool = False, ddp: bool = False, ) -> Generator: """ Runs model inference on PyTorch DataLoader and returns python generator with model predictions from `runner.predict_batch`. Args: loader: loader to predict model: model to use for prediction engine: engine to use for prediction seed: random seed to use before prediction fp16: boolean flag to use half-precision training (AMP > APEX) amp: boolean flag to use amp half-precision apex: boolean flag to use apex half-precision ddp: if `True` will start training in distributed mode. Note: Works only with python scripts. No jupyter support. Yields: bathes with model predictions """ self._engine = engine or get_available_engine( fp16=fp16, ddp=ddp, amp=amp, apex=apex) if model is not None: self.model = model assert self.model is not None # if resume is not None: # checkpoint = load_checkpoint(resume) # unpack_checkpoint(checkpoint, model=self.model) self.model = self.engine.sync_device(self.model) maybe_recursive_call(self.model, "train", mode=False) set_global_seed(seed) for batch in loader: yield self.predict_batch(batch)
def pack_checkpoint(model=None, criterion=None, optimizer=None, scheduler=None, **kwargs): """@TODO: Docs. Contribution is welcome.""" checkpoint = kwargs if isinstance(model, OrderedDict): raise NotImplementedError() else: model_module = get_nn_from_ddp_module(model) checkpoint["model_state_dict"] = maybe_recursive_call( model_module, "state_dict") for dict2save, name2save in zip( [criterion, optimizer, scheduler], ["criterion", "optimizer", "scheduler"], ): if dict2save is None: continue # @TODO refactor with maybe_recursive_call if isinstance(dict2save, dict): for key, value in dict2save.items(): if value is not None: state_dict2save = name2save + "_" + str(key) # checkpoint[name2save_] = value state_dict2save = state_dict2save + "_state_dict" checkpoint[state_dict2save] = value.state_dict() else: # checkpoint[name2save] = dict2save name2save = name2save + "_state_dict" checkpoint[name2save] = dict2save.state_dict() return checkpoint
def device(self, value: Device): """ Setter for the runner's device. Args: value: new torch device. Raises: TypeError: if `value` is out of `torch.device`, `str` or `None` """ if isinstance(value, torch.device): self._device = value elif isinstance(value, str): self._device = torch.device(value) elif isinstance(value, type(None)): self._device = None else: raise TypeError(f"Invalid value type " f"must be `str` or `torch.device` " f"got '{type(value)}'") if self._model is not None: self._model = maybe_recursive_call(self._model, "to", device=self._device or "cpu")
def model(self, value: Union[Model, Dict[str, Model]]): """ Setter for the runner's model, useful for experiment tracing. Args: value (Union[Model, Dict[str, Model]]): new model. Raises: TypeError: if value is out of `torch.nn.Module` or `Dict[str, torch.nn.Module]` """ if isinstance(value, nn.Module): model = value elif isinstance(value, dict): values_are_models = all( isinstance(v, nn.Module) for v in value.values()) if not values_are_models: raise TypeError( "Invalid dict value type, must be `torch.nn.Module`") model = value elif isinstance(value, type(None)): model = None else: raise TypeError( f"Invalid value type " f"must be `torch.nn.Module` or `Dict[str, torch.nn.Module]` " f"got '{type(value)}'") if model is not None and self._device is not None: model: Model = maybe_recursive_call(model, "to", device=self._device) self._model = model
def on_batch_end(self, runner: "IRunner") -> None: """On batch end event Args: runner: current runner """ if not runner.is_train_loader: return loss = runner.batch_metrics[self.metric_key] self._accumulation_counter += 1 need_gradient_step = ( self._accumulation_counter % self.accumulation_steps == 0 ) # This is very hacky check whether we have AMP optimizer and this may # change in future. # But alternative solution is to have AmpOptimizerCallback. # or expose another c'tor argument. # @TODO: speedup with re-definition ``on_stage_start`` if hasattr(self._optimizer, "_amp_stash"): from apex import amp # Need to set ``delay_unscale`` # according to # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations delay_unscale = not need_gradient_step with amp.scale_loss( loss, self._optimizer, delay_unscale=delay_unscale ) as scaled_loss: scaled_loss.backward() else: loss.backward() if need_gradient_step: self.grad_step( optimizer=self._optimizer, optimizer_wds=self._optimizer_wd, grad_clip_fn=self.grad_clip_fn, ) if not self.use_fast_zero_grad: maybe_recursive_call(self._optimizer, "zero_grad") else: maybe_recursive_call(self._optimizer, zero_grad) self._accumulation_counter = 0
def _process_trial_config(trial, config: Dict) -> Tuple[optuna.Trial, Dict]: def _eval_trial_suggestions(x): nonlocal trial if isinstance(x, str) and "trial.suggest_" in x: x = eval(x) return x config = maybe_recursive_call(config, _eval_trial_suggestions) return trial, config
def unpack_checkpoint( checkpoint: Dict, model: RunnerModel = None, criterion: RunnerCriterion = None, optimizer: RunnerOptimizer = None, scheduler: RunnerScheduler = None, ) -> None: """Load checkpoint from file and unpack the content to a model (if not None), criterion (if not None), optimizer (if not None), scheduler (if not None). Args: checkpoint: checkpoint to load model: model where should be updated state criterion: criterion where should be updated state optimizer: optimizer where should be updated state scheduler: scheduler where should be updated state """ if model is not None: model = get_nn_from_ddp_module(model) maybe_recursive_call( model, "load_state_dict", recursive_args=checkpoint["model_state_dict"], ) for dict2load, name2load in zip( [criterion, optimizer, scheduler], ["criterion", "optimizer", "scheduler"], ): if dict2load is None: continue if isinstance(dict2load, dict): for key, value in dict2load.items(): if value is not None: state_dict2load = f"{name2load}_{key}_state_dict" value.load_state_dict(checkpoint[state_dict2load]) else: name2load = f"{name2load}_state_dict" dict2load.load_state_dict(checkpoint[name2load])
def predict_loader( self, *, loader: DataLoader, model: Model = None, engine: Union["IEngine", str] = None, seed: int = 42, ) -> Generator: """ Runs model inference on PyTorch DataLoader and returns python generator with model predictions from `runner.predict_batch`. Args: loader: loader to predict model: model to use for prediction engine: engine to use for prediction seed: random seed to use before prediction Yields: bathes with model predictions """ if engine is not None: self.engine = engine if self.engine is None: self.engine = get_available_engine() if model is not None: self.model = model assert self.model is not None # if resume is not None: # checkpoint = load_checkpoint(resume) # unpack_checkpoint(checkpoint, model=self.model) self.model = self.engine.sync_device(self.model) maybe_recursive_call(self.model, "train", mode=False) set_global_seed(seed) for batch in loader: yield self.predict_batch(batch)
def on_loader_start(self, runner: "IRunner"): """Event handler.""" assert self.loader is not None self.is_train_loader: bool = self.loader_key.startswith("train") self.is_valid_loader: bool = self.loader_key.startswith("valid") self.is_infer_loader: bool = self.loader_key.startswith("infer") assert self.is_train_loader or self.is_valid_loader or self.is_infer_loader self.loader_batch_size: int = _get_batch_size(self.loader) self.loader_batch_len: int = len(self.loader) self.loader_sample_len: int = len(self.loader.dataset) self.loader_batch_step: int = 0 self.loader_sample_step: int = 0 self.loader_metrics: Dict = defaultdict(None) if self.loader_batch_len == 0: raise NotImplementedError(f"DataLoader with name {self.loader_key} is empty.") set_global_seed(self.seed + self.engine.rank + self.global_epoch_step) maybe_recursive_call(self.model, "train", mode=self.is_train_loader) if isinstance(self.loader.sampler, DistributedSampler): self.loader.sampler.set_epoch(self.stage_epoch_step) self.loader = self.engine.autocast_loader(self.loader)
def on_loader_start(self, runner: "IRunner"): """Event handler.""" assert self.loader is not None self.is_train_loader: bool = self.loader_key.startswith("train") self.is_valid_loader: bool = self.loader_key.startswith("valid") self.is_infer_loader: bool = self.loader_key.startswith("infer") assert self.is_train_loader or self.is_valid_loader or self.is_infer_loader self.loader_batch_size: int = get_loader_batch_size(self.loader) self.loader_batch_len: int = len(self.loader) self.loader_sample_len: int = get_loader_num_samples(self.loader) self.loader_batch_step: int = 0 self.loader_sample_step: int = 0 self.loader_metrics: Dict = defaultdict(None) if self.loader_batch_len == 0: raise IRunnerError( f"DataLoader with name {self.loader_key} is empty.") set_global_seed(self.seed + max(0, self.engine.process_index) + self.epoch_step) maybe_recursive_call(self.model, "train", mode=self.is_train_loader) if isinstance(self.loader.sampler, DistributedSampler): self.loader.sampler.set_epoch(self.epoch_step)
def process_components( model: Model, criterion: Criterion = None, optimizer: Optimizer = None, scheduler: Scheduler = None, distributed_params: Dict = None, device: Device = None, ) -> Tuple[Model, Criterion, Optimizer, Scheduler, Device]: """ Returns the processed model, criterion, optimizer, scheduler and device. Args: model (Model): torch model criterion (Criterion): criterion function optimizer (Optimizer): optimizer scheduler (Scheduler): scheduler distributed_params (dict, optional): dict with the parameters for distributed and FP16 method device (Device, optional): device Returns: tuple with processed model, criterion, optimizer, scheduler and device. Raises: NotImplementedError: if model is not nn.Module or dict for multi-gpu, nn.ModuleDict for DataParallel not implemented yet """ distributed_params = distributed_params or {} distributed_params = copy.deepcopy(distributed_params) distributed_params.update(get_distributed_params()) if device is None: device = get_device() elif isinstance(device, str): device = torch.device(device) is_apex_available = (distributed_params.pop("apex", True) and check_apex_available()) model: Model = maybe_recursive_call(model, "to", device=device) if check_ddp_wrapped(model): pass # distributed data parallel run (ddp) (with apex support) elif get_rank() >= 0: assert isinstance( model, nn.Module), "Distributed training is not available for KV model" local_rank = distributed_params.pop("local_rank", 0) or 0 device = f"cuda:{local_rank}" model = maybe_recursive_call(model, "to", device=device) syncbn = distributed_params.pop("syncbn", False) if is_apex_available: import apex model, optimizer = initialize_apex(model, optimizer, **distributed_params) model = apex.parallel.DistributedDataParallel(model) if syncbn: model = apex.parallel.convert_syncbn_model(model) else: model = nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank) # data parallel run (dp) (with apex support) else: # apex issue https://github.com/deepset-ai/FARM/issues/210 use_apex = (is_apex_available and torch.cuda.device_count() == 1) or ( is_apex_available and torch.cuda.device_count() > 1 and distributed_params.get("opt_level", "O0") == "O1") if use_apex: assert isinstance( model, nn.Module), "Apex training is not available for KV model" model, optimizer = initialize_apex(model, optimizer, **distributed_params) if (torch.cuda.device_count() > 1 and device.type != "cpu" and device.index is None): if isinstance(model, nn.Module): model = nn.DataParallel(model) elif isinstance(model, dict): model = {k: nn.DataParallel(v) for k, v in model.items()} else: raise NotImplementedError() model: Model = maybe_recursive_call(model, "to", device=device) return model, criterion, optimizer, scheduler, device
def predict_loader( self, *, loader: DataLoader, model: Model = None, engine: Union["IEngine", str] = None, seed: int = 42, # engine extra params, fp16: bool = False, amp: bool = False, apex: bool = False, ddp: bool = False, ) -> Generator: """ Runs model inference on PyTorch DataLoader and returns python generator with model predictions from `runner.predict_batch`. Args: loader: loader to predict model: model to use for prediction engine: engine to use for prediction seed: random seed to use before prediction fp16: boolean flag to use half-precision training (AMP > APEX) amp: boolean flag to use amp half-precision apex: boolean flag to use apex half-precision ddp: if `True` will start training in distributed mode. Note: Works only with python scripts. No jupyter support. Yields: bathes with model predictions .. note:: Please follow the `minimal examples`_ sections for use cases. .. _`minimal examples`: https://github.com/catalyst-team/catalyst#minimal-examples Examples: .. code-block:: python import os from torch import nn, optim from torch.nn import functional as F from torch.utils.data import DataLoader from catalyst import dl, metrics from catalyst.data.transforms import ToTensor from catalyst.contrib.datasets import MNIST model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) optimizer = optim.Adam(model.parameters(), lr=0.02) loaders = { "train": DataLoader( MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()), batch_size=32 ), "valid": DataLoader( MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()), batch_size=32 ), } class CustomRunner(dl.Runner): def predict_batch(self, batch): # model inference step return self.model(batch[0].to(self.device)) def on_loader_start(self, runner): super().on_loader_start(runner) self.meters = { key: metrics.AdditiveValueMetric(compute_on_call=False) for key in ["loss", "accuracy01", "accuracy03"] } def handle_batch(self, batch): # model train/valid step # unpack the batch x, y = batch # run model forward pass logits = self.model(x) # compute the loss loss = F.cross_entropy(logits, y) # compute other metrics of interest accuracy01, accuracy03 = metrics.accuracy(logits, y, topk=(1, 3)) # log metrics self.batch_metrics.update( {"loss": loss, "accuracy01": accuracy01, "accuracy03": accuracy03} ) for key in ["loss", "accuracy01", "accuracy03"]: self.meters[key].update( self.batch_metrics[key].item(), self.batch_size ) # run model backward pass if self.is_train_loader: loss.backward() self.optimizer.step() self.optimizer.zero_grad() def on_loader_end(self, runner): for key in ["loss", "accuracy01", "accuracy03"]: self.loader_metrics[key] = self.meters[key].compute()[0] super().on_loader_end(runner) runner = CustomRunner() # model training runner.train( model=model, optimizer=optimizer, loaders=loaders, logdir="./logs", num_epochs=5, verbose=True, valid_loader="valid", valid_metric="loss", minimize_valid_metric=True, ) # model inference for logits in runner.predict_loader(loader=loaders["valid"]): assert logits.detach().cpu().numpy().shape[-1] == 10 """ self._engine = engine or get_available_engine(fp16=fp16, ddp=ddp, amp=amp, apex=apex) if model is not None: self.model = model assert self.model is not None # if resume is not None: # checkpoint = load_checkpoint(resume) # unpack_checkpoint(checkpoint, model=self.model) self.model = self.engine.sync_device(self.model) maybe_recursive_call(self.model, "train", mode=False) set_global_seed(seed) for batch in loader: yield self.predict_batch(batch)
def process_components( model: RunnerModel, criterion: Criterion = None, optimizer: Optimizer = None, scheduler: Scheduler = None, distributed_params: Dict = None, device: Device = None, ) -> Tuple[RunnerModel, Criterion, Optimizer, Scheduler, Device]: """ Returns the processed model, criterion, optimizer, scheduler and device. Args: model: torch model criterion: criterion function optimizer: optimizer scheduler: scheduler distributed_params (dict, optional): dict with the parameters for distributed and FP16 method device (Device, optional): device Returns: tuple with processed model, criterion, optimizer, scheduler and device. Raises: ValueError: if device is None and TPU available, for using TPU need to manualy move model/optimizer/scheduler to a TPU device and pass device to a function. NotImplementedError: if model is not nn.Module or dict for multi-gpu, nn.ModuleDict for DataParallel not implemented yet """ distributed_params = distributed_params or {} distributed_params = copy.deepcopy(distributed_params) distributed_params.update(get_distributed_params()) if device is None and IS_XLA_AVAILABLE: raise ValueError( "TPU device is available. " "Please move model, optimizer and scheduler (if present) " "to TPU device manualy and specify a device or " "use CPU device.") if device is None: device = get_device() elif isinstance(device, str): device = torch.device(device) is_apex_enabled = (distributed_params.get("apex", False) and check_apex_available()) is_amp_enabled = (distributed_params.get("amp", False) and check_amp_available()) if is_apex_enabled and is_amp_enabled: raise ValueError("Both NVidia Apex and Torch.Amp are enabled. " "You must choose only one mixed precision backend") model: Model = maybe_recursive_call(model, "to", device=device) if check_ddp_wrapped(model): pass # distributed data parallel run (ddp) (with apex support) elif get_rank() >= 0: assert isinstance( model, nn.Module), "Distributed training is not available for KV model" local_rank = distributed_params.pop("local_rank", 0) or 0 device = f"cuda:{local_rank}" model = maybe_recursive_call(model, "to", device=device) syncbn = distributed_params.pop("syncbn", False) if is_apex_enabled: import apex if syncbn: model = apex.parallel.convert_syncbn_model(model) model, optimizer = initialize_apex(model, optimizer, **distributed_params) model = apex.parallel.DistributedDataParallel(model) else: if syncbn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank) # data parallel run (dp) (with apex support) else: is_data_parallel = (torch.cuda.device_count() > 1 and device.type != "cpu" and device.index is None) if is_apex_enabled and not is_data_parallel: model, optimizer = initialize_apex(model, optimizer, **distributed_params) elif not is_apex_enabled and is_data_parallel: if isinstance(model, nn.Module): model = nn.DataParallel(model) elif isinstance(model, dict): model = {k: nn.DataParallel(v) for k, v in model.items()} else: raise NotImplementedError() elif is_apex_enabled and is_data_parallel: model, optimizer = _wrap_into_data_parallel_with_apex( model, optimizer, distributed_params) model: Model = maybe_recursive_call(model, "to", device=device) return model, criterion, optimizer, scheduler, device