Beispiel #1
0
    def __call__(self, module: torch.nn.Module) -> None:
        """
        Applies an initializer to all parameters in a module that match one of the regexes we were
        given in this object's constructor.  Does nothing to parameters that do not match.

        Parameters
        ----------
        module : torch.nn.Module, required.
            The Pytorch module to apply the initializers to.
        """
        logger.info("Initializing parameters")
        unused_regexes = set([initializer[0] for initializer in self._initializers])
        uninitialized_parameters = set()
        # Store which initialisers were applied to which parameters.
        for name, parameter in module.named_parameters():
            for initializer_regex, initializer in self._initializers:
                allow = self._prevent_regex is None or not bool(re.search(self._prevent_regex, name))
                if allow and re.search(initializer_regex, name):
                    logger.info("Initializing %s using %s intitializer", name, initializer_regex)
                    initializer(parameter)
                    unused_regexes.discard(initializer_regex)
                    break
            else:  # no break
                uninitialized_parameters.add(name)
        for regex in unused_regexes:
            logger.warning("Did not use initialization regex that was passed: %s", regex)
        logger.info("Done initializing parameters; the following parameters are using their "
                    "default initialization from their code")
        uninitialized_parameter_list = list(uninitialized_parameters)
        uninitialized_parameter_list.sort()
        for name in uninitialized_parameter_list:
            logger.info("   %s", name)
Beispiel #2
0
def get_frozen_and_tunable_parameter_names(model: torch.nn.Module) -> List:
    frozen_parameter_names = []
    tunable_parameter_names = []
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            frozen_parameter_names.append(name)
        else:
            tunable_parameter_names.append(name)
    return [frozen_parameter_names, tunable_parameter_names]
    def __call__(self, module: torch.nn.Module) -> torch.Tensor:
        """
        Parameters
        ----------
        module : torch.nn.Module, required
            The module to regularize.
        """
        accumulator = 0.0
        # For each parameter find the first matching regex.
        for name, parameter in module.named_parameters():
            for regex, regularizer in self._regularizers:
                if re.search(regex, name):
                    penalty = regularizer(parameter)
                    accumulator = accumulator + penalty
                    break

        return accumulator
Beispiel #4
0
def train(
    dataset: torch.utils.data.Dataset,
    model: torch.nn.Module,
    epochs: int,
    batch_size: int,
    optimizer: torch.optim.Optimizer,
    stopping_delta: Optional[float] = None,
    collate_fn=default_collate,
    cuda: bool = True,
    sampler: Optional[torch.utils.data.sampler.Sampler] = None,
    silent: bool = False,
    update_freq: int = 10,
    evaluate_batch_size: int = 1024,
    update_callback: Optional[Callable[[float, float], None]] = None,
    epoch_callback: Optional[Callable[[int, torch.nn.Module], None]] = None,
) -> None:
    """
    Train the DEC model given a dataset, a model instance and various configuration parameters.

    :param dataset: instance of Dataset to use for training
    :param model: instance of DEC model to train
    :param epochs: number of training epochs
    :param batch_size: size of the batch to train with
    :param optimizer: instance of optimizer to use
    :param stopping_delta: label delta as a proportion to use for stopping, None to disable, default None
    :param collate_fn: function to merge a list of samples into mini-batch
    :param cuda: whether to use CUDA, defaults to True
    :param sampler: optional sampler to use in the DataLoader, defaults to None
    :param silent: set to True to prevent printing out summary statistics, defaults to False
    :param update_freq: frequency of batches with which to update counter, None disables, default 10
    :param evaluate_batch_size: batch size for evaluation stage, default 1024
    :param update_callback: optional function of accuracy and loss to update, default None
    :param epoch_callback: optional function of epoch and model, default None
    :return: None
    """
    static_dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        pin_memory=False,
        sampler=sampler,
        shuffle=False,
    )
    train_dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        sampler=sampler,
        shuffle=True,
    )
    data_iterator = tqdm(
        static_dataloader,
        leave=True,
        unit="batch",
        postfix={
            "epo": -1,
            "acc": "%.4f" % 0.0,
            "lss": "%.8f" % 0.0,
            "dlb": "%.4f" % -1,
        },
        disable=silent,
    )
    kmeans = KMeans(n_clusters=model.cluster_number, n_init=20)
    model.train()
    features = []
    actual = []
    # form initial cluster centres
    for index, batch in enumerate(data_iterator):
        if (isinstance(batch, tuple) or isinstance(batch, list)) and len(batch) == 2:
            batch, value = batch  # if we have a prediction label, separate it to actual
            actual.append(value)
        if cuda:
            batch = batch.cuda(non_blocking=True)
        features.append(model.encoder(batch).detach().cpu())
    actual = torch.cat(actual).long()
    predicted = kmeans.fit_predict(torch.cat(features).numpy())
    predicted_previous = torch.tensor(np.copy(predicted), dtype=torch.long)
    _, accuracy = cluster_accuracy(predicted, actual.cpu().numpy())
    cluster_centers = torch.tensor(
        kmeans.cluster_centers_, dtype=torch.float, requires_grad=True
    )
    if cuda:
        cluster_centers = cluster_centers.cuda(non_blocking=True)
    with torch.no_grad():
        # initialise the cluster centers
        model.state_dict()["assignment.cluster_centers"].copy_(cluster_centers)
    loss_function = nn.KLDivLoss(size_average=False)
    delta_label = None
    for epoch in range(epochs):
        features = []
        data_iterator = tqdm(
            train_dataloader,
            leave=True,
            unit="batch",
            postfix={
                "epo": epoch,
                "acc": "%.4f" % (accuracy or 0.0),
                "lss": "%.8f" % 0.0,
                "dlb": "%.4f" % (delta_label or 0.0),
            },
            disable=silent,
        )
        model.train()
        for index, batch in enumerate(data_iterator):
            if (isinstance(batch, tuple) or isinstance(batch, list)) and len(
                batch
            ) == 2:
                batch, _ = batch  # if we have a prediction label, strip it away
            if cuda:
                batch = batch.cuda(non_blocking=True)
            output = model(batch)
            target = target_distribution(output).detach()
            loss = loss_function(output.log(), target) / output.shape[0]
            data_iterator.set_postfix(
                epo=epoch,
                acc="%.4f" % (accuracy or 0.0),
                lss="%.8f" % float(loss.item()),
                dlb="%.4f" % (delta_label or 0.0),
            )
            optimizer.zero_grad()
            loss.backward()
            optimizer.step(closure=None)
            features.append(model.encoder(batch).detach().cpu())
            if update_freq is not None and index % update_freq == 0:
                loss_value = float(loss.item())
                data_iterator.set_postfix(
                    epo=epoch,
                    acc="%.4f" % (accuracy or 0.0),
                    lss="%.8f" % loss_value,
                    dlb="%.4f" % (delta_label or 0.0),
                )
                if update_callback is not None:
                    update_callback(accuracy, loss_value, delta_label)
        predicted, actual = predict(
            dataset,
            model,
            batch_size=evaluate_batch_size,
            collate_fn=collate_fn,
            silent=True,
            return_actual=True,
            cuda=cuda,
        )
        delta_label = (
            float((predicted != predicted_previous).float().sum().item())
            / predicted_previous.shape[0]
        )
        if stopping_delta is not None and delta_label < stopping_delta:
            print(
                'Early stopping as label delta "%1.5f" less than "%1.5f".'
                % (delta_label, stopping_delta)
            )
            break
        predicted_previous = predicted
        _, accuracy = cluster_accuracy(predicted.cpu().numpy(), actual.cpu().numpy())
        data_iterator.set_postfix(
            epo=epoch,
            acc="%.4f" % (accuracy or 0.0),
            lss="%.8f" % 0.0,
            dlb="%.4f" % (delta_label or 0.0),
        )
        if epoch_callback is not None:
            epoch_callback(epoch, model)
