Ejemplo n.º 1
0
def mean_weights(ids: list, hidden=True, diagonal=True, save_name='default'):
    """

    """
    id1, id2, id3 = [], [], []
    mean_abs = []
    patient_id = []
    brain_state = []
    batch_size = []

    for i, id_ in enumerate(ids):
        params = pickle.load(open('../models/' + id_ + '/params.pkl', 'rb'))
        # Get trained model
        model = models.GeneralRNN(params)

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model.load_state_dict(
            torch.load('../models/' + id_ + '/model.pth', map_location=device))
        W = model.W.weight.data.numpy()

        if id_[5:7] == 'ID11a':  # [7:12]
            patient_id.append('ID11a')
        elif id_[5:7] == 'ID11b':
            patient_id.append('ID11b')
        else:
            patient_id.append(params['patient_id'])
        brain_state.append(params['brain_state'])
        batch_size.append(params['batch_size'])

        W_abs = np.abs(W)
        if hidden is False:
            ch = params['visible_size']
            W_abs = np.abs(W[:ch, :ch])
        if diagonal is False:
            np.fill_diagonal(W_abs, 0)
        mean_abs.append(np.mean(W_abs))

    df = pd.DataFrame()
    df['Patient ID'] = patient_id
    df['Pos. in sleep cylce'] = brain_state
    df['Mean abs. weight'] = mean_abs
    df['Batch size'] = batch_size

    with sns.color_palette('colorblind', 3):
        plt.figure(figsize=(10, 8))
        sns.set_style('whitegrid')
        ax = sns.barplot(x='Mean abs. weight',
                         y='Batch size',
                         hue='Pos. in sleep cylce',
                         data=df,
                         orient='h')
        ax.set(xlabel='Mean abs. weight', ylabel='Batch size')
        ax.set_title('Mean abs. weight')
        #ax.set_xlim(right=0.05)
    plt.savefig('../doc/figures/barplots_meanabs_' + save_name + '.png')
    plt.close()
Ejemplo n.º 2
0
def predict(id_: str, custom_test_set: dict=None):
    """ Tests model an returns and saves predicted values.

        If the prediction set is not the training set, pass a custom_test_set dictionary containing:
            'time_begin', 'duration', 'batch_size'

        Returns and saves:
            ../model/eval_prediction.pkl
    """
    # Define device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Load data and parameters
    print('Status: Load and process data for prediction.')
    params = pickle.load(open('../models/' + id_ + '/params.pkl', 'rb'))
    if custom_test_set is None:
        data_pre = pre_process(params=params)
        batch_size = params['batch_size']
    else:
        data_pre = pre_process(params=params, custom_test_set=custom_test_set)
        batch_size = custom_test_set['batch_size']
    data_set = iEEG_DataSet(data_pre, params['window_size'])
    data_generator = torch.utils.data.DataLoader(data_set, batch_size=batch_size, shuffle=False)

    # Make model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = models.GeneralRNN(params)
    model.load_state_dict(torch.load('../models/' + id_ + '/model.pth', map_location=device))
    model = model.to(device)

    # Evaluate model
    model.eval()
    pred_all = []
    true_all = []

    print('Status: Start prediction with cuda = ' + str(torch.cuda.is_available()) + '.')
    with torch.no_grad():
        for X, y in data_generator:
            X, y = X.to(device), y.to(device)
            predictions = model(X)
            pred_all.append(predictions.cpu().numpy())
            true_all.append(y.cpu().numpy())
        pred_all, true_all = np.concatenate(pred_all), np.concatenate(true_all)

    # Save predictions to file
    print('Status: Save predictions to file.')
    eval_prediction = {'prediction': pred_all,
                       'true': true_all}
    pickle.dump(eval_prediction, open('../models/' + id_ + '/eval_prediction.pkl', 'wb'))

    return eval_prediction
