示例#1
0
    def _prepare_random_matched_spans(model, batch_instances, cuda):
        unk_idx = int(model.vectorizer.str_to_idx[SimpleInferenceVectorizer.PAD])
        Is, Cs, Os = [PaddedSequence.autopad([torch.LongTensor(inst[x]) for inst in batch_instances], batch_first=True, padding_value=unk_idx) for x in ['I', 'C', 'O']]
        target_spans = [inst['evidence_spans'] for inst in batch_instances]
        target = []
        articles = []
        for article, evidence_spans in zip((x['article'] for x in batch_instances), target_spans):
            tgt = torch.zeros(len(article))
            for start, end in evidence_spans:
                tgt[start:end] = 1
            (start, end) = random.choice(evidence_spans)
            # select a random span of the same length
            random_matched_span_start = random.randint(0, len(article))
            random_matched_span_end = random_matched_span_start + end - start
            tgt_pos = tgt[start:end]
            tgt_neg = tgt[random_matched_span_start:random_matched_span_end]
            article_pos = torch.LongTensor(article[start:end])
            article_neg = torch.LongTensor(article[random_matched_span_start:random_matched_span_end])
            if random.random() > 0.5:
                articles.append(torch.cat([article_pos, article_neg]))
                target.append(torch.cat([tgt_pos, tgt_neg]))
            else:
                articles.append(torch.cat([article_neg, article_pos]))
                target.append(torch.cat([tgt_neg, tgt_pos]))

        target = PaddedSequence.autopad(target, batch_first=True, padding_value=0)
        articles = PaddedSequence.autopad(articles, batch_first=True, padding_value=unk_idx)
        if cuda:
            articles, Is, Cs, Os, target = articles.cuda(), Is.cuda(), Cs.cuda(), Os.cuda(), target.cuda()
        return articles, Is, Cs, Os, target
示例#2
0
    def forward(self, word_inputs: PaddedSequence, init_hidden: torch.Tensor=None, query_v_for_attention: torch.Tensor=None, normalize_attention_distribution=True) -> (torch.Tensor, torch.Tensor):
        if isinstance(word_inputs, PaddedSequence):
            embedded = self.embedding(word_inputs.data)
            as_padded = word_inputs.pack_other(embedded)
            output, hidden = self.gru(as_padded, init_hidden)
            output = PaddedSequence.from_packed_sequence(output, batch_first=True)
        else:
            raise ValueError("Unknown input type {} for word_inputs: {}, try a PaddedSequence or a Tensor".format(type(word_inputs), word_inputs))

        # concatenate the hidden representations
        if self.bidirectional:
            if self.n_layers > 1:
                raise ValueError("Implement me!")
            hidden = torch.cat([hidden[0], hidden[1]], dim=1)

        if self.use_attention:
            # note that these hidden_input_states are masked to zeros (when appropriate) already when this is called.
            hidden_input_states = output
            a = self.attention_mechanism(hidden_input_states, query_v_for_attention, normalize=normalize_attention_distribution)

            # note this is an element-wise multiplication, so each of the hidden states is weighted by the attention vector
            weighted_hidden = torch.sum(a * output.data, dim=1)
            return output, weighted_hidden, a

        return output, hidden
