def run(trainData, testData, _config): """ Run model on given train, test data and compute metrics """ device = utils.getDevice() # Use 20% of the training data as validation (for early stopping, hyperparameter tuning) trainIdx, validIdx = utils.getRandomSplit(trainData.shape[0], [80, 20]) validData = SubsetStar(trainData, validIdx) trainData = SubsetStar(trainData, trainIdx) model = CGFF(**_config).to(device) model.train() model.fit( trainData=trainData, validData=validData, fitParams=_config, ) model = torch.load(_config["fileName"]).to(device) model.eval() trainResults = model.test(trainData) validResults = model.test(validData) testResults = {} for key, testData_ in testData.items(): testResults_ = model.test(testData_) testResults[key] = testResults_ return trainResults, validResults, testResults
def computeLoss(self, dataInGPU, loader=None): if dataInGPU: # Batch dataset is given X, y = dataInGPU outs = self(X) loss = F.nll_loss(outs, y) penalty = self.computePenalty() elif loader: # Loader for the full dataset is given # Returns loss in float (no grad). loss = 0 device = utils.getDevice() with torch.no_grad(): for batchIdx, batchData in enumerate(loader): X, y = batchData[:] X, y = X.to(device), y.to(device) outs = self.forward(X) loss += F.nll_loss(outs, y, reduction="sum").item() loss /= len(loader.dataset) penalty = self.computePenalty().item() else: raise AttributeError return loss + penalty, {"loss": loss, "penalty": penalty}
def test(self, testData, metrics=None): """ Test the model and plot curves """ testLoader = DataLoader(testData, batch_size=1000, num_workers=5, pin_memory=True) device = utils.getDevice() testLoss = 0 queue = thqueue.Queue(10) dataProducer = threading.Thread(target=producer, args=(device, queue, testLoader)) dataProducer.start() results = {} with torch.no_grad(): while True: batchIdx, X, y = queue.get() if X is None: break outs = self.forward(X) testLoss += F.nll_loss(outs, y, reduction="sum").item() if metrics: for metric in metrics: metricName = metric.__name__.split("_")[0] if metricName in results: results[metricName] += metric(y, outs, X=X, model=self) else: results[metricName] = metric(y, outs, X=X, model=self) testLoss /= len(testData) for metricName in results.keys(): results[metricName] = results[metricName] / len(testData) results["loss"] = testLoss results["penalty"] = self.computePenalty().item() results["lossAndPenalty"] = results["loss"] + results["penalty"] return results
def computeLoss(self, data, loader=None): if data: _, X, y = data[:] device = utils.getDevice() X, y = X.to(device), y.to(device) outs = self.forward(X) loss = ((outs - y)**2).mean() penalty = torch.tensor(0) if self.penaltyAlpha > 0: penalty = self.penaltyAlpha * self.cgLayer.penalty( mode=self.penaltyMode, T=self.penaltyT) elif loader: # Loader for the full dataset is given # Returns loss in float (no grad). loss = 0 device = utils.getDevice() with torch.no_grad(): for batchIdx, batchData in enumerate(loader): _, X, y = batchData[:] X, y = X.to(device), y.to(device) outs = self.forward(X) loss += ((outs - y)**2).sum() loss /= len(loader.dataset) penalty = torch.tensor(0) if self.penaltyAlpha > 0: penalty = self.penaltyAlpha * self.cgLayer.penalty( mode=self.penaltyMode, T=self.penaltyT) else: raise AttributeError return loss + penalty, {"loss": loss, "penalty": penalty}
def test(self, testData): """ Test the model """ device = utils.getDevice() testLoader = DataLoader(testData, batch_size=128, num_workers=2, pin_memory=True) testLoss = 0 testAcc = 0 results = {} with torch.no_grad(): for batchIdx, batchData in enumerate(testLoader): _, X, y = batchData[:] X, y = X.to(device), y.to(device) outs = self.forward(X) testLoss += ((outs - y)**2).sum().item() testAcc += (y == torch.round(outs)).float().sum().item() testLoss /= len(testData) testAcc = testAcc * 100 / len(testData) results["loss"] = testLoss results["accuracy"] = testAcc penalty = torch.tensor(0.0) if self.penaltyAlpha > 0: penalty = self.penaltyAlpha * self.cgLayer.penalty( mode=self.penaltyMode, T=self.penaltyT) results["penalty"] = penalty.item() results["lossAndPenalty"] = results["loss"] + results["penalty"] return results
trainResults = model.test(trainData) validResults = model.test(validData) testResults = {} for key, testData_ in testData.items(): testResults_ = model.test(testData_) testResults[key] = testResults_ return trainResults, validResults, testResults if __name__ == '__main__': basisDir = "data/basis" os.makedirs(basisDir, exist_ok=True) device = utils.getDevice() vocabularySize = 100 nSamples = 10000 task = SumTaskDataset print(task) trainData = task( nSamples=nSamples, sequenceLength=10, vocabularyRange=(1, vocabularySize), inputTransform=np.sort, ) testDataOOD = task(nSamples=nSamples, sequenceLength=10, vocabularyRange=(1, vocabularySize))
def fit(self, trainData, validData, fitParams): """ Fit the model to the training data Parameters ---------- trainData : Train Dataset validData : Validation Dataset fitParams : Dictionary with parameters for fitting. (lr, weightDecay(l2), lrSchedulerStepSize, fileName, batchSize, lossName, patience, numEpochs) Returns ------- None """ log = utils.getLogger() tqdmDisable = fitParams.get("tqdmDisable", False) device = utils.getDevice() optimizer = optim.SGD(self.parameters(), lr=fitParams["lr"], momentum=fitParams["momentum"]) fileName = fitParams["fileName"] N = len(trainData) if fitParams.get("nMinibatches"): batchSize = 1 << int(math.log2(N // fitParams["nMinibatches"])) else: batchSize = fitParams.get("batchSize", 0) if batchSize == 0 or batchSize >= len(trainData): batchSize = len(trainData) bestValidLoss = np.inf patience = fitParams["patience"] numEpochs = fitParams["numEpochs"] validationFrequency = 1 trainLoader = DataLoader( trainData, batch_size=batchSize, shuffle=True, num_workers=5, pin_memory=True, ) validLoader = DataLoader(validData, batch_size=1000, shuffle=True, num_workers=1, pin_memory=True) trainLoss, _ = self.computeLoss(dataInGPU=None, loader=trainLoader) trainPenalty = _["penalty"] epochTime = 0 for epoch in tqdm(range(1, numEpochs + 1), leave=False, desc="Epochs", disable=tqdmDisable): validLoss, _ = self.computeLoss(dataInGPU=None, loader=validLoader) validPenalty = _["penalty"] saved = "" if (epoch == 1 or epoch > patience or epoch >= numEpochs or epoch % validationFrequency == 0) and validLoss < bestValidLoss: # saved = "(Saved to {})".format(fileName) saved = "(Saved model)" torch.save(self, fileName) if validLoss < 0.995 * bestValidLoss: patience = np.max([epoch * 2, patience]) bestValidLoss = validLoss if epoch > patience: break log.info( f"{epoch} out of {min(patience, numEpochs)} | Train Loss: {trainLoss:.4f} ({trainPenalty:.4f}) | Valid Loss: {validLoss:.4f} ({validPenalty: .4f}) | {saved}" ) trainLoss = 0.0 trainPenalty = 0.0 queue = thqueue.Queue(10) dataProducer = threading.Thread(target=producer, args=(device, queue, trainLoader)) dataProducer.start() startTime = time.time() while True: batchIdx, X, y = queue.get() if X is None: break optimizer.zero_grad() # zero the gradient buffer batchTrainLoss, lossAndPenalty = self.computeLoss((X, y)) batchTrainLoss.backward() optimizer.step() trainLoss += float(batchTrainLoss.detach().cpu()) trainPenalty += float(lossAndPenalty["penalty"].detach().cpu()) epochTime = time.time() - startTime log.debug(f"Time taken this epoch: {epochTime}") trainLoss /= batchIdx + 1 trainPenalty /= batchIdx + 1 if self.kill_now: break
def __init__( self, in_channels, out_channels, invariant_transforms, kernel_size=3, stride=1, padding=0, bias=True, precomputed_basis_folder=".", penaltyAlpha=0, ): super().__init__() self.ksize = kernel_size kernel_size = _pair(kernel_size) stride = _pair(stride) padding = _pair(padding) self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.padding = padding # A list of m groups for this layer given in the form of their respective Reynolds operator function # For example: [IS.G_rotation, IS.G_flip] self.invariant_transforms = invariant_transforms self.penaltyAlpha = penaltyAlpha if IS.G_color_permutation in self.invariant_transforms: # Channels is part of the basis self.Wshape = (self.in_channels, *self.kernel_size) sameBasisAcross = 1 else: # Same basis is enough for all channels self.Wshape = self.kernel_size sameBasisAcross = self.in_channels # Load the basis from the specified folder. basisFileName = IS._getBasisFileName(listT=self.invariant_transforms, Wshape=self.Wshape) if basisFileName not in loadedBases: self.invariant_transforms, basisList, self.basisConfigs = self.getBasis( precomputed_basis_folder) basisList = [torch.Tensor(basis.T) for basis in basisList] self.basisShapes = [basis.shape[0] for basis in basisList] self.basis = torch.cat(basisList).to(utils.getDevice()) loadedBases[basisFileName] = (self.invariant_transforms, self.basisConfigs, self.basisShapes, self.basis) self.invariant_transforms, self.basisConfigs, self.basisShapes, self.basis = loadedBases[ basisFileName] # Weights are the parameters of the linear combination corresponding to each basis vector. self.weights = nn.Parameter( torch.Tensor(out_channels, sameBasisAcross, self.basis.shape[0], 1)) if bias: self.bias = Parameter(torch.Tensor(out_channels)) else: self.register_parameter("bias", None) self.reset_parameters()