コード例 #1
0
def evaluate(model, loss_fn, data_iterator, params, num_steps):
    """Evaluate the model on `num_steps` batches.
    Args:
        model: (torch.nn.Module) the neural network
        loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch
        data_iterator: (generator) a generator that generates batches of data and labels
        params: (Params) hyperparameters
        num_steps: (int) number of batches to train on, each of size params.batch_size
    """

    # set model to evaluation mode

    model.eval()
    # compute metrics over the dataset
    running_auc = utils.OutputAUC()
    running_metrics = utils.TestMetrics()
    # Use tqdm for progress bar
    with torch.no_grad():
        t = trange(num_steps)
        for _ in t:
            # fetch the next evaluation batch
            train_batch_w2v, train_batch_sp, labels_batch, ids = next(
                data_iterator)
            if 'w2v' in params.emb:
                output_batch = model(train_batch_w2v)
            elif 'sp' in params.emb:
                output_batch = model(train_batch_sp)
            else:
                output_batch = model(train_batch_w2v, train_batch_sp)
            loss_fn(output_batch, labels_batch)
            running_auc.update(labels_batch.data.cpu().numpy(),
                               output_batch.data.cpu().numpy())
            running_metrics.update(labels_batch.data.cpu().numpy(),
                                   output_batch.data.cpu().numpy())

    logging.info('AUCROC' + str(running_auc()))
    logging.info('METRICS' + str(running_metrics()))
    metrics = running_metrics()
    return {'AUCROC': metrics[0], "AUCPR": metrics[1]}