示例#3
0
    def forward(self, hidden_input_states: PaddedSequence, query_v_for_attention, normalize=True):
        if not isinstance(hidden_input_states, PaddedSequence):
            raise TypeError("Expected an input of type PaddedSequence but got {}".format(type(hidden_input_states)))
        if self.condition_attention:
            # the code below concatenates the query_v_for_attention (for a unit in the batch to each of the hidden states in the encoder)
            # expand the query vector used for attention by making it |batch|x1x|query_vector_size|
            query_v_for_attention = query_v_for_attention.unsqueeze(dim=1)
            # duplicate it to be the same number of (max) tokens in the batch
            query_v_for_attention = torch.cat(hidden_input_states.data.size()[1] * [query_v_for_attention], dim=1)
            # finally, concatenate this vector to every "final" element of the input tensor
            attention_inputs = torch.cat([hidden_input_states.data, query_v_for_attention], dim=2)
        else:
            attention_inputs = hidden_input_states.data
        raw_word_scores = self.token_attention_F(attention_inputs)
        raw_word_scores = raw_word_scores * hidden_input_states.mask(on=1.0, off=0.0, size=raw_word_scores.size(), device=raw_word_scores.device)
        # TODO this should probably become a logsumexp depending on condition
        a = self.attn_sm(raw_word_scores)

        # since we need to handle masking, we have to kill any support out of the softmax
        masked_attention = a * hidden_input_states.mask(on=1.0, off=0.0, size=a.size(), device=a.device)
        if normalize:
            # divide by the batch length here so we reduce the variance of the input to the next layer. this is only necessary for the tokenwise attention because its sum isn't constrained
            # a = masked_attention / word_inputs.batch_sizes.unsqueeze(-1).unsqueeze(-1).float()
            weights = torch.sum(masked_attention, dim=1).unsqueeze(1)
            a = masked_attention / weights
        else:
            a = masked_attention

        return a
示例#4
0
def prepare_article_attention_target_balanced(model, batch_instances, cuda):
    unk_idx = int(model.vectorizer.str_to_idx[SimpleInferenceVectorizer.PAD])
    Is = []
    Cs = []
    Os = []
    articles = []
    target = []
    for inst in batch_instances:
        i = torch.LongTensor(inst['I'])
        c = torch.LongTensor(inst['C'])
        o = torch.LongTensor(inst['O'])
        article = torch.LongTensor(inst['article'])
        target_spans = set([tuple(x) for x in inst['evidence_spans']])
        for start, end in target_spans:
            # positive example
            Is.append(i)
            Cs.append(c)
            Os.append(o)
            articles.append(article[start:end])
            target.append(torch.ones(end - start))

            # negative example
            neg_start, neg_end = _fetch_random_span(start, end, len(article), end - start)
            Is.append(i)
            Cs.append(c)
            Os.append(o)
            articles.append(article[neg_start:neg_end])
            target.append(torch.zeros(neg_end - neg_start))

    Is, Cs, Os, articles = [PaddedSequence.autopad(x, batch_first=True, padding_value=unk_idx) for x in [Is, Cs, Os, articles]]
    target = PaddedSequence.autopad(target, batch_first=True, padding_value=0)
    if cuda:
        articles, Is, Cs, Os, target = articles.cuda(), Is.cuda(), Cs.cuda(), Os.cuda(), target.cuda()
    return articles, Is, Cs, Os, target
示例#5
0
def split_sections(instances, inference_vectorizer, big_sections=False):
    """ Split into sections. If big_sections = False, use subsections, else use big sections. """
    unk_idx = int(
        inference_vectorizer.str_to_idx[SimpleInferenceVectorizer.PAD])
    Is, Cs, Os = [
        PaddedSequence.autopad(
            [torch.LongTensor(inst[x]) for inst in instances],
            batch_first=True,
            padding_value=unk_idx) for x in ['I', 'C', 'O']
    ]
    indices = []
    sections = []
    section_titles = []
    for i in range(len(instances)):
        info = instances[i]
        if big_sections:
            info = gen_big_sections(info)

        ss = info['section_splits']
        art = info['article']
        evidence_labels = info['evidence_spans']
        section_labels = []
        section_titles.append(info['section_titles'])
        start = 0
        new_added = 0

        for s in ss:
            tmp = art[s:start + s]
            is_evid = False
            for labels in evidence_labels:
                is_evid = is_evid or interval_overlap([start, start + s],
                                                      labels)

            if is_evid:
                section_labels.append(1)
            else:
                section_labels.append(0)

            if len(tmp) == 0:
                tmp = [unk_idx]
            sections.append(tmp)
            start += s
            new_added += 1

        indices.append(new_added)

    # cap number of sections...
    inst = [torch.LongTensor(inst) for inst in sections]
    import pdb
    pdb.set_trace()
    pad_sections = PaddedSequence.autopad(inst,
                                          batch_first=True,
                                          padding_value=unk_idx)
    return pad_sections, indices, section_labels, section_titles, Is, Cs, Os
