Exemplo n.º 1
0
Arquivo: train.py Projeto: xvdp/captum
def sklearn_train_linear_model(
    model: LinearModel,
    dataloader: DataLoader,
    construct_kwargs: Dict[str, Any],
    sklearn_trainer: str = "Lasso",
    norm_input: bool = False,
    **fit_kwargs,
):
    r"""
    Alternative method to train with sklearn. This does introduce some slight
    overhead as we convert the tensors to numpy and then convert the resulting
    trained model to a `LinearModel` object. However, this conversion
    should be negligible.

    Please note that this assumes:

    0. You have sklearn and numpy installed
    1. The dataset can fit into memory

    Args
        model
            The model to train.
        dataloader
            The data to use. This will be exhausted and converted to numpy
            arrays. Therefore please do not feed an infinite dataloader.
        norm_input
            Whether or not to normalize the input
        sklearn_trainer
            The sklearn model to use to train the model. Please refer to
            sklearn.linear_model for a list of modules to use.
        construct_kwargs
            Additional arguments provided to the `sklearn_trainer` constructor
        fit_kwargs
            Other arguments to send to `sklearn_trainer`'s `.fit` method
    """
    from functools import reduce

    try:
        import numpy as np
    except ImportError:
        raise ValueError("numpy is not available. Please install numpy.")

    try:
        import sklearn
        import sklearn.linear_model
        import sklearn.svm
    except ImportError:
        raise ValueError("sklearn is not available. Please install sklearn >= 0.23")

    if not sklearn.__version__ >= "0.23.0":
        warnings.warn(
            "Must have sklearn version 0.23.0 or higher to use "
            "sample_weight in Lasso regression."
        )

    num_batches = 0
    xs, ys, ws = [], [], []
    for data in dataloader:
        if len(data) == 3:
            x, y, w = data
        else:
            assert len(data) == 2
            x, y = data
            w = None

        xs.append(x.cpu().numpy())
        ys.append(y.cpu().numpy())
        if w is not None:
            ws.append(w.cpu().numpy())
        num_batches += 1

    x = np.concatenate(xs, axis=0)
    y = np.concatenate(ys, axis=0)
    if len(ws) > 0:
        w = np.concatenate(ws, axis=0)
    else:
        w = None

    if norm_input:
        mean, std = x.mean(0), x.std(0)
        x -= mean
        x /= std

    t1 = time.time()
    sklearn_model = reduce(
        lambda val, el: getattr(val, el), [sklearn] + sklearn_trainer.split(".")
    )(**construct_kwargs)
    try:
        sklearn_model.fit(x, y, sample_weight=w, **fit_kwargs)
    except TypeError:
        sklearn_model.fit(x, y, **fit_kwargs)
        warnings.warn(
            "Sample weight is not supported for the provided linear model!"
            " Trained model without weighting inputs. For Lasso, please"
            " upgrade sklearn to a version >= 0.23.0."
        )

    t2 = time.time()

    # Convert weights to pytorch
    num_outputs = 1 if len(y.shape) == 1 else y.shape[1]
    weight_values = torch.FloatTensor(sklearn_model.coef_)  # type: ignore
    bias_values = torch.FloatTensor([sklearn_model.intercept_])  # type: ignore
    model._construct_model_params(
        norm_type=None,
        weight_values=weight_values.view(num_outputs, -1),
        bias_value=bias_values.squeeze().unsqueeze(0),
    )

    if norm_input:
        model.norm = NormLayer(mean, std)

    return {"train_time": t2 - t1}
