Beispiel #1
0
def main(session=1, mode='subject_dependent'):
    os.environ['CUDA_VISIBLE_DEVICES'] = '4'
    # prepare data
    session = session
    balance = True
    shuffle = False
    modal = 'concat'
    nor_method = 1

    # reading the data in the whole dataset
    all_individual_data = []
    for i in range(1, 16):
        print("contructing dataset...")
        eeg = SEED_IV(session=session, individual=i, modal=modal, shuffle=shuffle, balance=balance,
                      normalization=nor_method)
        _train_X, _train_Y = eeg.get_train_data()
        _test_X, _test_Y = eeg.get_test_data()
        all_individual_data.append([(_train_X, _train_Y), (_test_X, _test_Y)])

    # Hyper-parameters
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    epochs = 50
    batch_size = 64
    learning_rate = 1e-3
    criterion = CrossEntropyLoss()
    for idx in range(1, 16):
        if mode == 'subject_dependent':
            train_X, train_Y = all_individual_data[idx-1][0][0], all_individual_data[idx-1][0][1]
            test_X, test_Y = all_individual_data[idx-1][1][0], all_individual_data[idx-1][1][1]
            exp_des = "%d_dependent_in_seesion_%d_%s_%s_%s_%d_%d" % (
                idx, session, 'balance' if balance else 'without_balance',
                'shuffle' if shuffle else "without_shuffle", 'seed_iv', epochs, batch_size)
            print("starting subject-dependent training experiments on individual %d in session %d"% (idx, session))
        elif mode == 'subject_independent':
            train_X = np.vstack([np.vstack((e[0][0], e[1][0])) for i, e in enumerate(all_individual_data) if i != idx-1])
            train_Y = np.hstack([np.hstack((e[0][1], e[1][1])) for i, e in enumerate(all_individual_data) if i != idx-1])
            test_X = np.vstack((all_individual_data[idx-1][0][0], all_individual_data[idx-1][1][0]))
            test_Y = np.hstack((all_individual_data[idx-1][0][1], all_individual_data[idx-1][1][1]))
            exp_des = "%d_independent_as_testset_in_seesion_%d_%s_%s_%s_%d_%d" % (
                idx, session, 'balance' if balance else 'without_balance',
                'shuffle' if shuffle else "without_shuffle", 'seed_iv', epochs, batch_size)
            print("starting subject-independent training experiments with individual %d in session %d as test set" % (idx, session))
        else:
            raise ValueError

        print("train_X shape", train_X.shape)
        print("train_Y shape", train_Y.shape)
        print("test_X shape", test_X.shape)
        print("test_Y shape", test_Y.shape)
        train_loader = DataLoader(dataset=SEED_IV_DATASET(train_X, train_Y), batch_size=batch_size, shuffle=shuffle,
                                  num_workers=4)
        test_loader = DataLoader(dataset=SEED_IV_DATASET(test_X, test_Y), batch_size=batch_size, shuffle=shuffle,
                                 num_workers=4)

        print("model construction...")
        net = SimpleDNN(num_layers=4, hidden_size=256, drop=0.5)
        net = net.to(device)
        optimization = Adam(net.parameters(), lr=learning_rate)
        save_model_path = '../../saved_models/%s/session_%d/subject_%d_as_testset' % (
            net.__class__.__name__, session,
            idx) if mode == 'subject_independent' else '../../saved_models/%s/session_%d/_subject%d' % (
            net.__class__.__name__, session, idx)
        if not os.path.exists(save_model_path):
            os.makedirs(save_model_path)

        # save model training state
        running_loss_list = []
        running_acc_list = []
        testing_loss_list =[]
        testing_acc_list = []
        best_acc = -1
        print("start training...")
        for epoch in range(epochs):
            net.train()
            running_loss = 0.0
            correct = 0.0
            total = 0.0
            for i, (feature, target) in enumerate(train_loader):
                optimization.zero_grad()
                feature = feature.to(device)
                target = target.type(torch.LongTensor).to(device)
                out = net(feature)
                loss = criterion(out, target)
                loss.backward()
                optimization.step()
                running_loss += loss.item()
                _, prediction = torch.max(out.data, dim=-1)
                total += target.size(0)
                correct += prediction.eq(target.data).cpu().sum()
            cur_loss = running_loss / len(train_loader)
            cur_acc = correct / total
            if isinstance(cur_acc, torch.Tensor):
                cur_acc = cur_acc.item()
            if isinstance(cur_loss, torch.Tensor):
                cur_loss = cur_loss.item()
            print('Loss: %.10f | Acc: %.3f%% (%d/%d)' % (
                cur_loss, 100 * cur_acc, correct, total))
            running_loss_list.append(cur_loss)
            running_acc_list.append(cur_acc)

            if epoch % 1 == 0:
                net.eval()
                print("start evaluating...")
                testing_loss = 0.0
                test_correct = 0.0
                test_total = 0.0
                for i, (feature, target) in enumerate(test_loader):
                    feature = feature.to(device)
                    target = target.type(torch.LongTensor).to(device)
                    with torch.no_grad():
                        out = net(feature)
                        loss = criterion(out, target)
                        testing_loss += loss.item()
                        _, prediction = torch.max(out.data, dim=-1)
                        # print(prediction)
                        test_total += target.size(0)
                        test_correct += prediction.eq(target.data).cpu().sum()
                test_acc = test_correct / test_total
                test_loss = testing_loss / len(test_loader)
                if isinstance(test_acc, torch.Tensor):
                    test_acc = test_acc.item()
                if isinstance(test_loss, torch.Tensor):
                    test_loss = test_loss.item()
                print('Testset Loss: %.10f | Acc: %.3f%% (%d/%d)' % (
                    test_loss, 100 * test_acc, test_correct, test_total))
                testing_acc_list.append(test_acc)
                testing_loss_list.append(test_loss)
                if test_acc > best_acc:
                    best_acc = test_acc
                    print("better model founded in testsets, start saving new model")
                    model_name = '%s_%s' % (net.__class__.__name__, str(best_acc)[2:6])
                    state = {
                        'net': net.state_dict(),
                        'epoch': epoch,
                        'best_acc': best_acc,
                        'current_loss': test_loss
                    }
                    torch.save(state, os.path.join(save_model_path, model_name))
        plot_acc_loss_curve({'train_loss': running_loss_list,
                            'train_acc': running_acc_list,
                            'test_loss': testing_loss_list,
                            'test_acc': testing_acc_list}, net.__class__.__name__, exp_des)