示例#6
0
def prepare_article_attention_target(model, batch_instances, cuda):
    unk_idx = int(model.vectorizer.str_to_idx[SimpleInferenceVectorizer.PAD])
    articles, Is, Cs, Os = [PaddedSequence.autopad([torch.LongTensor(inst[x]) for inst in batch_instances], batch_first=True, padding_value=unk_idx) for x in ['article', 'I', 'C', 'O']]
    target_spans = [inst['evidence_spans'] for inst in batch_instances]
    target = [torch.zeros(len(x['article'])) for x in batch_instances]
    for tgt, spans in zip(target, target_spans):
        for start, end in spans:
            tgt[start:end] = 1
    target = PaddedSequence.autopad(target, batch_first=True, padding_value=0)
    if cuda:
        articles, Is, Cs, Os, target = articles.cuda(), Is.cuda(), Cs.cuda(), Os.cuda(), target.cuda()
    return articles, Is, Cs, Os, target
示例#7
0
 def forward(self, word_inputs: PaddedSequence, query_v_for_attention: torch.Tensor=None, normalize_attention_distribution=True):
     if isinstance(word_inputs, PaddedSequence):
         embedded = self.embedding(word_inputs.data)
         as_padded = PaddedSequence(embedded, word_inputs.batch_sizes, word_inputs.batch_first)
     else:
         raise ValueError("Got an unexpected type {} for word_inputs {}".format(type(word_inputs), word_inputs))
     if self.use_attention:
         a = self.attention_mechanism(as_padded, query_v_for_attention, normalize=normalize_attention_distribution)
         output = torch.sum(a * embedded * as_padded.mask().unsqueeze(2).cuda(), dim=1)
         return embedded, output, a
     else:
         output = torch.sum(embedded, dim=1) / word_inputs.batch_sizes.unsqueeze(-1).to(torch.float)
         return embedded, output, None
示例#8
0
def make_preds(nnet,
               instances,
               batch_size,
               inference_vectorizer,
               verbose_attn_to_batches=False,
               cuda=USE_CUDA):
    # TODO consider removing the inference_vectorizer since all we need is an unk_idx from it
    y_vec = torch.cat(
        [_get_y_vec(inst['y'], as_vec=False) for inst in instances]).squeeze()
    unk_idx = int(
        inference_vectorizer.str_to_idx[SimpleInferenceVectorizer.PAD])
    y_hat_vec = []
    # we batch this so the GPU doesn't run out of memory
    nnet.eval()
    for i in range(0, len(instances), batch_size):
        batch_instances = instances[i:i + batch_size]
        articles, Is, Cs, Os = [
            PaddedSequence.autopad(
                [torch.LongTensor(inst[x]) for inst in batch_instances],
                batch_first=True,
                padding_value=unk_idx) for x in ['article', 'I', 'C', 'O']
        ]
        if cuda:
            articles, Is, Cs, Os = articles.cuda(), Is.cuda(), Cs.cuda(
            ), Os.cuda()
        verbose_attn = verbose_attn_to_batches and i in verbose_attn_to_batches
        y_hat_batch = nnet(articles,
                           Is,
                           Cs,
                           Os,
                           batch_size=len(batch_instances),
                           verbose_attn=verbose_attn)
        y_hat_vec.append(y_hat_batch)
    nnet.train()
    return y_vec, torch.cat(y_hat_vec, dim=0)
