Esempio n. 1
0
def get_all_html(model_dirname, dataset, embeddings, indexer, batch_size):
    indices = np.argwhere(dataset.z == 0)
    indices = [x[0] for x in indices]
    # print((indices))
    logger.info("consider a sample of {} for htmls".format(len(indices)))
    # print(dataset.z)
    # print(dataset.w)
    num_batches = int(np.ceil(len(indices) / float(batch_size)).astype(int))
    net = torch.load(os.path.join(model_dirname, 'checkpoint.net.pt'),
                     map_location=lambda storage, y: storage)
    net_inspector = Inspector(net, embeddings)
    categorywise_all_html = defaultdict(list)

    for batch_id in range(0, num_batches):
        u.log_frequently(5, batch_id, logger.debug,
                         'processing batch {}'.format(batch_id))
        _batch_start = batch_size * batch_id
        _batch_end = batch_size * (batch_id + 1)
        batch_indices = indices[_batch_start:_batch_end]

        # print(dataset.X[_batch_start:_batch_end])
        # print(dataset.X[batch_indices])
        X0 = Variable(torch.cuda.LongTensor(dataset.X[batch_indices]))

        X5, weights, bias, ngrams_interest = net_inspector.forward_inspect(
            X0, indexer)
        yp = F.sigmoid(X5)
        yp = yp.resize(yp.size()[0])
        y_pred = yp.data.cpu().numpy()
        y_true = dataset.y[batch_indices]
        confusion_categories = get_confusion_category(y_pred, y_true, 0.5)

        for idx in range(dataset.y[batch_indices].shape[0]):
            X0_numpy = X0[idx].data.cpu().numpy()
            X5_numpy = X5[idx].data.cpu().numpy()

            logit = X5_numpy[0]
            proba = y_pred[idx]
            proba_red = hedge(2 * proba - 1, 0, 1)
            proba_blue = -hedge(2 * proba - 1, -1, 0)

            heatmap_pos, heatmap_neg = get_heatmap(idx, weights,
                                                   ngrams_interest)
            heatmap_pos = normalize_heatmap(heatmap_pos, logit, 0, 1)
            heatmap_neg = normalize_heatmap(heatmap_neg, logit, -1, 0)
            # heatmap_pos = normalize_heatmap_sigmoid(heatmap_pos, 0, 1)
            # heatmap_neg = normalize_heatmap_sigmoid(heatmap_neg, -1, 0)

            confusion_category = confusion_categories[idx]
            true_probability = HighlightedLatex.get_highlighted_word(
                '{0:.2f}'.format(y_true[idx]), r=y_true[idx], b=0)
            predicted_probability = HighlightedLatex.get_highlighted_word(
                '{0:.2f}'.format(proba), r=proba_red, b=proba_blue)
            highlighted_text = HighlightedLatex.get_highlighted_words(
                indices2words(X0_numpy), heatmap_pos, heatmap_neg)
            sample_xml = HighlightedLatex.SAMPLE_FORMAT.format(
                confusion_category=confusion_category,
                true_probability=true_probability,
                predicted_probability=predicted_probability,
                highlighted_text=highlighted_text)
            categorywise_all_html[confusion_category].append(
                (sample_xml, y_true[idx], proba))

    return categorywise_all_html