Beispiel #2
0
def subject_dependent(individual=1, class_target=4):
    class_list = [
        "Valence", "Arousal", "Dominance", "Liking", "Valence-Arousal"
    ]
    class_nums = [2, 2, 2, 2, 4]
    test_loss_list = []  # 记录每一折验证的loss
    test_acc_list = []  # 记录每一折验证的acc
    # prepare data
    nor_method = 1
    label_smooth = 0.1
    shuffle = True

    # reading the data in the whole dataset
    deap = DEAP(individual=individual, normalization=nor_method)
    train_X, train_Y = deap.get_train_data()
    validate_X, validate_Y = deap.get_validate_data()
    test_X, test_Y = deap.get_test_data()

    # Hyper-parameters
    epochs = 150
    batch_size = 512
    learning_rate = 1e-3
    criterion = LabelSmoothSoftmax(lb_smooth=label_smooth)
    #criterion_attn = CrossEntropyLoss()
    print(
        "starting subject-dependent training experiments on individual %d class %s"
        % (individual, class_list[class_target]))

    print("train_X shape", train_X.shape)
    print("train_Y shape", train_Y.shape)
    print("validate_X shape", validate_X.shape)
    print("validate_Y shape", validate_Y.shape)
    print("test_X shape", test_X.shape)
    print("test_Y shape", test_Y.shape)

    train_Y, test_Y, validate_Y = train_Y[:, class_target].squeeze(
    ), test_Y[:, class_target].squeeze(), validate_Y[:,
                                                     class_target].squeeze()
    train_loader = DataLoader(dataset=DEAP_DATASET(train_X, train_Y),
                              batch_size=batch_size,
                              shuffle=shuffle,
                              num_workers=0)
    validate_loader = DataLoader(dataset=DEAP_DATASET(validate_X, validate_Y),
                                 batch_size=batch_size,
                                 shuffle=shuffle,
                                 num_workers=0)
    test_loader = DataLoader(dataset=DEAP_DATASET(test_X, test_Y),
                             batch_size=batch_size,
                             shuffle=shuffle,
                             num_workers=0)

    exp_des = "%d_dependent_%s_%s_%d_%d_%s" % (
        individual, 'shuffle' if shuffle else "without_shuffle", 'deap',
        epochs, batch_size, class_list[class_target])

    print("model construction...")
    net = SimpleDNN(num_layers=3,
                    input_size=200,
                    output_size=class_nums[class_target])
    # if fine_tuning we continue train the pretrained model
    net = net.to(device)
    save_model_path = '../../saved_models/%s/deap/subject_%d/%s/' % (
        net.__class__.__name__, individual, class_list[class_target])

    if not os.path.exists(save_model_path):
        os.makedirs(save_model_path)
    optimization = Adam(net.parameters(), lr=learning_rate, weight_decay=0.001)

    # save model training state
    running_loss_list = []
    running_acc_list = []
    validate_loss_list = []
    validate_acc_list = []
    best_acc = -1
    print("start training...")
    scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer=optimization, T_max=epochs)
    scheduler_warmup = GradualWarmupScheduler(optimizer=optimization,
                                              multiplier=10,
                                              total_epoch=np.ceil(0.1 *
                                                                  epochs),
                                              after_scheduler=scheduler_cosine)
    for epoch in range(epochs):
        net.train()
        running_loss = 0.0
        correct = 0.0
        total = 0.0
        for i, (feature, target) in enumerate(train_loader):
            feature = feature.reshape(-1, 200)
            optimization.zero_grad()
            # print("脏数据统计", torch.sum(torch.isnan(feature), dim=0))
            feature = feature.to(device)
            target = target.type(torch.LongTensor).to(device)
            out = net(feature)
            # print("训练集", out.data[:5])
            # print("训练集",eeg_attn.shape, eeg_attn.data[:5])
            # print("训练集",eye_attn.shape, eye_attn.data[:5])
            # print("batch output",out[0])
            cross_entropy_loss = criterion(out, target)
            # eeg_attn_loss = criterion_attn(eeg_attn, target)
            # eye_attn_loss = criterion_attn(eye_attn, target)
            # loss = cross_entropy_loss
            # print("交叉熵损失", cross_entropy_loss.data)
            # print("eeg注意力损失", eeg_attn_loss.data)
            # print("eye注意力损失", eeg_attn_loss.data)
            cross_entropy_loss.backward()
            clip_grad_norm_(net.parameters(), max_norm=10)
            # for name, parms in net.named_parameters():
            #     print('打印梯度')
            #     print('-->name:', name, '-->grad_requirs:', parms.requires_grad, \
            #           ' -->grad_value:', parms.grad)
            optimization.step()
            running_loss += cross_entropy_loss.item()
            # print("batch loss", loss.item())
            _, prediction = torch.max(out.data, dim=-1)
            # print('训练集', prediction[:5])
            total += target.size(0)
            correct += prediction.eq(target.data).cpu().sum().item()
        cur_loss = running_loss / total
        cur_acc = correct / total
        # print(cur_acc, correct, total)
        if isinstance(cur_acc, torch.Tensor):
            cur_acc = cur_acc.item()
        if isinstance(cur_loss, torch.Tensor):
            cur_loss = cur_loss.item()
        print('Training Loss: %.10f | Training Acc: %.3f%% (%d/%d)' %
              (cur_loss, 100 * cur_acc, correct, total))
        running_loss_list.append(cur_loss)
        running_acc_list.append(cur_acc)
        scheduler_warmup.step()
        if epoch % 1 == 0:
            net.eval()
            print("start evaluating...")
            validate_loss = 0.0
            validate_correct = 0.0
            validate_total = 0.0
            for i, (feature, target) in enumerate(validate_loader):
                feature = feature.reshape(-1, 200)
                feature = feature.to(device)
                target = target.type(torch.LongTensor).to(device)
                with torch.no_grad():
                    out = net(feature)
                    # print("c集", out.data[:5])
                    # print("c集", eeg_attn.data[:5])
                    # print("c集", eye_attn.data[:5])
                    loss = criterion(out, target)
                    validate_loss += loss.item()
                    _, prediction = torch.max(out.data, dim=-1)
                    # print('验证集',prediction[:5])
                    validate_total += target.size(0)
                    validate_correct += prediction.eq(
                        target.data).cpu().sum().item()
            validate_acc = validate_correct / validate_total
            validate_loss = validate_loss / validate_total
            if isinstance(validate_acc, torch.Tensor):
                validate_acc = validate_acc.item()
            if isinstance(validate_loss, torch.Tensor):
                validate_loss = validate_loss.item()
            print('Validate Loss: %.10f | Validate-Acc: %.3f%% (%d/%d)' %
                  (validate_loss, 100 * validate_acc, validate_correct,
                   validate_total))
            validate_acc_list.append(validate_acc)
            validate_loss_list.append(validate_loss)
            if validate_acc > best_acc:
                best_acc = validate_acc
                print(
                    "better model founded in testsets, start saving new model")
                model_name = '%s' % (net.__class__.__name__)
                state = {
                    'net': net.state_dict(),
                    'epoch': epoch,
                    'best_acc': best_acc,
                    'current_loss': validate_loss
                }
                torch.save(state, os.path.join(save_model_path, model_name))
    # 开始计算测试集
    checkpoint = torch.load(
        os.path.join(save_model_path, net.__class__.__name__))
    net.load_state_dict(checkpoint['net'])
    print("start evaluating...")
    testing_loss = 0.0
    test_correct = 0.0
    test_total = 0.0
    for i, (feature, target) in enumerate(test_loader):
        feature = feature.reshape(-1, 200)
        feature = feature.to(device)
        target = target.type(torch.LongTensor).to(device)
        with torch.no_grad():
            out = net(feature)
            loss = criterion(out, target)
            testing_loss += loss.item()
            _, prediction = torch.max(out.data, dim=-1)
            # print(prediction)
            test_total += target.size(0)
            test_correct += prediction.eq(target.data).cpu().sum().item()
    test_acc = test_correct / test_total
    test_loss = testing_loss / test_total
    if isinstance(test_acc, torch.Tensor):
        test_acc = test_acc.item()
    if isinstance(test_loss, torch.Tensor):
        test_loss = test_loss.item()
    print('Test Loss: %.10f | Test Acc: %.3f%% (%d/%d)' %
          (test_loss, 100 * test_acc, test_correct, test_total))
    test_acc_list.append(test_acc)
    test_loss_list.append(test_loss)
    plot_acc_loss_curve(
        {
            'train_loss': running_loss_list,
            'train_acc': running_acc_list,
            'test_loss': validate_loss_list,
            'test_acc': validate_acc_list
        }, net.__class__.__name__, exp_des)
    pd.DataFrame.from_dict({
        'test_loss': test_loss_list,
        'test_acc': test_acc_list
    }).to_csv('./results/deap_individual_%d_%s.csv' %
              (individual, class_list[class_target]),
              mode='w',
              index=False,
              header=True,
              encoding='utf-8')
