Пример #1
0
def process_components(
    model: _Model,
    criterion: _Criterion = None,
    optimizer: _Optimizer = None,
    scheduler: _Scheduler = None,
    distributed_params: Dict = None
) -> Tuple[_Model, _Criterion, _Optimizer, _Scheduler, torch.device]:
    distributed_params = distributed_params or {}
    distributed_params = copy.deepcopy(distributed_params)
    device = utils.get_device()

    model = maybe_recursive_call(model, "to", device=device)

    if utils.is_wrapped_with_ddp(model):
        pass
    elif len(distributed_params) > 0:
        assert isinstance(model, nn.Module)
        utils.assert_fp16_available()
        from apex import amp
        from apex.parallel import convert_syncbn_model

        distributed_rank = distributed_params.pop("rank", -1)
        syncbn = distributed_params.pop("syncbn", False)

        if distributed_rank > -1:
            torch.cuda.set_device(distributed_rank)
            torch.distributed.init_process_group(
                backend="nccl", init_method="env://"
            )

        model, optimizer = amp.initialize(
            model, optimizer, **distributed_params
        )

        if distributed_rank > -1:
            from apex.parallel import DistributedDataParallel
            model = DistributedDataParallel(model)

            if syncbn:
                model = convert_syncbn_model(model)
        elif torch.cuda.device_count() > 1:
            model = torch.nn.DataParallel(model)
    elif torch.cuda.device_count() > 1:
        if isinstance(model, nn.Module):
            model = torch.nn.DataParallel(model)
        elif isinstance(model, dict):
            model = {k: torch.nn.DataParallel(v) for k, v in model.items()}

    model = maybe_recursive_call(model, "to", device=device)

    return model, criterion, optimizer, scheduler, device
Пример #2
0
    def on_batch_end(self, state: State) -> None:
        """On batch end event

        Args:
            state (State): current state
        """
        if not state.is_train_loader:
            return

        loss = state.batch_metrics[self.loss_key]

        self._accumulation_counter += 1
        need_gradient_step = (self._accumulation_counter +
                              1) % self.accumulation_steps == 0

        # This is very hacky check whether we have AMP optimizer and this may
        # change in future.
        # But alternative solution is to have AmpOptimizerCallback.
        # or expose another c'tor argument.
        if hasattr(self._optimizer, "_amp_stash"):
            from apex import amp

            # Need to set ``delay_unscale``
            # according to
            # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
            delay_unscale = not need_gradient_step
            with amp.scale_loss(loss,
                                self._optimizer,
                                delay_unscale=delay_unscale) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        if need_gradient_step:
            self.grad_step(
                optimizer=self._optimizer,
                optimizer_wds=self._optimizer_wd,
                grad_clip_fn=self.grad_clip_fn,
            )

            # if self.save_model_grads:
            #     for tag, value in model.named_parameters():
            #         tag = tag.replace(".", "/")
            #         state.model_grads[tag] = value.grad.cpu().numpy()

            utils.maybe_recursive_call(self._optimizer, "zero_grad")

            self._accumulation_counter = 0
Пример #3
0
    def model(self, value: Union[Model, Dict[str, Model]]):
        """
        Setter for the runner's model'
        """
        if isinstance(value, nn.Module):
            model = value
        elif isinstance(value, dict):
            values_are_models = all(
                [isinstance(v, nn.Module) for v in value.values()]
            )
            if not values_are_models:
                raise TypeError(
                    "Invalid dict value type, must be `torch.nn.Module`"
                )

            model = value

        else:
            raise TypeError(
                f"Invalid value type "
                f"must be `torch.nn.Module` or `Dict[str, torch.nn.Module]` "
                f"got '{type(value)}'"
            )

        if self._device is not None:
            model: Model = utils.maybe_recursive_call(
                model, "to", device=self._device
            )

        self._model = model
    def on_batch_end(self, state: _State):
        """On batch end event"""
        if not state.need_backward:
            return

        loss = self._get_loss(state)

        self._accumulation_counter += 1
        model = state.model
        optimizer = state.get_key(key="optimizer",
                                  inner_key=self.optimizer_key)

        need_gradient_step = \
            (self._accumulation_counter + 1) % self.accumulation_steps == 0

        # This is very hacky check whether we have AMP optimizer and this may
        # change in future.
        # But alternative solution is to have AmpOptimizerCallback.
        # or expose another c'tor argument.
        if hasattr(optimizer, "_amp_stash"):
            from apex import amp
            # Need to set ``delay_unscale``
            # according to
            # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
            delay_unscale = not need_gradient_step
            with amp.scale_loss(loss, optimizer,
                                delay_unscale=delay_unscale) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        if need_gradient_step:
            self.grad_step(optimizer=optimizer,
                           optimizer_wds=self._optimizer_wd,
                           grad_clip_fn=self.grad_clip_fn)

            if self.save_model_grads:
                for tag, value in model.named_parameters():
                    tag = tag.replace(".", "/")
                    state.model_grads[tag] = value.grad.cpu().numpy()

            utils.maybe_recursive_call(model, "zero_grad")

            self._accumulation_counter = 0
