Ejemplo n.º 1
0
    def _prepare_random_matched_spans(model, batch_instances, cuda):
        unk_idx = int(model.vectorizer.str_to_idx[SimpleInferenceVectorizer.PAD])
        Is, Cs, Os = [PaddedSequence.autopad([torch.LongTensor(inst[x]) for inst in batch_instances], batch_first=True, padding_value=unk_idx) for x in ['I', 'C', 'O']]
        target_spans = [inst['evidence_spans'] for inst in batch_instances]
        target = []
        articles = []
        for article, evidence_spans in zip((x['article'] for x in batch_instances), target_spans):
            tgt = torch.zeros(len(article))
            for start, end in evidence_spans:
                tgt[start:end] = 1
            (start, end) = random.choice(evidence_spans)
            # select a random span of the same length
            random_matched_span_start = random.randint(0, len(article))
            random_matched_span_end = random_matched_span_start + end - start
            tgt_pos = tgt[start:end]
            tgt_neg = tgt[random_matched_span_start:random_matched_span_end]
            article_pos = torch.LongTensor(article[start:end])
            article_neg = torch.LongTensor(article[random_matched_span_start:random_matched_span_end])
            if random.random() > 0.5:
                articles.append(torch.cat([article_pos, article_neg]))
                target.append(torch.cat([tgt_pos, tgt_neg]))
            else:
                articles.append(torch.cat([article_neg, article_pos]))
                target.append(torch.cat([tgt_neg, tgt_pos]))

        target = PaddedSequence.autopad(target, batch_first=True, padding_value=0)
        articles = PaddedSequence.autopad(articles, batch_first=True, padding_value=unk_idx)
        if cuda:
            articles, Is, Cs, Os, target = articles.cuda(), Is.cuda(), Cs.cuda(), Os.cuda(), target.cuda()
        return articles, Is, Cs, Os, target
Ejemplo n.º 2
0
def prepare_article_attention_target_balanced(model, batch_instances, cuda):
    unk_idx = int(model.vectorizer.str_to_idx[SimpleInferenceVectorizer.PAD])
    Is = []
    Cs = []
    Os = []
    articles = []
    target = []
    for inst in batch_instances:
        i = torch.LongTensor(inst['I'])
        c = torch.LongTensor(inst['C'])
        o = torch.LongTensor(inst['O'])
        article = torch.LongTensor(inst['article'])
        target_spans = set([tuple(x) for x in inst['evidence_spans']])
        for start, end in target_spans:
            # positive example
            Is.append(i)
            Cs.append(c)
            Os.append(o)
            articles.append(article[start:end])
            target.append(torch.ones(end - start))

            # negative example
            neg_start, neg_end = _fetch_random_span(start, end, len(article), end - start)
            Is.append(i)
            Cs.append(c)
            Os.append(o)
            articles.append(article[neg_start:neg_end])
            target.append(torch.zeros(neg_end - neg_start))

    Is, Cs, Os, articles = [PaddedSequence.autopad(x, batch_first=True, padding_value=unk_idx) for x in [Is, Cs, Os, articles]]
    target = PaddedSequence.autopad(target, batch_first=True, padding_value=0)
    if cuda:
        articles, Is, Cs, Os, target = articles.cuda(), Is.cuda(), Cs.cuda(), Os.cuda(), target.cuda()
    return articles, Is, Cs, Os, target
Ejemplo n.º 3
0
def split_sections(instances, inference_vectorizer, big_sections=False):
    """ Split into sections. If big_sections = False, use subsections, else use big sections. """
    unk_idx = int(
        inference_vectorizer.str_to_idx[SimpleInferenceVectorizer.PAD])
    Is, Cs, Os = [
        PaddedSequence.autopad(
            [torch.LongTensor(inst[x]) for inst in instances],
            batch_first=True,
            padding_value=unk_idx) for x in ['I', 'C', 'O']
    ]
    indices = []
    sections = []
    section_titles = []
    for i in range(len(instances)):
        info = instances[i]
        if big_sections:
            info = gen_big_sections(info)

        ss = info['section_splits']
        art = info['article']
        evidence_labels = info['evidence_spans']
        section_labels = []
        section_titles.append(info['section_titles'])
        start = 0
        new_added = 0

        for s in ss:
            tmp = art[s:start + s]
            is_evid = False
            for labels in evidence_labels:
                is_evid = is_evid or interval_overlap([start, start + s],
                                                      labels)

            if is_evid:
                section_labels.append(1)
            else:
                section_labels.append(0)

            if len(tmp) == 0:
                tmp = [unk_idx]
            sections.append(tmp)
            start += s
            new_added += 1

        indices.append(new_added)

    # cap number of sections...
    inst = [torch.LongTensor(inst) for inst in sections]
    import pdb
    pdb.set_trace()
    pad_sections = PaddedSequence.autopad(inst,
                                          batch_first=True,
                                          padding_value=unk_idx)
    return pad_sections, indices, section_labels, section_titles, Is, Cs, Os
