def get_all_html(model_dirname, dataset, embeddings, indexer, batch_size): indices = np.argwhere(dataset.z == 0) indices = [x[0] for x in indices] # print((indices)) logger.info("consider a sample of {} for htmls".format(len(indices))) # print(dataset.z) # print(dataset.w) num_batches = int(np.ceil(len(indices) / float(batch_size)).astype(int)) net = torch.load(os.path.join(model_dirname, 'checkpoint.net.pt'), map_location=lambda storage, y: storage) net_inspector = Inspector(net, embeddings) categorywise_all_html = defaultdict(list) for batch_id in range(0, num_batches): u.log_frequently(5, batch_id, logger.debug, 'processing batch {}'.format(batch_id)) _batch_start = batch_size * batch_id _batch_end = batch_size * (batch_id + 1) batch_indices = indices[_batch_start:_batch_end] # print(dataset.X[_batch_start:_batch_end]) # print(dataset.X[batch_indices]) X0 = Variable(torch.cuda.LongTensor(dataset.X[batch_indices])) X5, weights, bias, ngrams_interest = net_inspector.forward_inspect( X0, indexer) yp = F.sigmoid(X5) yp = yp.resize(yp.size()[0]) y_pred = yp.data.cpu().numpy() y_true = dataset.y[batch_indices] confusion_categories = get_confusion_category(y_pred, y_true, 0.5) for idx in range(dataset.y[batch_indices].shape[0]): X0_numpy = X0[idx].data.cpu().numpy() X5_numpy = X5[idx].data.cpu().numpy() logit = X5_numpy[0] proba = y_pred[idx] proba_red = hedge(2 * proba - 1, 0, 1) proba_blue = -hedge(2 * proba - 1, -1, 0) heatmap_pos, heatmap_neg = get_heatmap(idx, weights, ngrams_interest) heatmap_pos = normalize_heatmap(heatmap_pos, logit, 0, 1) heatmap_neg = normalize_heatmap(heatmap_neg, logit, -1, 0) # heatmap_pos = normalize_heatmap_sigmoid(heatmap_pos, 0, 1) # heatmap_neg = normalize_heatmap_sigmoid(heatmap_neg, -1, 0) confusion_category = confusion_categories[idx] true_probability = HighlightedLatex.get_highlighted_word( '{0:.2f}'.format(y_true[idx]), r=y_true[idx], b=0) predicted_probability = HighlightedLatex.get_highlighted_word( '{0:.2f}'.format(proba), r=proba_red, b=proba_blue) highlighted_text = HighlightedLatex.get_highlighted_words( indices2words(X0_numpy), heatmap_pos, heatmap_neg) sample_xml = HighlightedLatex.SAMPLE_FORMAT.format( confusion_category=confusion_category, true_probability=true_probability, predicted_probability=predicted_probability, highlighted_text=highlighted_text) categorywise_all_html[confusion_category].append( (sample_xml, y_true[idx], proba)) return categorywise_all_html