コード例 #1
0
ファイル: train_tcGAN.py プロジェクト: hundredball/Math24
def main(index_exp):

    # ----- Wrap up dataloader -----
    # Load data
    X, Y, _ = rdl.read_data([1, 2, 3],
                            list(range(11)),
                            channel_limit=12,
                            rm_baseline=True)

    # Remove trials
    X, Y = preprocessing.remove_trials(X, Y, threshold=60)

    # Downsample to 64 Hz
    X = decimate(X, 4, axis=2)
    print('> After downsampling, shape of X: ', X.shape)

    # Split data for cross validation
    kf = KFold(n_splits=10, shuffle=True, random_state=23)
    for i, (train_index, test_index) in enumerate(kf.split(X)):
        if i == index_exp:
            train_data, train_target = X[train_index, :], Y[train_index]
            test_data, test_target = X[test_index, :], Y[test_index]

    # (sample, channel, time) -> (sample, 1, channel, time)
    [train_data, test_data] = [x.reshape((x.shape[0],1,x.shape[1],x.shape[2])) \
                               for x in [train_data,test_data]]

    (train_dataTS, train_targetTS, test_dataTS,
     test_targetTS) = map(torch.from_numpy,
                          (train_data, train_target, test_data, test_target))
    [train_dataset,test_dataset] = map(\
            Data.TensorDataset, [train_dataTS.float(),test_dataTS.float()], [train_targetTS.float(),test_targetTS.float()])

    train_loader = Data.DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=True)
    test_loader = Data.DataLoader(test_dataset, batch_size=args.batch_size)

    # ----- Create model -----
    gen = tcGAN.generator().to(device=device)
    dis = tcGAN.discregressor().to(device=device)

    if torch.cuda.device_count() > 1:
        gen = nn.DataParallel(gen)
        dis = nn.DataParallel(dis)

    # ----- Train model -----
    trainer = Trainer(gen, dis, train_loader, test_loader)
    trainer.train(args.num_epoch)