コード例 #2
0
def evaluate_attn(model, loss_fn, data_iterator, params, num_steps,
                  data_loader, model_dir):
    """Evaluate the model on `num_steps` batches.
    Args:
        model: (torch.nn.Module) the neural network
        loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch
        data_iterator: (generator) a generator that generates batches of data and labels
        params: (Params) hyperparameters
        num_steps: (int) number of batches to train on, each of size params.batch_size
        data_loader: data_loader that contains index to word mappings
        model_dir: model directory where attention rankings will be saved
    """
    # set model to evaluation mode

    model.eval()

    # summary for current eval loop
    master_list = []

    # compute metrics over the dataset

    running_metrics = utils.TestMetrics()
    # Use tqdm for progress bar
    with torch.no_grad():
        t = trange(num_steps)
        for _ in t:
            # fetch the next evaluation batch
            train_batch_w2v, train_batch_sp, labels_batch, ids = next(
                data_iterator)
            if 'w2v' in params.emb:
                output_batch, attn_weights_w2v = model(train_batch_w2v,
                                                       interpret=True)
                batch_word_indexes = train_batch_w2v[0].tolist()
                batch_text = []
                for word_indexes in batch_word_indexes:
                    unigrams, bigrams, trigrams = [], [], []
                    for ind in range(len(word_indexes)):
                        if ind < 2:
                            pre_context = data_loader.index_to_word_w2v[
                                word_indexes[ind - 1]]
                        elif ind < 1:
                            pre_context = ''
                        else:
                            pre_context = data_loader.index_to_word_w2v[word_indexes[ind - 2]] + ' ' + \
                                          data_loader.index_to_word_w2v[word_indexes[ind - 1]]
                        if ind + 4 < len(word_indexes):
                            unigrams.append(pre_context + ' [' +
                                            data_loader.index_to_word_w2v[
                                                word_indexes[ind]] + '] ' +
                                            data_loader.index_to_word_w2v[
                                                word_indexes[ind + 1]] + ' ' +
                                            data_loader.index_to_word_w2v[
                                                word_indexes[ind + 2]])
                            bigrams.append(pre_context + ' [' +
                                           data_loader.index_to_word_w2v[
                                               word_indexes[ind]] + ' ' +
                                           data_loader.index_to_word_w2v[
                                               word_indexes[ind + 1]] + '] ' +
                                           data_loader.index_to_word_w2v[
                                               word_indexes[ind + 2]] + ' ' +
                                           data_loader.index_to_word_w2v[
                                               word_indexes[ind + 3]])
                            trigrams.append(pre_context + ' [' +
                                            data_loader.index_to_word_w2v[
                                                word_indexes[ind]] + ' ' +
                                            data_loader.index_to_word_w2v[
                                                word_indexes[ind + 1]] + ' ' +
                                            data_loader.index_to_word_w2v[
                                                word_indexes[ind + 2]] + '] ' +
                                            data_loader.index_to_word_w2v[
                                                word_indexes[ind + 3]] + ' ' +
                                            data_loader.index_to_word_w2v[
                                                word_indexes[ind + 4]])
                        elif ind + 3 < len(word_indexes):
                            unigrams.append(pre_context + ' [' +
                                            data_loader.index_to_word_w2v[
                                                word_indexes[ind]] + '] ' +
                                            data_loader.index_to_word_w2v[
                                                word_indexes[ind + 1]] + ' ' +
                                            data_loader.index_to_word_w2v[
                                                word_indexes[ind + 2]])
                            bigrams.append(pre_context + ' [' +
                                           data_loader.index_to_word_w2v[
                                               word_indexes[ind]] + ' ' +
                                           data_loader.index_to_word_w2v[
                                               word_indexes[ind + 1]] + '] ' +
                                           data_loader.index_to_word_w2v[
                                               word_indexes[ind + 2]] + ' ' +
                                           data_loader.index_to_word_w2v[
                                               word_indexes[ind + 3]])
                            trigrams.append(pre_context + ' [' +
                                            data_loader.index_to_word_w2v[
                                                word_indexes[ind]] + ' ' +
                                            data_loader.index_to_word_w2v[
                                                word_indexes[ind + 1]] + ' ' +
                                            data_loader.index_to_word_w2v[
                                                word_indexes[ind + 2]] + '] ' +
                                            data_loader.index_to_word_w2v[
                                                word_indexes[ind + 3]])
                        elif ind + 2 < len(word_indexes):
                            unigrams.append(pre_context + ' [' +
                                            data_loader.index_to_word_w2v[
                                                word_indexes[ind]] + '] ' +
                                            data_loader.index_to_word_w2v[
                                                word_indexes[ind + 1]] + ' ' +
                                            data_loader.index_to_word_w2v[
                                                word_indexes[ind + 2]])
                            bigrams.append(pre_context + ' [' +
                                           data_loader.index_to_word_w2v[
                                               word_indexes[ind]] + ' ' +
                                           data_loader.index_to_word_w2v[
                                               word_indexes[ind + 1]] + '] ' +
                                           data_loader.index_to_word_w2v[
                                               word_indexes[ind + 2]])
                            trigrams.append(pre_context + ' [' +
                                            data_loader.index_to_word_w2v[
                                                word_indexes[ind]] + ' ' +
                                            data_loader.index_to_word_w2v[
                                                word_indexes[ind + 1]] + ' ' +
                                            data_loader.index_to_word_w2v[
                                                word_indexes[ind + 2]] + '] ')
                        elif ind + 1 < len(word_indexes):
                            unigrams.append(pre_context + ' [' +
                                            data_loader.index_to_word_w2v[
                                                word_indexes[ind]] + '] ' +
                                            data_loader.index_to_word_w2v[
                                                word_indexes[ind + 1]])
                            bigrams.append(pre_context + ' [' +
                                           data_loader.index_to_word_w2v[
                                               word_indexes[ind]] + ' ' +
                                           data_loader.index_to_word_w2v[
                                               word_indexes[ind + 1]] + '] ')
                        else:
                            unigrams.append(pre_context + ' [' +
                                            data_loader.index_to_word_w2v[
                                                word_indexes[ind]] + '] ')

                    batch_text.append(unigrams + bigrams + ['<CONV_PAD>'] +
                                      trigrams + ['<CONV_PAD>'] +
                                      ['<CONV_PAD>'])
                output_list = output_batch.tolist()
                attn_weights_list = [x.tolist() for x in attn_weights_w2v]
                labels_batch_list = labels_batch.tolist()
                assert len(ids) == len(batch_text)
                assert len(ids) == len(labels_batch_list)
                assert len(ids) == len(output_list)
                assert len(ids) == len(attn_weights_list[0])
                for head in range(len(attn_weights_list)):
                    for index in range(len(ids)):
                        temp_list = []
                        temp_list.append(ids[index])
                        temp_list.append('w2v')
                        temp_list.append(head)
                        temp_list.append(labels_batch_list[index][0])
                        temp_list.append(output_list[index][0])
                        attn_words = list(
                            zip(attn_weights_list[head][index],
                                batch_text[index]))
                        attn_words.sort(reverse=True)
                        new_attn_words = [
                            x for t in attn_words[:50] for x in t
                        ]
                        temp_list.extend(new_attn_words)
                        master_list.append(temp_list)
            elif 'sp' in params.emb:
                output_batch, attn_weights_sp = model(train_batch_sp,
                                                      interpret=True)
                batch_word_indexes = train_batch_sp[0].tolist()
                batch_text = []
                for word_indexes in batch_word_indexes:
                    unigrams, bigrams, trigrams = [], [], []
                    for ind in range(len(word_indexes)):
                        if ind < 2:
                            pre_context = data_loader.index_to_word_sp[
                                word_indexes[ind - 1]]
                        elif ind < 1:
                            pre_context = ''
                        else:
                            pre_context = data_loader.index_to_word_sp[word_indexes[ind - 2]] + ' ' + \
                                          data_loader.index_to_word_sp[word_indexes[ind - 1]]

                        if ind + 4 < len(word_indexes):
                            unigrams.append(
                                pre_context + ' [' +
                                data_loader.index_to_word_sp[
                                    word_indexes[ind]] + '] ' +
                                data_loader.index_to_word_sp[word_indexes[ind +
                                                                          1]] +
                                ' ' +
                                data_loader.index_to_word_sp[word_indexes[ind +
                                                                          2]])
                            bigrams.append(
                                pre_context + ' [' +
                                data_loader.index_to_word_sp[
                                    word_indexes[ind]] + ' ' +
                                data_loader.index_to_word_sp[word_indexes[ind +
                                                                          1]] +
                                '] ' +
                                data_loader.index_to_word_sp[word_indexes[ind +
                                                                          2]] +
                                ' ' +
                                data_loader.index_to_word_sp[word_indexes[ind +
                                                                          3]])
                            trigrams.append(
                                pre_context + ' [' +
                                data_loader.index_to_word_sp[
                                    word_indexes[ind]] + ' ' +
                                data_loader.index_to_word_sp[word_indexes[ind +
                                                                          1]] +
                                ' ' +
                                data_loader.index_to_word_sp[word_indexes[ind +
                                                                          2]] +
                                '] ' +
                                data_loader.index_to_word_sp[word_indexes[ind +
                                                                          3]] +
                                ' ' +
                                data_loader.index_to_word_sp[word_indexes[ind +
                                                                          4]])
                        elif ind + 3 < len(word_indexes):
                            unigrams.append(
                                pre_context + ' [' +
                                data_loader.index_to_word_sp[
                                    word_indexes[ind]] + '] ' +
                                data_loader.index_to_word_sp[word_indexes[ind +
                                                                          1]] +
                                ' ' +
                                data_loader.index_to_word_sp[word_indexes[ind +
                                                                          2]])
                            bigrams.append(
                                pre_context + ' [' +
                                data_loader.index_to_word_sp[
                                    word_indexes[ind]] + ' ' +
                                data_loader.index_to_word_sp[word_indexes[ind +
                                                                          1]] +
                                '] ' +
                                data_loader.index_to_word_sp[word_indexes[ind +
                                                                          2]] +
                                ' ' +
                                data_loader.index_to_word_sp[word_indexes[ind +
                                                                          3]])
                            trigrams.append(
                                pre_context + ' [' +
                                data_loader.index_to_word_sp[
                                    word_indexes[ind]] + ' ' +
                                data_loader.index_to_word_sp[word_indexes[ind +
                                                                          1]] +
                                ' ' +
                                data_loader.index_to_word_sp[word_indexes[ind +
                                                                          2]] +
                                '] ' +
                                data_loader.index_to_word_sp[word_indexes[ind +
                                                                          3]])
                        elif ind + 2 < len(word_indexes):
                            unigrams.append(
                                pre_context + ' [' +
                                data_loader.index_to_word_sp[
                                    word_indexes[ind]] + '] ' +
                                data_loader.index_to_word_sp[word_indexes[ind +
                                                                          1]] +
                                ' ' +
                                data_loader.index_to_word_sp[word_indexes[ind +
                                                                          2]])
                            bigrams.append(
                                pre_context + ' [' +
                                data_loader.index_to_word_sp[
                                    word_indexes[ind]] + ' ' +
                                data_loader.index_to_word_sp[word_indexes[ind +
                                                                          1]] +
                                '] ' +
                                data_loader.index_to_word_sp[word_indexes[ind +
                                                                          2]])
                            trigrams.append(
                                pre_context + ' [' +
                                data_loader.index_to_word_sp[
                                    word_indexes[ind]] + ' ' +
                                data_loader.index_to_word_sp[word_indexes[ind +
                                                                          1]] +
                                ' ' +
                                data_loader.index_to_word_sp[word_indexes[ind +
                                                                          2]] +
                                '] ')
                        elif ind + 1 < len(word_indexes):
                            unigrams.append(pre_context + ' [' +
                                            data_loader.index_to_word_sp[
                                                word_indexes[ind]] + '] ' +
                                            data_loader.index_to_word_sp[
                                                word_indexes[ind + 1]])
                            bigrams.append(pre_context + ' [' +
                                           data_loader.index_to_word_sp[
                                               word_indexes[ind]] + ' ' +
                                           data_loader.index_to_word_sp[
                                               word_indexes[ind + 1]] + '] ')
                        else:
                            unigrams.append(pre_context + ' [' +
                                            data_loader.index_to_word_sp[
                                                word_indexes[ind]] + '] ')

                    batch_text.append(unigrams + bigrams + trigrams)
                output_list = output_batch.tolist()
                attn_weights_list = [x.tolist() for x in attn_weights_sp]
                labels_batch_list = labels_batch.tolist()
                assert len(ids) == len(batch_text)
                assert len(ids) == len(labels_batch_list)
                assert len(ids) == len(output_list)
                assert len(ids) == len(attn_weights_list[0])
                for head in range(len(attn_weights_list)):
                    for index in range(len(ids)):
                        temp_list = []
                        temp_list.append(ids[index])
                        temp_list.append(head)
                        temp_list.append('sp300')
                        temp_list.append(labels_batch_list[index][0])
                        temp_list.append(output_list[index][0])
                        attn_words = list(
                            zip(attn_weights_list[head][index],
                                batch_text[index]))
                        attn_words.sort(reverse=True)
                        new_attn_words = [
                            x for t in attn_words[:50] for x in t
                        ]

                        temp_list.extend(new_attn_words)
                        master_list.append(temp_list)
                output_batch = model(train_batch_sp)
            loss_fn(output_batch, labels_batch)
            running_metrics.update(labels_batch.data.cpu().numpy(),
                                   output_batch.data.cpu().numpy())

    df_attn_scores = pd.DataFrame(
        master_list,
        columns=[
            "ICUSTAY_ID", 'head', 'embedding', params.task + "_label",
            params.task + "_prediction"
        ] + [
            'attn_' + str(i // 2) if i % 2 == 0 else 'words_' + str(i // 2)
            for i in range(100)
        ])
    print(df_attn_scores.dtypes)
    df_attn_scores.sort_values(by=[params.task + "_prediction"],
                               ascending=False,
                               inplace=True)
    print(df_attn_scores.head(5))
    datasetPath = os.path.join(model_dir, 'df_attn.csv')
    df_attn_scores.to_csv(datasetPath, index=False)
    logging.info('AUCROC' + str(running_metrics()))
    metrics = running_metrics()
    return {'AUCROC': metrics[0], "AUCPR": metrics[1]}
コード例 #3
0
def allied_final_evaluate(model,
                          loss_fn,
                          data_iterator,
                          metrics,
                          params,
                          num_steps,
                          allied=False):
    # set model to evaluation mode

    model.eval()

    # summary for current eval loop
    summ = []

    # compute metrics over the dataset
    running_auc = utils.OutputAUC()
    running_metrics = utils.TestMetrics()
    running_icd = utils.MetricsICD()
    # Use tqdm for progress bar
    with torch.no_grad():
        t = trange(num_steps)
        for i in t:
            # fetch the next evaluation batch
            train_batch_w2v, train_batch_sp, labels_batch, icd_labels, ids = next(
                data_iterator)
            if 'w2v' in params.emb:
                output_batch, icd_batch = model(train_batch_w2v)
            elif 'sp' in params.emb:
                output_batch, icd_batch = model(train_batch_sp)
            else:
                output_batch, icd_batch = model(train_batch_w2v,
                                                train_batch_sp)
            loss = loss_fn(output_batch, labels_batch)
            # print(loss)
            running_icd.update(icd_labels.data.cpu().numpy(),
                               icd_batch.data.cpu().numpy())
            running_auc.update(labels_batch.data.cpu().numpy(),
                               output_batch.data.cpu().numpy())
            running_metrics.update(labels_batch.data.cpu().numpy(),
                                   output_batch.data.cpu().numpy())
            # extract data from torch Variable, move to cpu, convert to numpy arrays
            output_batch = output_batch.data.cpu().numpy()
            labels_batch = labels_batch.data.cpu().numpy()

            # compute all metrics on this batch
            summary_batch = {
                metric: metrics[metric](output_batch, labels_batch)
                for metric in metrics
            }
            summary_batch['loss'] = loss.data.item()
            summ.append(summary_batch)

    metrics = running_metrics()
    metrics_mean = {
        metric: np.mean([x[metric] for x in summ])
        for metric in summ[0]
    }
    metrics_string = " ; ".join("{}: {:05.3f}".format(k, v)
                                for k, v in metrics_mean.items())
    logging.info("- Eval metrics : " + metrics_string)
    logging.info('AUCROC' + str(running_auc()))
    logging.info('AUCROC' + str(metrics[0]))
    logging.info('AUCPR' + str(metrics[1]))
    logging.info('MICRO AUCROC_ICD' + str(running_icd()))
    macro_auc = running_icd.macro_auc()
    logging.info('MACRO AUCROC_ICD' + str(macro_auc))

    return {
        'AUCROC': metrics[0],
        "AUCPR": metrics[1],
        "MICRO_AUCROC_ICD": running_icd(),
        "MACRO_AUCROC_ICD": macro_auc
    }
コード例 #4
0
def evaluate(model, loss_fn, data_iterator, metrics, params, num_steps):
    # set model to evaluation mode

    model.eval()

    # summary for current eval loop
    summ = []

    # compute metrics over the dataset
    running_auc = utils.OutputAUC()
    running_metrics = utils.TestMetrics()
    running_icd = utils.MetricsICD()
    # Use tqdm for progress bar
    with torch.no_grad():
        t = trange(num_steps)
        for i in t:
            # fetch the next evaluation batch
            if 'phen' in params.model:
                if params.task == 'icd':
                    train_batch_w2v, train_batch_sp, _, labels_batch = next(
                        data_iterator)
                else:
                    train_batch_w2v, train_batch_sp, labels_batch, _ = next(
                        data_iterator)
                output_batch = model(train_batch_w2v)
                loss = loss_fn(output_batch, labels_batch)
                loss = loss / params.grad_acc  # Normalize our loss (if averaged)
                # print(loss)
            elif params.model == 'lr':
                train_batch, labels_batch = next(data_iterator)
                output_batch = model(train_batch)
                loss = loss_fn(output_batch, labels_batch)

                loss = loss / params.grad_acc  # Normalize our loss (if averaged)
                # print(loss)
            else:
                train_batch_w2v, train_batch_sp, labels_batch, _, ids = next(
                    data_iterator)
                output_batch = model(train_batch_w2v)
                loss = loss_fn(output_batch, labels_batch)

            running_auc.update(labels_batch.data.cpu().numpy(),
                               output_batch.data.cpu().numpy())
            running_metrics.update(labels_batch.data.cpu().numpy(),
                                   output_batch.data.cpu().numpy())
            if params.task == 'icd_only':
                running_icd.update(labels_batch.data.cpu().numpy(),
                                   output_batch.data.cpu().numpy())
            # extract data from torch Variable, move to cpu, convert to numpy arrays
            output_batch = output_batch.data.cpu().numpy()
            labels_batch = labels_batch.data.cpu().numpy()

            # compute all metrics on this batch
            summary_batch = {
                metric: metrics[metric](output_batch, labels_batch)
                for metric in metrics
            }
            summary_batch['loss'] = loss.data.item()
            summ.append(summary_batch)

    metrics_mean = {
        metric: np.mean([x[metric] for x in summ])
        for metric in summ[0]
    }
    metrics_string = " ; ".join("{}: {:05.3f}".format(k, v)
                                for k, v in metrics_mean.items())
    logging.info("- Eval metrics : " + metrics_string)
    logging.info('AUCROC' + str(running_auc()))

    if params.task == 'icd_only':
        return {
            'AUCROC': running_icd(),
            'MACRO_AUCROC_ICD': running_icd.macro_auc()
        }
    else:
        logging.info('METRICS' + str(running_metrics()))
        metrics = running_metrics()
        return {'AUCROC': metrics[0], "AUCPR": metrics[1]}
コード例 #5
0
def evaluate_attn(model, loss_fn, data_iterator, metrics, params, num_steps,
                  data_loader, model_dir):
    # set model to evaluation mode
    model.eval()

    # summary for current eval loop
    summ = []

    master_list = []

    # compute metrics over the dataset
    running_metrics = utils.TestMetrics()
    # Use tqdm for progress bar
    with torch.no_grad():
        t = trange(num_steps)
        for i in t:
            # fetch the next evaluation batch
            train_batch_w2v, labels_batch, ids = next(data_iterator)
            output_batch, attn_weights_w2v = model(train_batch_w2v,
                                                   interpret=True)
            batch_word_indexes = train_batch_w2v[0].tolist()
            batch_text = []
            for word_indexes in batch_word_indexes:
                unigrams, bigrams, trigrams = [], [], []
                for ind in range(len(word_indexes)):
                    if ind < 2:
                        pre_context = data_loader.index_to_word_w2v[
                            word_indexes[ind - 1]]
                    elif ind < 1:
                        pre_context = ''
                    else:
                        pre_context = data_loader.index_to_word_w2v[word_indexes[ind - 2]] + ' ' + \
                                      data_loader.index_to_word_w2v[word_indexes[ind - 1]]
                    if ind + 4 < len(word_indexes):
                        unigrams.append(
                            pre_context + ' [' +
                            data_loader.index_to_word_w2v[word_indexes[ind]] +
                            '] ' +
                            data_loader.index_to_word_w2v[word_indexes[ind +
                                                                       1]] +
                            ' ' +
                            data_loader.index_to_word_w2v[word_indexes[ind +
                                                                       2]])
                        bigrams.append(
                            pre_context + ' [' +
                            data_loader.index_to_word_w2v[word_indexes[ind]] +
                            ' ' +
                            data_loader.index_to_word_w2v[word_indexes[ind +
                                                                       1]] +
                            '] ' +
                            data_loader.index_to_word_w2v[word_indexes[ind +
                                                                       2]] +
                            ' ' +
                            data_loader.index_to_word_w2v[word_indexes[ind +
                                                                       3]])
                        trigrams.append(
                            pre_context + ' [' +
                            data_loader.index_to_word_w2v[word_indexes[ind]] +
                            ' ' +
                            data_loader.index_to_word_w2v[word_indexes[ind +
                                                                       1]] +
                            ' ' +
                            data_loader.index_to_word_w2v[word_indexes[ind +
                                                                       2]] +
                            '] ' +
                            data_loader.index_to_word_w2v[word_indexes[ind +
                                                                       3]] +
                            ' ' +
                            data_loader.index_to_word_w2v[word_indexes[ind +
                                                                       4]])
                    elif ind + 3 < len(word_indexes):
                        unigrams.append(
                            pre_context + ' [' +
                            data_loader.index_to_word_w2v[word_indexes[ind]] +
                            '] ' +
                            data_loader.index_to_word_w2v[word_indexes[ind +
                                                                       1]] +
                            ' ' +
                            data_loader.index_to_word_w2v[word_indexes[ind +
                                                                       2]])
                        bigrams.append(
                            pre_context + ' [' +
                            data_loader.index_to_word_w2v[word_indexes[ind]] +
                            ' ' +
                            data_loader.index_to_word_w2v[word_indexes[ind +
                                                                       1]] +
                            '] ' +
                            data_loader.index_to_word_w2v[word_indexes[ind +
                                                                       2]] +
                            ' ' +
                            data_loader.index_to_word_w2v[word_indexes[ind +
                                                                       3]])
                        trigrams.append(
                            pre_context + ' [' +
                            data_loader.index_to_word_w2v[word_indexes[ind]] +
                            ' ' +
                            data_loader.index_to_word_w2v[word_indexes[ind +
                                                                       1]] +
                            ' ' +
                            data_loader.index_to_word_w2v[word_indexes[ind +
                                                                       2]] +
                            '] ' +
                            data_loader.index_to_word_w2v[word_indexes[ind +
                                                                       3]])
                    elif ind + 2 < len(word_indexes):
                        unigrams.append(
                            pre_context + ' [' +
                            data_loader.index_to_word_w2v[word_indexes[ind]] +
                            '] ' +
                            data_loader.index_to_word_w2v[word_indexes[ind +
                                                                       1]] +
                            ' ' +
                            data_loader.index_to_word_w2v[word_indexes[ind +
                                                                       2]])
                        bigrams.append(
                            pre_context + ' [' +
                            data_loader.index_to_word_w2v[word_indexes[ind]] +
                            ' ' +
                            data_loader.index_to_word_w2v[word_indexes[ind +
                                                                       1]] +
                            '] ' +
                            data_loader.index_to_word_w2v[word_indexes[ind +
                                                                       2]])
                        trigrams.append(
                            pre_context + ' [' +
                            data_loader.index_to_word_w2v[word_indexes[ind]] +
                            ' ' +
                            data_loader.index_to_word_w2v[word_indexes[ind +
                                                                       1]] +
                            ' ' +
                            data_loader.index_to_word_w2v[word_indexes[ind +
                                                                       2]] +
                            '] ')
                    elif ind + 1 < len(word_indexes):
                        unigrams.append(
                            pre_context + ' [' +
                            data_loader.index_to_word_w2v[word_indexes[ind]] +
                            '] ' +
                            data_loader.index_to_word_w2v[word_indexes[ind +
                                                                       1]])
                        bigrams.append(
                            pre_context + ' [' +
                            data_loader.index_to_word_w2v[word_indexes[ind]] +
                            ' ' +
                            data_loader.index_to_word_w2v[word_indexes[ind +
                                                                       1]] +
                            '] ')
                    else:
                        unigrams.append(
                            pre_context + ' [' +
                            data_loader.index_to_word_w2v[word_indexes[ind]] +
                            '] ')

                batch_text.append(unigrams + bigrams + ['<CONV_PAD>'] +
                                  trigrams + ['<CONV_PAD>'] + ['<CONV_PAD>'])
            output_list = output_batch.tolist()
            attn_weights_list = [x.tolist() for x in attn_weights_w2v]
            labels_batch_list = labels_batch.tolist()
            assert len(ids) == len(batch_text)
            assert len(ids) == len(labels_batch_list)
            assert len(ids) == len(output_list)
            assert len(ids) == len(attn_weights_list[0])
            for head in range(len(attn_weights_list)):
                for index in range(len(ids)):
                    temp_list = []
                    temp_list.append(ids[index])
                    temp_list.append('w2v')
                    temp_list.append(head)
                    temp_list.append(labels_batch_list[index][0])
                    temp_list.append(output_list[index][0])
                    attn_words = list(
                        zip(attn_weights_list[head][index], batch_text[index]))
                    attn_words.sort(reverse=True)
                    new_attn_words = [x for t in attn_words[:50] for x in t]
                    temp_list.extend(new_attn_words)
                    master_list.append(temp_list)
            loss = loss_fn(output_batch, labels_batch)

            running_metrics.update(labels_batch.data.cpu().numpy(),
                                   output_batch.data.cpu().numpy())
            # extract data from torch Variable, move to cpu, convert to numpy arrays
            output_batch = output_batch.data.cpu().numpy()
            labels_batch = labels_batch.data.cpu().numpy()

            # compute all metrics on this batch
            summary_batch = {
                metric: metrics[metric](output_batch, labels_batch)
                for metric in metrics
            }
            summary_batch['loss'] = loss.data.item()
            summ.append(summary_batch)

    df_attn_scores = pd.DataFrame(
        master_list,
        columns=[
            "ICUSTAY_ID", 'head', 'embedding', params.task + "_label",
            params.task + "_prediction"
        ] + [
            'attn_' + str(i // 2) if i % 2 == 0 else 'words_' + str(i // 2)
            for i in range(100)
        ])
    print(df_attn_scores.dtypes)
    df_attn_scores.sort_values(by=[params.task + "_prediction"],
                               ascending=False,
                               inplace=True)
    print(df_attn_scores.head(5))
    datasetPath = os.path.join(model_dir, 'df_attn.csv')
    df_attn_scores.to_csv(datasetPath, index=False)

    metrics_mean = {
        metric: np.mean([x[metric] for x in summ])
        for metric in summ[0]
    }
    metrics_string = " ; ".join("{}: {:05.3f}".format(k, v)
                                for k, v in metrics_mean.items())
    logging.info("- Eval metrics : " + metrics_string)
    logging.info('AUCROC' + str(running_metrics()))
    metrics = running_metrics()
    return {'AUCROC': metrics[0], "AUCPR": metrics[1]}