Exemplo n.º 1
0
    def test(self, views, TrainShare=None, verbose=True):
        tic = time.time()
        assert len(np.shape(
            views)) == 3, "Data shape must be 3D, i.e. sub x time x voxel"

        self.TestFeatures = None

        NumSub, NumTime, NumVoxel = np.shape(views)
        NumFea = self.net_shape[-1]
        if NumFea is None:
            NumFea = np.min((NumTime, NumVoxel))
            if verbose:
                print(
                    "Number of features is automatically assigned, Features: ",
                    NumFea)
                self.net_shape[-1] = NumFea

        if TrainShare is not None:
            Share = TrainShare
            self.Share = TrainShare
        elif self.Share is not None:
            Share = self.Share

        if self.loss_type == 'mse':
            criterion = torch.nn.MSELoss()
        elif self.loss_type == 'soft':
            criterion = torch.nn.MultiLabelSoftMarginLoss()
        elif self.loss_type == 'mean':
            criterion = torch.mean
        elif self.loss_type == 'norm':
            criterion = torch.norm
        else:
            raise Exception(
                "Loss function type is wrong! Options: \'mse\', \'soft\', \'mean\', or \'norm\'"
            )

        self.ha_loss_test_vec = list()
        self.ha_loss_test = None

        NewViews = list()
        G = torch.Tensor(Share)

        for s in range(NumSub):
            net_shape = np.concatenate(([NumVoxel], self.net_shape))
            net = MLP(model=net_shape,
                      activation=self.activation,
                      gpu_enable=self.gpu_enable)

            if self.optim == "adam":
                optimizer = optim.Adam(net.parameters(), lr=self.learning_rate)
            elif self.optim == "sgd":
                optimizer = optim.SGD(net.parameters(), lr=self.learning_rate)
            else:
                raise Exception(
                    "Optimization algorithm is wrong! Options: \'adam\' or \'sgd\'"
                )

            X = torch.Tensor(views[s])
            net.train()

            for j in range(self.iteration):
                for epoch in range(self.epoch):
                    perm = torch.randperm(NumTime)
                    sum_loss = 0

                    for i in range(0, NumTime, self.batch_size):
                        x = X[perm[i:i + self.batch_size]]
                        g = G[perm[i:i + self.batch_size]]

                        # Send data to GPU
                        if self.gpu_enable:
                            x = x.cuda()
                            g = g.cuda()

                        optimizer.zero_grad()
                        fx = net(x)

                        if self.loss_type == 'mse' or self.loss_type == 'soft':
                            loss = criterion(g, fx)
                        else:
                            loss = criterion(g - fx)

                        loss.backward()
                        optimizer.step()
                        sum_loss += loss.data.cpu().numpy()

                        if self.epoch_internal_iteration > (i / NumTime):
                            break

                    if verbose:
                        print(
                            "TEST, UPDATE NETWORK: Iteration {:6d}, Subject {:6d}, Epoch {:6d}, loss error: {}"
                            .format(j + 1, s + 1, epoch + 1, sum_loss))

            if self.gpu_enable:
                X = X.cuda()

            NewViews.append(net(X).data.cpu().numpy())

        ha_model = GPUHA(Dim=NumFea, regularization=self.regularization)
        ha_model.test(views=NewViews, G=Share, verbose=verbose)
        self.TestFeatures = ha_model.Xtest
        self.TestRuntime = time.time() - tic
        return self.TestFeatures
Exemplo n.º 2
0
if __name__ == "__main__":
    data = np.random.rand(10, 10, 10)
    dat = np.random.rand(2, 10, 10)

    model = RDHA([None], [None],
                 epoch_internal_iteration=0.5,
                 batch_size=10,
                 epoch=1,
                 norm2_enable=False)
    model.train(data)
    model.test(dat)
    X = model.TrainFeatures
    Y = model.TestFeatures
    G = model.Share

    model1 = GPUHA()
    tic = time.time()
    model1.train(data)
    toc = time.time() - tic
    model1.test(dat)
    X2 = model1.Xtrain
    Y2 = model1.Xtest
    G2 = model1.G

    print("\nRDHA, trace(G) = ", np.trace(G), " G^TG= ",
          np.trace(np.dot(np.transpose(G), G)))
    print("RDHA, Shared Space Shape: ", np.shape(G))
    print("RDHA, Features Shape: ", np.shape(X))
    # print("RDHA, Loss vec: ", model.ha_loss_vec)
    print("RDHA, Error: ", model.ha_loss, ", Runtime: ", model.TrainRuntime)
    error = 0
