def evaluate(model, loss_fn, data_iterator, params, num_steps): """Evaluate the model on `num_steps` batches. Args: model: (torch.nn.Module) the neural network loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch data_iterator: (generator) a generator that generates batches of data and labels params: (Params) hyperparameters num_steps: (int) number of batches to train on, each of size params.batch_size """ # set model to evaluation mode model.eval() # compute metrics over the dataset running_auc = utils.OutputAUC() running_metrics = utils.TestMetrics() # Use tqdm for progress bar with torch.no_grad(): t = trange(num_steps) for _ in t: # fetch the next evaluation batch train_batch_w2v, train_batch_sp, labels_batch, ids = next( data_iterator) if 'w2v' in params.emb: output_batch = model(train_batch_w2v) elif 'sp' in params.emb: output_batch = model(train_batch_sp) else: output_batch = model(train_batch_w2v, train_batch_sp) loss_fn(output_batch, labels_batch) running_auc.update(labels_batch.data.cpu().numpy(), output_batch.data.cpu().numpy()) running_metrics.update(labels_batch.data.cpu().numpy(), output_batch.data.cpu().numpy()) logging.info('AUCROC' + str(running_auc())) logging.info('METRICS' + str(running_metrics())) metrics = running_metrics() return {'AUCROC': metrics[0], "AUCPR": metrics[1]}
def evaluate_attn(model, loss_fn, data_iterator, params, num_steps, data_loader, model_dir): """Evaluate the model on `num_steps` batches. Args: model: (torch.nn.Module) the neural network loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch data_iterator: (generator) a generator that generates batches of data and labels params: (Params) hyperparameters num_steps: (int) number of batches to train on, each of size params.batch_size data_loader: data_loader that contains index to word mappings model_dir: model directory where attention rankings will be saved """ # set model to evaluation mode model.eval() # summary for current eval loop master_list = [] # compute metrics over the dataset running_metrics = utils.TestMetrics() # Use tqdm for progress bar with torch.no_grad(): t = trange(num_steps) for _ in t: # fetch the next evaluation batch train_batch_w2v, train_batch_sp, labels_batch, ids = next( data_iterator) if 'w2v' in params.emb: output_batch, attn_weights_w2v = model(train_batch_w2v, interpret=True) batch_word_indexes = train_batch_w2v[0].tolist() batch_text = [] for word_indexes in batch_word_indexes: unigrams, bigrams, trigrams = [], [], [] for ind in range(len(word_indexes)): if ind < 2: pre_context = data_loader.index_to_word_w2v[ word_indexes[ind - 1]] elif ind < 1: pre_context = '' else: pre_context = data_loader.index_to_word_w2v[word_indexes[ind - 2]] + ' ' + \ data_loader.index_to_word_w2v[word_indexes[ind - 1]] if ind + 4 < len(word_indexes): unigrams.append(pre_context + ' [' + data_loader.index_to_word_w2v[ word_indexes[ind]] + '] ' + data_loader.index_to_word_w2v[ word_indexes[ind + 1]] + ' ' + data_loader.index_to_word_w2v[ word_indexes[ind + 2]]) bigrams.append(pre_context + ' [' + data_loader.index_to_word_w2v[ word_indexes[ind]] + ' ' + data_loader.index_to_word_w2v[ word_indexes[ind + 1]] + '] ' + data_loader.index_to_word_w2v[ word_indexes[ind + 2]] + ' ' + data_loader.index_to_word_w2v[ word_indexes[ind + 3]]) trigrams.append(pre_context + ' [' + data_loader.index_to_word_w2v[ word_indexes[ind]] + ' ' + data_loader.index_to_word_w2v[ word_indexes[ind + 1]] + ' ' + data_loader.index_to_word_w2v[ word_indexes[ind + 2]] + '] ' + data_loader.index_to_word_w2v[ word_indexes[ind + 3]] + ' ' + data_loader.index_to_word_w2v[ word_indexes[ind + 4]]) elif ind + 3 < len(word_indexes): unigrams.append(pre_context + ' [' + data_loader.index_to_word_w2v[ word_indexes[ind]] + '] ' + data_loader.index_to_word_w2v[ word_indexes[ind + 1]] + ' ' + data_loader.index_to_word_w2v[ word_indexes[ind + 2]]) bigrams.append(pre_context + ' [' + data_loader.index_to_word_w2v[ word_indexes[ind]] + ' ' + data_loader.index_to_word_w2v[ word_indexes[ind + 1]] + '] ' + data_loader.index_to_word_w2v[ word_indexes[ind + 2]] + ' ' + data_loader.index_to_word_w2v[ word_indexes[ind + 3]]) trigrams.append(pre_context + ' [' + data_loader.index_to_word_w2v[ word_indexes[ind]] + ' ' + data_loader.index_to_word_w2v[ word_indexes[ind + 1]] + ' ' + data_loader.index_to_word_w2v[ word_indexes[ind + 2]] + '] ' + data_loader.index_to_word_w2v[ word_indexes[ind + 3]]) elif ind + 2 < len(word_indexes): unigrams.append(pre_context + ' [' + data_loader.index_to_word_w2v[ word_indexes[ind]] + '] ' + data_loader.index_to_word_w2v[ word_indexes[ind + 1]] + ' ' + data_loader.index_to_word_w2v[ word_indexes[ind + 2]]) bigrams.append(pre_context + ' [' + data_loader.index_to_word_w2v[ word_indexes[ind]] + ' ' + data_loader.index_to_word_w2v[ word_indexes[ind + 1]] + '] ' + data_loader.index_to_word_w2v[ word_indexes[ind + 2]]) trigrams.append(pre_context + ' [' + data_loader.index_to_word_w2v[ word_indexes[ind]] + ' ' + data_loader.index_to_word_w2v[ word_indexes[ind + 1]] + ' ' + data_loader.index_to_word_w2v[ word_indexes[ind + 2]] + '] ') elif ind + 1 < len(word_indexes): unigrams.append(pre_context + ' [' + data_loader.index_to_word_w2v[ word_indexes[ind]] + '] ' + data_loader.index_to_word_w2v[ word_indexes[ind + 1]]) bigrams.append(pre_context + ' [' + data_loader.index_to_word_w2v[ word_indexes[ind]] + ' ' + data_loader.index_to_word_w2v[ word_indexes[ind + 1]] + '] ') else: unigrams.append(pre_context + ' [' + data_loader.index_to_word_w2v[ word_indexes[ind]] + '] ') batch_text.append(unigrams + bigrams + ['<CONV_PAD>'] + trigrams + ['<CONV_PAD>'] + ['<CONV_PAD>']) output_list = output_batch.tolist() attn_weights_list = [x.tolist() for x in attn_weights_w2v] labels_batch_list = labels_batch.tolist() assert len(ids) == len(batch_text) assert len(ids) == len(labels_batch_list) assert len(ids) == len(output_list) assert len(ids) == len(attn_weights_list[0]) for head in range(len(attn_weights_list)): for index in range(len(ids)): temp_list = [] temp_list.append(ids[index]) temp_list.append('w2v') temp_list.append(head) temp_list.append(labels_batch_list[index][0]) temp_list.append(output_list[index][0]) attn_words = list( zip(attn_weights_list[head][index], batch_text[index])) attn_words.sort(reverse=True) new_attn_words = [ x for t in attn_words[:50] for x in t ] temp_list.extend(new_attn_words) master_list.append(temp_list) elif 'sp' in params.emb: output_batch, attn_weights_sp = model(train_batch_sp, interpret=True) batch_word_indexes = train_batch_sp[0].tolist() batch_text = [] for word_indexes in batch_word_indexes: unigrams, bigrams, trigrams = [], [], [] for ind in range(len(word_indexes)): if ind < 2: pre_context = data_loader.index_to_word_sp[ word_indexes[ind - 1]] elif ind < 1: pre_context = '' else: pre_context = data_loader.index_to_word_sp[word_indexes[ind - 2]] + ' ' + \ data_loader.index_to_word_sp[word_indexes[ind - 1]] if ind + 4 < len(word_indexes): unigrams.append( pre_context + ' [' + data_loader.index_to_word_sp[ word_indexes[ind]] + '] ' + data_loader.index_to_word_sp[word_indexes[ind + 1]] + ' ' + data_loader.index_to_word_sp[word_indexes[ind + 2]]) bigrams.append( pre_context + ' [' + data_loader.index_to_word_sp[ word_indexes[ind]] + ' ' + data_loader.index_to_word_sp[word_indexes[ind + 1]] + '] ' + data_loader.index_to_word_sp[word_indexes[ind + 2]] + ' ' + data_loader.index_to_word_sp[word_indexes[ind + 3]]) trigrams.append( pre_context + ' [' + data_loader.index_to_word_sp[ word_indexes[ind]] + ' ' + data_loader.index_to_word_sp[word_indexes[ind + 1]] + ' ' + data_loader.index_to_word_sp[word_indexes[ind + 2]] + '] ' + data_loader.index_to_word_sp[word_indexes[ind + 3]] + ' ' + data_loader.index_to_word_sp[word_indexes[ind + 4]]) elif ind + 3 < len(word_indexes): unigrams.append( pre_context + ' [' + data_loader.index_to_word_sp[ word_indexes[ind]] + '] ' + data_loader.index_to_word_sp[word_indexes[ind + 1]] + ' ' + data_loader.index_to_word_sp[word_indexes[ind + 2]]) bigrams.append( pre_context + ' [' + data_loader.index_to_word_sp[ word_indexes[ind]] + ' ' + data_loader.index_to_word_sp[word_indexes[ind + 1]] + '] ' + data_loader.index_to_word_sp[word_indexes[ind + 2]] + ' ' + data_loader.index_to_word_sp[word_indexes[ind + 3]]) trigrams.append( pre_context + ' [' + data_loader.index_to_word_sp[ word_indexes[ind]] + ' ' + data_loader.index_to_word_sp[word_indexes[ind + 1]] + ' ' + data_loader.index_to_word_sp[word_indexes[ind + 2]] + '] ' + data_loader.index_to_word_sp[word_indexes[ind + 3]]) elif ind + 2 < len(word_indexes): unigrams.append( pre_context + ' [' + data_loader.index_to_word_sp[ word_indexes[ind]] + '] ' + data_loader.index_to_word_sp[word_indexes[ind + 1]] + ' ' + data_loader.index_to_word_sp[word_indexes[ind + 2]]) bigrams.append( pre_context + ' [' + data_loader.index_to_word_sp[ word_indexes[ind]] + ' ' + data_loader.index_to_word_sp[word_indexes[ind + 1]] + '] ' + data_loader.index_to_word_sp[word_indexes[ind + 2]]) trigrams.append( pre_context + ' [' + data_loader.index_to_word_sp[ word_indexes[ind]] + ' ' + data_loader.index_to_word_sp[word_indexes[ind + 1]] + ' ' + data_loader.index_to_word_sp[word_indexes[ind + 2]] + '] ') elif ind + 1 < len(word_indexes): unigrams.append(pre_context + ' [' + data_loader.index_to_word_sp[ word_indexes[ind]] + '] ' + data_loader.index_to_word_sp[ word_indexes[ind + 1]]) bigrams.append(pre_context + ' [' + data_loader.index_to_word_sp[ word_indexes[ind]] + ' ' + data_loader.index_to_word_sp[ word_indexes[ind + 1]] + '] ') else: unigrams.append(pre_context + ' [' + data_loader.index_to_word_sp[ word_indexes[ind]] + '] ') batch_text.append(unigrams + bigrams + trigrams) output_list = output_batch.tolist() attn_weights_list = [x.tolist() for x in attn_weights_sp] labels_batch_list = labels_batch.tolist() assert len(ids) == len(batch_text) assert len(ids) == len(labels_batch_list) assert len(ids) == len(output_list) assert len(ids) == len(attn_weights_list[0]) for head in range(len(attn_weights_list)): for index in range(len(ids)): temp_list = [] temp_list.append(ids[index]) temp_list.append(head) temp_list.append('sp300') temp_list.append(labels_batch_list[index][0]) temp_list.append(output_list[index][0]) attn_words = list( zip(attn_weights_list[head][index], batch_text[index])) attn_words.sort(reverse=True) new_attn_words = [ x for t in attn_words[:50] for x in t ] temp_list.extend(new_attn_words) master_list.append(temp_list) output_batch = model(train_batch_sp) loss_fn(output_batch, labels_batch) running_metrics.update(labels_batch.data.cpu().numpy(), output_batch.data.cpu().numpy()) df_attn_scores = pd.DataFrame( master_list, columns=[ "ICUSTAY_ID", 'head', 'embedding', params.task + "_label", params.task + "_prediction" ] + [ 'attn_' + str(i // 2) if i % 2 == 0 else 'words_' + str(i // 2) for i in range(100) ]) print(df_attn_scores.dtypes) df_attn_scores.sort_values(by=[params.task + "_prediction"], ascending=False, inplace=True) print(df_attn_scores.head(5)) datasetPath = os.path.join(model_dir, 'df_attn.csv') df_attn_scores.to_csv(datasetPath, index=False) logging.info('AUCROC' + str(running_metrics())) metrics = running_metrics() return {'AUCROC': metrics[0], "AUCPR": metrics[1]}
def allied_final_evaluate(model, loss_fn, data_iterator, metrics, params, num_steps, allied=False): # set model to evaluation mode model.eval() # summary for current eval loop summ = [] # compute metrics over the dataset running_auc = utils.OutputAUC() running_metrics = utils.TestMetrics() running_icd = utils.MetricsICD() # Use tqdm for progress bar with torch.no_grad(): t = trange(num_steps) for i in t: # fetch the next evaluation batch train_batch_w2v, train_batch_sp, labels_batch, icd_labels, ids = next( data_iterator) if 'w2v' in params.emb: output_batch, icd_batch = model(train_batch_w2v) elif 'sp' in params.emb: output_batch, icd_batch = model(train_batch_sp) else: output_batch, icd_batch = model(train_batch_w2v, train_batch_sp) loss = loss_fn(output_batch, labels_batch) # print(loss) running_icd.update(icd_labels.data.cpu().numpy(), icd_batch.data.cpu().numpy()) running_auc.update(labels_batch.data.cpu().numpy(), output_batch.data.cpu().numpy()) running_metrics.update(labels_batch.data.cpu().numpy(), output_batch.data.cpu().numpy()) # extract data from torch Variable, move to cpu, convert to numpy arrays output_batch = output_batch.data.cpu().numpy() labels_batch = labels_batch.data.cpu().numpy() # compute all metrics on this batch summary_batch = { metric: metrics[metric](output_batch, labels_batch) for metric in metrics } summary_batch['loss'] = loss.data.item() summ.append(summary_batch) metrics = running_metrics() metrics_mean = { metric: np.mean([x[metric] for x in summ]) for metric in summ[0] } metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in metrics_mean.items()) logging.info("- Eval metrics : " + metrics_string) logging.info('AUCROC' + str(running_auc())) logging.info('AUCROC' + str(metrics[0])) logging.info('AUCPR' + str(metrics[1])) logging.info('MICRO AUCROC_ICD' + str(running_icd())) macro_auc = running_icd.macro_auc() logging.info('MACRO AUCROC_ICD' + str(macro_auc)) return { 'AUCROC': metrics[0], "AUCPR": metrics[1], "MICRO_AUCROC_ICD": running_icd(), "MACRO_AUCROC_ICD": macro_auc }
def evaluate(model, loss_fn, data_iterator, metrics, params, num_steps): # set model to evaluation mode model.eval() # summary for current eval loop summ = [] # compute metrics over the dataset running_auc = utils.OutputAUC() running_metrics = utils.TestMetrics() running_icd = utils.MetricsICD() # Use tqdm for progress bar with torch.no_grad(): t = trange(num_steps) for i in t: # fetch the next evaluation batch if 'phen' in params.model: if params.task == 'icd': train_batch_w2v, train_batch_sp, _, labels_batch = next( data_iterator) else: train_batch_w2v, train_batch_sp, labels_batch, _ = next( data_iterator) output_batch = model(train_batch_w2v) loss = loss_fn(output_batch, labels_batch) loss = loss / params.grad_acc # Normalize our loss (if averaged) # print(loss) elif params.model == 'lr': train_batch, labels_batch = next(data_iterator) output_batch = model(train_batch) loss = loss_fn(output_batch, labels_batch) loss = loss / params.grad_acc # Normalize our loss (if averaged) # print(loss) else: train_batch_w2v, train_batch_sp, labels_batch, _, ids = next( data_iterator) output_batch = model(train_batch_w2v) loss = loss_fn(output_batch, labels_batch) running_auc.update(labels_batch.data.cpu().numpy(), output_batch.data.cpu().numpy()) running_metrics.update(labels_batch.data.cpu().numpy(), output_batch.data.cpu().numpy()) if params.task == 'icd_only': running_icd.update(labels_batch.data.cpu().numpy(), output_batch.data.cpu().numpy()) # extract data from torch Variable, move to cpu, convert to numpy arrays output_batch = output_batch.data.cpu().numpy() labels_batch = labels_batch.data.cpu().numpy() # compute all metrics on this batch summary_batch = { metric: metrics[metric](output_batch, labels_batch) for metric in metrics } summary_batch['loss'] = loss.data.item() summ.append(summary_batch) metrics_mean = { metric: np.mean([x[metric] for x in summ]) for metric in summ[0] } metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in metrics_mean.items()) logging.info("- Eval metrics : " + metrics_string) logging.info('AUCROC' + str(running_auc())) if params.task == 'icd_only': return { 'AUCROC': running_icd(), 'MACRO_AUCROC_ICD': running_icd.macro_auc() } else: logging.info('METRICS' + str(running_metrics())) metrics = running_metrics() return {'AUCROC': metrics[0], "AUCPR": metrics[1]}
def evaluate_attn(model, loss_fn, data_iterator, metrics, params, num_steps, data_loader, model_dir): # set model to evaluation mode model.eval() # summary for current eval loop summ = [] master_list = [] # compute metrics over the dataset running_metrics = utils.TestMetrics() # Use tqdm for progress bar with torch.no_grad(): t = trange(num_steps) for i in t: # fetch the next evaluation batch train_batch_w2v, labels_batch, ids = next(data_iterator) output_batch, attn_weights_w2v = model(train_batch_w2v, interpret=True) batch_word_indexes = train_batch_w2v[0].tolist() batch_text = [] for word_indexes in batch_word_indexes: unigrams, bigrams, trigrams = [], [], [] for ind in range(len(word_indexes)): if ind < 2: pre_context = data_loader.index_to_word_w2v[ word_indexes[ind - 1]] elif ind < 1: pre_context = '' else: pre_context = data_loader.index_to_word_w2v[word_indexes[ind - 2]] + ' ' + \ data_loader.index_to_word_w2v[word_indexes[ind - 1]] if ind + 4 < len(word_indexes): unigrams.append( pre_context + ' [' + data_loader.index_to_word_w2v[word_indexes[ind]] + '] ' + data_loader.index_to_word_w2v[word_indexes[ind + 1]] + ' ' + data_loader.index_to_word_w2v[word_indexes[ind + 2]]) bigrams.append( pre_context + ' [' + data_loader.index_to_word_w2v[word_indexes[ind]] + ' ' + data_loader.index_to_word_w2v[word_indexes[ind + 1]] + '] ' + data_loader.index_to_word_w2v[word_indexes[ind + 2]] + ' ' + data_loader.index_to_word_w2v[word_indexes[ind + 3]]) trigrams.append( pre_context + ' [' + data_loader.index_to_word_w2v[word_indexes[ind]] + ' ' + data_loader.index_to_word_w2v[word_indexes[ind + 1]] + ' ' + data_loader.index_to_word_w2v[word_indexes[ind + 2]] + '] ' + data_loader.index_to_word_w2v[word_indexes[ind + 3]] + ' ' + data_loader.index_to_word_w2v[word_indexes[ind + 4]]) elif ind + 3 < len(word_indexes): unigrams.append( pre_context + ' [' + data_loader.index_to_word_w2v[word_indexes[ind]] + '] ' + data_loader.index_to_word_w2v[word_indexes[ind + 1]] + ' ' + data_loader.index_to_word_w2v[word_indexes[ind + 2]]) bigrams.append( pre_context + ' [' + data_loader.index_to_word_w2v[word_indexes[ind]] + ' ' + data_loader.index_to_word_w2v[word_indexes[ind + 1]] + '] ' + data_loader.index_to_word_w2v[word_indexes[ind + 2]] + ' ' + data_loader.index_to_word_w2v[word_indexes[ind + 3]]) trigrams.append( pre_context + ' [' + data_loader.index_to_word_w2v[word_indexes[ind]] + ' ' + data_loader.index_to_word_w2v[word_indexes[ind + 1]] + ' ' + data_loader.index_to_word_w2v[word_indexes[ind + 2]] + '] ' + data_loader.index_to_word_w2v[word_indexes[ind + 3]]) elif ind + 2 < len(word_indexes): unigrams.append( pre_context + ' [' + data_loader.index_to_word_w2v[word_indexes[ind]] + '] ' + data_loader.index_to_word_w2v[word_indexes[ind + 1]] + ' ' + data_loader.index_to_word_w2v[word_indexes[ind + 2]]) bigrams.append( pre_context + ' [' + data_loader.index_to_word_w2v[word_indexes[ind]] + ' ' + data_loader.index_to_word_w2v[word_indexes[ind + 1]] + '] ' + data_loader.index_to_word_w2v[word_indexes[ind + 2]]) trigrams.append( pre_context + ' [' + data_loader.index_to_word_w2v[word_indexes[ind]] + ' ' + data_loader.index_to_word_w2v[word_indexes[ind + 1]] + ' ' + data_loader.index_to_word_w2v[word_indexes[ind + 2]] + '] ') elif ind + 1 < len(word_indexes): unigrams.append( pre_context + ' [' + data_loader.index_to_word_w2v[word_indexes[ind]] + '] ' + data_loader.index_to_word_w2v[word_indexes[ind + 1]]) bigrams.append( pre_context + ' [' + data_loader.index_to_word_w2v[word_indexes[ind]] + ' ' + data_loader.index_to_word_w2v[word_indexes[ind + 1]] + '] ') else: unigrams.append( pre_context + ' [' + data_loader.index_to_word_w2v[word_indexes[ind]] + '] ') batch_text.append(unigrams + bigrams + ['<CONV_PAD>'] + trigrams + ['<CONV_PAD>'] + ['<CONV_PAD>']) output_list = output_batch.tolist() attn_weights_list = [x.tolist() for x in attn_weights_w2v] labels_batch_list = labels_batch.tolist() assert len(ids) == len(batch_text) assert len(ids) == len(labels_batch_list) assert len(ids) == len(output_list) assert len(ids) == len(attn_weights_list[0]) for head in range(len(attn_weights_list)): for index in range(len(ids)): temp_list = [] temp_list.append(ids[index]) temp_list.append('w2v') temp_list.append(head) temp_list.append(labels_batch_list[index][0]) temp_list.append(output_list[index][0]) attn_words = list( zip(attn_weights_list[head][index], batch_text[index])) attn_words.sort(reverse=True) new_attn_words = [x for t in attn_words[:50] for x in t] temp_list.extend(new_attn_words) master_list.append(temp_list) loss = loss_fn(output_batch, labels_batch) running_metrics.update(labels_batch.data.cpu().numpy(), output_batch.data.cpu().numpy()) # extract data from torch Variable, move to cpu, convert to numpy arrays output_batch = output_batch.data.cpu().numpy() labels_batch = labels_batch.data.cpu().numpy() # compute all metrics on this batch summary_batch = { metric: metrics[metric](output_batch, labels_batch) for metric in metrics } summary_batch['loss'] = loss.data.item() summ.append(summary_batch) df_attn_scores = pd.DataFrame( master_list, columns=[ "ICUSTAY_ID", 'head', 'embedding', params.task + "_label", params.task + "_prediction" ] + [ 'attn_' + str(i // 2) if i % 2 == 0 else 'words_' + str(i // 2) for i in range(100) ]) print(df_attn_scores.dtypes) df_attn_scores.sort_values(by=[params.task + "_prediction"], ascending=False, inplace=True) print(df_attn_scores.head(5)) datasetPath = os.path.join(model_dir, 'df_attn.csv') df_attn_scores.to_csv(datasetPath, index=False) metrics_mean = { metric: np.mean([x[metric] for x in summ]) for metric in summ[0] } metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in metrics_mean.items()) logging.info("- Eval metrics : " + metrics_string) logging.info('AUCROC' + str(running_metrics())) metrics = running_metrics() return {'AUCROC': metrics[0], "AUCPR": metrics[1]}