def plot_train_probs(subject, data_path, model_path):
    with open(model_path + "/" + subject + ".pickle", "rb") as f:
        state_dict = cPickle.load(f)
    cnn = ConvNet(state_dict["params"])
    cnn.set_weights(state_dict["weights"])
    scalers = state_dict["scalers"]

    d = load_train_data(data_path, subject)
    x, y = d["x"], d["y"]

    x, _ = (
        scale_across_time(x, x_test=None, scalers=scalers)
        if state_dict["params"]["scale_time"]
        else scale_across_features(x, x_test=None, scalers=scalers)
    )

    cnn.batch_size.set_value(x.shape[0])
    probs = cnn.get_test_proba(x)

    fpr, tpr, threshold = roc_curve(y, probs)
    c = np.sqrt((1 - tpr) ** 2 + fpr ** 2)
    opt_threshold = threshold[np.where(c == np.min(c))[0]]
    print opt_threshold

    x_coords = np.zeros(len(y), dtype="float64")
    rng = np.random.RandomState(42)
    x_coords += rng.normal(0.0, 0.08, size=len(x_coords))
    plt.scatter(x_coords, probs, c=y, s=60)
    plt.title(subject)
    plt.show()
def plot_features(subject, data_path, model_path, test_labels, dataset='test'):
    with open(model_path + '/' + subject + '.pickle', 'rb') as f:
        state_dict = cPickle.load(f)
    cnn = ConvNet(state_dict['params'])
    cnn.set_weights(state_dict['weights'])
    scalers = state_dict['scalers']

    if dataset == 'test':
        d = load_test_data(data_path, subject)
        x = d['x']
        y = test_labels['preictal']
    elif dataset == 'train':
        d = load_train_data(data_path, subject)
        x, y = d['x'], d['y']
    else:
        raise ValueError('dataset')

    x, _ = scale_across_time(x, x_test=None, scalers=scalers) if state_dict['params']['scale_time'] \
        else scale_across_features(x, x_test=None, scalers=scalers)

    cnn.batch_size.set_value(x.shape[0])
    get_features = theano.function([cnn.x, Param(cnn.training_mode, default=0)], cnn.feature_extractor.output,
                                 allow_input_downcast=True)

    logits_test = get_features(x)
    model = TSNE(n_components=2, random_state=0)
    z = model.fit_transform(np.float64(logits_test))
    plt.scatter(z[:, 0], z[:, 1], s=60, c=y)
    plt.show()
Esempio n. 3
0
def plot_train_probs(subject, data_path, model_path):
    with open(model_path + '/' + subject + '.pickle', 'rb') as f:
        state_dict = pickle.load(f)
    cnn = ConvNet(state_dict['params'])
    cnn.set_weights(state_dict['weights'])
    scalers = state_dict['scalers']

    d = load_train_data(data_path, subject)
    x, y = d['x'], d['y']

    x, _ = scale_across_time(x, x_test=None, scalers=scalers) if state_dict['params']['scale_time'] \
        else scale_across_features(x, x_test=None, scalers=scalers)

    cnn.batch_size.set_value(x.shape[0])
    probs = cnn.get_test_proba(x)

    fpr, tpr, threshold = roc_curve(y, probs)
    c = np.sqrt((1 - tpr)**2 + fpr**2)
    opt_threshold = threshold[np.where(c == np.min(c))[0]]
    print(opt_threshold)

    x_coords = np.zeros(len(y), dtype='float64')
    rng = np.random.RandomState(42)
    x_coords += rng.normal(0.0, 0.08, size=len(x_coords))
    plt.scatter(x_coords, probs, c=y, s=60)
    plt.title(subject)
    plt.show()