示例#9
0
    def forward(self,
                word_inputs: PaddedSequence,
                mask=None,
                query_v_for_attention=None,
                normalize_attention_distribution=True):

        embedded = self.embedding(word_inputs.data)
        projected = self.projection_layer(embedded)
        mask = word_inputs.mask().to("cuda")

        # now to the star transformer.
        # the model will return a tuple comprising <batch, words, dims> and a second
        # tensor (the rely nodes) of <batch, dims> -- we take the latter
        # in the case where no attention is to be used
        token_vectors, a_v = self.st(projected, mask=mask)

        if self.use_attention:
            token_vectors = PaddedSequence(token_vectors,
                                           word_inputs.batch_sizes,
                                           batch_first=True)
            a = None
            if self.concat_relay:
                ###
                # need to concatenate a_v <batch x model_d> for all articles
                ###
                token_vectors_with_relay = self._concat_relay_to_tokens_in_batches(
                    token_vectors, a_v, word_inputs.batch_sizes)

                a = self.attention_mechanism(
                    token_vectors_with_relay,
                    query_v_for_attention,
                    normalize=normalize_attention_distribution)
            else:
                a = self.attention_mechanism(
                    token_vectors,
                    query_v_for_attention,
                    normalize=normalize_attention_distribution)

            # note this is an element-wise multiplication, so each of the hidden states is weighted by the attention vector
            weighted_hidden = torch.sum(a * token_vectors.data, dim=1)

            return token_vectors, weighted_hidden, a

        return a_v
示例#10
0
def _prepare_random_concatenated_spans(model, batch_instances, cuda):
    unk_idx = int(model.vectorizer.str_to_idx[SimpleInferenceVectorizer.PAD])
    target_spans = [inst['evidence_spans'] for inst in batch_instances]
    Is = []
    Os = []
    Cs = []
    target = []
    articles = []
    for instance, evidence_spans in zip(batch_instances, target_spans):
        article = instance['article']
        article = torch.LongTensor(article)
        tgt = torch.zeros(len(article))
        Is.append(instance['I'])
        Os.append(instance['O'])
        Cs.append(instance['C'])
        for start, end in evidence_spans:
            tgt[start:end] = 1
        start, end = random.choice(evidence_spans)
        unacceptable_start = start - (end - start)
        unacceptable_end = end + (end - start)
        random_matched_span_start = random.randint(0, len(article))
        # rejection sample until we find an acceptable span start either inside or outside the document
        while unacceptable_start - random_matched_span_start < 0 and 0 < unacceptable_end - random_matched_span_start:
            random_matched_span_start = random.randint(0, len(article))
        random_matched_span = (random_matched_span_start, random_matched_span_start + end - start)
        if random.random() > 0.5:
            tgt = torch.cat([tgt[start:end], tgt[random_matched_span[0]:random_matched_span[1]]]).contiguous()
            article = torch.cat([article[start:end], article[random_matched_span[0]:random_matched_span[1]]]).contiguous()
        else:
            tgt = torch.cat([tgt[random_matched_span[0]:random_matched_span[1]], tgt[start:end]]).contiguous()
            article = torch.cat([article[random_matched_span[0]:random_matched_span[1]], article[start:end]]).contiguous()
        tgt /= torch.sum(tgt)
        target.append(tgt)
        articles.append(article)

    Is, Cs, Os = [PaddedSequence.autopad([torch.LongTensor(elem) for elem in cond], batch_first=True, padding_value=unk_idx) for cond in [Is, Cs, Os]]
    target = PaddedSequence.autopad(target, batch_first=True, padding_value=0)
    articles = PaddedSequence.autopad(articles, batch_first=True, padding_value=unk_idx)
    if cuda:
        articles, Is, Cs, Os, target = articles.cuda(), Is.cuda(), Cs.cuda(), Os.cuda(), target.cuda()
    return articles, Is, Cs, Os, target