コード例 #2
0
def main(index_exp, index_split):
    
    faulthandler.enable()
    torch.cuda.empty_cache()
    
    best_error = 100
    lr_step = [40, 70, 120]
    multiframe = ['convlstm', 'convfc']
    dirName = '%s_data%d_%s_%s_%s'%(args.model_name, args.data_cate, args.augmentation, args.loss_type, args.file_name)
    fileName = '%s_split%d_exp%d'%(dirName, index_split, index_exp)
    
    # Create folder for results of this model
    if not os.path.exists('./results/%s'%(dirName)):
        os.makedirs('./results/%s'%(dirName))
    
    # ------------- Wrap up dataloader -----------------
    if args.input_type == 'signal':
        X, Y_reg, C = raw_dataloader.read_data([1,2,3], list(range(11)), channel_limit=21, rm_baseline=True)
        num_channel = X.shape[1]
        num_feature = X.shape[2]     # Number of time sample
        
        # Remove trials
        X, Y_reg = preprocessing.remove_trials(X, Y_reg, threshold=60)
        
        # Split data for cross validation
        if args.num_fold == 1:
            train_data, test_data, train_target, test_target = train_test_split(X, Y_reg, test_size=0.1, random_state=23)
            # Random state 15: training error becomes lower, testing error becomes higher
        else:
            kf = KFold(n_splits=args.num_fold, shuffle=True, random_state=23)
            for i, (train_index, test_index) in enumerate(kf.split(X)):
                if i == index_exp:
                    train_data, train_target = X[train_index, :], Y_reg[train_index]
                    test_data, test_target = X[test_index, :], Y_reg[test_index]
                    
        # Split data for ensemble methods
        if not args.ensemble:
            if args.num_split > 1:
                data_list, target_list = preprocessing.stratified_split(train_data, train_target, n_split=args.num_split, mode=args.split_mode)
                train_data, train_target = data_list[index_split], target_list[index_split]
                '''
                kf = KFold(n_splits=args.num_split, shuffle=True, random_state=32)
                for i, (other_index, split_index) in enumerate(kf.split(train_data)):
                    if i == index_split:
                        train_data, train_target = train_data[split_index, :], train_target[split_index]
                '''
        # Normalize the data
        if args.normalize:
            train_data, test_data = preprocessing.normalize(train_data, test_data)
        
                    
        # Data augmentation
        if args.augmentation == 'overlapping':
            train_data, train_target = data_augmentation.aug(train_data, train_target, args.augmentation,
                                                             (256, 64, 128))
            test_data, test_target = data_augmentation.aug(test_data, test_target, args.augmentation,
                                                             (256, 64, 128))
        elif args.augmentation == 'add_noise':
            train_data, train_target = data_augmentation.aug(train_data, train_target, args.augmentation,
                                                             (30, 1))
        elif args.augmentation == 'add_noise_minority':
            train_data, train_target = data_augmentation.aug(train_data, train_target, args.augmentation,
                                                             (30, 1))
        elif args.augmentation == 'SMOTER':
            train_data, train_target = data_augmentation.aug(train_data, train_target, args.augmentation)
            
        # scale data
        if args.scale_data:
            train_data, test_data = train_data.reshape((train_data.shape[0],-1)), test_data.reshape((test_data.shape[0],-1))
            train_data, test_data = preprocessing.scale(train_data, test_data)
            train_data = train_data.reshape((train_data.shape[0],num_channel, -1))
            test_data = test_data.reshape((test_data.shape[0],num_channel, -1))
            
        if args.model_name in ['eegnet', 'eegnet_trans_signal']:
            # (sample, channel, time) -> (sample, channel_NN, channel_EEG, time)
            [train_data, test_data] = [X.reshape((X.shape[0], 1, num_channel, num_feature)) \
                                       for X in [train_data, test_data]]
        
        
        (train_dataTS, train_targetTS, test_dataTS, test_targetTS) = map(
                torch.from_numpy, (train_data, train_target, test_data, test_target))
        [train_dataset,test_dataset] = map(\
                Data.TensorDataset, [train_dataTS.float(),test_dataTS.float()], [train_targetTS.float(),test_targetTS.float()])

        if not args.str_sampling:
            train_loader = Data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
        test_loader = Data.DataLoader(test_dataset, batch_size=args.batch_size)
        
        model_param = [train_data.shape]
        
    elif args.input_type == 'power':
        if args.data_cate == 1:
            ERSP_all, tmp_all, freqs = dataloader.load_data()
        elif args.data_cate == 2:
            data_file = './raw_data/ERSP_from_raw_%d_channel21.data'%(args.index_sub)
            with open(data_file, 'rb') as fp:
                dict_ERSP = pickle.load(fp)
            ERSP_all, tmp_all = dict_ERSP['ERSP'], dict_ERSP['SLs']
        num_channel = ERSP_all.shape[1]
        num_freq = ERSP_all.shape[2]
            
        # Remove trials
        ERSP_all, tmp_all = preprocessing.remove_trials(ERSP_all, tmp_all, threshold=60)
        
        # Split data for cross validation
        if args.num_fold == 1:
            train_data, test_data, train_target, test_target = train_test_split(ERSP_all, tmp_all[:,2], test_size=0.1, random_state=23)
        else:
            kf = KFold(n_splits=args.num_fold, shuffle=True, random_state=23)
            for i, (train_index, test_index) in enumerate(kf.split(ERSP_all)):
                if i == index_exp:
                    train_data, test_data = ERSP_all[train_index, :], ERSP_all[test_index, :]
                    if args.data_cate == 2:
                        train_target, test_target = tmp_all[train_index], tmp_all[test_index]
                    else:
                        train_target, test_target = tmp_all[train_index, 2], tmp_all[test_index, 2]
                        
                    if args.add_CE:
                        assert args.data_cate == 2
                        with open('./raw_data/CE_sub%d'%(args.index_sub), 'rb') as fp:
                            CE = pickle.load(fp)
                        CE_train, CE_test = CE[train_index,:], CE[test_index,:]
                        # PCA for CE
                        pca = PCA(n_components=10)
                        pca.fit(CE_train)
                        CE_train, CE_test = pca.transform(CE_train), pca.transform(CE_test)
                        
                    
        # Split data for ensemble methods
        if not args.ensemble:
            if args.num_split > 1:
                data_list, target_list = preprocessing.stratified_split(train_data, train_target, n_split=args.num_split, mode=args.split_mode)
                train_data, train_target = data_list[index_split], target_list[index_split]
                '''
                kf = KFold(n_splits=args.num_split, shuffle=True, random_state=32)
                for i, (other_index, split_index) in enumerate(kf.split(np.arange(len(train_data)))):
                    if i == index_split:
                        train_data, train_target = train_data[split_index, :], train_target[split_index]
                '''
                    
        # Concatenate train and test for standardizinsg
        data = np.concatenate((train_data, test_data), axis=0)
        target = np.concatenate((train_target, test_target))
                    
        # Standardize data
        num_train = len(train_data)
        data, target = preprocessing.standardize(data, target, train_indices = np.arange(num_train), threshold=0.0)
        data = data.reshape((data.shape[0], -1))
        
        # Scale target between 0 and 1
        if args.post_scale:
            print('Scale the target between 0-1')
            target = target/60
        
        # Split data
        train_data, test_data = data[:num_train, :], data[num_train:, :]
        train_target, test_target = target[:num_train], target[num_train:]
        
        # Data augmentation
        if args.augmentation == 'SMOTER':
            train_data, train_target = data_augmentation.aug(train_data, train_target, args.augmentation)
        
        # center data
        if args.center_flag:
            train_data, test_data = preprocessing.center(train_data, test_data)
            
        # scale data
        if args.scale_data:
            train_data, test_data = preprocessing.scale(train_data, test_data)
            
        # Add conditional entropy
        if args.add_CE:
            train_data = np.concatenate((train_data, CE_train), axis=1)
            test_data = np.concatenate((test_data, CE_train), axis=1)
            
        if args.model_name == 'eegnet_trans_power':
            # (sample, channel, freq) -> (sample, channel_NN, channel_EEG, freq)
            [train_data, test_data] = [X.reshape((X.shape[0], 1, num_channel, num_freq)) \
                                       for X in [train_data, test_data]]
        
        (train_dataTS, train_targetTS, test_dataTS, test_targetTS) = map(
                torch.from_numpy, (train_data, train_target, test_data, test_target))
        [train_dataset,test_dataset] = map(\
                Data.TensorDataset, [train_dataTS.float(),test_dataTS.float()], [train_targetTS.float(),test_targetTS.float()])

        if not args.str_sampling:
            train_loader = Data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
        test_loader = Data.DataLoader(test_dataset, batch_size=args.batch_size)
        
        model_param = [train_data.shape]
        
    elif args.input_type == 'image':
        
        if args.ensemble:
            input_model_name = args.pre_model_name
        else:
            input_model_name = args.model_name
        
        assert (input_model_name in multiframe) == (args.num_time>1)
        
        # Let input size be 224x224 if the model is vgg16
        if input_model_name in ['vgg16', 'resnet50']:
            input_size = 224
        else:
            input_size = 64
            
        # Load Data
        data_transforms = {
                'train': transforms.Compose([
                        ndl.Rescale(input_size, args.num_time),
                        ndl.ToTensor(args.num_time)]), 
                'test': transforms.Compose([
                        ndl.Rescale(input_size, args.num_time),
                        ndl.ToTensor(args.num_time)])
                }

        print("Initializing Datasets and Dataloaders...")

        # Create training and testing datasets
        # image_datasets = {x: ndl.TopoplotLoader(args.image_folder, x, args.num_time, data_transforms[x],
        #                 scale=args.scale_image, index_exp=index_exp, index_split=index_split) for x in ['train', 'test']}
        [train_dataset,test_dataset] = [ndl.TopoplotLoader(args.image_folder, x, args.num_time, data_transforms[x],
                        scale=args.scale_image, index_exp=index_exp, index_split=index_split) for x in ['train', 'test']]

        # Create training and testing dataloaders
        # if not args.str_sampling:
        #     train_loader = Data.DataLoader(image_datasets['train'], batch_size=args.batch_size, shuffle=True, num_workers=4)
        # test_loader = Data.DataLoader(image_datasets['test'], batch_size=args.batch_size, shuffle=False, num_workers=4)
        if not args.str_sampling:
            train_loader = Data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
        test_loader = Data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
        model_param = [input_size]
        
    elif args.input_type == 'EEGLearn_img':
        
        # Load data
        with open('./EEGLearn_imgs/data1.data', 'rb') as fp:
            dict_data = pickle.load(fp)
        data, target = dict_data['data'], dict_data['target']
        input_size = data.shape[2]
        
        # Split data for cross validation
        if args.num_fold == 1:
            train_data, test_data, train_target, test_target = train_test_split(data, target, test_size=0.1, random_state=23)
            # Random state 15: training error becomes lower, testing error becomes higher
        else:
            kf = KFold(n_splits=args.num_fold, shuffle=True, random_state=23)
            for i, (train_index, test_index) in enumerate(kf.split(data)):
                if i == index_exp:
                    train_data, train_target = data[train_index, :], target[train_index]
                    test_data, test_target = data[test_index, :], target[test_index]
        
        (train_dataTS, train_targetTS, test_dataTS, test_targetTS) = map(
                torch.from_numpy, (train_data, train_target, test_data, test_target))
        [train_dataset,test_dataset] = map(\
                Data.TensorDataset, [train_dataTS.float(),test_dataTS.float()], [train_targetTS.float(),test_targetTS.float()])

        if not args.str_sampling:
            train_loader = Data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
        test_loader = Data.DataLoader(test_dataset, batch_size=args.batch_size)
        
        
    # ------------ Create model ---------------
    if args.input_type in ['image','EEGLearn_img']:
        model_param = [input_size]
    else:
        model_param = [train_data.shape]
    
    if not args.ensemble:
        model = read_model(args.model_name, model_param)
    else:
        pre_models = []
        for i in range(args.num_split):
            pre_model = read_model(args.pre_model_name, model_param)
            pre_model.load_state_dict( torch.load('%s/last_model_exp%d_split%d.pt'%(args.ensemble, index_exp, i)) )
            set_parameter_requires_grad(pre_model, True)
            pre_models.append(pre_model)
            
        model = models.__dict__[args.model_name](pre_models)
        
    print('Use model %s'%(args.model_name))
        
    # Run on GPU
    model = model.to(device=device)
    
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'L2':
        criterion = nn.MSELoss().to(device=device)
    elif args.loss_type == 'L1':
        criterion = nn.L1Loss().to(device=device)
    elif args.loss_type == 'L4':
        criterion = L4Loss
    elif args.loss_type == 'MyLoss':
        criterion = MyLoss
    print('Use %s loss'%(args.loss_type))
    
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr_rate,momentum=0.9)
    #optimizer = torch.optim.Adam(model.parameters(), lr=args.lr_rate)
    
    # Record loss and accuracy of each epoch
    dict_error = {'train_std': list(range(args.num_epoch)), 'test_std': list(range(args.num_epoch)),
                  'train_mape': list(range(args.num_epoch)), 'test_mape': list(range(args.num_epoch))}
    
    # optionally evaluate the trained model
    if args.evaluate:
        if args.resume:
            if os.path.isfile(args.resume):
                model.load_state_dict(torch.load(args.resume))
        
        _, target, pred, _, _ = validate(test_loader, model, criterion)
        plot_scatter(target, pred, dirName, fileName)
        return 0
    
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_error = checkpoint['best_error']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            dict_error['train_std'][:args.start_epoch] = checkpoint['dict_error']['train_std']
            dict_error['test_std'][:args.start_epoch] = checkpoint['dict_error']['test_std']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    
    # ------------- Train model ------------------

    for epoch in range(args.start_epoch, args.num_epoch):
        # Create dataloader if using stratified sampler
        if args.str_sampling:
            sampler = SubsetRandomSampler(get_indices_RSS(train_target, int(0.5*len(train_target))))
            train_loader = Data.DataLoader(train_dataset, batch_size=args.batch_size, \
                                           sampler=sampler, num_workers=4)
            
        # Learning rate decay
        if epoch in lr_step:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.1
        
        # train for one epoch
        _, dict_error['train_std'][epoch], dict_error['train_mape'][epoch] = \
            train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        _, _, _, std_error, dict_error['test_mape'][epoch] = validate(test_loader, model, criterion)
        dict_error['test_std'][epoch] = std_error

        # remember best standard error and save checkpoint
        is_best = std_error < best_error
        best_error = min(std_error, best_error)
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_error': best_error,
            'optimizer': optimizer.state_dict(),
            'dict_error': dict_error
        }, is_best)
        
        # Save best model
        if is_best:
            torch.save(model.state_dict(), './results/%s/best_model_exp%d_split%d.pt'%(dirName, index_exp, index_split))
        if epoch == args.num_epoch-1:
            torch.save(model.state_dict(), './results/%s/last_model_exp%d_split%d.pt'%(dirName, index_exp, index_split))
    # Plot error curve
    plot_error(dict_error, dirName, fileName)
    
    # Plot scatter plots
    _, target, pred, _, _ = validate(test_loader, model, criterion)
    plot_scatter(target, pred, dirName, fileName)
    dict_error['target'], dict_error['pred'] = target, pred
    
    # Plot histogram
    import matplotlib.pyplot as plt
    plt.hist(target, label = 'True')
    plt.hist(pred, label = 'Pred')
    plt.legend(loc='upper right')
    plt.savefig('./results/hist.png')
    
    # Save error over epochs
    with open('./results/%s/%s.data'%(dirName, fileName), 'wb') as fp:
        pickle.dump(dict_error, fp)
