def _prepare_random_matched_spans(model, batch_instances, cuda): unk_idx = int(model.vectorizer.str_to_idx[SimpleInferenceVectorizer.PAD]) Is, Cs, Os = [PaddedSequence.autopad([torch.LongTensor(inst[x]) for inst in batch_instances], batch_first=True, padding_value=unk_idx) for x in ['I', 'C', 'O']] target_spans = [inst['evidence_spans'] for inst in batch_instances] target = [] articles = [] for article, evidence_spans in zip((x['article'] for x in batch_instances), target_spans): tgt = torch.zeros(len(article)) for start, end in evidence_spans: tgt[start:end] = 1 (start, end) = random.choice(evidence_spans) # select a random span of the same length random_matched_span_start = random.randint(0, len(article)) random_matched_span_end = random_matched_span_start + end - start tgt_pos = tgt[start:end] tgt_neg = tgt[random_matched_span_start:random_matched_span_end] article_pos = torch.LongTensor(article[start:end]) article_neg = torch.LongTensor(article[random_matched_span_start:random_matched_span_end]) if random.random() > 0.5: articles.append(torch.cat([article_pos, article_neg])) target.append(torch.cat([tgt_pos, tgt_neg])) else: articles.append(torch.cat([article_neg, article_pos])) target.append(torch.cat([tgt_neg, tgt_pos])) target = PaddedSequence.autopad(target, batch_first=True, padding_value=0) articles = PaddedSequence.autopad(articles, batch_first=True, padding_value=unk_idx) if cuda: articles, Is, Cs, Os, target = articles.cuda(), Is.cuda(), Cs.cuda(), Os.cuda(), target.cuda() return articles, Is, Cs, Os, target
def forward(self, word_inputs: PaddedSequence, init_hidden: torch.Tensor=None, query_v_for_attention: torch.Tensor=None, normalize_attention_distribution=True) -> (torch.Tensor, torch.Tensor): if isinstance(word_inputs, PaddedSequence): embedded = self.embedding(word_inputs.data) as_padded = word_inputs.pack_other(embedded) output, hidden = self.gru(as_padded, init_hidden) output = PaddedSequence.from_packed_sequence(output, batch_first=True) else: raise ValueError("Unknown input type {} for word_inputs: {}, try a PaddedSequence or a Tensor".format(type(word_inputs), word_inputs)) # concatenate the hidden representations if self.bidirectional: if self.n_layers > 1: raise ValueError("Implement me!") hidden = torch.cat([hidden[0], hidden[1]], dim=1) if self.use_attention: # note that these hidden_input_states are masked to zeros (when appropriate) already when this is called. hidden_input_states = output a = self.attention_mechanism(hidden_input_states, query_v_for_attention, normalize=normalize_attention_distribution) # note this is an element-wise multiplication, so each of the hidden states is weighted by the attention vector weighted_hidden = torch.sum(a * output.data, dim=1) return output, weighted_hidden, a return output, hidden
def forward(self, hidden_input_states: PaddedSequence, query_v_for_attention, normalize=True): if not isinstance(hidden_input_states, PaddedSequence): raise TypeError("Expected an input of type PaddedSequence but got {}".format(type(hidden_input_states))) if self.condition_attention: # the code below concatenates the query_v_for_attention (for a unit in the batch to each of the hidden states in the encoder) # expand the query vector used for attention by making it |batch|x1x|query_vector_size| query_v_for_attention = query_v_for_attention.unsqueeze(dim=1) # duplicate it to be the same number of (max) tokens in the batch query_v_for_attention = torch.cat(hidden_input_states.data.size()[1] * [query_v_for_attention], dim=1) # finally, concatenate this vector to every "final" element of the input tensor attention_inputs = torch.cat([hidden_input_states.data, query_v_for_attention], dim=2) else: attention_inputs = hidden_input_states.data raw_word_scores = self.token_attention_F(attention_inputs) raw_word_scores = raw_word_scores * hidden_input_states.mask(on=1.0, off=0.0, size=raw_word_scores.size(), device=raw_word_scores.device) # TODO this should probably become a logsumexp depending on condition a = self.attn_sm(raw_word_scores) # since we need to handle masking, we have to kill any support out of the softmax masked_attention = a * hidden_input_states.mask(on=1.0, off=0.0, size=a.size(), device=a.device) if normalize: # divide by the batch length here so we reduce the variance of the input to the next layer. this is only necessary for the tokenwise attention because its sum isn't constrained # a = masked_attention / word_inputs.batch_sizes.unsqueeze(-1).unsqueeze(-1).float() weights = torch.sum(masked_attention, dim=1).unsqueeze(1) a = masked_attention / weights else: a = masked_attention return a
def prepare_article_attention_target_balanced(model, batch_instances, cuda): unk_idx = int(model.vectorizer.str_to_idx[SimpleInferenceVectorizer.PAD]) Is = [] Cs = [] Os = [] articles = [] target = [] for inst in batch_instances: i = torch.LongTensor(inst['I']) c = torch.LongTensor(inst['C']) o = torch.LongTensor(inst['O']) article = torch.LongTensor(inst['article']) target_spans = set([tuple(x) for x in inst['evidence_spans']]) for start, end in target_spans: # positive example Is.append(i) Cs.append(c) Os.append(o) articles.append(article[start:end]) target.append(torch.ones(end - start)) # negative example neg_start, neg_end = _fetch_random_span(start, end, len(article), end - start) Is.append(i) Cs.append(c) Os.append(o) articles.append(article[neg_start:neg_end]) target.append(torch.zeros(neg_end - neg_start)) Is, Cs, Os, articles = [PaddedSequence.autopad(x, batch_first=True, padding_value=unk_idx) for x in [Is, Cs, Os, articles]] target = PaddedSequence.autopad(target, batch_first=True, padding_value=0) if cuda: articles, Is, Cs, Os, target = articles.cuda(), Is.cuda(), Cs.cuda(), Os.cuda(), target.cuda() return articles, Is, Cs, Os, target
def split_sections(instances, inference_vectorizer, big_sections=False): """ Split into sections. If big_sections = False, use subsections, else use big sections. """ unk_idx = int( inference_vectorizer.str_to_idx[SimpleInferenceVectorizer.PAD]) Is, Cs, Os = [ PaddedSequence.autopad( [torch.LongTensor(inst[x]) for inst in instances], batch_first=True, padding_value=unk_idx) for x in ['I', 'C', 'O'] ] indices = [] sections = [] section_titles = [] for i in range(len(instances)): info = instances[i] if big_sections: info = gen_big_sections(info) ss = info['section_splits'] art = info['article'] evidence_labels = info['evidence_spans'] section_labels = [] section_titles.append(info['section_titles']) start = 0 new_added = 0 for s in ss: tmp = art[s:start + s] is_evid = False for labels in evidence_labels: is_evid = is_evid or interval_overlap([start, start + s], labels) if is_evid: section_labels.append(1) else: section_labels.append(0) if len(tmp) == 0: tmp = [unk_idx] sections.append(tmp) start += s new_added += 1 indices.append(new_added) # cap number of sections... inst = [torch.LongTensor(inst) for inst in sections] import pdb pdb.set_trace() pad_sections = PaddedSequence.autopad(inst, batch_first=True, padding_value=unk_idx) return pad_sections, indices, section_labels, section_titles, Is, Cs, Os
def prepare_article_attention_target(model, batch_instances, cuda): unk_idx = int(model.vectorizer.str_to_idx[SimpleInferenceVectorizer.PAD]) articles, Is, Cs, Os = [PaddedSequence.autopad([torch.LongTensor(inst[x]) for inst in batch_instances], batch_first=True, padding_value=unk_idx) for x in ['article', 'I', 'C', 'O']] target_spans = [inst['evidence_spans'] for inst in batch_instances] target = [torch.zeros(len(x['article'])) for x in batch_instances] for tgt, spans in zip(target, target_spans): for start, end in spans: tgt[start:end] = 1 target = PaddedSequence.autopad(target, batch_first=True, padding_value=0) if cuda: articles, Is, Cs, Os, target = articles.cuda(), Is.cuda(), Cs.cuda(), Os.cuda(), target.cuda() return articles, Is, Cs, Os, target
def forward(self, word_inputs: PaddedSequence, query_v_for_attention: torch.Tensor=None, normalize_attention_distribution=True): if isinstance(word_inputs, PaddedSequence): embedded = self.embedding(word_inputs.data) as_padded = PaddedSequence(embedded, word_inputs.batch_sizes, word_inputs.batch_first) else: raise ValueError("Got an unexpected type {} for word_inputs {}".format(type(word_inputs), word_inputs)) if self.use_attention: a = self.attention_mechanism(as_padded, query_v_for_attention, normalize=normalize_attention_distribution) output = torch.sum(a * embedded * as_padded.mask().unsqueeze(2).cuda(), dim=1) return embedded, output, a else: output = torch.sum(embedded, dim=1) / word_inputs.batch_sizes.unsqueeze(-1).to(torch.float) return embedded, output, None
def make_preds(nnet, instances, batch_size, inference_vectorizer, verbose_attn_to_batches=False, cuda=USE_CUDA): # TODO consider removing the inference_vectorizer since all we need is an unk_idx from it y_vec = torch.cat( [_get_y_vec(inst['y'], as_vec=False) for inst in instances]).squeeze() unk_idx = int( inference_vectorizer.str_to_idx[SimpleInferenceVectorizer.PAD]) y_hat_vec = [] # we batch this so the GPU doesn't run out of memory nnet.eval() for i in range(0, len(instances), batch_size): batch_instances = instances[i:i + batch_size] articles, Is, Cs, Os = [ PaddedSequence.autopad( [torch.LongTensor(inst[x]) for inst in batch_instances], batch_first=True, padding_value=unk_idx) for x in ['article', 'I', 'C', 'O'] ] if cuda: articles, Is, Cs, Os = articles.cuda(), Is.cuda(), Cs.cuda( ), Os.cuda() verbose_attn = verbose_attn_to_batches and i in verbose_attn_to_batches y_hat_batch = nnet(articles, Is, Cs, Os, batch_size=len(batch_instances), verbose_attn=verbose_attn) y_hat_vec.append(y_hat_batch) nnet.train() return y_vec, torch.cat(y_hat_vec, dim=0)
def forward(self, word_inputs: PaddedSequence, mask=None, query_v_for_attention=None, normalize_attention_distribution=True): embedded = self.embedding(word_inputs.data) projected = self.projection_layer(embedded) mask = word_inputs.mask().to("cuda") # now to the star transformer. # the model will return a tuple comprising <batch, words, dims> and a second # tensor (the rely nodes) of <batch, dims> -- we take the latter # in the case where no attention is to be used token_vectors, a_v = self.st(projected, mask=mask) if self.use_attention: token_vectors = PaddedSequence(token_vectors, word_inputs.batch_sizes, batch_first=True) a = None if self.concat_relay: ### # need to concatenate a_v <batch x model_d> for all articles ### token_vectors_with_relay = self._concat_relay_to_tokens_in_batches( token_vectors, a_v, word_inputs.batch_sizes) a = self.attention_mechanism( token_vectors_with_relay, query_v_for_attention, normalize=normalize_attention_distribution) else: a = self.attention_mechanism( token_vectors, query_v_for_attention, normalize=normalize_attention_distribution) # note this is an element-wise multiplication, so each of the hidden states is weighted by the attention vector weighted_hidden = torch.sum(a * token_vectors.data, dim=1) return token_vectors, weighted_hidden, a return a_v
def _prepare_random_concatenated_spans(model, batch_instances, cuda): unk_idx = int(model.vectorizer.str_to_idx[SimpleInferenceVectorizer.PAD]) target_spans = [inst['evidence_spans'] for inst in batch_instances] Is = [] Os = [] Cs = [] target = [] articles = [] for instance, evidence_spans in zip(batch_instances, target_spans): article = instance['article'] article = torch.LongTensor(article) tgt = torch.zeros(len(article)) Is.append(instance['I']) Os.append(instance['O']) Cs.append(instance['C']) for start, end in evidence_spans: tgt[start:end] = 1 start, end = random.choice(evidence_spans) unacceptable_start = start - (end - start) unacceptable_end = end + (end - start) random_matched_span_start = random.randint(0, len(article)) # rejection sample until we find an acceptable span start either inside or outside the document while unacceptable_start - random_matched_span_start < 0 and 0 < unacceptable_end - random_matched_span_start: random_matched_span_start = random.randint(0, len(article)) random_matched_span = (random_matched_span_start, random_matched_span_start + end - start) if random.random() > 0.5: tgt = torch.cat([tgt[start:end], tgt[random_matched_span[0]:random_matched_span[1]]]).contiguous() article = torch.cat([article[start:end], article[random_matched_span[0]:random_matched_span[1]]]).contiguous() else: tgt = torch.cat([tgt[random_matched_span[0]:random_matched_span[1]], tgt[start:end]]).contiguous() article = torch.cat([article[random_matched_span[0]:random_matched_span[1]], article[start:end]]).contiguous() tgt /= torch.sum(tgt) target.append(tgt) articles.append(article) Is, Cs, Os = [PaddedSequence.autopad([torch.LongTensor(elem) for elem in cond], batch_first=True, padding_value=unk_idx) for cond in [Is, Cs, Os]] target = PaddedSequence.autopad(target, batch_first=True, padding_value=0) articles = PaddedSequence.autopad(articles, batch_first=True, padding_value=unk_idx) if cuda: articles, Is, Cs, Os, target = articles.cuda(), Is.cuda(), Cs.cuda(), Os.cuda(), target.cuda() return articles, Is, Cs, Os, target
def forward(self, query: List[torch.tensor], document_batch: List[torch.tensor]): assert len(query) == len(document_batch) # note about device management: # since distributed training is enabled, the inputs to this module can be on *any* device (preferably cpu, since we wrap and unwrap the module) # we want to keep these params on the input device (assuming CPU) for as long as possible for cheap memory access target_device = next(self.parameters()).device cls_token = torch.tensor([self.cls_token_id ]) #.to(device=document_batch[0].device) sep_token = torch.tensor([self.sep_token_id ]) #.to(device=document_batch[0].device) input_tensors = [] position_ids = [] for q, d in zip(query, document_batch): if len(q) + len(d) + 2 > self.max_length: d = d[:(self.max_length - len(q) - 2)] input_tensors.append( torch.cat([cls_token, q, sep_token, d.to(dtype=q.dtype)])) position_ids.append( torch.arange(0, input_tensors[-1].size().numel())) #position_ids.append(torch.tensor(list(range(0, len(q) + 1)) + list(range(0, len(d) + 1)))) bert_input = PaddedSequence.autopad(input_tensors, batch_first=True, padding_value=self.pad_token_id, device=target_device) positions = PaddedSequence.autopad(position_ids, batch_first=True, padding_value=0, device=target_device) (classes, ) = self.bert(bert_input.data, attention_mask=bert_input.mask( on=1.0, off=0.0, dtype=torch.float, device=target_device), position_ids=positions.data) assert torch.all(classes == classes) # for nans return classes
def _concat_relay_to_tokens_in_batches(self, article_token_batches, relay_batches, batch_sizes): ''' Takes <batch x doc_len x embedding> tensor (article_token_batches) and builds and returns a version <batch x doc_len x [embedding + relay_embedding]> which concatenates repeated copies of the relay embedding associated with each batch. ''' # create an empty <batch x (token emedding + relay_embedding)> article_tokens_with_relays = torch.zeros( article_token_batches.data.shape[0], article_token_batches.data.shape[1], article_token_batches.data.shape[2] + relay_batches.shape[1]) for b in range(article_token_batches.data.shape[0]): batch_relay = relay_batches[b].repeat( article_tokens_with_relays.shape[1], 1) article_tokens_with_relays[b] = torch.cat( (article_token_batches.data[b], batch_relay), 1) return PaddedSequence(article_tokens_with_relays.to("cuda"), batch_sizes, batch_first=True)
def train(ev_inf: InferenceNet, train_Xy, val_Xy, test_Xy, inference_vectorizer, epochs=10, batch_size=16, shuffle=True): # we sort these so batches all have approximately the same length (ish), which decreases the # average amount of padding needed, and thus total number of steps in training. if not shuffle: train_Xy.sort(key=lambda x: len(x['article'])) val_Xy.sort(key=lambda x: len(x['article'])) test_Xy.sort(key=lambda x: len(x['article'])) print("Using {} training examples, {} validation examples, {} testing examples".format(len(train_Xy), len(val_Xy), len(test_Xy))) most_common = stats.mode([_get_majority_label(inst) for inst in train_Xy])[0][0] best_val_model = None best_val_f1 = float('-inf') if USE_CUDA: ev_inf = ev_inf.cuda() optimizer = optim.Adam(ev_inf.parameters()) criterion = nn.CrossEntropyLoss(reduction='sum') # sum (not average) of the batch losses. # TODO add epoch timing information here epochs_since_improvement = 0 val_metrics = { "val_acc": [], "val_p": [], "val_r": [], "val_f1": [], "val_loss": [], 'train_loss': [], 'val_aucs': [], 'train_aucs': [], 'val_entropies': [], 'val_evidence_token_mass': [], 'val_evidence_token_err': [], 'train_entropies': [], 'train_evidence_token_mass': [], 'train_evidence_token_err': [] } for epoch in range(epochs): if epochs_since_improvement > 10: print("Exiting early due to no improvement on validation after 10 epochs.") break if shuffle: random.shuffle(train_Xy) epoch_loss = 0 for i in range(0, len(train_Xy), batch_size): instances = train_Xy[i:i+batch_size] ys = torch.cat([_get_y_vec(inst['y'], as_vec=False) for inst in instances], dim=0) # TODO explain the use of padding here unk_idx = int(inference_vectorizer.str_to_idx[SimpleInferenceVectorizer.PAD]) articles, Is, Cs, Os = [PaddedSequence.autopad([torch.LongTensor(inst[x]) for inst in instances], batch_first=True, padding_value=unk_idx) for x in ['article', 'I', 'C', 'O']] optimizer.zero_grad() if USE_CUDA: articles, Is, Cs, Os = articles.cuda(), Is.cuda(), Cs.cuda(), Os.cuda() ys = ys.cuda() verbose_attn = (epoch == epochs - 1 and i == 0) or (epoch == 0 and i == 0) if verbose_attn: print("Training attentions:") tags = ev_inf(articles, Is, Cs, Os, batch_size=len(instances), verbose_attn=verbose_attn) loss = criterion(tags, ys) #if loss.item() != loss.item(): # import pdb; pdb.set_trace() epoch_loss += loss.item() loss.backward() optimizer.step() val_metrics['train_loss'].append(epoch_loss) with torch.no_grad(): verbose_attn_to_batches = set([0,1,2,3,4]) if epoch == epochs - 1 or epoch == 0 else False if verbose_attn_to_batches: print("Validation attention:") # make_preds runs in eval mode val_y, val_y_hat = make_preds(ev_inf, val_Xy, batch_size, inference_vectorizer, verbose_attn_to_batches=verbose_attn_to_batches) val_loss = criterion(val_y_hat, val_y.squeeze()) y_hat = to_int_preds(val_y_hat) if epoch == 0: dummy_preds = [most_common] * len(val_y) dummy_acc = accuracy_score(val_y.cpu(), dummy_preds) val_metrics["baseline_val_acc"] = dummy_acc p, r, f1, _ = precision_recall_fscore_support(val_y.cpu(), dummy_preds, labels=None, beta=1, average='macro', pos_label=1, warn_for=('f-score',), sample_weight=None) val_metrics['p_dummy'] = p val_metrics['r_dummy'] = r val_metrics['f_dummy'] = f1 print("val dummy accuracy: {:.3f}".format(dummy_acc)) print("classification report for dummy on val: ") print(classification_report(val_y.cpu(), dummy_preds)) print("\n\n") acc = accuracy_score(val_y.cpu(), y_hat) val_metrics["val_acc"].append(acc) val_loss = val_loss.cpu().item() val_metrics["val_loss"].append(val_loss) # f1 = f1_score(val_y, y_hat, average="macro") p, r, f1, _ = precision_recall_fscore_support(val_y.cpu(), y_hat, labels=None, beta=1, average='macro', pos_label=1, warn_for=('f-score',), sample_weight=None) val_metrics["val_f1"].append(f1) val_metrics["val_p"].append(p) val_metrics["val_r"].append(r) if ev_inf.article_encoder.use_attention: train_auc, train_entropies, train_evidence_token_masses, train_evidence_token_err = evaluate_model_attention_distribution(ev_inf, train_Xy, cuda=USE_CUDA, compute_attention_diagnostics=True) val_auc, val_entropies, val_evidence_token_masses, val_evidence_token_err = evaluate_model_attention_distribution(ev_inf, val_Xy, cuda=USE_CUDA, compute_attention_diagnostics=True) print("train auc: {:.3f}, entropy: {:.3f}, evidence mass: {:.3f}, err: {:.3f}".format(train_auc, train_entropies, train_evidence_token_masses, train_evidence_token_err)) print("val auc: {:.3f}, entropy: {:.3f}, evidence mass: {:.3f}, err: {:.3f}".format(val_auc, val_entropies, val_evidence_token_masses, val_evidence_token_err)) else: train_auc, train_entropies, train_evidence_token_masses, train_evidence_token_err = "", "", "", "" val_auc, val_entropies, val_evidence_token_masses, val_evidence_token_err = "", "", "", "" val_metrics['train_aucs'].append(train_auc) val_metrics['train_entropies'].append(train_entropies) val_metrics['train_evidence_token_mass'].append(train_evidence_token_masses) val_metrics['train_evidence_token_err'].append(train_evidence_token_err) val_metrics['val_aucs'].append(val_auc) val_metrics['val_entropies'].append(val_entropies) val_metrics['val_evidence_token_mass'].append(val_evidence_token_masses) val_metrics['val_evidence_token_err'].append(val_evidence_token_err) if f1 > best_val_f1: print("New best model at {} with val f1 {:.3f}".format(epoch, f1)) best_val_f1 = f1 best_val_model = copy.deepcopy(ev_inf) epochs_since_improvement = 0 else: epochs_since_improvement += 1 #if val_loss != val_loss or epoch_loss != epoch_loss: # import pdb; pdb.set_trace() print("epoch {}. train loss: {}; val loss: {}; val acc: {:.3f}".format( epoch, epoch_loss, val_loss, acc)) print(classification_report(val_y.cpu(), y_hat)) print("val macro f1: {0:.3f}".format(f1)) print("\n\n") val_metrics['best_val_f1'] = best_val_f1 with torch.no_grad(): print("Test attentions:") verbose_attn_to_batches = set([0,1,2,3,4]) # make_preds runs in eval mode test_y, test_y_hat = make_preds(best_val_model, test_Xy, batch_size, inference_vectorizer, verbose_attn_to_batches=verbose_attn_to_batches) test_loss = criterion(test_y_hat, test_y.squeeze()) y_hat = to_int_preds(test_y_hat) final_test_preds = zip([t['a_id'] for t in test_Xy], [t['p_id'] for t in test_Xy], y_hat) acc = accuracy_score(test_y.cpu(), y_hat) val_metrics["test_acc"] = acc test_loss = test_loss.cpu().item() val_metrics["test_loss"] = test_loss # f1 = f1_score(test_y, y_hat, average="macro") p, r, f1, _ = precision_recall_fscore_support(test_y.cpu(), y_hat, labels=None, beta=1, average='macro', pos_label=1, warn_for=('f-score',), sample_weight=None) val_metrics["test_f1"] = f1 val_metrics["test_p"] = p val_metrics["test_r"] = r if ev_inf.article_encoder.use_attention: test_auc, test_entropies, test_evidence_token_masses, test_evidence_token_err = evaluate_model_attention_distribution(best_val_model, test_Xy, cuda=USE_CUDA, compute_attention_diagnostics=True) print("test auc: {:.3f}, , entropy: {:.3f}, kl_to_uniform {:.3f}".format(test_auc, test_entropies, test_evidence_token_masses)) else: test_auc, test_entropies, test_evidence_token_masses, test_evidence_token_err = "", "", "", "" val_metrics['test_auc'] = test_auc val_metrics['test_entropy'] = test_entropies val_metrics['test_evidence_token_mass'] = test_evidence_token_masses val_metrics['test_evidence_token_err'] = test_evidence_token_err print("test loss: {}; test acc: {:.3f}".format(test_loss, acc)) print(classification_report(test_y.cpu(), y_hat)) print("test macro f1: {}".format(f1)) print("\n\n") return best_val_model, inference_vectorizer, train_Xy, val_Xy, val_metrics, final_test_preds
def forward(self, article_tokens: PaddedSequence, indices, I_tokens: PaddedSequence, C_tokens: PaddedSequence, O_tokens: PaddedSequence, batch_size, h_dropout_rate=0.2, recursive_encoding={}): inner_batch = 1 # this is over sections! ### Run our encode function ### I_v, C_v, O_v = self._encode(I_tokens, C_tokens, O_tokens) query_v, old_query_v = None, None ### Run normal attention over the data ### if self.article_encoder.condition_attention: query_v = torch.cat([I_v, C_v, O_v], dim=1) old_query_v = copy.deepcopy(query_v) #if self.use_attention_over_article_tokens: cmb_hidden = [] ### encode each section with the article encoder ### for i in range(0, len(article_tokens[0]), inner_batch): tokens = article_tokens[0][i:i + inner_batch] new_tkn = PaddedSequence.autopad(tokens, batch_first=True) if query_v is not None: query_v = torch.cat([ old_query_v for _ in range(min(len(tokens), inner_batch)) ], dim=0) #_, hidden, _ = self.article_encoder(new_tkn, query_v_for_attention=query_v) if self.article_encoder in ("transformer", "CBoW"): hidden = self.article_encoder(new_tkn, query_v_for_attention=query_v) else: # assume RNN _, hidden = self.article_encoder(new_tkn, query_v_for_attention=query_v) cmb_hidden.append(hidden) hidden = torch.cat(cmb_hidden, dim=0) #else: # if self.article_encoder in ("Transformer", "CBoW"): # # hidden = self.article_encoder(article_tokens, query_v_for_attention=query_v) # else: # assume RNN # _, hidden = self.article_encoder(article_tokens, query_v_for_attention=query_v) art_secs = [] token_secs = [] i = 0 ### Reshape our tokens + article representations. ### for idx in indices: art_secs.append(hidden[i:i + idx]) token_secs.append(article_tokens[i:i + idx]) i += idx hidden_articles = art_secs batch_a_v = None section_weights = [] for i in range(batch_size): hidden_art = hidden_articles[i] # single hidden article token_art = token_secs[i] # single article tokens if self.condition_attention: query_v = torch.cat( [old_query_v for _ in range(len(hidden_art))], dim=0) ### Run section attention over the data for each section ### a = self.section_attn(token_art, hidden_input_states=hidden_art, query_v_for_attention=query_v, normalize=True) section_weights.append(a) if self.recursive_encoding: section_splits = recursive_encoding['section_splits'] new_articles = [] last = 0 ### -> Reweight sections based on subsection: # [Alpha(S1.1), Alpha(S1.2)] * [Encoding of S1.1, Encoding of S1.2]) for s in section_splits: section_encoding = hidden_art[last:last + s] ws = a[last:last + s] new_articles.append( torch.mm(torch.transpose(ws, dim0=1, dim1=0), section_encoding)) ### -> another attention layer (share it) new_tokens = recursive_encoding['big_sections'] hidden_art = torch.cat(new_articles, dim=0).unsqueeze(0) new_query_v = torch.cat( [old_query_v for _ in range(hidden_art.shape[1])], dim=0) a = self.section_attn(new_tokens, hidden_input_states=hidden_art, query_v_for_attention=new_query_v, normalize=True) ### Combine the re-weighted sections ### weighted = (a * hidden_art).squeeze().unsqueeze(0) weighted_hidden = torch.sum(weighted, dim=1) article_v = torch.sum(weighted_hidden, dim=0) if batch_a_v is None: batch_a_v = article_v else: batch_a_v = torch.stack([batch_a_v, article_v]) # per batch ### Finish Plugging in ### if len(batch_a_v.shape) == 1: batch_a_v = batch_a_v.unsqueeze(0) h = torch.cat([batch_a_v, I_v, C_v, O_v], dim=1) raw_out = self.out(self.MLP_hidden(h)) return F.softmax(raw_out, dim=1), section_weights