Esempio n. 4
0
def cross_validate(subject, data_path, reg_C, random_cv=False):
    if random_cv:
        d = load_train_data(data_path, subject)
        x, y = d['x'], d['y']
        skf = StratifiedKFold(y, n_folds=10)
    else:
        filenames_grouped_by_hour = cPickle.load(open('filenames.pickle'))
        data_grouped_by_hour = load_grouped_train_data(
            data_path, subject, filenames_grouped_by_hour)
        n_preictal, n_interictal = len(data_grouped_by_hour['preictal']), len(
            data_grouped_by_hour['interictal'])
        hours_data = data_grouped_by_hour['preictal'] + data_grouped_by_hour[
            'interictal']
        hours_labels = np.concatenate(
            (np.ones(n_preictal), np.zeros(n_interictal)))
        n_folds = n_preictal
        skf = StratifiedKFold(hours_labels, n_folds=n_folds)

    preictal_probs, labels = [], []
    for train_indexes, valid_indexes in skf:
        x_train, x_valid = [], []
        y_train, y_valid = [], []
        for i in train_indexes:
            x_train.extend(hours_data[i])
            y_train.extend(hours_labels[i] * np.ones(len(hours_data[i])))
        for i in valid_indexes:
            x_valid.extend(hours_data[i])
            y_valid.extend(hours_labels[i] * np.ones(len(hours_data[i])))

        x_train = [x[..., np.newaxis] for x in x_train]
        x_train = np.concatenate(x_train, axis=3)
        x_train = np.rollaxis(x_train, axis=3)
        y_train = np.array(y_train)

        x_valid = [x[..., np.newaxis] for x in x_valid]
        x_valid = np.concatenate(x_valid, axis=3)
        x_valid = np.rollaxis(x_valid, axis=3)
        y_valid = np.array(y_valid)

        n_valid_examples = x_valid.shape[0]
        n_timesteps = x_valid.shape[-1]

        x_train, y_train = reshape_data(x_train, y_train)
        data_scaler = StandardScaler()
        x_train = data_scaler.fit_transform(x_train)

        logreg = LogisticRegression(C=reg_C)
        logreg.fit(x_train, y_train)

        x_valid = reshape_data(x_valid)
        x_valid = data_scaler.transform(x_valid)

        p_valid = predict(logreg, x_valid, n_valid_examples, n_timesteps)

        preictal_probs.extend(p_valid)
        labels.extend(y_valid)

    return preictal_probs, labels
def train(subject, data_path, reg_C=None):
    d = load_train_data(data_path, subject)
    x, y = d['x'], d['y']
    x, y = reshape_data(x, y)
    data_scaler = StandardScaler()
    x = data_scaler.fit_transform(x)
    lda = LogisticRegression(C=reg_C)
    lda.fit(x, y)
    return lda, data_scaler
Esempio n. 6
0
def curve_per_subject(subject, data_path, test_labels):
    d = load_train_data(data_path, subject)
    x, y_10m = d['x'], d['y']
    n_train_examples = x.shape[0]
    n_timesteps = x.shape[-1]
    print('n_preictal', np.sum(y_10m))
    print('n_inetrictal', np.sum(y_10m - 1))

    x, y = reshape_data(x, y_10m)
    data_scaler = StandardScaler()
    x = data_scaler.fit_transform(x)

    lda = LDA()
    lda.fit(x, y)

    pred_1m = lda.predict_proba(x)[:, 1]
    pred_10m = np.reshape(pred_1m, (n_train_examples, n_timesteps))
    pred_10m = np.mean(pred_10m, axis=1)
    fpr, tpr, threshold = roc_curve(y_10m, pred_10m)
    c = np.sqrt((1 - tpr) ** 2 + fpr ** 2)
    opt_threshold = threshold[np.where(c == np.min(c))[0]][-1]
    print(opt_threshold)

    # ------- TEST ---------------

    d = load_test_data(data_path, subject)
    x_test, id = d['x'], d['id']
    n_test_examples = x_test.shape[0]
    n_timesteps = x_test.shape[3]
    x_test = reshape_data(x_test)
    x_test = data_scaler.transform(x_test)

    pred_1m = lda.predict_proba(x_test)[:, 1]
    pred_10m = np.reshape(pred_1m, (n_test_examples, n_timesteps))
    pred_10m = np.mean(pred_10m, axis=1)

    y_pred = np.zeros_like(test_labels)
    y_pred[np.where(pred_10m >= opt_threshold)] = 1
    cm = confusion_matrix(test_labels, y_pred)
    print(print_cm(cm, labels=['interictal', 'preictal']))
    sn = 1.0 * cm[1, 1] / (cm[1, 1] + cm[1, 0])
    sp = 1.0 * cm[0, 0] / (cm[0, 0] + cm[0, 1])
    print(sn, sp)

    sn, sp = [], []
    t_list = np.arange(0.0, 1.0, 0.01)
    for t in t_list:
        y_pred = np.zeros_like(test_labels)
        y_pred[np.where(pred_10m >= t)] = 1
        cm = confusion_matrix(test_labels, y_pred)
        sn_t = 1.0 * cm[1, 1] / (cm[1, 1] + cm[1, 0])
        sp_t = 1.0 * cm[0, 0] / (cm[0, 0] + cm[0, 1])
        sn.append(sn_t)
        sp.append(sp_t)

    return t_list, sn, sp