Ejemplo n.º 4
0
def prepare_article_attention_target(model, batch_instances, cuda):
    unk_idx = int(model.vectorizer.str_to_idx[SimpleInferenceVectorizer.PAD])
    articles, Is, Cs, Os = [PaddedSequence.autopad([torch.LongTensor(inst[x]) for inst in batch_instances], batch_first=True, padding_value=unk_idx) for x in ['article', 'I', 'C', 'O']]
    target_spans = [inst['evidence_spans'] for inst in batch_instances]
    target = [torch.zeros(len(x['article'])) for x in batch_instances]
    for tgt, spans in zip(target, target_spans):
        for start, end in spans:
            tgt[start:end] = 1
    target = PaddedSequence.autopad(target, batch_first=True, padding_value=0)
    if cuda:
        articles, Is, Cs, Os, target = articles.cuda(), Is.cuda(), Cs.cuda(), Os.cuda(), target.cuda()
    return articles, Is, Cs, Os, target
Ejemplo n.º 5
0
def make_preds(nnet,
               instances,
               batch_size,
               inference_vectorizer,
               verbose_attn_to_batches=False,
               cuda=USE_CUDA):
    # TODO consider removing the inference_vectorizer since all we need is an unk_idx from it
    y_vec = torch.cat(
        [_get_y_vec(inst['y'], as_vec=False) for inst in instances]).squeeze()
    unk_idx = int(
        inference_vectorizer.str_to_idx[SimpleInferenceVectorizer.PAD])
    y_hat_vec = []
    # we batch this so the GPU doesn't run out of memory
    nnet.eval()
    for i in range(0, len(instances), batch_size):
        batch_instances = instances[i:i + batch_size]
        articles, Is, Cs, Os = [
            PaddedSequence.autopad(
                [torch.LongTensor(inst[x]) for inst in batch_instances],
                batch_first=True,
                padding_value=unk_idx) for x in ['article', 'I', 'C', 'O']
        ]
        if cuda:
            articles, Is, Cs, Os = articles.cuda(), Is.cuda(), Cs.cuda(
            ), Os.cuda()
        verbose_attn = verbose_attn_to_batches and i in verbose_attn_to_batches
        y_hat_batch = nnet(articles,
                           Is,
                           Cs,
                           Os,
                           batch_size=len(batch_instances),
                           verbose_attn=verbose_attn)
        y_hat_vec.append(y_hat_batch)
    nnet.train()
    return y_vec, torch.cat(y_hat_vec, dim=0)
