예제 #1
0
    def get_callbacks(self, stage: str) -> "OrderedDict[str, Callback]":
        """
        Override of ``BaseExperiment.get_callbacks`` method.
        Will add several of callbacks by default in case they missed.

        Args:
            stage: name of stage. It should start with `infer` if you
                don't need default callbacks, as they required only for
                training stages.

        Returns:
            OrderedDict[str, Callback]: Ordered dictionary of callbacks
                for experiment
        """
        callbacks = super().get_callbacks(stage=stage) or OrderedDict()

        # default_callbacks = [(Name, InterfaceClass, InstanceFactory)]
        default_callbacks = []

        is_amp_enabled = (self.distributed_params.get("amp", False)
                          and check_amp_available())
        optimizer_cls = (AMPOptimizerCallback
                         if is_amp_enabled else OptimizerCallback)

        if not stage.startswith("infer"):
            if self._criterion is not None and isinstance(
                    self._criterion, Criterion):
                default_callbacks.append(
                    ("_criterion", None, CriterionCallback))
            if self._optimizer is not None and isinstance(
                    self._optimizer, Optimizer):
                default_callbacks.append(
                    ("_optimizer", IOptimizerCallback, optimizer_cls))
            if self._scheduler is not None and isinstance(
                    self._scheduler, (Scheduler, ReduceLROnPlateau)):
                default_callbacks.append(
                    ("_scheduler", ISchedulerCallback, SchedulerCallback))

        for (
                callback_name,
                callback_interface,
                callback_fn,
        ) in default_callbacks:
            callback_interface = callback_interface or callback_fn
            is_already_present = any(
                check_callback_isinstance(x, callback_interface)
                for x in callbacks.values())
            if not is_already_present:
                callbacks[callback_name] = callback_fn()

        return callbacks
예제 #2
0
    def get_callbacks(self, stage: str) -> "OrderedDict[Callback]":
        """Returns the callbacks for a given stage."""
        callbacks_params = self.stages_config[stage].get(
            "callbacks_params", {})

        callbacks = OrderedDict()
        for key, callback_params in callbacks_params.items():
            callback = self._get_callback(**callback_params)
            callbacks[key] = callback

        # default_callbacks = [(Name, InterfaceClass, InstanceFactory)]
        default_callbacks = []

        is_amp_enabled = (self.distributed_params.get("amp", False)
                          and check_amp_available())
        optimizer_cls = (AMPOptimizerCallback
                         if is_amp_enabled else OptimizerCallback)

        if self._verbose:
            default_callbacks.append(("_verbose", None, VerboseLogger))
        if self._check_time:
            default_callbacks.append(("_timer", None, TimerCallback))
        if self._check_run:
            default_callbacks.append(("_check", None, CheckRunCallback))
        if self._overfit:
            default_callbacks.append(("_overfit", None, BatchOverfitCallback))

        if not stage.startswith("infer"):
            default_callbacks.append(("_metrics", None, MetricManagerCallback))
            default_callbacks.append(
                ("_validation", None, ValidationManagerCallback))
            default_callbacks.append(("_console", None, ConsoleLogger))

            if self.logdir is not None:
                default_callbacks.append(("_saver", None, CheckpointCallback))
                default_callbacks.append(
                    ("_tensorboard", None, TensorboardLogger))

            if self.stages_config[stage].get("criterion_params", {}):
                default_callbacks.append(
                    ("_criterion", None, CriterionCallback))
            if self.stages_config[stage].get("optimizer_params", {}):
                default_callbacks.append(
                    ("_optimizer", IOptimizerCallback, optimizer_cls))
            if self.stages_config[stage].get("scheduler_params", {}):
                default_callbacks.append(
                    ("_scheduler", ISchedulerCallback, SchedulerCallback))

        default_callbacks.append(("_exception", None, ExceptionCallback))

        for (
                callback_name,
                callback_interface,
                callback_fn,
        ) in default_callbacks:
            callback_interface = callback_interface or callback_fn
            is_already_present = any(
                check_callback_isinstance(x, callback_interface)
                for x in callbacks.values())
            if not is_already_present:
                callbacks[callback_name] = callback_fn()

        # NOTE: stage should be in self.stages_config
        #       othervise will be raised ValueError
        stage_index = list(self.stages_config.keys()).index(stage)
        self._process_callbacks(callbacks, stage_index)

        return callbacks