def curve_per_subject(subject, data_path, test_labels):
    d = load_train_data(data_path, subject)
    x, y_10m = d['x'], d['y']
    n_train_examples = x.shape[0]
    n_timesteps = x.shape[-1]
    print 'n_preictal', np.sum(y_10m)
    print 'n_inetrictal', np.sum(y_10m - 1)

    x, y = reshape_data(x, y_10m)
    data_scaler = StandardScaler()
    x = data_scaler.fit_transform(x)

    lda = LDA()
    lda.fit(x, y)

    pred_1m = lda.predict_proba(x)[:, 1]
    pred_10m = np.reshape(pred_1m, (n_train_examples, n_timesteps))
    pred_10m = np.mean(pred_10m, axis=1)
    fpr, tpr, threshold = roc_curve(y_10m, pred_10m)
    c = np.sqrt((1 - tpr) ** 2 + fpr ** 2)
    opt_threshold = threshold[np.where(c == np.min(c))[0]][-1]
    print opt_threshold

    # ------- TEST ---------------

    d = load_test_data(data_path, subject)
    x_test, id = d['x'], d['id']
    n_test_examples = x_test.shape[0]
    n_timesteps = x_test.shape[3]
    x_test = reshape_data(x_test)
    x_test = data_scaler.transform(x_test)

    pred_1m = lda.predict_proba(x_test)[:, 1]
    pred_10m = np.reshape(pred_1m, (n_test_examples, n_timesteps))
    pred_10m = np.mean(pred_10m, axis=1)

    y_pred = np.zeros_like(test_labels)
    y_pred[np.where(pred_10m >= opt_threshold)] = 1
    cm = confusion_matrix(test_labels, y_pred)
    print print_cm(cm, labels=['interictal', 'preictal'])
    sn = 1.0 * cm[1, 1] / (cm[1, 1] + cm[1, 0])
    sp = 1.0 * cm[0, 0] / (cm[0, 0] + cm[0, 1])
    print sn, sp

    sn, sp = [], []
    t_list = np.arange(0.0, 1.0, 0.01)
    for t in t_list:
        y_pred = np.zeros_like(test_labels)
        y_pred[np.where(pred_10m >= t)] = 1
        cm = confusion_matrix(test_labels, y_pred)
        sn_t = 1.0 * cm[1, 1] / (cm[1, 1] + cm[1, 0])
        sp_t = 1.0 * cm[0, 0] / (cm[0, 0] + cm[0, 1])
        sn.append(sn_t)
        sp.append(sp_t)

    return t_list, sn, sp
