class EEGDataset1(Dataset): """PyTorch dataloader for the imaginary coherence dataloader (dataset 1). """ all_transformations = ["one", "std"] def __init__(self, data_folder, file_indices, subj_data, transformation="none", super_node=False): """ Args: data_folder : str The root folder of the preprocessed data. file_indices : Dict[int->int] Converts linear indices useful to iterate through the dataset into keys for the `subj_data` structure. subj_data : dict Information about the file location of each sample """ super(EEGDataset1, self).__init__() if transformation not in all_transformations: raise ValueError( f"Transformation must be in {all_transformations}.") self.super_node = super_node self.transformation = transformation self.data_folder = data_folder self.num_nodes = 90 self.xfile_cache = LRUCache(capacity=50) self.yfile_cache = LRUCache(capacity=500) self.subj_data = subj_data self.file_indices = file_indices def get_xy_files(self, idx): idx = self.file_indices[idx] sdata = self.subj_data[idx] x_file = os.path.join(self.data_folder, "X", sdata["file"]) y_file = os.path.join(self.data_folder, "Y", sdata["file"]) iif = sdata["index_in_file"] return x_file, y_file, iif def __getitem__(self, idx): x_file, y_file, iif = self.get_xy_files(idx) X = self.xfile_cache.load(x_file, iif) # [4095, 50] X = self.transform(X) # [num_freq, 90, 90] Y = self.yfile_cache.load(y_file, iif) sample = { "X": torch.tensor(X, dtype=torch.float32), "Y": torch.tensor(Y, dtype=torch.long), } return sample def transform(self, X): """ Args: X : numpy array [4095, 50] Returns: X_transformed : numpy array [num_freq, 90, 90] """ if self.transformation == "std": X_delta = np.mean(X[:, 0:4], axis=-1) # 1 to <4 Hz X_theta = np.mean(X[:, 4:8], axis=-1) # 4 to <8 Hz X_alpha = np.mean(X[:, 8:13], axis=-1) # 8 - <13 Hz X_beta = np.mean(X[:, 13:30], axis=-1) # 13 - <30 Hz X_gamma = np.mean(X[:, 30:], axis=-1) # >=30 Hz X_aggregated = np.stack( (X_delta, X_theta, X_alpha, X_beta, X_gamma), axis=1) elif self.transformation == "one": X_aggregated = np.mean(X, axis=-1).expand_dims(1) As = [] for band in range(X.shape[1]): A = self.adj_from_tril(X_aggregated[:, band], num_nodes=self.num_nodes, super_node=self.super_node) # 90 x 90 As.append(A) A = np.stack(As, axis=0).astype(np.float32) # num_freq x 90 x 90 return A def __len__(self): return len(self.file_indices) @property def num_bands(self): if self.transformation == "std": return 5 elif self.transformation == "one": return 1 def adj_from_tril(self, one_coh_arr): """ builds the A hat matrix of the paper for one sample. https://github.com/brainstorm-tools/brainstorm3/blob/master/toolbox/process/functions/process_compress_sym.m shows that the flat matrix contains the lower triangular values of the initial symmetric matrix. Args: one_coh_arr : array [num_nodes*(num_nodes+1)/2] super_node : bool (default False) Returns: A : array [num_nodes, num_nodes] """ # First construct weighted adjacency matrix A = np.zeros((self.num_nodes, self.num_nodes)) index = np.tril_indices(self.num_nodes) A[index] = one_coh_arr A = (A + A.T) if self.super_node: A = np.concatenate((A, np.ones((self.num_nodes, 1))), axis=1) # adding the super node A = np.concatenate((A, np.ones((1, self.num_nodes + 1))), axis=0) # A tilde from the paper di = np.diag_indices(self.num_nodes) A[di] = A[di] / 2 A_tilde = A + np.eye(self.num_nodes) # D tilde power -0.5 D_tilde_inv = np.diag(np.power(np.sum(A_tilde, axis=0), -0.5)) # Finally build A_hat A_hat = np.matmul(D_tilde_inv, np.matmul(A_tilde, D_tilde_inv)) return A_hat
class EEGDataset2(Dataset): """PyTorch dataloader for the temporal sequence dataset (dataset 2). This dataset has 435 identified ROIs, each containing the mean activation of sources in the regions at every time-point. The data is organized by sleep activity. Each file contains activations for the ROIs at 1500 time points (at 1Hz?) . Note: The dataset has very small values (i.e. 1e-10 range). This may cause precision errors when using single-precision floating point numbers. This class offers two normalization options: - standardizing each ROI to 0-mean, unit variance (requires preprocessing the whole dataset to extract global statistics) - scaling by a large value (NORM_CONSTANT). """ all_normalizations = [ "standard", # Standardize each ROI "none", # Multiply all values by NORM_CONSTANT "val", # Indicates that this is a validation loader so normalization is loaded from the tr loader ] NORM_CONSTANT = 1.0e10 def __init__(self, data_folder, file_indices, subj_data, normalization="none"): """ Args: data_folder : str The root folder of the preprocessed data. file_indices : Dict[int->int] Converts linear indices useful to iterate through the dataset into keys for the `subj_data` structure. subj_data : dict Information about the file location of each sample normalization : str The type of normalization to use for the data. This can be either standard, none or val. val should only be used if this is a validation dataset and the statistics are extracted from the training set. """ super(EEGDataset2, self).__init__() if normalization not in EEGDataset2.all_normalizations: raise ValueError(f"Normalization must be in {all_normalizations}.") self.normalization = normalization self.data_folder = data_folder self.xfile_cache = LRUCache(capacity=50) self.yfile_cache = LRUCache(capacity=500) self.subj_data = subj_data self.file_indices = file_indices self.init_normalizer() def get_xy_files(self, idx): idx = self.file_indices[idx] sdata = self.subj_data[idx] x_file = os.path.join(self.data_folder, "X", sdata["file"]) y_file = os.path.join(self.data_folder, "Y", sdata["file"]) iif = sdata["index_in_file"] return x_file, y_file, iif def __getitem__(self, idx): x_file, y_file, iif = self.get_xy_files(idx) X = self.xfile_cache.load(x_file, iif) # [num_nodes, time_steps] X = self.normalize(X) Y = self.yfile_cache.load(y_file, iif) sample = { "X": torch.tensor(X, dtype=torch.float32), "Y": torch.tensor(Y, dtype=torch.long), } return sample def __len__(self): return len(self.file_indices) def init_normalizer(self): if self.normalization == "val": return print( f"{time_str()} Initializing normalization ({self.normalization}) statistics." ) if self.normalization == "none": self.scaler = None return self.scaler = StandardScaler(copy=False, with_mean=True, with_std=True) # Iterate all samples to compute statistics. # TODO: This can be optimized to feed the scalers all samples read from a file # but care must be taken to actually only feed it samples whose id is in # the allowed ids. for i in range(len(self)): x_file, y_file, iif = self.get_xy_files(i) X = self.xfile_cache.load(x_file, iif) self.scaler.partial_fit(X) def normalize(self, data): """ Args: - data : array [423, time_steps] Returns: - norm_data : array [423, time_steps] """ if self.normalization == "val": raise ValueError( "Normalization cannot be `val`, must be set to a concrete value." ) if self.normalization == "none": data = data * EEGDataset2.NORM_CONSTANT else: data = self.scaler.transform(data) return data.astype(np.float32)