예제 #3
0
def process_components(
    model: RunnerModel,
    criterion: Criterion = None,
    optimizer: Optimizer = None,
    scheduler: Scheduler = None,
    distributed_params: Dict = None,
    device: Device = None,
) -> Tuple[RunnerModel, Criterion, Optimizer, Scheduler, Device]:
    """
    Returns the processed model, criterion, optimizer, scheduler and device.

    Args:
        model: torch model
        criterion: criterion function
        optimizer: optimizer
        scheduler: scheduler
        distributed_params (dict, optional): dict with the parameters
            for distributed and FP16 method
        device (Device, optional): device

    Returns:
        tuple with processed model, criterion, optimizer, scheduler and device.

    Raises:
        ValueError: if device is None and TPU available,
            for using TPU need to manualy move model/optimizer/scheduler
            to a TPU device and pass device to a function.
        NotImplementedError: if model is not nn.Module or dict for multi-gpu,
            nn.ModuleDict for DataParallel not implemented yet
    """
    distributed_params = distributed_params or {}
    distributed_params = copy.deepcopy(distributed_params)
    distributed_params.update(get_distributed_params())

    if device is None and IS_XLA_AVAILABLE:
        raise ValueError(
            "TPU device is available. "
            "Please move model, optimizer and scheduler (if present) "
            "to TPU device manualy and specify a device or "
            "use CPU device.")

    if device is None:
        device = get_device()
    elif isinstance(device, str):
        device = torch.device(device)

    is_apex_enabled = (distributed_params.get("apex", False)
                       and check_apex_available())

    is_amp_enabled = (distributed_params.get("amp", False)
                      and check_amp_available())

    if is_apex_enabled and is_amp_enabled:
        raise ValueError("Both NVidia Apex and Torch.Amp are enabled. "
                         "You must choose only one mixed precision backend")
    model: Model = maybe_recursive_call(model, "to", device=device)

    if check_ddp_wrapped(model):
        pass
    # distributed data parallel run (ddp) (with apex support)
    elif get_rank() >= 0:
        assert isinstance(
            model,
            nn.Module), "Distributed training is not available for KV model"

        local_rank = distributed_params.pop("local_rank", 0) or 0
        device = f"cuda:{local_rank}"
        model = maybe_recursive_call(model, "to", device=device)

        syncbn = distributed_params.pop("syncbn", False)

        if is_apex_enabled:
            import apex

            if syncbn:
                model = apex.parallel.convert_syncbn_model(model)

            model, optimizer = initialize_apex(model, optimizer,
                                               **distributed_params)
            model = apex.parallel.DistributedDataParallel(model)
        else:
            if syncbn:
                model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[local_rank], output_device=local_rank)
    # data parallel run (dp) (with apex support)
    else:
        is_data_parallel = (torch.cuda.device_count() > 1
                            and device.type != "cpu" and device.index is None)

        if is_apex_enabled and not is_data_parallel:
            model, optimizer = initialize_apex(model, optimizer,
                                               **distributed_params)

        elif not is_apex_enabled and is_data_parallel:
            if isinstance(model, nn.Module):
                model = nn.DataParallel(model)
            elif isinstance(model, dict):
                model = {k: nn.DataParallel(v) for k, v in model.items()}
            else:
                raise NotImplementedError()

        elif is_apex_enabled and is_data_parallel:
            model, optimizer = _wrap_into_data_parallel_with_apex(
                model, optimizer, distributed_params)

    model: Model = maybe_recursive_call(model, "to", device=device)

    return model, criterion, optimizer, scheduler, device