Exemplo n.º 2
0
Arquivo: train.py Projeto: xvdp/captum
def sgd_train_linear_model(
    model: LinearModel,
    dataloader: DataLoader,
    construct_kwargs: Dict[str, Any],
    max_epoch: int = 100,
    reduce_lr: bool = True,
    initial_lr: float = 0.01,
    alpha: float = 1.0,
    loss_fn: Callable = l2_loss,
    reg_term: Optional[int] = 1,
    patience: int = 10,
    threshold: float = 1e-4,
    running_loss_window: Optional[int] = None,
    device: Optional[str] = None,
    init_scheme: str = "zeros",
    debug: bool = False,
) -> Dict[str, float]:
    r"""
    Trains a linear model with SGD. This will continue to iterate your
    dataloader until we converged to a solution or alternatively until we have
    exhausted `max_epoch`.

    Convergence is defined by the loss not changing by `threshold` amount for
    `patience` number of iterations.

    Args:
        model
            The model to train
        dataloader
            The data to train it with. We will assume the dataloader produces
            either pairs or triples of the form (x, y) or (x, y, w). Where x and
            y are typical pairs for supervised learning and w is a weight
            vector.

            We will call `model._construct_model_params` with construct_kwargs
            and the input features set to `x.shape[1]` (`x.shape[0]` corresponds
            to the batch size). We assume that `len(x.shape) == 2`, i.e. the
            tensor is flat. The number of output features will be set to
            y.shape[1] or 1 (if `len(y.shape) == 1`); we require `len(y.shape)
            <= 2`.
        max_epoch
            The maximum number of epochs to exhaust
        reduce_lr
            Whether or not to reduce the learning rate as iterations progress.
            Halves the learning rate when the training loss does not move. This
            uses torch.optim.lr_scheduler.ReduceLROnPlateau and uses the
            parameters `patience` and `threshold`
        initial_lr
            The initial learning rate to use.
        alpha
            A constant for the regularization term.
        loss_fn
            The loss to optimise for. This must accept three parameters:
            x1 (predicted), x2 (labels) and a weight vector
        reg_term
            Regularization is defined by the `reg_term` norm of the weights.
            Please use `None` if you do not wish to use regularization.
        patience
            Defines the number of iterations in a row the loss must remain
            within `threshold` in order to be classified as converged.
        threshold
            Threshold for convergence detection.
        running_loss_window
            Used to report the training loss once we have finished training and
            to determine when we have converged (along with reducing the
            learning rate).

            The reported training loss will take the last `running_loss_window`
            iterations and average them.

            If `None` we will approximate this to be the number of examples in
            an epoch.
        init_scheme
            Initialization to use prior to training the linear model.
        device
            The device to send the model and data to. If None then no `.to` call
            will be used.
        debug
            Whether to print the loss, learning rate per iteration

    Returns
        This will return the final training loss (averaged with
        `running_loss_window`)
    """

    loss_window: List[torch.Tensor] = []
    min_avg_loss = None
    convergence_counter = 0
    converged = False

    def get_point(datapoint):
        if len(datapoint) == 2:
            x, y = datapoint
            w = None
        else:
            x, y, w = datapoint

        if device is not None:
            x = x.to(device)
            y = y.to(device)
            if w is not None:
                w = w.to(device)

        return x, y, w

    # get a point and construct the model
    data_iter = iter(dataloader)
    x, y, w = get_point(next(data_iter))

    model._construct_model_params(
        in_features=x.shape[1],
        out_features=y.shape[1] if len(y.shape) == 2 else 1,
        **construct_kwargs,
    )
    model.train()

    assert model.linear is not None

    if init_scheme is not None:
        assert init_scheme in ["xavier", "zeros"]

        with torch.no_grad():
            if init_scheme == "xavier":
                torch.nn.init.xavier_uniform_(model.linear.weight)
            else:
                model.linear.weight.zero_()

            if model.linear.bias is not None:
                model.linear.bias.zero_()

    optim = torch.optim.SGD(model.parameters(), lr=initial_lr)
    if reduce_lr:
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optim, factor=0.5, patience=patience, threshold=threshold
        )

    t1 = time.time()
    epoch = 0
    i = 0
    while epoch < max_epoch:
        while True:  # for x, y, w in dataloader
            if running_loss_window is None:
                running_loss_window = x.shape[0] * len(dataloader)

            y = y.view(x.shape[0], -1)
            if w is not None:
                w = w.view(x.shape[0], -1)

            i += 1

            out = model(x)

            loss = loss_fn(y, out, w)
            if reg_term is not None:
                reg = torch.norm(model.linear.weight, p=reg_term)
                loss += reg.sum() * alpha

            if len(loss_window) >= running_loss_window:
                loss_window = loss_window[1:]
            loss_window.append(loss.clone().detach())
            assert len(loss_window) <= running_loss_window

            average_loss = torch.mean(torch.stack(loss_window))
            if min_avg_loss is not None:
                # if we haven't improved by at least `threshold`
                if average_loss > min_avg_loss or torch.isclose(
                    min_avg_loss, average_loss, atol=threshold
                ):
                    convergence_counter += 1
                    if convergence_counter >= patience:
                        converged = True
                        break
                else:
                    convergence_counter = 0
            if min_avg_loss is None or min_avg_loss >= average_loss:
                min_avg_loss = average_loss.clone()

            if debug:
                print(
                    f"lr={optim.param_groups[0]['lr']}, Loss={loss},"
                    + "Aloss={average_loss}, min_avg_loss={min_avg_loss}"
                )

            loss.backward()

            optim.step()
            model.zero_grad()
            if scheduler:
                scheduler.step(average_loss)

            temp = next(data_iter, None)
            if temp is None:
                break
            x, y, w = get_point(temp)

        if converged:
            break

        epoch += 1
        data_iter = iter(dataloader)
        x, y, w = get_point(next(data_iter))

    t2 = time.time()
    return {
        "train_time": t2 - t1,
        "train_loss": torch.mean(torch.stack(loss_window)).item(),
        "train_iter": i,
        "train_epoch": epoch,
    }