Пример #5
0
    def device(self, value: Device):
        """
        Setter for the runner's device'
        """
        if isinstance(value, (str, torch.device)):
            self._device = value
        else:
            raise TypeError(f"Invalid value type "
                            f"must be `str` or `torch.device` "
                            f"got '{type(value)}'")

        if self._model is not None:
            self._model = utils.maybe_recursive_call(self._model,
                                                     "to",
                                                     device=self._device)
Пример #6
0
def process_components(
    model: Model,
    criterion: Criterion = None,
    optimizer: Optimizer = None,
    scheduler: Scheduler = None,
    distributed_params: Dict = None,
    device: Device = None,
) -> Tuple[Model, Criterion, Optimizer, Scheduler, Device]:
    """
    Returns the processed model, criterion, optimizer, scheduler and device

    Args:
        model (Model): torch model
        criterion (Criterion): criterion function
        optimizer (Optimizer): optimizer
        scheduler (Scheduler): scheduler
        distributed_params (dict, optional): dict with the parameters
            for distributed and FP16 methond
        device (Device, optional): device
    """
    distributed_params = distributed_params or {}
    distributed_params = copy.deepcopy(distributed_params)
    if device is None:
        device = utils.get_device()

    model: Model = maybe_recursive_call(model, "to", device=device)

    if utils.is_wrapped_with_ddp(model):
        pass
    elif len(distributed_params) > 0:
        assert isinstance(model, nn.Module)
        distributed_rank = distributed_params.pop("rank", -1)
        syncbn = distributed_params.pop("syncbn", False)

        if distributed_rank > -1:
            torch.cuda.set_device(distributed_rank)
            torch.distributed.init_process_group(backend="nccl",
                                                 init_method="env://")

        if "opt_level" in distributed_params:
            utils.assert_fp16_available()
            from apex import amp

            amp_result = amp.initialize(model, optimizer, **distributed_params)
            if optimizer is not None:
                model, optimizer = amp_result
            else:
                model = amp_result

            if distributed_rank > -1:
                from apex.parallel import DistributedDataParallel
                model = DistributedDataParallel(model)

                if syncbn:
                    from apex.parallel import convert_syncbn_model
                    model = convert_syncbn_model(model)

        if distributed_rank <= -1 and torch.cuda.device_count() > 1:
            model = torch.nn.DataParallel(model)
    elif torch.cuda.device_count() > 1:
        if isinstance(model, nn.Module):
            model = torch.nn.DataParallel(model)
        elif isinstance(model, dict):
            model = {k: torch.nn.DataParallel(v) for k, v in model.items()}

    model = maybe_recursive_call(model, "to", device=device)

    return model, criterion, optimizer, scheduler, device
