Exemplo n.º 1
0
def getAllBasis(folder, listT, Wshape, powerSetMethod=None, force=False):
    """
    Loads from file. Creates all subspaces if file does not exists.
    :param folder: Folder to load from or save to.
    :param listT: List of Reynolds operators (functions)
    :param Wshape: Shape of input/weights
    :param powerSetMethod: Method to iterate over the power set of listT (i.e., the different subsets of listT).
    :param force: Force re-create all subspaces.
    :return: Same as `loadAllBasis` function.
             listT: Lexicographically sorted list of Reynolds operators (functions)
             allSubspaces: Lattice of subspaces
             powerSetConfigs: The subsets of listT (encoded by indices) that have nonempty subspaces.
    """
    _log = utils.getLogger()
    try:
        assert force == False
        basis = loadAllBasis(folder, listT, Wshape)
        return basis
    except:
        _log.info(
            f"Failed to load basis file. Creating basis file at {folder}.")
        allSubspaces, powerSetConfigs = findAllSubspacesFromTransforms(
            listT=listT,
            Wshape=Wshape,
            powerSetMethod=powerSetMethod,
            show=False)
        fileName = saveAllBasis(folder=folder,
                                listT=listT,
                                Wshape=Wshape,
                                allSubspaces=allSubspaces,
                                powerSetConfigs=powerSetConfigs)

        _log.info(f"Saved basis file: {os.path.join(folder, fileName)}.")

    return loadAllBasis(folder, listT, Wshape)
Exemplo n.º 2
0
    def __init__(self, **modelParams):
        super().__init__()

        self.imageSize = modelParams.get("imageSize", 28)
        self.inputChannels = modelParams.get("inputChannels", 3)
        self.kernelSize = modelParams.get("kernelSize", 3)
        self.stride = modelParams.get("stride", 1)
        self.padding = modelParams.get("padding", 1)
        self.nOutputs = modelParams.get("nOutputs", 10)

        # Folder name where the subspaces (the basis vectors) are stored.
        self.precomputedBasisFolder = modelParams.get("precomputedBasisFolder",
                                                      ".")
        _log = utils.getLogger()

        # For each layer, get the list of group names (for example, [rotation, flip])
        listInvariantTransforms = modelParams.get("listInvariantTransforms",
                                                  [None] * 4)
        listInvariantTransforms = [
            ["trivial"] if it is None or it == [] else it
            for it in listInvariantTransforms
        ]

        _log.info(
            f"Invariant transforms for layers: {listInvariantTransforms}")

        # For each layer and for each group, get the respective Reynolds operator function
        # Example: "rotation" -> IS.G_rotation
        listInvariantTransforms = [
            IS.getTransformationsFromNames(it)
            for it in listInvariantTransforms
        ]
        self.listInvariantTransforms = listInvariantTransforms

        os.makedirs(self.precomputedBasisFolder, exist_ok=True)

        # Strength and temperature for the penalty.
        self.penaltyAlpha = modelParams.get("penaltyAlpha", 0)
        self.penaltyMode = modelParams.get("penaltyMode", "simple")
        self.penaltyT = modelParams.get("penaltyT", 1)

        # Different architectures
        architecture = modelParams.get("architecture", "simple")
        if architecture == "simple":
            architecture = [10, 10, 'M', 20, 20, 'M']
        elif self.imageSize >= 32:
            architecture = [
                64, 'M', 128, 'M', 128, 128, 'M', 128, 128, 'M', 128, 128, 'M'
            ]
        else:
            architecture = [64, 'M', 128, 'M', 128, 128, 'M', 128, 128, 'M']

        self.convLayers, linearSize = self._make_layers(architecture)
        self.fc1 = nn.Linear(linearSize, 50)
        self.fc2 = nn.Linear(50, self.nOutputs)

        # Signals to be handled by the GracefulKiller parent class.
        # At these signals, sets self.kill_now to True
        signal.signal(signal.SIGINT, self.exit_gracefully)
        signal.signal(signal.SIGTERM, self.exit_gracefully)
