def fit(self, **kwargs: xr.DataArray): """ Calls the compile and the fit method of the wrapped pytorch module. """ x, y = split_kwargs(kwargs) # check if gpu is available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(device) batch_size_str = self.fit_kwargs['batch_size'] epochs_str = self.fit_kwargs['epochs'] x_np = xarray_to_numpy(x) y_np = xarray_to_numpy(y) dataset = TimeSeriesDataset(x_np, y_np) train_loader = DataLoader(dataset=dataset, batch_size=int(batch_size_str), shuffle=True) learning_rate = 1e-4 scheduler = StepLR(self.optimizer, step_size=1) for epoch in range(1, int(epochs_str) + 1): self.model.train() for batch_idx, (data, target) in enumerate(train_loader): # put data to computing device (gpu) data, target = data.to(device), target.to(device) # Before the backward pass, use the optimizer object to zero all of the # gradients for the variables it will update (which are the learnable # weights of the model). This is because by default, gradients are # accumulated in buffers( i.e, not overwritten) whenever .backward() # is called. Checkout docs of torch.autograd.backward for more details. self.optimizer.zero_grad() # Forward pass: compute predicted y by passing x to the model. y_pred = self.model(data) # Compute loss loss = self.loss_fn(y_pred, target) # Backward pass: compute gradient of the loss with respect to model # parameters loss.backward() # Calling the step function on an Optimizer makes an update to its # parameters self.optimizer.step() # maybe do some printing and loss output # test routine self.model.eval() scheduler.step() self.model.to("cpu") self.is_fitted = True
def transform(self, **kwargs: xr.DataArray) -> xr.DataArray: """ Calls predict of the underlying PyTorch model. :param x: The dataset for which a prediction should be performed :return: The prediction. Each output of the PyTorch model is a separate data variable in the returned xarray. """ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(device) x_np = xarray_to_numpy(kwargs) x_dl = torch.from_numpy(x_np).float() self.model.eval() output = [] with torch.no_grad(): x_dl = x_dl.to(device) output = self.model(x_dl) pred = output.to("cpu").numpy() self.model.to("cpu") ret = numpy_to_xarray(pred, list(kwargs.values())[0], self.name) return ret