예제 #1
0
def run(trainData, testData, _config):
    """
    Run model on given train, test data and compute metrics
    """
    device = utils.getDevice()

    # Use 20% of the training data as validation (for early stopping, hyperparameter tuning)
    trainIdx, validIdx = utils.getRandomSplit(trainData.shape[0], [80, 20])

    validData = SubsetStar(trainData, validIdx)
    trainData = SubsetStar(trainData, trainIdx)

    model = CGFF(**_config).to(device)

    model.train()
    model.fit(
        trainData=trainData,
        validData=validData,
        fitParams=_config,
    )

    model = torch.load(_config["fileName"]).to(device)
    model.eval()

    trainResults = model.test(trainData)
    validResults = model.test(validData)

    testResults = {}
    for key, testData_ in testData.items():
        testResults_ = model.test(testData_)
        testResults[key] = testResults_

    return trainResults, validResults, testResults
예제 #2
0
    def computeLoss(self, dataInGPU, loader=None):
        if dataInGPU:
            # Batch dataset is given
            X, y = dataInGPU

            outs = self(X)
            loss = F.nll_loss(outs, y)
            penalty = self.computePenalty()

        elif loader:
            # Loader for the full dataset is given
            # Returns loss in float (no grad).
            loss = 0
            device = utils.getDevice()
            with torch.no_grad():
                for batchIdx, batchData in enumerate(loader):
                    X, y = batchData[:]
                    X, y = X.to(device), y.to(device)
                    outs = self.forward(X)

                    loss += F.nll_loss(outs, y, reduction="sum").item()

                loss /= len(loader.dataset)

                penalty = self.computePenalty().item()
        else:
            raise AttributeError

        return loss + penalty, {"loss": loss, "penalty": penalty}
예제 #3
0
    def test(self, testData, metrics=None):
        """
        Test the model and plot curves
        """

        testLoader = DataLoader(testData,
                                batch_size=1000,
                                num_workers=5,
                                pin_memory=True)

        device = utils.getDevice()
        testLoss = 0

        queue = thqueue.Queue(10)
        dataProducer = threading.Thread(target=producer,
                                        args=(device, queue, testLoader))
        dataProducer.start()

        results = {}
        with torch.no_grad():
            while True:
                batchIdx, X, y = queue.get()
                if X is None:
                    break

                outs = self.forward(X)
                testLoss += F.nll_loss(outs, y, reduction="sum").item()

                if metrics:
                    for metric in metrics:
                        metricName = metric.__name__.split("_")[0]
                        if metricName in results:
                            results[metricName] += metric(y,
                                                          outs,
                                                          X=X,
                                                          model=self)
                        else:
                            results[metricName] = metric(y,
                                                         outs,
                                                         X=X,
                                                         model=self)

        testLoss /= len(testData)

        for metricName in results.keys():
            results[metricName] = results[metricName] / len(testData)

        results["loss"] = testLoss

        results["penalty"] = self.computePenalty().item()
        results["lossAndPenalty"] = results["loss"] + results["penalty"]

        return results
예제 #4
0
    def computeLoss(self, data, loader=None):
        if data:
            _, X, y = data[:]

            device = utils.getDevice()
            X, y = X.to(device), y.to(device)

            outs = self.forward(X)
            loss = ((outs - y)**2).mean()

            penalty = torch.tensor(0)
            if self.penaltyAlpha > 0:
                penalty = self.penaltyAlpha * self.cgLayer.penalty(
                    mode=self.penaltyMode, T=self.penaltyT)
        elif loader:
            # Loader for the full dataset is given
            # Returns loss in float (no grad).
            loss = 0
            device = utils.getDevice()
            with torch.no_grad():
                for batchIdx, batchData in enumerate(loader):
                    _, X, y = batchData[:]
                    X, y = X.to(device), y.to(device)

                    outs = self.forward(X)

                    loss += ((outs - y)**2).sum()

                loss /= len(loader.dataset)

            penalty = torch.tensor(0)
            if self.penaltyAlpha > 0:
                penalty = self.penaltyAlpha * self.cgLayer.penalty(
                    mode=self.penaltyMode, T=self.penaltyT)
        else:
            raise AttributeError

        return loss + penalty, {"loss": loss, "penalty": penalty}
