def load_traced_model( model_path: Union[str, Path], device: Device = "cpu", opt_level: str = None, ) -> ScriptModule: """ Loads a traced model Args: model_path: Path to traced model device (str): Torch device opt_level (str): Apex FP16 init level, optional Returns: (ScriptModule): Traced model """ # jit.load dont work with pathlib.Path model_path = str(model_path) if opt_level is not None: device = "cuda" model = torch.jit.load(model_path, map_location=device) if opt_level is not None: utils.assert_fp16_available() from apex import amp model = amp.initialize(model, optimizers=None, opt_level=opt_level) return model
def get_model(self, stage: str) -> _Model: model_params = self._config["model_params"] fp16 = model_params.pop("fp16", False) model = MODELS.get_from_params(**model_params) if fp16: utils.assert_fp16_available() model = Fp16Wrap(model) model = self._preprocess_model_for_stage(stage, model) model = self._postprocess_model_for_stage(stage, model) return model
def process_components( model: _Model, criterion: _Criterion = None, optimizer: _Optimizer = None, scheduler: _Scheduler = None, distributed_params: Dict = None ) -> Tuple[_Model, _Criterion, _Optimizer, _Scheduler, torch.device]: distributed_params = distributed_params or {} distributed_params = copy.deepcopy(distributed_params) device = utils.get_device() model = maybe_recursive_call(model, "to", device=device) if utils.is_wrapped_with_ddp(model): pass elif len(distributed_params) > 0: assert isinstance(model, nn.Module) utils.assert_fp16_available() from apex import amp from apex.parallel import convert_syncbn_model distributed_rank = distributed_params.pop("rank", -1) syncbn = distributed_params.pop("syncbn", False) if distributed_rank > -1: torch.cuda.set_device(distributed_rank) torch.distributed.init_process_group( backend="nccl", init_method="env://" ) model, optimizer = amp.initialize( model, optimizer, **distributed_params ) if distributed_rank > -1: from apex.parallel import DistributedDataParallel model = DistributedDataParallel(model) if syncbn: model = convert_syncbn_model(model) elif torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) elif torch.cuda.device_count() > 1: if isinstance(model, nn.Module): model = torch.nn.DataParallel(model) elif isinstance(model, dict): model = {k: torch.nn.DataParallel(v) for k, v in model.items()} model = maybe_recursive_call(model, "to", device=device) return model, criterion, optimizer, scheduler, device
def process_components( model: _Model, criterion: _Criterion = None, optimizer: _Optimizer = None, scheduler: _Scheduler = None, distributed_params: Dict = None ) -> Tuple[_Model, _Criterion, _Optimizer, _Scheduler, torch.device]: distributed_params = distributed_params or {} distributed_params = copy.deepcopy(distributed_params) device = utils.get_device() if torch.cuda.is_available(): benchmark = os.environ.get("CUDNN_BENCHMARK", "True") == "True" cudnn.benchmark = benchmark model = model.to(device) if utils.is_wrapped_with_ddp(model): pass elif len(distributed_params) > 0: utils.assert_fp16_available() from apex import amp distributed_rank = distributed_params.pop("rank", -1) if distributed_rank > -1: torch.cuda.set_device(distributed_rank) torch.distributed.init_process_group( backend="nccl", init_method="env://" ) model, optimizer = amp.initialize( model, optimizer, **distributed_params ) if distributed_rank > -1: from apex.parallel import DistributedDataParallel model = DistributedDataParallel(model) elif torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) elif torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) model = model.to(device) return model, criterion, optimizer, scheduler, device
def process_components( model: Model, criterion: Criterion = None, optimizer: Optimizer = None, scheduler: Scheduler = None, distributed_params: Dict = None, device: Device = None, ) -> Tuple[Model, Criterion, Optimizer, Scheduler, Device]: """ Returns the processed model, criterion, optimizer, scheduler and device Args: model (Model): torch model criterion (Criterion): criterion function optimizer (Optimizer): optimizer scheduler (Scheduler): scheduler distributed_params (dict, optional): dict with the parameters for distributed and FP16 methond device (Device, optional): device """ distributed_params = distributed_params or {} distributed_params = copy.deepcopy(distributed_params) if device is None: device = utils.get_device() model: Model = maybe_recursive_call(model, "to", device=device) if utils.is_wrapped_with_ddp(model): pass elif len(distributed_params) > 0: assert isinstance(model, nn.Module) distributed_rank = distributed_params.pop("rank", -1) syncbn = distributed_params.pop("syncbn", False) if distributed_rank > -1: torch.cuda.set_device(distributed_rank) torch.distributed.init_process_group(backend="nccl", init_method="env://") if "opt_level" in distributed_params: utils.assert_fp16_available() from apex import amp amp_result = amp.initialize(model, optimizer, **distributed_params) if optimizer is not None: model, optimizer = amp_result else: model = amp_result if distributed_rank > -1: from apex.parallel import DistributedDataParallel model = DistributedDataParallel(model) if syncbn: from apex.parallel import convert_syncbn_model model = convert_syncbn_model(model) if distributed_rank <= -1 and torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) elif torch.cuda.device_count() > 1: if isinstance(model, nn.Module): model = torch.nn.DataParallel(model) elif isinstance(model, dict): model = {k: torch.nn.DataParallel(v) for k, v in model.items()} model = maybe_recursive_call(model, "to", device=device) return model, criterion, optimizer, scheduler, device
def trace_model( model: Model, runner: Runner, batch=None, method_name: str = "forward", mode: str = "eval", requires_grad: bool = False, opt_level: str = None, device: Device = "cpu", predict_params: dict = None, ) -> ScriptModule: """ Traces model using runner and batch Args: model: Model to trace runner: Model's native runner that was used to train model batch: Batch to trace the model method_name (str): Model's method name that will be used as entrypoint during tracing mode (str): Mode for model to trace (``train`` or ``eval``) requires_grad (bool): Flag to use grads opt_level (str): Apex FP16 init level, optional device (str): Torch device predict_params (dict): additional parameters for model forward Returns: (ScriptModule): Traced model """ if batch is None or runner is None: raise ValueError("Both batch and runner must be specified.") if mode not in ["train", "eval"]: raise ValueError(f"Unknown mode '{mode}'. Must be 'eval' or 'train'") predict_params = predict_params or {} tracer = _TracingModelWrapper(model, method_name) if opt_level is not None: utils.assert_fp16_available() # If traced in AMP we need to initialize the model before calling # the jit # https://github.com/NVIDIA/apex/issues/303#issuecomment-493142950 from apex import amp model = model.to(device) model = amp.initialize(model, optimizers=None, opt_level=opt_level) # TODO: remove `check_trace=False` # after fixing this bug https://github.com/pytorch/pytorch/issues/23993 params = {**predict_params, "check_trace": False} else: params = predict_params getattr(model, mode)() utils.set_requires_grad(model, requires_grad=requires_grad) _runner_model, _runner_device = runner.model, runner.device runner.model, runner.device = tracer, device runner.predict_batch(batch, **params) result: ScriptModule = tracer.tracing_result runner.model, runner.device = _runner_model, _runner_device return result