def sklearn_train_linear_model( model: LinearModel, dataloader: DataLoader, construct_kwargs: Dict[str, Any], sklearn_trainer: str = "Lasso", norm_input: bool = False, **fit_kwargs, ): r""" Alternative method to train with sklearn. This does introduce some slight overhead as we convert the tensors to numpy and then convert the resulting trained model to a `LinearModel` object. However, this conversion should be negligible. Please note that this assumes: 0. You have sklearn and numpy installed 1. The dataset can fit into memory Args model The model to train. dataloader The data to use. This will be exhausted and converted to numpy arrays. Therefore please do not feed an infinite dataloader. norm_input Whether or not to normalize the input sklearn_trainer The sklearn model to use to train the model. Please refer to sklearn.linear_model for a list of modules to use. construct_kwargs Additional arguments provided to the `sklearn_trainer` constructor fit_kwargs Other arguments to send to `sklearn_trainer`'s `.fit` method """ from functools import reduce try: import numpy as np except ImportError: raise ValueError("numpy is not available. Please install numpy.") try: import sklearn import sklearn.linear_model import sklearn.svm except ImportError: raise ValueError("sklearn is not available. Please install sklearn >= 0.23") if not sklearn.__version__ >= "0.23.0": warnings.warn( "Must have sklearn version 0.23.0 or higher to use " "sample_weight in Lasso regression." ) num_batches = 0 xs, ys, ws = [], [], [] for data in dataloader: if len(data) == 3: x, y, w = data else: assert len(data) == 2 x, y = data w = None xs.append(x.cpu().numpy()) ys.append(y.cpu().numpy()) if w is not None: ws.append(w.cpu().numpy()) num_batches += 1 x = np.concatenate(xs, axis=0) y = np.concatenate(ys, axis=0) if len(ws) > 0: w = np.concatenate(ws, axis=0) else: w = None if norm_input: mean, std = x.mean(0), x.std(0) x -= mean x /= std t1 = time.time() sklearn_model = reduce( lambda val, el: getattr(val, el), [sklearn] + sklearn_trainer.split(".") )(**construct_kwargs) try: sklearn_model.fit(x, y, sample_weight=w, **fit_kwargs) except TypeError: sklearn_model.fit(x, y, **fit_kwargs) warnings.warn( "Sample weight is not supported for the provided linear model!" " Trained model without weighting inputs. For Lasso, please" " upgrade sklearn to a version >= 0.23.0." ) t2 = time.time() # Convert weights to pytorch num_outputs = 1 if len(y.shape) == 1 else y.shape[1] weight_values = torch.FloatTensor(sklearn_model.coef_) # type: ignore bias_values = torch.FloatTensor([sklearn_model.intercept_]) # type: ignore model._construct_model_params( norm_type=None, weight_values=weight_values.view(num_outputs, -1), bias_value=bias_values.squeeze().unsqueeze(0), ) if norm_input: model.norm = NormLayer(mean, std) return {"train_time": t2 - t1}
def sgd_train_linear_model( model: LinearModel, dataloader: DataLoader, construct_kwargs: Dict[str, Any], max_epoch: int = 100, reduce_lr: bool = True, initial_lr: float = 0.01, alpha: float = 1.0, loss_fn: Callable = l2_loss, reg_term: Optional[int] = 1, patience: int = 10, threshold: float = 1e-4, running_loss_window: Optional[int] = None, device: Optional[str] = None, init_scheme: str = "zeros", debug: bool = False, ) -> Dict[str, float]: r""" Trains a linear model with SGD. This will continue to iterate your dataloader until we converged to a solution or alternatively until we have exhausted `max_epoch`. Convergence is defined by the loss not changing by `threshold` amount for `patience` number of iterations. Args: model The model to train dataloader The data to train it with. We will assume the dataloader produces either pairs or triples of the form (x, y) or (x, y, w). Where x and y are typical pairs for supervised learning and w is a weight vector. We will call `model._construct_model_params` with construct_kwargs and the input features set to `x.shape[1]` (`x.shape[0]` corresponds to the batch size). We assume that `len(x.shape) == 2`, i.e. the tensor is flat. The number of output features will be set to y.shape[1] or 1 (if `len(y.shape) == 1`); we require `len(y.shape) <= 2`. max_epoch The maximum number of epochs to exhaust reduce_lr Whether or not to reduce the learning rate as iterations progress. Halves the learning rate when the training loss does not move. This uses torch.optim.lr_scheduler.ReduceLROnPlateau and uses the parameters `patience` and `threshold` initial_lr The initial learning rate to use. alpha A constant for the regularization term. loss_fn The loss to optimise for. This must accept three parameters: x1 (predicted), x2 (labels) and a weight vector reg_term Regularization is defined by the `reg_term` norm of the weights. Please use `None` if you do not wish to use regularization. patience Defines the number of iterations in a row the loss must remain within `threshold` in order to be classified as converged. threshold Threshold for convergence detection. running_loss_window Used to report the training loss once we have finished training and to determine when we have converged (along with reducing the learning rate). The reported training loss will take the last `running_loss_window` iterations and average them. If `None` we will approximate this to be the number of examples in an epoch. init_scheme Initialization to use prior to training the linear model. device The device to send the model and data to. If None then no `.to` call will be used. debug Whether to print the loss, learning rate per iteration Returns This will return the final training loss (averaged with `running_loss_window`) """ loss_window: List[torch.Tensor] = [] min_avg_loss = None convergence_counter = 0 converged = False def get_point(datapoint): if len(datapoint) == 2: x, y = datapoint w = None else: x, y, w = datapoint if device is not None: x = x.to(device) y = y.to(device) if w is not None: w = w.to(device) return x, y, w # get a point and construct the model data_iter = iter(dataloader) x, y, w = get_point(next(data_iter)) model._construct_model_params( in_features=x.shape[1], out_features=y.shape[1] if len(y.shape) == 2 else 1, **construct_kwargs, ) model.train() assert model.linear is not None if init_scheme is not None: assert init_scheme in ["xavier", "zeros"] with torch.no_grad(): if init_scheme == "xavier": torch.nn.init.xavier_uniform_(model.linear.weight) else: model.linear.weight.zero_() if model.linear.bias is not None: model.linear.bias.zero_() optim = torch.optim.SGD(model.parameters(), lr=initial_lr) if reduce_lr: scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optim, factor=0.5, patience=patience, threshold=threshold ) t1 = time.time() epoch = 0 i = 0 while epoch < max_epoch: while True: # for x, y, w in dataloader if running_loss_window is None: running_loss_window = x.shape[0] * len(dataloader) y = y.view(x.shape[0], -1) if w is not None: w = w.view(x.shape[0], -1) i += 1 out = model(x) loss = loss_fn(y, out, w) if reg_term is not None: reg = torch.norm(model.linear.weight, p=reg_term) loss += reg.sum() * alpha if len(loss_window) >= running_loss_window: loss_window = loss_window[1:] loss_window.append(loss.clone().detach()) assert len(loss_window) <= running_loss_window average_loss = torch.mean(torch.stack(loss_window)) if min_avg_loss is not None: # if we haven't improved by at least `threshold` if average_loss > min_avg_loss or torch.isclose( min_avg_loss, average_loss, atol=threshold ): convergence_counter += 1 if convergence_counter >= patience: converged = True break else: convergence_counter = 0 if min_avg_loss is None or min_avg_loss >= average_loss: min_avg_loss = average_loss.clone() if debug: print( f"lr={optim.param_groups[0]['lr']}, Loss={loss}," + "Aloss={average_loss}, min_avg_loss={min_avg_loss}" ) loss.backward() optim.step() model.zero_grad() if scheduler: scheduler.step(average_loss) temp = next(data_iter, None) if temp is None: break x, y, w = get_point(temp) if converged: break epoch += 1 data_iter = iter(dataloader) x, y, w = get_point(next(data_iter)) t2 = time.time() return { "train_time": t2 - t1, "train_loss": torch.mean(torch.stack(loss_window)).item(), "train_iter": i, "train_epoch": epoch, }