コード例 #1
0
    def argmax(self) -> Tuple[LongTensor, LongTensor]:
        """Compute the most probable labeled dependency tree.

        Returns:
            - Tensor of shape (B, N) containing the head positions of the best tree.
            - Tensor of shape (B, N) containing the dependency types for the
              corresponding head-dependent relation.
        """
        assert self.mask is not None

        # each shape: (bsz, slen, slen)
        scores, best_types = self.scores.max(dim=3)
        lengths = self.mask.long().sum(dim=1)

        if self.proj:
            crf = DependencyCRF(_unconvert(scores),
                                lengths - 1,
                                multiroot=self.multiroot)
            # shape: (bsz, slen)
            _, pred_heads = _convert(crf.argmax).max(dim=1)
            pred_heads[:, self.ROOT] = self.ROOT
        else:
            if not self.multiroot:
                warnings.warn(
                    "argmax for non-projective is still multiroot although multiroot=False"
                )
            # shape: (bsz, slen)
            pred_heads = find_mst(scores, lengths.tolist())

        # shape: (bsz, slen)
        pred_types = best_types.gather(1, pred_heads.unsqueeze(1)).squeeze(1)

        return pred_heads, pred_types  # type: ignore
コード例 #2
0
    def marginals(self) -> Tensor:
        """Compute the arc marginal probabilities.

        Returns:
            Tensor of shape (B, N, N, L) containing the arc marginal probabilities.
        """
        assert self.mask is not None

        if self.proj:
            lengths = self.mask.long().sum(dim=1)
            crf = DependencyCRF(_unconvert(self.scores),
                                lengths - 1,
                                multiroot=self.multiroot)
            margs = _convert(crf.marginals)

            # marginals of incoming arcs to root are zero
            margs[:, :, self.ROOT] = 0
            # marginals of self-loops are zero
            self_loop_mask = torch.eye(margs.size(1)).to(
                margs.device).unsqueeze(2).bool()
            margs = margs.masked_fill(self_loop_mask, 0)

            return margs

        return compute_marginals(self.scores, self.mask, self.multiroot)