def cross_validate(subject, data_path, reg_C, random_cv=False):
    if random_cv:
        d = load_train_data(data_path,subject)
        x, y = d['x'], d['y']
        skf = StratifiedKFold(y, n_folds=10)
    else:
        filenames_grouped_by_hour = cPickle.load(open('filenames.pickle'))
        data_grouped_by_hour = load_grouped_train_data(data_path, subject, filenames_grouped_by_hour)
        n_preictal, n_interictal = len(data_grouped_by_hour['preictal']), len(data_grouped_by_hour['interictal'])
        hours_data = data_grouped_by_hour['preictal'] + data_grouped_by_hour['interictal']
        hours_labels = np.concatenate((np.ones(n_preictal), np.zeros(n_interictal)))
        n_folds = n_preictal
        skf = StratifiedKFold(hours_labels, n_folds=n_folds)


    preictal_probs, labels = [], []
    for train_indexes, valid_indexes in skf:
        x_train, x_valid = [], []
        y_train, y_valid = [], []
        for i in train_indexes:
            x_train.extend(hours_data[i])
            y_train.extend(hours_labels[i] * np.ones(len(hours_data[i])))
        for i in valid_indexes:
            x_valid.extend(hours_data[i])
            y_valid.extend(hours_labels[i] * np.ones(len(hours_data[i])))

        x_train = [x[..., np.newaxis] for x in x_train]
        x_train = np.concatenate(x_train, axis=3)
        x_train = np.rollaxis(x_train, axis=3)
        y_train = np.array(y_train)

        x_valid = [x[..., np.newaxis] for x in x_valid]
        x_valid = np.concatenate(x_valid, axis=3)
        x_valid = np.rollaxis(x_valid, axis=3)
        y_valid = np.array(y_valid)

        n_valid_examples = x_valid.shape[0]
        n_timesteps = x_valid.shape[-1]

        x_train, y_train = reshape_data(x_train, y_train)
        data_scaler = StandardScaler()
        x_train = data_scaler.fit_transform(x_train)

        logreg = LogisticRegression(C=reg_C)
        logreg.fit(x_train, y_train)

        x_valid = reshape_data(x_valid)
        x_valid = data_scaler.transform(x_valid)

        p_valid = predict(logreg, x_valid, n_valid_examples, n_timesteps)

        preictal_probs.extend(p_valid)
        labels.extend(y_valid)

    return preictal_probs, labels
Esempio n. 9
0
def train(subject, data_path, plot=False):
    d = load_train_data(data_path, subject)
    x, y = d['x'], d['y']
    print 'n_preictal', np.sum(y)
    print 'n_inetrictal', np.sum(y - 1)
    n_channels = x.shape[1]
    n_fbins = x.shape[2]

    x, y = reshape_data(x, y)
    data_scaler = StandardScaler()
    x = data_scaler.fit_transform(x)

    lda = LDA()
    lda.fit(x, y)
    coef = lda.scalings_ * lda.coef_[:1].T
    channels = []
    fbins = []
    for c in range(n_channels):
        fbins.extend(range(n_fbins))  # 0- delta, 1- theta ...
        channels.extend([c] * n_fbins)

    if plot:
        fig = plt.figure()
        for i in range(n_channels):
            if n_channels == 24:
                fig.add_subplot(4, 6, i)
            else:
                fig.add_subplot(4, 4, i)
            ax = plt.gca()
            ax.set_xlim([0, n_fbins])
            ax.set_xticks(np.arange(0.5, n_fbins + 0.5, 1))
            ax.set_xticklabels(np.arange(0, n_fbins))
            max_y = max(abs(coef)) + 0.01
            ax.set_ylim([0, max_y])
            ax.set_yticks(
                np.around(np.arange(0, max_y, max_y / 4.0), decimals=1))
            for label in (ax.get_xticklabels() + ax.get_yticklabels()):
                label.set_fontsize(15)
            plt.bar(range(0, n_fbins),
                    abs(coef[i * n_fbins:i * n_fbins + n_fbins]))
        fig.suptitle(subject, fontsize=20)
        plt.show()

    coefs = np.reshape(coef, (n_channels, n_fbins))
    return lda, data_scaler, coefs
def train(subject, data_path, plot=False):
    d = load_train_data(data_path, subject)
    x, y = d['x'], d['y']
    print 'n_preictal', np.sum(y)
    print 'n_inetrictal', np.sum(y - 1)
    n_channels = x.shape[1]
    n_fbins = x.shape[2]

    x, y = reshape_data(x, y)
    data_scaler = StandardScaler()
    x = data_scaler.fit_transform(x)

    lda = LDA()
    lda.fit(x, y)
    coef = lda.scalings_ * lda.coef_[:1].T
    channels = []
    fbins = []
    for c in range(n_channels):
        fbins.extend(range(n_fbins))  # 0- delta, 1- theta ...
        channels.extend([c] * n_fbins)

    if plot:
        fig = plt.figure()
        for i in range(n_channels):
            if n_channels == 24:
                fig.add_subplot(4, 6, i)
            else:
                fig.add_subplot(4, 4, i)
            ax = plt.gca()
            ax.set_xlim([0, n_fbins])
            ax.set_xticks(np.arange(0.5, n_fbins + 0.5, 1))
            ax.set_xticklabels(np.arange(0, n_fbins))
            max_y = max(abs(coef)) + 0.01
            ax.set_ylim([0, max_y])
            ax.set_yticks(np.around(np.arange(0, max_y, max_y / 4.0), decimals=1))
            for label in (ax.get_xticklabels() + ax.get_yticklabels()):
                label.set_fontsize(15)
            plt.bar(range(0, n_fbins), abs(coef[i * n_fbins:i * n_fbins + n_fbins]))
        fig.suptitle(subject, fontsize=20)
        plt.show()

    coefs = np.reshape(coef, (n_channels, n_fbins))
    return lda, data_scaler, coefs
