示例#1
0
    def compute_loss(state):
        arr = state["batch"].to_array()
        words = torch.from_numpy(arr["word_ids"]).long().to(device)
        mask = torch.from_numpy(arr["mask"]).bool().to(device)
        ptst_mask = torch.from_numpy(arr["ptst_mask"]).bool().to(device)

        model.train()
        scores = model(words, mask)
        masked_scores = scores.masked_fill(~ptst_mask, -1e9)

        # mask passed to LinearCRF shouldn't include the last token
        last_idx = mask.long().sum(dim=1, keepdim=True) - 1
        mask_ = mask.scatter(1, last_idx, False)[:, :-1]

        crf = LinearCRF(masked_scores, mask_)
        crf_z = LinearCRF(scores, mask_)
        ptst_loss = -crf.log_partitions().sum() + crf_z.log_partitions().sum()
        ptst_loss /= mask.size(0)

        state["loss"] = ptst_loss
        state["stats"] = {"ptst_loss": ptst_loss.item()}
        state["n_items"] = mask.long().sum().item()
示例#2
0
    def maybe_compute_loss(state):
        if not compute_loss:
            return

        arr = state["arr"] if "arr" in state else state["batch"].to_array()
        state["arr"] = arr
        if "scores" in state:
            scores = state["scores"]
        else:
            assert arr["mask"].all()
            words = torch.from_numpy(arr["word_ids"]).long().to(device)
            model.eval()
            scores = model(words)

        mask = torch.from_numpy(arr["mask"]).bool().to(device)
        ptst_mask = torch.from_numpy(arr["ptst_mask"]).bool().to(device)

        masked_scores = scores.masked_fill(~ptst_mask, -1e9)
        crf = LinearCRF(masked_scores)
        crf_z = LinearCRF(scores)
        ptst_loss = -crf.log_partitions().sum() + crf_z.log_partitions().sum()
        state["ptst_loss"] = ptst_loss.item()
        state["size"] = mask.size(0)