def main():
    ############## seeg data ##########
    sid = 10  # 4
    fs = 1000
    class_number = 5
    Session_num, UseChn, EmgChn, TrigChn, activeChan = get_channel_setting(sid)

    loadPath = data_dir + 'preprocessing' + '/P' + str(
        sid) + '/preprocessing2.mat'
    mat = hdf5storage.loadmat(loadPath)
    data = mat['Datacell']
    channelNum = int(mat['channelNum'][0, 0])
    data = np.concatenate((data[0, 0], data[0, 1]), 0)
    del mat
    # standardization
    # no effect. why?
    if 1 == 1:
        chn_data = data[:, -3:]
        data = data[:, :-3]
        scaler = StandardScaler()
        scaler.fit(data)
        data = scaler.transform((data))
        data = np.concatenate((data, chn_data), axis=1)

    # stim0 is trigger channel, stim1 is trigger position calculated from EMG signal.
    chn_names = np.append(["seeg"] * len(UseChn), ["stim0", "emg", "stim1"])
    chn_types = np.append(["seeg"] * len(UseChn), ["stim", "emg", "stim"])
    info = mne.create_info(ch_names=list(chn_names),
                           ch_types=list(chn_types),
                           sfreq=fs)
    raw = mne.io.RawArray(data.transpose(), info)

    # gesture/events type: 1,2,3,4,5
    events0 = mne.find_events(raw, stim_channel='stim0')
    events1 = mne.find_events(raw, stim_channel='stim1')
    # events number should start from 0: 0,1,2,3,4, instead of 1,2,3,4,5
    events0 = events0 - [0, 0, 1]
    events1 = events1 - [0, 0, 1]

    # print(events[:5])  # show the first 5
    # Epoch from 4s before(idle) until 4s after(movement) stim1.
    raw = raw.pick(["seeg"])
    epochs = mne.Epochs(raw, events1, tmin=0, tmax=4, baseline=None)
    # or epoch from 0s to 4s which only contain movement data.
    # epochs = mne.Epochs(raw, events1, tmin=0, tmax=4,baseline=None)

    epoch1 = epochs['0'].get_data(
    )  # 20 trials. 8001 time points per trial for 8s.
    epoch2 = epochs['1'].get_data()
    epoch3 = epochs['2'].get_data()
    epoch4 = epochs['3'].get_data()
    epoch5 = epochs['4'].get_data()
    list_of_epochs = [epoch1, epoch2, epoch3, epoch4, epoch5]
    total_len = list_of_epochs[0].shape[2]

    # validate=test=2 trials
    trial_number = [list(range(epochi.shape[0])) for epochi in list_of_epochs
                    ]  # [ [0,1,2,...19],[0,1,2...19],... ]
    test_trials = [random.sample(epochi, 2) for epochi in trial_number]
    # len(test_trials[0]) # test trials number
    trial_number_left = [
        np.setdiff1d(trial_number[i], test_trials[i])
        for i in range(class_number)
    ]

    val_trials = [
        random.sample(list(epochi), 2) for epochi in trial_number_left
    ]
    train_trials = [
        np.setdiff1d(trial_number_left[i], val_trials[i]).tolist()
        for i in range(class_number)
    ]

    # no missing trials
    assert [
        sorted(test_trials[i] + val_trials[i] + train_trials[i])
        for i in range(class_number)
    ] == trial_number

    test_epochs = [
        epochi[test_trials[clas], :, :]
        for clas, epochi in enumerate(list_of_epochs)
    ]  # [ epoch0,epoch1,epch2,epoch3,epoch4 ]
    val_epochs = [
        epochi[val_trials[clas], :, :]
        for clas, epochi in enumerate(list_of_epochs)
    ]
    train_epochs = [
        epochi[train_trials[clas], :, :]
        for clas, epochi in enumerate(list_of_epochs)
    ]

    wind = 500
    stride = 500
    X_train = []
    y_train = []
    X_val = []
    y_val = []
    X_test = []
    y_test = []

    for clas, epochi in enumerate(test_epochs):
        Xi, y = slide_epochs(epochi, clas, wind, stride)
        assert Xi.shape[0] == len(y)
        X_test.append(Xi)
        y_test.append(y)
    X_test = np.concatenate(X_test, axis=0)  # (1300, 63, 500)
    y_test = np.asarray(y_test)
    y_test = np.reshape(y_test, (-1, 1))  # (5, 270)

    for clas, epochi in enumerate(val_epochs):
        Xi, y = slide_epochs(epochi, clas, wind, stride)
        assert Xi.shape[0] == len(y)
        X_val.append(Xi)
        y_val.append(y)
    X_val = np.concatenate(X_val, axis=0)  # (1300, 63, 500)
    y_val = np.asarray(y_val)
    y_val = np.reshape(y_val, (-1, 1))  # (5, 270)

    for clas, epochi in enumerate(train_epochs):
        Xi, y = slide_epochs(epochi, clas, wind, stride)
        assert Xi.shape[0] == len(y)
        X_train.append(Xi)
        y_train.append(y)
    X_train = np.concatenate(X_train, axis=0)  # (1300, 63, 500)
    y_train = np.asarray(y_train)
    y_train = np.reshape(y_train, (-1, 1))  # (5, 270)
    chn_num = X_train.shape[1]

    train_set = myDataset(X_train, y_train)
    val_set = myDataset(X_val, y_val)
    test_set = myDataset(X_test, y_test)

    batch_size = 32
    train_loader = DataLoader(dataset=train_set,
                              batch_size=batch_size,
                              shuffle=True,
                              pin_memory=False)
    val_loader = DataLoader(dataset=val_set,
                            batch_size=batch_size,
                            shuffle=True,
                            pin_memory=False)
    test_loader = DataLoader(dataset=test_set,
                             batch_size=batch_size,
                             shuffle=True,
                             pin_memory=False)

    ########## end  seeg #########################
    global args, enable_cuda
    ################################################################ INIT #################################################################################

    args = parser.parse_args()

    cwd = os.getcwd()
    #dpath=os.path.dirname(cwd)
    dpath = '/Volumes/Samsung_T5/data/braindecode/'
    result_dir = dpath + 'result/'
    if not os.path.exists(result_dir):
        os.makedirs(result_dir)
    #Paths for data, model and checkpoint
    data_path = os.path.join(dpath, 'Data/')
    model_save_path = os.path.join(dpath, 'Models',
                                   'Model_GumbelregHighgamma_M' + str(args.M))
    checkpoint_path = os.path.join(
        dpath, 'Models', 'Checkpoint_GumbelregHighgamma_M' + str(args.M))
    if not os.path.isdir(os.path.join(dpath, 'Models')):
        os.makedirs(os.path.join(dpath, 'Models'))

    #Check if CUDA is available
    enable_cuda = torch.cuda.is_available()
    if (args.verbose):
        print('GPU computing: ', enable_cuda)

    #Set random seed
    if (args.seed == 0):
        args.seed = randint(1, 99999)

    #Initialize devices with random seed
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    training_accs = []
    val_accs = []
    test_accs = []

    #Create a vector of length epochs, decaying start_value to end_value exponentially, reaching end_value at end_epoch
    def exponential_decay_schedule(start_value, end_value, epochs, end_epoch):
        t = torch.FloatTensor(torch.arange(0.0, epochs))
        p = torch.clamp(t / end_epoch, 0, 1)
        out = start_value * torch.pow(end_value / start_value, p)

        return out

    #Network loss function
    def loss_function(output, target, model, lamba, weight_decay):
        l = nn.CrossEntropyLoss()
        sup_loss = l(output, target)
        reg = model.regularizer(lamba, weight_decay)

        return sup_loss, reg

    #Create schedule for temperature and regularization threshold
    temperature_schedule = exponential_decay_schedule(args.start_temp,
                                                      args.end_temp,
                                                      args.epochs,
                                                      int(args.epochs * 3 / 4))
    thresh_schedule = exponential_decay_schedule(10.0, 1.1, args.epochs,
                                                 args.epochs)

    #Load data
    num_subjects = 5
    input_dim = [44, 1125]
    #train_loader1,val_loader1,test_loader1 = all_subject_loader_HGD(batch_size=args.batch_size,train_split=args.train_split,path=data_path,num_subjects=num_subjects)

    ################################################################ SUBJECT-INDEPENDENT CHANNEL SELECTION #################################################################################

    if (args.verbose):
        print('Start training')

    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    #Instantiate model
    model = SelectionNet(input_dim, args.M).float()
    if (enable_cuda):
        model.cuda()
    model.set_freeze(False)

    optimizer = torch.optim.Adam(model.parameters(), args.lr)

    prev_val_loss = 100
    patience_timer = 0
    early_stop = False
    epoch = 0
    fig, ax = plt.subplots()

    while epoch in range(args.epochs) and (not early_stop):

        #Update temperature and threshold
        model.set_thresh(thresh_schedule[epoch])
        model.set_temperature(temperature_schedule[epoch])

        #Perform training step
        train(train_loader, model, loss_function, optimizer, epoch,
              args.weight_decay, args.lamba, args.gradacc, args.verbose)
        val_loss = validate(val_loader, model, loss_function, epoch,
                            args.weight_decay, args.lamba, args.verbose)
        tr_acc, val_acc, test_acc = test(train_loader, val_loader, test_loader,
                                         model, loss_function,
                                         args.weight_decay, args.verbose)

        #Extract selection neuron entropies, current selections and probability distributions
        H, sel, probas = model.monitor()
        ax.plot(probas.detach().numpy())
        fig.savefig(result_dir + 'prob_dist' + str(epoch) + '.png')
        ax.clear()
        #fig.clear()

        #If selection convergence is reached, enable early stopping scheme
        if ((torch.mean(H.data) <= args.entropy_lim)
                and (val_loss > prev_val_loss - args.stop_delta)):
            patience_timer += 1
            if (args.verbose):
                print('Early stopping timer ', patience_timer)
            if (patience_timer == args.patience):
                early_stop = True
        else:
            patience_timer = 0
            H, sel, probas = model.monitor()
            torch.save(model.state_dict(), checkpoint_path)
            prev_val_loss = val_loss

        epoch += 1

    if (args.verbose):
        print('Channel selection finished')

    #Store subject independent model
    model.load_state_dict(torch.load(checkpoint_path))
    pretrained_path = str(model_save_path +
                          'all_subjects_channels_selected.pt')
    torch.save(model.state_dict(), pretrained_path)

    ################################################################ SUBJECT FINETUNING  #################################################################################
    ## freeze the selection layer and train the model other part
    if (args.verbose):
        print('Start subject specific training')

    for k in range(1, num_subjects + 1):

        if (args.verbose):
            print('Start training for subject ' + str(k))

        torch.manual_seed(args.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

        #Load subject independent model and freeze selection neurons
        model = SelectionNet(input_dim, args.M)
        model.load_state_dict(torch.load(pretrained_path))
        if (enable_cuda):
            model.cuda()
        model.set_freeze(True)

        #Load subject dependent data
        train_loader, val_loader, test_loader = within_subject_loader_HGD(
            subject=k,
            batch_size=args.batch_size,
            train_split=args.train_split,
            path=data_path)

        optimizer = torch.optim.Adam(model.parameters(), args.lr)

        prev_val_loss = 100
        patience_timer = 0
        early_stop = False
        epoch = 0
        while epoch in range(args.epochs) and (not early_stop):

            #Perform train step
            train(train_loader, model, loss_function, optimizer, epoch,
                  args.weight_decay, args.lamba, args.gradacc, args.verbose)
            val_loss = validate(val_loader, model, loss_function, epoch,
                                args.weight_decay, args.lamba, args.verbose)
            tr_acc, val_acc, test_acc = test(train_loader, val_loader,
                                             test_loader, model, loss_function,
                                             args.weight_decay, args.verbose)

            #Extract selection neuron entropies, current selections and probability distributions
            H, sel, probas = model.monitor()

            #Perform early stopping
            if (val_loss > prev_val_loss - args.stop_delta):
                patience_timer += 1
                if (args.verbose):
                    print('Early stopping timer ', patience_timer)
                if (patience_timer == args.patience):
                    early_stop = True
            else:
                patience_timer = 0
                torch.save(model.state_dict(), checkpoint_path)
                prev_val_loss = val_loss

            epoch += 1

        #Store model with lowest validation loss
        model.load_state_dict(torch.load(checkpoint_path))
        path = str(model_save_path + 'finished_subject' + str(k) + '.pt')
        torch.save(model.state_dict(), path)

        #Evaluate model
        tr_acc, val_acc, test_acc = test(train_loader, val_loader, test_loader,
                                         model, loss_function,
                                         args.weight_decay, args.verbose)
        training_accs.append(tr_acc)
        val_accs.append(val_acc)
        test_accs.append(test_acc)

################################################################ TERMINATION  #################################################################################

    print('Selection', sel.data)
    print('Training accuracies', training_accs)
    print('Validation accuracies', val_accs)
    print('Testing accuracies', test_accs)

    tr_med = statistics.median(training_accs)
    val_med = statistics.median(val_accs)
    test_med = statistics.median(test_accs)
    tr_mean = statistics.mean(training_accs)
    val_mean = statistics.mean(val_accs)
    test_mean = statistics.mean(test_accs)

    print('Training median accuracy', tr_med)
    print('Validation median accuracy', val_med)
    print('Testing median accuracy', test_med)
    print('Training mean accuracy', tr_mean)
    print('Validation mean accuracy', val_mean)
    print('Testing mean accuracy', test_mean)
示例#2
0
seed = 20200220  # random seed to make results reproducible
set_random_seeds(seed=seed)

cuda = torch.cuda.is_available(
)  # check if GPU is available, if True chooses to use it
device = 'cuda' if cuda else 'cpu'
if cuda:
    torch.backends.cudnn.benchmark = True

import inspect as i
import sys
#sys.stdout.write(i.getsource(deepnet))

sid = 10  #4
class_number = 5
Session_num, UseChn, EmgChn, TrigChn = get_channel_setting(sid)
#fs=[Frequencies[i,1] for i in range(Frequencies.shape[0]) if Frequencies[i,0] == sid][0]
fs = 1000

project_dir = data_dir + 'preprocessing' + '/P' + str(sid) + '/'
model_path = project_dir + 'pth' + '/'
if not os.path.exists(model_path):
    os.makedirs(model_path)

#[Frequencies[i,1] for i in range(Frequencies.shape[0]) if Frequencies[i,0] == sid][0]

loadPath = data_dir + 'preprocessing' + '/P' + str(sid) + '/preprocessing2.mat'
mat = hdf5storage.loadmat(loadPath)
data = mat['Datacell']
channelNum = int(mat['channelNum'][0, 0])
data = np.concatenate((data[0, 0], data[0, 1]), 0)
示例#3
0
from gesture.preprocess.chn_settings import get_channel_setting

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'


import inspect as i
import sys
#sys.stdout.write(i.getsource(deepnet))

#a=torch.randn(1, 1, 208, 500)
#model = deepnet_resnet(208,5,input_window_samples=500,expand=False)
#model.train()
#b=model(a)

pn=10 #4
Session_num,UseChn,EmgChn,TrigChn, activeChan = get_channel_setting(pn)
#fs=[Frequencies[i,1] for i in range(Frequencies.shape[0]) if Frequencies[i,0] == pn][0]
fs=1000

[Frequencies[i,1] for i in range(Frequencies.shape[0]) if Frequencies[i,0] == pn][0]

loadPath = data_dir+'preprocessing'+'/P'+str(pn)+'/preprocessing2.mat'
mat=hdf5storage.loadmat(loadPath)
data = mat['Datacell']
channelNum=int(mat['channelNum'][0,0])
data=np.concatenate((data[0,0],data[0,1]),0)
del mat
# standardization
# no effect. why?
if 1==1:
    chn_data=data[:,-3:]