示例#11
0
 def forward(self, query: List[torch.tensor],
             document_batch: List[torch.tensor]):
     assert len(query) == len(document_batch)
     # note about device management:
     # since distributed training is enabled, the inputs to this module can be on *any* device (preferably cpu, since we wrap and unwrap the module)
     # we want to keep these params on the input device (assuming CPU) for as long as possible for cheap memory access
     target_device = next(self.parameters()).device
     cls_token = torch.tensor([self.cls_token_id
                               ])  #.to(device=document_batch[0].device)
     sep_token = torch.tensor([self.sep_token_id
                               ])  #.to(device=document_batch[0].device)
     input_tensors = []
     position_ids = []
     for q, d in zip(query, document_batch):
         if len(q) + len(d) + 2 > self.max_length:
             d = d[:(self.max_length - len(q) - 2)]
         input_tensors.append(
             torch.cat([cls_token, q, sep_token,
                        d.to(dtype=q.dtype)]))
         position_ids.append(
             torch.arange(0, input_tensors[-1].size().numel()))
         #position_ids.append(torch.tensor(list(range(0, len(q) + 1)) + list(range(0, len(d) + 1))))
     bert_input = PaddedSequence.autopad(input_tensors,
                                         batch_first=True,
                                         padding_value=self.pad_token_id,
                                         device=target_device)
     positions = PaddedSequence.autopad(position_ids,
                                        batch_first=True,
                                        padding_value=0,
                                        device=target_device)
     (classes, ) = self.bert(bert_input.data,
                             attention_mask=bert_input.mask(
                                 on=1.0,
                                 off=0.0,
                                 dtype=torch.float,
                                 device=target_device),
                             position_ids=positions.data)
     assert torch.all(classes == classes)  # for nans
     return classes
示例#12
0
    def _concat_relay_to_tokens_in_batches(self, article_token_batches,
                                           relay_batches, batch_sizes):
        '''
        Takes <batch x doc_len x embedding> tensor (article_token_batches) and builds and returns
        a version <batch x doc_len x [embedding + relay_embedding]> which concatenates repeated
        copies of the relay embedding associated with each batch.
        '''

        # create an empty <batch x (token emedding + relay_embedding)>
        article_tokens_with_relays = torch.zeros(
            article_token_batches.data.shape[0],
            article_token_batches.data.shape[1],
            article_token_batches.data.shape[2] + relay_batches.shape[1])

        for b in range(article_token_batches.data.shape[0]):
            batch_relay = relay_batches[b].repeat(
                article_tokens_with_relays.shape[1], 1)
            article_tokens_with_relays[b] = torch.cat(
                (article_token_batches.data[b], batch_relay), 1)

        return PaddedSequence(article_tokens_with_relays.to("cuda"),
                              batch_sizes,
                              batch_first=True)