Ejemplo n.º 3
0
def train(params):
    """ Trains model with parameters params.

        Saves:
            ../model/model.pth
            ../model/params.pkl
            ../model/eval_optim.pkl
    """
    # Define device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Load data
    print('Status: Start data preparation.')
    data_pre = pre_process(params=params)
    if params['visible_size'] == 'all':
        params['visible_size'] = data_pre.shape[1]
    data_set = iEEG_DataSet(data_pre, params['window_size'])
    data_generator = torch.utils.data.DataLoader(data_set, batch_size=params['batch_size'], shuffle=params['shuffle'])

    # Make model
    model = models.GeneralRNN(params)
    model = model.to(device)

    # Define training parameters
    criterion = None
    if params['loss_function'] == 'mae':
        criterion = nn.L1Loss(reduction='none')
    elif params['loss_function'] == 'mse':
        criterion = nn.MSELoss(reduction='none')
    else:
        print('Error: No valid loss function.')

    lr = params['lr']
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # Training
    loss = None
    epoch_loss = np.zeros([params['epochs'], model.visible_size])
    epoch_grad_norm = np.zeros(params['epochs'])
    #epoch_time = np.zeros(params['epochs'] + 1)
    W = []

    start_time = time.time()
    #epoch_time[0] = time.time()
    print('Status: Start training with cuda = ' + str(torch.cuda.is_available()) + '.')

    for epoch in range(params['epochs']):
        W.append(np.copy(model.W.weight.data.cpu().numpy()))
        for X, y in data_generator:
            X, y = X.to(device), y.float().to(device)
            optimizer.zero_grad()
            prediction = model(X)
            loss = criterion(prediction, y)
            torch.autograd.backward(loss.mean()) #  loss.mean().backward()
            optimizer.step()
        for p in model.parameters():
            epoch_grad_norm[epoch] = p.grad.data.norm(2).item()
        epoch_loss[epoch, :] = np.mean(loss.detach().cpu().numpy(), axis=0)
        #epoch_time[epoch + 1] = time.time() - epoch_time[epoch]
        #if epoch % 5 == 0:

        add_id = params['add_id']
        print(f'{add_id} Epoch: {epoch} | Loss: {np.mean(epoch_loss[epoch, :]):.4}')

    total_time = time.time() - start_time
    print(f'Time [min]: {total_time / 60:.3}')

    # Make optimizer evaluation dictionary
    eval_optimization = {'id_': params['id_'],
                         'loss': epoch_loss,
                         'grad_norm': epoch_grad_norm}

    # Save model
    print('Status: Save trained model to file.')
    directory = '../models/' + params['id_']
    if not os.path.exists(directory):
        os.mkdir(directory)
    torch.save(model.state_dict(), directory + '/model.pth')
    pickle.dump(params, open(directory + '/params.pkl', 'wb'))
    pickle.dump(eval_optimization, open(directory + '/eval_optimization.pkl', 'wb'))

    W_epoch = {'W_epoch': W}
    pickle.dump(W_epoch, open(directory + '/W_epoch.pkl', 'wb'))
Ejemplo n.º 4
0
def plot_weights(id_: str, vmax=1, linewidth=0, absolute=False):
    """ Makes and saves a heat map of weight matrix W to ../figures/.

        Saves:
            Figure "weights_[...]"
    """
    params = pickle.load(open('../models/' + id_ + '/params.pkl', 'rb'))
    # Get trained model
    model = models.GeneralRNN(params)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.load_state_dict(
        torch.load('../models/' + id_ + '/model.pth', map_location=device))
    W = model.W.weight.data.numpy()

    vmin = -vmax
    cmap = 'bwr'
    ch = params['visible_size']

    if absolute:
        vmin = 0
        cmap = 'Blues'
        W = np.abs(W)

    fig = plt.figure(figsize=(10, 10))
    gs = fig.add_gridspec(nrows=W.shape[0], ncols=W.shape[0])
    cbar_ax = fig.add_axes([.92, .11, .02, .77])  # x-pos,y-pos,width,height

    if W.shape[0] == ch:
        ax0 = fig.add_subplot(gs[:ch, :ch])
        sns.heatmap(W[:ch, :ch],
                    cmap=cmap,
                    vmin=vmin,
                    vmax=vmax,
                    cbar_ax=cbar_ax,
                    linewidths=linewidth,
                    ax=ax0)
        ax0.set_ylabel('to visible nodes')
        ax0.set_xlabel('from visible nodes')
        ax0.set_title('Weight matrix of ' + params['id_'])

    else:
        ax0 = fig.add_subplot(gs[:ch, :ch])
        ax0.get_xaxis().set_visible(False)
        ax1 = fig.add_subplot(gs[:ch, ch:])
        ax1.get_xaxis().set_visible(False), ax1.get_yaxis().set_visible(False)
        ax2 = fig.add_subplot(gs[ch:, :ch])

        ax3 = fig.add_subplot(gs[ch:, ch:])
        ax3.get_yaxis().set_visible(False)

        sns.heatmap(W[:ch, :ch],
                    cmap=cmap,
                    vmin=vmin,
                    vmax=vmax,
                    cbar=False,
                    linewidths=linewidth,
                    ax=ax0)
        sns.heatmap(W[:ch, ch:],
                    cmap=cmap,
                    vmin=vmin,
                    vmax=vmax,
                    cbar=False,
                    linewidths=linewidth,
                    ax=ax1)
        sns.heatmap(W[ch:, :ch],
                    cmap=cmap,
                    vmin=vmin,
                    vmax=vmax,
                    cbar=False,
                    linewidths=linewidth,
                    ax=ax2)
        sns.heatmap(W[ch:, ch:],
                    cmap=cmap,
                    vmin=vmin,
                    vmax=vmax,
                    cbar_ax=cbar_ax,
                    linewidths=linewidth,
                    ax=ax3)

        pos_to_vis = 0.8 / W.shape[0] * params['hidden_size'] + 0.8 / W.shape[
            0] * (ch / 2) + 0.1
        pos_to_hid = 0.8 / W.shape[0] * (params['hidden_size'] / 2) + 0.1
        pos_from_vis = 0.8 / W.shape[0] * (ch / 2) + 0.1
        pos_from_hid = 0.8 / W.shape[0] * ch + 0.8 / W.shape[0] * (
            params['hidden_size'] / 2) + 0.1
        fig.text(0.08,
                 pos_to_vis,
                 'to visible node',
                 va='center',
                 ha='center',
                 rotation='vertical')
        fig.text(0.08,
                 pos_to_hid,
                 'to hidden node',
                 va='center',
                 ha='center',
                 rotation='vertical')
        fig.text(pos_from_vis,
                 0.06,
                 'from visible node',
                 va='center',
                 ha='center')
        fig.text(pos_from_hid,
                 0.06,
                 'from hidden node',
                 va='center',
                 ha='center')
        fig.subplots_adjust(hspace=0.8, wspace=0.8)

    plt.suptitle('Weight matrix of model ' + params['id_'])
    fig.savefig('../doc/figures/weights_' + id_ + '.png')
    plt.close()