Ejemplo n.º 6
0
def _prepare_random_concatenated_spans(model, batch_instances, cuda):
    unk_idx = int(model.vectorizer.str_to_idx[SimpleInferenceVectorizer.PAD])
    target_spans = [inst['evidence_spans'] for inst in batch_instances]
    Is = []
    Os = []
    Cs = []
    target = []
    articles = []
    for instance, evidence_spans in zip(batch_instances, target_spans):
        article = instance['article']
        article = torch.LongTensor(article)
        tgt = torch.zeros(len(article))
        Is.append(instance['I'])
        Os.append(instance['O'])
        Cs.append(instance['C'])
        for start, end in evidence_spans:
            tgt[start:end] = 1
        start, end = random.choice(evidence_spans)
        unacceptable_start = start - (end - start)
        unacceptable_end = end + (end - start)
        random_matched_span_start = random.randint(0, len(article))
        # rejection sample until we find an acceptable span start either inside or outside the document
        while unacceptable_start - random_matched_span_start < 0 and 0 < unacceptable_end - random_matched_span_start:
            random_matched_span_start = random.randint(0, len(article))
        random_matched_span = (random_matched_span_start, random_matched_span_start + end - start)
        if random.random() > 0.5:
            tgt = torch.cat([tgt[start:end], tgt[random_matched_span[0]:random_matched_span[1]]]).contiguous()
            article = torch.cat([article[start:end], article[random_matched_span[0]:random_matched_span[1]]]).contiguous()
        else:
            tgt = torch.cat([tgt[random_matched_span[0]:random_matched_span[1]], tgt[start:end]]).contiguous()
            article = torch.cat([article[random_matched_span[0]:random_matched_span[1]], article[start:end]]).contiguous()
        tgt /= torch.sum(tgt)
        target.append(tgt)
        articles.append(article)

    Is, Cs, Os = [PaddedSequence.autopad([torch.LongTensor(elem) for elem in cond], batch_first=True, padding_value=unk_idx) for cond in [Is, Cs, Os]]
    target = PaddedSequence.autopad(target, batch_first=True, padding_value=0)
    articles = PaddedSequence.autopad(articles, batch_first=True, padding_value=unk_idx)
    if cuda:
        articles, Is, Cs, Os, target = articles.cuda(), Is.cuda(), Cs.cuda(), Os.cuda(), target.cuda()
    return articles, Is, Cs, Os, target
Ejemplo n.º 7
0
 def forward(self, query: List[torch.tensor],
             document_batch: List[torch.tensor]):
     assert len(query) == len(document_batch)
     # note about device management:
     # since distributed training is enabled, the inputs to this module can be on *any* device (preferably cpu, since we wrap and unwrap the module)
     # we want to keep these params on the input device (assuming CPU) for as long as possible for cheap memory access
     target_device = next(self.parameters()).device
     cls_token = torch.tensor([self.cls_token_id
                               ])  #.to(device=document_batch[0].device)
     sep_token = torch.tensor([self.sep_token_id
                               ])  #.to(device=document_batch[0].device)
     input_tensors = []
     position_ids = []
     for q, d in zip(query, document_batch):
         if len(q) + len(d) + 2 > self.max_length:
             d = d[:(self.max_length - len(q) - 2)]
         input_tensors.append(
             torch.cat([cls_token, q, sep_token,
                        d.to(dtype=q.dtype)]))
         position_ids.append(
             torch.arange(0, input_tensors[-1].size().numel()))
         #position_ids.append(torch.tensor(list(range(0, len(q) + 1)) + list(range(0, len(d) + 1))))
     bert_input = PaddedSequence.autopad(input_tensors,
                                         batch_first=True,
                                         padding_value=self.pad_token_id,
                                         device=target_device)
     positions = PaddedSequence.autopad(position_ids,
                                        batch_first=True,
                                        padding_value=0,
                                        device=target_device)
     (classes, ) = self.bert(bert_input.data,
                             attention_mask=bert_input.mask(
                                 on=1.0,
                                 off=0.0,
                                 dtype=torch.float,
                                 device=target_device),
                             position_ids=positions.data)
     assert torch.all(classes == classes)  # for nans
     return classes