Beispiel #3
0
def subject_dependent_CV(session=1):
    # prepare data
    session = session
    balance = True
    shuffle = False
    modal = 'concat'
    nor_method = 1
    label_smooth = 0.3
    fine_tuning = True
    best_acc_list = []
    best_precision_list = []
    best_recall_list = []
    best_f1_list = []

    result_save_path = './seed_cv_results/session{}'.format(session)
    if not os.path.exists(result_save_path):
        os.makedirs(result_save_path)

    # reading the data in the whole dataset
    for idx in range(1, 16):
        print("contructing dataset...")
        eeg = SEED_IV(session=session,
                      individual=idx,
                      modal=modal,
                      shuffle=shuffle,
                      balance=balance,
                      k_fold=3)
        k_fold_data = eeg.get_training_kfold_data()
        for fold, (train_X, train_Y, test_X, test_Y) in enumerate(k_fold_data):
            best_acc = -1
            best_precision = -1
            best_recall = -1
            best_f1 = -1
            print("train_X shape", train_X.shape)
            print("train_Y shape", train_Y.shape)
            print("test_X shape", test_X.shape)
            print("test_Y shape", test_Y.shape)
            train_X, train_Y, test_X, test_Y = seed_normalization(train_X,
                                                                  train_Y,
                                                                  test_X,
                                                                  test_Y,
                                                                  nor_method=1,
                                                                  merge=0,
                                                                  column=0)
            train_X = train_X.astype(np.float32)
            test_X = test_X.astype(np.float32)
            train_Y = train_Y.astype(np.int32)
            test_Y = test_Y.astype(np.int32)
            # Hyper-parameters
            device = torch.device(
                'cuda') if torch.cuda.is_available() else torch.device('cpu')
            epochs = 500
            batch_size = 1024
            learning_rate = 1e-4
            criterion = LabelSmoothSoftmax(lb_smooth=label_smooth)

            exp_des = "%d_dependent_in_seesion_%d_fold%d_%s_%s_%s_%d_%d" % (
                idx, session, fold,
                'balance' if balance else 'without_balance', 'shuffle'
                if shuffle else "without_shuffle", 'seed', epochs, batch_size)
            print(
                "starting subject-dependent training experiments on individual %d in session %d"
                % (idx, session))

            print("train_X shape", train_X.shape)
            print("train_Y shape", train_Y.shape)
            print("test_X shape", test_X.shape)
            print("test_Y shape", test_Y.shape)

            print("model construction...")
            net = Hierarchical_ATTN_With_Senti_Map()
            # if fine_tuning we continue train the pretrained model

            net = net.to(device)
            save_model_path = '../../saved_models/%s/session_%d/subject_%d/fold_%d' % (
                net.__class__.__name__, session, idx, fold)
            if not os.path.exists(save_model_path):
                os.makedirs(save_model_path)
            optimization = RMSprop(net.parameters(),
                                   lr=learning_rate,
                                   weight_decay=0.01)

            # save model training state
            running_loss_list = []
            running_acc_list = []
            testing_loss_list = []
            testing_acc_list = []
            print("start training...")
            scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer=optimization, T_max=epochs)
            scheduler_warmup = GradualWarmupScheduler(
                optimizer=optimization,
                multiplier=10,
                total_epoch=np.ceil(0.1 * epochs),
                after_scheduler=scheduler_cosine)
            for epoch in range(epochs):
                net.train()
                optimization.zero_grad()
                #print("脏数据统计", torch.sum(torch.isnan(feature), dim=0))
                eeg = train_X[:, :310]
                eye = train_X[:, 310:]
                eeg = eeg.reshape(-1, 62, 5)
                eeg = torch.FloatTensor(eeg).to(device)
                eye = torch.FloatTensor(eye).to(device)
                #print("eeg type {}, eye type {}".format(type(eeg),type(eye)))
                target = torch.LongTensor(train_Y).to(device)
                out = net(eeg, eye)
                #print("batch output",out[0])
                loss = criterion(out, target)
                loss.backward()
                clip_grad_norm_(net.parameters(), max_norm=10)
                optimization.step()
                scheduler_warmup.step()
                running_loss = loss.item()
                #print("batch loss", loss.item())
                _, prediction = torch.max(out.data, dim=-1)
                total = target.size(0)
                correct = prediction.eq(target.data).cpu().sum().item()

                cur_loss = running_loss / len(train_X)
                cur_acc = correct / total
                if isinstance(cur_acc, torch.Tensor):
                    cur_acc = cur_acc.item()
                if isinstance(cur_loss, torch.Tensor):
                    cur_loss = cur_loss.item()
                print(
                    'Epoch %d/%d\tTraining Loss: %.10f | Acc: %.3f%% (%d/%d)' %
                    (epoch, epochs, cur_loss, 100 * cur_acc, correct, total))
                running_loss_list.append(cur_loss)
                running_acc_list.append(cur_acc)

                if epoch % 1 == 0:
                    net.eval()
                    print("start evaluating...")
                    eeg = test_X[:, :310]
                    eye = test_X[:, 310:]
                    eeg = eeg.reshape(-1, 62, 5)
                    eeg = torch.FloatTensor(eeg).to(device)
                    eye = torch.FloatTensor(eye).to(device)
                    target = torch.LongTensor(test_Y).to(device)
                    with torch.no_grad():
                        out = net(eeg, eye)
                        loss = criterion(out, target)
                        testing_loss = loss.item()
                        _, prediction = torch.max(out.data, dim=-1)
                        # print(prediction)
                        test_total = target.size(0)
                        test_correct = prediction.eq(
                            target.data).cpu().sum().item()

                        y_pre = prediction.cpu().numpy()
                        y_true = target.cpu().numpy()

                        test_acc = accuracy_score(y_true, y_pre)

                        test_loss = testing_loss / test_total
                        if isinstance(test_acc, torch.Tensor):
                            test_acc = test_acc.item()
                        if isinstance(test_loss, torch.Tensor):
                            test_loss = test_loss.item()
                        print('Testset Loss: %.10f | Acc: %.3f%% (%d/%d)' %
                              (test_loss, 100 * test_acc, test_correct,
                               test_total))
                        testing_acc_list.append(test_acc)
                        testing_loss_list.append(test_loss)
                        if test_acc > best_acc:
                            best_acc = test_acc
                            best_precision = precision_score(y_true,
                                                             y_pre,
                                                             average="macro")
                            best_recall = precision_score(y_true,
                                                          y_pre,
                                                          average="macro")
                            best_f1 = f1_score(y_true, y_pre, average="macro")
                            print(
                                "better model founded in testsets, start saving new model"
                            )
                            model_name = '%s' % (net.__class__.__name__)
                            state = {
                                'net': net.state_dict(),
                                'epoch': epoch,
                                'best_acc': best_acc,
                                'current_loss': test_loss
                            }
                            torch.save(
                                state, os.path.join(save_model_path,
                                                    model_name))
            best_f1_list.append(best_f1)
            best_acc_list.append(best_acc)
            best_precision_list.append(best_precision)
            best_recall_list.append(best_recall)

            plot_acc_loss_curve(
                {
                    'train_loss': running_loss_list,
                    'train_acc': running_acc_list,
                    'test_loss': testing_loss_list,
                    'test_acc': testing_acc_list
                }, net.__class__.__name__, exp_des)
    df = pd.DataFrame().from_dict({
        "acc": best_acc_list,
        "precision": best_precision_list,
        "recall": best_recall_list,
        "f1": best_f1_list
    })
    df_mean = df.mean()
    df_std = df.std()
    df = df.append(df_mean, ignore_index=True)
    df = df.append(df_std, ignore_index=True)
    df.to_csv(result_save_path + '/results.csv')