示例#13
0
def train(ev_inf: InferenceNet, train_Xy, val_Xy, test_Xy, inference_vectorizer, epochs=10, batch_size=16, shuffle=True):
    # we sort these so batches all have approximately the same length (ish), which decreases the 
    # average amount of padding needed, and thus total number of steps in training.
    if not shuffle:
        train_Xy.sort(key=lambda x: len(x['article']))
        val_Xy.sort(key=lambda x: len(x['article']))
        test_Xy.sort(key=lambda x: len(x['article']))
    print("Using {} training examples, {} validation examples, {} testing examples".format(len(train_Xy), len(val_Xy), len(test_Xy)))
    most_common = stats.mode([_get_majority_label(inst) for inst in train_Xy])[0][0]

    best_val_model = None
    best_val_f1 = float('-inf')
    if USE_CUDA:
        ev_inf = ev_inf.cuda()

    optimizer = optim.Adam(ev_inf.parameters())
    criterion = nn.CrossEntropyLoss(reduction='sum')  # sum (not average) of the batch losses.

    # TODO add epoch timing information here
    epochs_since_improvement = 0
    val_metrics = {
        "val_acc": [],
        "val_p": [],
        "val_r": [],
        "val_f1": [],
        "val_loss": [],
        'train_loss': [],
        'val_aucs': [],
        'train_aucs': [],
        'val_entropies': [],
        'val_evidence_token_mass': [],
        'val_evidence_token_err': [],
        'train_entropies': [],
        'train_evidence_token_mass': [],
        'train_evidence_token_err': []
    }
    for epoch in range(epochs):
        if epochs_since_improvement > 10:
            print("Exiting early due to no improvement on validation after 10 epochs.")
            break
        if shuffle:
            random.shuffle(train_Xy)

        epoch_loss = 0
        for i in range(0, len(train_Xy), batch_size):
            instances = train_Xy[i:i+batch_size]
            ys = torch.cat([_get_y_vec(inst['y'], as_vec=False) for inst in instances], dim=0)
            # TODO explain the use of padding here
            unk_idx = int(inference_vectorizer.str_to_idx[SimpleInferenceVectorizer.PAD])
            articles, Is, Cs, Os = [PaddedSequence.autopad([torch.LongTensor(inst[x]) for inst in instances], batch_first=True, padding_value=unk_idx) for x in ['article', 'I', 'C', 'O']]
            optimizer.zero_grad()
            if USE_CUDA:
                articles, Is, Cs, Os = articles.cuda(), Is.cuda(), Cs.cuda(), Os.cuda()
                ys = ys.cuda()
            verbose_attn = (epoch == epochs - 1 and i == 0) or (epoch == 0 and i == 0)
            if verbose_attn:
                print("Training attentions:")
            tags = ev_inf(articles, Is, Cs, Os, batch_size=len(instances), verbose_attn=verbose_attn)
            loss = criterion(tags, ys)
            #if loss.item() != loss.item():
            #    import pdb; pdb.set_trace()
            epoch_loss += loss.item()
            loss.backward()
            optimizer.step()
        val_metrics['train_loss'].append(epoch_loss)

        with torch.no_grad():
            verbose_attn_to_batches = set([0,1,2,3,4]) if epoch == epochs - 1 or epoch == 0 else False
            if verbose_attn_to_batches:
                print("Validation attention:")
            # make_preds runs in eval mode
            val_y, val_y_hat = make_preds(ev_inf, val_Xy, batch_size, inference_vectorizer, verbose_attn_to_batches=verbose_attn_to_batches)
            val_loss = criterion(val_y_hat, val_y.squeeze())
            y_hat = to_int_preds(val_y_hat)

            if epoch == 0:
                dummy_preds = [most_common] * len(val_y)
                dummy_acc = accuracy_score(val_y.cpu(), dummy_preds)
                val_metrics["baseline_val_acc"] = dummy_acc
                p, r, f1, _ = precision_recall_fscore_support(val_y.cpu(), dummy_preds, labels=None, beta=1, average='macro', pos_label=1, warn_for=('f-score',), sample_weight=None)
                val_metrics['p_dummy'] = p
                val_metrics['r_dummy'] = r
                val_metrics['f_dummy'] = f1

                print("val dummy accuracy: {:.3f}".format(dummy_acc))
                print("classification report for dummy on val: ")
                print(classification_report(val_y.cpu(), dummy_preds))
                print("\n\n")

            acc = accuracy_score(val_y.cpu(), y_hat)
            val_metrics["val_acc"].append(acc)
            val_loss = val_loss.cpu().item()
            val_metrics["val_loss"].append(val_loss)
           
            # f1 = f1_score(val_y, y_hat, average="macro")
            p, r, f1, _ = precision_recall_fscore_support(val_y.cpu(), y_hat, labels=None, beta=1, average='macro', pos_label=1, warn_for=('f-score',), sample_weight=None)
            val_metrics["val_f1"].append(f1)
            val_metrics["val_p"].append(p)
            val_metrics["val_r"].append(r)

            if ev_inf.article_encoder.use_attention:
                train_auc, train_entropies, train_evidence_token_masses, train_evidence_token_err = evaluate_model_attention_distribution(ev_inf, train_Xy, cuda=USE_CUDA, compute_attention_diagnostics=True)
                val_auc, val_entropies, val_evidence_token_masses, val_evidence_token_err = evaluate_model_attention_distribution(ev_inf, val_Xy, cuda=USE_CUDA, compute_attention_diagnostics=True)
                print("train auc: {:.3f}, entropy: {:.3f}, evidence mass: {:.3f}, err: {:.3f}".format(train_auc, train_entropies, train_evidence_token_masses, train_evidence_token_err))
                print("val auc: {:.3f}, entropy: {:.3f}, evidence mass: {:.3f}, err: {:.3f}".format(val_auc, val_entropies, val_evidence_token_masses, val_evidence_token_err))
            else:
                train_auc, train_entropies, train_evidence_token_masses, train_evidence_token_err = "", "", "", ""
                val_auc, val_entropies, val_evidence_token_masses, val_evidence_token_err = "", "", "", ""
            val_metrics['train_aucs'].append(train_auc)
            val_metrics['train_entropies'].append(train_entropies)
            val_metrics['train_evidence_token_mass'].append(train_evidence_token_masses)
            val_metrics['train_evidence_token_err'].append(train_evidence_token_err)
            val_metrics['val_aucs'].append(val_auc)
            val_metrics['val_entropies'].append(val_entropies)
            val_metrics['val_evidence_token_mass'].append(val_evidence_token_masses)
            val_metrics['val_evidence_token_err'].append(val_evidence_token_err)
            if f1 > best_val_f1:
                print("New best model at {} with val f1 {:.3f}".format(epoch, f1))
                best_val_f1 = f1
                best_val_model = copy.deepcopy(ev_inf)
                epochs_since_improvement = 0
            else:
                epochs_since_improvement += 1

            #if val_loss != val_loss or epoch_loss != epoch_loss:
            #    import pdb; pdb.set_trace()

            print("epoch {}. train loss: {}; val loss: {}; val acc: {:.3f}".format(
                epoch, epoch_loss, val_loss, acc))
       
            print(classification_report(val_y.cpu(), y_hat))
            print("val macro f1: {0:.3f}".format(f1))
            print("\n\n")

    val_metrics['best_val_f1'] = best_val_f1
    with torch.no_grad():
        print("Test attentions:")
        verbose_attn_to_batches = set([0,1,2,3,4])
        # make_preds runs in eval mode
        test_y, test_y_hat = make_preds(best_val_model, test_Xy, batch_size, inference_vectorizer, verbose_attn_to_batches=verbose_attn_to_batches)
        test_loss = criterion(test_y_hat, test_y.squeeze())
        y_hat = to_int_preds(test_y_hat)
        final_test_preds = zip([t['a_id'] for t in test_Xy], [t['p_id'] for t in test_Xy], y_hat)

        acc = accuracy_score(test_y.cpu(), y_hat)
        val_metrics["test_acc"] = acc
        test_loss = test_loss.cpu().item()
        val_metrics["test_loss"] = test_loss

        # f1 = f1_score(test_y, y_hat, average="macro")
        p, r, f1, _ = precision_recall_fscore_support(test_y.cpu(), y_hat, labels=None, beta=1, average='macro', pos_label=1, warn_for=('f-score',), sample_weight=None)
        val_metrics["test_f1"] = f1
        val_metrics["test_p"] = p
        val_metrics["test_r"] = r
        if ev_inf.article_encoder.use_attention:
            test_auc, test_entropies, test_evidence_token_masses, test_evidence_token_err = evaluate_model_attention_distribution(best_val_model, test_Xy, cuda=USE_CUDA, compute_attention_diagnostics=True)
            print("test auc: {:.3f}, , entropy: {:.3f}, kl_to_uniform {:.3f}".format(test_auc, test_entropies, test_evidence_token_masses))
        else:
            test_auc, test_entropies, test_evidence_token_masses, test_evidence_token_err = "", "", "", ""
        val_metrics['test_auc'] = test_auc
        val_metrics['test_entropy'] = test_entropies
        val_metrics['test_evidence_token_mass'] = test_evidence_token_masses
        val_metrics['test_evidence_token_err'] = test_evidence_token_err

        print("test loss: {}; test acc: {:.3f}".format(test_loss, acc))

        print(classification_report(test_y.cpu(), y_hat))
        print("test macro f1: {}".format(f1))
        print("\n\n")

    return best_val_model, inference_vectorizer, train_Xy, val_Xy, val_metrics, final_test_preds