예제 #5
0
    def test(self, testData):
        """
        Test the model
        """

        device = utils.getDevice()

        testLoader = DataLoader(testData,
                                batch_size=128,
                                num_workers=2,
                                pin_memory=True)

        testLoss = 0
        testAcc = 0

        results = {}
        with torch.no_grad():
            for batchIdx, batchData in enumerate(testLoader):
                _, X, y = batchData[:]
                X, y = X.to(device), y.to(device)

                outs = self.forward(X)
                testLoss += ((outs - y)**2).sum().item()
                testAcc += (y == torch.round(outs)).float().sum().item()

        testLoss /= len(testData)
        testAcc = testAcc * 100 / len(testData)

        results["loss"] = testLoss
        results["accuracy"] = testAcc

        penalty = torch.tensor(0.0)
        if self.penaltyAlpha > 0:
            penalty = self.penaltyAlpha * self.cgLayer.penalty(
                mode=self.penaltyMode, T=self.penaltyT)
        results["penalty"] = penalty.item()
        results["lossAndPenalty"] = results["loss"] + results["penalty"]

        return results
예제 #6
0
    trainResults = model.test(trainData)
    validResults = model.test(validData)

    testResults = {}
    for key, testData_ in testData.items():
        testResults_ = model.test(testData_)
        testResults[key] = testResults_

    return trainResults, validResults, testResults


if __name__ == '__main__':
    basisDir = "data/basis"
    os.makedirs(basisDir, exist_ok=True)

    device = utils.getDevice()
    vocabularySize = 100
    nSamples = 10000

    task = SumTaskDataset
    print(task)

    trainData = task(
        nSamples=nSamples,
        sequenceLength=10,
        vocabularyRange=(1, vocabularySize),
        inputTransform=np.sort,
    )
    testDataOOD = task(nSamples=nSamples,
                       sequenceLength=10,
                       vocabularyRange=(1, vocabularySize))