Ejemplo n.º 8
0
def train(ev_inf: InferenceNet, train_Xy, val_Xy, test_Xy, inference_vectorizer, epochs=10, batch_size=16, shuffle=True):
    # we sort these so batches all have approximately the same length (ish), which decreases the 
    # average amount of padding needed, and thus total number of steps in training.
    if not shuffle:
        train_Xy.sort(key=lambda x: len(x['article']))
        val_Xy.sort(key=lambda x: len(x['article']))
        test_Xy.sort(key=lambda x: len(x['article']))
    print("Using {} training examples, {} validation examples, {} testing examples".format(len(train_Xy), len(val_Xy), len(test_Xy)))
    most_common = stats.mode([_get_majority_label(inst) for inst in train_Xy])[0][0]

    best_val_model = None
    best_val_f1 = float('-inf')
    if USE_CUDA:
        ev_inf = ev_inf.cuda()

    optimizer = optim.Adam(ev_inf.parameters())
    criterion = nn.CrossEntropyLoss(reduction='sum')  # sum (not average) of the batch losses.

    # TODO add epoch timing information here
    epochs_since_improvement = 0
    val_metrics = {
        "val_acc": [],
        "val_p": [],
        "val_r": [],
        "val_f1": [],
        "val_loss": [],
        'train_loss': [],
        'val_aucs': [],
        'train_aucs': [],
        'val_entropies': [],
        'val_evidence_token_mass': [],
        'val_evidence_token_err': [],
        'train_entropies': [],
        'train_evidence_token_mass': [],
        'train_evidence_token_err': []
    }
    for epoch in range(epochs):
        if epochs_since_improvement > 10:
            print("Exiting early due to no improvement on validation after 10 epochs.")
            break
        if shuffle:
            random.shuffle(train_Xy)

        epoch_loss = 0
        for i in range(0, len(train_Xy), batch_size):
            instances = train_Xy[i:i+batch_size]
            ys = torch.cat([_get_y_vec(inst['y'], as_vec=False) for inst in instances], dim=0)
            # TODO explain the use of padding here
            unk_idx = int(inference_vectorizer.str_to_idx[SimpleInferenceVectorizer.PAD])
            articles, Is, Cs, Os = [PaddedSequence.autopad([torch.LongTensor(inst[x]) for inst in instances], batch_first=True, padding_value=unk_idx) for x in ['article', 'I', 'C', 'O']]
            optimizer.zero_grad()
            if USE_CUDA:
                articles, Is, Cs, Os = articles.cuda(), Is.cuda(), Cs.cuda(), Os.cuda()
                ys = ys.cuda()
            verbose_attn = (epoch == epochs - 1 and i == 0) or (epoch == 0 and i == 0)
            if verbose_attn:
                print("Training attentions:")
            tags = ev_inf(articles, Is, Cs, Os, batch_size=len(instances), verbose_attn=verbose_attn)
            loss = criterion(tags, ys)
            #if loss.item() != loss.item():
            #    import pdb; pdb.set_trace()
            epoch_loss += loss.item()
            loss.backward()
            optimizer.step()
        val_metrics['train_loss'].append(epoch_loss)

        with torch.no_grad():
            verbose_attn_to_batches = set([0,1,2,3,4]) if epoch == epochs - 1 or epoch == 0 else False
            if verbose_attn_to_batches:
                print("Validation attention:")
            # make_preds runs in eval mode
            val_y, val_y_hat = make_preds(ev_inf, val_Xy, batch_size, inference_vectorizer, verbose_attn_to_batches=verbose_attn_to_batches)
            val_loss = criterion(val_y_hat, val_y.squeeze())
            y_hat = to_int_preds(val_y_hat)

            if epoch == 0:
                dummy_preds = [most_common] * len(val_y)
                dummy_acc = accuracy_score(val_y.cpu(), dummy_preds)
                val_metrics["baseline_val_acc"] = dummy_acc
                p, r, f1, _ = precision_recall_fscore_support(val_y.cpu(), dummy_preds, labels=None, beta=1, average='macro', pos_label=1, warn_for=('f-score',), sample_weight=None)
                val_metrics['p_dummy'] = p
                val_metrics['r_dummy'] = r
                val_metrics['f_dummy'] = f1

                print("val dummy accuracy: {:.3f}".format(dummy_acc))
                print("classification report for dummy on val: ")
                print(classification_report(val_y.cpu(), dummy_preds))
                print("\n\n")

            acc = accuracy_score(val_y.cpu(), y_hat)
            val_metrics["val_acc"].append(acc)
            val_loss = val_loss.cpu().item()
            val_metrics["val_loss"].append(val_loss)
           
            # f1 = f1_score(val_y, y_hat, average="macro")
            p, r, f1, _ = precision_recall_fscore_support(val_y.cpu(), y_hat, labels=None, beta=1, average='macro', pos_label=1, warn_for=('f-score',), sample_weight=None)
            val_metrics["val_f1"].append(f1)
            val_metrics["val_p"].append(p)
            val_metrics["val_r"].append(r)

            if ev_inf.article_encoder.use_attention:
                train_auc, train_entropies, train_evidence_token_masses, train_evidence_token_err = evaluate_model_attention_distribution(ev_inf, train_Xy, cuda=USE_CUDA, compute_attention_diagnostics=True)
                val_auc, val_entropies, val_evidence_token_masses, val_evidence_token_err = evaluate_model_attention_distribution(ev_inf, val_Xy, cuda=USE_CUDA, compute_attention_diagnostics=True)
                print("train auc: {:.3f}, entropy: {:.3f}, evidence mass: {:.3f}, err: {:.3f}".format(train_auc, train_entropies, train_evidence_token_masses, train_evidence_token_err))
                print("val auc: {:.3f}, entropy: {:.3f}, evidence mass: {:.3f}, err: {:.3f}".format(val_auc, val_entropies, val_evidence_token_masses, val_evidence_token_err))
            else:
                train_auc, train_entropies, train_evidence_token_masses, train_evidence_token_err = "", "", "", ""
                val_auc, val_entropies, val_evidence_token_masses, val_evidence_token_err = "", "", "", ""
            val_metrics['train_aucs'].append(train_auc)
            val_metrics['train_entropies'].append(train_entropies)
            val_metrics['train_evidence_token_mass'].append(train_evidence_token_masses)
            val_metrics['train_evidence_token_err'].append(train_evidence_token_err)
            val_metrics['val_aucs'].append(val_auc)
            val_metrics['val_entropies'].append(val_entropies)
            val_metrics['val_evidence_token_mass'].append(val_evidence_token_masses)
            val_metrics['val_evidence_token_err'].append(val_evidence_token_err)
            if f1 > best_val_f1:
                print("New best model at {} with val f1 {:.3f}".format(epoch, f1))
                best_val_f1 = f1
                best_val_model = copy.deepcopy(ev_inf)
                epochs_since_improvement = 0
            else:
                epochs_since_improvement += 1

            #if val_loss != val_loss or epoch_loss != epoch_loss:
            #    import pdb; pdb.set_trace()

            print("epoch {}. train loss: {}; val loss: {}; val acc: {:.3f}".format(
                epoch, epoch_loss, val_loss, acc))
       
            print(classification_report(val_y.cpu(), y_hat))
            print("val macro f1: {0:.3f}".format(f1))
            print("\n\n")

    val_metrics['best_val_f1'] = best_val_f1
    with torch.no_grad():
        print("Test attentions:")
        verbose_attn_to_batches = set([0,1,2,3,4])
        # make_preds runs in eval mode
        test_y, test_y_hat = make_preds(best_val_model, test_Xy, batch_size, inference_vectorizer, verbose_attn_to_batches=verbose_attn_to_batches)
        test_loss = criterion(test_y_hat, test_y.squeeze())
        y_hat = to_int_preds(test_y_hat)
        final_test_preds = zip([t['a_id'] for t in test_Xy], [t['p_id'] for t in test_Xy], y_hat)

        acc = accuracy_score(test_y.cpu(), y_hat)
        val_metrics["test_acc"] = acc
        test_loss = test_loss.cpu().item()
        val_metrics["test_loss"] = test_loss

        # f1 = f1_score(test_y, y_hat, average="macro")
        p, r, f1, _ = precision_recall_fscore_support(test_y.cpu(), y_hat, labels=None, beta=1, average='macro', pos_label=1, warn_for=('f-score',), sample_weight=None)
        val_metrics["test_f1"] = f1
        val_metrics["test_p"] = p
        val_metrics["test_r"] = r
        if ev_inf.article_encoder.use_attention:
            test_auc, test_entropies, test_evidence_token_masses, test_evidence_token_err = evaluate_model_attention_distribution(best_val_model, test_Xy, cuda=USE_CUDA, compute_attention_diagnostics=True)
            print("test auc: {:.3f}, , entropy: {:.3f}, kl_to_uniform {:.3f}".format(test_auc, test_entropies, test_evidence_token_masses))
        else:
            test_auc, test_entropies, test_evidence_token_masses, test_evidence_token_err = "", "", "", ""
        val_metrics['test_auc'] = test_auc
        val_metrics['test_entropy'] = test_entropies
        val_metrics['test_evidence_token_mass'] = test_evidence_token_masses
        val_metrics['test_evidence_token_err'] = test_evidence_token_err

        print("test loss: {}; test acc: {:.3f}".format(test_loss, acc))

        print(classification_report(test_y.cpu(), y_hat))
        print("test macro f1: {}".format(f1))
        print("\n\n")

    return best_val_model, inference_vectorizer, train_Xy, val_Xy, val_metrics, final_test_preds