Exemplo n.º 3
0
    def train(self, views, verbose=True):
        tic = time.time()
        assert len(np.shape(
            views)) == 3, "Data shape must be 3D, i.e. sub x time x voxel"

        self.Share = None
        self.TrainFeatures = None

        NumSub, NumTime, NumVoxel = np.shape(views)
        NumFea = self.net_shape[-1]
        if NumFea is None:
            NumFea = np.min((NumTime, NumVoxel))
            if verbose:
                print(
                    "Number of features is automatically assigned, Features: ",
                    NumFea)
                self.net_shape[-1] = NumFea

        Share = np.random.randn(NumTime, NumFea)

        if self.loss_type == 'mse':
            criterion = torch.nn.MSELoss()
        elif self.loss_type == 'soft':
            criterion = torch.nn.MultiLabelSoftMarginLoss()
        elif self.loss_type == 'mean':
            criterion = torch.mean
        elif self.loss_type == 'norm':
            criterion = torch.norm
        else:
            raise Exception(
                "Loss function type is wrong! Options: \'mse\', \'soft\', \'mean\', or \'norm\'"
            )

        self.ha_loss_vec = list()

        self.ha_loss = None

        for j in range(self.iteration):

            NewViews = list()
            G = torch.Tensor(Share)

            for s in range(NumSub):
                net_shape = np.concatenate(([NumVoxel], self.net_shape))
                net = MLP(model=net_shape,
                          activation=self.activation,
                          gpu_enable=self.gpu_enable)

                if self.optim == "adam":
                    optimizer = optim.Adam(net.parameters(),
                                           lr=self.learning_rate)
                elif self.optim == "sgd":
                    optimizer = optim.SGD(net.parameters(),
                                          lr=self.learning_rate)
                else:
                    raise Exception(
                        "Optimization algorithm is wrong! Options: \'adam\' or \'sgd\'"
                    )

                X = torch.Tensor(views[s])
                net.train()

                for epoch in range(self.epoch):
                    perm = torch.randperm(NumTime)
                    sum_loss = 0

                    for i in range(0, NumTime, self.batch_size):
                        x = X[perm[i:i + self.batch_size]]
                        g = G[perm[i:i + self.batch_size]]

                        # Send data to GPU
                        if self.gpu_enable:
                            x = x.cuda()
                            g = g.cuda()

                        optimizer.zero_grad()
                        fx = net(x)

                        if self.loss_type == 'mse' or self.loss_type == 'soft':
                            loss = criterion(fx, g) / NumTime
                        else:
                            loss = criterion(fx - g) / NumTime

                        if self.norm1_enable or self.norm2_enable:
                            for weight in net.get_weights():
                                if self.norm1_enable:
                                    loss += self.alpha * torch.mean(
                                        torch.abs(weight[1]))

                                if self.norm2_enable:
                                    loss += self.alpha * torch.mean(weight[1]**
                                                                    2)

                        loss.backward()
                        optimizer.step()
                        sum_loss += loss.data.cpu().numpy()

                        if self.epoch_internal_iteration > (i / NumTime):
                            break

                    if verbose:
                        print(
                            "TRAIN, UPDATE NETWORK: Iteration {:5d}, Subject {:6d}, Epoch {:6d}, loss error: {}"
                            .format(j + 1, s + 1, epoch + 1, sum_loss))

                if self.gpu_enable:
                    X = X.cuda()

                NewViews.append(net(X).data.cpu().numpy())

            ha_model = GPUHA(Dim=NumFea, regularization=self.regularization)

            if NumFea >= NumTime:
                ha_model.train(views=NewViews,
                               verbose=verbose,
                               gpu=self.gpu_enable)
            else:
                ha_model.train(views=NewViews, verbose=verbose, gpu=False)

            Share = ha_model.G
            out_features = ha_model.Xtrain
            error = np.mean(ha_model.Etrain)

            if error == 0:
                assert self.Share is not None, "All extracted features are zero, i.e. number of features is not enough for creating a shared space"
                self.TrainRuntime = time.time() - tic
                return self.TrainFeatures, self.Share

            if self.best_result_enable:
                if self.ha_loss is None:
                    self.Share = Share
                    self.TrainFeatures = out_features
                    self.ha_loss = error

                if error <= self.ha_loss:
                    self.Share = Share
                    self.TrainFeatures = out_features
                    self.ha_loss = error
            else:
                self.Share = Share
                self.TrainFeatures = out_features
                self.ha_loss = error

            if verbose:
                print("Hyperalignment error: {}".format(error))

            self.ha_loss_vec.append(error)

        self.TrainRuntime = time.time() - tic
        return self.TrainFeatures, self.Share