예제 #7
0
    def fit(self, trainData, validData, fitParams):
        """
        Fit the model to the training data
        Parameters
        ----------
        trainData : Train Dataset
        validData : Validation Dataset
        fitParams : Dictionary with parameters for fitting.
                    (lr, weightDecay(l2), lrSchedulerStepSize, fileName, batchSize, lossName, patience, numEpochs)
        Returns
        -------
        None
        """
        log = utils.getLogger()
        tqdmDisable = fitParams.get("tqdmDisable", False)
        device = utils.getDevice()
        optimizer = optim.SGD(self.parameters(),
                              lr=fitParams["lr"],
                              momentum=fitParams["momentum"])

        fileName = fitParams["fileName"]

        N = len(trainData)
        if fitParams.get("nMinibatches"):
            batchSize = 1 << int(math.log2(N // fitParams["nMinibatches"]))
        else:
            batchSize = fitParams.get("batchSize", 0)

        if batchSize == 0 or batchSize >= len(trainData):
            batchSize = len(trainData)

        bestValidLoss = np.inf
        patience = fitParams["patience"]
        numEpochs = fitParams["numEpochs"]
        validationFrequency = 1

        trainLoader = DataLoader(
            trainData,
            batch_size=batchSize,
            shuffle=True,
            num_workers=5,
            pin_memory=True,
        )
        validLoader = DataLoader(validData,
                                 batch_size=1000,
                                 shuffle=True,
                                 num_workers=1,
                                 pin_memory=True)

        trainLoss, _ = self.computeLoss(dataInGPU=None, loader=trainLoader)
        trainPenalty = _["penalty"]
        epochTime = 0

        for epoch in tqdm(range(1, numEpochs + 1),
                          leave=False,
                          desc="Epochs",
                          disable=tqdmDisable):
            validLoss, _ = self.computeLoss(dataInGPU=None, loader=validLoader)
            validPenalty = _["penalty"]

            saved = ""
            if (epoch == 1 or epoch > patience or epoch >= numEpochs or epoch %
                    validationFrequency == 0) and validLoss < bestValidLoss:
                # saved = "(Saved to {})".format(fileName)
                saved = "(Saved model)"
                torch.save(self, fileName)
                if validLoss < 0.995 * bestValidLoss:
                    patience = np.max([epoch * 2, patience])
                bestValidLoss = validLoss

            if epoch > patience:
                break

            log.info(
                f"{epoch} out of {min(patience, numEpochs)} | Train Loss: {trainLoss:.4f} ({trainPenalty:.4f}) | Valid Loss: {validLoss:.4f} ({validPenalty: .4f}) | {saved}"
            )

            trainLoss = 0.0
            trainPenalty = 0.0

            queue = thqueue.Queue(10)
            dataProducer = threading.Thread(target=producer,
                                            args=(device, queue, trainLoader))
            dataProducer.start()

            startTime = time.time()
            while True:
                batchIdx, X, y = queue.get()
                if X is None:
                    break
                optimizer.zero_grad()  # zero the gradient buffer
                batchTrainLoss, lossAndPenalty = self.computeLoss((X, y))
                batchTrainLoss.backward()
                optimizer.step()
                trainLoss += float(batchTrainLoss.detach().cpu())
                trainPenalty += float(lossAndPenalty["penalty"].detach().cpu())

            epochTime = time.time() - startTime
            log.debug(f"Time taken this epoch: {epochTime}")
            trainLoss /= batchIdx + 1
            trainPenalty /= batchIdx + 1

            if self.kill_now:
                break
예제 #8
0
    def __init__(
        self,
        in_channels,
        out_channels,
        invariant_transforms,
        kernel_size=3,
        stride=1,
        padding=0,
        bias=True,
        precomputed_basis_folder=".",
        penaltyAlpha=0,
    ):
        super().__init__()

        self.ksize = kernel_size

        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding

        # A list of m groups for this layer given in the form of their respective Reynolds operator function
        # For example: [IS.G_rotation, IS.G_flip]
        self.invariant_transforms = invariant_transforms
        self.penaltyAlpha = penaltyAlpha

        if IS.G_color_permutation in self.invariant_transforms:
            # Channels is part of the basis
            self.Wshape = (self.in_channels, *self.kernel_size)
            sameBasisAcross = 1
        else:
            # Same basis is enough for all channels
            self.Wshape = self.kernel_size
            sameBasisAcross = self.in_channels

        # Load the basis from the specified folder.
        basisFileName = IS._getBasisFileName(listT=self.invariant_transforms,
                                             Wshape=self.Wshape)
        if basisFileName not in loadedBases:
            self.invariant_transforms, basisList, self.basisConfigs = self.getBasis(
                precomputed_basis_folder)
            basisList = [torch.Tensor(basis.T) for basis in basisList]
            self.basisShapes = [basis.shape[0] for basis in basisList]
            self.basis = torch.cat(basisList).to(utils.getDevice())
            loadedBases[basisFileName] = (self.invariant_transforms,
                                          self.basisConfigs, self.basisShapes,
                                          self.basis)

        self.invariant_transforms, self.basisConfigs, self.basisShapes, self.basis = loadedBases[
            basisFileName]

        # Weights are the parameters of the linear combination corresponding to each basis vector.
        self.weights = nn.Parameter(
            torch.Tensor(out_channels, sameBasisAcross, self.basis.shape[0],
                         1))

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter("bias", None)
        self.reset_parameters()