コード例 #3
0
ファイル: LSTransform.py プロジェクト: hundredball/Math24
    axs[1].set_title('r = %.3f' % (np.corrcoef(true, pred)[0, 1]))

    std = mean_squared_error(true, pred)**0.5
    fig.suptitle('Standard error: %.3f' % (std))

    if fileName is not None:
        plt.savefig('./results/classical/%s_scatter.png' % (fileName))

    plt.close()


if __name__ == '__main__':

    # Load data
    X, Y, C, S = raw_dataloader.read_data([1, 2, 3],
                                          range(11),
                                          channel_limit=21,
                                          rm_baseline=True)

    # Leave one subject out
    dict_error = {
        'train_std': np.zeros((11, 5)),
        'test_std': np.zeros((11, 5))
    }
    for i_base in range(11):
        print('----- Subject %d -----' % (i_base))

        lst_model = LST(11, i_base)
        indices_base, indices_other = np.where(S == i_base)[0], np.where(
            S != i_base)[0]
        base_data, base_target, base_sub = X[
            indices_base, :], Y[indices_base], S[indices_base]
コード例 #4
0
    global args, device

    args = parser.parse_args()

    if args.clf_model in ['pcafc', 'pcafc_sd']:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        print(device)

    # Create folder for results of this model
    if not os.path.exists('./results/%s' % (args.dirName)):
        os.makedirs('./results/%s' % (args.dirName))

    # Load data
    if args.input_type == 'signal':
        data, SLs, _, S, D = raw_dataloader.read_data([1, 2, 3],
                                                      range(11),
                                                      channel_limit=21,
                                                      rm_baseline=True)
        #data = np.random.rand(data.shape[0], data.shape[1], data.shape[2])
    elif args.input_type == 'ERSP':
        with open('./raw_data/ERSP_from_raw_100_channel21.data', 'rb') as fp:
            dict_ERSP = pickle.load(fp)
        data, SLs, S, D = dict_ERSP['ERSP'], dict_ERSP['SLs'], dict_ERSP[
            'Sub_ID'], dict_ERSP['D']
        data, SLs = preprocessing.standardize(data, SLs, threshold=0.0)
    elif args.input_type == 'bp_ratio':
        data, SLs, _, S, D = raw_dataloader.read_data([1, 2, 3],
                                                      range(11),
                                                      channel_limit=21,
                                                      rm_baseline=True)
        low, high = [4, 7, 13], [7, 13, 30]
        data = bandpower.get_bandpower(data, low=low, high=high)