def plot_train_test(subject, data_path):
    d = load_train_data(data_path, subject)
    x_train = d['x']
    x_train = x_train.reshape(x_train.shape[0], x_train.shape[1] * x_train.shape[2] * x_train.shape[3])

    d = load_test_data(data_path, subject)
    x_test, id = d['x'], d['id']
    x_test = np.reshape(x_test, (x_test.shape[0], x_test.shape[1] * x_test.shape[2] * x_test.shape[3]))

    x_all = np.vstack((np.float64(x_train), np.float64(x_test)))
    scaler = StandardScaler()
    x_all = scaler.fit_transform(x_all)

    colors = ['r'] * len(x_train) + ['b'] * len(x_test)
    markers = ['o'] * len(x_train) + ['^'] * len(x_test)

    pca = PCA(50)
    pca.fit(x_all)
    x_all = pca.fit_transform(x_all)

    model = TSNE(n_components=2, perplexity=40, learning_rate=100, random_state=42)
    z = model.fit_transform(x_all)

    for a, b, c, d in zip(z[:, 0], z[:, 1], colors, markers):
        plt.scatter(a, b, c=c, s=60, marker=d)

    plt.scatter(z[0, 0], z[0, 1], c=colors[0], marker=markers[0], s=60, label='train')
    plt.scatter(z[-1, 0], z[-1, 1], c=colors[-1], marker=markers[-1], s=60, label='test')

    zz = z[np.where(np.array(markers) != u' ')[0], :]
    ax = plt.subplot(111)
    ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.05),
              ncol=2, fancybox=True, shadow=True)
    plt.xlim([min(zz[:, 0]) - 0.5, max(zz[:, 0] + 0.5)])
    plt.ylim([min(zz[:, 1]) - 0.5, max(zz[:, 1] + 0.5)])
    for label in (ax.get_xticklabels() + ax.get_yticklabels()):
        label.set_fontsize(20)
    plt.ylabel('Z_2', fontsize=20)
    plt.xlabel('Z_1', fontsize=20)
    plt.show()
