def btnConvert_click(self): msgBox = QMessageBox() tStart = time.time() Activation = ui.cbActivation.currentData() LossNorm = ui.cbLossNorm.currentData() LossType = ui.cbType.currentData() try: Layers = strRange(ui.txtLayers.text(), Unique=False) if Layers is None: raise Exception('') except: msgBox.setText("Layers is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Alpha = np.float32(ui.txtAlpha.text()) if Alpha < 0: raise Exception except: msgBox.setText("Alpha is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Iter = np.int32(ui.txtIter.text()) except: msgBox.setText("Number of iteration is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: BatchSize = np.int32(ui.txtBatch.text()) except: msgBox.setText("Number of batch is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: ReportStep = np.int32(ui.txtReportStep.text()) except: msgBox.setText("Number of Report Step is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: LearningRate = np.float32(ui.txtRate.text()) except: msgBox.setText("Number of Report Step is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if not ui.cbCov.isChecked() and not ui.cbCorr.isChecked(): msgBox.setText("At least, you must select one metric!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Filter try: Filter = ui.txtFilter.text() if not len(Filter): Filter = None else: Filter = Filter.replace("\'", " ").replace(",", " ").replace( "[", "").replace("]", "").split() Filter = np.int32(Filter) except: print("Filter is wrong!") return # OutFile OutFile = ui.txtOutFile.text() if not len(OutFile): msgBox.setText("Please enter out file!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False OutData = dict() # InFile InFile = ui.txtInFile.text() if not len(InFile): msgBox.setText("Please enter input file!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if not os.path.isfile(InFile): msgBox.setText("Input file not found!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False print("Loading ...") InData = io.loadmat(InFile) # Data if not len(ui.txtData.currentText()): msgBox.setText("Please enter Input Data variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Label if not len(ui.txtLabel.currentText()): msgBox.setText("Please enter Train Label variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Design if not len(ui.txtDesign.currentText()): msgBox.setText("Please enter Input Design variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Design = InData[ui.txtDesign.currentText()] except: msgBox.setText("Design value is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: X = InData[ui.txtData.currentText()] L = InData[ui.txtLabel.currentText()][0] except: print("Cannot load data or label") return # Task if not len(ui.txtTask.currentText()): msgBox.setText("Please enter Task variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Task Val if not len(ui.txtTaskVal.currentText()): msgBox.setText("Please enter Task value!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: TaskIDTitle = ui.txtTaskVal.currentText() except: msgBox.setText("Task value is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: TaskTitle = InData[ui.txtTask.currentText()][0] except: msgBox.setText("Task variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False TaskTitleUnique = np.unique(TaskTitle) Task = np.zeros(np.shape(TaskTitle)) for ttinx, tt in enumerate(TaskTitle): for ttlinx, ttl in enumerate(TaskTitleUnique): if tt[0] == ttl: Task[ttinx] = ttlinx + 1 break for ttlinx, ttl in enumerate(TaskTitleUnique): if TaskIDTitle == ttl: TaskID = ttlinx + 1 break OutData["Task"] = TaskIDTitle # Subject if not len(ui.txtSubject.currentText()): msgBox.setText("Please enter Subject variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Subject Val if not len(ui.txtSubjectVal.currentText()): msgBox.setText("Please enter Subject value!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: SubID = np.int32(ui.txtSubjectVal.currentText()) except: msgBox.setText("Subject value is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Sub = InData[ui.txtSubject.currentText()][0] except: msgBox.setText("Subject variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False OutData["SubjectID"] = SubID # Run if not len(ui.txtRun.currentText()): msgBox.setText("Please enter Run variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Run Val if not len(ui.txtRunVal.currentText()): msgBox.setText("Please enter Run value!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: RunID = np.int32(ui.txtRunVal.currentText()) except: msgBox.setText("Run value is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Run = InData[ui.txtRun.currentText()][0] except: msgBox.setText("Run variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False OutData["RunID"] = RunID # Counter if not len(ui.txtCounter.currentText()): msgBox.setText("Please enter Counter variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Counter Val if not len(ui.txtCounterVal.currentText()): msgBox.setText("Please enter Counter value!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: ConID = np.int32(ui.txtCounterVal.currentText()) except: msgBox.setText("Counter value is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Con = InData[ui.txtCounter.currentText()][0] except: msgBox.setText("Counter variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False OutData["CounterID"] = ConID if Filter is not None: for fil in Filter: # Remove Training Set labelIndx = np.where(L == fil)[0] Design = np.delete(Design, labelIndx, axis=0) X = np.delete(X, labelIndx, axis=0) L = np.delete(L, labelIndx, axis=0) Task = np.delete(Task, labelIndx, axis=0) Sub = np.delete(Sub, labelIndx, axis=0) Run = np.delete(Run, labelIndx, axis=0) Con = np.delete(Con, labelIndx, axis=0) print("Class ID = " + str(fil) + " is removed from data.") # Select Task TaskIndex = np.where(Task == TaskID) Design = Design[TaskIndex, :][0] X = X[TaskIndex, :][0] L = L[TaskIndex] Sub = Sub[TaskIndex] Run = Run[TaskIndex] Con = Con[TaskIndex] # Select Subject SubIndex = np.where(Sub == SubID) Design = Design[SubIndex, :][0] X = X[SubIndex, :][0] L = L[SubIndex] Run = Run[SubIndex] Con = Con[SubIndex] # Select Counter ConIndex = np.where(Con == ConID) Design = Design[ConIndex, :][0] X = X[ConIndex, :][0] L = L[ConIndex] Run = Run[ConIndex] # Select Run RunIndex = np.where(Run == RunID) Design = Design[RunIndex, :][0] X = X[RunIndex, :][0] L = L[RunIndex] # This will only use in supervised methods LUnique = np.unique(L) LNum = np.shape(LUnique)[0] OutData["Label"] = LUnique OutData["ModelAnalysis"] = "Tensorflow.Session.Deep.RSA" if np.shape(X)[0] == 0: msgBox.setText("The selected data is empty!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if ui.cbScale.isChecked(): X = preprocessing.scale(X) print("Data is scaled to N(0,1).") print("Running Deep RSA ...") # RSA Method OutData['Method'] = dict() OutData['Method']['Layers'] = ui.txtLayers.text() OutData['Method']['Alpha'] = Alpha OutData['Method']['Activation'] = Activation OutData['Method']['LossNorm'] = LossNorm OutData['Method']['LearningRate'] = LearningRate OutData['Method']['NumIter'] = Iter OutData['Method']['BatchSize'] = BatchSize OutData['Method']['ReportStep'] = ReportStep OutData['Method']['Verbose'] = ui.cbVerbose.isChecked() rsa = DeepRSA(layers=Layers, n_iter=Iter, learning_rate=LearningRate,loss_norm=LossNorm,activation=Activation,\ batch_size=BatchSize,report_step=ReportStep,verbose=ui.cbVerbose.isChecked(),\ CPU=ui.cbDevice.currentData(), alpha=Alpha, loss_type=LossType) Betas, Weights, Biases, loss_vec, MSE, Performance = rsa.fit( data_vals=X, design_vals=Design) OutData["LossVec"] = loss_vec OutData["MSE"] = MSE OutData["Performance"] = Performance print("MSE: %f" % (MSE)) if ui.cbBeta.isChecked(): OutData["Betas"] = Betas OutData["Weights"] = Weights OutData["Biases"] = Biases # Calculate Results if ui.cbCorr.isChecked(): print("Calculating Correlation ...") Corr = np.corrcoef(Betas) corClass = SimilarityMatrixBetweenClass(Corr) OutData["Correlation"] = Corr OutData["Correlation_min"] = corClass.min() OutData["Correlation_max"] = corClass.max() OutData["Correlation_std"] = corClass.std() OutData["Correlation_mean"] = corClass.mean() print("Correlation: min: {:3.10f}, max: {:3.10f}, mean: {:3.10f}, std: {:3.10f}".format(corClass.min(), \ corClass.max(), corClass.mean(), corClass.std())) if ui.cbCov.isChecked(): print("Calculating Covariance ...") Cov = np.cov(Betas) covClass = SimilarityMatrixBetweenClass(Cov) OutData["Covariance"] = Cov OutData["Covariance_min"] = covClass.min() OutData["Covariance_max"] = covClass.max() OutData["Covariance_std"] = covClass.std() OutData["Covariance_mean"] = covClass.mean() print("Covariance: min: {:3.10f}, max: {:3.10f}, mean: {:3.10f}, std: {:3.10f}".format(covClass.min(), \ covClass.max(), covClass.mean(), covClass.std())) OutData["RunTime"] = time.time() - tStart print("Runtime (s): %f" % (OutData["RunTime"])) print("Saving results ...") io.savemat(OutFile, mdict=OutData, do_compression=True) print("Output is saved.") if ui.cbDiagram.isChecked(): if ui.cbCorr.isChecked(): fig1 = plt.figure(num=None, figsize=(5, 5), dpi=100) plt.pcolor(Corr, vmin=-0.1, vmax=1) plt.xlim([0, LNum]) plt.ylim([0, LNum]) plt.colorbar() ax = plt.gca() ax.set_aspect(1) plt.title( 'DeepRSA: Correlation\nTask: %s\nSub: %d, Counter: %d, Run: %d' % (TaskIDTitle, SubID, ConID, RunID)) plt.show() if ui.cbCov.isChecked(): fig2 = plt.figure(num=None, figsize=(5, 5), dpi=100) plt.pcolor(Cov) plt.xlim([0, LNum]) plt.ylim([0, LNum]) plt.colorbar() ax = plt.gca() ax.set_aspect(1) plt.title( 'DeepRSA: Covariance\nTask: %s\nSub: %d, Counter: %d, Run: %d' % (TaskIDTitle, SubID, ConID, RunID)) plt.show() print("DONE.") msgBox.setText( "Gradient Representational Similarity Analysis (RSA) is done.") msgBox.setIcon(QMessageBox.Information) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_()
def btnConvert_click(self): msgBox = QMessageBox() tStart = time.time() if not ui.cbCov.isChecked() and not ui.cbCorr.isChecked(): msgBox.setText("At least, you must select one metric!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Filter try: Filter = ui.txtFilter.text() if not len(Filter): Filter = None else: Filter = Filter.replace("\'", " ").replace(",", " ").replace( "[", "").replace("]", "").split() Filter = np.int32(Filter) except: print("Filter is wrong!") return # OutFile OutFile = ui.txtOutFile.text() if not len(OutFile): msgBox.setText("Please enter out file!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False OutData = dict() OutData["ModelAnalysis"] = "RSA" # InFile InFile = ui.txtInFile.text() if not len(InFile): msgBox.setText("Please enter input file!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if not os.path.isfile(InFile): msgBox.setText("Input file not found!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False print("Loading ...") InData = io.loadmat(InFile) # Data if not len(ui.txtData.currentText()): msgBox.setText("Please enter Input Data variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Label if not len(ui.txtLabel.currentText()): msgBox.setText("Please enter Train Label variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Design if not len(ui.txtDesign.currentText()): msgBox.setText("Please enter Input Design variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Design = InData[ui.txtDesign.currentText()] except: msgBox.setText("Design value is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: X = InData[ui.txtData.currentText()] L = InData[ui.txtLabel.currentText()][0] if ui.cbScale.isChecked() and not ui.rbScale.isChecked(): X = preprocessing.scale(X) print("Whole of data is scaled X~N(0,1).") except: print("Cannot load data or label") return # Task if not len(ui.txtTask.currentText()): msgBox.setText("Please enter Task variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: TaskTitle = InData[ui.txtTask.currentText()][0] except: msgBox.setText("Task variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False TaskTitleUnique = np.unique(TaskTitle) Task = np.zeros(np.shape(TaskTitle)) for ttinx, tt in enumerate(TaskTitle): for ttlinx, ttl in enumerate(TaskTitleUnique): if tt[0] == ttl: Task[ttinx] = ttlinx + 1 break # Subject if not len(ui.txtSubject.currentText()): msgBox.setText("Please enter Subject variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Sub = InData[ui.txtSubject.currentText()][0] except: msgBox.setText("Subject variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Run if not len(ui.txtRun.currentText()): msgBox.setText("Please enter Run variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Run = InData[ui.txtRun.currentText()][0] except: msgBox.setText("Run variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Counter if not len(ui.txtCounter.currentText()): msgBox.setText("Please enter Counter variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Con = InData[ui.txtCounter.currentText()][0] except: msgBox.setText("Counter variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if Filter is not None: for fil in Filter: # Remove Training Set labelIndx = np.where(L == fil)[0] Design = np.delete(Design, labelIndx, axis=0) X = np.delete(X, labelIndx, axis=0) L = np.delete(L, labelIndx, axis=0) Task = np.delete(Task, labelIndx, axis=0) Sub = np.delete(Sub, labelIndx, axis=0) Run = np.delete(Run, labelIndx, axis=0) Con = np.delete(Con, labelIndx, axis=0) print("Class ID = " + str(fil) + " is removed from data.") try: Unit = np.int32(ui.txtUnit.text()) except: msgBox.setText("Unit for the test set must be a number!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if Unit < 1: msgBox.setText("Unit for the test set must be greater than zero!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False print("Calculating Levels ...") GroupFold = None FoldStr = "" if ui.cbFSubject.isChecked(): if not ui.rbFRun.isChecked(): GroupFold = [Sub] FoldStr = "Subject" else: GroupFold = np.concatenate(([Sub], [Run])) FoldStr = "Subject+Run" if ui.cbFTask.isChecked(): GroupFold = np.concatenate( (GroupFold, [Task])) if GroupFold is not None else [Task] FoldStr = FoldStr + "+Task" if ui.cbFCounter.isChecked(): GroupFold = np.concatenate( (GroupFold, [Con])) if GroupFold is not None else [Con] FoldStr = FoldStr + "+Counter" if FoldStr == "": FoldStr = "Whole-Data" GUFold = [1] ListFold = [1] UniqFold = [1] GroupFold = [1] UnitFold = np.ones((1, np.shape(X)[0])) else: GroupFold = np.transpose(GroupFold) UniqFold = np.array(list(set(tuple(i) for i in GroupFold.tolist()))) FoldIDs = np.arange(len(UniqFold)) + 1 if len(UniqFold) <= Unit: msgBox.setText( "Unit must be smaller than all possible levels! Number of all levels is: " + str(len(UniqFold))) msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if np.mod(len(UniqFold), Unit): msgBox.setText( "Unit must be divorceable to all possible levels! Number of all levels is: " + str(len(UniqFold))) msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False ListFold = list() for gfold in GroupFold: for ufoldindx, ufold in enumerate(UniqFold): if (ufold == gfold).all(): currentID = FoldIDs[ufoldindx] break ListFold.append(currentID) ListFold = np.int32(ListFold) if Unit == 1: UnitFold = np.int32(ListFold) else: UnitFold = np.int32((ListFold - 0.1) / Unit) + 1 GUFold = np.unique(UnitFold) FoldInfo = dict() FoldInfo["Unit"] = Unit FoldInfo["Group"] = GroupFold FoldInfo["Order"] = FoldStr FoldInfo["List"] = ListFold FoldInfo["Unique"] = UniqFold FoldInfo["Folds"] = UnitFold OutData = dict() OutData["FoldInfo"] = FoldInfo OutData["ModelAnalysis"] = "Numpy.Group.RSA" print("Number of all levels is: " + str(len(UniqFold))) Cov = None Corr = None AMSE = list() for foldID, fold in enumerate(GUFold): print("Analyzing level " + str(foldID + 1), " of ", str(len(UniqFold)), " ...") Index = np.where(UnitFold == fold) # Whole-Data if FoldStr == "Whole-Data" and np.shape(Index)[0]: Index = [Index[1]] XLi = X[Index] if ui.cbScale.isChecked() and ui.rbScale.isChecked(): XLi = preprocessing.scale(XLi) print("Whole of data is scaled X%d~N(0,1)." % (foldID + 1)) RegLi = np.insert(Design[Index], 0, 1, axis=1) BetaLi = np.linalg.lstsq(RegLi, XLi)[0][1:, :] print("Calculating MSE for level %d ..." % (foldID + 1)) MSE = mean_squared_error(XLi, np.matmul(Design[Index], BetaLi)) print("MSE%d: %f" % (foldID + 1, MSE)) OutData["MSE" + str(foldID)] = MSE AMSE.append(MSE) if ui.cbBeta.isChecked(): OutData["BetaL" + str(foldID + 1)] = BetaLi # Calculate Results if ui.cbCorr.isChecked(): print("Calculating Correlation for level %d ..." % (foldID + 1)) CorrLi = np.corrcoef(BetaLi) OutData["Corr" + str(foldID + 1)] = CorrLi if Corr is None: Corr = CorrLi.copy() else: if ui.rbAvg.isChecked(): Corr = np.add(Corr, CorrLi) elif ui.rbMin.isChecked(): Corr = np.minimum(Corr, CorrLi) else: Corr = np.maximum(Corr, CorrLi) if ui.cbCov.isChecked(): print("Calculating Covariance for level %d ..." % (foldID + 1)) CovLi = np.cov(BetaLi) OutData["Cov" + str(foldID + 1)] = CovLi if Cov is None: Cov = CovLi.copy() else: if ui.rbAvg.isChecked(): Cov = np.add(Cov, CovLi) elif ui.rbMin.isChecked(): Cov = np.minimum(Cov, CovLi) else: Cov = np.maximum(Cov, CovLi) CoEff = len(UniqFold) - 1 if len(UniqFold) > 2 else 1 if ui.cbCov.isChecked(): if ui.rbAvg.isChecked(): Cov = Cov / CoEff covClass = SimilarityMatrixBetweenClass(Cov) OutData["Covariance"] = Cov OutData["Covariance_min"] = covClass.min() OutData["Covariance_max"] = covClass.max() OutData["Covariance_std"] = covClass.std() OutData["Covariance_mean"] = covClass.mean() if ui.cbCorr.isChecked(): if ui.rbAvg.isChecked(): Corr = Corr / CoEff corClass = SimilarityMatrixBetweenClass(Corr) OutData["Correlation"] = Corr OutData["Correlation_min"] = corClass.min() OutData["Correlation_max"] = corClass.max() OutData["Correlation_std"] = corClass.std() OutData["Correlation_mean"] = corClass.mean() OutData["MSE"] = np.mean(AMSE) OutData["MSE_std"] = np.std(AMSE) print("Average MSE: %f" % (OutData["MSE"])) OutData["RunTime"] = time.time() - tStart print("Runtime (s): %f" % (OutData["RunTime"])) print("Saving results ...") io.savemat(OutFile, mdict=OutData, do_compression=True) print("Output is saved.") if ui.cbDiagram.isChecked(): if ui.cbCorr.isChecked(): fig1 = plt.figure(num=None, figsize=(5, 5), dpi=100) plt.pcolor(Corr, vmin=-0.1, vmax=1) plt.xlim([0, np.shape(Corr)[0]]) plt.ylim([0, np.shape(Corr)[0]]) plt.colorbar() ax = plt.gca() ax.set_aspect(1) plt.title('Correlation of Categories\nLevel: ' + FoldStr) plt.show() if ui.cbCov.isChecked(): fig2 = plt.figure(num=None, figsize=(5, 5), dpi=100) plt.pcolor(Cov) plt.xlim([0, np.shape(Cov)[0]]) plt.ylim([0, np.shape(Cov)[0]]) plt.colorbar() ax = plt.gca() ax.set_aspect(1) plt.title('Covariance of Categories\nLevel: ' + FoldStr) plt.show() print("DONE.") msgBox.setText( "Group Representational Similarity Analysis (RSA) is done.") msgBox.setIcon(QMessageBox.Information) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_()
def btnConvert_click(self): msgBox = QMessageBox() tStart = time.time() if not ui.cbCov.isChecked() and not ui.cbCorr.isChecked(): msgBox.setText("At least, you must select one metric!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Method method = ui.cbMethod.currentData() # Solver solver = ui.cbSolver.currentText() # Selection selection = ui.cbSelection.currentText() # Fit fit = ui.cbFit.isChecked() # normalize normalize = ui.cbNormalize.isChecked() try: alpha = np.float(ui.txtAlpha.text()) except: msgBox.setText("Alpha is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: iter = np.int(ui.txtMaxIter.text()) except: msgBox.setText("Max Iteration is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: tol = np.float(ui.txtTole.text()) except: msgBox.setText("Tolerance is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: l1 = np.float(ui.txtL1.text()) except: msgBox.setText("L1 is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: njob = np.float(ui.txtJobs.text()) except: msgBox.setText("Number of jobs is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Filter try: Filter = ui.txtFilter.text() if not len(Filter): Filter = None else: Filter = Filter.replace("\'", " ").replace(",", " ").replace( "[", "").replace("]", "").split() Filter = np.int32(Filter) except: print("Filter is wrong!") return # OutFile OutFile = ui.txtOutFile.text() if not len(OutFile): msgBox.setText("Please enter out file!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False OutData = dict() # InFile InFile = ui.txtInFile.text() if not len(InFile): msgBox.setText("Please enter input file!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if not os.path.isfile(InFile): msgBox.setText("Input file not found!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False print("Loading ...") InData = mainIO_load(InFile) # Data if not len(ui.txtData.currentText()): msgBox.setText("Please enter Input Data variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Label if not len(ui.txtLabel.currentText()): msgBox.setText("Please enter Train Label variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Design if not len(ui.txtDesign.currentText()): msgBox.setText("Please enter Input Design variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Design = InData[ui.txtDesign.currentText()] except: msgBox.setText("Design value is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Condition if not len(ui.txtCond.currentText()): msgBox.setText("Please enter Condition variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Cond = InData[ui.txtCond.currentText()] OutData[ui.txtCond.currentText()] = Cond labels = list() for con in Cond: labels.append(reshape_condition_cell(con[1])) except: msgBox.setText("Condition value is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False FontSize = ui.txtFontSize.value() try: X = InData[ui.txtData.currentText()] L = InData[ui.txtLabel.currentText()][0] if ui.cbScale.isChecked() and not ui.rbScale.isChecked(): X = preprocessing.scale(X) print("Whole of data is scaled X~N(0,1).") except: print("Cannot load data or label") return # Task if not len(ui.txtTask.currentText()): msgBox.setText("Please enter Task variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: TaskTitle = np.array(InData[ui.txtTask.currentText()][0]) except: msgBox.setText("Task variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False TaskTitleUnique = np.unique(TaskTitle) Task = np.zeros(np.shape(TaskTitle)) for ttinx, tt in enumerate(TaskTitle): for ttlinx, ttl in enumerate(TaskTitleUnique): if tt[0] == ttl: Task[ttinx] = ttlinx + 1 break # Subject if not len(ui.txtSubject.currentText()): msgBox.setText("Please enter Subject variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Sub = InData[ui.txtSubject.currentText()][0] except: msgBox.setText("Subject variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Run if not len(ui.txtRun.currentText()): msgBox.setText("Please enter Run variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Run = InData[ui.txtRun.currentText()][0] except: msgBox.setText("Run variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Counter if not len(ui.txtCounter.currentText()): msgBox.setText("Please enter Counter variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Con = InData[ui.txtCounter.currentText()][0] except: msgBox.setText("Counter variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if Filter is not None: for fil in Filter: # Remove Training Set labelIndx = np.where(L == fil)[0] Design = np.delete(Design, labelIndx, axis=0) X = np.delete(X, labelIndx, axis=0) L = np.delete(L, labelIndx, axis=0) Task = np.delete(Task, labelIndx, axis=0) Sub = np.delete(Sub, labelIndx, axis=0) Run = np.delete(Run, labelIndx, axis=0) Con = np.delete(Con, labelIndx, axis=0) print("Class ID = " + str(fil) + " is removed from data.") try: Unit = np.int32(ui.txtUnit.text()) except: msgBox.setText("Unit for the test set must be a number!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if Unit < 1: msgBox.setText("Unit for the test set must be greater than zero!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False print("Calculating Levels ...") GroupFold = None FoldStr = "" if ui.cbFSubject.isChecked(): if not ui.rbFRun.isChecked(): GroupFold = [Sub] FoldStr = "Subject" else: GroupFold = np.concatenate(([Sub], [Run])) FoldStr = "Subject+Run" if ui.cbFTask.isChecked(): GroupFold = np.concatenate( (GroupFold, [Task])) if GroupFold is not None else [Task] FoldStr = FoldStr + "+Task" if ui.cbFCounter.isChecked(): GroupFold = np.concatenate( (GroupFold, [Con])) if GroupFold is not None else [Con] FoldStr = FoldStr + "+Counter" if FoldStr == "": FoldStr = "Whole-Data" GUFold = [1] ListFold = [1] UniqFold = [1] GroupFold = [1] UnitFold = np.ones((1, np.shape(X)[0])) else: GroupFold = np.transpose(GroupFold) UniqFold = np.array(list(set(tuple(i) for i in GroupFold.tolist()))) FoldIDs = np.arange(len(UniqFold)) + 1 if len(UniqFold) <= Unit: msgBox.setText( "Unit must be smaller than all possible levels! Number of all levels is: " + str(len(UniqFold))) msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if np.mod(len(UniqFold), Unit): msgBox.setText( "Unit must be divorceable to all possible levels! Number of all levels is: " + str(len(UniqFold))) msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False ListFold = list() for gfold in GroupFold: for ufoldindx, ufold in enumerate(UniqFold): if (ufold == gfold).all(): currentID = FoldIDs[ufoldindx] break ListFold.append(currentID) ListFold = np.int32(ListFold) if Unit == 1: UnitFold = np.int32(ListFold) else: UnitFold = np.int32((ListFold - 0.1) / Unit) + 1 GUFold = np.unique(UnitFold) FoldInfo = dict() FoldInfo["Unit"] = Unit FoldInfo["Group"] = GroupFold FoldInfo["Order"] = FoldStr FoldInfo["List"] = ListFold FoldInfo["Unique"] = UniqFold FoldInfo["Folds"] = UnitFold OutData["FoldInfo"] = FoldInfo OutData["ModelAnalysis"] = "SK.Group.RSA." + ui.cbMethod.currentText() print("Number of all levels is: " + str(len(UniqFold))) Cov = None Corr = None AMSE = list() Beta = None for foldID, fold in enumerate(GUFold): print("Analyzing level " + str(foldID + 1), " of ", str(len(UniqFold)), " ...") Index = np.where(UnitFold == fold) # Whole-Data if FoldStr == "Whole-Data" and np.shape(Index)[0]: Index = [Index[1]] XLi = X[Index] if ui.cbScale.isChecked() and ui.rbScale.isChecked(): XLi = preprocessing.scale(XLi) print("Whole of data is scaled X%d~N(0,1)." % (foldID + 1)) RegLi = np.insert(Design[Index], 0, 1, axis=1) if method == "ols": model = linmdl.LinearRegression(fit_intercept=fit, normalize=normalize, n_jobs=njob) elif method == "ridge": model = linmdl.Ridge(alpha=alpha, fit_intercept=fit, normalize=normalize, max_iter=iter, tol=tol, solver=solver) elif method == "lasso": model = linmdl.Lasso(alpha=alpha, fit_intercept=fit, normalize=normalize, max_iter=iter, tol=tol, selection=selection) elif method == "elast": model = linmdl.ElasticNet(alpha=alpha, l1_ratio=l1, fit_intercept=fit, normalize=normalize, \ max_iter=iter, tol=tol, selection=selection) model.fit(RegLi, XLi) BetaLi = np.transpose(model.coef_)[1:, :] Beta = BetaLi if Beta is None else Beta + BetaLi print("Calculating MSE for level %d ..." % (foldID + 1)) MSE = mean_squared_error(XLi, np.matmul(Design[Index], BetaLi)) print("MSE%d: %f" % (foldID + 1, MSE)) OutData["MSE" + str(foldID)] = MSE AMSE.append(MSE) if ui.cbBeta.isChecked(): OutData["BetaL" + str(foldID + 1)] = BetaLi # Calculate Results if ui.cbCorr.isChecked(): print("Calculating Correlation for level %d ..." % (foldID + 1)) CorrLi = np.corrcoef(BetaLi) OutData["Corr" + str(foldID + 1)] = CorrLi if Corr is None: Corr = CorrLi.copy() else: if ui.rbAvg.isChecked(): Corr = np.add(Corr, CorrLi) elif ui.rbMin.isChecked(): Corr = np.minimum(Corr, CorrLi) else: Corr = np.maximum(Corr, CorrLi) if ui.cbCov.isChecked(): print("Calculating Covariance for level %d ..." % (foldID + 1)) CovLi = np.cov(BetaLi) OutData["Cov" + str(foldID + 1)] = CovLi if Cov is None: Cov = CovLi.copy() else: if ui.rbAvg.isChecked(): Cov = np.add(Cov, CovLi) elif ui.rbMin.isChecked(): Cov = np.minimum(Cov, CovLi) else: Cov = np.maximum(Cov, CovLi) CoEff = len(UniqFold) - 1 if len(UniqFold) > 2 else 1 if ui.cbCov.isChecked(): if ui.rbAvg.isChecked(): Cov = Cov / CoEff covClass = SimilarityMatrixBetweenClass(Cov) OutData["Covariance"] = Cov OutData["Covariance_min"] = covClass.min() OutData["Covariance_max"] = covClass.max() OutData["Covariance_std"] = covClass.std() OutData["Covariance_mean"] = covClass.mean() if ui.cbCorr.isChecked(): if ui.rbAvg.isChecked(): Corr = Corr / CoEff corClass = SimilarityMatrixBetweenClass(Corr) OutData["Correlation"] = Corr OutData["Correlation_min"] = corClass.min() OutData["Correlation_max"] = corClass.max() OutData["Correlation_std"] = corClass.std() OutData["Correlation_mean"] = corClass.mean() # Calculating Distance Matrix dis = np.zeros((np.shape(Beta)[0], np.shape(Beta)[0])) for i in range(np.shape(Beta)[0]): for j in range(i + 1, np.shape(Beta)[0]): dis[i, j] = 1 - np.dot(Beta[i, :], Beta[j, :].T) dis[j, i] = dis[i, j] OutData["DistanceMatrix"] = dis Z = linkage(dis) OutData["Linkage"] = Z OutData["MSE"] = np.mean(AMSE) print("Average MSE: %f" % (OutData["MSE"])) OutData["RunTime"] = time.time() - tStart print("Runtime (s): %f" % (OutData["RunTime"])) print("Saving results ...") mainIO_save(OutData, OutFile) print("Output is saved.") if ui.cbDiagram.isChecked(): if ui.cbCorr.isChecked(): NumData = np.shape(Corr)[0] fig1 = plt.figure(num=None, figsize=(NumData, NumData), dpi=100) plt.pcolor(Corr, vmin=np.min(Corr), vmax=np.max(Corr)) plt.xlim([0, NumData]) plt.ylim([0, NumData]) cbar = plt.colorbar() cbar.ax.tick_params(labelsize=FontSize) ax = plt.gca() ax.invert_yaxis() ax.set_aspect(1) ax.set_yticks(np.arange(NumData) + 0.5, minor=False) ax.set_xticks(np.arange(NumData) + 0.5, minor=False) ax.set_xticklabels(labels, minor=False, fontsize=FontSize, rotation=ui.txtXRotation.value()) ax.set_yticklabels(labels, minor=False, fontsize=FontSize, rotation=ui.txtYRotation.value()) ax.grid(False) ax.set_aspect(1) ax.set_frame_on(False) for t in ax.xaxis.get_major_ticks(): t.tick1On = False t.tick2On = False for t in ax.yaxis.get_major_ticks(): t.tick1On = False t.tick2On = False if len(ui.txtTitleCorr.text()): plt.title(ui.txtTitleCorr.text()) else: plt.title('Group RSA: Correlation\nLevel: ' + FoldStr) plt.show() if ui.cbCov.isChecked(): NumData = np.shape(Cov)[0] fig2 = plt.figure(num=None, figsize=(NumData, NumData), dpi=100) plt.pcolor(Cov, vmin=np.min(Cov), vmax=np.max(Cov)) plt.xlim([0, NumData]) plt.ylim([0, NumData]) cbar = plt.colorbar() cbar.ax.tick_params(labelsize=FontSize) ax = plt.gca() ax.invert_yaxis() ax.set_aspect(1) ax.set_yticks(np.arange(NumData) + 0.5, minor=False) ax.set_xticks(np.arange(NumData) + 0.5, minor=False) ax.set_xticklabels(labels, minor=False, fontsize=FontSize, rotation=ui.txtXRotation.value()) ax.set_yticklabels(labels, minor=False, fontsize=FontSize, rotation=ui.txtYRotation.value()) ax.grid(False) ax.set_aspect(1) ax.set_frame_on(False) for t in ax.xaxis.get_major_ticks(): t.tick1On = False t.tick2On = False for t in ax.yaxis.get_major_ticks(): t.tick1On = False t.tick2On = False if len(ui.txtTitleCov.text()): plt.title(ui.txtTitleCov.text()) else: plt.title('Group RSA: Covariance\nLevel: ' + FoldStr) plt.show() fig3 = plt.figure(figsize=(25, 10), ) if len(ui.txtTitleDen.text()): plt.title(ui.txtTitleDen.text()) else: plt.title( 'Group MP Gradient RSA: Similarity Analysis\nLevel: ' + FoldStr) dn = dendrogram(Z, labels=labels, leaf_font_size=FontSize, color_threshold=1, leaf_rotation=ui.txtXRotation.value()) plt.show() print("DONE.") msgBox.setText( "Group Representational Similarity Analysis (RSA) is done.") msgBox.setIcon(QMessageBox.Information) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_()
def btnConvert_click(self): msgBox = QMessageBox() tStart = time.time() if not ui.cbCov.isChecked() and not ui.cbCorr.isChecked(): msgBox.setText("At least, you must select one metric!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Filter try: Filter = ui.txtFilter.text() if not len(Filter): Filter = None else: Filter = Filter.replace("\'", " ").replace(",", " ").replace( "[", "").replace("]", "").split() Filter = np.int32(Filter) except: print("Filter is wrong!") return # OutFile OutFile = ui.txtOutFile.text() if not len(OutFile): msgBox.setText("Please enter out file!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False OutData = dict() # InFile InFile = ui.txtInFile.text() if not len(InFile): msgBox.setText("Please enter input file!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if not os.path.isfile(InFile): msgBox.setText("Input file not found!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False print("Loading ...") InData = io.loadmat(InFile) # Data if not len(ui.txtData.currentText()): msgBox.setText("Please enter Input Data variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Label if not len(ui.txtLabel.currentText()): msgBox.setText("Please enter Train Label variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Design if not len(ui.txtDesign.currentText()): msgBox.setText("Please enter Input Design variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Design = InData[ui.txtDesign.currentText()] except: msgBox.setText("Design value is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: X = InData[ui.txtData.currentText()] L = InData[ui.txtLabel.currentText()][0] except: print("Cannot load data or label") return # Task if not len(ui.txtTask.currentText()): msgBox.setText("Please enter Task variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Task Val if not len(ui.txtTaskVal.currentText()): msgBox.setText("Please enter Task value!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: TaskIDTitle = ui.txtTaskVal.currentText() except: msgBox.setText("Task value is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: TaskTitle = InData[ui.txtTask.currentText()][0] except: msgBox.setText("Task variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False TaskTitleUnique = np.unique(TaskTitle) Task = np.zeros(np.shape(TaskTitle)) for ttinx, tt in enumerate(TaskTitle): for ttlinx, ttl in enumerate(TaskTitleUnique): if tt[0] == ttl: Task[ttinx] = ttlinx + 1 break for ttlinx, ttl in enumerate(TaskTitleUnique): if TaskIDTitle == ttl: TaskID = ttlinx + 1 break OutData["Task"] = TaskIDTitle # Subject if not len(ui.txtSubject.currentText()): msgBox.setText("Please enter Subject variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Subject Val if not len(ui.txtSubjectVal.currentText()): msgBox.setText("Please enter Subject value!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: SubID = np.int32(ui.txtSubjectVal.currentText()) except: msgBox.setText("Subject value is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Sub = InData[ui.txtSubject.currentText()][0] except: msgBox.setText("Subject variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False OutData["SubjectID"] = SubID # Run if not len(ui.txtRun.currentText()): msgBox.setText("Please enter Run variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Run Val if not len(ui.txtRunVal.currentText()): msgBox.setText("Please enter Run value!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: RunID = np.int32(ui.txtRunVal.currentText()) except: msgBox.setText("Run value is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Run = InData[ui.txtRun.currentText()][0] except: msgBox.setText("Run variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False OutData["RunID"] = RunID # Counter if not len(ui.txtCounter.currentText()): msgBox.setText("Please enter Counter variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Counter Val if not len(ui.txtCounterVal.currentText()): msgBox.setText("Please enter Counter value!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: ConID = np.int32(ui.txtCounterVal.currentText()) except: msgBox.setText("Counter value is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Con = InData[ui.txtCounter.currentText()][0] except: msgBox.setText("Counter variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False OutData["CounterID"] = ConID if Filter is not None: for fil in Filter: # Remove Training Set labelIndx = np.where(L == fil)[0] Design = np.delete(Design, labelIndx, axis=0) X = np.delete(X, labelIndx, axis=0) L = np.delete(L, labelIndx, axis=0) Task = np.delete(Task, labelIndx, axis=0) Sub = np.delete(Sub, labelIndx, axis=0) Run = np.delete(Run, labelIndx, axis=0) Con = np.delete(Con, labelIndx, axis=0) print("Class ID = " + str(fil) + " is removed from data.") # Select Task TaskIndex = np.where(Task == TaskID) Design = Design[TaskIndex, :][0] X = X[TaskIndex, :][0] L = L[TaskIndex] Sub = Sub[TaskIndex] Run = Run[TaskIndex] Con = Con[TaskIndex] # Select Subject SubIndex = np.where(Sub == SubID) Design = Design[SubIndex, :][0] X = X[SubIndex, :][0] L = L[SubIndex] Run = Run[SubIndex] Con = Con[SubIndex] # Select Counter ConIndex = np.where(Con == ConID) Design = Design[ConIndex, :][0] X = X[ConIndex, :][0] L = L[ConIndex] Run = Run[ConIndex] # Select Run RunIndex = np.where(Run == RunID) Design = Design[RunIndex, :][0] X = X[RunIndex, :][0] L = L[RunIndex] # This will only use in supervised methods LUnique = np.unique(L) LNum = np.shape(LUnique)[0] OutData["Label"] = LUnique OutData["ModelAnalysis"] = "Numpy.Session.RSA" if np.shape(X)[0] == 0: msgBox.setText("The selected data is empty!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if ui.cbScale.isChecked(): X = preprocessing.scale(X) print("Data is scaled to N(0,1).") print("Running RSA ...") # RSA Method Reg = np.insert(Design, 0, 1, axis=1) Betas = np.linalg.lstsq(Reg, X)[0][1:, :] print("Calculating MSE ...") MSE = mean_squared_error(X, np.matmul(Design, Betas)) print("MSE: %f" % (MSE)) OutData["MSE"] = MSE if ui.cbBeta.isChecked(): OutData["Betas"] = Betas # Calculate Results if ui.cbCorr.isChecked(): print("Calculating Correlation ...") Corr = np.corrcoef(Betas) corClass = SimilarityMatrixBetweenClass(Corr) OutData["Correlation"] = Corr OutData["Correlation_min"] = corClass.min() OutData["Correlation_max"] = corClass.max() OutData["Correlation_std"] = corClass.std() OutData["Correlation_mean"] = corClass.mean() if ui.cbCov.isChecked(): print("Calculating Covariance ...") Cov = np.cov(Betas) covClass = SimilarityMatrixBetweenClass(Cov) OutData["Covariance"] = Cov OutData["Covariance_min"] = covClass.min() OutData["Covariance_max"] = covClass.max() OutData["Covariance_std"] = covClass.std() OutData["Covariance_mean"] = covClass.mean() OutData["RunTime"] = time.time() - tStart print("Runtime (s): %f" % (OutData["RunTime"])) print("Saving results ...") io.savemat(OutFile, mdict=OutData, do_compression=True) print("Output is saved.") if ui.cbDiagram.isChecked(): if ui.cbCorr.isChecked(): fig1 = plt.figure(num=None, figsize=(5, 5), dpi=100) plt.pcolor(Corr, vmin=-0.1, vmax=1) plt.xlim([0, LNum]) plt.ylim([0, LNum]) plt.colorbar() ax = plt.gca() ax.set_aspect(1) plt.title( 'Correlation of Categories\nTask: %s\nSub: %d, Counter: %d, Run: %d' % (TaskIDTitle, SubID, ConID, RunID)) plt.show() if ui.cbCov.isChecked(): fig2 = plt.figure(num=None, figsize=(5, 5), dpi=100) plt.pcolor(Cov) plt.xlim([0, LNum]) plt.ylim([0, LNum]) plt.colorbar() ax = plt.gca() ax.set_aspect(1) plt.title( 'Covariance of Categories\nTask: %s\nSub: %d, Counter: %d, Run: %d' % (TaskIDTitle, SubID, ConID, RunID)) plt.show() print("DONE.") msgBox.setText("Representational Similarity Analysis (RSA) is done.") msgBox.setIcon(QMessageBox.Information) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_()
def btnConvert_click(self): msgBox = QMessageBox() tStart = time.time() Method = ui.cbMethod.currentData() LossType = ui.cbLossType.currentData() Optim = ui.cbOptim.currentData() try: Epoch = np.int32(ui.txtIter.text()) except: msgBox.setText("Number of iteration is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: BatchSize = np.int32(ui.txtBatch.text()) except: msgBox.setText("Number of batch is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: ReportStep = np.int32(ui.txtReportStep.text()) except: msgBox.setText("Number of Report Step is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: LearningRate = np.float32(ui.txtRate.text()) except: msgBox.setText("Learning rate is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: LassoAlpha = np.float32(ui.txtLParam.text()) except: msgBox.setText("Number of Lasso Parameter is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: ElasticLambda1 = np.float32(ui.txtEL1.text()) except: msgBox.setText("Number of Elastic Lambda 1 is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: ElasticAlpha = np.float32(ui.txtEL2.text()) except: msgBox.setText("Number of Elastic Lambda 2 is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: RidgeReg = np.float32(ui.txtRRP.text()) except: msgBox.setText("Number of Ridge Regression Parameter is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if not ui.cbCov.isChecked() and not ui.cbCorr.isChecked(): msgBox.setText("At least, you must select one metric!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Filter try: Filter = ui.txtFilter.text() if not len(Filter): Filter = None else: Filter = Filter.replace("\'", " ").replace(",", " ").replace( "[", "").replace("]", "").split() Filter = np.int32(Filter) except: print("Filter is wrong!") return # OutFile OutFile = ui.txtOutFile.text() if not len(OutFile): msgBox.setText("Please enter out file!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False OutData = dict() # InFile InFile = ui.txtInFile.text() if not len(InFile): msgBox.setText("Please enter input file!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if not os.path.isfile(InFile): msgBox.setText("Input file not found!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False print("Loading ...") InData = io.loadmat(InFile) # Data if not len(ui.txtData.currentText()): msgBox.setText("Please enter Input Data variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Label if not len(ui.txtLabel.currentText()): msgBox.setText("Please enter Train Label variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Design if not len(ui.txtDesign.currentText()): msgBox.setText("Please enter Input Design variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Design = InData[ui.txtDesign.currentText()] except: msgBox.setText("Design value is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Condition if not len(ui.txtCond.currentText()): msgBox.setText("Please enter Condition variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Cond = InData[ui.txtCond.currentText()] OutData[ui.txtCond.currentText()] = Cond labels = list() for con in Cond: labels.append(con[1][0]) labels = np.array(labels) except: msgBox.setText("Condition value is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False FontSize = ui.txtFontSize.value() try: X = InData[ui.txtData.currentText()] L = InData[ui.txtLabel.currentText()][0] if ui.cbScale.isChecked() and not ui.rbScale.isChecked(): X = preprocessing.scale(X) print("Whole of data is scaled X~N(0,1).") except: print("Cannot load data or label") return # Task if not len(ui.txtTask.currentText()): msgBox.setText("Please enter Task variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: TaskTitle = InData[ui.txtTask.currentText()][0] except: msgBox.setText("Task variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False TaskTitleUnique = np.unique(TaskTitle) Task = np.zeros(np.shape(TaskTitle)) for ttinx, tt in enumerate(TaskTitle): for ttlinx, ttl in enumerate(TaskTitleUnique): if tt[0] == ttl: Task[ttinx] = ttlinx + 1 break # Subject if not len(ui.txtSubject.currentText()): msgBox.setText("Please enter Subject variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Sub = InData[ui.txtSubject.currentText()][0] except: msgBox.setText("Subject variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Run if not len(ui.txtRun.currentText()): msgBox.setText("Please enter Run variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Run = InData[ui.txtRun.currentText()][0] except: msgBox.setText("Run variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Counter if not len(ui.txtCounter.currentText()): msgBox.setText("Please enter Counter variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Con = InData[ui.txtCounter.currentText()][0] except: msgBox.setText("Counter variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if Filter is not None: for fil in Filter: # Remove Training Set labelIndx = np.where(L == fil)[0] Design = np.delete(Design, labelIndx, axis=0) X = np.delete(X, labelIndx, axis=0) L = np.delete(L, labelIndx, axis=0) Task = np.delete(Task, labelIndx, axis=0) Sub = np.delete(Sub, labelIndx, axis=0) Run = np.delete(Run, labelIndx, axis=0) Con = np.delete(Con, labelIndx, axis=0) print("Class ID = " + str(fil) + " is removed from data.") try: Unit = np.int32(ui.txtUnit.text()) except: msgBox.setText("Unit for the test set must be a number!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if Unit < 1: msgBox.setText("Unit for the test set must be greater than zero!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False print("Calculating Levels ...") GroupFold = None FoldStr = "" if ui.cbFSubject.isChecked(): if not ui.rbFRun.isChecked(): GroupFold = [Sub] FoldStr = "Subject" else: GroupFold = np.concatenate(([Sub], [Run])) FoldStr = "Subject+Run" if ui.cbFTask.isChecked(): GroupFold = np.concatenate( (GroupFold, [Task])) if GroupFold is not None else [Task] FoldStr = FoldStr + "+Task" if ui.cbFCounter.isChecked(): GroupFold = np.concatenate( (GroupFold, [Con])) if GroupFold is not None else [Con] FoldStr = FoldStr + "+Counter" if FoldStr == "": FoldStr = "Whole-Data" GUFold = [1] ListFold = [1] UniqFold = [1] GroupFold = [1] UnitFold = np.ones((1, np.shape(X)[0])) else: GroupFold = np.transpose(GroupFold) UniqFold = np.array(list(set(tuple(i) for i in GroupFold.tolist()))) FoldIDs = np.arange(len(UniqFold)) + 1 if len(UniqFold) <= Unit: msgBox.setText( "Unit must be smaller than all possible levels! Number of all levels is: " + str(len(UniqFold))) msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if np.mod(len(UniqFold), Unit): msgBox.setText( "Unit must be divorceable to all possible levels! Number of all levels is: " + str(len(UniqFold))) msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False ListFold = list() for gfold in GroupFold: for ufoldindx, ufold in enumerate(UniqFold): if (ufold == gfold).all(): currentID = FoldIDs[ufoldindx] break ListFold.append(currentID) ListFold = np.int32(ListFold) if Unit == 1: UnitFold = np.int32(ListFold) else: UnitFold = np.int32((ListFold - 0.1) / Unit) + 1 GUFold = np.unique(UnitFold) FoldInfo = dict() FoldInfo["Unit"] = Unit FoldInfo["Group"] = GroupFold FoldInfo["Order"] = FoldStr FoldInfo["List"] = ListFold FoldInfo["Unique"] = UniqFold FoldInfo["Folds"] = UnitFold OutData["FoldInfo"] = FoldInfo OutData[ "ModelAnalysis"] = "PyTorch.Group.Gradient.RSA." + ui.cbMethod.currentText( ) print("Number of all levels is: " + str(len(UniqFold))) Cov = None Corr = None AMSE = list() APer = list() # RSA Method OutData['Method'] = dict() OutData['Method']['Method'] = Method OutData['Method']['LossType'] = LossType OutData['Method']['Optimization'] = Optim OutData['Method']['LearningRate'] = LearningRate OutData['Method']['Epoch'] = Epoch OutData['Method']['BatchSize'] = BatchSize OutData['Method']['ReportStep'] = ReportStep OutData['Method']['RidgeAlpha'] = RidgeReg OutData['Method']['ElaticLambda1'] = ElasticLambda1 OutData['Method']['ElaticAlpha'] = ElasticAlpha OutData['Method']['LassoAlpha'] = LassoAlpha OutData['Method']['Verbose'] = ui.cbVerbose.isChecked() Beta = None for foldID, fold in enumerate(GUFold): print("Analyzing level " + str(foldID + 1), " of ", str(len(UniqFold)), " ...") Index = np.where(UnitFold == fold) # Whole-Data if FoldStr == "Whole-Data" and np.shape(Index)[0]: Index = [Index[1]] XLi = X[Index] RegLi = Design[Index] if ui.cbScale.isChecked() and ui.rbScale.isChecked(): XLi = preprocessing.scale(XLi) print("Whole of data is scaled X%d~N(0,1)." % (foldID + 1)) print("Running Gradient RSA ...") rsa = GrRSA(method=Method, loss_type=LossType, optim=Optim, learning_rate=LearningRate, epoch=Epoch, \ batch_size=BatchSize, report_step=ReportStep, ridge_param=RidgeReg, elstnet_l1_ratio=ElasticLambda1, \ elstnet_alpha=ElasticAlpha, lasso_alpha=LassoAlpha, verbose=ui.cbVerbose.isChecked(), \ gpu_enable=ui.cbDevice.currentData(), normalization=False) BetaLi, EpsLi, loss_vec, MSE, Performance, _ = rsa.fit( data_vals=XLi, design_vals=RegLi) OutData["LossVec"] = loss_vec print("Calculating MSE for level %d ..." % (foldID + 1)) print("MSE%d: %f" % (foldID + 1, MSE)) print("Perfromance%d: %f" % (foldID + 1, Performance)) OutData["MSE" + str(foldID)] = MSE OutData["Performance" + str(foldID)] = MSE AMSE.append(MSE) APer.append(Performance) Beta = BetaLi if Beta is None else Beta + BetaLi if ui.cbBeta.isChecked(): OutData["BetaL" + str(foldID + 1)] = BetaLi OutData["EpsL" + str(foldID + 1)] = EpsLi # Calculate Results if ui.cbCorr.isChecked(): print("Calculating Correlation for level %d ..." % (foldID + 1)) CorrLi = np.corrcoef(BetaLi) OutData["Corr" + str(foldID + 1)] = CorrLi if Corr is None: Corr = CorrLi.copy() else: if ui.rbAvg.isChecked(): Corr = np.add(Corr, CorrLi) elif ui.rbMin.isChecked(): Corr = np.minimum(Corr, CorrLi) else: Corr = np.maximum(Corr, CorrLi) if ui.cbCov.isChecked(): print("Calculating Covariance for level %d ..." % (foldID + 1)) CovLi = np.cov(BetaLi) OutData["Cov" + str(foldID + 1)] = CovLi if Cov is None: Cov = CovLi.copy() else: if ui.rbAvg.isChecked(): Cov = np.add(Cov, CovLi) elif ui.rbMin.isChecked(): Cov = np.minimum(Cov, CovLi) else: Cov = np.maximum(Cov, CovLi) if ui.cbCov.isChecked(): if ui.rbAvg.isChecked(): Cov = Cov / len(UniqFold) covClass = SimilarityMatrixBetweenClass(Cov) OutData["Covariance"] = Cov OutData["Covariance_min"] = covClass.min() OutData["Covariance_max"] = covClass.max() OutData["Covariance_std"] = covClass.std() OutData["Covariance_mean"] = covClass.mean() if ui.cbCorr.isChecked(): if ui.rbAvg.isChecked(): Corr = Corr / len(UniqFold) corClass = SimilarityMatrixBetweenClass(Corr) OutData["Correlation"] = Corr OutData["Correlation_min"] = corClass.min() OutData["Correlation_max"] = corClass.max() OutData["Correlation_std"] = corClass.std() OutData["Correlation_mean"] = corClass.mean() OutData["MSE"] = np.mean(AMSE) OutData["MSE_std"] = np.std(AMSE) OutData["Performance"] = np.mean(APer) OutData["Performance_std"] = np.std(APer) # Calculating Distance Matrix dis = np.zeros((np.shape(Beta)[0], np.shape(Beta)[0])) for i in range(np.shape(Beta)[0]): for j in range(i + 1, np.shape(Beta)[0]): dis[i, j] = 1 - np.dot(Beta[i, :], Beta[j, :].T) dis[j, i] = dis[i, j] OutData["DistanceMatrix"] = dis Z = linkage(dis) OutData["Linkage"] = Z print("Average MSE: %f" % (OutData["MSE"])) OutData["RunTime"] = time.time() - tStart print("Runtime (s): %f" % (OutData["RunTime"])) print("Saving results ...") io.savemat(OutFile, mdict=OutData, do_compression=True) print("Output is saved.") if ui.cbDiagram.isChecked(): if ui.cbCorr.isChecked(): NumData = np.shape(Corr)[0] fig1 = plt.figure(num=None, figsize=(NumData, NumData), dpi=100) plt.pcolor(Corr, vmin=np.min(Corr), vmax=np.max(Corr)) plt.xlim([0, NumData]) plt.ylim([0, NumData]) cbar = plt.colorbar() cbar.ax.tick_params(labelsize=FontSize) ax = plt.gca() ax.invert_yaxis() ax.set_aspect(1) ax.set_yticks(np.arange(NumData) + 0.5, minor=False) ax.set_xticks(np.arange(NumData) + 0.5, minor=False) ax.set_xticklabels(labels, minor=False, fontsize=FontSize, rotation=45) ax.set_yticklabels(labels, minor=False, fontsize=FontSize) ax.grid(False) ax.set_aspect(1) ax.set_frame_on(False) for t in ax.xaxis.get_major_ticks(): t.tick1On = False t.tick2On = False for t in ax.yaxis.get_major_ticks(): t.tick1On = False t.tick2On = False if len(ui.txtTitleCorr.text()): plt.title(ui.txtTitleCorr.text()) else: plt.title('Group MP Gradient RSA: Correlation\nLevel: ' + FoldStr) plt.show() if ui.cbCov.isChecked(): NumData = np.shape(Cov)[0] fig2 = plt.figure(num=None, figsize=(NumData, NumData), dpi=100) plt.pcolor(Cov, vmin=np.min(Cov), vmax=np.max(Cov)) plt.xlim([0, NumData]) plt.ylim([0, NumData]) cbar = plt.colorbar() cbar.ax.tick_params(labelsize=FontSize) ax = plt.gca() ax.invert_yaxis() ax.set_aspect(1) ax.set_yticks(np.arange(NumData) + 0.5, minor=False) ax.set_xticks(np.arange(NumData) + 0.5, minor=False) ax.set_xticklabels(labels, minor=False, fontsize=FontSize, rotation=45) ax.set_yticklabels(labels, minor=False, fontsize=FontSize) ax.grid(False) ax.set_aspect(1) ax.set_frame_on(False) for t in ax.xaxis.get_major_ticks(): t.tick1On = False t.tick2On = False for t in ax.yaxis.get_major_ticks(): t.tick1On = False t.tick2On = False if len(ui.txtTitleCov.text()): plt.title(ui.txtTitleCov.text()) else: plt.title('Group MP Gradient RSA: Covariance\nLevel: ' + FoldStr) plt.show() fig3 = plt.figure(figsize=(25, 10), ) if len(ui.txtTitleDen.text()): plt.title(ui.txtTitleDen.text()) else: plt.title( 'Group MP Gradient RSA: Similarity Analysis\nLevel: ' + FoldStr) dn = dendrogram(Z, labels=labels, leaf_font_size=FontSize, color_threshold=1) plt.show() print("DONE.") msgBox.setText( "Group Representational Similarity Analysis (RSA) is done.") msgBox.setIcon(QMessageBox.Information) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_()
def btnConvert_click(self): msgBox = QMessageBox() tStart = time.time() Method = ui.cbMethod.currentData() LossType = ui.cbLossType.currentData() Optim = ui.cbOptim.currentData() try: Epoch = np.int32(ui.txtIter.text()) except: msgBox.setText("Number of iteration is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: BatchSize = np.int32(ui.txtBatch.text()) except: msgBox.setText("Number of batch is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: ReportStep = np.int32(ui.txtReportStep.text()) except: msgBox.setText("Number of Report Step is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: LearningRate = np.float32(ui.txtRate.text()) except: msgBox.setText("Number of Report Step is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: LassoAlpha = np.float32(ui.txtLParam.text()) except: msgBox.setText("Number of Lasso Parameter is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: ElasticLambda1 = np.float32(ui.txtEL1.text()) except: msgBox.setText("Number of Elastic Lambda 1 is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: ElasticAlpha = np.float32(ui.txtEL2.text()) except: msgBox.setText("Number of Elastic Lambda 2 is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: RidgeReg = np.float32(ui.txtRRP.text()) except: msgBox.setText("Number of Ridge Regression Parameter is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if not ui.cbCov.isChecked() and not ui.cbCorr.isChecked(): msgBox.setText("At least, you must select one metric!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Filter try: Filter = ui.txtFilter.text() if not len(Filter): Filter = None else: Filter = Filter.replace("\'", " ").replace(",", " ").replace("[", "").replace("]","").split() Filter = np.int32(Filter) except: print("Filter is wrong!") return # OutFile OutFile = ui.txtOutFile.text() if not len(OutFile): msgBox.setText("Please enter out file!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False OutData = dict() # InFile InFile = ui.txtInFile.text() if not len(InFile): msgBox.setText("Please enter input file!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if not os.path.isfile(InFile): msgBox.setText("Input file not found!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False print("Loading ...") InData = mainIO_load(InFile) # Data if not len(ui.txtData.currentText()): msgBox.setText("Please enter Input Data variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Label if not len(ui.txtLabel.currentText()): msgBox.setText("Please enter Train Label variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Design if not len(ui.txtDesign.currentText()): msgBox.setText("Please enter Input Design variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Design = InData[ui.txtDesign.currentText()] except: msgBox.setText("Design value is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Condition if not len(ui.txtCond.currentText()): msgBox.setText("Please enter Condition variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Cond = InData[ui.txtCond.currentText()] OutData[ui.txtCond.currentText()] = Cond labels = list() for con in Cond: labels.append(reshape_condition_cell(con[1])) except: msgBox.setText("Condition value is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False FontSize = ui.txtFontSize.value() try: X = InData[ui.txtData.currentText()] L = InData[ui.txtLabel.currentText()][0] except: print("Cannot load data or label") return # Task if not len(ui.txtTask.currentText()): msgBox.setText("Please enter Task variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Task Val if not len(ui.txtTaskVal.currentText()): msgBox.setText("Please enter Task value!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: TaskIDTitle = ui.txtTaskVal.currentText() except: msgBox.setText("Task value is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: TaskTitle = np.array(InData[ui.txtTask.currentText()][0]) except: msgBox.setText("Task variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False TaskTitleUnique = np.unique(TaskTitle) Task = np.zeros(np.shape(TaskTitle)) for ttinx, tt in enumerate(TaskTitle): for ttlinx, ttl in enumerate(TaskTitleUnique): if tt[0] == ttl: Task[ttinx] = ttlinx + 1 break for ttlinx, ttl in enumerate(TaskTitleUnique): if TaskIDTitle == ttl: TaskID = ttlinx + 1 break OutData["Task"] = TaskIDTitle # Subject if not len(ui.txtSubject.currentText()): msgBox.setText("Please enter Subject variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Subject Val if not len(ui.txtSubjectVal.currentText()): msgBox.setText("Please enter Subject value!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: SubID = np.int32(ui.txtSubjectVal.currentText()) except: msgBox.setText("Subject value is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Sub = InData[ui.txtSubject.currentText()][0] except: msgBox.setText("Subject variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False OutData["SubjectID"] = SubID # Run if not len(ui.txtRun.currentText()): msgBox.setText("Please enter Run variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Run Val if not len(ui.txtRunVal.currentText()): msgBox.setText("Please enter Run value!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: RunID = np.int32(ui.txtRunVal.currentText()) except: msgBox.setText("Run value is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Run = InData[ui.txtRun.currentText()][0] except: msgBox.setText("Run variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False OutData["RunID"] = RunID # Counter if not len(ui.txtCounter.currentText()): msgBox.setText("Please enter Counter variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Counter Val if not len(ui.txtCounterVal.currentText()): msgBox.setText("Please enter Counter value!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: ConID = np.int32(ui.txtCounterVal.currentText()) except: msgBox.setText("Counter value is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Con = InData[ui.txtCounter.currentText()][0] except: msgBox.setText("Counter variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False OutData["CounterID"] = ConID if Filter is not None: for fil in Filter: # Remove Training Set labelIndx = np.where(L == fil)[0] Design = np.delete(Design, labelIndx, axis=0) X = np.delete(X, labelIndx, axis=0) L = np.delete(L, labelIndx, axis=0) Task = np.delete(Task, labelIndx, axis=0) Sub = np.delete(Sub, labelIndx, axis=0) Run = np.delete(Run, labelIndx, axis=0) Con = np.delete(Con, labelIndx, axis=0) print("Class ID = " + str(fil) + " is removed from data.") # Select Task TaskIndex = np.where(Task == TaskID) Design = Design[TaskIndex,:][0] X = X[TaskIndex,:][0] L = L[TaskIndex] Sub = Sub[TaskIndex] Run = Run[TaskIndex] Con = Con[TaskIndex] # Select Subject SubIndex = np.where(Sub == SubID) Design = Design[SubIndex,:][0] X = X[SubIndex,:][0] L = L[SubIndex] Run = Run[SubIndex] Con = Con[SubIndex] # Select Counter ConIndex = np.where(Con == ConID) Design = Design[ConIndex,:][0] X = X[ConIndex,:][0] L = L[ConIndex] Run = Run[ConIndex] # Select Run RunIndex = np.where(Run == RunID) Design = Design[RunIndex,:][0] X = X[RunIndex,:][0] L = L[RunIndex] # This will only use in supervised methods LUnique = np.unique(L) LNum = np.shape(LUnique)[0] OutData["Label"] = LUnique OutData["ModelAnalysis"] = "PyTorch.Session.Gradient.RSA." + ui.cbMethod.currentText() if np.shape(X)[0] == 0: msgBox.setText("The selected data is empty!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if ui.cbScale.isChecked(): X = preprocessing.scale(X) print("Data is scaled to N(0,1).") print("Running Gradient RSA ...") # RSA Method OutData['Method'] = dict() OutData['Method']['Method'] = Method OutData['Method']['LossType'] = LossType OutData['Method']['Optimization'] = Optim OutData['Method']['LearningRate'] = LearningRate OutData['Method']['Epoch'] = Epoch OutData['Method']['BatchSize'] = BatchSize OutData['Method']['ReportStep'] = ReportStep OutData['Method']['RidgeAlpha'] = RidgeReg OutData['Method']['ElaticLambda1'] = ElasticLambda1 OutData['Method']['ElaticAlpha'] = ElasticAlpha OutData['Method']['LassoAlpha'] = LassoAlpha OutData['Method']['Verbose'] = ui.cbVerbose.isChecked() rsa = GrRSA(method=Method,loss_type=LossType, optim=Optim, learning_rate=LearningRate, epoch=Epoch, \ batch_size=BatchSize, report_step=ReportStep, ridge_param=RidgeReg, elstnet_l1_ratio=ElasticLambda1,\ elstnet_alpha=ElasticAlpha, lasso_alpha=LassoAlpha, verbose=ui.cbVerbose.isChecked(),\ gpu_enable=ui.cbDevice.currentData(),normalization=False) Betas, Eps, loss_vec, MSE, Performacne, _ = rsa.fit(data_vals=X, design_vals=Design) OutData["LossVec"] = loss_vec OutData["MSE"] = MSE OutData["Performance"] = Performacne print("MSE: ", OutData["MSE"]) if ui.cbBeta.isChecked(): OutData["Betas"] = Betas OutData["Eps"] = Eps # Calculate Results if ui.cbCorr.isChecked(): print("Calculating Correlation ...") Corr = np.corrcoef(Betas) corClass = SimilarityMatrixBetweenClass(Corr) OutData["Correlation"] = Corr OutData["Correlation_min"] = corClass.min() OutData["Correlation_max"] = corClass.max() OutData["Correlation_std"] = corClass.std() OutData["Correlation_mean"] = corClass.mean() if ui.cbCov.isChecked(): print("Calculating Covariance ...") Cov = np.cov(Betas) covClass = SimilarityMatrixBetweenClass(Cov) OutData["Covariance"] = Cov OutData["Covariance_min"] = covClass.min() OutData["Covariance_max"] = covClass.max() OutData["Covariance_std"] = covClass.std() OutData["Covariance_mean"] = covClass.mean() # Calculating Distance Matrix dis = np.zeros((np.shape(Betas)[0], np.shape(Betas)[0])) for i in range(np.shape(Betas)[0]): for j in range(i + 1, np.shape(Betas)[0]): dis[i, j] = 1 - np.dot(Betas[i, :], Betas[j, :].T) dis[j, i] = dis[i, j] OutData["DistanceMatrix"] = dis Z = linkage(dis) OutData["Linkage"] = Z OutData["RunTime"] = time.time() - tStart print("Runtime (s): %f" % (OutData["RunTime"])) print("Saving results ...") mainIO_save(OutData, OutFile) print("Output is saved.") if ui.cbDiagram.isChecked(): if ui.cbCorr.isChecked(): NumData = np.shape(Corr)[0] fig1 = plt.figure(num=None, figsize=(NumData, NumData), dpi=100) plt.pcolor(Corr, vmin=np.min(Corr), vmax=np.max(Corr)) plt.xlim([0, NumData]) plt.ylim([0, NumData]) cbar = plt.colorbar() cbar.ax.tick_params(labelsize=FontSize) ax = plt.gca() ax.invert_yaxis() ax.set_aspect(1) ax.set_yticks(np.arange(NumData) + 0.5, minor=False) ax.set_xticks(np.arange(NumData) + 0.5, minor=False) ax.set_xticklabels(labels, minor=False, fontsize=FontSize, rotation=45) ax.set_yticklabels(labels, minor=False, fontsize=FontSize) ax.grid(False) ax.set_aspect(1) ax.set_frame_on(False) for t in ax.xaxis.get_major_ticks(): t.tick1On = False t.tick2On = False for t in ax.yaxis.get_major_ticks(): t.tick1On = False t.tick2On = False if len(ui.txtTitleCorr.text()): plt.title(ui.txtTitleCorr.text()) else: plt.title('Correlation (' + ui.cbMethod.currentText() + \ ')\nTask: %s\nSub: %d, Counter: %d, Run: %d' % (TaskIDTitle, SubID, ConID, RunID)) plt.show() if ui.cbCov.isChecked(): NumData = np.shape(Cov)[0] fig2 = plt.figure(num=None, figsize=(NumData, NumData), dpi=100) plt.pcolor(Cov, vmin=np.min(Cov), vmax=np.max(Cov)) plt.xlim([0, NumData]) plt.ylim([0, NumData]) cbar = plt.colorbar() cbar.ax.tick_params(labelsize=FontSize) ax = plt.gca() ax.invert_yaxis() ax.set_aspect(1) ax.set_yticks(np.arange(NumData) + 0.5, minor=False) ax.set_xticks(np.arange(NumData) + 0.5, minor=False) ax.set_xticklabels(labels, minor=False, fontsize=FontSize, rotation=45) ax.set_yticklabels(labels, minor=False, fontsize=FontSize) ax.grid(False) ax.set_aspect(1) ax.set_frame_on(False) for t in ax.xaxis.get_major_ticks(): t.tick1On = False t.tick2On = False for t in ax.yaxis.get_major_ticks(): t.tick1On = False t.tick2On = False if len(ui.txtTitleCov.text()): plt.title(ui.txtTitleCov.text()) else: plt.title('Covariance (' + ui.cbMethod.currentText() + \ ')\nTask: %s\nSub: %d, Counter: %d, Run: %d' % (TaskIDTitle, SubID, ConID, RunID)) plt.show() fig3 = plt.figure(figsize=(25, 10), ) if len(ui.txtTitleDen.text()): plt.title(ui.txtTitleDen.text()) else: plt.title('Similarity Analysis (' + ui.cbMethod.currentText() + \ ')\nTask: %s\nSub: %d, Counter: %d, Run: %d' % (TaskIDTitle, SubID, ConID, RunID)) dn = dendrogram(Z, labels=labels, leaf_font_size=FontSize, color_threshold=1) plt.show() print("DONE.") msgBox.setText("Gradient Representational Similarity Analysis (RSA) is done.") msgBox.setIcon(QMessageBox.Information) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_()
def btnConvert_click(self): msgBox = QMessageBox() tStart = time.time() Activation = ui.cbActivation.currentData() LossNorm = ui.cbLossNorm.currentData() try: Layers = strRange(ui.txtLayers.text(), Unique=False) if Layers is None: raise Exception('') except: msgBox.setText("Layers is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: KIter = np.int32(ui.txtKIter.text()) except: msgBox.setText("Number of iteration is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: RIter = np.int32(ui.txtRIter.text()) except: msgBox.setText("Number of iteration is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: BatchSize = np.int32(ui.txtBatch.text()) except: msgBox.setText("Number of batch is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: ReportStep = np.int32(ui.txtReportStep.text()) except: msgBox.setText("Number of Report Step is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: LearningRate = np.float32(ui.txtRate.text()) except: msgBox.setText("Number of Report Step is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if not ui.cbCov.isChecked() and not ui.cbCorr.isChecked(): msgBox.setText("At least, you must select one metric!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Filter try: Filter = ui.txtFilter.text() if not len(Filter): Filter = None else: Filter = Filter.replace("\'", " ").replace(",", " ").replace( "[", "").replace("]", "").split() Filter = np.int32(Filter) except: print("Filter is wrong!") return # OutFile OutFile = ui.txtOutFile.text() if not len(OutFile): msgBox.setText("Please enter out file!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False OutData = dict() # InFile InFile = ui.txtInFile.text() if not len(InFile): msgBox.setText("Please enter input file!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if not os.path.isfile(InFile): msgBox.setText("Input file not found!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False print("Loading ...") InData = io.loadmat(InFile) # Data if not len(ui.txtData.currentText()): msgBox.setText("Please enter Input Data variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Label if not len(ui.txtLabel.currentText()): msgBox.setText("Please enter Train Label variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Design if not len(ui.txtDesign.currentText()): msgBox.setText("Please enter Input Design variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Design = InData[ui.txtDesign.currentText()] except: msgBox.setText("Design value is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: X = InData[ui.txtData.currentText()] L = InData[ui.txtLabel.currentText()][0] if ui.cbScale.isChecked() and not ui.rbScale.isChecked(): X = preprocessing.scale(X) print("Whole of data is scaled X~N(0,1).") except: print("Cannot load data or label") return # Task if not len(ui.txtTask.currentText()): msgBox.setText("Please enter Task variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: TaskTitle = InData[ui.txtTask.currentText()][0] except: msgBox.setText("Task variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False TaskTitleUnique = np.unique(TaskTitle) Task = np.zeros(np.shape(TaskTitle)) for ttinx, tt in enumerate(TaskTitle): for ttlinx, ttl in enumerate(TaskTitleUnique): if tt[0] == ttl: Task[ttinx] = ttlinx + 1 break # Subject if not len(ui.txtSubject.currentText()): msgBox.setText("Please enter Subject variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Sub = InData[ui.txtSubject.currentText()][0] except: msgBox.setText("Subject variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Run if not len(ui.txtRun.currentText()): msgBox.setText("Please enter Run variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Run = InData[ui.txtRun.currentText()][0] except: msgBox.setText("Run variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Counter if not len(ui.txtCounter.currentText()): msgBox.setText("Please enter Counter variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Con = InData[ui.txtCounter.currentText()][0] except: msgBox.setText("Counter variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if Filter is not None: for fil in Filter: # Remove Training Set labelIndx = np.where(L == fil)[0] Design = np.delete(Design, labelIndx, axis=0) X = np.delete(X, labelIndx, axis=0) L = np.delete(L, labelIndx, axis=0) Task = np.delete(Task, labelIndx, axis=0) Sub = np.delete(Sub, labelIndx, axis=0) Run = np.delete(Run, labelIndx, axis=0) Con = np.delete(Con, labelIndx, axis=0) print("Class ID = " + str(fil) + " is removed from data.") try: Unit = np.int32(ui.txtUnit.text()) except: msgBox.setText("Unit for the test set must be a number!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if Unit < 1: msgBox.setText("Unit for the test set must be greater than zero!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False print("Calculating Levels ...") GroupFold = None FoldStr = "" if ui.cbFSubject.isChecked(): if not ui.rbFRun.isChecked(): GroupFold = [Sub] FoldStr = "Subject" else: GroupFold = np.concatenate(([Sub], [Run])) FoldStr = "Subject+Run" if ui.cbFTask.isChecked(): GroupFold = np.concatenate( (GroupFold, [Task])) if GroupFold is not None else [Task] FoldStr = FoldStr + "+Task" if ui.cbFCounter.isChecked(): GroupFold = np.concatenate( (GroupFold, [Con])) if GroupFold is not None else [Con] FoldStr = FoldStr + "+Counter" if FoldStr == "": FoldStr = "Whole-Data" GUFold = [1] ListFold = [1] UniqFold = [1] GroupFold = [1] UnitFold = np.ones((1, np.shape(X)[0])) else: GroupFold = np.transpose(GroupFold) UniqFold = np.array(list(set(tuple(i) for i in GroupFold.tolist()))) FoldIDs = np.arange(len(UniqFold)) + 1 if len(UniqFold) <= Unit: msgBox.setText( "Unit must be smaller than all possible levels! Number of all levels is: " + str(len(UniqFold))) msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if np.mod(len(UniqFold), Unit): msgBox.setText( "Unit must be divorceable to all possible levels! Number of all levels is: " + str(len(UniqFold))) msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False ListFold = list() for gfold in GroupFold: for ufoldindx, ufold in enumerate(UniqFold): if (ufold == gfold).all(): currentID = FoldIDs[ufoldindx] break ListFold.append(currentID) ListFold = np.int32(ListFold) if Unit == 1: UnitFold = np.int32(ListFold) else: UnitFold = np.int32((ListFold - 0.1) / Unit) + 1 GUFold = np.unique(UnitFold) FoldInfo = dict() FoldInfo["Unit"] = Unit FoldInfo["Group"] = GroupFold FoldInfo["Order"] = FoldStr FoldInfo["List"] = ListFold FoldInfo["Unique"] = UniqFold FoldInfo["Folds"] = UnitFold OutData["FoldInfo"] = FoldInfo OutData["ModelAnalysis"] = "Tensorflow.Group.Single-Deep-Kernel.RSA" print("Number of all levels is: " + str(len(UniqFold))) # RSA Method OutData['Method'] = dict() OutData['Method']['Layers'] = ui.txtLayers.text() OutData['Method']['Activation'] = Activation OutData['Method']['LossNorm'] = LossNorm OutData['Method']['LearningRate'] = LearningRate OutData['Method']['KernelIter'] = KIter OutData['Method']['RSAIter'] = RIter OutData['Method']['BatchSize'] = BatchSize OutData['Method']['ReportStep'] = ReportStep OutData['Method']['Verbose'] = ui.cbVerbose.isChecked() TData = list() TReg = list() print("Reshaping Data ...") for foldID, fold in enumerate(GUFold): print("Reshaping level " + str(foldID + 1), " of ", str(len(UniqFold)), " ...") Index = np.where(UnitFold == fold) # Whole-Data if FoldStr == "Whole-Data" and np.shape(Index)[0]: Index = [Index[1]] XLi = X[Index] RegLi = Design[Index] if ui.cbScale.isChecked() and ui.rbScale.isChecked(): XLi = preprocessing.scale(XLi) print("Whole of data is scaled X%d~N(0,1)." % (foldID + 1)) TData.append(XLi) TReg.append(RegLi) print("Running Deep Group RSA ...") rsa = DeepGroupRSA(layers=Layers, kernel_iter = KIter, rsa_iter = RIter, learning_rate=LearningRate, loss_norm=LossNorm, activation=Activation, \ batch_size=BatchSize, report_step=ReportStep, verbose=ui.cbVerbose.isChecked(), \ NCat=np.shape(Design)[1], NVoxel=np.shape(X)[1], CPU=ui.cbDevice.currentData()) Betas, Eps, Weights, Bias, MSE, loss_mat = rsa.fit(data_vals=TData, design_vals=TReg) OutData["Weight"] = Weights OutData["Bias"] = Bias OutData["Perfromance"] = MSE OutData["Perfromance_Average"] = rsa.AMSE OutData["Perfromance_std"] = np.std(rsa.AMSE) OutData["LossMat"] = loss_mat print("Average Performance: %f" % (OutData["Perfromance"])) print("Calculating cov & corr ... ") AvgCov = None AvgCorr = None for beta_id, beta in enumerate(Betas): OutData["BetaL" + str(beta_id)] = beta if ui.cbCov.isChecked(): co = np.cov(beta) OutData["Cov" + str(beta_id)] = co AvgCov = co if AvgCov is None else AvgCov + co if ui.cbCorr.isChecked(): cr = np.corrcoef(beta) OutData["Corr" + str(beta_id)] = cr AvgCorr = cr if AvgCorr is None else AvgCorr + cr for eps_id, ep in enumerate(Eps): OutData["EpsL" + str(eps_id)] = ep if ui.cbCov.isChecked(): AvgCov = AvgCov / len(TData) covClass = SimilarityMatrixBetweenClass(AvgCov) OutData["Covariance"] = AvgCov OutData["Covariance_min"] = covClass.min() OutData["Covariance_max"] = covClass.max() OutData["Covariance_mean"] = covClass.mean() OutData["Covariance_std"] = covClass.std() if ui.cbCorr.isChecked(): AvgCorr = AvgCorr / len(TData) corClass = SimilarityMatrixBetweenClass(AvgCorr) OutData["Correlation"] = AvgCorr OutData["Correlation_min"] = corClass.min() OutData["Correlation_max"] = corClass.max() OutData["Correlation_mean"] = corClass.mean() OutData["Correlation_std"] = corClass.std() OutData["RunTime"] = time.time() - tStart print("Runtime (s): %f" % (OutData["RunTime"])) print("Saving results ...") io.savemat(OutFile, mdict=OutData, do_compression=True) print("Output is saved.") if ui.cbDiagram.isChecked(): if ui.cbCorr.isChecked(): fig1 = plt.figure(num=None, figsize=(5, 5), dpi=100) plt.pcolor(AvgCorr, vmin=-0.1, vmax=1) plt.xlim([0, np.shape(AvgCorr)[0]]) plt.ylim([0, np.shape(AvgCorr)[0]]) plt.colorbar() ax = plt.gca() ax.set_aspect(1) plt.title('Deep Group RSA: Correlation\nLevel: ' + FoldStr) plt.show() if ui.cbCov.isChecked(): fig2 = plt.figure(num=None, figsize=(5, 5), dpi=100) plt.pcolor(AvgCov) plt.xlim([0, np.shape(AvgCov)[0]]) plt.ylim([0, np.shape(AvgCov)[0]]) plt.colorbar() ax = plt.gca() ax.set_aspect(1) plt.title('Deep Group RSA: Covariance\nLevel: ' + FoldStr) plt.show() print("DONE.") msgBox.setText( "Group Level Single-Deep-Kernel Representational Similarity Analysis is done." ) msgBox.setIcon(QMessageBox.Information) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_()
def btnConvert_click(self): msgBox = QMessageBox() tStart = time.time() if not ui.cbCov.isChecked() and not ui.cbCorr.isChecked(): msgBox.setText("At least, you must select one metric!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False Method = ui.cbMethod.currentData() NuregMethod = ui.cbNuregMethod.currentData() Tau2Prior = ui.cbTau2Prior.currentData() SNRPrior = ui.cbSNRPrior.currentData() GPS = ui.cbGBS.isChecked() GPI = ui.cbGPI.isChecked() BaselineSingle = ui.cbBaselineSingle.isChecked() AutoNuisance = ui.cbAutoNuisance.isChecked() NuregZscore = ui.cbNuregZscore.isChecked() try: miter = np.int(ui.txtMaxIter.text()) assert miter >= 1, None except: msgBox.setText("Max Iteration is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: iiter = np.int(ui.txtInitIter.text()) assert iiter >= 1, None except: msgBox.setText("Init Iteration is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Speed = np.int(ui.txtAnnealSpeed.text()) assert Speed >= 1, None except: msgBox.setText("Anneal speed is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: rank = np.int(ui.txtRank.text()) assert rank >= 0, None if rank == 0: rank = None except: msgBox.setText("Rank is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: NumReg = np.int(ui.txtNumReg.text()) assert NumReg >= 0, None if NumReg == 0: NumReg = None except: msgBox.setText("Number of Reg is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: SNRBin = np.int(ui.txtSNRBins.text()) assert SNRBin > 0, None except: msgBox.setText("SNR bin is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: rhoBin = np.int(ui.txtRhoBins.text()) assert rhoBin > 0, None except: msgBox.setText("Rho bin is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: tol = np.float(ui.txtTole.text()) except: msgBox.setText("Tolerance is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: eta = np.float(ui.txtEta.text()) except: msgBox.setText("Eta is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: LogSRange = np.float(ui.txtLogSRange.text()) except: msgBox.setText("LogS range is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: SpaceSmoothRange = np.float(ui.txtSpaceSmoothRange.text()) if SpaceSmoothRange == 0: SpaceSmoothRange = None except: msgBox.setText("Space Smooth Range is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: IntenSmoothRange = np.float(ui.txtIntenSmoothRange.text()) if IntenSmoothRange == 0: IntenSmoothRange = None except: msgBox.setText("Inten Smooth Range is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Tau = np.float(ui.txtTauRange.text()) except: msgBox.setText("Eta is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: CodeText = ui.txtEvents.toPlainText() allvars = dict(locals(), **globals()) exec(CodeText, allvars, allvars) Optimizer = allvars['optimizer'] MinimizeOptions = allvars['minimize_options'] except Exception as e: msgBox.setText("Optimizer is wrong!\n" + str(e)) msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Filter try: Filter = ui.txtFilter.text() if not len(Filter): Filter = None else: Filter = Filter.replace("\'", " ").replace(",", " ").replace("[", "").replace("]","").split() Filter = np.int32(Filter) except: print("Filter is wrong!") return # OutFile OutFile = ui.txtOutFile.text() if not len(OutFile): msgBox.setText("Please enter out file!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False OutData = dict() # InFile InFile = ui.txtInFile.text() if not len(InFile): msgBox.setText("Please enter input file!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if not os.path.isfile(InFile): msgBox.setText("Input file not found!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False print("Loading ...") InData = mainIO_load(InFile) # Data if not len(ui.txtData.currentText()): msgBox.setText("Please enter Input Data variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Label if not len(ui.txtLabel.currentText()): msgBox.setText("Please enter Train Label variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Design if not len(ui.txtDesign.currentText()): msgBox.setText("Please enter Input Design variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Design = InData[ui.txtDesign.currentText()] except: msgBox.setText("Design value is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Condition if not len(ui.txtCond.currentText()): msgBox.setText("Please enter Condition variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Cond = InData[ui.txtCond.currentText()] OutData[ui.txtCond.currentText()] = Cond labels = list() for con in Cond: labels.append(reshape_condition_cell(con[1])) except: msgBox.setText("Condition value is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False FontSize = ui.txtFontSize.value() try: X = InData[ui.txtData.currentText()] Intensity = np.mean(X, axis=0) L = InData[ui.txtLabel.currentText()][0] if ui.cbScale.isChecked() and not ui.rbScale.isChecked(): X = preprocessing.scale(X) print("Whole of data is scaled X~N(0,1).") except: print("Cannot load data or label") return # Task if not len(ui.txtTask.currentText()): msgBox.setText("Please enter Task variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: TaskTitle = np.array(InData[ui.txtTask.currentText()][0]) except: msgBox.setText("Task variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False TaskTitleUnique = np.unique(TaskTitle) Task = np.zeros(np.shape(TaskTitle)) for ttinx, tt in enumerate(TaskTitle): for ttlinx, ttl in enumerate(TaskTitleUnique): if tt[0] == ttl: Task[ttinx] = ttlinx + 1 break # Subject if not len(ui.txtSubject.currentText()): msgBox.setText("Please enter Subject variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Sub = InData[ui.txtSubject.currentText()][0] except: msgBox.setText("Subject variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Run if not len(ui.txtRun.currentText()): msgBox.setText("Please enter Run variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Run = InData[ui.txtRun.currentText()][0] except: msgBox.setText("Run variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Counter if not len(ui.txtCounter.currentText()): msgBox.setText("Please enter Counter variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Con = InData[ui.txtCounter.currentText()][0] except: msgBox.setText("Counter variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: if len(ui.txtCoord.currentText()): Coord = np.transpose(InData[ui.txtCoord.currentText()]) else: Coord = None except: msgBox.setText("Coordinate variable is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if Filter is not None: for fil in Filter: # Remove Training Set labelIndx = np.where(L == fil)[0] Design = np.delete(Design, labelIndx, axis=0) X = np.delete(X, labelIndx, axis=0) L = np.delete(L, labelIndx, axis=0) Task = np.delete(Task, labelIndx, axis=0) Sub = np.delete(Sub, labelIndx, axis=0) Run = np.delete(Run, labelIndx, axis=0) Con = np.delete(Con, labelIndx, axis=0) print("Class ID = " + str(fil) + " is removed from data.") try: Unit = np.int32(ui.txtUnit.text()) except: msgBox.setText("Unit for the test set must be a number!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if Unit < 1: msgBox.setText("Unit for the test set must be greater than zero!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False print("Calculating Levels ...") GroupFold = None FoldStr = "" if ui.cbFSubject.isChecked(): if not ui.rbFRun.isChecked(): GroupFold = [Sub] FoldStr = "Subject" else: GroupFold = np.concatenate(([Sub],[Run])) FoldStr = "Subject+Run" if ui.cbFTask.isChecked(): GroupFold = np.concatenate((GroupFold,[Task])) if GroupFold is not None else [Task] FoldStr = FoldStr + "+Task" if ui.cbFCounter.isChecked(): GroupFold = np.concatenate((GroupFold,[Con])) if GroupFold is not None else [Con] FoldStr = FoldStr + "+Counter" if FoldStr == "": FoldStr = "Whole-Data" GUFold = [1] ListFold = [1] UniqFold = [1] GroupFold = [1] UnitFold = np.ones((1, np.shape(X)[0])) else: GroupFold = np.transpose(GroupFold) UniqFold = np.array(list(set(tuple(i) for i in GroupFold.tolist()))) FoldIDs = np.arange(len(UniqFold)) + 1 if len(UniqFold) <= Unit: msgBox.setText("Unit must be smaller than all possible levels! Number of all levels is: " + str(len(UniqFold))) msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if np.mod(len(UniqFold),Unit): msgBox.setText("Unit must be divorceable to all possible levels! Number of all levels is: " + str(len(UniqFold))) msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False ListFold = list() for gfold in GroupFold: for ufoldindx, ufold in enumerate(UniqFold): if (ufold == gfold).all(): currentID = FoldIDs[ufoldindx] break ListFold.append(currentID) ListFold = np.int32(ListFold) if Unit == 1: UnitFold = np.int32(ListFold) else: UnitFold = np.int32((ListFold - 0.1) / Unit) + 1 GUFold = np.unique(UnitFold) FoldInfo = dict() FoldInfo["Unit"] = Unit FoldInfo["Group"] = GroupFold FoldInfo["Order"] = FoldStr FoldInfo["List"] = ListFold FoldInfo["Unique"] = UniqFold FoldInfo["Folds"] = UnitFold OutData["FoldInfo"] = FoldInfo OutData["ModelAnalysis"] = "SK.Group.RSA." + ui.cbMethod.currentText() print("Number of all levels is: " + str(len(UniqFold))) Cov = None Corr = None AMSE = list() Beta = None for foldID, fold in enumerate(GUFold): print("Analyzing level " + str(foldID + 1)," of ", str(len(UniqFold)) , " ...") Index = np.where(UnitFold == fold) # Whole-Data if FoldStr == "Whole-Data" and np.shape(Index)[0]: Index = [Index[1]] XLi = X[Index] if ui.cbScale.isChecked() and ui.rbScale.isChecked(): XLi = preprocessing.scale(XLi) print("Whole of data is scaled X%d~N(0,1)." % (foldID + 1)) #RegLi = np.insert(Design[Index], 0, 1, axis=1) RegLi = Design[Index] try: if Method == "brsa": model = BRSA(n_iter=miter, rank=rank, auto_nuisance=AutoNuisance, n_nureg=NumReg, nureg_zscore=NuregZscore, nureg_method=NuregMethod, baseline_single=BaselineSingle, GP_space=GPS, GP_inten=GPI, space_smooth_range=SpaceSmoothRange, inten_smooth_range=IntenSmoothRange, tau_range=Tau, tau2_prior=Tau2Prior, eta=eta, init_iter=iiter, anneal_speed=Speed, tol=tol, optimizer=Optimizer, minimize_options=MinimizeOptions) model.fit(XLi, RegLi, coords=Coord, inten=Intensity) else: model = GBRSA(n_iter=miter, rank=rank, auto_nuisance=AutoNuisance, n_nureg=NumReg, nureg_zscore=NuregZscore, nureg_method=NuregMethod, baseline_single=BaselineSingle, tol=tol, anneal_speed=Speed, SNR_prior=SNRPrior, logS_range=LogSRange, SNR_bins=SNRBin, rho_bins=rhoBin, optimizer=Optimizer, minimize_options=MinimizeOptions) model.fit(XLi, RegLi) except Exception as e: msgBox.setText(str(e)) print(str(e)) msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False BetaLi = model.beta_ Beta = BetaLi if Beta is None else Beta + BetaLi print("Calculating MSE for level %d ..." % (foldID + 1)) MSE = mean_squared_error(XLi, np.matmul(RegLi, BetaLi)) print("MSE%d: %f" % (foldID + 1, MSE)) OutData["MSE" + str(foldID)] = MSE AMSE.append(MSE) if ui.cbBeta.isChecked(): OutData["BetaL" + str(foldID + 1)] = BetaLi # Calculate Results if ui.cbCorr.isChecked(): print("Calculating Correlation for level %d ..." % (foldID + 1)) CorrLi = np.corrcoef(BetaLi) OutData["Corr" + str(foldID + 1)] = CorrLi if Corr is None: Corr = CorrLi.copy() else: if ui.rbAvg.isChecked(): Corr = np.add(Corr, CorrLi) elif ui.rbMin.isChecked(): Corr = np.minimum(Corr, CorrLi) else: Corr = np.maximum(Corr, CorrLi) if ui.cbCov.isChecked(): print("Calculating Covariance for level %d ..." % (foldID + 1)) CovLi = np.cov(BetaLi) OutData["Cov" + str(foldID + 1)] = CovLi if Cov is None: Cov = CovLi.copy() else: if ui.rbAvg.isChecked(): Cov = np.add(Cov, CovLi) elif ui.rbMin.isChecked(): Cov = np.minimum(Cov, CovLi) else: Cov = np.maximum(Cov, CovLi) CoEff = len(UniqFold) - 1 if len(UniqFold) > 2 else 1 if ui.cbCov.isChecked(): if ui.rbAvg.isChecked(): Cov = Cov / CoEff covClass = SimilarityMatrixBetweenClass(Cov) OutData["Covariance"] = Cov OutData["Covariance_min"] = covClass.min() OutData["Covariance_max"] = covClass.max() OutData["Covariance_std"] = covClass.std() OutData["Covariance_mean"] = covClass.mean() if ui.cbCorr.isChecked(): if ui.rbAvg.isChecked(): Corr = Corr / CoEff corClass = SimilarityMatrixBetweenClass(Corr) OutData["Correlation"] = Corr OutData["Correlation_min"] = corClass.min() OutData["Correlation_max"] = corClass.max() OutData["Correlation_std"] = corClass.std() OutData["Correlation_mean"] = corClass.mean() # Calculating Distance Matrix dis = np.zeros((np.shape(Beta)[0], np.shape(Beta)[0])) for i in range(np.shape(Beta)[0]): for j in range(i + 1, np.shape(Beta)[0]): dis[i, j] = 1 - np.dot(Beta[i, :], Beta[j, :].T) dis[j, i] = dis[i, j] OutData["DistanceMatrix"] = dis Z = linkage(dis) OutData["Linkage"] = Z OutData["MSE"] = np.mean(AMSE) print("Average MSE: %f" % (OutData["MSE"])) OutData["RunTime"] = time.time() - tStart print("Runtime (s): %f" % (OutData["RunTime"])) print("Saving results ...") mainIO_save(OutFile, OutData) print("Output is saved.") if ui.cbDiagram.isChecked(): if ui.cbCorr.isChecked(): NumData = np.shape(Corr)[0] fig1 = plt.figure(num=None, figsize=(NumData, NumData), dpi=100) plt.pcolor(Corr, vmin=np.min(Corr), vmax=np.max(Corr)) plt.xlim([0, NumData]) plt.ylim([0, NumData]) cbar = plt.colorbar() cbar.ax.tick_params(labelsize=FontSize) ax = plt.gca() ax.invert_yaxis() ax.set_aspect(1) ax.set_yticks(np.arange(NumData) + 0.5, minor=False) ax.set_xticks(np.arange(NumData) + 0.5, minor=False) ax.set_xticklabels(labels, minor=False, fontsize=FontSize, rotation=ui.txtXRotation.value()) ax.set_yticklabels(labels, minor=False, fontsize=FontSize, rotation=ui.txtYRotation.value()) ax.grid(False) ax.set_aspect(1) ax.set_frame_on(False) for t in ax.xaxis.get_major_ticks(): t.tick1On = False t.tick2On = False for t in ax.yaxis.get_major_ticks(): t.tick1On = False t.tick2On = False if len(ui.txtTitleCorr.text()): plt.title(ui.txtTitleCorr.text()) else: plt.title('Group RSA: Correlation\nLevel: ' + FoldStr) plt.show() if ui.cbCov.isChecked(): NumData = np.shape(Cov)[0] fig2 = plt.figure(num=None, figsize=(NumData, NumData), dpi=100) plt.pcolor(Cov, vmin=np.min(Cov), vmax=np.max(Cov)) plt.xlim([0, NumData]) plt.ylim([0, NumData]) cbar = plt.colorbar() cbar.ax.tick_params(labelsize=FontSize) ax = plt.gca() ax.invert_yaxis() ax.set_aspect(1) ax.set_yticks(np.arange(NumData) + 0.5, minor=False) ax.set_xticks(np.arange(NumData) + 0.5, minor=False) ax.set_xticklabels(labels, minor=False, fontsize=FontSize, rotation=ui.txtXRotation.value()) ax.set_yticklabels(labels, minor=False, fontsize=FontSize, rotation=ui.txtYRotation.value()) ax.grid(False) ax.set_aspect(1) ax.set_frame_on(False) for t in ax.xaxis.get_major_ticks(): t.tick1On = False t.tick2On = False for t in ax.yaxis.get_major_ticks(): t.tick1On = False t.tick2On = False if len(ui.txtTitleCov.text()): plt.title(ui.txtTitleCov.text()) else: plt.title('Group RSA: Covariance\nLevel: ' + FoldStr) plt.show() fig3 = plt.figure(figsize=(25, 10), ) if len(ui.txtTitleDen.text()): plt.title(ui.txtTitleDen.text()) else: plt.title('Group MP Gradient RSA: Similarity Analysis\nLevel: ' + FoldStr) dn = dendrogram(Z, labels=labels, leaf_font_size=FontSize, color_threshold=1, leaf_rotation=ui.txtXRotation.value()) plt.show() print("DONE.") msgBox.setText("Group Representational Similarity Analysis (RSA) is done.") msgBox.setIcon(QMessageBox.Information) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_()
def btnConvert_click(self): msgBox = QMessageBox() tStart = time.time() if not ui.cbCov.isChecked() and not ui.cbCorr.isChecked(): msgBox.setText("At least, you must select one metric!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Method gpu = ui.cbGPU.currentData() Verbose = ui.cbVerbose.isChecked() try: gamma = np.float(ui.txtGamma.text()) except: msgBox.setText("Gamma is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Iter = np.int(ui.txtIter.text()) except: msgBox.setText("Max Iteration is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: NumFea = ui.txtNumFea.value() except: msgBox.setText("Number of features is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if NumFea <= 0: NumFea = None # Filter try: Filter = ui.txtFilter.text() if not len(Filter): Filter = None else: Filter = Filter.replace("\'", " ").replace(",", " ").replace( "[", "").replace("]", "").split() Filter = np.int32(Filter) except: print("Filter is wrong!") return # OutFile OutFile = ui.txtOutFile.text() if not len(OutFile): msgBox.setText("Please enter out file!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False OutData = dict() # InFile InFile = ui.txtInFile.text() if not len(InFile): msgBox.setText("Please enter input file!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if not os.path.isfile(InFile): msgBox.setText("Input file not found!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if not len(ui.txtSharedSpace.text()): msgBox.setText("Please enter Shared Space variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if not len(ui.txtSharedVoxelSpace.text()): msgBox.setText("Please enter Shared Voxel Space variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if not len(ui.txtViewSpaces.text()): msgBox.setText("Please enter View Space variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if not len(ui.txtTransformMats.text()): msgBox.setText("Please enter Transform Matrices variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False print("Loading ...") InData = io.loadmat(InFile) OutData["imgShape"] = InData["imgShape"] # Data if not len(ui.txtData.currentText()): msgBox.setText("Please enter Input Data variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Label if not len(ui.txtLabel.currentText()): msgBox.setText("Please enter Train Label variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Coordinate if not len(ui.txtCoordinate.currentText()): msgBox.setText("Please enter Input Coordinate variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Coordinate = InData[ui.txtCoordinate.currentText()] OutData["coordinate"] = Coordinate except: msgBox.setText("Coordinate value is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Condition if not len(ui.txtCond.currentText()): msgBox.setText("Please enter Condition variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Cond = InData[ui.txtCond.currentText()] OutData["condition"] = Cond labels = list() for con in Cond: labels.append(con[1][0]) labels = np.array(labels) OutData["labels"] = labels except: msgBox.setText("Condition value is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False FontSize = ui.txtFontSize.value() try: X = InData[ui.txtData.currentText()] L = InData[ui.txtLabel.currentText()][0] if ui.cbScale.isChecked() and not ui.rbScale.isChecked(): X = preprocessing.scale(X) print("Whole of data is scaled X~N(0,1).") except: print("Cannot load data or label") return # Task if not len(ui.txtTask.currentText()): msgBox.setText("Please enter Task variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: TaskTitle = InData[ui.txtTask.currentText()][0] except: msgBox.setText("Task variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False TaskTitleUnique = np.unique(TaskTitle) Task = np.zeros(np.shape(TaskTitle)) for ttinx, tt in enumerate(TaskTitle): for ttlinx, ttl in enumerate(TaskTitleUnique): if tt[0] == ttl: Task[ttinx] = ttlinx + 1 break # Subject if not len(ui.txtSubject.currentText()): msgBox.setText("Please enter Subject variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Sub = InData[ui.txtSubject.currentText()][0] except: msgBox.setText("Subject variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Run if not len(ui.txtRun.currentText()): msgBox.setText("Please enter Run variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Run = InData[ui.txtRun.currentText()][0] except: msgBox.setText("Run variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Counter if not len(ui.txtCounter.currentText()): msgBox.setText("Please enter Counter variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Con = InData[ui.txtCounter.currentText()][0] except: msgBox.setText("Counter variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if Filter is not None: for fil in Filter: # Remove Training Set labelIndx = np.where(L == fil)[0] X = np.delete(X, labelIndx, axis=0) L = np.delete(L, labelIndx, axis=0) Task = np.delete(Task, labelIndx, axis=0) Sub = np.delete(Sub, labelIndx, axis=0) Run = np.delete(Run, labelIndx, axis=0) Con = np.delete(Con, labelIndx, axis=0) print("Class ID = " + str(fil) + " is removed from data.") try: Unit = np.int32(ui.txtUnit.text()) except: msgBox.setText("Unit for the test set must be a number!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if Unit < 1: msgBox.setText("Unit for the test set must be greater than zero!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False print("Calculating Levels ...") GroupFold = None FoldStr = "" if ui.cbFSubject.isChecked(): if not ui.rbFRun.isChecked(): GroupFold = [Sub] FoldStr = "Subject" else: GroupFold = np.concatenate(([Sub], [Run])) FoldStr = "Subject+Run" if ui.cbFTask.isChecked(): GroupFold = np.concatenate( (GroupFold, [Task])) if GroupFold is not None else [Task] FoldStr = FoldStr + "+Task" if ui.cbFCounter.isChecked(): GroupFold = np.concatenate( (GroupFold, [Con])) if GroupFold is not None else [Con] FoldStr = FoldStr + "+Counter" if FoldStr == "": FoldStr = "Whole-Data" GUFold = [1] ListFold = [1] UniqFold = [1] GroupFold = [1] UnitFold = np.ones((1, np.shape(X)[0])) else: GroupFold = np.transpose(GroupFold) UniqFold = np.array(list(set(tuple(i) for i in GroupFold.tolist()))) FoldIDs = np.arange(len(UniqFold)) + 1 if len(UniqFold) <= Unit: msgBox.setText( "Unit must be smaller than all possible levels! Number of all levels is: " + str(len(UniqFold))) msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if np.mod(len(UniqFold), Unit): msgBox.setText( "Unit must be divorceable to all possible levels! Number of all levels is: " + str(len(UniqFold))) msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False ListFold = list() for gfold in GroupFold: for ufoldindx, ufold in enumerate(UniqFold): if (ufold == gfold).all(): currentID = FoldIDs[ufoldindx] break ListFold.append(currentID) ListFold = np.int32(ListFold) if Unit == 1: UnitFold = np.int32(ListFold) else: UnitFold = np.int32((ListFold - 0.1) / Unit) + 1 GUFold = np.unique(UnitFold) FoldInfo = dict() FoldInfo["Unit"] = Unit FoldInfo["Group"] = GroupFold FoldInfo["Order"] = FoldStr FoldInfo["List"] = ListFold FoldInfo["Unique"] = UniqFold FoldInfo["Folds"] = UnitFold OutData["FoldInfo"] = FoldInfo OutData["ModelAnalysis"] = "SSA" print("Number of all levels is: " + str(len(UniqFold))) Xi = list() Yi = list() for foldID, fold in enumerate(GUFold): print("Extracting view " + str(foldID + 1), " of ", str(len(UniqFold)), " ...") Index = np.where(UnitFold == fold) # Whole-Data if FoldStr == "Whole-Data" and np.shape(Index)[0]: Index = [Index[1]] Xi.append(X[Index]) Yi.append(label_binarize(L[Index], np.unique(L))) try: ssa = SSA(gamma=gamma, gpu=gpu) Beta = ssa.run(X=Xi, Y=Yi, Dim=NumFea, verbose=Verbose, Iteration=Iter, ShowError=ui.cbError.isChecked()) if ui.cbError.isChecked(): if ssa.LostVec is not None: OutData["LossVec"] = ssa.LostVec OutData["Error"] = ssa.Loss except Exception as e: msgBox.setText(str(e)) msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False OutData["AlgorithmRuntime"] = ssa.Runtime OutData[ui.txtSharedSpace.text()] = Beta OutData[ui.txtSharedVoxelSpace.text()] = ssa.getSharedVoxelSpace() OutData[ui.txtTransformMats.text()] = ssa.getTransformMats() if ui.cbViewSpace.isChecked(): for viewID, view in enumerate(ssa.getSubjectSpace()): OutData[ui.txtViewSpaces.text() + "_View" + str(viewID + 1)] = np.transpose(view) else: OutData[ui.txtViewSpaces.text()] = ssa.getSubjectSpace() print("Calculating Distance Matrix ...") dis = np.zeros((np.shape(Beta)[0], np.shape(Beta)[0])) for i in range(np.shape(Beta)[0]): for j in range(i + 1, np.shape(Beta)[0]): dis[i, j] = 1 - np.dot(Beta[i, :], Beta[j, :].T) dis[j, i] = dis[i, j] # dis = dis - np.min(dis) # dis = dis / np.max(dis) OutData["DistanceMatrix"] = dis print("Applying linkage ...") Z = linkage(dis, method=ui.cbLMethod.currentData(), metric=ui.cbLMetric.currentData(), optimal_ordering=ui.cbLOrder.isChecked()) OutData["Linkage"] = Z if ui.cbCov.isChecked(): Cov = np.cov(Beta) # OutData["Covariance"] = Cov # OutData["Covariance_min"] = np.min(Cov) # OutData["Covariance_max"] = np.max(Cov) # OutData["Covariance_std"] = np.std(Cov) # OutData["Covariance_mean"] = np.mean(Cov) covClass = SimilarityMatrixBetweenClass(Cov) OutData["Covariance"] = Cov OutData["Covariance_min"] = covClass.min() OutData["Covariance_max"] = covClass.max() OutData["Covariance_std"] = covClass.std() OutData["Covariance_mean"] = covClass.mean() if ui.cbCorr.isChecked(): Corr = np.corrcoef(Beta) # OutData["Correlation"] = Corr # OutData["Correlation_min"] = np.min(Corr) # OutData["Correlation_max"] = np.max(Corr) # OutData["Correlation_std"] = np.std(Corr) # OutData["Correlation_mean"] = np.mean(Corr) corClass = SimilarityMatrixBetweenClass(Corr) OutData["Correlation"] = Corr OutData["Correlation_min"] = corClass.min() OutData["Correlation_max"] = corClass.max() OutData["Correlation_std"] = corClass.std() OutData["Correlation_mean"] = corClass.mean() OutData["Runtime"] = time.time() - tStart print("Runtime: ", OutData["Runtime"]) print("Algorithm Runtime: ", OutData["AlgorithmRuntime"]) print("Saving results ...") io.savemat(OutFile, mdict=OutData, do_compression=True) print("Output is saved.") if ui.cbDiagram.isChecked(): drawrsa = DrawRSA() if ui.cbCorr.isChecked(): drawrsa.ShowFigure(Corr, labels, ui.txtTitleCorr.text(), FontSize, ui.txtXRotation.value(), ui.txtYRotation.value()) if ui.cbCov.isChecked(): drawrsa.ShowFigure(Cov, labels, ui.txtTitleCov.text(), FontSize, ui.txtXRotation.value(), ui.txtYRotation.value()) drawrsa.ShowDend(Z, labels, ui.txtTitleDen.text(), FontSize, ui.txtXRotation.value()) print("DONE.") msgBox.setText("Shared Similarity Analysis (SSA) is done.") msgBox.setIcon(QMessageBox.Information) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_()
def btnConvert_click(self): msgBox = QMessageBox() tStart = time.time() if not ui.cbCov.isChecked() and not ui.cbCorr.isChecked(): msgBox.setText("At least, you must select one metric!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Method method = ui.cbMethod.currentData() # Solver solver = ui.cbSolver.currentText() # Selection selection = ui.cbSelection.currentText() # Fit fit = ui.cbFit.isChecked() # normalize normalize = ui.cbNormalize.isChecked() try: alpha = np.float(ui.txtAlpha.text()) except: msgBox.setText("Alpha is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: iter = np.int(ui.txtMaxIter.text()) except: msgBox.setText("Max Iteration is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: tol = np.float(ui.txtTole.text()) except: msgBox.setText("Tolerance is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: l1 = np.float(ui.txtL1.text()) except: msgBox.setText("L1 is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: njob = np.float(ui.txtJobs.text()) except: msgBox.setText("Number of jobs is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Filter try: Filter = ui.txtFilter.text() if not len(Filter): Filter = None else: Filter = Filter.replace("\'", " ").replace(",", " ").replace( "[", "").replace("]", "").split() Filter = np.int32(Filter) except: print("Filter is wrong!") return # OutFile OutFile = ui.txtOutFile.text() if not len(OutFile): msgBox.setText("Please enter out file!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False OutData = dict() # InFile InFile = ui.txtInFile.text() if not len(InFile): msgBox.setText("Please enter input file!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if not os.path.isfile(InFile): msgBox.setText("Input file not found!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False print("Loading ...") InData = io.loadmat(InFile) # Data if not len(ui.txtData.currentText()): msgBox.setText("Please enter Input Data variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Label if not len(ui.txtLabel.currentText()): msgBox.setText("Please enter Train Label variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Design if not len(ui.txtDesign.currentText()): msgBox.setText("Please enter Input Design variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Design = InData[ui.txtDesign.currentText()] except: msgBox.setText("Design value is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Condition if not len(ui.txtCond.currentText()): msgBox.setText("Please enter Condition variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Cond = InData[ui.txtCond.currentText()] OutData[ui.txtCond.currentText()] = Cond labels = list() for con in Cond: labels.append(con[1][0]) labels = np.array(labels) except: msgBox.setText("Condition value is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: X = InData[ui.txtData.currentText()] L = InData[ui.txtLabel.currentText()][0] except: print("Cannot load data or label") return # Task if not len(ui.txtTask.currentText()): msgBox.setText("Please enter Task variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Task Val if not len(ui.txtTaskVal.currentText()): msgBox.setText("Please enter Task value!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: TaskIDTitle = ui.txtTaskVal.currentText() except: msgBox.setText("Task value is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: TaskTitle = InData[ui.txtTask.currentText()][0] except: msgBox.setText("Task variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False TaskTitleUnique = np.unique(TaskTitle) Task = np.zeros(np.shape(TaskTitle)) for ttinx, tt in enumerate(TaskTitle): for ttlinx, ttl in enumerate(TaskTitleUnique): if tt[0] == ttl: Task[ttinx] = ttlinx + 1 break for ttlinx, ttl in enumerate(TaskTitleUnique): if TaskIDTitle == ttl: TaskID = ttlinx + 1 break OutData["Task"] = TaskIDTitle # Subject if not len(ui.txtSubject.currentText()): msgBox.setText("Please enter Subject variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Subject Val if not len(ui.txtSubjectVal.currentText()): msgBox.setText("Please enter Subject value!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: SubID = np.int32(ui.txtSubjectVal.currentText()) except: msgBox.setText("Subject value is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Sub = InData[ui.txtSubject.currentText()][0] except: msgBox.setText("Subject variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False OutData["SubjectID"] = SubID # Run if not len(ui.txtRun.currentText()): msgBox.setText("Please enter Run variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Run Val if not len(ui.txtRunVal.currentText()): msgBox.setText("Please enter Run value!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: RunID = np.int32(ui.txtRunVal.currentText()) except: msgBox.setText("Run value is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Run = InData[ui.txtRun.currentText()][0] except: msgBox.setText("Run variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False OutData["RunID"] = RunID # Counter if not len(ui.txtCounter.currentText()): msgBox.setText("Please enter Counter variable name!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False # Counter Val if not len(ui.txtCounterVal.currentText()): msgBox.setText("Please enter Counter value!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: ConID = np.int32(ui.txtCounterVal.currentText()) except: msgBox.setText("Counter value is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False try: Con = InData[ui.txtCounter.currentText()][0] except: msgBox.setText("Counter variable name is wrong!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False OutData["CounterID"] = ConID if Filter is not None: for fil in Filter: # Remove Training Set labelIndx = np.where(L == fil)[0] Design = np.delete(Design, labelIndx, axis=0) X = np.delete(X, labelIndx, axis=0) L = np.delete(L, labelIndx, axis=0) Task = np.delete(Task, labelIndx, axis=0) Sub = np.delete(Sub, labelIndx, axis=0) Run = np.delete(Run, labelIndx, axis=0) Con = np.delete(Con, labelIndx, axis=0) print("Class ID = " + str(fil) + " is removed from data.") # Select Task TaskIndex = np.where(Task == TaskID) Design = Design[TaskIndex, :][0] X = X[TaskIndex, :][0] L = L[TaskIndex] Sub = Sub[TaskIndex] Run = Run[TaskIndex] Con = Con[TaskIndex] # Select Subject SubIndex = np.where(Sub == SubID) Design = Design[SubIndex, :][0] X = X[SubIndex, :][0] L = L[SubIndex] Run = Run[SubIndex] Con = Con[SubIndex] # Select Counter ConIndex = np.where(Con == ConID) Design = Design[ConIndex, :][0] X = X[ConIndex, :][0] L = L[ConIndex] Run = Run[ConIndex] # Select Run RunIndex = np.where(Run == RunID) Design = Design[RunIndex, :][0] X = X[RunIndex, :][0] L = L[RunIndex] # This will only use in supervised methods LUnique = np.unique(L) LNum = np.shape(LUnique)[0] OutData["Label"] = LUnique OutData["ModelAnalysis"] = "SK.Session.RSA." + ui.cbMethod.currentText( ) if np.shape(X)[0] == 0: msgBox.setText("The selected data is empty!") msgBox.setIcon(QMessageBox.Critical) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_() return False if ui.cbScale.isChecked(): X = preprocessing.scale(X) print("Data is scaled to N(0,1).") print("Running RSA ...") # RSA Method Reg = np.insert(Design, 0, 1, axis=1) if method == "ols": model = linmdl.LinearRegression(fit_intercept=fit, normalize=normalize, n_jobs=njob) elif method == "ridge": model = linmdl.Ridge(alpha=alpha, fit_intercept=fit, normalize=normalize, max_iter=iter, tol=tol, solver=solver) elif method == "lasso": model = linmdl.Lasso(alpha=alpha, fit_intercept=fit, normalize=normalize, max_iter=iter, tol=tol, selection=selection) elif method == "elast": model = linmdl.ElasticNet(alpha=alpha,l1_ratio=l1, fit_intercept=fit, normalize=normalize, \ max_iter=iter, tol=tol,selection=selection) model.fit(Reg, X) Betas = np.transpose(model.coef_)[1:, :] print("Calculating MSE ...") MSE = mean_squared_error(X, np.matmul(Design, Betas)) print("MSE: %f" % (MSE)) OutData["MSE"] = MSE # Calculating Distance Matrix dis = np.zeros((np.shape(Betas)[0], np.shape(Betas)[0])) for i in range(np.shape(Betas)[0]): for j in range(i + 1, np.shape(Betas)[0]): dis[i, j] = 1 - np.dot(Betas[i, :], Betas[j, :].T) dis[j, i] = dis[i, j] OutData["DistanceMatrix"] = dis Z = linkage(dis) OutData["Linkage"] = Z if ui.cbBeta.isChecked(): OutData["Betas"] = Betas # Calculate Results if ui.cbCorr.isChecked(): print("Calculating Correlation ...") Corr = np.corrcoef(Betas) corClass = SimilarityMatrixBetweenClass(Corr) OutData["Correlation"] = Corr OutData["Correlation_min"] = corClass.min() OutData["Correlation_max"] = corClass.max() OutData["Correlation_std"] = corClass.std() OutData["Correlation_mean"] = corClass.mean() if ui.cbCov.isChecked(): print("Calculating Covariance ...") Cov = np.cov(Betas) covClass = SimilarityMatrixBetweenClass(Cov) OutData["Covariance"] = Cov OutData["Covariance_min"] = covClass.min() OutData["Covariance_max"] = covClass.max() OutData["Covariance_std"] = covClass.std() OutData["Covariance_mean"] = covClass.mean() OutData["RunTime"] = time.time() - tStart print("Runtime (s): %f" % (OutData["RunTime"])) print("Saving results ...") io.savemat(OutFile, mdict=OutData, do_compression=True) print("Output is saved.") FontSize = ui.txtFontSize.value() if ui.cbDiagram.isChecked(): if ui.cbCorr.isChecked(): NumData = np.shape(Corr)[0] fig1 = plt.figure(num=None, figsize=(NumData, NumData), dpi=100) plt.pcolor(Corr, vmin=np.min(Corr), vmax=np.max(Corr)) plt.xlim([0, NumData]) plt.ylim([0, NumData]) cbar = plt.colorbar() cbar.ax.tick_params(labelsize=FontSize) ax = plt.gca() ax.invert_yaxis() ax.set_aspect(1) ax.set_yticks(np.arange(NumData) + 0.5, minor=False) ax.set_xticks(np.arange(NumData) + 0.5, minor=False) ax.set_xticklabels(labels, minor=False, fontsize=FontSize, rotation=ui.txtXRotation.value()) ax.set_yticklabels(labels, minor=False, fontsize=FontSize, rotation=ui.txtYRotation.value()) ax.grid(False) ax.set_aspect(1) ax.set_frame_on(False) for t in ax.xaxis.get_major_ticks(): t.tick1On = False t.tick2On = False for t in ax.yaxis.get_major_ticks(): t.tick1On = False t.tick2On = False if len(ui.txtTitleCorr.text()): plt.title(ui.txtTitleCorr.text()) else: plt.title('Correlation (' + ui.cbMethod.currentText() + \ ')\nTask: %s\nSub: %d, Counter: %d, Run: %d' % (TaskIDTitle, SubID, ConID, RunID)) plt.show() if ui.cbCov.isChecked(): NumData = np.shape(Cov)[0] fig2 = plt.figure(num=None, figsize=(NumData, NumData), dpi=100) plt.pcolor(Cov, vmin=np.min(Cov), vmax=np.max(Cov)) plt.xlim([0, NumData]) plt.ylim([0, NumData]) cbar = plt.colorbar() cbar.ax.tick_params(labelsize=FontSize) ax = plt.gca() ax.invert_yaxis() ax.set_aspect(1) ax.set_yticks(np.arange(NumData) + 0.5, minor=False) ax.set_xticks(np.arange(NumData) + 0.5, minor=False) ax.set_xticklabels(labels, minor=False, fontsize=FontSize, rotation=ui.txtXRotation.value()) ax.set_yticklabels(labels, minor=False, fontsize=FontSize, rotation=ui.txtYRotation.value()) ax.grid(False) ax.set_aspect(1) ax.set_frame_on(False) for t in ax.xaxis.get_major_ticks(): t.tick1On = False t.tick2On = False for t in ax.yaxis.get_major_ticks(): t.tick1On = False t.tick2On = False if len(ui.txtTitleCov.text()): plt.title(ui.txtTitleCov.text()) else: plt.title('Covariance (' + ui.cbMethod.currentText() + \ ')\nTask: %s\nSub: %d, Counter: %d, Run: %d' % (TaskIDTitle, SubID, ConID, RunID)) plt.show() fig3 = plt.figure(figsize=(25, 10), ) if len(ui.txtTitleDen.text()): plt.title(ui.txtTitleDen.text()) else: plt.title('Similarity Analysis (' + ui.cbMethod.currentText() + \ ')\nTask: %s\nSub: %d, Counter: %d, Run: %d' % (TaskIDTitle, SubID, ConID, RunID)) dn = dendrogram(Z, labels=labels, leaf_font_size=FontSize, color_threshold=1, leaf_rotation=ui.txtXRotation.value()) plt.show() print("DONE.") msgBox.setText("Representational Similarity Analysis (RSA) is done.") msgBox.setIcon(QMessageBox.Information) msgBox.setStandardButtons(QMessageBox.Ok) msgBox.exec_()