Beispiel #4
0
def main(session=1, mode='subject_dependent'):
    os.environ['CUDA_VISIBLE_DEVICES'] = '2'
    # prepare data
    session = session
    balance = False
    shuffle = False
    modal = 'concat'
    nor_method = 1
    label_smooth = 0.1
    fine_tuning = True

    # reading the data in the whole dataset
    all_individual_data = []
    for i in range(1, 16):
        print("contructing dataset...")
        eeg = SEED_IV(session=session,
                      individual=i,
                      modal=modal,
                      shuffle=shuffle,
                      balance=balance,
                      normalization=nor_method)
        _train_X, _train_Y = eeg.get_train_data()
        _test_X, _test_Y = eeg.get_test_data()
        all_individual_data.append([(_train_X, _train_Y), (_test_X, _test_Y)])

    # Hyper-parameters
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    epochs = 120
    batch_size = 128
    learning_rate = 1e-3
    criterion = LabelSmoothSoftmax(lb_smooth=label_smooth)
    for idx in range(1, 16):
        if mode == 'subject_dependent':
            train_X, train_Y = all_individual_data[
                idx - 1][0][0], all_individual_data[idx - 1][0][1]
            test_X, test_Y = all_individual_data[
                idx - 1][1][0], all_individual_data[idx - 1][1][1]
            exp_des = "%d_dependent_in_seesion_%d_%s_%s_%s_%d_%d" % (
                idx, session, 'balance' if balance else 'without_balance',
                'shuffle' if shuffle else "without_shuffle", 'seed', epochs,
                batch_size)
            print(
                "starting subject-dependent training experiments on individual %d in session %d"
                % (idx, session))
        elif mode == 'subject_independent':
            train_X = np.vstack([
                np.vstack((e[0][0], e[1][0]))
                for i, e in enumerate(all_individual_data) if i != idx - 1
            ])
            train_Y = np.hstack([
                np.hstack((e[0][1], e[1][1]))
                for i, e in enumerate(all_individual_data) if i != idx - 1
            ])
            test_X = np.vstack((all_individual_data[idx - 1][0][0],
                                all_individual_data[idx - 1][1][0]))
            test_Y = np.hstack((all_individual_data[idx - 1][0][1],
                                all_individual_data[idx - 1][1][1]))
            exp_des = "%d_independent_as_testset_in_seesion_%d_%s_%s_%s_%d_%d" % (
                idx, session, 'balance' if balance else 'without_balance',
                'shuffle' if shuffle else "without_shuffle", 'seed', epochs,
                batch_size)
            print(
                "starting subject-independent training experiments with individual %d in session %d as test set"
                % (idx, session))
        else:
            raise ValueError

        print("train_X shape", train_X.shape)
        print("train_Y shape", train_Y.shape)
        print("test_X shape", test_X.shape)
        print("test_Y shape", test_Y.shape)
        train_loader = DataLoader(dataset=SEED_IV_DATASET(train_X, train_Y),
                                  batch_size=batch_size,
                                  shuffle=shuffle,
                                  num_workers=4)
        test_loader = DataLoader(dataset=SEED_IV_DATASET(test_X, test_Y),
                                 batch_size=batch_size,
                                 shuffle=shuffle,
                                 num_workers=4)

        print("model construction...")
        net = Hierarchical_ATTN()
        # if fine_tuning we continue train the pretrained model
        if mode == 'subject_dependent' and fine_tuning:
            load_path = "../../saved_models/%s/session_%d/subject_%d_as_testset" % (
                net.__class__.__name__, session, idx)
            files = os.listdir(load_path)
            best_model = max(files)
            checkpoint = torch.load(os.path.join(load_path, best_model))
            net.load_state_dict(checkpoint['net'])
            learning_rate = 1e-5
            batch_size = train_X.shape[0]

        net = net.to(device)
        save_model_path = '../../saved_models/%s/session_%d/subject_%d_as_testset' % (
            net.__class__.__name__, session, idx
        ) if mode == 'subject_independent' else '../../saved_models/%s/session_%d/subject_%d' % (
            net.__class__.__name__, session, idx)
        if not os.path.exists(save_model_path):
            os.makedirs(save_model_path)
        optimization = Adam(net.parameters(),
                            lr=learning_rate,
                            weight_decay=0.001)

        # save model training state
        running_loss_list = []
        running_acc_list = []
        testing_loss_list = []
        testing_acc_list = []
        best_acc = -1
        print("start training...")
        scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer=optimization, T_max=epochs)
        scheduler_warmup = GradualWarmupScheduler(
            optimizer=optimization,
            multiplier=10,
            total_epoch=np.ceil(0.1 * epochs),
            after_scheduler=scheduler_cosine)
        for epoch in range(epochs):
            net.train()
            running_loss = 0.0
            correct = 0.0
            total = 0.0
            for i, (feature, target) in enumerate(train_loader):
                optimization.zero_grad()
                #print("脏数据统计", torch.sum(torch.isnan(feature), dim=0))
                eeg = feature[:, :310]
                eye = feature[:, 310:]
                eeg = eeg.reshape(-1, 62, 5)
                eeg = eeg.to(device)
                eye = eye.to(device)
                target = target.type(torch.LongTensor).to(device)
                out = net(eeg, eye)
                #print("batch output",out[0])
                cross_entropy_loss = criterion(out, target)
                cross_entropy_loss.backward()
                clip_grad_norm_(net.parameters(), max_norm=10)
                optimization.step()
                running_loss += cross_entropy_loss.item()
                #print("batch loss", loss.item())
                _, prediction = torch.max(out.data, dim=-1)
                total += target.size(0)
                correct += prediction.eq(target.data).cpu().sum()
            cur_loss = running_loss / len(train_loader)
            cur_acc = correct / total
            if isinstance(cur_acc, torch.Tensor):
                cur_acc = cur_acc.item()
            if isinstance(cur_loss, torch.Tensor):
                cur_loss = cur_loss.item()
            print('Loss: %.10f | Acc: %.3f%% (%d/%d)' %
                  (cur_loss, 100 * cur_acc, correct, total))
            running_loss_list.append(cur_loss)
            running_acc_list.append(cur_acc)
            scheduler_warmup.step()
            if epoch % 1 == 0:
                net.eval()
                print("start evaluating...")
                testing_loss = 0.0
                test_correct = 0.0
                test_total = 0.0
                for i, (feature, target) in enumerate(test_loader):
                    eeg = feature[:, :310]
                    eye = feature[:, 310:]
                    eeg = eeg.reshape(-1, 62, 5)
                    eeg = eeg.to(device)
                    eye = eye.to(device)
                    target = target.type(torch.LongTensor).to(device)
                    with torch.no_grad():
                        out = net(eeg, eye)
                        loss = criterion(out, target)
                        testing_loss += loss.item()
                        _, prediction = torch.max(out.data, dim=-1)
                        # print(prediction)
                        test_total += target.size(0)
                        test_correct += prediction.eq(target.data).cpu().sum()
                test_acc = test_correct / test_total
                test_loss = testing_loss / len(test_loader)
                if isinstance(test_acc, torch.Tensor):
                    test_acc = test_acc.item()
                if isinstance(test_loss, torch.Tensor):
                    test_loss = test_loss.item()
                print('Testset Loss: %.10f | Acc: %.3f%% (%d/%d)' %
                      (test_loss, 100 * test_acc, test_correct, test_total))
                testing_acc_list.append(test_acc)
                testing_loss_list.append(test_loss)
                if test_acc > best_acc:
                    best_acc = test_acc
                    print(
                        "better model founded in testsets, start saving new model"
                    )
                    model_name = '%s_%s' % (net.__class__.__name__,
                                            str(best_acc)[2:6])
                    state = {
                        'net': net.state_dict(),
                        'epoch': epoch,
                        'best_acc': best_acc,
                        'current_loss': test_loss
                    }
                    torch.save(state, os.path.join(save_model_path,
                                                   model_name))
        plot_acc_loss_curve(
            {
                'train_loss': running_loss_list,
                'train_acc': running_acc_list,
                'test_loss': testing_loss_list,
                'test_acc': testing_acc_list
            }, net.__class__.__name__, exp_des)