Esempio n. 12
0
def train(subject, data_path, model_path, model_params, validation_params):
    d = load_train_data(data_path, subject)
    x, y, filename_to_idx = d['x'], d['y'], d['filename_to_idx']
    x_test = load_test_data(data_path,
                            subject)['x'] if model_params['use_test'] else None

    # --------- add params
    model_params['n_channels'] = x.shape[1]
    model_params['n_fbins'] = x.shape[2]
    model_params['n_timesteps'] = x.shape[3]

    print '============ parameters'
    for key, value in model_params.items():
        print key, ':', value
    print '========================'

    x_train, y_train = None, None
    x_valid, y_valid = None, None

    if model_params['overlap']:
        # no validation if overlap
        filenames_grouped_by_hour = cPickle.load(open('filenames.pickle'))
        data_grouped_by_hour = load_grouped_train_data(
            data_path, subject, filenames_grouped_by_hour)
        x, y = generate_overlapped_data(data_grouped_by_hour,
                                        overlap_size=model_params['overlap'],
                                        window_size=x.shape[-1],
                                        overlap_interictal=True,
                                        overlap_preictal=True)
        print x.shape

        x, scalers = scale_across_time(x, x_test=None) if model_params['scale_time'] \
            else scale_across_features(x, x_test=None)

        cnn = ConvNet(model_params)
        cnn.train(train_set=(x, y), max_iter=175000)
        state_dict = cnn.get_state()
        state_dict['scalers'] = scalers
        with open(model_path + '/' + subject + '.pickle', 'wb') as f:
            cPickle.dump(state_dict, f, protocol=cPickle.HIGHEST_PROTOCOL)
        return
    else:
        if validation_params['random_split']:
            skf = StratifiedShuffleSplit(y,
                                         n_iter=1,
                                         test_size=0.25,
                                         random_state=0)
            for train_idx, valid_idx in skf:
                x_train, y_train = x[train_idx], y[train_idx]
                x_valid, y_valid = x[valid_idx], y[valid_idx]
        else:
            filenames_grouped_by_hour = cPickle.load(open('filenames.pickle'))
            d = split_train_valid_filenames(subject, filenames_grouped_by_hour)
            train_filenames, valid_filenames = d['train_filenames'], d[
                'valid_filnames']
            train_idx = [filename_to_idx[i] for i in train_filenames]
            valid_idx = [filename_to_idx[i] for i in valid_filenames]
            x_train, y_train = x[train_idx], y[train_idx]
            x_valid, y_valid = x[valid_idx], y[valid_idx]

    if model_params['scale_time']:
        x_train, scalers_train = scale_across_time(x=x_train, x_test=x_test)
        x_valid, _ = scale_across_time(x=x_valid,
                                       x_test=x_test,
                                       scalers=scalers_train)
    else:
        x_train, scalers_train = scale_across_features(x=x_train,
                                                       x_test=x_test)
        x_valid, _ = scale_across_features(x=x_valid,
                                           x_test=x_test,
                                           scalers=scalers_train)

    del x, x_test

    print '============ dataset'
    print 'train:', x_train.shape
    print 'n_pos:', np.sum(y_train), 'n_neg:', len(y_train) - np.sum(y_train)
    print 'valid:', x_valid.shape
    print 'n_pos:', np.sum(y_valid), 'n_neg:', len(y_valid) - np.sum(y_valid)

    # -------------- validate
    cnn = ConvNet(model_params)
    best_iter = cnn.validate(train_set=(x_train, y_train),
                             valid_set=(x_valid, y_valid),
                             valid_freq=validation_params['valid_freq'],
                             max_iter=validation_params['max_iter'],
                             fname_out=model_path + '/' + subject + '.txt')

    # ---------------- scale
    d = load_train_data(data_path, subject)
    x, y, filename_to_idx = d['x'], d['y'], d['filename_to_idx']
    x_test = load_test_data(data_path,
                            subject)['x'] if model_params['use_test'] else None

    x, scalers = scale_across_time(x=x, x_test=x_test) if model_params['scale_time'] \
        else scale_across_features(x=x, x_test=x_test)
    del x_test

    cnn = ConvNet(model_params)
    cnn.train(train_set=(x, y), max_iter=best_iter)
    state_dict = cnn.get_state()
    state_dict['scalers'] = scalers
    with open(model_path + '/' + subject + '.pickle', 'wb') as f:
        cPickle.dump(state_dict, f, protocol=cPickle.HIGHEST_PROTOCOL)
Esempio n. 13
0
""" Loop over all patients, 
	make probabilistic predictions for each method within a given patient
	combine the probalities by either:
	1) just summing them up
	2) subtracting the mean prediction probability from each class 
	   for each method and then summing (Avoids all 0 predictions) """
all_predictions = []
all_predictions_ns = []
validations_true = []
validations_preds = []

for patient in all_patients:

	#LOAD DATA
	d = load_train_data('preprocessed/cnn/', patient)
	x, y, filename_to_idx = d['x'], d['y'], d['filename_to_idx']
	x_test = load_test_data('preprocessed/cnn/', patient)['x']

	test_preds,test_preds_ns,val_preds, val_true,train_loss,valid_loss= train_predict_test_cnn(
		patient,CNN(patient),x,x_test,enhance_size = 1000)

	roc_area = roc_auc_score(val_true,val_preds)
	print patient, roc_area

	plot = plt.figure()
	plot_train_val_loss(train_loss,valid_loss,patient)
	plot.savefig('./figs/CNN'+patient+'train_val.png')

	
	all_predictions.append(test_preds)
