Ejemplo n.º 1
0
def test(model, Y, epoch, data_path, fold, gpu, version, code_inds, dicts,
         samples, model_dir, testing):
    """
        Testing loop.
        Returns metrics
    """
    filename = data_path.replace('train', fold)
    print('file for evaluation: %s' % filename)
    num_labels = len(dicts['ind2c'])

    #initialize stuff for saving attention samples
    if samples:
        tp_file = open('%s/tp_%s_examples_%d.txt' % (model_dir, fold, epoch),
                       'w')
        fp_file = open('%s/fp_%s_examples_%d.txt' % (model_dir, fold, epoch),
                       'w')
        window_size = model.conv.weight.data.size()[2]

    y, yhat, yhat_raw, hids, losses = [], [], [], [], []
    ind2w, w2ind, ind2c, c2ind = dicts['ind2w'], dicts['w2ind'], dicts[
        'ind2c'], dicts['c2ind']

    desc_embed = model.lmbda > 0
    if desc_embed and len(code_inds) > 0:
        unseen_code_vecs(model, code_inds, dicts, gpu)

    model.eval()
    gen = datasets.data_generator(filename,
                                  dicts,
                                  1,
                                  num_labels,
                                  version=version,
                                  desc_embed=desc_embed)
    for batch_idx, tup in tqdm(enumerate(gen)):
        data, target, hadm_ids, _, descs = tup
        data, target = Variable(torch.LongTensor(data),
                                volatile=True), Variable(
                                    torch.FloatTensor(target))
        if gpu:
            data = data.cuda()
            target = target.cuda()
        model.zero_grad()

        if desc_embed:
            desc_data = descs
        else:
            desc_data = None

        #get an attention sample for 2% of batches
        get_attn = samples and (np.random.rand() < 0.02 or
                                (fold == 'test' and testing))
        output, loss, alpha = model(data,
                                    target,
                                    desc_data=desc_data,
                                    get_attention=get_attn)

        output = F.sigmoid(output)
        output = output.data.cpu().numpy()
        losses.append(loss.item())
        target_data = target.data.cpu().numpy()
        if get_attn and samples:
            interpret.save_samples(data,
                                   output,
                                   target_data,
                                   alpha,
                                   window_size,
                                   epoch,
                                   tp_file,
                                   fp_file,
                                   dicts=dicts)

        #save predictions, target, hadm ids
        yhat_raw.append(output)
        output = np.round(output)
        y.append(target_data)
        yhat.append(output)
        hids.extend(hadm_ids)

    #close files if needed
    if samples:
        tp_file.close()
        fp_file.close()

    y = np.concatenate(y, axis=0)
    yhat = np.concatenate(yhat, axis=0)
    yhat_raw = np.concatenate(yhat_raw, axis=0)

    #write the predictions
    preds_file = persistence.write_preds(yhat, model_dir, hids, fold, ind2c,
                                         yhat_raw)
    #get metrics
    k = 5 if num_labels == 50 else [8, 15]
    metrics = evaluation.all_metrics(yhat, y, k=k, yhat_raw=yhat_raw)
    evaluation.print_metrics(metrics)
    metrics['loss_%s' % fold] = np.mean(losses)
    return metrics