Exemplo n.º 3
0
def findAllSubspacesFromTransforms(listT,
                                   Wshape,
                                   powerSetMethod=None,
                                   method=2,
                                   decimals=6,
                                   show=False):
    """
    Given a list of Reynolds operators of various groups, computes their respective 1-eigenspaces
    and calls `findAllSubspaces` function to find the full lattice of subspaces.

    :param listT: List of Reynolds operator transformations for various groups (e.g., [G_rotation, G_flip]).
    :param Wshape: Shape of input/weights.
    :param powerSetMethod: Method to iterate over the power set of listT (i.e., the different subsets of listT).
    :param method: Alternate methods to compute subspace intersections
    :param decimals: Precision
    :param show: Show the subspaces for debugging purposes.
    :return: allSubspaces: All the nonempty subspaces arranged in non-increasing order of invariance.
             powerSetConfigs: All the subsets of listT (encoded by indices) with nonempty subspace in the lattice.
    """

    _log = utils.getLogger()
    listT = _sortTransforms(listT)

    # For every i, construct 1-eigenspace for the Reynolds operator listT[i].
    listV = []
    for transform in listT:
        A, Ac, S = getInvariantSubspace(Wshape, transform)
        listV.append(A)
        if show:
            _log.debug(
                f"Invariant subspace of transform {transform.__name__}: {A.shape}"
            )
            showSubspace(A, Wshape, 1, channels=True)

    allSubspaces, powerSetConfigs = findAllSubspaces(
        listV=listV,
        listT=listT,
        powerSetMethod=powerSetMethod,
        method=method,
        decimals=decimals)

    if show:
        for subspace, config in zip(allSubspaces, powerSetConfigs):
            showSubspace(subspace, Wshape, 1, channels=True)

    return allSubspaces, powerSetConfigs
Exemplo n.º 4
0
def getInvariantSubspace(Wshape, transformation):
    """
    Given the function form of the Reynolds operator of a group, e.g., G_rotation, obtain its left 1-eigenspace.
    :param Wshape: Shape of the input/weights.
                    For instance for images, Wshape=(channels, kernel_size, kernel_size).
                                 for sequences, Wshape=(sequence_length, dimension).
    :param transformation: The function form of Reynolds operator of the group, e.g., G_rotation.
    :return: Returns the left 1-eigenspace of shape (np.prod(Wshape), numEigenvectors),
                its complement and the eigenvalues.
    """

    Wsize = np.prod(Wshape)
    max_samples = Wsize

    try:
        ncpus = int(os.environ["SLURM_JOB_CPUS_PER_NODE"])
    except KeyError:
        ncpus = tmp.cpu_count()

    _log = utils.getLogger()
    _log.debug(f"Using {ncpus} CPUs to get {transformation.__name__}")

    pool = tmp.Pool(ncpus)

    ans = pool.map(
        partial(constructTbar_onehot,
                transformation=transformation,
                Wshape=Wshape), list(range(max_samples)))
    pool.close()
    pool.join()

    # Reynolds operator for the transformation
    Tbar = np.stack(ans, axis=1)

    # Eigenvectors of Tbar (symmetric)
    # Eigenvalues are in columns arranged in ascending order, 1-eigenvectors at the end.
    S, Vh = np.linalg.eigh(Tbar)

    rank = np.linalg.matrix_rank(np.diag(S), hermitian=True)

    # V: Eigenvectors associated with eigenvalue 1 and Vc is the complement space.
    V = Vh[:, -rank:]
    Vc = Vh[:, :-rank]

    return V, Vc, S