示例#14
0
    def forward(self,
                article_tokens: PaddedSequence,
                indices,
                I_tokens: PaddedSequence,
                C_tokens: PaddedSequence,
                O_tokens: PaddedSequence,
                batch_size,
                h_dropout_rate=0.2,
                recursive_encoding={}):

        inner_batch = 1  # this is over sections!

        ### Run our encode function ###
        I_v, C_v, O_v = self._encode(I_tokens, C_tokens, O_tokens)

        query_v, old_query_v = None, None
        ### Run normal attention over the data ###
        if self.article_encoder.condition_attention:
            query_v = torch.cat([I_v, C_v, O_v], dim=1)
            old_query_v = copy.deepcopy(query_v)

        #if self.use_attention_over_article_tokens:
        cmb_hidden = []
        ### encode each section with the article encoder ###
        for i in range(0, len(article_tokens[0]), inner_batch):
            tokens = article_tokens[0][i:i + inner_batch]
            new_tkn = PaddedSequence.autopad(tokens, batch_first=True)
            if query_v is not None:
                query_v = torch.cat([
                    old_query_v for _ in range(min(len(tokens), inner_batch))
                ],
                                    dim=0)
            #_, hidden, _ = self.article_encoder(new_tkn, query_v_for_attention=query_v)

            if self.article_encoder in ("transformer", "CBoW"):
                hidden = self.article_encoder(new_tkn,
                                              query_v_for_attention=query_v)
            else:
                # assume RNN
                _, hidden = self.article_encoder(new_tkn,
                                                 query_v_for_attention=query_v)

            cmb_hidden.append(hidden)

        hidden = torch.cat(cmb_hidden, dim=0)

        #else:
        #    if self.article_encoder in ("Transformer", "CBoW"):
        #
        #        hidden = self.article_encoder(article_tokens, query_v_for_attention=query_v)
        #    else:
        # assume RNN
        #        _, hidden = self.article_encoder(article_tokens, query_v_for_attention=query_v)

        art_secs = []
        token_secs = []
        i = 0
        ### Reshape our tokens + article representations. ###
        for idx in indices:
            art_secs.append(hidden[i:i + idx])
            token_secs.append(article_tokens[i:i + idx])
            i += idx

        hidden_articles = art_secs
        batch_a_v = None
        section_weights = []

        for i in range(batch_size):
            hidden_art = hidden_articles[i]  # single hidden article
            token_art = token_secs[i]  # single article tokens

            if self.condition_attention:
                query_v = torch.cat(
                    [old_query_v for _ in range(len(hidden_art))], dim=0)

            ### Run section attention over the data for each section ###
            a = self.section_attn(token_art,
                                  hidden_input_states=hidden_art,
                                  query_v_for_attention=query_v,
                                  normalize=True)

            section_weights.append(a)

            if self.recursive_encoding:
                section_splits = recursive_encoding['section_splits']
                new_articles = []
                last = 0

                ### -> Reweight sections based on subsection:
                # [Alpha(S1.1), Alpha(S1.2)] * [Encoding of S1.1, Encoding of S1.2])
                for s in section_splits:
                    section_encoding = hidden_art[last:last + s]
                    ws = a[last:last + s]
                    new_articles.append(
                        torch.mm(torch.transpose(ws, dim0=1, dim1=0),
                                 section_encoding))

                ### -> another attention layer (share it)
                new_tokens = recursive_encoding['big_sections']
                hidden_art = torch.cat(new_articles, dim=0).unsqueeze(0)

                new_query_v = torch.cat(
                    [old_query_v for _ in range(hidden_art.shape[1])], dim=0)
                a = self.section_attn(new_tokens,
                                      hidden_input_states=hidden_art,
                                      query_v_for_attention=new_query_v,
                                      normalize=True)

            ### Combine the re-weighted sections ###
            weighted = (a * hidden_art).squeeze().unsqueeze(0)
            weighted_hidden = torch.sum(weighted, dim=1)
            article_v = torch.sum(weighted_hidden, dim=0)

            if batch_a_v is None:
                batch_a_v = article_v
            else:
                batch_a_v = torch.stack([batch_a_v, article_v])  # per batch

        ### Finish Plugging in ###
        if len(batch_a_v.shape) == 1:
            batch_a_v = batch_a_v.unsqueeze(0)

        h = torch.cat([batch_a_v, I_v, C_v, O_v], dim=1)
        raw_out = self.out(self.MLP_hidden(h))

        return F.softmax(raw_out, dim=1), section_weights