def train(subject, data_path, model_path, model_params, validation_params):
    d = load_train_data(data_path, subject)
    x, y, filename_to_idx = d['x'], d['y'], d['filename_to_idx']
    x_test = load_test_data(data_path, subject)['x'] if model_params['use_test'] else None

    # --------- add params
    model_params['n_channels'] = x.shape[1]
    model_params['n_fbins'] = x.shape[2]
    model_params['n_timesteps'] = x.shape[3]

    print '============ parameters'
    for key, value in model_params.items():
        print key, ':', value
    print '========================'

    x_train, y_train = None, None
    x_valid, y_valid = None, None

    if model_params['overlap']:
        # no validation if overlap
        filenames_grouped_by_hour = cPickle.load(open('filenames.pickle'))
        data_grouped_by_hour = load_grouped_train_data(data_path, subject, filenames_grouped_by_hour)
        x, y = generate_overlapped_data(data_grouped_by_hour, overlap_size=model_params['overlap'],
                                        window_size=x.shape[-1],
                                        overlap_interictal=True,
                                        overlap_preictal=True)
        print x.shape

        x, scalers = scale_across_time(x, x_test=None) if model_params['scale_time'] \
            else scale_across_features(x, x_test=None)

        cnn = ConvNet(model_params)
        cnn.train(train_set=(x, y), max_iter=175000)
        state_dict = cnn.get_state()
        state_dict['scalers'] = scalers
        with open(model_path + '/' + subject + '.pickle', 'wb') as f:
            cPickle.dump(state_dict, f, protocol=cPickle.HIGHEST_PROTOCOL)
        return
    else:
        if validation_params['random_split']:
            skf = StratifiedShuffleSplit(y, n_iter=1, test_size=0.25, random_state=0)
            for train_idx, valid_idx in skf:
                x_train, y_train = x[train_idx], y[train_idx]
                x_valid, y_valid = x[valid_idx], y[valid_idx]
        else:
            filenames_grouped_by_hour = cPickle.load(open('filenames.pickle'))
            d = split_train_valid_filenames(subject, filenames_grouped_by_hour)
            train_filenames, valid_filenames = d['train_filenames'], d['valid_filnames']
            train_idx = [filename_to_idx[i] for i in train_filenames]
            valid_idx = [filename_to_idx[i] for i in valid_filenames]
            x_train, y_train = x[train_idx], y[train_idx]
            x_valid, y_valid = x[valid_idx], y[valid_idx]

    if model_params['scale_time']:
        x_train, scalers_train = scale_across_time(x=x_train, x_test=x_test)
        x_valid, _ = scale_across_time(x=x_valid, x_test=x_test, scalers=scalers_train)
    else:
        x_train, scalers_train = scale_across_features(x=x_train, x_test=x_test)
        x_valid, _ = scale_across_features(x=x_valid, x_test=x_test, scalers=scalers_train)

    del x, x_test

    print '============ dataset'
    print 'train:', x_train.shape
    print 'n_pos:', np.sum(y_train), 'n_neg:', len(y_train) - np.sum(y_train)
    print 'valid:', x_valid.shape
    print 'n_pos:', np.sum(y_valid), 'n_neg:', len(y_valid) - np.sum(y_valid)

    # -------------- validate
    cnn = ConvNet(model_params)
    best_iter = cnn.validate(train_set=(x_train, y_train),
                             valid_set=(x_valid, y_valid),
                             valid_freq=validation_params['valid_freq'],
                             max_iter=validation_params['max_iter'],
                             fname_out=model_path + '/' + subject + '.txt')

    # ---------------- scale
    d = load_train_data(data_path, subject)
    x, y, filename_to_idx = d['x'], d['y'], d['filename_to_idx']
    x_test = load_test_data(data_path, subject)['x'] if model_params['use_test'] else None

    x, scalers = scale_across_time(x=x, x_test=x_test) if model_params['scale_time'] \
        else scale_across_features(x=x, x_test=x_test)
    del x_test

    cnn = ConvNet(model_params)
    cnn.train(train_set=(x, y), max_iter=best_iter)
    state_dict = cnn.get_state()
    state_dict['scalers'] = scalers
    with open(model_path + '/' + subject + '.pickle', 'wb') as f:
        cPickle.dump(state_dict, f, protocol=cPickle.HIGHEST_PROTOCOL)