def main(index_exp): # ----- Wrap up dataloader ----- # Load data X, Y, _ = rdl.read_data([1, 2, 3], list(range(11)), channel_limit=12, rm_baseline=True) # Remove trials X, Y = preprocessing.remove_trials(X, Y, threshold=60) # Downsample to 64 Hz X = decimate(X, 4, axis=2) print('> After downsampling, shape of X: ', X.shape) # Split data for cross validation kf = KFold(n_splits=10, shuffle=True, random_state=23) for i, (train_index, test_index) in enumerate(kf.split(X)): if i == index_exp: train_data, train_target = X[train_index, :], Y[train_index] test_data, test_target = X[test_index, :], Y[test_index] # (sample, channel, time) -> (sample, 1, channel, time) [train_data, test_data] = [x.reshape((x.shape[0],1,x.shape[1],x.shape[2])) \ for x in [train_data,test_data]] (train_dataTS, train_targetTS, test_dataTS, test_targetTS) = map(torch.from_numpy, (train_data, train_target, test_data, test_target)) [train_dataset,test_dataset] = map(\ Data.TensorDataset, [train_dataTS.float(),test_dataTS.float()], [train_targetTS.float(),test_targetTS.float()]) train_loader = Data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) test_loader = Data.DataLoader(test_dataset, batch_size=args.batch_size) # ----- Create model ----- gen = tcGAN.generator().to(device=device) dis = tcGAN.discregressor().to(device=device) if torch.cuda.device_count() > 1: gen = nn.DataParallel(gen) dis = nn.DataParallel(dis) # ----- Train model ----- trainer = Trainer(gen, dis, train_loader, test_loader) trainer.train(args.num_epoch)
def main(index_exp, index_split): faulthandler.enable() torch.cuda.empty_cache() best_error = 100 lr_step = [40, 70, 120] multiframe = ['convlstm', 'convfc'] dirName = '%s_data%d_%s_%s_%s'%(args.model_name, args.data_cate, args.augmentation, args.loss_type, args.file_name) fileName = '%s_split%d_exp%d'%(dirName, index_split, index_exp) # Create folder for results of this model if not os.path.exists('./results/%s'%(dirName)): os.makedirs('./results/%s'%(dirName)) # ------------- Wrap up dataloader ----------------- if args.input_type == 'signal': X, Y_reg, C = raw_dataloader.read_data([1,2,3], list(range(11)), channel_limit=21, rm_baseline=True) num_channel = X.shape[1] num_feature = X.shape[2] # Number of time sample # Remove trials X, Y_reg = preprocessing.remove_trials(X, Y_reg, threshold=60) # Split data for cross validation if args.num_fold == 1: train_data, test_data, train_target, test_target = train_test_split(X, Y_reg, test_size=0.1, random_state=23) # Random state 15: training error becomes lower, testing error becomes higher else: kf = KFold(n_splits=args.num_fold, shuffle=True, random_state=23) for i, (train_index, test_index) in enumerate(kf.split(X)): if i == index_exp: train_data, train_target = X[train_index, :], Y_reg[train_index] test_data, test_target = X[test_index, :], Y_reg[test_index] # Split data for ensemble methods if not args.ensemble: if args.num_split > 1: data_list, target_list = preprocessing.stratified_split(train_data, train_target, n_split=args.num_split, mode=args.split_mode) train_data, train_target = data_list[index_split], target_list[index_split] ''' kf = KFold(n_splits=args.num_split, shuffle=True, random_state=32) for i, (other_index, split_index) in enumerate(kf.split(train_data)): if i == index_split: train_data, train_target = train_data[split_index, :], train_target[split_index] ''' # Normalize the data if args.normalize: train_data, test_data = preprocessing.normalize(train_data, test_data) # Data augmentation if args.augmentation == 'overlapping': train_data, train_target = data_augmentation.aug(train_data, train_target, args.augmentation, (256, 64, 128)) test_data, test_target = data_augmentation.aug(test_data, test_target, args.augmentation, (256, 64, 128)) elif args.augmentation == 'add_noise': train_data, train_target = data_augmentation.aug(train_data, train_target, args.augmentation, (30, 1)) elif args.augmentation == 'add_noise_minority': train_data, train_target = data_augmentation.aug(train_data, train_target, args.augmentation, (30, 1)) elif args.augmentation == 'SMOTER': train_data, train_target = data_augmentation.aug(train_data, train_target, args.augmentation) # scale data if args.scale_data: train_data, test_data = train_data.reshape((train_data.shape[0],-1)), test_data.reshape((test_data.shape[0],-1)) train_data, test_data = preprocessing.scale(train_data, test_data) train_data = train_data.reshape((train_data.shape[0],num_channel, -1)) test_data = test_data.reshape((test_data.shape[0],num_channel, -1)) if args.model_name in ['eegnet', 'eegnet_trans_signal']: # (sample, channel, time) -> (sample, channel_NN, channel_EEG, time) [train_data, test_data] = [X.reshape((X.shape[0], 1, num_channel, num_feature)) \ for X in [train_data, test_data]] (train_dataTS, train_targetTS, test_dataTS, test_targetTS) = map( torch.from_numpy, (train_data, train_target, test_data, test_target)) [train_dataset,test_dataset] = map(\ Data.TensorDataset, [train_dataTS.float(),test_dataTS.float()], [train_targetTS.float(),test_targetTS.float()]) if not args.str_sampling: train_loader = Data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) test_loader = Data.DataLoader(test_dataset, batch_size=args.batch_size) model_param = [train_data.shape] elif args.input_type == 'power': if args.data_cate == 1: ERSP_all, tmp_all, freqs = dataloader.load_data() elif args.data_cate == 2: data_file = './raw_data/ERSP_from_raw_%d_channel21.data'%(args.index_sub) with open(data_file, 'rb') as fp: dict_ERSP = pickle.load(fp) ERSP_all, tmp_all = dict_ERSP['ERSP'], dict_ERSP['SLs'] num_channel = ERSP_all.shape[1] num_freq = ERSP_all.shape[2] # Remove trials ERSP_all, tmp_all = preprocessing.remove_trials(ERSP_all, tmp_all, threshold=60) # Split data for cross validation if args.num_fold == 1: train_data, test_data, train_target, test_target = train_test_split(ERSP_all, tmp_all[:,2], test_size=0.1, random_state=23) else: kf = KFold(n_splits=args.num_fold, shuffle=True, random_state=23) for i, (train_index, test_index) in enumerate(kf.split(ERSP_all)): if i == index_exp: train_data, test_data = ERSP_all[train_index, :], ERSP_all[test_index, :] if args.data_cate == 2: train_target, test_target = tmp_all[train_index], tmp_all[test_index] else: train_target, test_target = tmp_all[train_index, 2], tmp_all[test_index, 2] if args.add_CE: assert args.data_cate == 2 with open('./raw_data/CE_sub%d'%(args.index_sub), 'rb') as fp: CE = pickle.load(fp) CE_train, CE_test = CE[train_index,:], CE[test_index,:] # PCA for CE pca = PCA(n_components=10) pca.fit(CE_train) CE_train, CE_test = pca.transform(CE_train), pca.transform(CE_test) # Split data for ensemble methods if not args.ensemble: if args.num_split > 1: data_list, target_list = preprocessing.stratified_split(train_data, train_target, n_split=args.num_split, mode=args.split_mode) train_data, train_target = data_list[index_split], target_list[index_split] ''' kf = KFold(n_splits=args.num_split, shuffle=True, random_state=32) for i, (other_index, split_index) in enumerate(kf.split(np.arange(len(train_data)))): if i == index_split: train_data, train_target = train_data[split_index, :], train_target[split_index] ''' # Concatenate train and test for standardizinsg data = np.concatenate((train_data, test_data), axis=0) target = np.concatenate((train_target, test_target)) # Standardize data num_train = len(train_data) data, target = preprocessing.standardize(data, target, train_indices = np.arange(num_train), threshold=0.0) data = data.reshape((data.shape[0], -1)) # Scale target between 0 and 1 if args.post_scale: print('Scale the target between 0-1') target = target/60 # Split data train_data, test_data = data[:num_train, :], data[num_train:, :] train_target, test_target = target[:num_train], target[num_train:] # Data augmentation if args.augmentation == 'SMOTER': train_data, train_target = data_augmentation.aug(train_data, train_target, args.augmentation) # center data if args.center_flag: train_data, test_data = preprocessing.center(train_data, test_data) # scale data if args.scale_data: train_data, test_data = preprocessing.scale(train_data, test_data) # Add conditional entropy if args.add_CE: train_data = np.concatenate((train_data, CE_train), axis=1) test_data = np.concatenate((test_data, CE_train), axis=1) if args.model_name == 'eegnet_trans_power': # (sample, channel, freq) -> (sample, channel_NN, channel_EEG, freq) [train_data, test_data] = [X.reshape((X.shape[0], 1, num_channel, num_freq)) \ for X in [train_data, test_data]] (train_dataTS, train_targetTS, test_dataTS, test_targetTS) = map( torch.from_numpy, (train_data, train_target, test_data, test_target)) [train_dataset,test_dataset] = map(\ Data.TensorDataset, [train_dataTS.float(),test_dataTS.float()], [train_targetTS.float(),test_targetTS.float()]) if not args.str_sampling: train_loader = Data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) test_loader = Data.DataLoader(test_dataset, batch_size=args.batch_size) model_param = [train_data.shape] elif args.input_type == 'image': if args.ensemble: input_model_name = args.pre_model_name else: input_model_name = args.model_name assert (input_model_name in multiframe) == (args.num_time>1) # Let input size be 224x224 if the model is vgg16 if input_model_name in ['vgg16', 'resnet50']: input_size = 224 else: input_size = 64 # Load Data data_transforms = { 'train': transforms.Compose([ ndl.Rescale(input_size, args.num_time), ndl.ToTensor(args.num_time)]), 'test': transforms.Compose([ ndl.Rescale(input_size, args.num_time), ndl.ToTensor(args.num_time)]) } print("Initializing Datasets and Dataloaders...") # Create training and testing datasets # image_datasets = {x: ndl.TopoplotLoader(args.image_folder, x, args.num_time, data_transforms[x], # scale=args.scale_image, index_exp=index_exp, index_split=index_split) for x in ['train', 'test']} [train_dataset,test_dataset] = [ndl.TopoplotLoader(args.image_folder, x, args.num_time, data_transforms[x], scale=args.scale_image, index_exp=index_exp, index_split=index_split) for x in ['train', 'test']] # Create training and testing dataloaders # if not args.str_sampling: # train_loader = Data.DataLoader(image_datasets['train'], batch_size=args.batch_size, shuffle=True, num_workers=4) # test_loader = Data.DataLoader(image_datasets['test'], batch_size=args.batch_size, shuffle=False, num_workers=4) if not args.str_sampling: train_loader = Data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) test_loader = Data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4) model_param = [input_size] elif args.input_type == 'EEGLearn_img': # Load data with open('./EEGLearn_imgs/data1.data', 'rb') as fp: dict_data = pickle.load(fp) data, target = dict_data['data'], dict_data['target'] input_size = data.shape[2] # Split data for cross validation if args.num_fold == 1: train_data, test_data, train_target, test_target = train_test_split(data, target, test_size=0.1, random_state=23) # Random state 15: training error becomes lower, testing error becomes higher else: kf = KFold(n_splits=args.num_fold, shuffle=True, random_state=23) for i, (train_index, test_index) in enumerate(kf.split(data)): if i == index_exp: train_data, train_target = data[train_index, :], target[train_index] test_data, test_target = data[test_index, :], target[test_index] (train_dataTS, train_targetTS, test_dataTS, test_targetTS) = map( torch.from_numpy, (train_data, train_target, test_data, test_target)) [train_dataset,test_dataset] = map(\ Data.TensorDataset, [train_dataTS.float(),test_dataTS.float()], [train_targetTS.float(),test_targetTS.float()]) if not args.str_sampling: train_loader = Data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) test_loader = Data.DataLoader(test_dataset, batch_size=args.batch_size) # ------------ Create model --------------- if args.input_type in ['image','EEGLearn_img']: model_param = [input_size] else: model_param = [train_data.shape] if not args.ensemble: model = read_model(args.model_name, model_param) else: pre_models = [] for i in range(args.num_split): pre_model = read_model(args.pre_model_name, model_param) pre_model.load_state_dict( torch.load('%s/last_model_exp%d_split%d.pt'%(args.ensemble, index_exp, i)) ) set_parameter_requires_grad(pre_model, True) pre_models.append(pre_model) model = models.__dict__[args.model_name](pre_models) print('Use model %s'%(args.model_name)) # Run on GPU model = model.to(device=device) if torch.cuda.device_count() > 1: model = nn.DataParallel(model) # define loss function (criterion) and optimizer if args.loss_type == 'L2': criterion = nn.MSELoss().to(device=device) elif args.loss_type == 'L1': criterion = nn.L1Loss().to(device=device) elif args.loss_type == 'L4': criterion = L4Loss elif args.loss_type == 'MyLoss': criterion = MyLoss print('Use %s loss'%(args.loss_type)) optimizer = torch.optim.SGD(model.parameters(), lr=args.lr_rate,momentum=0.9) #optimizer = torch.optim.Adam(model.parameters(), lr=args.lr_rate) # Record loss and accuracy of each epoch dict_error = {'train_std': list(range(args.num_epoch)), 'test_std': list(range(args.num_epoch)), 'train_mape': list(range(args.num_epoch)), 'test_mape': list(range(args.num_epoch))} # optionally evaluate the trained model if args.evaluate: if args.resume: if os.path.isfile(args.resume): model.load_state_dict(torch.load(args.resume)) _, target, pred, _, _ = validate(test_loader, model, criterion) plot_scatter(target, pred, dirName, fileName) return 0 # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_error = checkpoint['best_error'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) dict_error['train_std'][:args.start_epoch] = checkpoint['dict_error']['train_std'] dict_error['test_std'][:args.start_epoch] = checkpoint['dict_error']['test_std'] print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) # ------------- Train model ------------------ for epoch in range(args.start_epoch, args.num_epoch): # Create dataloader if using stratified sampler if args.str_sampling: sampler = SubsetRandomSampler(get_indices_RSS(train_target, int(0.5*len(train_target)))) train_loader = Data.DataLoader(train_dataset, batch_size=args.batch_size, \ sampler=sampler, num_workers=4) # Learning rate decay if epoch in lr_step: for param_group in optimizer.param_groups: param_group['lr'] *= 0.1 # train for one epoch _, dict_error['train_std'][epoch], dict_error['train_mape'][epoch] = \ train(train_loader, model, criterion, optimizer, epoch) # evaluate on validation set _, _, _, std_error, dict_error['test_mape'][epoch] = validate(test_loader, model, criterion) dict_error['test_std'][epoch] = std_error # remember best standard error and save checkpoint is_best = std_error < best_error best_error = min(std_error, best_error) save_checkpoint({ 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_error': best_error, 'optimizer': optimizer.state_dict(), 'dict_error': dict_error }, is_best) # Save best model if is_best: torch.save(model.state_dict(), './results/%s/best_model_exp%d_split%d.pt'%(dirName, index_exp, index_split)) if epoch == args.num_epoch-1: torch.save(model.state_dict(), './results/%s/last_model_exp%d_split%d.pt'%(dirName, index_exp, index_split)) # Plot error curve plot_error(dict_error, dirName, fileName) # Plot scatter plots _, target, pred, _, _ = validate(test_loader, model, criterion) plot_scatter(target, pred, dirName, fileName) dict_error['target'], dict_error['pred'] = target, pred # Plot histogram import matplotlib.pyplot as plt plt.hist(target, label = 'True') plt.hist(pred, label = 'Pred') plt.legend(loc='upper right') plt.savefig('./results/hist.png') # Save error over epochs with open('./results/%s/%s.data'%(dirName, fileName), 'wb') as fp: pickle.dump(dict_error, fp)
axs[1].set_title('r = %.3f' % (np.corrcoef(true, pred)[0, 1])) std = mean_squared_error(true, pred)**0.5 fig.suptitle('Standard error: %.3f' % (std)) if fileName is not None: plt.savefig('./results/classical/%s_scatter.png' % (fileName)) plt.close() if __name__ == '__main__': # Load data X, Y, C, S = raw_dataloader.read_data([1, 2, 3], range(11), channel_limit=21, rm_baseline=True) # Leave one subject out dict_error = { 'train_std': np.zeros((11, 5)), 'test_std': np.zeros((11, 5)) } for i_base in range(11): print('----- Subject %d -----' % (i_base)) lst_model = LST(11, i_base) indices_base, indices_other = np.where(S == i_base)[0], np.where( S != i_base)[0] base_data, base_target, base_sub = X[ indices_base, :], Y[indices_base], S[indices_base]
global args, device args = parser.parse_args() if args.clf_model in ['pcafc', 'pcafc_sd']: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(device) # Create folder for results of this model if not os.path.exists('./results/%s' % (args.dirName)): os.makedirs('./results/%s' % (args.dirName)) # Load data if args.input_type == 'signal': data, SLs, _, S, D = raw_dataloader.read_data([1, 2, 3], range(11), channel_limit=21, rm_baseline=True) #data = np.random.rand(data.shape[0], data.shape[1], data.shape[2]) elif args.input_type == 'ERSP': with open('./raw_data/ERSP_from_raw_100_channel21.data', 'rb') as fp: dict_ERSP = pickle.load(fp) data, SLs, S, D = dict_ERSP['ERSP'], dict_ERSP['SLs'], dict_ERSP[ 'Sub_ID'], dict_ERSP['D'] data, SLs = preprocessing.standardize(data, SLs, threshold=0.0) elif args.input_type == 'bp_ratio': data, SLs, _, S, D = raw_dataloader.read_data([1, 2, 3], range(11), channel_limit=21, rm_baseline=True) low, high = [4, 7, 13], [7, 13, 30] data = bandpower.get_bandpower(data, low=low, high=high)
'Sub_ID': subjects, 'D': D } with open(savePath, 'wb') as fp: pickle.dump(dict_ERSP, fp) return new_f, t, Zxx if __name__ == '__main__': channel_limit = 21 # Save data for all subject X, Y_reg, channels, S, D = dl.read_data([1, 2, 3], list(range(11)), channel_limit=channel_limit, rm_baseline=True) freq, t, Zxx = STFT( X, Y_reg, S, D, 2, 30, savePath='./raw_data/ERSP_from_raw_100_channel%d_nolog.data' % (channel_limit)) ''' print('Calculate conditional entropy...') _ = add_features.calculate_CE(X, './raw_data/CE_sub100_channel%d.data'%(channel_limit)) # Save data for each subject
print(CE_X_Y) CE_Y_X = get_conditional_entropy(Y, Y, X, X) print(CE_Y_X) print('Takes %.3f seconds'%(time.time()-start_time)) ''' ''' # Load preprocessed ERSP data ERSP_all, tmp_all, freqs = dataloader.load_data() ERSP_all, tmp_all = preprocessing.remove_trials(ERSP_all, tmp_all, 60) ERSP_all, _ = preprocessing.standardize(ERSP_all, tmp_all) correlation_all = get_correlations(ERSP_all) ''' # Load raw data for i in range(11): X, _, Y_reg, channels = raw_dataloader.read_data([1, 2, 3], date=[i], pred_type='class', rm_baseline=True) X, Y_reg = preprocessing.remove_trials(X, Y_reg, 60) _ = calculate_CE(X, './raw_data/CE_sub%d' % (i)) X, _, Y_reg, channels = raw_dataloader.read_data([1, 2, 3], date=range(11), pred_type='class', rm_baseline=True) X, Y_reg = preprocessing.remove_trials(X, Y_reg, 60) _ = calculate_CE(X, './raw_data/CE_sub100')