def predict(
        input_ids: torch.Tensor,
        token_type_ids: torch.Tensor,
        attention_masks: torch.Tensor,
        model: torch.nn.Module,
        batch_size: int,
        device_: Optional[str] = None,  # if None, it automatically detects if a GPU is available, if not uses a CPU
        disable_progress_bar: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Given a trained model and unseen data, performs the predictions and returns the results.
    Unlike in the fine-tuning and training stages, during prediction there's no need to build a dataloader which
    splits the set into train and validation, and randomly shuffles the training samples. We can just pass the items
    directly one by one. As we're not training, there are no training epochs either.
    :param input_ids: torch.tensor of shape (N, max_len) representing the ids of each token of the N encoded sequence
           pairs, with padding at the end up to max_len. If decoded, the input_ids will consist of a "[CLS]" token,
           followed by the question's tokens, followed by a "[SEP]" token, followed by the context's tokens, followed
           by a "[SEP]" token, followed by "[PAD]" tokens, if relevant, up to max_len.
    :param token_type_ids: torch.tensor of shape (N, max_len) where each Nth dimension is filled with 1 for token
           positions in the context text, 0 elsewhere (i.e. in question and padding)
    :param attention_masks: torch.tensor of shape (N, max_len) where each Nth dimension is filled with 1 for
           non-"[PAD]" tokens, 0 for "[PAD]" tokens.
    :param model: the model to use (must be instance of torch.nn.Module). As we're performing predictions, this must
           be a trained model.
    :param batch_size: the batch size to use for predictions. Batching samples speeds up processing.
    :param device_: if specified, the device used for the computations. Can be one of cpu, cuda, mkldnn, opengl,
           opencl, ideep, hip, msnpu. If set to None, it will default to GPU (cuda) if one is available, else it will
           use a CPU. Default: None
    :param disable_progress_bar: bool; whether to disable the tqdm progress bar. When used in production for quickly
           returning answers to a single or small set of questions, the bar might be distracting. Default: False.
    :return: pred_start: torch.tensor of shape (N) with the predicted indices of the first answer token for each answer
             pred_end: torch.tensor of shape (N) with the predicted indices of the last answer token for each answer
    """
    assert input_ids.shape == token_type_ids.shape == attention_masks.shape, "Some input shapes are wrong"

    device = set_hardware_acceleration(default=device_)
    model = model.to(device)
    model.eval()

    pred_start = torch.tensor([], dtype=torch.long, device=device)  # initialising tensors for storing results
    pred_end = torch.tensor([], dtype=torch.long, device=device)

    t_i = time()
    # batch the samples to speed up processing. We do batching manually here to avoid using DataLoader
    for batch_i in tqdm(range(0, len(input_ids), batch_size), disable=disable_progress_bar):
        batch_input_ids =  ## create the batch manually
        batch_token_type_ids =  ## create the batch manually
        batch_attention_masks =  ## create the batch manually
        with torch.no_grad():
            ## implement the forward pass. Look up the documentation, as this is slightly different compared to the training loop

            pred_start_positions = torch.argmax(start_logits, dim=1)
            pred_end_positions = torch.argmax(end_logits, dim=1)

            pred_start = torch.cat((pred_start, pred_start_positions))
            pred_end = torch.cat((pred_end, pred_end_positions))
        if torch.cuda.is_available():
            logger.debug("GPU memory usage: \n", gpu_memory_usage())

    logger.info(f"All predictions calculated in {format_time(time() - t_i)}.")
    if torch.cuda.is_available():
        logger.info("GPU memory usage: \n", gpu_memory_usage())

    return pred_start, pred_end
Beispiel #6
0
def get_default_optimizer_params(
    model: torch.nn.Module,
    base_lr: Optional[float] = None,
    weight_decay: Optional[float] = None,
    weight_decay_norm: Optional[float] = None,
    bias_lr_factor: Optional[float] = 1.0,
    weight_decay_bias: Optional[float] = None,
    overrides: Optional[Dict[str, Dict[str, float]]] = None,
) -> List[Dict[str, Any]]:
    """
    Get default param list for optimizer, with support for a few types of
    overrides. If no overrides needed, this is equivalent to `model.parameters()`.

    Args:
        base_lr: lr for every group by default. Can be omitted to use the one in optimizer.
        weight_decay: weight decay for every group by default. Can be omitted to use the one
            in optimizer.
        weight_decay_norm: override weight decay for params in normalization layers
        bias_lr_factor: multiplier of lr for bias parameters.
        weight_decay_bias: override weight decay for bias parameters
        overrides: if not `None`, provides values for optimizer hyperparameters
            (LR, weight decay) for module parameters with a given name; e.g.
            ``{"embedding": {"lr": 0.01, "weight_decay": 0.1}}`` will set the LR and
            weight decay values for all module parameters named `embedding`.

    For common detection models, ``weight_decay_norm`` is the only option
    needed to be set. ``bias_lr_factor,weight_decay_bias`` are legacy settings
    from Detectron1 that are not found useful.

    Example:
    ::
        torch.optim.SGD(get_default_optimizer_params(model, weight_decay_norm=0),
                       lr=0.01, weight_decay=1e-4, momentum=0.9)
    """
    if overrides is None:
        overrides = {}
    defaults = {}
    if base_lr is not None:
        defaults["lr"] = base_lr
    if weight_decay is not None:
        defaults["weight_decay"] = weight_decay
    bias_overrides = {}
    if bias_lr_factor is not None and bias_lr_factor != 1.0:
        # NOTE: unlike Detectron v1, we now by default make bias hyperparameters
        # exactly the same as regular weights.
        if base_lr is None:
            raise ValueError("bias_lr_factor requires base_lr")
        bias_overrides["lr"] = base_lr * bias_lr_factor
    if weight_decay_bias is not None:
        bias_overrides["weight_decay"] = weight_decay_bias
    if len(bias_overrides):
        if "bias" in overrides:
            raise ValueError("Conflicting overrides for 'bias'")
        overrides["bias"] = bias_overrides

    norm_module_types = (
        torch.nn.BatchNorm1d,
        torch.nn.BatchNorm2d,
        torch.nn.BatchNorm3d,
        torch.nn.SyncBatchNorm,
        # NaiveSyncBatchNorm inherits from BatchNorm2d
        torch.nn.GroupNorm,
        torch.nn.InstanceNorm1d,
        torch.nn.InstanceNorm2d,
        torch.nn.InstanceNorm3d,
        torch.nn.LayerNorm,
        torch.nn.LocalResponseNorm,
    )
    params: List[Dict[str, Any]] = []
    memo: Set[torch.nn.parameter.Parameter] = set()
    for module in model.modules():
        for module_param_name, value in module.named_parameters(recurse=False):
            if not value.requires_grad:
                continue
            # Avoid duplicating parameters
            if value in memo:
                continue
            memo.add(value)

            hyperparams = copy.copy(defaults)
            if isinstance(module, norm_module_types) and weight_decay_norm is not None:
                hyperparams["weight_decay"] = weight_decay_norm
            hyperparams.update(overrides.get(module_param_name, {}))
            params.append({"params": [value], **hyperparams})
    return reduce_param_groups(params)
Beispiel #7
0
    def plot_attention(
        cls,
        model: torch.nn.Module,
        output_dir: Optional[Path],
        summary_writer: Optional[SummaryWriter],
        iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
        reporter: SubReporter,
        options: TrainerOptions,
    ) -> None:
        assert check_argument_types()
        import matplotlib

        ngpu = options.ngpu
        no_forward_run = options.no_forward_run

        matplotlib.use("Agg")
        import matplotlib.pyplot as plt
        from matplotlib.ticker import MaxNLocator

        model.eval()
        for ids, batch in iterator:
            assert isinstance(batch, dict), type(batch)
            assert len(next(iter(batch.values()))) == len(ids), (
                len(next(iter(batch.values()))),
                len(ids),
            )
            batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
            if no_forward_run:
                continue

            # 1. Forwarding model and gathering all attentions
            #    calculate_all_attentions() uses single gpu only.
            att_dict = calculate_all_attentions(model, batch)

            # 2. Plot attentions: This part is slow due to matplotlib
            for k, att_list in att_dict.items():
                assert len(att_list) == len(ids), (len(att_list), len(ids))
                for id_, att_w in zip(ids, att_list):

                    if isinstance(att_w, torch.Tensor):
                        att_w = att_w.detach().cpu().numpy()

                    if att_w.ndim == 2:
                        att_w = att_w[None]
                    elif att_w.ndim > 3 or att_w.ndim == 1:
                        raise RuntimeError(
                            f"Must be 2 or 3 dimension: {att_w.ndim}")

                    w, h = plt.figaspect(1.0 / len(att_w))
                    fig = plt.Figure(figsize=(w * 1.3, h * 1.3))
                    axes = fig.subplots(1, len(att_w))
                    if len(att_w) == 1:
                        axes = [axes]

                    for ax, aw in zip(axes, att_w):
                        ax.imshow(aw.astype(np.float32), aspect="auto")
                        ax.set_title(f"{k}_{id_}")
                        ax.set_xlabel("Input")
                        ax.set_ylabel("Output")
                        ax.xaxis.set_major_locator(MaxNLocator(integer=True))
                        ax.yaxis.set_major_locator(MaxNLocator(integer=True))

                    if output_dir is not None:
                        p = output_dir / id_ / f"{k}.{reporter.get_epoch()}ep.png"
                        p.parent.mkdir(parents=True, exist_ok=True)
                        fig.savefig(p)

                    if summary_writer is not None:
                        summary_writer.add_figure(f"{k}_{id_}", fig,
                                                  reporter.get_epoch())

                    # Dummy register() stimulates to increment the counter
                    reporter.register({})
Beispiel #8
0
def my_correct_bias(model: torch.nn.Module,
                    quant_params,
                    num_quant_samples: int,
                    data_loader,
                    num_bias_correct_samples: int,
                    conv_bn_dict=None,
                    perform_only_empirical_bias_corr: bool = True,
                    layers_to_ignore=None,
                    quantizer_modifications=None):

    if layers_to_ignore is None:
        layers_to_ignore = []

    # Find batch size and shape of input tensor
    batch_size, input_shape = aimet_torch.utils.get_input_shape_batch_size(
        data_loader)

    # Rounding up number of samples to batch size
    n_batches_bias_correction = int(
        np.ceil(num_bias_correct_samples / batch_size))
    n_batches_quantization = int(np.ceil(num_quant_samples / batch_size))

    data_loader_n_samples_bias_corr = aimet_torch.utils.IterFirstX(
        data_loader, n_batches_bias_correction)
    data_loader_n_samples_quant = aimet_torch.utils.IterFirstX(
        data_loader, n_batches_quantization)

    # TODO: Remove wrapper function
    # Create a wrapping function for data loader for quantization
    def pass_data_through_model(model,
                                early_stopping_iterations=None,
                                use_cuda=False):
        # pylint: disable=unused-argument
        # forward pass for given number of batches for model
        for (images_in_one_batch, _) in data_loader_n_samples_quant:
            aimet_torch.bias_correction.forward_pass(model,
                                                     images_in_one_batch)

    ordered_conv_linear_nodes = aimet_torch.utils.get_ordered_lists_of_conv_fc(
        model, input_shape)

    if conv_bn_dict is None:
        conv_bn_dict = aimet_torch.bias_correction.find_all_conv_bn_with_activation(
            model, input_shape)

    # Create a copy of the model as reference model
    model_copy = copy.deepcopy(model)

    # Add bias for all the layers whose bias is None
    for name, module in ordered_conv_linear_nodes:
        if module.bias is None:
            if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
                output_size = module.out_channels
            elif isinstance(module, torch.nn.Linear):
                output_size = module.out_features
            module.bias = torch.nn.Parameter(torch.zeros(output_size))
            module.bias.data = module.bias.data.to(device=module.weight.device)

    # Quantize full model
    dummy_tensors = aimet_torch.utils.create_rand_tensors_given_shapes(
        input_shape)
    dummy_tensors = [
        tensor.to(aimet_torch.utils.get_device(model))
        for tensor in dummy_tensors
    ]
    q = aimet_torch.quantsim.QuantizationSimModel(
        model=model,
        quant_scheme=quant_params.quant_scheme,
        rounding_mode=quant_params.round_mode,
        default_output_bw=quant_params.act_bw,
        default_param_bw=quant_params.weight_bw,
        in_place=True,
        dummy_input=dummy_tensors,
        config_file=quant_params.config_file)

    # make sure  model got updated in-place before we use it for bc updates
    assert (q.model is model)

    if quantizer_modifications is not None:
        quantizer_modifications(q)

    # updates to skip_output_activation and layers_to_ignore
    for name, module in model.named_modules():
        # Skip all layer's output quantization
        if isinstance(module, QcQuantizeWrapper):
            module.output_quantizers[0].enabled = False

    q.compute_encodings(pass_data_through_model, None)

    # For first conv layer, perform analytical bc if perform_only_empirical_bias_corr is set to False
    # and layer is not marked to be ignored during bc.
    if not perform_only_empirical_bias_corr:
        module_name, module = ordered_conv_linear_nodes[0]
        if module not in layers_to_ignore:
            aimet_torch.bias_correction.logger.info(
                'Correcting layer %s using Analytical Bias Correction',
                module_name)
            quantize_layer = aimet_torch.utils.get_layer_by_name(
                model, module_name)
            aimet_torch.bias_correction.call_analytical_mo_correct_bias(
                quantize_layer, None, None)
            aimet_torch.bias_correction.logger.info(
                'Corrected bias for the layer')
            ordered_conv_linear_nodes.pop(0)

    for module_name, module in ordered_conv_linear_nodes:
        # Ignore all layers which are skipped by user
        if module in layers_to_ignore or module_name in layers_to_ignore:
            continue
        else:
            # make sure module is in the model used by qsim.
            assert (module in list(q.model.modules()))
            # Analytical Bias Correction is only done for Conv layers
            reference_layer = aimet_torch.utils.get_layer_by_name(
                model_copy, module_name)
            quantize_layer = aimet_torch.utils.get_layer_by_name(
                model, module_name)

            if module in conv_bn_dict.keys():

                bn_layer_info = conv_bn_dict[module]

                if perform_only_empirical_bias_corr or bn_layer_info is None or bn_layer_info.input_bn is None:
                    aimet_torch.bias_correction.logger.info(
                        'Correcting layer %s using Empirical Bias Correction',
                        module_name)
                    bias_correction = libpymo.BiasCorrection()

                    # Get output from quantized model and reference model

                    for images_in_one_batch, _ in data_loader_n_samples_bias_corr:
                        reference_output_batch = aimet_torch.bias_correction.get_output_data(
                            reference_layer, model_copy, images_in_one_batch)
                        quantized_model_output_batch = aimet_torch.bias_correction.get_output_data(
                            quantize_layer, model, images_in_one_batch)

                        if isinstance(reference_layer, torch.nn.Linear):
                            extended_shape = np.concatenate(
                                (reference_output_batch.shape, np.array([1,
                                                                         1])))
                            reference_output_batch = reference_output_batch.reshape(
                                extended_shape)
                            quantized_model_output_batch = quantized_model_output_batch.reshape(
                                extended_shape)

                        bias_correction.storePreActivationOutput(
                            reference_output_batch)
                        bias_correction.storeQuantizedPreActivationOutput(
                            quantized_model_output_batch)

                    aimet_torch.bias_correction.call_empirical_mo_correct_bias(
                        module, bias_correction)

                else:
                    aimet_torch.bias_correction.logger.info(
                        'Correcting layer %s using Analytical Bias Correction',
                        module_name)
                    aimet_torch.bias_correction.call_analytical_mo_correct_bias(
                        quantize_layer, bn_layer_info.input_bn,
                        bn_layer_info.in_activation_type)

                aimet_torch.bias_correction.logger.info(
                    'Corrected bias for the layer')

    aimet_torch.save_utils.SaveUtils.remove_quantization_wrappers(model)

    aimet_torch.bias_correction.logger.info('Completed bias correction')
def _prepare_model(model: torch.nn.Module) -> torch.nn.Module:
    # Remove last layer
    layers_without_fc = list(model.children())[:-1]
    return torch.nn.Sequential(*layers_without_fc)
Beispiel #10
0
def _find_module(root: torch.nn.Module, m: torch.nn.Module):
    for n, p in root.named_modules():
        if m is p:
            return n
    raise NameError('module is not installed as a submodule')
Beispiel #11
0
def add_auto_convert(module: torch.nn.Module) -> torch.nn.Module:
    def convert_to_dispatch_proxy(x):
        if isinstance(x, torch.Tensor):
            return x.as_subclass(
                QuantizationConvertTensorProxy)  # type: ignore[arg-type]
        else:
            return x

    module_id_to_fqn: Dict[int, str] = {}
    # Counter for global quantizeable ops, useful for intermediate activation
    # logging.
    global_op_idx = [0]

    global_disable_torch_function_override = False

    class QuantizationConvertTensorProxy(torch.Tensor):
        """
        An override of `torch.Tensor` to enable dynamic dispatch for
        quantization inference.

        For each function with a `__torch_fuction__` override, this proxy does
        the following for functions which need quantization:

        1. calls `_auto_quant_state.validate_cur_op` to validate that
           the currently seen op is the same as what was recorded during tracing
        2. calls `_auto_quant_state.op_convert_before_hook`.
        3. executes the function, with target, args and kwargs possibly modified
           by (2)
        4. calls `_auto_quant_state.inference_function_after_hook`.
        5. calls `_auto_quant_state.mark_cur_op_complete` to increment
           the current op index in preparation for the next op

        Otherwise, calls the original function.
        """
        @classmethod
        def __torch_function__(cls, func, types, args=(), kwargs=None):
            nonlocal global_disable_torch_function_override
            if (
                    # global override means disable the override here
                    global_disable_torch_function_override or
                    # to prevent printing things from going into an infinite loop
                    func == torch.Tensor.__repr__ or
                    # we don't need to override getters in this framework
                    func.__name__ == '__get__'):
                return super().__torch_function__(func, types, args, kwargs)

            kwargs = kwargs if kwargs else {}
            # if we are in a function, the current module is always a parent
            parent_module = cur_module
            hook_type = get_torch_function_hook_type(parent_module, func)

            if enable_logging:
                fqn_for_logging = module_id_to_fqn.get(
                    id(parent_module), 'unknown') if parent_module else None
                logger.debug(
                    f" fqn:{fqn_for_logging} _tf_ {func} " +
                    f"hook_type {hook_type} " +
                    # f"arg_types {[type(arg) for arg in args]}) " +
                    f"arg_dtypes {[arg.dtype if isinstance(arg, torch.Tensor) else None for arg in args]}"
                )

            if hook_type is HookType.OP_HOOKS:
                qstate: AutoQuantizationState = parent_module._auto_quant_state  # type: ignore[union-attr]
                # before hooks
                qstate.validate_cur_op(func)
                func, args, kwargs = qstate.op_convert_before_hook(
                    func, args, kwargs,
                    parent_module)  # type: ignore[arg-type]

                # forward
                output = super().__torch_function__(func, types, args, kwargs)
                # after hooks
                output = qstate.op_convert_after_hook(func, output,
                                                      global_op_idx)
                qstate.mark_cur_op_complete(func)

            elif hook_type is HookType.ARG_DEQUANTS:
                # TODO(future PR): handle more dtypes
                new_args = []
                for arg in args:
                    if isinstance(arg, torch.Tensor) and arg.is_quantized:
                        new_args.append(arg.dequantize())
                    else:
                        new_args.append(arg)
                args = tuple(new_args)
                output = super().__torch_function__(func, types, args, kwargs)

            else:  # HookType.NONE
                output = super().__torch_function__(func, types, args, kwargs)

            # TODO: is this right? Don't really understand this
            if output is NotImplemented:
                with torch._C.DisableTorchFunction():
                    output = func(
                        *args,
                        **kwargs).as_subclass(QuantizationConvertTensorProxy)
                assert output is not NotImplemented

            if enable_logging:
                fqn_for_logging = module_id_to_fqn.get(
                    id(parent_module), 'unknown') if parent_module else None
                out_dtype = None
                if isinstance(output, torch.Tensor):
                    out_dtype = output.dtype
                logger.debug(
                    f" fqn:{fqn_for_logging} _tf_ {func} out {out_dtype} end")

            return output

        def __repr__(self):
            return f'QuantizationConvertTensorProxy({super().__repr__()})'

    cur_module = None
    module_stack: List[torch.nn.Module] = []

    assert len(module.__class__.__bases__) == 1

    class QuantizationDispatchModule(module.__class__.__bases__[0]
                                     ):  # type: ignore[name-defined]
        """
        An override of user defined subclass of `nn.Module` to enable
        dynamic tracing for quantization, after model conversion
        to quantized domain.

        `cur_module` keeps track of the current module in the stack.

        Tensor arguments are converted to `QuantizationConvertTensorProxy`.

        We override the `__call__` function to do the following for each
        module:

        If the module is an op which needs quantization:

        1. calls `_auto_quant_state.validate_cur_op` to validate that
           the currently seen op is the same as what was recorded during tracing
        2. calls parent module's `._auto_quant_state.op_convert_before_hook`
        3. executes the original module forward
        4. calls parent module's `_auto_quant_state.op_convert_after_hook`
        5. calls `_auto_quant_state.mark_cur_op_complete` to increment
           the current op index in preparation for the next op

        If the module can contain children ops that need quantization:

        1. calls `_auto_quant_state.inputs_convert_hook` (not implemented yet)
        2. executes the original module forward
        3. calls `_auto_quant_state.outputs_convert_hook`

        Otherwise, calls the original module forward.
        """
        def __call__(self, *args, **kwargs):
            new_args = map_aggregate(args, convert_to_dispatch_proxy)
            new_kwargs = map_aggregate(kwargs, convert_to_dispatch_proxy)
            orig_module_call = torch.nn.Module.__call__
            orig_nn_sequential_forward = torch.nn.Sequential.forward

            def _patched_module_call(self, *args, **kwargs):
                nonlocal cur_module
                old_module = cur_module
                cur_module = self
                nonlocal global_disable_torch_function_override
                try:
                    parent_module = module_stack[-1] if len(
                        module_stack) else None
                    module_stack.append(self)
                    hook_type = get_module_hook_type(parent_module, cur_module)
                    if enable_logging:
                        fqn_for_logging = module_id_to_fqn.get(id(self), None)
                        logger.debug(
                            f" fqn: {fqn_for_logging} " +
                            f"_cl_ {type(self)} " +
                            f"arg_dtypes {[arg.dtype if isinstance(arg, torch.Tensor) else None for arg in args]} "
                            + f"hook_type {hook_type}")

                    if hook_type is HookType.OP_HOOKS:
                        # before hooks
                        qstate: AutoQuantizationState = \
                            parent_module._auto_quant_state  # type: ignore[union-attr, assignment]
                        qstate.validate_cur_op(cur_module)

                        # If we are in this hook, `cur_module` is a leaf module.
                        # Therefore, we do not need to override any of its
                        # children. Disabling the overrides for performance.
                        old_global_disable_torch_function_override = \
                            global_disable_torch_function_override
                        global_disable_torch_function_override = True

                        _, args, kwargs = qstate.op_convert_before_hook(
                            cur_module, args, kwargs, cur_module)
                        # forward
                        output = orig_module_call(self, *args, **kwargs)
                        # after hooks
                        output = qstate.op_convert_after_hook(
                            cur_module, output, global_op_idx)

                        # Re-enable the override.
                        global_disable_torch_function_override = \
                            old_global_disable_torch_function_override

                        qstate.mark_cur_op_complete(cur_module)

                    elif hook_type is HookType.MODULE_IO_HOOKS:
                        cur_qstate: AutoQuantizationState = cur_module._auto_quant_state

                        cur_qstate.reset_to_new_call()

                        # before hooks (TODO)
                        # forward
                        output = orig_module_call(self, *args, **kwargs)
                        # after hooks

                        # For the sake of performance, we assume no overrides
                        # are needed for quantizing/dequantizing things
                        old_global_disable_torch_function_override = \
                            global_disable_torch_function_override
                        global_disable_torch_function_override = True

                        output = cur_qstate.outputs_convert_hook(output)

                        global_disable_torch_function_override = \
                            old_global_disable_torch_function_override

                        cur_qstate.validate_is_at_last_seen_idx()

                    elif hook_type is HookType.ARG_DEQUANTS:
                        # TODO(future PR): handle more dtypes
                        new_args = []
                        for arg in args:
                            if isinstance(arg,
                                          torch.Tensor) and arg.is_quantized:
                                dequant = arg.dequantize().as_subclass(
                                    QuantizationConvertTensorProxy
                                )  # type: ignore[arg-type]
                                new_args.append(dequant)
                            else:
                                new_args.append(arg)
                        args = tuple(new_args)
                        output = orig_module_call(self, *args, **kwargs)

                    else:
                        output = orig_module_call(self, *args, **kwargs)

                    if enable_logging:
                        fqn_for_logging = module_id_to_fqn.get(id(self), None)
                        logger.debug(
                            f" fqn: {fqn_for_logging} " +
                            f"_cl_ {type(self)} " +
                            f"dtype {output.dtype if isinstance(output, torch.Tensor) else None} "
                            + "end")
                    return output
                finally:
                    module_stack.pop()
                    cur_module = old_module

            torch.nn.Module.__call__ = _patched_module_call
            torch.nn.Sequential.forward = _nn_sequential_patched_forward  # type: ignore[assignment]

            try:
                global_op_idx[0] = 0
                output = super().__call__(*new_args, **new_kwargs)

                def unwrap_proxy(a):
                    if isinstance(a, QuantizationConvertTensorProxy):
                        a.__class__ = torch.Tensor  # type: ignore[assignment]
                    return a

                output = map_aggregate(output, unwrap_proxy)
                return output
            finally:
                torch.nn.Module.__call__ = orig_module_call
                torch.nn.Sequential.forward = orig_nn_sequential_forward  # type: ignore[assignment]

        def rewrite_for_scripting(self):
            return auto_trace_rewriter.rewrite_for_scripting(self)

    pack_weights_for_functionals(module)
    attach_scale_zp_values_to_model(module)
    attach_op_convert_info_to_model(module)
    attach_output_convert_info_to_model(module)

    # Since eager mode convert could have changed the IDs of some modules,
    # populate the FQN map again
    for k, v in module.named_modules():
        module_id_to_fqn[id(v)] = k

    module.__class__ = QuantizationDispatchModule

    return module
Beispiel #12
0
def add_auto_observation(
    model: torch.nn.Module,
    qconfig_dict: Dict[str, Any],
    example_inputs: Tuple[Any],
    input_dtypes: Any = (
        torch.float, ),  # must be same structure as model inputs
    output_dtypes: Any = (
        torch.float, ),  # must be same structure as model outputs
    prepare_custom_config_dict: Dict[str, Any] = None,
) -> torch.nn.Module:
    def convert_to_interception_proxy(x):
        if isinstance(x, torch.Tensor):
            return x.as_subclass(
                QuantizationPrepareTensorProxy)  # type: ignore[arg-type]
        else:
            return x

    cur_module = None
    first_call = True
    module_stack: List[torch.nn.Module] = []
    # Counter for tensor IDs, will be modified inplace by quant state.
    # This is used to track tensors from output ops to input ops. For example,
    # if op_n had a tensor output with id=1, and op_n+2 had a tensor input with
    # id=1, we know that the output of op_n is the input to op_n+2. Note,
    # this is a list because it needs to incremented inplace.
    qtensor_id = [0]
    module_id_to_fqn: Dict[int, str] = {}

    # Counter for global quantizeable ops, useful for intermediate activation
    # logging.
    global_op_idx = [0]

    global_disable_torch_function_override = False

    class QuantizationPrepareTensorProxy(torch.Tensor):
        """
        An override of `torch.Tensor` to enable dynamic tracing for
        quantization.

        For each function with a `__torch_function__` override, this proxy does
        the following for functions which need quantization:

        1. calls `_auto_quant_state.validate_cur_op` to validate that
           the currently seen op is the same as what was recorded during tracing
        2. calls `_auto_quant_state.op_prepare_before_hook`
        3. executes the original function
        4. calls `_auto_quant_state.op_prepare_after_hook`
        5. calls `_auto_quant_state.mark_cur_op_complete` to increment
           the current op index in preparation for the next op

        Otherwise, calls the original function.
        """
        @classmethod
        def __torch_function__(cls, func, types, args=(), kwargs=None):
            nonlocal global_disable_torch_function_override
            if (
                    # global override means disable the override here
                    global_disable_torch_function_override or
                    # to prevent printing things from going into an infinite loop
                    func == torch.Tensor.__repr__ or
                    # we don't need to override getters in this framework
                    func.__name__ == '__get__'):
                return super().__torch_function__(func, types, args, kwargs)

            # if we are in a function, the current module is always a parent
            nonlocal cur_module
            parent_module = cur_module
            if enable_logging:
                if not is_activation_post_process(parent_module):
                    # logging for insides of obs/fq is not useful for this framework

                    # fqn map does not contain observers, which is why we
                    # cannot always assume that FQN exists
                    fqn_for_logging = module_id_to_fqn.get(
                        id(parent_module),
                        'unknown') if parent_module else None
                    logger.debug(
                        f' fqn:{fqn_for_logging} _tf_ {str(func)} len_args {len(args)}'
                    )

            nonlocal qtensor_id
            kwargs = kwargs if kwargs else {}
            hook_type = get_torch_function_hook_type(parent_module, func)

            if first_call and hook_type is not HookType.OP_HOOKS:
                qstate = getattr(parent_module, '_auto_quant_state', None)
                if qstate:
                    qstate.add_seen_op_type_without_op_hooks(func)

            if hook_type is HookType.OP_HOOKS:
                qstate = parent_module._auto_quant_state  # type: ignore[attr-defined]
                fqn = module_id_to_fqn[id(
                    parent_module)] if parent_module else None
                if not first_call:
                    qstate.validate_cur_op(func)
                # run "before" hook
                args, kwargs = qstate.op_prepare_before_hook(
                    func, args, kwargs, first_call, qtensor_id, fqn,
                    parent_module)
                # forward
                output = super().__torch_function__(func, types, args, kwargs)
                # run "after" hook
                output = qstate.op_prepare_after_hook(func, output, args,
                                                      first_call, qtensor_id,
                                                      parent_module,
                                                      global_op_idx)
                qstate.mark_cur_op_complete(func)
            else:
                output = super().__torch_function__(func, types, args, kwargs)

            # TODO: is this right? Don't really understand this
            if output is NotImplemented:
                with torch._C.DisableTorchFunction():
                    output = func(
                        *args,
                        **kwargs).as_subclass(QuantizationPrepareTensorProxy)
                assert output is not NotImplemented

            return output

        def __repr__(self):
            return f'QuantizationPrepareTensorProxy({super().__repr__()})'

        # TODO(future PR): add other math overrides

    class QuantizationInterceptionModule(type(model)):  # type: ignore[misc]
        """
        An override of user defined subclass of `nn.Module` to enable
        dynamic tracing for quantization.

        `cur_module` keeps track of the current module in the stack.

        During the fist call, an `AutoQuantizationState` object is created and
        attached to each non-leaf modules which we need to check for
        quantizeable operations.

        We override the `__call__` function to do the following for each
        module:

        If the module is an op which needs quantization:

        1. calls `_auto_quant_state.validate_cur_op` to validate that
           the currently seen op is the same as what was recorded during tracing
        2. calls parent module's `._auto_quant_state.op_prepare_before_hook`
        3. executes the original module forward
        4. calls parent module's `_auto_quant_state.op_prepare_after_hook`
        5. calls `_auto_quant_state.mark_cur_op_complete` to increment
           the current op index in preparation for the next op

        If the module can contain children ops that need quantization:

        1. calls `_auto_quant_state.inputs_prepare_hook` (not implemented yet)
        2. executes the original module forward
        3. calls `_auto_quant_state.outputs_prepare_hook`

        Otherwise, calls the original module forward.
        """
        def __call__(self, *args, **kwargs):
            new_args = map_aggregate(args, convert_to_interception_proxy)
            new_kwargs = map_aggregate(kwargs, convert_to_interception_proxy)
            orig_module_call = torch.nn.Module.__call__
            orig_nn_sequential_forward = torch.nn.Sequential.forward

            def _patched_module_call(self, *args, **kwargs):

                if enable_logging:
                    fqn = module_id_to_fqn.get(id(self), None)
                    logger.debug(f" fqn:{fqn} _cl_: {type(self)} start")

                nonlocal cur_module
                old_module = cur_module
                cur_module = self
                try:
                    parent_module = module_stack[-1] if len(
                        module_stack) else None
                    module_stack.append(self)
                    fqn = module_id_to_fqn.get(id(self), None)

                    hook_type = get_module_hook_type(parent_module, cur_module)

                    if first_call and hook_type is not HookType.OP_HOOKS and \
                            parent_module is not None:
                        parent_qstate_fc = getattr(parent_module,
                                                   '_auto_quant_state', None)
                        if parent_qstate_fc:
                            parent_qstate_fc.add_seen_op_type_without_op_hooks(
                                type(cur_module))

                    if hook_type is HookType.OP_HOOKS:
                        parent_qstate: AutoQuantizationState = \
                            parent_module._auto_quant_state  # type: ignore[union-attr, assignment]
                        # before hooks
                        if not first_call:
                            parent_qstate.validate_cur_op(cur_module)

                        # If we are in this hook, `cur_module` is a leaf module.
                        # Therefore, we do not need to override any of its
                        # children. Disabling the overrides for performance.
                        nonlocal global_disable_torch_function_override
                        old_global_disable_torch_function_override = \
                            global_disable_torch_function_override
                        global_disable_torch_function_override = True

                        # mypy ignore is used instead of assert because this
                        # runs on every forward and assert has a performance cost
                        args, kwargs = parent_qstate.op_prepare_before_hook(
                            cur_module, args, kwargs, first_call, qtensor_id,
                            fqn, cur_module)  # type: ignore[arg-type]

                        # original forward
                        output = orig_module_call(self, *args, **kwargs)

                        # Re-enable the overrides.
                        global_disable_torch_function_override = \
                            old_global_disable_torch_function_override

                        # after hooks
                        # TODO is it correct to call_cur_module twice here?
                        output = parent_qstate.op_prepare_after_hook(
                            cur_module, output, args, first_call, qtensor_id,
                            cur_module, global_op_idx)
                        parent_qstate.mark_cur_op_complete(cur_module)

                    elif hook_type is HookType.MODULE_IO_HOOKS:
                        # TODO(future PR): add inputs io hook

                        cur_qstate = cur_module._auto_quant_state
                        cur_qstate.reset_to_new_call()

                        # original forward
                        output = orig_module_call(self, *args, **kwargs)

                        # after hooks
                        output = cur_qstate.outputs_prepare_hook(
                            output, first_call, qtensor_id)
                        cur_qstate.validate_is_at_last_seen_idx()

                    elif hook_type is HookType.ARG_DEQUANTS:
                        output = orig_module_call(self, *args, **kwargs)
                        # if this fp32 was inplace, make sure to set the output dtype
                        # back to torch.float
                        if hasattr(output, '_qtensor_info'):
                            del output._qtensor_info

                    else:
                        output = orig_module_call(self, *args, **kwargs)

                    if enable_logging:
                        fqn = module_id_to_fqn.get(id(self), None)
                        logger.debug(f" fqn:{fqn} _cl_: {type(self)} end")

                    return output
                finally:
                    module_stack.pop()
                    cur_module = old_module

            torch.nn.Module.__call__ = _patched_module_call
            torch.nn.Sequential.forward = _nn_sequential_patched_forward  # type: ignore[assignment]
            nonlocal first_call
            try:
                if first_call:
                    # Create a list before iterating because we are adding new
                    # named modules inside the loop.
                    named_modules = list(self.named_modules())

                    # Record module instances which are leaves or children of leaves
                    leaves = set()
                    for fqn, child in named_modules:
                        if is_leaf(child, prepare_custom_config_dict):
                            for _, child_child in child.named_modules():
                                leaves.add(child_child)

                    for fqn, v in named_modules:

                        # fqn is the global FQN, i.e. 'foo.bar.baz'
                        # v is the module instance
                        #
                        # we need to associate the global FQN with SeenOp
                        # for modules, this is the module FQN
                        # for functions, this is the parent module FQN
                        module_id_to_fqn[id(v)] = fqn

                        if v in leaves:
                            continue

                        if v is self:
                            # for the top level module only, specify input
                            # and output dtypes
                            v._auto_quant_state = AutoQuantizationState(
                                qconfig_dict, fqn, input_dtypes, output_dtypes)
                            pass
                        else:
                            v._auto_quant_state = AutoQuantizationState(
                                qconfig_dict, fqn)

                global_op_idx[0] = 0

                output = super().__call__(*new_args, **new_kwargs)
                return output
            finally:
                torch.nn.Module.__call__ = orig_module_call
                torch.nn.Sequential.forward = orig_nn_sequential_forward  # type: ignore[assignment]
                first_call = False

    model.__class__ = QuantizationInterceptionModule
    # create the graph
    trace_with_inputs(model, example_inputs)
    return model
Beispiel #13
0
def to_device(module: t.nn.Module, use_cuda: bool = True):
    return module.cuda() if use_cuda and t.cuda.is_available() else module
Beispiel #14
0
def plot_grad_flow(
    model: torch.nn.Module,
    lines: bool = True,
    alpha: float = 0.5,
    line_width: float = 1.0,
) -> None:
    """

    Plots the gradients flowing through different layers in the net during training.
    Can be used for checking for possible gradient vanishing / exploding problems.

    Usage: After loss.backwards(), use plot_grad_flow(model) to visualize the gradient flow of model

    :param model:
    :type model:
    :param lines:
    :type lines:
    :param alpha:
    :type alpha:
    :param line_width:
    :type line_width:"""
    assert 0.0 < alpha <= 1.0
    ave_grads = []
    max_grads = []
    layers = []

    for n, p in model.named_parameters():
        if p.requires_grad and ("bias" not in n):
            layers.append(n)
            grad_abs = p.grad.abs()
            ave_grads.append(grad_abs.mean())
            max_grads.append(grad_abs.max())

    if lines:
        pyplot.plot(max_grads, alpha=alpha, linewidth=line_width, color="r")
        pyplot.plot(ave_grads, alpha=alpha, linewidth=line_width, color="g")
    else:
        pyplot.bar(
            numpy.arange(len(max_grads)),
            max_grads,
            alpha=alpha,
            linewidth=line_width,
            color="r",
        )
        pyplot.bar(
            numpy.arange(len(max_grads)),
            ave_grads,
            alpha=alpha,
            linewidth=line_width,
            color="g",
        )

    pyplot.hlines(0, 0, len(ave_grads) + 1, linewidth=1, color="k")
    pyplot.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical")
    pyplot.xlim(left=0, right=len(ave_grads))
    max_g = max(max_grads)
    margin = max_g * 1.1
    pyplot.ylim(
        bottom=max_g - margin, top=margin
    )  # zoom in on the lower gradient regions
    pyplot.xlabel("Layers")
    pyplot.ylabel("Gradient Magnitude")
    pyplot.title("Gradient Flow")
    pyplot.grid(True)
    pyplot.legend(
        [
            Line2D([0], [0], color="c", lw=4),
            Line2D([0], [0], color="b", lw=4),
            Line2D([0], [0], color="k", lw=4),
        ],
        ["max-gradient", "mean-gradient", "zero-gradient"],
    )
Beispiel #15
0
def check_same_model_params(model_a: torch.nn.Module, model_b: torch.nn.Module, message: str = "") -> None:
    for p_a, p_b in zip(model_a.parameters(), model_b.parameters()):
        assert torch.allclose(p_a, p_b, atol=1e-3), f"Model parameters differ\n{p_a} {p_b}\n" + message

    for b_a, b_b in zip(model_a.buffers(), model_b.buffers()):
        assert torch.allclose(b_a, b_b), f"Model buffers differ {b_a} - {b_b}\n" + message
Beispiel #16
0
def generate_param_groups(
    network: torch.nn.Module,
    layer_matches: Sequence[Callable],
    match_types: Sequence[str],
    lr_values: Sequence[float],
    include_others: bool = True,
):
    """
    Utility function to generate parameter groups with different LR values for optimizer.
    The output parameter groups have the same order as `layer_match` functions.

    Args:
        network: source network to generate parameter groups from.
        layer_matches: a list of callable functions to select or filter out network layer groups,
            for "select" type, the input will be the `network`, for "filter" type,
            the input will be every item of `network.named_parameters()`.
            for "select", the parameters will be
            `select_func(network).parameters()`.
            for "filter", the parameters will be
            `map(lambda x: x[1], filter(filter_func, network.named_parameters()))`
        match_types: a list of tags to identify the matching type corresponding to the `layer_matches` functions,
            can be "select" or "filter".
        lr_values: a list of LR values corresponding to the `layer_matches` functions.
        include_others: whether to include the rest layers as the last group, default to True.

    It's mainly used to set different LR values for different network elements, for example:

    .. code-block:: python

        net = Unet(spatial_dims=3, in_channels=1, out_channels=3, channels=[2, 2, 2], strides=[1, 1, 1])
        print(net)  # print out network components to select expected items
        print(net.named_parameters())  # print out all the named parameters to filter out expected items
        params = generate_param_groups(
            network=net,
            layer_matches=[lambda x: x.model[0], lambda x: "2.0.conv" in x[0]],
            match_types=["select", "filter"],
            lr_values=[1e-2, 1e-3],
        )
        # the groups will be a list of dictionaries:
        # [{'params': <generator object Module.parameters at 0x7f9090a70bf8>, 'lr': 0.01},
        #  {'params': <filter object at 0x7f9088fd0dd8>, 'lr': 0.001},
        #  {'params': <filter object at 0x7f9088fd0da0>}]
        optimizer = torch.optim.Adam(params, 1e-4)

    """
    layer_matches = ensure_tuple(layer_matches)
    match_types = ensure_tuple_rep(match_types, len(layer_matches))
    lr_values = ensure_tuple_rep(lr_values, len(layer_matches))

    def _get_select(f):
        def _select():
            return f(network).parameters()

        return _select

    def _get_filter(f):
        def _filter():
            # should eventually generate a list of network parameters
            return map(lambda x: x[1], filter(f, network.named_parameters()))

        return _filter

    params = []
    _layers = []
    for func, ty, lr in zip(layer_matches, match_types, lr_values):
        if ty.lower() == "select":
            layer_params = _get_select(func)
        elif ty.lower() == "filter":
            layer_params = _get_filter(func)
        else:
            raise ValueError(f"unsupported layer match type: {ty}.")

        params.append({"params": layer_params(), "lr": lr})
        _layers.extend(list(map(id, layer_params())))

    if include_others:
        params.append({
            "params":
            filter(lambda p: id(p) not in _layers, network.parameters())
        })

    return params
Beispiel #17
0
def train_one_epoch(args,
                    model: torch.nn.Module,
                    criterion: torch.nn.Module,
                    dataloader: Iterable,
                    optimizer: torch.optim.Optimizer,
                    device: torch.device,
                    epoch: int,
                    max_norm: float = 0):
    model.train()
    criterion.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
    metric_logger.add_meter('grad_norm', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 50

    for samples, targets, support_images, support_class_ids, support_targets in metric_logger.log_every(dataloader, print_freq, header):

        # * Sample Support Categories;
        # * Filters Targets (only keep GTs within support categories);
        # * Samples Support Images and Targets
        targets, support_images, support_class_ids, support_targets = \
            sample_support_categories(args, targets, support_images, support_class_ids, support_targets)

        samples = samples.to(device)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        support_images = support_images.to(device)
        support_class_ids = support_class_ids.to(device)
        support_targets = [{k: v.to(device) for k, v in t.items()} for t in support_targets]

        outputs = model(samples, targets=targets, supp_samples=support_images, supp_class_ids=support_class_ids, supp_targets=support_targets)
        loss_dict = criterion(outputs)
        weight_dict = criterion.weight_dict
        losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = utils.reduce_dict(loss_dict)
        loss_dict_reduced_unscaled = {f'{k}_unscaled': v for k, v in loss_dict_reduced.items()}
        loss_dict_reduced_scaled = {k: v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict}
        losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())

        loss_value = losses_reduced_scaled.item()

        if not math.isfinite(loss_value):
            print("Loss is NaN - {}. \nTraining terminated unexpectedly.\n".format(loss_value))
            print("loss dict:")
            print(loss_dict_reduced)
            sys.exit(1)

        optimizer.zero_grad()
        losses.backward()
        if max_norm > 0:
            grad_total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        else:
            grad_total_norm = utils.get_total_grad_norm(model.parameters(), max_norm)
        optimizer.step()

        metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled)
        metric_logger.update(class_error=loss_dict_reduced['class_error'])
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        metric_logger.update(grad_norm=grad_total_norm)

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)

    del support_images
    del support_class_ids
    del support_targets
    del samples
    del targets
    del outputs
    del weight_dict
    del grad_total_norm
    del loss_value
    del losses
    del loss_dict
    del loss_dict_reduced
    del loss_dict_reduced_scaled
    del loss_dict_reduced_unscaled

    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def export_gradient_graph(model: torch.nn.Module,
                          loss_fn: Callable[[Any, Any], Any],
                          example_input: torch.Tensor,
                          example_labels: torch.Tensor,
                          gradient_graph_path: Union[Path, str],
                          opset_version=12) -> None:
    r"""
    Build a gradient graph for `model` so that you can output gradients in an inference session when given specific input and corresponding labels.

    Args:
        model (torch.nn.Module): A gradient graph will be built for this model.

        loss_fn (Callable[[Any, Any], Any]): A function to compute the loss given the model's output and the `example_labels`.
        Predefined loss functions such as `torch.nn.CrossEntropyLoss()` will work but you might not be able to load the graph in other environments such as an InferenceSession in ONNX Runtime Web, instead, use a custom Python method.

        example_input (torch.Tensor): Example input that you would give your model for inference/prediction.

        example_labels (torch.Tensor): The expected labels for `example_input`.
        This could be the output of your model when given `example_input` but it might be different if your loss function expects labels to be different (e.g. when using cross entropy loss).

        gradient_graph_path (Union[Path, str]): The path to where you would like to save the gradient graph.

        opset_version (int): See `torch.onnx.export`.
    """

    # Make sure that loss nodes that expect multiple outputs are set up.
    CustomOpSymbolicRegistry.register_all()

    if not isinstance(gradient_graph_path, str):
        gradient_graph_path = str(gradient_graph_path)

    class WrapperModule(torch.nn.Module):
        def forward(self, model_input, expected_labels, *model_params):
            for param, set_param in zip(model.parameters(), model_params):
                param.data = set_param.data
            output = model(model_input)
            loss = loss_fn(output, expected_labels)
            return output, loss

    wrapped_model = WrapperModule()

    dynamic_axes = {
        'input': {
            0: 'batch_size',
        },
        'labels': {
            0: 'batch_size',
        },
        'output': {
            0: 'batch_size',
        },
    }

    args = (example_input, example_labels, *tuple(model.parameters()))
    model_param_names = tuple(name for name, _ in model.named_parameters())
    input_names = ['input', 'labels', *model_param_names]
    nodes_needing_gradients = set(name
                                  for name, param in model.named_parameters()
                                  if param.requires_grad)

    f = io.BytesIO()
    torch.onnx.export(wrapped_model,
                      args,
                      f,
                      export_params=True,
                      opset_version=opset_version,
                      do_constant_folding=False,
                      training=TrainingMode.TRAINING,
                      input_names=input_names,
                      output_names=['output', 'loss'],
                      dynamic_axes=dynamic_axes)

    exported_model = f.getvalue()
    builder = GradientGraphBuilder(exported_model, {'loss'},
                                   nodes_needing_gradients, 'loss')
    builder.build()
    builder.save(gradient_graph_path)
Beispiel #19
0
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
                    model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
                    amp: bool = True, teacher_model: torch.nn.Module = None,
                    teach_loss: torch.nn.Module = None, distill_token: bool=False, choices=None, mode='super', retrain_config=None):
    model.train()
    criterion.train()

    # set random seed
    random.seed(epoch)

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 10
    if mode == 'retrain':
        config = retrain_config
        model_module = unwrap_model(model)
        print(config)
        model_module.set_sample_config(config=config)
        print(model_module.get_sampled_params_numel(config))

    for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
        samples = samples.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        # sample random config
        if mode == 'super':
            config = sample_configs(choices=choices)
            model_module = unwrap_model(model)
            model_module.set_sample_config(config=config)
        elif mode == 'retrain':
            config = retrain_config
            model_module = unwrap_model(model)
            model_module.set_sample_config(config=config)
        if mixup_fn is not None:
            samples, targets = mixup_fn(samples, targets)
        if amp:
            with torch.cuda.amp.autocast():
                if teacher_model:
                    with torch.no_grad():
                        teach_output = teacher_model(samples)
                    _, teacher_label = teach_output.topk(1, 1, True, True)
                    if distill_token:
                        output_cls, output_dis = model(samples)
                        loss = 1/2 * criterion(output_cls, targets) + 1/2 * teach_loss(output_dis, teacher_label.squeeze())
                    else:
                        outputs = model(samples)
                        loss = 1/2 * criterion(outputs, targets) + 1/2 * teach_loss(outputs, teacher_label.squeeze())
                else:
                    outputs = model(samples)
                    loss = criterion(outputs, targets)
        else:
            outputs = model(samples)
            if teacher_model:
                with torch.no_grad():
                    teach_output = teacher_model(samples)
                _, teacher_label = teach_output.topk(1, 1, True, True)
                loss = 1 / 2 * criterion(outputs, targets) + 1 / 2 * teach_loss(outputs, teacher_label.squeeze())
            else:
                loss = criterion(outputs, targets)

        loss_value = loss.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            sys.exit(1)

        optimizer.zero_grad()

        # this attribute is added by timm on one optimizer (adahessian)
        if amp:
            is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
            loss_scaler(loss, optimizer, clip_grad=max_norm,
                    parameters=model.parameters(), create_graph=is_second_order)
        else:
            loss.backward()
            optimizer.step()

        torch.cuda.synchronize()
        if model_ema is not None:
            model_ema.update(model)

        metric_logger.update(loss=loss_value)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
Beispiel #20
0
def generic_loop(epochs: int,
                 device: torch.device,
                 opt: torch.optim,
                 loss_fn: torch.nn,
                 model: torch.nn.Module,
                 train_fact_iter,
                 valid_fact_iter,
                 weight_decay: float = 0.0,
                 clip_grads_at: float = -1.0,
                 lr_schedule=None,
                 eval_fn: Callable = None) -> (list, list, list):
    train_loss = []
    train_acc = []
    val_acc = []
    lrs = []

    # Epoch level
    for e in range(epochs):

        per_epoch_loss = []
        per_epoch_tr_acc = []

        # Train
        with Timer() as timer:

            # Make data
            #             trn_dl, val_dl = data_fn(data['train']), data_fn(data['valid'])
            #             trn_dl = train_fact_iter
            for x, y in tqdm(train_fact_iter):

                #                 if batch_start_hook: batch_start_hook()
                opt.zero_grad()

                if lr_schedule: lrs.append(update_lr(opt, lr_schedule.get()))

                _x = torch.tensor(x, dtype=torch.long, device=device)
                _y = torch.tensor(y, dtype=torch.long, device=device)
                try:
                    y_pred = model(_x, _y)
                except:
                    print(traceback.print_exc())
                    return (_x, _y)
                try:
                    loss = loss_fn(y_pred=y_pred, y_true=_y)
                except:
                    print(y_pred.shape, _y.shape)
                    return (_x, _y)
                per_epoch_tr_acc.append(
                    eval_fn(y_pred=y_pred, y_true=_y).item())
                per_epoch_loss.append(loss.item())

                loss.backward()

                if clip_grads_at > 0.0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   clip_grads_at)
                #                 for group in opt.param_groups:
                #                     for param in group['params']:
                #                         param.data = param.data.add(-weight_decay * group['lr'], param.data)

                opt.step()

        # Val
        with torch.no_grad():

            per_epoch_vl_acc = []
            for x, y in tqdm(valid_fact_iter):
                _x = torch.tensor(x, dtype=torch.long, device=device)
                _y = torch.tensor(y, dtype=torch.long, device=device)

                model.eval()

                y_pred = model(_x, _y)
                loss = loss_fn(y_pred=y_pred, y_true=_y)
                per_epoch_vl_acc.append(
                    eval_fn(y_pred=y_pred, y_true=_y).item())
                model.train()
        #                 per_epoch_vl_acc.append(loss.item())

        # Bookkeep
        #         per_epoch_vl_acc = [0] # @TODO:Remove this once we start calculating accuracy.
        train_acc.append(np.mean(per_epoch_tr_acc))
        train_loss.append(np.mean(per_epoch_loss))
        val_acc.append(np.mean(per_epoch_vl_acc))

        print(
            "Epoch: %(epo)03d | Loss: %(loss).5f | Tr_c: %(tracc)0.5f | Vl_c: %(vlacc)0.5f | Time: %(time).3f min"
            % {
                'epo': e,
                'loss': float(np.mean(per_epoch_loss)),
                'tracc': float(np.mean(per_epoch_tr_acc)),
                'vlacc': float(np.mean(per_epoch_vl_acc)),
                'time': timer.interval / 60.0
            })

    return train_acc, train_loss, val_acc, lrs
Beispiel #21
0
def save_samples(generator: torch.nn.Module,
                 cp_name: str,
                 cuda_mode: bool,
                 prefix: str,
                 save_dir='./',
                 nc=3,
                 im_size=64,
                 fig_size=(5, 5),
                 enhance=True,
                 SNGAN=False):
    generator.eval()

    n_tests = fig_size[0] * fig_size[1]

    to_pil = transforms.ToPILImage()
    to_tensor = transforms.ToTensor()

    if SNGAN:
        noise = torch.randn(n_tests, 128).view(-1, 128, 1, 1)
    else:
        noise = torch.randn(n_tests, 100).view(-1, 100, 1, 1)

    if cuda_mode:
        noise = noise.cuda()

    gen_image = generator(noise).view(-1, nc, im_size, im_size)
    gen_image = denorm(gen_image)

    #n_rows = np.sqrt(noise.size()[0]).astype(np.int32)
    #n_cols = np.sqrt(noise.size()[0]).astype(np.int32)
    n_cols, n_rows = fig_size
    fig, axes = plt.subplots(n_cols, n_rows, figsize=(n_rows, n_cols))
    for ax, img in zip(axes.flatten(), gen_image):
        ax.axis('off')
        ax.set_adjustable('box-forced')

        img = img.cpu().data

        if enhance:
            img_E = ImageEnhance.Sharpness(to_pil(img)).enhance(1.)
            img = to_tensor(img_E)

        # Scale to 0-255
        img = (((img - img.min()) * 255) /
               (img.max() - img.min())).numpy().transpose(1, 2, 0).astype(
                   np.uint8).squeeze()
        # ax.imshow(img.cpu().detach().view(image_size, image_size, 3).numpy(), cmap=None, aspect='equal')

        if nc == 1:
            ax.imshow(img, cmap="gray", aspect='equal')
        else:
            ax.imshow(img, cmap=None, aspect='equal')

    plt.subplots_adjust(wspace=0, hspace=0)
    #title = 'Samples'
    #fig.text(0.5, 0.04, title, ha='center')

    # save figure

    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    save_fn = save_dir + prefix + '_' + cp_name + '.pdf'
    plt.savefig(save_fn)

    plt.close()
Beispiel #22
0
def plot_predictions(model: torch.nn.Module,
                     x: tensor,
                     y: tensor,
                     num_samples: int = 1,
                     shade: bool = False,
                     x_names: List = None,
                     y_names: List = None,
                     axes=None,
                     fig=None,
                     save_filename=None):
    """
    Plot the predictions and the confidence interval
    :param x:
    :param y:
    :param: transform: Transform jointly X, Y (usually coming from a same dataset)
    :param: features_names: List of names for the features
    :return:
    """

    # Plot the prediction of the GP
    mean_cond, cov_cond = model.predict(x)

    # Sample some predictions
    if num_samples <= 0:
        samples = None
    else:
        samples = model.sample(x, num_samples=num_samples)

    if len(mean_cond.shape) == 1:
        mean_cond = mean_cond.unsqueeze(1)

    num_features = x.shape[1]
    # if num_features > 1:
    #     shade = False

    mean_add_std = mean_cond + cov_cond.mm(
        torch.ones((mean_cond.shape[0], 1)).to(DEVICE))
    mean_sub_std = mean_cond - cov_cond.mm(
        torch.ones((mean_cond.shape[0], 1)).to(DEVICE))

    x_train = model.x_train
    y_train = model.y_train

    if axes is None:
        fig, axes = plt.subplots(num_features,
                                 1,
                                 figsize=(3 * 1, 3 * num_features),
                                 squeeze=False,
                                 sharey=False,
                                 sharex=False)
    else:
        if len(axes.shape) == 1: axes = np.expand_dims(axes, 1)

    if num_features == 1:
        axes = np.asarray(axes)

    for ix in range(num_features):
        axes[ix, 0].plot(x_train[:, ix].cpu().data.numpy(),
                         y_train.cpu().data.numpy(),
                         marker='.',
                         ms=2,
                         label="Training",
                         linestyle='')
        if samples is not None:
            axes[ix, 0].plot(x[:, ix].cpu().data.numpy(),
                             samples.cpu().data.numpy(),
                             ms=4,
                             marker=".",
                             linestyle="",
                             label="Sample")

        axes[ix, 0].plot(x[:, ix].cpu().data.numpy(),
                         mean_cond.cpu().data.numpy(),
                         marker='o',
                         ms=4,
                         label="Pred.",
                         linestyle="")
        axes[ix, 0].plot(x[:, ix].cpu().data.numpy(),
                         y.cpu().data.numpy(),
                         '.',
                         ms=4,
                         label="Ground",
                         linestyle='')

        if shade:
            ind_sort = x[:, ix].sort(descending=False)[-1]
            axes[ix, 0].fill_between(
                x[ind_sort, ix].cpu().data.numpy(),
                mean_sub_std.cpu().data.numpy().squeeze()[ind_sort],
                mean_add_std.cpu().data.numpy().squeeze()[ind_sort],
                color="#dddddd")

        # axes[ix].set_title(f"Samples from the GP posterior")
        if x_names is not None:
            axes[ix, 0].set_xlabel(f"{x_names[ix]}")

        if y_names is not None:
            axes[ix, 0].set_ylabel(f"{y_names}")

    # plt.show()
    axes[-1, 0].legend()

    if fig is not None:
        fig.tight_layout(rect=[0, 0.03, 1, 0.95])

    if save_filename is not None:
        fig.savfig(save_filename, bbox_inches='tight', format='png', dpi=200)
        plt.close(fig)

    return fig, axes
def evaluate_cor(
    model: torch.nn.Module,
    test_file: ty.Union[str, pathlib.Path],
    span_digitizer: ty.Callable[[ty.Mapping[str, ty.Any]],
                                datatools.FeaturefulSpan],
    pair_feats_digitizer: ty.Callable[[ty.Mapping[str, str]],
                                      ty.Iterable[int]],
    loss_fun: ty.Callable = libdecofre.masked_multi_cross_entropy,
    device=None,
    num_workers: int = 0,
) -> ty.Tuple[torch.FloatTensor]:
    model = model.to(device)
    with tempfile.TemporaryDirectory(
            prefix="decofre_antecedents_") as temp_dir:
        test_set = datatools.AntecedentsDataset.from_json(
            test_file,
            span_digitizer=span_digitizer,
            pair_feats_digitizer=pair_feats_digitizer,
            cache_dir=temp_dir,
            set_name="test",
        )
        test_loader = torch.utils.data.DataLoader(
            dataset=test_set,
            sampler=torch.utils.data.BatchSampler(
                torch.utils.data.RandomSampler(test_set),
                batch_size=8,
                drop_last=False),
            collate_fn=lambda x: x[0],
            num_workers=num_workers,
        )

        logger.info("Evaluating on the test set")
        model.eval()
        evaluator = runners.Evaluator(
            model,
            loss=loss_fun,
            metrics={
                "antecedent_accuracy":
                runners.MultiLoss(
                    loss_fn=lambda x, y: _summable_antecedent_accuracy(
                        x.to(device), y.to(device)),
                    output_transform=runners.extract_output,
                    device=device,
                    loss_names=(
                        "total_accuracy",
                        "mention_new_accuracy",
                        "anaphora_accuracy",
                    ),
                ),
                "attributions":
                runners.MultiLoss(
                    loss_fn=lambda x, y: attributions(x.to(device), y.to(device
                                                                         )),
                    output_transform=runners.extract_output,
                    averaged=False,
                    device=device,
                    loss_names=(
                        "true_new",
                        "false_new",
                        "correct_link",
                        "false_link",
                        "wrong_link",
                    ),
                ),
            },
        )
        state = evaluator.run(test_loader)
    for name, value in state.metrics.items():
        logger.info(f"{name}: {value}")
    return state.metrics["loss"]
Beispiel #24
0
    def train_one_epoch(
        cls,
        model: torch.nn.Module,
        iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
        optimizers: Sequence[torch.optim.Optimizer],
        schedulers: Sequence[Optional[AbsScheduler]],
        reporter: SubReporter,
        options: TrainerOptions,
    ) -> bool:
        assert check_argument_types()

        # Note(kamo): assumes one optimizer
        assert cls.num_optimizers == 1, cls.num_optimizers
        assert len(optimizers) == 1, len(optimizers)
        optimizer = optimizers[0]
        scheduler = schedulers[0]

        grad_noise = options.grad_noise
        accum_grad = options.accum_grad
        grad_clip = options.grad_clip
        log_interval = options.log_interval
        no_forward_run = options.no_forward_run
        ngpu = options.ngpu
        distributed = isinstance(model,
                                 torch.nn.parallel.DistributedDataParallel)
        use_apex = options.train_dtype in ("O0", "O1", "O2", "O3")

        if log_interval is None:
            try:
                log_interval = max(len(iterator) // 20, 10)
            except TypeError:
                log_interval = 100

        model.train()
        all_steps_are_invalid = True
        # [For distributed] Because iteration counts are not always equals between
        # processes, send stop-flag to the other processes if iterator is finished
        iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")

        start_time = time.perf_counter()
        for iiter, (_, batch) in enumerate(
                reporter.measure_iter_time(iterator, "iter_time"), 1):
            assert isinstance(batch, dict), type(batch)

            if distributed:
                torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
                if iterator_stop > 0:
                    break

            batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
            if no_forward_run:
                all_steps_are_invalid = False
                reporter.register({})
                continue

            with reporter.measure_time("forward_time"):
                loss, stats, weight = model(**batch)
            if ngpu > 1 or distributed:
                # Apply weighted averaging for loss and stats
                loss = (loss * weight.type(loss.dtype)).sum()

                # if distributed, this method can also apply all_reduce()
                stats, weight = recursive_average(stats, weight, distributed)

                # Now weight is summation over all workers
                loss /= weight
            if distributed:
                # NOTE(kamo): Multiply world_size because DistributedDataParallel
                # automatically normalizes the gradient by world_size.
                loss *= torch.distributed.get_world_size()

            reporter.register(stats, weight)

            loss /= accum_grad
            with reporter.measure_time("backward_time"):
                if use_apex:
                    try:
                        from apex import amp
                    except ImportError:
                        logging.error(
                            "You need to install apex. "
                            "See https://github.com/NVIDIA/apex#linux")

                    with amp.scale_loss(loss, optimizers) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

            if iiter % accum_grad == 0:
                # gradient noise injection
                if grad_noise:
                    add_gradient_noise(
                        model,
                        reporter.get_total_count(),
                        duration=100,
                        eta=1.0,
                        scale_factor=0.55,
                    )

                # compute the gradient norm to check if it is normal or not
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), grad_clip)
                # PyTorch<=1.4, clip_grad_norm_ returns float value
                if not isinstance(grad_norm, torch.Tensor):
                    grad_norm = torch.tensor(grad_norm)

                if not torch.isfinite(grad_norm):
                    logging.warning(
                        f"The grad norm is {grad_norm}. Skipping updating the model."
                    )
                else:
                    all_steps_are_invalid = False
                    with reporter.measure_time("optim_step_time"):
                        optimizer.step()
                    if isinstance(scheduler, AbsBatchStepScheduler):
                        scheduler.step()
                optimizer.zero_grad()

                # Register lr and train/load time[sec/step],
                # where step refers to accum_grad * mini-batch
                reporter.register(
                    dict(
                        {
                            f"lr_{i}": pg["lr"]
                            for i, pg in enumerate(optimizer.param_groups)
                            if "lr" in pg
                        },
                        train_time=time.perf_counter() - start_time,
                    ),
                    # Suppress to increment the internal counter.
                    not_increment_count=True,
                )
                start_time = time.perf_counter()

            if iiter % log_interval == 0:
                logging.info(reporter.log_message())

        else:
            if distributed:
                iterator_stop.fill_(1)
                torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)

        return all_steps_are_invalid
def train_cor(
    model: torch.nn.Module,
    train_file: ty.Union[str, pathlib.Path],
    span_digitizer: ty.Callable[[ty.Mapping[str, ty.Any]],
                                datatools.FeaturefulSpan],
    pair_feats_digitizer: ty.Callable[[ty.Mapping[str, str]],
                                      ty.Iterable[int]],
    out_dir: ty.Union[str, pathlib.Path],
    temp_dir: ty.Union[str, pathlib.Path],
    device: torch.device,
    epochs: int,
    patience: int,
    train_batch_size: int = 8,
    dev_file: ty.Optional[ty.Union[str, pathlib.Path]] = None,
    optimizer=None,
    loss_fun: ty.Callable = libdecofre.masked_multi_cross_entropy,
    trainer_cls=runners.SinkTrainer,
    num_workers: int = 0,
    debug: bool = False,
    config: ty.Optional[ty.Dict[str, ty.Any]] = None,
    **kwargs,
) -> ty.Tuple[ignite.engine.Engine, ty.Iterable, ty.Dict[str, ty.Any]]:
    logger.info("Training antecedent scoring")
    model = model.to(device)
    train_set = datatools.AntecedentsDataset.from_json(
        train_file,
        span_digitizer=span_digitizer,
        pair_feats_digitizer=pair_feats_digitizer,
        cache_dir=temp_dir,
        set_name="train_cor",
    )
    train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        sampler=torch.utils.data.BatchSampler(
            torch.utils.data.RandomSampler(train_set),
            batch_size=train_batch_size,
            drop_last=False,
        ),
        collate_fn=lambda x: x[0],
        num_workers=num_workers,
    )

    if dev_file is not None:
        dev_set = datatools.AntecedentsDataset.from_json(
            dev_file,
            span_digitizer=span_digitizer,
            pair_feats_digitizer=pair_feats_digitizer,
            cache_dir=temp_dir,
            set_name="dev_cor",
        )
        dev_loader = torch.utils.data.DataLoader(
            dataset=dev_set,
            sampler=torch.utils.data.BatchSampler(
                torch.utils.data.RandomSampler(dev_set),
                batch_size=train_batch_size,
                drop_last=False,
            ),
            collate_fn=lambda x: x[0],
            num_workers=num_workers,
        )  # type: ty.Optional[torch.data.DataLoader]
    else:
        dev_loader = None

    cor_trainer = trainer_cls(
        model,
        checkpointed_models={"cor": model},
        loss_fun=loss_fun,
        optimizer=optimizer,
        dev_loss=loss_fun,
        dev_metrics={
            "antecedent_accuracy":
            runners.MultiLoss(
                loss_fn=lambda x, y: _summable_antecedent_accuracy(
                    x.to(device), y.to(device)),
                num_loss=3,
                output_transform=runners.extract_output,
                device=device,
                loss_names=(
                    "total_accuracy",
                    "mention_new_accuracy",
                    "anaphora_accuracy",
                ),
            ),
            "attributions":
            runners.MultiLoss(
                loss_fn=lambda x, y: attributions(x.to(device), y.to(device)),
                output_transform=runners.extract_output,
                averaged=False,
                device=device,
                loss_names=(
                    "true_new",
                    "false_new",
                    "correct_link",
                    "false_link",
                    "wrong_link",
                ),
            ),
        },
        save_path=out_dir,
        debug=debug,
        **kwargs,
    )
    if config["lr-schedule"] == "step":
        logger.debug("Training with 'step' LR schedule, using γ=0.95")
        torch_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                             len(train_loader),
                                                             gamma=0.95)
        scheduler = ignite.contrib.handlers.create_lr_scheduler_with_warmup(
            torch_lr_scheduler,
            warmup_start_value=0.0,
            warmup_end_value=optimizer.defaults["lr"],
            warmup_duration=1000,
        )
        cor_trainer.add_event_handler(ignite.engine.Events.ITERATION_STARTED,
                                      scheduler)

    return (
        cor_trainer,
        train_loader,
        {
            "max_epochs": epochs,
            "patience": patience,
            "dev_loader": dev_loader,
            "run_name": "antecedent_scoring",
        },
    )
Beispiel #26
0
def train_one_epoch(model: torch.nn.Module,
                    criterion: torch.nn.Module,
                    data_loader: Iterable,
                    optimizer: torch.optim.Optimizer,
                    device: torch.device,
                    epoch: int,
                    max_norm: float = 0):
    model.train()
    criterion.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter(
        'class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 10

    for samples, targets in metric_logger.log_every(data_loader, print_freq,
                                                    header):
        samples = samples.to(device)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        outputs = model(samples)
        loss_dict = criterion(outputs, targets)
        weight_dict = criterion.weight_dict
        losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys()
                     if k in weight_dict)

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = utils.reduce_dict(loss_dict)
        loss_dict_reduced_unscaled = {
            f'{k}_unscaled': v
            for k, v in loss_dict_reduced.items()
        }
        loss_dict_reduced_scaled = {
            k: v * weight_dict[k]
            for k, v in loss_dict_reduced.items() if k in weight_dict
        }
        losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())

        loss_value = losses_reduced_scaled.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            print(loss_dict_reduced)
            sys.exit(1)

        optimizer.zero_grad()
        losses.backward()
        if max_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        optimizer.step()

        metric_logger.update(loss=loss_value,
                             **loss_dict_reduced_scaled,
                             **loss_dict_reduced_unscaled)
        metric_logger.update(class_error=loss_dict_reduced['class_error'])
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
Beispiel #27
0
 def count_parameters(model: torch.nn.Module):
     return sum(p.numel() for p in model.parameters() if p.requires_grad)
Beispiel #28
0
 def __init__(self, model: torch.nn.Module, trained_param_path: Path, cuda):
     self.device = torch.device('cuda' if cuda >= 0 else 'cpu')
     self.trained_param_path = trained_param_path
     self.model = model.to(self.device)
     self._load_param()
     self.model.eval()
    def fit(self,
            model: torch.nn.Module,
            optimizer: torch.optim.Optimizer,
            criterion,
            train_dataloader: DataLoader,
            val_dataloader: DataLoader,
            epochs: int,
            metrics=None,
            device='cpu',
            **kwargs):
        self.device = device
        self.model = model.to(self.device)
        self.optimizer = optimizer
        self.criterion = criterion
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        self.metrics = metrics or dict()
        self.epochs = epochs

        self.config = {
            'model':
            repr(self.model),
            'optimizer':
            repr(self.optimizer),
            'criterion':
            repr(self.criterion),
            'metrics':
            ', '.join(metric_name for metric_name in self.metrics.keys()),
            'epochs':
            self.epochs,
        }

        self.train_losses = []
        self.val_losses = []

        for func in self.on_fit_start:
            func(self, self.config)
        try:
            with tqdm.tqdm(range(self.epochs),
                           desc="Training epochs",
                           unit="epoch") as epoch_progress_bar:
                for epoch in epoch_progress_bar:
                    for func in self.on_epoch_start:
                        func(self, epoch)
                    # train step
                    self.model.train()
                    with tqdm.tqdm(self.train_dataloader,
                                   desc="Train",
                                   unit="batch",
                                   leave=False) as train_progress_bar:
                        for batch_idx, (
                                batch_x,
                                batch_y) in enumerate(train_progress_bar):
                            for func in self.on_training_batch_start:
                                func(self, batch_x, batch_y)

                            self.optimizer.zero_grad()
                            loss, out = self.criterion(model,
                                                       batch_x.to(self.device),
                                                       batch_y.to(self.device))
                            loss.backward()
                            self.optimizer.step()

                            self.train_losses.append(loss.item())

                            for func in self.on_training_batch_end:
                                func(self, batch_x, batch_y, loss)
                    # validation step
                    self.model.eval()
                    metric_results = defaultdict(list)
                    with tqdm.tqdm(self.val_dataloader,
                                   desc='Validation',
                                   unit='batch',
                                   leave=False) as validation_progress_bar:
                        for batch_idx, (
                                batch_x,
                                batch_y) in enumerate(validation_progress_bar):
                            for func in self.on_validation_batch_start:
                                func(self, batch_x, batch_y)

                            with torch.no_grad():
                                loss, out = self.criterion(
                                    model, batch_x.to(self.device),
                                    batch_y.to(self.device))
                                self.val_losses.append(loss.item())

                                for metric_name, metric_func in self.metrics.items(
                                ):
                                    metric_results[metric_name].append(
                                        metric_func(batch_y, out))

                            for func in self.on_validation_batch_end:
                                func(self, batch_x, batch_y, loss,
                                     metric_results)

                    for func in self.on_epoch_end:
                        func(self, epoch)
        except KeyboardInterrupt:
            print('Keyboard Interruption, teardown')
        finally:
            for func in self.on_fit_end:
                func(self)

        return model
def train_one_epoch(model: torch.nn.Module,
                    criterion: DistillationLoss,
                    data_loader: Iterable,
                    optimizer: torch.optim.Optimizer,
                    device: torch.device,
                    epoch: int,
                    loss_scaler,
                    max_norm: float = 0,
                    model_ema: Optional[ModelEma] = None,
                    mixup_fn: Optional[Mixup] = None,
                    teacher=None,
                    set_training_mode=True):
    # TODO fix this for finetuning
    # model.train(set_training_mode)
    model.train()

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 100

    for samples, targets in metric_logger.log_every(data_loader, print_freq,
                                                    header):
        samples = samples.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        if mixup_fn is not None:
            samples, targets = mixup_fn(samples, targets)

        samples, targets, mix_rate, aux_targets = two_mix(
            samples, targets, num_patch=samples.shape[-1] // 16)

        with torch.cuda.amp.autocast():
            # outputs, r_loss = model(samples)
            outputs, r_loss, s_loss, proj = model(samples, aux_targets)
            loss = torch.sum(-targets * (1e-8 + outputs.softmax(dim=-1)).log(),
                             dim=-1).mean()

            loss_value = loss.item()
            loss += 1. * (r_loss + 1. * s_loss)

        if not math.isfinite(loss.item()):
            print("Loss is {}, stopping training".format(loss_value))
            sys.exit(1)

        optimizer.zero_grad()

        # this attribute is added by timm on one optimizer (adahessian)
        is_second_order = hasattr(
            optimizer, 'is_second_order') and optimizer.is_second_order
        loss_scaler(loss,
                    optimizer,
                    clip_grad=max_norm,
                    parameters=model.parameters(),
                    create_graph=is_second_order)

        torch.cuda.synchronize()
        if model_ema is not None:
            model_ema.update(model)

        metric_logger.update(loss=loss_value)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        metric_logger.meters['r'].update(r_loss.item(), n=targets.shape[0])
        # metric_logger.meters['p'].update(proj.item(), n=targets.shape[0])
        metric_logger.meters['s'].update(s_loss.item(), n=targets.shape[0])
        # metric_logger.meters['cos'].update(cos.item(), n=targets.shape[0])
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
Beispiel #31
0
def dump_sender_receiver(game: torch.nn.Module,
                         dataset: 'torch.utils.data.DataLoader',
                         gs: bool,
                         variable_length: bool,
                         device: Optional[torch.device] = None):
    """
    A tool to dump the interaction between Sender and Receiver
    :param game: A Game instance
    :param dataset: Dataset of inputs to be used when analyzing the communication
    :param gs: whether Gumbel-Softmax relaxation was used during training
    :param variable_length: whether variable-length communication is used
    :param device: device (e.g. 'cuda') to be used
    :return:
    """
    train_state = game.training  # persist so we restore it back
    game.eval()

    device = device if device is not None else common_opts.device

    sender_inputs, messages, receiver_inputs, receiver_outputs = [], [], [], []
    labels = []

    with torch.no_grad():
        for batch in dataset:
            # by agreement, each batch is (sender_input, labels) plus optional (receiver_input)
            sender_input = move_to(batch[0], device)
            print("sender_input: " + str(sender_input))
            print("batch" + str(batch))
            print("batch length" + str(len(batch)))
            receiver_input = None if len(batch) == 2 else move_to(
                batch[2], device)
            print("receiver_input: " + str(receiver_input))
            message = game.sender(sender_input)

            # Under GS, the only output is a message; under Reinforce, two additional tensors are returned.
            # We don't need them.
            if not gs: message = message[0]

            output = game.receiver(message, receiver_input)
            if not gs: output = output[0]

            if batch[1] is not None:
                labels.extend(batch[1])

            if isinstance(sender_input, list) or isinstance(
                    sender_input, tuple):
                sender_inputs.extend(zip(*sender_input))
            else:
                sender_inputs.extend(sender_input)

            if receiver_input is not None:
                receiver_inputs.extend(receiver_input)

            if gs:
                message = message.argmax(
                    dim=-1)  # actual symbols instead of one-hot encoded

            if not variable_length:
                messages.extend(message)
                receiver_outputs.extend(output)
            else:
                # A trickier part is to handle EOS in the messages. It also might happen that not every message has EOS.
                # We cut messages at EOS if it is present or return the entire message otherwise. Note, EOS id is always
                # set to 0.

                for i in range(message.size(0)):
                    eos_positions = (message[i, :] == 0).nonzero()
                    message_end = eos_positions[0].item(
                    ) if eos_positions.size(0) > 0 else -1
                    assert message_end == -1 or message[i, message_end] == 0
                    if message_end < 0:
                        messages.append(message[i, :])
                    else:
                        messages.append(message[i, :message_end + 1])

                    if gs:
                        receiver_outputs.append(output[i, message_end, ...])
                    else:
                        receiver_outputs.append(output[i, ...])

    game.train(mode=train_state)

    return sender_inputs, messages, receiver_inputs, receiver_outputs, labels
Beispiel #32
0
def train_model(model: torch.nn.Module,
                device: torch.device,
                optimizer: AdamW,
                file_path: Text,
                train_loader: DataLoader,
                valid_loader: DataLoader,
                label_to_cellid: Dict[int, int],
                num_epochs: int,
                best_valid_loss: float = float("Inf")):
    '''Main funcion for training model.'''
    # initialize running values
    running_loss = 0.0
    global_step = 0
    valid_accuracy_list, train_loss_list, valid_loss_list = [], [], []
    global_steps_list, true_points_list, pred_points_list = [], [], []

    # Training loop.
    model.train()

    for epoch in range(num_epochs):
        for batch, _ in train_loader:
            optimizer.zero_grad()
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            outputs = model(input_ids,
                            attention_mask=attention_mask,
                            labels=labels)
            loss = outputs.loss
            loss.mean().backward()
            optimizer.step()

            # Update running values.
            running_loss += loss.mean().item()
            global_step += 1

        # Evaluation step.
        valid_loss, predictions, true_vals, true_points, pred_points = evaluate(
            model, valid_loader, device, label_to_cellid)

        average_train_loss = running_loss / labels.shape[0]
        accuracy = accuracy_cells(true_vals, predictions)
        train_loss_list.append(average_train_loss)
        valid_loss_list.append(valid_loss)
        global_steps_list.append(global_step)
        valid_accuracy_list.append(accuracy)
        true_points_list.append(true_points)
        pred_points_list.append(pred_points)

        # Resetting running values.
        running_loss = 0.0

        logging.info('Epoch [{}/{}], Step [{}/{}], \
        Accuracy: {:.4f},Train Loss: {:.4f}, Valid Loss: {:.4f}'.format(
            epoch + 1, num_epochs, global_step, num_epochs * len(train_loader),
            accuracy, average_train_loss, valid_loss))

        # Save model and results in checkpoint.
        if best_valid_loss > valid_loss:
            best_valid_loss = valid_loss
            util.save_checkpoint(file_path + '/' + 'model.pt', model,
                                 best_valid_loss)
            util.save_metrics(file_path + '/' + 'metrics.pt', train_loss_list,
                              valid_loss_list, global_steps_list,
                              valid_accuracy_list, true_points_list,
                              pred_points_list)

            model.train()

    logging.info('Finished Training.')
Beispiel #33
0
def train(classifier: torch.nn.Module,
          x: torch.Tensor,
          y: torch.Tensor,
          test_x: torch.Tensor = torch.Tensor(),
          test_y: torch.Tensor = torch.Tensor(),
          batch_size: int = 16,
          num_epochs: int = 2,
          run_device: str = "cpu",
          learning_rate: float = 0.001,
          beta_1: float = 0.9,
          beta_2: float = 0.999,
          random_state: torch.ByteTensor = torch.get_rng_state().clone(),
          verbose: bool = False) -> Tuple[torch.nn.Module, torch.ByteTensor]:
    """
    Function to train classifiers and save the trained classifiers.

    Parameters
    ----------
    classifier: torch.nn.Module
    x: torch.Tensor
    y: torch.Tensor
    test_x: torch.Tensor
    test_y: torch.Tensor
    batch_size: int
    num_epochs: int
    run_device: str
    learning_rate: float
    beta_1: float
    beta_2: float
    random_state: torch.ByteTensor
    verbose: bool

    Returns
    -------
    Tuple[torch.nn.Module, torch.Tensor]

    """
    assert isinstance(classifier, torch.nn.Module)
    assert isinstance(x, torch.Tensor)
    assert isinstance(y, torch.Tensor)
    assert isinstance(test_x, torch.Tensor)
    assert isinstance(test_y, torch.Tensor)
    assert isinstance(batch_size, int) and (batch_size > 0)
    assert isinstance(num_epochs, int) and (num_epochs > 0)
    assert isinstance(run_device, str) and (run_device.lower()
                                            in ["cpu", "cuda"])
    assert isinstance(learning_rate, float) and (learning_rate > 0.0)
    assert isinstance(beta_1, float) and (0.0 <= beta_1 < 1.0)
    assert isinstance(beta_2, float) and (0.0 <= beta_2 < 1.0)
    assert isinstance(random_state, torch.ByteTensor)
    assert isinstance(verbose, bool)

    # Set the seed for generating random numbers.
    random_state_previous: torch.ByteTensor = torch.get_rng_state().clone()
    torch.set_rng_state(random_state)

    # Set the classifier.
    classifier_train: torch.nn.Module = copy.deepcopy(classifier.cpu())
    run_device_train: str = run_device.lower()
    if run_device_train == "cuda":
        assert torch.cuda.is_available()
        classifier_train = classifier_train.cuda()
        if torch.cuda.device_count() > 1:
            num_gpus: int = torch.cuda.device_count()
            classifier_train = torch.nn.DataParallel(classifier_train,
                                                     device_ids=list(
                                                         range(0, num_gpus)))

    # Set a criterion and optimizer.
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(params=classifier_train.parameters(),
                                 lr=learning_rate,
                                 betas=(beta_1, beta_2))

    # Covert PyTorch's Tensor to TensorDataset.
    x_train, y_train = x.clone(), y.clone()
    dataset_train = torch.utils.data.TensorDataset(x_train, y_train)
    dataloader_train = torch.utils.data.DataLoader(dataset_train,
                                                   batch_size=batch_size,
                                                   num_workers=0,
                                                   shuffle=True)

    has_test: bool = False
    if (test_x.size(0) > 0) and (test_y.size(0) > 0):
        x_test, y_test = test_x.clone(), test_y.clone()
        dataset_test = torch.utils.data.TensorDataset(x_test, y_test)
        dataloader_test = torch.utils.data.DataLoader(dataset_test,
                                                      batch_size=batch_size,
                                                      num_workers=0,
                                                      shuffle=False)
        has_test = True

    # Initialize the early_stopping object.
    early_stopping: _EarlyStopping = _EarlyStopping(patience=10,
                                                    delta=0.0,
                                                    verbose=False)

    log_template: str = "[{0}/{1}] Loss: {2:.4f}, Time: {3:.2f}s"
    log_template_test: str = "[{0}/{1}] Loss (Train): {2:.4f}, Loss (Test): {3:.4f}, Time: {4:.2f}s"

    list_loss: list = list()
    list_loss_test: list = list()

    # Train the classifiers.
    classifier_train.train()
    for epoch in range(1, num_epochs + 1):
        start_time: float = time.time()
        for (_, batch) in enumerate(dataloader_train, 0):
            batch_x, batch_y = batch
            if run_device_train == "cuda":
                batch_x, batch_y = batch_x.cuda(), batch_y.cuda()

            optimizer.zero_grad()
            output: torch.Tensor = classifier_train(batch_x)
            loss: torch.Tensor = criterion(output, batch_y)
            loss.backward()
            optimizer.step()

            list_loss.append(loss.detach().cpu().item())
        end_time: float = time.time()

        if has_test:
            classifier_train.eval()
            for (_, batch) in enumerate(dataloader_test, 0):
                batch_x, batch_y = batch
                if run_device_train == "cuda":
                    batch_x, batch_y = batch_x.cuda(), batch_y.cuda()

                output = classifier_train(batch_x)
                loss = criterion(output, batch_y)

                list_loss_test.append(loss.detach().cpu().item())
            classifier_train.train()

            early_stopping(loss=np.mean(list_loss_test),
                           model=classifier_train)
        else:
            early_stopping(loss=np.mean(list_loss), model=classifier_train)

        if verbose:
            if has_test:
                print(
                    log_template_test.format(epoch, num_epochs,
                                             np.mean(list_loss),
                                             np.mean(list_loss_test),
                                             end_time - start_time))
            else:
                print(
                    log_template.format(epoch, num_epochs, np.mean(list_loss),
                                        end_time - start_time))

        if early_stopping.early_stop:
            state_dict, rng_state = early_stopping.get_best_model()
            classifier_train.load_state_dict(state_dict)
            torch.set_rng_state(rng_state)
            break

    if isinstance(classifier_train, torch.nn.DataParallel):
        classifier_train = classifier_train.module

    random_state_after: torch.ByteTensor = torch.get_rng_state().clone()
    torch.set_rng_state(random_state_previous)

    return classifier_train.cpu(), random_state_after.clone()