Exemplo n.º 4
0
    def btnConvert_click(self):
        totalTime = 0
        msgBox = QMessageBox()

        TrFoldErr = list()
        TeFoldErr = list()

        try:
            FoldFrom = np.int32(ui.txtFoldFrom.text())
            FoldTo = np.int32(ui.txtFoldTo.text())
        except:
            print("Please check fold parameters!")
            return

        if FoldTo < FoldFrom:
            print("Please check fold parameters!")
            return

        for fold_all in range(FoldFrom, FoldTo + 1):
            tic = time.time()
            # Regularization
            try:
                Regularization = np.float(ui.txtRegularization.text())
            except:
                msgBox.setText("Regularization value is wrong!")
                msgBox.setIcon(QMessageBox.Critical)
                msgBox.setStandardButtons(QMessageBox.Ok)
                msgBox.exec_()
                return False

            # OutFile
            OutFile = ui.txtOutFile.text()
            OutFile = OutFile.replace("$FOLD$", str(fold_all))
            if not len(OutFile):
                msgBox.setText("Please enter out file!")
                msgBox.setIcon(QMessageBox.Critical)
                msgBox.setStandardButtons(QMessageBox.Ok)
                msgBox.exec_()
                return False

            # InFile
            InFile = ui.txtInFile.text()
            InFile = InFile.replace("$FOLD$", str(fold_all))
            if not len(InFile):
                msgBox.setText("Please enter input file!")
                msgBox.setIcon(QMessageBox.Critical)
                msgBox.setStandardButtons(QMessageBox.Ok)
                msgBox.exec_()
                return False
            if not os.path.isfile(InFile):
                msgBox.setText("Input file not found!")
                msgBox.setIcon(QMessageBox.Critical)
                msgBox.setStandardButtons(QMessageBox.Ok)
                msgBox.exec_()
                return False

            InData = io.loadmat(InFile)
            OutData = dict()
            OutData["imgShape"] = InData["imgShape"]

            # Data
            if not len(ui.txtITrData.currentText()):
                msgBox.setText("Please enter Input Train Data variable name!")
                msgBox.setIcon(QMessageBox.Critical)
                msgBox.setStandardButtons(QMessageBox.Ok)
                msgBox.exec_()
                return False
            if not len(ui.txtITeData.currentText()):
                msgBox.setText("Please enter Input Test Data variable name!")
                msgBox.setIcon(QMessageBox.Critical)
                msgBox.setStandardButtons(QMessageBox.Ok)
                msgBox.exec_()
                return False
            if not len(ui.txtOTrData.text()):
                msgBox.setText("Please enter Output Train Data variable name!")
                msgBox.setIcon(QMessageBox.Critical)
                msgBox.setStandardButtons(QMessageBox.Ok)
                msgBox.exec_()
                return False
            if not len(ui.txtOTeData.text()):
                msgBox.setText("Please enter Output Test Data variable name!")
                msgBox.setIcon(QMessageBox.Critical)
                msgBox.setStandardButtons(QMessageBox.Ok)
                msgBox.exec_()
                return False

            try:
                XTr = InData[ui.txtITrData.currentText()]
                XTe = InData[ui.txtITeData.currentText()]

                if ui.cbScale.isChecked() and not ui.rbScale.isChecked():
                    XTr = preprocessing.scale(XTr)
                    XTe = preprocessing.scale(XTe)
                    print("Whole of data is scaled X~N(0,1).")
            except:
                print("Cannot load data")
                return

            # NComponent
            try:
                NumFea = np.int32(ui.txtNumFea.text())
            except:
                msgBox.setText("Number of features is wrong!")
                msgBox.setIcon(QMessageBox.Critical)
                msgBox.setStandardButtons(QMessageBox.Ok)
                msgBox.exec_()
                return False
            if NumFea < 0:
                msgBox.setText("Number of features must be greater than zero!")
                msgBox.setIcon(QMessageBox.Critical)
                msgBox.setStandardButtons(QMessageBox.Ok)
                msgBox.exec_()
                return False
            if NumFea > np.shape(XTr)[1]:
                msgBox.setText("Number of features is wrong!")
                msgBox.setIcon(QMessageBox.Critical)
                msgBox.setStandardButtons(QMessageBox.Ok)
                msgBox.exec_()
                return False

            if NumFea == 0:
                NumFea = None

            # Label
            if not len(ui.txtITrLabel.currentText()):
                msgBox.setText("Please enter Train Input Label variable name!")
                msgBox.setIcon(QMessageBox.Critical)
                msgBox.setStandardButtons(QMessageBox.Ok)
                msgBox.exec_()
                return False
            if not len(ui.txtITeLabel.currentText()):
                msgBox.setText("Please enter Test Input Label variable name!")
                msgBox.setIcon(QMessageBox.Critical)
                msgBox.setStandardButtons(QMessageBox.Ok)
                msgBox.exec_()
                return False
            if not len(ui.txtOTrLabel.text()):
                msgBox.setText(
                    "Please enter Train Output Label variable name!")
                msgBox.setIcon(QMessageBox.Critical)
                msgBox.setStandardButtons(QMessageBox.Ok)
                msgBox.exec_()
                return False
            if not len(ui.txtOTeLabel.text()):
                msgBox.setText("Please enter Test Output Label variable name!")
                msgBox.setIcon(QMessageBox.Critical)
                msgBox.setStandardButtons(QMessageBox.Ok)
                msgBox.exec_()
                return False
            try:
                OutData[ui.txtOTrLabel.text()] = InData[
                    ui.txtITrLabel.currentText()]
                OutData[ui.txtOTeLabel.text()] = InData[
                    ui.txtITeLabel.currentText()]
            except:
                print("Cannot load labels!")

            # Subject
            if not len(ui.txtITrSubject.currentText()):
                msgBox.setText(
                    "Please enter Train Input Subject variable name!")
                msgBox.setIcon(QMessageBox.Critical)
                msgBox.setStandardButtons(QMessageBox.Ok)
                msgBox.exec_()
                return False
            if not len(ui.txtITeSubject.currentText()):
                msgBox.setText(
                    "Please enter Test Input Subject variable name!")
                msgBox.setIcon(QMessageBox.Critical)
                msgBox.setStandardButtons(QMessageBox.Ok)
                msgBox.exec_()
                return False
            if not len(ui.txtOTrSubject.text()):
                msgBox.setText(
                    "Please enter Train Output Subject variable name!")
                msgBox.setIcon(QMessageBox.Critical)
                msgBox.setStandardButtons(QMessageBox.Ok)
                msgBox.exec_()
                return False
            if not len(ui.txtOTeSubject.text()):
                msgBox.setText(
                    "Please enter Test Output Subject variable name!")
                msgBox.setIcon(QMessageBox.Critical)
                msgBox.setStandardButtons(QMessageBox.Ok)
                msgBox.exec_()
                return False
            try:
                TrSubject = InData[ui.txtITrSubject.currentText()]
                OutData[ui.txtOTrSubject.text()] = TrSubject
                TeSubject = InData[ui.txtITeSubject.currentText()]
                OutData[ui.txtOTeSubject.text()] = TeSubject
            except:
                print("Cannot load Subject IDs")
                return

            # Task
            if ui.cbTask.isChecked():
                if not len(ui.txtITrTask.currentText()):
                    msgBox.setText(
                        "Please enter Input Train Task variable name!")
                    msgBox.setIcon(QMessageBox.Critical)
                    msgBox.setStandardButtons(QMessageBox.Ok)
                    msgBox.exec_()
                    return False
                if not len(ui.txtITeTask.currentText()):
                    msgBox.setText(
                        "Please enter Input Test Task variable name!")
                    msgBox.setIcon(QMessageBox.Critical)
                    msgBox.setStandardButtons(QMessageBox.Ok)
                    msgBox.exec_()
                    return False
                if not len(ui.txtOTrTask.text()):
                    msgBox.setText(
                        "Please enter Output Train Task variable name!")
                    msgBox.setIcon(QMessageBox.Critical)
                    msgBox.setStandardButtons(QMessageBox.Ok)
                    msgBox.exec_()
                    return False
                if not len(ui.txtOTeTask.text()):
                    msgBox.setText(
                        "Please enter Output Test Task variable name!")
                    msgBox.setIcon(QMessageBox.Critical)
                    msgBox.setStandardButtons(QMessageBox.Ok)
                    msgBox.exec_()
                    return False
                try:
                    TrTask = InData[ui.txtITrTask.currentText()]
                    OutData[ui.txtOTrTask.text()] = TrTask
                    TeTask = InData[ui.txtITeTask.currentText()]
                    OutData[ui.txtOTeTask.text()] = TeTask
                    TrTaskIndex = TrTask.copy()
                    for tasindx, tas in enumerate(np.unique(TrTask)):
                        TrTaskIndex[TrTask == tas] = tasindx + 1
                    TeTaskIndex = TeTask.copy()
                    for tasindx, tas in enumerate(np.unique(TeTask)):
                        TeTaskIndex[TeTask == tas] = tasindx + 1
                except:
                    print("Cannot load Tasks!")
                    return

            # Run
            if ui.cbRun.isChecked():
                if not len(ui.txtITrRun.currentText()):
                    msgBox.setText(
                        "Please enter Train Input Run variable name!")
                    msgBox.setIcon(QMessageBox.Critical)
                    msgBox.setStandardButtons(QMessageBox.Ok)
                    msgBox.exec_()
                    return False
                if not len(ui.txtITeRun.currentText()):
                    msgBox.setText(
                        "Please enter Test Input Run variable name!")
                    msgBox.setIcon(QMessageBox.Critical)
                    msgBox.setStandardButtons(QMessageBox.Ok)
                    msgBox.exec_()
                    return False
                if not len(ui.txtOTrRun.text()):
                    msgBox.setText(
                        "Please enter Train Output Run variable name!")
                    msgBox.setIcon(QMessageBox.Critical)
                    msgBox.setStandardButtons(QMessageBox.Ok)
                    msgBox.exec_()
                    return False
                if not len(ui.txtOTeRun.text()):
                    msgBox.setText(
                        "Please enter Test Output Run variable name!")
                    msgBox.setIcon(QMessageBox.Critical)
                    msgBox.setStandardButtons(QMessageBox.Ok)
                    msgBox.exec_()
                    return False
                try:
                    TrRun = InData[ui.txtITrRun.currentText()]
                    OutData[ui.txtOTrRun.text()] = TrRun
                    TeRun = InData[ui.txtITeRun.currentText()]
                    OutData[ui.txtOTeRun.text()] = TeRun
                except:
                    print("Cannot load Runs!")
                    return

            # Counter
            if ui.cbCounter.isChecked():
                if not len(ui.txtITrCounter.currentText()):
                    msgBox.setText(
                        "Please enter Train Input Counter variable name!")
                    msgBox.setIcon(QMessageBox.Critical)
                    msgBox.setStandardButtons(QMessageBox.Ok)
                    msgBox.exec_()
                    return False
                if not len(ui.txtITeCounter.currentText()):
                    msgBox.setText(
                        "Please enter Test Input Counter variable name!")
                    msgBox.setIcon(QMessageBox.Critical)
                    msgBox.setStandardButtons(QMessageBox.Ok)
                    msgBox.exec_()
                    return False
                if not len(ui.txtOTrCounter.text()):
                    msgBox.setText(
                        "Please enter Train Output Counter variable name!")
                    msgBox.setIcon(QMessageBox.Critical)
                    msgBox.setStandardButtons(QMessageBox.Ok)
                    msgBox.exec_()
                    return False
                if not len(ui.txtOTeCounter.text()):
                    msgBox.setText(
                        "Please enter Test Output Counter variable name!")
                    msgBox.setIcon(QMessageBox.Critical)
                    msgBox.setStandardButtons(QMessageBox.Ok)
                    msgBox.exec_()
                    return False
                try:
                    TrCounter = InData[ui.txtITrCounter.currentText()]
                    OutData[ui.txtOTrCounter.text()] = TrCounter
                    TeCounter = InData[ui.txtITeCounter.currentText()]
                    OutData[ui.txtOTeCounter.text()] = TeCounter
                except:
                    print("Cannot load Counters!")
                    return

            # Matrix Label
            if ui.cbmLabel.isChecked():
                if not len(ui.txtITrmLabel.currentText()):
                    msgBox.setText(
                        "Please enter Train Input Matrix Label variable name!")
                    msgBox.setIcon(QMessageBox.Critical)
                    msgBox.setStandardButtons(QMessageBox.Ok)
                    msgBox.exec_()
                    return False
                if not len(ui.txtITemLabel.currentText()):
                    msgBox.setText(
                        "Please enter Test Input Matrix Label variable name!")
                    msgBox.setIcon(QMessageBox.Critical)
                    msgBox.setStandardButtons(QMessageBox.Ok)
                    msgBox.exec_()
                    return False
                if not len(ui.txtOTrmLabel.text()):
                    msgBox.setText(
                        "Please enter Train Output Matrix Label variable name!"
                    )
                    msgBox.setIcon(QMessageBox.Critical)
                    msgBox.setStandardButtons(QMessageBox.Ok)
                    msgBox.exec_()
                    return False
                if not len(ui.txtOTemLabel.text()):
                    msgBox.setText(
                        "Please enter Test Output Matrix Label variable name!")
                    msgBox.setIcon(QMessageBox.Critical)
                    msgBox.setStandardButtons(QMessageBox.Ok)
                    msgBox.exec_()
                    return False
                try:
                    OutData[ui.txtOTrmLabel.text()] = InData[
                        ui.txtITrmLabel.currentText()]
                    OutData[ui.txtOTemLabel.text()] = InData[
                        ui.txtITemLabel.currentText()]
                except:
                    print("Cannot load matrix lables!")
                    return

            # Design
            if ui.cbDM.isChecked():
                if not len(ui.txtITrDM.currentText()):
                    msgBox.setText(
                        "Please enter Train Input Design Matrix variable name!"
                    )
                    msgBox.setIcon(QMessageBox.Critical)
                    msgBox.setStandardButtons(QMessageBox.Ok)
                    msgBox.exec_()
                    return False
                if not len(ui.txtITeDM.currentText()):
                    msgBox.setText(
                        "Please enter Test Input Design Matrix variable name!")
                    msgBox.setIcon(QMessageBox.Critical)
                    msgBox.setStandardButtons(QMessageBox.Ok)
                    msgBox.exec_()
                    return False
                if not len(ui.txtOTrDM.text()):
                    msgBox.setText(
                        "Please enter Train Output Design Matrix variable name!"
                    )
                    msgBox.setIcon(QMessageBox.Critical)
                    msgBox.setStandardButtons(QMessageBox.Ok)
                    msgBox.exec_()
                    return False
                if not len(ui.txtOTeDM.text()):
                    msgBox.setText(
                        "Please enter Test Output Design Matrix variable name!"
                    )
                    msgBox.setIcon(QMessageBox.Critical)
                    msgBox.setStandardButtons(QMessageBox.Ok)
                    msgBox.exec_()
                    return False
                try:
                    OutData[ui.txtOTrDM.text()] = InData[
                        ui.txtITrDM.currentText()]
                    OutData[ui.txtOTeDM.text()] = InData[
                        ui.txtITeDM.currentText()]
                except:
                    print("Cannot load design matrices!")
                    return

            # Coordinate
            if ui.cbCol.isChecked():
                if not len(ui.txtCol.currentText()):
                    msgBox.setText("Please enter Coordinator variable name!")
                    msgBox.setIcon(QMessageBox.Critical)
                    msgBox.setStandardButtons(QMessageBox.Ok)
                    msgBox.exec_()
                    return False
                if not len(ui.txtOCol.text()):
                    msgBox.setText("Please enter Coordinator variable name!")
                    msgBox.setIcon(QMessageBox.Critical)
                    msgBox.setStandardButtons(QMessageBox.Ok)
                    msgBox.exec_()
                    return False
                try:
                    OutData[ui.txtOCol.text()] = InData[
                        ui.txtCol.currentText()]
                except:
                    print("Cannot load coordinator!")
                    return

            # Condition
            if ui.cbCond.isChecked():
                if not len(ui.txtCond.currentText()):
                    msgBox.setText("Please enter Condition variable name!")
                    msgBox.setIcon(QMessageBox.Critical)
                    msgBox.setStandardButtons(QMessageBox.Ok)
                    msgBox.exec_()
                    return False
                if not len(ui.txtOCond.text()):
                    msgBox.setText("Please enter Condition variable name!")
                    msgBox.setIcon(QMessageBox.Critical)
                    msgBox.setStandardButtons(QMessageBox.Ok)
                    msgBox.exec_()
                    return False
                try:
                    OutData[ui.txtOCond.text()] = InData[
                        ui.txtCond.currentText()]
                except:
                    print("Cannot load conditions!")
                    return

            # FoldID
            if ui.cbFoldID.isChecked():
                if not len(ui.txtFoldID.currentText()):
                    msgBox.setText("Please enter FoldID variable name!")
                    msgBox.setIcon(QMessageBox.Critical)
                    msgBox.setStandardButtons(QMessageBox.Ok)
                    msgBox.exec_()
                    return False
                if not len(ui.txtOFoldID.text()):
                    msgBox.setText("Please enter FoldID variable name!")
                    msgBox.setIcon(QMessageBox.Critical)
                    msgBox.setStandardButtons(QMessageBox.Ok)
                    msgBox.exec_()
                    return False
                try:
                    OutData[ui.txtOFoldID.text()] = InData[
                        ui.txtFoldID.currentText()]
                except:
                    print("Cannot load Fold ID!")
                    return

            # FoldInfo
            if ui.cbFoldInfo.isChecked():
                if not len(ui.txtFoldInfo.currentText()):
                    msgBox.setText("Please enter FoldInfo variable name!")
                    msgBox.setIcon(QMessageBox.Critical)
                    msgBox.setStandardButtons(QMessageBox.Ok)
                    msgBox.exec_()
                    return False
                if not len(ui.txtOFoldInfo.text()):
                    msgBox.setText("Please enter FoldInfo variable name!")
                    msgBox.setIcon(QMessageBox.Critical)
                    msgBox.setStandardButtons(QMessageBox.Ok)
                    msgBox.exec_()
                    return False
                try:
                    OutData[ui.txtOFoldInfo.text()] = InData[
                        ui.txtFoldInfo.currentText()]
                except:
                    print("Cannot load Fold Info!")
                    return
                pass

            # Number of Scan
            if ui.cbNScan.isChecked():
                if not len(ui.txtITrScan.currentText()):
                    msgBox.setText(
                        "Please enter Number of Scan variable name for Input Train!"
                    )
                    msgBox.setIcon(QMessageBox.Critical)
                    msgBox.setStandardButtons(QMessageBox.Ok)
                    msgBox.exec_()
                    return False
                if not len(ui.txtITeScan.currentText()):
                    msgBox.setText(
                        "Please enter Number of Scan variable name for Input Test!"
                    )
                    msgBox.setIcon(QMessageBox.Critical)
                    msgBox.setStandardButtons(QMessageBox.Ok)
                    msgBox.exec_()
                    return False
                if not len(ui.txtOTrScan.text()):
                    msgBox.setText(
                        "Please enter Number of Scan variable name for Output Train!"
                    )
                    msgBox.setIcon(QMessageBox.Critical)
                    msgBox.setStandardButtons(QMessageBox.Ok)
                    msgBox.exec_()
                    return False
                if not len(ui.txtOTeScan.text()):
                    msgBox.setText(
                        "Please enter Number of Scan variable name for Output Test!"
                    )
                    msgBox.setIcon(QMessageBox.Critical)
                    msgBox.setStandardButtons(QMessageBox.Ok)
                    msgBox.exec_()
                    return False
                try:
                    OutData[ui.txtOTrScan.text()] = InData[
                        ui.txtITrScan.currentText()]
                    OutData[ui.txtOTeScan.text()] = InData[
                        ui.txtITeScan.currentText()]
                except:
                    print("Cannot load NScan!")
                    return

            # Train Analysis Level
            print("Calculating Analysis Level for Training Set ...")
            TrGroupFold = None
            FoldStr = ""
            if ui.cbFSubject.isChecked():
                if not ui.rbFRun.isChecked():
                    TrGroupFold = TrSubject
                    FoldStr = "Subject"
                else:
                    TrGroupFold = np.concatenate((TrSubject, TrRun))
                    FoldStr = "Subject+Run"

            if ui.cbFTask.isChecked():
                TrGroupFold = np.concatenate(
                    (TrGroupFold,
                     TrTaskIndex)) if TrGroupFold is not None else TrTaskIndex
                FoldStr = FoldStr + "+Task"

            if ui.cbFCounter.isChecked():
                TrGroupFold = np.concatenate(
                    (TrGroupFold,
                     TrCounter)) if TrGroupFold is not None else TrCounter
                FoldStr = FoldStr + "+Counter"

            TrGroupFold = np.transpose(TrGroupFold)

            TrUniqFold = np.array(
                list(set(tuple(i) for i in TrGroupFold.tolist())))

            TrFoldIDs = np.arange(len(TrUniqFold)) + 1

            TrListFold = list()
            for gfold in TrGroupFold:
                for ufoldindx, ufold in enumerate(TrUniqFold):
                    if (ufold == gfold).all():
                        currentID = TrFoldIDs[ufoldindx]
                        break
                TrListFold.append(currentID)
            TrListFold = np.int32(TrListFold)
            TrListFoldUniq = np.unique(TrListFold)

            # Test Analysis Level
            print("Calculating Analysis Level for Testing Set ...")
            TeGroupFold = None
            if ui.cbFSubject.isChecked():
                if not ui.rbFRun.isChecked():
                    TeGroupFold = TeSubject
                else:
                    TeGroupFold = np.concatenate((TeSubject, TeRun))

            if ui.cbFTask.isChecked():
                TeGroupFold = np.concatenate(
                    (TeGroupFold,
                     TeTaskIndex)) if TeGroupFold is not None else TeTaskIndex

            if ui.cbFCounter.isChecked():
                TeGroupFold = np.concatenate(
                    (TeGroupFold,
                     TeCounter)) if TeGroupFold is not None else TeCounter

            TeGroupFold = np.transpose(TeGroupFold)

            TeUniqFold = np.array(
                list(set(tuple(i) for i in TeGroupFold.tolist())))

            TeFoldIDs = np.arange(len(TeUniqFold)) + 1

            TeListFold = list()
            for gfold in TeGroupFold:
                for ufoldindx, ufold in enumerate(TeUniqFold):
                    if (ufold == gfold).all():
                        currentID = TeFoldIDs[ufoldindx]
                        break
                TeListFold.append(currentID)
            TeListFold = np.int32(TeListFold)
            TeListFoldUniq = np.unique(TeListFold)

            # Train Partition
            print("Partitioning Training Data ...")
            TrX = list()
            TrShape = None
            for foldindx, fold in enumerate(TrListFoldUniq):
                dat = XTr[np.where(TrListFold == fold)]
                if ui.cbScale.isChecked() and ui.rbScale.isChecked():
                    dat = preprocessing.scale(dat)
                    print("Data belong to View " + str(foldindx + 1) +
                          " is scaled X~N(0,1).")

                TrX.append(dat)
                if TrShape is None:
                    TrShape = np.shape(dat)
                else:
                    if not (TrShape == np.shape(dat)):
                        print("ERROR: Train, Reshape problem for Fold " +
                              str(foldindx + 1) + ", Shape: " +
                              str(np.shape(dat)))
                        return
                print("Train: View " + str(foldindx + 1) +
                      " is extracted. Shape: " + str(np.shape(dat)))

            print("Training Shape: " + str(np.shape(TrX)))

            # Test Partition
            print("Partitioning Testing Data ...")
            TeX = list()
            TeShape = None
            for foldindx, fold in enumerate(TeListFoldUniq):
                dat = XTe[np.where(TeListFold == fold)]
                if ui.cbScale.isChecked() and ui.rbScale.isChecked():
                    dat = preprocessing.scale(dat)
                    print("Data belong to View " + str(foldindx + 1) +
                          " is scaled X~N(0,1).")
                TeX.append(dat)
                if TeShape is None:
                    TeShape = np.shape(dat)
                else:
                    if not (TeShape == np.shape(dat)):
                        print("Test: Reshape problem for Fold " +
                              str(foldindx + 1))
                        return
                print("Test: View " + str(foldindx + 1) + " is extracted.")

            print("Testing Shape: " + str(np.shape(TeX)))

            model = GPUHA(Dim=NumFea, regularization=Regularization)

            print("Running Hyperalignment on Training Data ...")
            MappedXtr, G, _, _, _ = model.train(TrX)

            print("Running Hyperalignment on Testing Data ...")
            MappedXte, _, _, _ = model.test(TeX)

            # Train Dot Product
            print("Producting Training Data ...")
            TrHX = None
            TrErr = None
            for foldindx, fold in enumerate(TrListFoldUniq):
                TrErr = TrErr + (
                    G - MappedXtr[foldindx]
                ) if TrErr is not None else G - MappedXtr[foldindx]
                TrHX = np.concatenate(
                    (TrHX, MappedXtr[foldindx]
                     )) if TrHX is not None else MappedXtr[foldindx]
            OutData[ui.txtOTrData.text()] = TrHX
            foldindx = foldindx + 1
            TrErr = TrErr / foldindx
            print("Train: alignment error ", np.linalg.norm(TrErr))
            TrFoldErr.append(np.linalg.norm(TrErr))

            # Train Dot Product
            print("Producting Testing Data ...")
            TeHX = None
            TeErr = None
            for foldindx, fold in enumerate(TeListFoldUniq):
                TeErr = TeErr + (
                    G - MappedXte[foldindx]
                ) if TeErr is not None else G - MappedXte[foldindx]
                TeHX = np.concatenate(
                    (TeHX, MappedXte[foldindx]
                     )) if TeHX is not None else MappedXte[foldindx]
            OutData[ui.txtOTeData.text()] = TeHX
            foldindx = foldindx + 1
            TeErr = TeErr / foldindx
            print("Test: alignment error ", np.linalg.norm(TeErr))
            TeFoldErr.append(np.linalg.norm(TeErr))

            HAParam = dict()
            HAParam["Share"] = G
            HAParam["Level"] = FoldStr
            OutData["FunctionalAlignment"] = HAParam
            OutData["Runtime"] = time.time() - tic
            totalTime += OutData["Runtime"]
            print("Saving ...")
            io.savemat(OutFile, mdict=OutData)
            print("Fold " + str(fold_all) + " is DONE: " + OutFile)

        print("Training -> Alignment Error: mean " + str(np.mean(TrFoldErr)) +
              " std " + str(np.std(TrFoldErr)))
        print("Testing  -> Alignment Error: mean " + str(np.mean(TeFoldErr)) +
              " std " + str(np.std(TeFoldErr)))
        print("Runtime: ", totalTime)
        print("GPU Hyperalignment is done.")
        msgBox.setText("GPU Hyperalignment is done.")
        msgBox.setIcon(QMessageBox.Information)
        msgBox.setStandardButtons(QMessageBox.Ok)
        msgBox.exec_()