Ejemplo n.º 9
0
    def forward(self,
                article_tokens: PaddedSequence,
                indices,
                I_tokens: PaddedSequence,
                C_tokens: PaddedSequence,
                O_tokens: PaddedSequence,
                batch_size,
                h_dropout_rate=0.2,
                recursive_encoding={}):

        inner_batch = 1  # this is over sections!

        ### Run our encode function ###
        I_v, C_v, O_v = self._encode(I_tokens, C_tokens, O_tokens)

        query_v, old_query_v = None, None
        ### Run normal attention over the data ###
        if self.article_encoder.condition_attention:
            query_v = torch.cat([I_v, C_v, O_v], dim=1)
            old_query_v = copy.deepcopy(query_v)

        #if self.use_attention_over_article_tokens:
        cmb_hidden = []
        ### encode each section with the article encoder ###
        for i in range(0, len(article_tokens[0]), inner_batch):
            tokens = article_tokens[0][i:i + inner_batch]
            new_tkn = PaddedSequence.autopad(tokens, batch_first=True)
            if query_v is not None:
                query_v = torch.cat([
                    old_query_v for _ in range(min(len(tokens), inner_batch))
                ],
                                    dim=0)
            #_, hidden, _ = self.article_encoder(new_tkn, query_v_for_attention=query_v)

            if self.article_encoder in ("transformer", "CBoW"):
                hidden = self.article_encoder(new_tkn,
                                              query_v_for_attention=query_v)
            else:
                # assume RNN
                _, hidden = self.article_encoder(new_tkn,
                                                 query_v_for_attention=query_v)

            cmb_hidden.append(hidden)

        hidden = torch.cat(cmb_hidden, dim=0)

        #else:
        #    if self.article_encoder in ("Transformer", "CBoW"):
        #
        #        hidden = self.article_encoder(article_tokens, query_v_for_attention=query_v)
        #    else:
        # assume RNN
        #        _, hidden = self.article_encoder(article_tokens, query_v_for_attention=query_v)

        art_secs = []
        token_secs = []
        i = 0
        ### Reshape our tokens + article representations. ###
        for idx in indices:
            art_secs.append(hidden[i:i + idx])
            token_secs.append(article_tokens[i:i + idx])
            i += idx

        hidden_articles = art_secs
        batch_a_v = None
        section_weights = []

        for i in range(batch_size):
            hidden_art = hidden_articles[i]  # single hidden article
            token_art = token_secs[i]  # single article tokens

            if self.condition_attention:
                query_v = torch.cat(
                    [old_query_v for _ in range(len(hidden_art))], dim=0)

            ### Run section attention over the data for each section ###
            a = self.section_attn(token_art,
                                  hidden_input_states=hidden_art,
                                  query_v_for_attention=query_v,
                                  normalize=True)

            section_weights.append(a)

            if self.recursive_encoding:
                section_splits = recursive_encoding['section_splits']
                new_articles = []
                last = 0

                ### -> Reweight sections based on subsection:
                # [Alpha(S1.1), Alpha(S1.2)] * [Encoding of S1.1, Encoding of S1.2])
                for s in section_splits:
                    section_encoding = hidden_art[last:last + s]
                    ws = a[last:last + s]
                    new_articles.append(
                        torch.mm(torch.transpose(ws, dim0=1, dim1=0),
                                 section_encoding))

                ### -> another attention layer (share it)
                new_tokens = recursive_encoding['big_sections']
                hidden_art = torch.cat(new_articles, dim=0).unsqueeze(0)

                new_query_v = torch.cat(
                    [old_query_v for _ in range(hidden_art.shape[1])], dim=0)
                a = self.section_attn(new_tokens,
                                      hidden_input_states=hidden_art,
                                      query_v_for_attention=new_query_v,
                                      normalize=True)

            ### Combine the re-weighted sections ###
            weighted = (a * hidden_art).squeeze().unsqueeze(0)
            weighted_hidden = torch.sum(weighted, dim=1)
            article_v = torch.sum(weighted_hidden, dim=0)

            if batch_a_v is None:
                batch_a_v = article_v
            else:
                batch_a_v = torch.stack([batch_a_v, article_v])  # per batch

        ### Finish Plugging in ###
        if len(batch_a_v.shape) == 1:
            batch_a_v = batch_a_v.unsqueeze(0)

        h = torch.cat([batch_a_v, I_v, C_v, O_v], dim=1)
        raw_out = self.out(self.MLP_hidden(h))

        return F.softmax(raw_out, dim=1), section_weights