Ejemplo n.º 5
0
def plot_weighted_prediction(id_, node_idx, max_duration=.5):
    # Get model
    params = pickle.load(open('../models/' + id_ + '/params.pkl', 'rb'))
    fs = params['resample']
    model = models.GeneralRNN(params)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.load_state_dict(
        torch.load('../models/' + id_ + '/model.pth', map_location=device))
    W = model.W.weight.data.numpy()
    w = W[node_idx, :]
    b = model.W.bias.data.numpy()
    b = b[node_idx]

    # Get prediction
    eval_prediction = pickle.load(
        open('../models/' + id_ + '/eval_prediction.pkl', 'rb'))
    if max_duration * fs >= eval_prediction['prediction'].shape[0]:
        max_duration = int((eval_prediction['prediction'].shape[0] - 1) / fs)
    prediction = eval_prediction['prediction'][-int(max_duration * fs):,
                                               node_idx]
    true = eval_prediction['true'][-int(max_duration * fs):, node_idx]

    # Get data
    data = utrain.pre_process(params=params).numpy()
    data = data[-int(max_duration * fs):, :]
    data[:, node_idx] = prediction

    sns.set_style('white')
    fig = plt.figure(figsize=(12, 4.3))
    gs = fig.add_gridspec(1, 5)

    ax0 = fig.add_subplot(gs[:, :2])
    sns.heatmap(W, cmap='seismic', vmin=-1, vmax=1)
    ax0.add_patch(
        mpl.patches.Rectangle((0, node_idx),
                              data.shape[1],
                              1,
                              fill=False,
                              edgecolor='black',
                              lw=3))
    ax0.set_xlabel('From node'), ax0.set_ylabel('To node'), ax0.set_title(
        'Weight matrix')

    ax1 = fig.add_subplot(gs[:, 2:])
    t = np.arange(0, data.shape[0] / fs, 1 / fs)
    cmap = mpl.cm.get_cmap('seismic')
    for i in range(data.shape[1] - 1):
        color = cmap(w[i] / 2 + 0.5)
        alpha = np.abs(w[i]) / np.max(np.abs(w))
        plt.plot(t, data[:, i], color=color, alpha=alpha)
    plt.plot(t, prediction, color='black', linestyle=':', label='predicted')
    plt.plot(t, true, color='black', label='true')
    ax1.set_xlabel('Time [s]'), ax1.set_ylabel(
        'Membrane potential u(t) [a.U.]')
    ax1.set_xlim(0,
                 t[-1]), ax1.set_title('Contribution to prediction of node ' +
                                       str(node_idx)), plt.legend()
    plt.tight_layout()
    plt.savefig('../doc/figures/contribution_' + id_ + '_node_' +
                str(node_idx) + '.png')
    plt.close()