Ejemplo n.º 2
0
def test(args, model, Y, epoch, data_path, fold, gpu, version, code_inds,
         dicts, samples, model_dir, testing):
    """
        Testing loop.
        Returns metrics
    """
    filename = data_path.replace('train', fold)
    print('file for evaluation: %s' % filename)
    num_labels = len(dicts['ind2c'])

    #initialize stuff for saving attention samples
    if samples:
        tp_file = open('%s/tp_%s_examples_%d.txt' % (model_dir, fold, epoch),
                       'w')
        fp_file = open('%s/fp_%s_examples_%d.txt' % (model_dir, fold, epoch),
                       'w')
        window_size = model.conv.weight.data.size()[2]

    y, yhat, yhat_raw, hids, losses = [], [], [], [], []
    ind2w, w2ind, ind2c, c2ind = dicts['ind2w'], dicts['w2ind'], dicts[
        'ind2c'], dicts['c2ind']

    desc_embed = model.lmbda > 0
    if desc_embed and len(code_inds) > 0:
        unseen_code_vecs(model, code_inds, dicts, gpu)

    if args.model == 'bert':
        if args.redefined_tokenizer:
            bert_tokenizer = BertTokenizer.from_pretrained(args.tokenizer_path,
                                                           do_lower_case=True)
        else:
            bert_tokenizer = BertTokenizer.from_pretrained(
                './pretrained_weights/bert-base-uncased-vocab.txt',
                do_lower_case=True)
    elif args.model == 'biobert':
        if args.redefined_tokenizer:
            bert_tokenizer = BertTokenizer.from_pretrained(args.tokenizer_path,
                                                           do_lower_case=False)
        else:
            bert_tokenizer = BertTokenizer.from_pretrained(
                './pretrained_weights/biobert_pretrain_output_all_notes_150000/vocab.txt',
                do_lower_case=False)
    elif args.model == 'bert-tiny':
        if args.redefined_tokenizer:
            bert_tokenizer = BertTokenizer.from_pretrained(args.tokenizer_path,
                                                           do_lower_case=True)
        else:
            bert_tokenizer = BertTokenizer.from_pretrained(
                './pretrained_weights/bert-tiny-uncased-vocab.txt',
                do_lower_case=True)
    else:
        bert_tokenizer = None

    model.eval()
    gen = datasets.data_generator(filename,
                                  dicts,
                                  1,
                                  num_labels,
                                  version=version,
                                  desc_embed=desc_embed,
                                  bert_tokenizer=bert_tokenizer,
                                  test=True,
                                  max_seq_length=args.max_sequence_length)
    for batch_idx, tup in tqdm(enumerate(gen)):
        data, target, hadm_ids, _, descs = tup
        data, target = torch.LongTensor(data), torch.FloatTensor(target)
        if gpu:
            data = data.cuda()
            target = target.cuda()

        if desc_embed:
            desc_data = descs
        else:
            desc_data = None

        if args.model in ['bert', 'biobert', 'bert-tiny']:
            token_type_ids = (data > 0).long() * 0
            attention_mask = (data > 0).long()
            position_ids = torch.arange(data.size(1)).expand(
                data.size(0), data.size(1))
            if gpu:
                position_ids = position_ids.cuda()
            position_ids = position_ids * (data > 0).long()
        else:
            attention_mask = (data > 0).long()
            token_type_ids = None
            position_ids = None

        if args.model in BERT_MODEL_LIST:
            with torch.no_grad():
                output, loss = model(input_ids=data, \
                                     token_type_ids=token_type_ids, \
                                     attention_mask=attention_mask, \
                                     position_ids=position_ids, \
                                     labels=target, \
                                     desc_data=desc_data, \
                                     pos_labels=None, \
                                    )
            output = torch.sigmoid(output)
            output = output.data.cpu().numpy()
        else:
            with torch.no_grad():
                output, loss, alpha = model(data,
                                            target,
                                            desc_data=desc_data,
                                            get_attention=get_attn)

            #get an attention sample for 2% of batches
            get_attn = samples and (np.random.rand() < 0.02 or
                                    (fold == 'test' and testing))

            output = torch.sigmoid(output)
            output = output.data.cpu().numpy()

            if get_attn and samples:
                interpret.save_samples(data,
                                       output,
                                       target_data,
                                       alpha,
                                       window_size,
                                       epoch,
                                       tp_file,
                                       fp_file,
                                       dicts=dicts)

        losses.append(loss.item())
        target_data = target.data.cpu().numpy()

        #save predictions, target, hadm ids
        yhat_raw.append(output)
        output = np.round(output)
        y.append(target_data)
        yhat.append(output)
        hids.extend(hadm_ids)

    # close files if needed
    if samples:
        tp_file.close()
        fp_file.close()

    y = np.concatenate(y, axis=0)
    yhat = np.concatenate(yhat, axis=0)
    yhat_raw = np.concatenate(yhat_raw, axis=0)

    #write the predictions
    preds_file = persistence.write_preds(yhat, model_dir, hids, fold, ind2c,
                                         yhat_raw)
    #get metrics
    k = 5 if num_labels == 50 else [8, 15]
    metrics = evaluation.all_metrics(yhat, y, k=k, yhat_raw=yhat_raw)
    evaluation.print_metrics(metrics)
    metrics['loss_%s' % fold] = np.mean(losses)
    return metrics