Exemplo n.º 5
0
    def fit(self, trainData, validData, fitParams):
        """
        Fit the model to the training data
        Parameters
        ----------
        trainData : Train Dataset
        validData : Validation Dataset
        fitParams : Dictionary with parameters for fitting.
                    (lr, weightDecay(l2), lrSchedulerStepSize, fileName, batchSize, lossName, patience, numEpochs)
        Returns
        -------
        None
        """
        log = utils.getLogger()
        tqdmDisable = fitParams.get("tqdmDisable", False)
        optimizer = optim.Adam(self.parameters(),
                               lr=fitParams["lr"],
                               weight_decay=fitParams["weightDecay"])
        fileName = fitParams["fileName"]

        N = trainData.shape[0]
        if fitParams.get("nMinibatches"):
            batchSize = 1 << int(math.log2(N // fitParams["nMinibatches"]))
        else:
            batchSize = fitParams.get("batchSize", 0)

        if batchSize == 0 or batchSize >= len(trainData):
            batchSize = len(trainData)

        bestValidLoss = np.inf
        patience = fitParams["patience"]
        numEpochs = fitParams["numEpochs"]
        validationFrequency = 1

        trainLoader = DataLoader(trainData, batch_size=batchSize)

        trainLoss, trainDict = self.computeLoss(data=None, loader=trainLoader)
        trainLoss = trainLoss.item()
        trainPenalty = trainDict["penalty"].item()

        for epoch in tqdm(range(1, numEpochs + 1),
                          leave=False,
                          desc="Epochs",
                          disable=tqdmDisable):
            validLoss, validDict = self.computeLoss(validData)
            validLoss = validLoss.item()

            saved = ""
            if (epoch == 1 or epoch > patience or epoch >= numEpochs or epoch %
                    validationFrequency == 0) and validLoss < bestValidLoss:
                saved = "(Saved model)"
                torch.save(self, fileName)
                if validLoss < 0.995 * bestValidLoss:
                    patience = np.max([epoch * 2, patience])
                bestValidLoss = validLoss

            if epoch > patience:
                break

            log.info(
                f"{epoch} out of {min(patience, numEpochs)} | "
                f"Train Loss: {trainLoss:.4f} (Penalty: {trainPenalty:.4f}) | "
                f"Valid Loss: {validLoss:.4f} (Penalty: {validDict['penalty'].item():.4f}) | "
                f"{saved}\r")

            trainLoss = 0.0
            trainPenalty = 0.0
            batchIdx = 0
            for batchIdx, batchTrainData in enumerate(
                    tqdm(trainLoader,
                         leave=False,
                         desc="Minibatches",
                         disable=tqdmDisable)):
                optimizer.zero_grad()  # zero the gradient buffer
                batchTrainLoss, batchTrainDict = self.computeLoss(
                    batchTrainData)
                trainLoss += batchTrainLoss.item()
                trainPenalty += batchTrainDict["penalty"].item()
                batchTrainLoss.backward()
                optimizer.step()

            trainLoss /= batchIdx + 1
            trainPenalty /= batchIdx + 1

            if self.kill_now:
                break
Exemplo n.º 6
0
    def fit(self, trainData, validData, fitParams):
        """
        Fit the model to the training data
        Parameters
        ----------
        trainData : Train Dataset
        validData : Validation Dataset
        fitParams : Dictionary with parameters for fitting.
                    (lr, weightDecay(l2), lrSchedulerStepSize, fileName, batchSize, lossName, patience, numEpochs)
        Returns
        -------
        None
        """
        log = utils.getLogger()
        tqdmDisable = fitParams.get("tqdmDisable", False)
        device = utils.getDevice()
        optimizer = optim.SGD(self.parameters(),
                              lr=fitParams["lr"],
                              momentum=fitParams["momentum"])

        fileName = fitParams["fileName"]

        N = len(trainData)
        if fitParams.get("nMinibatches"):
            batchSize = 1 << int(math.log2(N // fitParams["nMinibatches"]))
        else:
            batchSize = fitParams.get("batchSize", 0)

        if batchSize == 0 or batchSize >= len(trainData):
            batchSize = len(trainData)

        bestValidLoss = np.inf
        patience = fitParams["patience"]
        numEpochs = fitParams["numEpochs"]
        validationFrequency = 1

        trainLoader = DataLoader(
            trainData,
            batch_size=batchSize,
            shuffle=True,
            num_workers=5,
            pin_memory=True,
        )
        validLoader = DataLoader(validData,
                                 batch_size=1000,
                                 shuffle=True,
                                 num_workers=1,
                                 pin_memory=True)

        trainLoss, _ = self.computeLoss(dataInGPU=None, loader=trainLoader)
        trainPenalty = _["penalty"]
        epochTime = 0

        for epoch in tqdm(range(1, numEpochs + 1),
                          leave=False,
                          desc="Epochs",
                          disable=tqdmDisable):
            validLoss, _ = self.computeLoss(dataInGPU=None, loader=validLoader)
            validPenalty = _["penalty"]

            saved = ""
            if (epoch == 1 or epoch > patience or epoch >= numEpochs or epoch %
                    validationFrequency == 0) and validLoss < bestValidLoss:
                # saved = "(Saved to {})".format(fileName)
                saved = "(Saved model)"
                torch.save(self, fileName)
                if validLoss < 0.995 * bestValidLoss:
                    patience = np.max([epoch * 2, patience])
                bestValidLoss = validLoss

            if epoch > patience:
                break

            log.info(
                f"{epoch} out of {min(patience, numEpochs)} | Train Loss: {trainLoss:.4f} ({trainPenalty:.4f}) | Valid Loss: {validLoss:.4f} ({validPenalty: .4f}) | {saved}"
            )

            trainLoss = 0.0
            trainPenalty = 0.0

            queue = thqueue.Queue(10)
            dataProducer = threading.Thread(target=producer,
                                            args=(device, queue, trainLoader))
            dataProducer.start()

            startTime = time.time()
            while True:
                batchIdx, X, y = queue.get()
                if X is None:
                    break
                optimizer.zero_grad()  # zero the gradient buffer
                batchTrainLoss, lossAndPenalty = self.computeLoss((X, y))
                batchTrainLoss.backward()
                optimizer.step()
                trainLoss += float(batchTrainLoss.detach().cpu())
                trainPenalty += float(lossAndPenalty["penalty"].detach().cpu())

            epochTime = time.time() - startTime
            log.debug(f"Time taken this epoch: {epochTime}")
            trainLoss /= batchIdx + 1
            trainPenalty /= batchIdx + 1

            if self.kill_now:
                break
Exemplo n.º 7
0
    def __init__(
        self,
        root,
        name,
        groups,
        mode,
        train=True,
        fold=None,
        pSamples=None,
        seed=42,
        transform=None,
        target_transform=None,
        normalize_transform=None,
        download=False,
        force=False,
    ):
        super().__init__(root,
                         transform=transform,
                         target_transform=target_transform)
        # :mode:
        # :fold: (i,k) to indicate the fold i of a k-fold cross validation; None for no CV.
        # :pSamples: Percentage to sample. Used to obtain the effect of training dataset size. Use None for no subsampling.
        # :root, train, transform, target_transform, download: Standard parameters
        # :normalize_transform: Do not pass Normalize() in transform.
        #        Pass True for trainData (automatically computed). Use trainData's normalizeTransform for testData.
        # :force: Force re-create dataset.

        self.root = root
        self.name = name.upper(
        ) + "Xtra"  # Name of the dataset, for example MNISTXtra
        self.train = train
        self.normalize_transform = normalize_transform
        self.force = force

        # Groups used to transform the data.
        # For example: [G_rotation_create, G_flip_create]
        self.groups = groups

        # Codes which of the groups are in G_I and which are in G_D
        self.mode = mode

        if download or force:
            self.download()

        if not self._check_exists():
            raise RuntimeError("Dataset not found." +
                               " You can use download=True to download it")

        if self.train:
            self.data_file = self.train_file
            splitPrefix = "train"
        else:
            self.data_file = self.test_file
            splitPrefix = "test"

        self.data, self.targets = torch.load(
            os.path.join(self.folder, self.data_file))

        _log = utils.getLogger()

        # Cross validation folds
        if fold is not None:
            cvIt = fold[0]
            cvFolds = fold[1]

            self.split_file = f"{splitPrefix}_cvFolds={cvFolds}_seed={seed}.pt"
            if not self._check_exists(self.split_file):
                _log.info("Split file does not exist. Creating one.")
                splits = randomKFoldSplit(len(self.data), cvFolds, seed)
                torch.save(splits, os.path.join(self.folder, self.split_file))

            # _log.info(f"Using splits from {self.split_file}")
            splits = torch.load(os.path.join(self.folder, self.split_file))

            allExceptFold = np.concatenate(splits[:cvIt] + splits[cvIt + 1:])

            self.data = self.data[allExceptFold]
            self.targets = self.targets[allExceptFold]

        classes = self.targets.unique().sort().values
        self.class_to_idx = {
            class_name.item(): i
            for i, class_name in enumerate(classes)
        }

        if pSamples is not None:
            # Percentage to sample : To find effect of dataset size.
            N = self.data.shape[0]
            sampleIdx = np.random.choice(N,
                                         int(N * pSamples / 100),
                                         replace=False)
            self.data = self.data[sampleIdx]
            self.targets = self.targets[sampleIdx]

        # Compute input normalization (in the training data).
        if self.train and self.normalize_transform:
            mean = self.data.float().mean(dim=(0, 1, 2)) / 255
            std = self.data.float().std(dim=(0, 1, 2)) / 255
            std[std == 0] = 1
            self.normalize_transform = transforms.Normalize(mean, std)
Exemplo n.º 8
0
    def download(self):
        # Actually download MNIST and create the individual datasets.
        _log = utils.getLogger()
        if not self.force and self._check_exists():
            return

        os.makedirs(self.folder, exist_ok=True)

        if self.name.lower() == "mnistxtra":
            nChannels = 1
            subsetLabels = [3, 4]
            datasetClass = MNIST
        elif self.name.lower() == "mnistfullxtra":
            nChannels = 1
            subsetLabels = None
            datasetClass = MNIST
        else:
            raise NotImplementedError

        trainData = datasetClass(self.root, train=True, download=True)

        if type(trainData.data) != torch.Tensor:
            trainData.data = torch.tensor(trainData.data)

        if type(trainData.targets) != torch.Tensor:
            trainData.targets = torch.tensor(trainData.targets)

        testData = datasetClass(self.root, train=False)
        if type(testData.data) != torch.Tensor:
            testData.data = torch.tensor(testData.data)

        if type(testData.targets) != torch.Tensor:
            testData.targets = torch.tensor(testData.targets)

        _log.info(
            f"Processing {datasetClass.__name__} dataset to create {self.name} [{self.mode}] dataset."
        )
        np.random.seed(42)
        # Modes
        modeList = self.mode.split("_")

        color = (255, 0, 0)
        try:
            labelType = modeList[-1]
        except:
            labelType = "replace"

        # Use only a subset of digits
        if subsetLabels:
            subset(trainData, subsetLabels)
            subset(testData, subsetLabels)

        # MNIST Grayscale to RGB (color all digits red)
        if nChannels == 1:
            makeThreeChannels(trainData, color=color)
            makeThreeChannels(testData, color=color)

        # Hypothesis 1 : Selectively able to choose the group
        config = modeList[1]

        # Groups the label depends on (given by `mode`)
        G_D = [
            self.groups[i].__name__ for i in range(len(config))
            if config[i] == "1"
        ]

        # Groups the label is invariant to (given by `mode`)
        G_I = [
            self.groups[i].__name__ for i in range(len(config))
            if config[i] == "0"
        ]
        print(f"G_D: {G_D}, G_I: {G_I}")
        for i, g_create in enumerate(self.groups):
            if config[i] == "1":
                # config[i] = 1: The group affects the label.
                # Apply the group transformations to both training and test data, and change labels accordingly.
                randomGroupTransformation(trainData,
                                          g_create,
                                          changeTarget=labelType)
                randomGroupTransformation(testData,
                                          g_create,
                                          changeTarget=labelType)
                if labelType == "replace":
                    labelType = "add"  # "add" the second label
            else:
                # config[i] = 0: The group DOES NOT affect the label.
                # Apply the group transformations only to the test data.
                # Training data need not see these transformations (economic sampling).
                randomGroupTransformation(testData, g_create)

        torch.save(
            [trainData.data, trainData.targets],
            os.path.join(self.folder, self.train_file),
        )
        torch.save([testData.data, testData.targets],
                   os.path.join(self.folder, self.test_file))

        _log.info("Done creating dataset.")