def mean_weights(ids: list, hidden=True, diagonal=True, save_name='default'): """ """ id1, id2, id3 = [], [], [] mean_abs = [] patient_id = [] brain_state = [] batch_size = [] for i, id_ in enumerate(ids): params = pickle.load(open('../models/' + id_ + '/params.pkl', 'rb')) # Get trained model model = models.GeneralRNN(params) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.load_state_dict( torch.load('../models/' + id_ + '/model.pth', map_location=device)) W = model.W.weight.data.numpy() if id_[5:7] == 'ID11a': # [7:12] patient_id.append('ID11a') elif id_[5:7] == 'ID11b': patient_id.append('ID11b') else: patient_id.append(params['patient_id']) brain_state.append(params['brain_state']) batch_size.append(params['batch_size']) W_abs = np.abs(W) if hidden is False: ch = params['visible_size'] W_abs = np.abs(W[:ch, :ch]) if diagonal is False: np.fill_diagonal(W_abs, 0) mean_abs.append(np.mean(W_abs)) df = pd.DataFrame() df['Patient ID'] = patient_id df['Pos. in sleep cylce'] = brain_state df['Mean abs. weight'] = mean_abs df['Batch size'] = batch_size with sns.color_palette('colorblind', 3): plt.figure(figsize=(10, 8)) sns.set_style('whitegrid') ax = sns.barplot(x='Mean abs. weight', y='Batch size', hue='Pos. in sleep cylce', data=df, orient='h') ax.set(xlabel='Mean abs. weight', ylabel='Batch size') ax.set_title('Mean abs. weight') #ax.set_xlim(right=0.05) plt.savefig('../doc/figures/barplots_meanabs_' + save_name + '.png') plt.close()
def predict(id_: str, custom_test_set: dict=None): """ Tests model an returns and saves predicted values. If the prediction set is not the training set, pass a custom_test_set dictionary containing: 'time_begin', 'duration', 'batch_size' Returns and saves: ../model/eval_prediction.pkl """ # Define device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load data and parameters print('Status: Load and process data for prediction.') params = pickle.load(open('../models/' + id_ + '/params.pkl', 'rb')) if custom_test_set is None: data_pre = pre_process(params=params) batch_size = params['batch_size'] else: data_pre = pre_process(params=params, custom_test_set=custom_test_set) batch_size = custom_test_set['batch_size'] data_set = iEEG_DataSet(data_pre, params['window_size']) data_generator = torch.utils.data.DataLoader(data_set, batch_size=batch_size, shuffle=False) # Make model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = models.GeneralRNN(params) model.load_state_dict(torch.load('../models/' + id_ + '/model.pth', map_location=device)) model = model.to(device) # Evaluate model model.eval() pred_all = [] true_all = [] print('Status: Start prediction with cuda = ' + str(torch.cuda.is_available()) + '.') with torch.no_grad(): for X, y in data_generator: X, y = X.to(device), y.to(device) predictions = model(X) pred_all.append(predictions.cpu().numpy()) true_all.append(y.cpu().numpy()) pred_all, true_all = np.concatenate(pred_all), np.concatenate(true_all) # Save predictions to file print('Status: Save predictions to file.') eval_prediction = {'prediction': pred_all, 'true': true_all} pickle.dump(eval_prediction, open('../models/' + id_ + '/eval_prediction.pkl', 'wb')) return eval_prediction
def train(params): """ Trains model with parameters params. Saves: ../model/model.pth ../model/params.pkl ../model/eval_optim.pkl """ # Define device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load data print('Status: Start data preparation.') data_pre = pre_process(params=params) if params['visible_size'] == 'all': params['visible_size'] = data_pre.shape[1] data_set = iEEG_DataSet(data_pre, params['window_size']) data_generator = torch.utils.data.DataLoader(data_set, batch_size=params['batch_size'], shuffle=params['shuffle']) # Make model model = models.GeneralRNN(params) model = model.to(device) # Define training parameters criterion = None if params['loss_function'] == 'mae': criterion = nn.L1Loss(reduction='none') elif params['loss_function'] == 'mse': criterion = nn.MSELoss(reduction='none') else: print('Error: No valid loss function.') lr = params['lr'] optimizer = torch.optim.Adam(model.parameters(), lr=lr) # Training loss = None epoch_loss = np.zeros([params['epochs'], model.visible_size]) epoch_grad_norm = np.zeros(params['epochs']) #epoch_time = np.zeros(params['epochs'] + 1) W = [] start_time = time.time() #epoch_time[0] = time.time() print('Status: Start training with cuda = ' + str(torch.cuda.is_available()) + '.') for epoch in range(params['epochs']): W.append(np.copy(model.W.weight.data.cpu().numpy())) for X, y in data_generator: X, y = X.to(device), y.float().to(device) optimizer.zero_grad() prediction = model(X) loss = criterion(prediction, y) torch.autograd.backward(loss.mean()) # loss.mean().backward() optimizer.step() for p in model.parameters(): epoch_grad_norm[epoch] = p.grad.data.norm(2).item() epoch_loss[epoch, :] = np.mean(loss.detach().cpu().numpy(), axis=0) #epoch_time[epoch + 1] = time.time() - epoch_time[epoch] #if epoch % 5 == 0: add_id = params['add_id'] print(f'{add_id} Epoch: {epoch} | Loss: {np.mean(epoch_loss[epoch, :]):.4}') total_time = time.time() - start_time print(f'Time [min]: {total_time / 60:.3}') # Make optimizer evaluation dictionary eval_optimization = {'id_': params['id_'], 'loss': epoch_loss, 'grad_norm': epoch_grad_norm} # Save model print('Status: Save trained model to file.') directory = '../models/' + params['id_'] if not os.path.exists(directory): os.mkdir(directory) torch.save(model.state_dict(), directory + '/model.pth') pickle.dump(params, open(directory + '/params.pkl', 'wb')) pickle.dump(eval_optimization, open(directory + '/eval_optimization.pkl', 'wb')) W_epoch = {'W_epoch': W} pickle.dump(W_epoch, open(directory + '/W_epoch.pkl', 'wb'))
def plot_weights(id_: str, vmax=1, linewidth=0, absolute=False): """ Makes and saves a heat map of weight matrix W to ../figures/. Saves: Figure "weights_[...]" """ params = pickle.load(open('../models/' + id_ + '/params.pkl', 'rb')) # Get trained model model = models.GeneralRNN(params) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.load_state_dict( torch.load('../models/' + id_ + '/model.pth', map_location=device)) W = model.W.weight.data.numpy() vmin = -vmax cmap = 'bwr' ch = params['visible_size'] if absolute: vmin = 0 cmap = 'Blues' W = np.abs(W) fig = plt.figure(figsize=(10, 10)) gs = fig.add_gridspec(nrows=W.shape[0], ncols=W.shape[0]) cbar_ax = fig.add_axes([.92, .11, .02, .77]) # x-pos,y-pos,width,height if W.shape[0] == ch: ax0 = fig.add_subplot(gs[:ch, :ch]) sns.heatmap(W[:ch, :ch], cmap=cmap, vmin=vmin, vmax=vmax, cbar_ax=cbar_ax, linewidths=linewidth, ax=ax0) ax0.set_ylabel('to visible nodes') ax0.set_xlabel('from visible nodes') ax0.set_title('Weight matrix of ' + params['id_']) else: ax0 = fig.add_subplot(gs[:ch, :ch]) ax0.get_xaxis().set_visible(False) ax1 = fig.add_subplot(gs[:ch, ch:]) ax1.get_xaxis().set_visible(False), ax1.get_yaxis().set_visible(False) ax2 = fig.add_subplot(gs[ch:, :ch]) ax3 = fig.add_subplot(gs[ch:, ch:]) ax3.get_yaxis().set_visible(False) sns.heatmap(W[:ch, :ch], cmap=cmap, vmin=vmin, vmax=vmax, cbar=False, linewidths=linewidth, ax=ax0) sns.heatmap(W[:ch, ch:], cmap=cmap, vmin=vmin, vmax=vmax, cbar=False, linewidths=linewidth, ax=ax1) sns.heatmap(W[ch:, :ch], cmap=cmap, vmin=vmin, vmax=vmax, cbar=False, linewidths=linewidth, ax=ax2) sns.heatmap(W[ch:, ch:], cmap=cmap, vmin=vmin, vmax=vmax, cbar_ax=cbar_ax, linewidths=linewidth, ax=ax3) pos_to_vis = 0.8 / W.shape[0] * params['hidden_size'] + 0.8 / W.shape[ 0] * (ch / 2) + 0.1 pos_to_hid = 0.8 / W.shape[0] * (params['hidden_size'] / 2) + 0.1 pos_from_vis = 0.8 / W.shape[0] * (ch / 2) + 0.1 pos_from_hid = 0.8 / W.shape[0] * ch + 0.8 / W.shape[0] * ( params['hidden_size'] / 2) + 0.1 fig.text(0.08, pos_to_vis, 'to visible node', va='center', ha='center', rotation='vertical') fig.text(0.08, pos_to_hid, 'to hidden node', va='center', ha='center', rotation='vertical') fig.text(pos_from_vis, 0.06, 'from visible node', va='center', ha='center') fig.text(pos_from_hid, 0.06, 'from hidden node', va='center', ha='center') fig.subplots_adjust(hspace=0.8, wspace=0.8) plt.suptitle('Weight matrix of model ' + params['id_']) fig.savefig('../doc/figures/weights_' + id_ + '.png') plt.close()
def plot_weighted_prediction(id_, node_idx, max_duration=.5): # Get model params = pickle.load(open('../models/' + id_ + '/params.pkl', 'rb')) fs = params['resample'] model = models.GeneralRNN(params) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.load_state_dict( torch.load('../models/' + id_ + '/model.pth', map_location=device)) W = model.W.weight.data.numpy() w = W[node_idx, :] b = model.W.bias.data.numpy() b = b[node_idx] # Get prediction eval_prediction = pickle.load( open('../models/' + id_ + '/eval_prediction.pkl', 'rb')) if max_duration * fs >= eval_prediction['prediction'].shape[0]: max_duration = int((eval_prediction['prediction'].shape[0] - 1) / fs) prediction = eval_prediction['prediction'][-int(max_duration * fs):, node_idx] true = eval_prediction['true'][-int(max_duration * fs):, node_idx] # Get data data = utrain.pre_process(params=params).numpy() data = data[-int(max_duration * fs):, :] data[:, node_idx] = prediction sns.set_style('white') fig = plt.figure(figsize=(12, 4.3)) gs = fig.add_gridspec(1, 5) ax0 = fig.add_subplot(gs[:, :2]) sns.heatmap(W, cmap='seismic', vmin=-1, vmax=1) ax0.add_patch( mpl.patches.Rectangle((0, node_idx), data.shape[1], 1, fill=False, edgecolor='black', lw=3)) ax0.set_xlabel('From node'), ax0.set_ylabel('To node'), ax0.set_title( 'Weight matrix') ax1 = fig.add_subplot(gs[:, 2:]) t = np.arange(0, data.shape[0] / fs, 1 / fs) cmap = mpl.cm.get_cmap('seismic') for i in range(data.shape[1] - 1): color = cmap(w[i] / 2 + 0.5) alpha = np.abs(w[i]) / np.max(np.abs(w)) plt.plot(t, data[:, i], color=color, alpha=alpha) plt.plot(t, prediction, color='black', linestyle=':', label='predicted') plt.plot(t, true, color='black', label='true') ax1.set_xlabel('Time [s]'), ax1.set_ylabel( 'Membrane potential u(t) [a.U.]') ax1.set_xlim(0, t[-1]), ax1.set_title('Contribution to prediction of node ' + str(node_idx)), plt.legend() plt.tight_layout() plt.savefig('../doc/figures/contribution_' + id_ + '_node_' + str(node_idx) + '.png') plt.close()