def compute_loss(state): arr = state["batch"].to_array() words = torch.from_numpy(arr["word_ids"]).long().to(device) mask = torch.from_numpy(arr["mask"]).bool().to(device) ptst_mask = torch.from_numpy(arr["ptst_mask"]).bool().to(device) model.train() scores = model(words, mask) masked_scores = scores.masked_fill(~ptst_mask, -1e9) # mask passed to LinearCRF shouldn't include the last token last_idx = mask.long().sum(dim=1, keepdim=True) - 1 mask_ = mask.scatter(1, last_idx, False)[:, :-1] crf = LinearCRF(masked_scores, mask_) crf_z = LinearCRF(scores, mask_) ptst_loss = -crf.log_partitions().sum() + crf_z.log_partitions().sum() ptst_loss /= mask.size(0) state["loss"] = ptst_loss state["stats"] = {"ptst_loss": ptst_loss.item()} state["n_items"] = mask.long().sum().item()
def maybe_compute_loss(state): if not compute_loss: return arr = state["arr"] if "arr" in state else state["batch"].to_array() state["arr"] = arr if "scores" in state: scores = state["scores"] else: assert arr["mask"].all() words = torch.from_numpy(arr["word_ids"]).long().to(device) model.eval() scores = model(words) mask = torch.from_numpy(arr["mask"]).bool().to(device) ptst_mask = torch.from_numpy(arr["ptst_mask"]).bool().to(device) masked_scores = scores.masked_fill(~ptst_mask, -1e9) crf = LinearCRF(masked_scores) crf_z = LinearCRF(scores) ptst_loss = -crf.log_partitions().sum() + crf_z.log_partitions().sum() state["ptst_loss"] = ptst_loss.item() state["size"] = mask.size(0)