コード例 #3
0
def train(train_iter, val_iter, model):
    opt = AdamW(model.parameters(), lr=1e-4, eps=1e-8)
    scheduler = WarmupLinearSchedule(opt, warmup_steps=20, t_total=2500)
    model.train()
    losses = []
    for i, ex in enumerate(train_iter):
        opt.zero_grad()
        words, mapper, _ = ex.word
        label, lengths = ex.head
        batch, _ = label.shape

        # Model
        final = model(words.cuda(), mapper)
        for b in range(batch):
            final[b, lengths[b]-1:, :] = 0
            final[b, :, lengths[b]-1:] = 0

        if not lengths.max() <= final.shape[1] + 1:
            print("fail")
            continue
        dist = DependencyCRF(final, lengths=lengths)

        labels = dist.struct.to_parts(label, lengths=lengths).type_as(final)
        log_prob = dist.log_prob(labels)

        loss = log_prob.sum()
        (-loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        opt.step()
        scheduler.step()
        losses.append(loss.detach())
        if i % 50 == 1:
            print(-torch.tensor(losses).mean(), words.shape)
            losses = []
        if i % 600 == 500:
            validate(val_iter)
コード例 #4
0
    def log_partitions(self) -> Tensor:
        """Compute the log partition function.

        Returns:
            1-D tensor of length B containing the log partition functions.
        """
        assert self.mask is not None

        if self.proj:
            lengths = self.mask.long().sum(dim=1)
            crf = DependencyCRF(_unconvert(self.scores),
                                lengths - 1,
                                multiroot=self.multiroot)
            return crf.partition

        return compute_log_partitions(self.scores, self.mask, self.multiroot)
コード例 #5
0
def validate(val_iter):
    incorrect_edges = 0
    total_edges = 0
    model.eval()
    for i, ex in enumerate(val_iter):
        words, mapper, _ = ex.word
        label, lengths = ex.head
        batch, _ = label.shape

        final = model(words.cuda(), mapper)
        for b in range(batch):
            final[b, lengths[b]-1:, :] = 0
            final[b, :, lengths[b]-1:] = 0
        dist = DependencyCRF(final, lengths=lengths)
        gold = dist.struct.to_parts(label, lengths=lengths).type_as(dist.argmax)
        incorrect_edges += (dist.argmax[:, :].cpu() - gold[:, :].cpu()).abs().sum() / 2.0
        total_edges += gold.sum()

    print(total_edges, incorrect_edges)
    model.train()
コード例 #6
0
ファイル: main.py プロジェクト: shivaat/MTLB-STRUCT
            nb_tr_steps += 1
            torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), 
                                       max_norm=max_grad_norm)     
            optimizer.step()
            scheduler.step()

        else:
            b_tags = [tag[mask] for mask, tag in zip(b_label_masks, b_tags)]
            b_tags = pad_sequence(b_tags, batch_first=True, padding_value=0)

            loss_main, logits, labels, final = model(b_input_ids, b_tags, labels=b_labels, label_masks=b_label_masks)

            if not lengths.max() <= final.shape[1]:
                dep_loss = 0
            else:
                dist = DependencyCRF(final, lengths=lengths)
                dep_labels = dist.struct.to_parts(b_tags, lengths=lengths).type_as(final)   # [BATCH_SIZE, lengths, lengths]
                log_prob = dist.log_prob(dep_labels)

                dep_loss = log_prob.mean() #sum()

            if dep_loss < 0 :
                loss = loss_main -dep_loss/dep_loss_factor  
            else:
                loss = loss_main

            loss.backward()
            tr_loss += loss.item()
            nb_tr_steps += 1
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
コード例 #7
0
ファイル: berteval.py プロジェクト: shivaat/MTLB-STRUCT
def eval(iter_data, model, tags2idx, device_name, mt=False):
    device = device_name
    logger.info("starting to evaluate")
    model = model.eval()
    eval_loss, eval_accuracy = 0, 0
    nb_eval_steps = 0
    predictions, true_labels, data_instances, probs = [], [], [], []
    dep_predictions, dep_gold_labels = [], []
    total_edges = 0
    incorrect_edges = 0
    for batch in tqdm(iter_data):
        batch = tuple(t.to(device) for t in batch)

        b_input_ids, b_pos_ids, b_tag_ids, b_deptype_ids, b_labels, b_input_mask,\
                b_token_type_ids, b_label_masks, lengths = batch
        #print('b_input_ids size:', b_input_ids.size())    # batch_size*max_len
        lengths = torch.flatten(lengths)
        batch_size, _ = b_tag_ids.shape

        with torch.no_grad():
            if not mt:
                tmp_eval_loss, logits, reduced_labels = model(
                    b_input_ids,
                    b_tag_ids,
                    token_type_ids=b_token_type_ids,
                    attention_mask=b_input_mask,
                    labels=b_labels,
                    label_masks=b_label_masks)

            else:

                tmp_eval_loss, logits, reduced_labels, final = model(
                    b_input_ids,
                    b_tag_ids,
                    labels=b_labels,
                    label_masks=b_label_masks)

                if not lengths.max() <= final.shape[1]:  #+ 1:
                    #print("fail to evaluate for dependency:", "max length", lengths.max(), "final shape", final.shape[1])
                    #continue
                    out = torch.zeros(
                        b_tags.size())  # not sure about the size!
                    # I cannot think what the size should be
                else:
                    dist = DependencyCRF(final, lengths=lengths)

                    out = dist.argmax
                    dep_predictions.append(out)

                    b_tags = [
                        tag[mask]
                        for mask, tag in zip(b_label_masks, b_tag_ids)
                    ]
                    b_tags = pad_sequence(b_tags,
                                          batch_first=True,
                                          padding_value=0)
                    dep_gold = dist.struct.to_parts(
                        b_tags, lengths=lengths).type_as(out)
                    dep_gold_labels.append(dep_gold)

                    incorrect_edges += (out[:, :].cpu() -
                                        dep_gold[:, :].cpu()).abs().sum() / 2.0
                    total_edges += dep_gold.sum()
                    #log_prob = dist.log_prob(dep_labels)
                    #dep_loss = log_prob.sum()

        tags_idx = [tags2idx[t] for t in tags2idx]
        logits_probs = F.softmax(logits, dim=2)[:, :, tags_idx]
        preds = torch.argmax(F.log_softmax(logits, dim=2), dim=2)
        #print('***',logits_probs)
        #print('logits size:',logits.size())     # batch_size*sentence_len(before padding)
        logits_probs = logits_probs.detach().cpu().numpy()
        preds = preds.detach().cpu().numpy()
        reduced_labels = reduced_labels.to('cpu').numpy()

        labels_to_append = []
        predictions_to_append = []
        logits_to_append = []

        for prediction, r_label, logit in zip(preds, reduced_labels,
                                              logits_probs):
            preds = []
            labels = []
            logs = []
            for pred, lab, log in zip(prediction, r_label, logit):
                if lab.item(
                ) == -1:  # masked label; -1 means do not collect this label
                    continue
                preds.append(pred)
                labels.append(lab)
                logs.append(log)
            predictions_to_append.append(preds)
            labels_to_append.append(labels)
            logits_to_append.append(logs)

        predictions.extend(predictions_to_append)
        true_labels.extend(labels_to_append)
        data_instances.extend(b_input_ids)
        probs.extend(logits_to_append)

        eval_loss += tmp_eval_loss.mean().item()

        nb_eval_steps += 1

    if mt:
        print('num of edges', total_edges, 'incorrect_edges:', incorrect_edges)
        print('aacuracy', (total_edges - incorrect_edges) / total_edges)

    eval_loss = eval_loss / nb_eval_steps
    logger.info("eval loss (only main): {}".format(eval_loss))
    idx2tags = {tags2idx[t]: t for t in tags2idx}
    pred_tags = [[idx2tags[p_i] for p_i in p] for p in predictions]
    valid_tags = [[idx2tags[l_i] for l_i in l] for l in true_labels]
    logger.info("Seqeval accuracy: {}".format(
        accuracy_score(valid_tags, pred_tags)))
    fscore = f1_score(valid_tags, pred_tags)
    logger.info("Seqeval F1-Score: {}".format(fscore))
    logger.info("Seqeval Classification report: -- ")
    logger.info(classification_report(valid_tags, pred_tags))

    final_labels = [[idx2tags[p_i] for p_i in p] for p in predictions]
    return final_labels, probs, fscore