Пример #7
0
def process_components(
    model: Model,
    criterion: Criterion = None,
    optimizer: Optimizer = None,
    scheduler: Scheduler = None,
    distributed_params: Dict = None,
    device: Device = None,
) -> Tuple[Model, Criterion, Optimizer, Scheduler, Device]:
    """
    Returns the processed model, criterion, optimizer, scheduler and device

    Args:
        model (Model): torch model
        criterion (Criterion): criterion function
        optimizer (Optimizer): optimizer
        scheduler (Scheduler): scheduler
        distributed_params (dict, optional): dict with the parameters
            for distributed and FP16 methond
        device (Device, optional): device
    """
    distributed_params = distributed_params or {}
    distributed_params = copy.deepcopy(distributed_params)
    distributed_params.update(get_distributed_params())
    if device is None:
        device = utils.get_device()

    model: Model = utils.maybe_recursive_call(model, "to", device=device)

    if utils.is_wrapped_with_ddp(model):
        pass
    elif get_rank() >= 0:
        assert isinstance(model, nn.Module)
        local_rank = distributed_params.pop("local_rank", 0)
        device = f"cuda:{local_rank}"
        model = utils.maybe_recursive_call(model, "to", device=device)

        syncbn = distributed_params.pop("syncbn", False)
        use_apex = distributed_params.pop("apex", True) and is_apex_available()

        if use_apex:
            import apex
            amp_params = get_default_params(apex.amp.initialize,
                                            ["models", "optimizers"])
            amp_params["opt_level"] = "O0"
            for dp in distributed_params:
                if dp in amp_params:
                    amp_params[dp] = distributed_params[dp]

            amp_result = apex.amp.initialize(model, optimizer, **amp_params)
            if optimizer is not None:
                model, optimizer = amp_result
            else:
                model = amp_result

            model = apex.parallel.DistributedDataParallel(model)

            if syncbn:
                model = apex.parallel.convert_syncbn_model(model)
        else:
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[local_rank], output_device=local_rank)
    elif torch.cuda.device_count() > 1:
        if isinstance(model, nn.Module):
            model = torch.nn.DataParallel(model)
        elif isinstance(model, dict):
            model = {k: torch.nn.DataParallel(v) for k, v in model.items()}

    model: Model = utils.maybe_recursive_call(model, "to", device=device)

    return model, criterion, optimizer, scheduler, device
Пример #8
0
def process_components(
    model: Model,
    criterion: Criterion = None,
    optimizer: Optimizer = None,
    scheduler: Scheduler = None,
    distributed_params: Dict = None,
    device: Device = None,
) -> Tuple[Model, Criterion, Optimizer, Scheduler, Device]:
    """
    Returns the processed model, criterion, optimizer, scheduler and device

    Args:
        model (Model): torch model
        criterion (Criterion): criterion function
        optimizer (Optimizer): optimizer
        scheduler (Scheduler): scheduler
        distributed_params (dict, optional): dict with the parameters
            for distributed and FP16 methond
        device (Device, optional): device
    """
    distributed_params = distributed_params or {}
    distributed_params = copy.deepcopy(distributed_params)
    distributed_params.update(get_distributed_params())
    if device is None:
        device = utils.get_device()

    use_apex = distributed_params.pop("apex", True) and is_apex_available()

    model: Model = utils.maybe_recursive_call(model, "to", device=device)

    if utils.is_wrapped_with_ddp(model):
        pass
    # distributed data parallel run (ddp) (with apex support)
    elif get_rank() >= 0:
        assert isinstance(model, nn.Module), \
            "No support for dixtributed KV model yet"

        local_rank = distributed_params.pop("local_rank", 0)
        device = f"cuda:{local_rank}"
        model = utils.maybe_recursive_call(model, "to", device=device)

        syncbn = distributed_params.pop("syncbn", False)

        if use_apex:
            import apex
            model, optimizer = initialize_apex(model, optimizer,
                                               **distributed_params)
            model = apex.parallel.DistributedDataParallel(model)

            if syncbn:
                model = apex.parallel.convert_syncbn_model(model)
        else:
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[local_rank], output_device=local_rank)
    # data parallel run (dp) (with apex support)
    else:
        # apex issue https://github.com/deepset-ai/FARM/issues/210
        can_use_apex = \
            (use_apex and torch.cuda.device_count() == 1) \
            or (
                    torch.cuda.device_count() > 1
                    and distributed_params.get("opt_level", "O0") == "O1"
            )

        if can_use_apex:
            assert isinstance(model, nn.Module), \
                "No support for apex KV model yet"

            model, optimizer = initialize_apex(model, optimizer,
                                               **distributed_params)

        if torch.cuda.device_count() > 1:
            if isinstance(model, nn.Module):
                model = nn.DataParallel(model)
            elif isinstance(model, dict):
                model = {k: nn.DataParallel(v) for k, v in model.items()}

    model: Model = utils.maybe_recursive_call(model, "to", device=device)

    return model, criterion, optimizer, scheduler, device