コード例 #5
0
            'Sub_ID': subjects,
            'D': D
        }
        with open(savePath, 'wb') as fp:
            pickle.dump(dict_ERSP, fp)

    return new_f, t, Zxx


if __name__ == '__main__':

    channel_limit = 21

    # Save data for all subject
    X, Y_reg, channels, S, D = dl.read_data([1, 2, 3],
                                            list(range(11)),
                                            channel_limit=channel_limit,
                                            rm_baseline=True)
    freq, t, Zxx = STFT(
        X,
        Y_reg,
        S,
        D,
        2,
        30,
        savePath='./raw_data/ERSP_from_raw_100_channel%d_nolog.data' %
        (channel_limit))
    '''
    print('Calculate conditional entropy...')
    _ = add_features.calculate_CE(X, './raw_data/CE_sub100_channel%d.data'%(channel_limit))
    
    # Save data for each subject
コード例 #6
0
    print(CE_X_Y)
    CE_Y_X = get_conditional_entropy(Y, Y, X, X)
    print(CE_Y_X)
    
    print('Takes %.3f seconds'%(time.time()-start_time))
    '''
    '''
    # Load preprocessed ERSP data
    ERSP_all, tmp_all, freqs = dataloader.load_data()
    
    ERSP_all, tmp_all = preprocessing.remove_trials(ERSP_all, tmp_all, 60)
    ERSP_all, _ = preprocessing.standardize(ERSP_all, tmp_all)
    
    correlation_all = get_correlations(ERSP_all)
    '''

    # Load raw data
    for i in range(11):
        X, _, Y_reg, channels = raw_dataloader.read_data([1, 2, 3],
                                                         date=[i],
                                                         pred_type='class',
                                                         rm_baseline=True)
        X, Y_reg = preprocessing.remove_trials(X, Y_reg, 60)
        _ = calculate_CE(X, './raw_data/CE_sub%d' % (i))

    X, _, Y_reg, channels = raw_dataloader.read_data([1, 2, 3],
                                                     date=range(11),
                                                     pred_type='class',
                                                     rm_baseline=True)
    X, Y_reg = preprocessing.remove_trials(X, Y_reg, 60)
    _ = calculate_CE(X, './raw_data/CE_sub100')