def sdp(self, hidden: dict, graph=True): # 语义依存 sdp_arc, sdp_label, _ = self.model.sdp_decoder(hidden['word_cls_input'], hidden['word_length']) sdp_arc = torch.sigmoid_(sdp_arc) if graph: # 语义依存图 sdp_arc.transpose_(-1, -2) sdp_root_mask = sdp_arc[:, 0].argmax(dim=-1).unsqueeze_(-1).expand_as(sdp_arc[:, 0]) sdp_arc[:, 0] = 0 sdp_arc[:, 0].scatter_(dim=-1, index=sdp_root_mask, value=1) sdp_arc_T = sdp_arc.transpose(-1, -2) sdp_arc_fix = sdp_arc_T.argmax(dim=-1).unsqueeze_(-1).expand_as(sdp_arc) sdp_arc = ((sdp_arc_T > 0.5) & (sdp_arc_T > sdp_arc)). \ scatter_(dim=-1, index=sdp_arc_fix, value=True) else: # 语义依存树 sdp_arc_fix = eisner(sdp_arc, hidden['word_cls_mask']).unsqueeze_(-1).expand_as(sdp_arc) sdp_arc = torch.zeros_like(sdp_arc, dtype=torch.bool).scatter_(dim=-1, index=sdp_arc_fix, value=True) sdp_label = torch.argmax(sdp_label, dim=-1) word_cls_mask = hidden['word_cls_mask'] word_cls_mask = word_cls_mask.unsqueeze(-1).expand(-1, -1, word_cls_mask.size(1)) sdp_arc = sdp_arc & word_cls_mask sdp_label = self._get_graph_entities(sdp_arc, sdp_label, self.sdp_vocab) return sdp_label
def dep(self, hidden: dict, fast=False): # 依存句法树 dep_arc, dep_label, word_length = self.model.dep_decoder( hidden['word_cls_input'], hidden['word_length']) if fast: dep_arc_fix = dep_arc.argmax( dim=-1).unsqueeze_(-1).expand_as(dep_arc) else: dep_arc_fix = eisner( dep_arc, hidden['word_cls_mask']).unsqueeze_(-1).expand_as(dep_arc) dep_arc = torch.zeros_like(dep_arc, dtype=torch.bool).scatter_( dim=-1, index=dep_arc_fix, value=True) dep_label[:, :, :, self.dep_fix:] = float('-inf') dep_label = torch.argmax(dep_label, dim=-1) word_cls_mask = hidden['word_cls_mask'] word_cls_mask = word_cls_mask.unsqueeze(-1).expand( -1, -1, word_cls_mask.size(1)) dep_arc = dep_arc & word_cls_mask dep_label = self._get_graph_entities(dep_arc, dep_label, self.dep_vocab) return dep_label
def dep(self, hidden: dict, fast=False): """ 依存句法树 Args: hidden: 分词时所得到的中间表示 fast: 启用 fast 模式时,减少对结果的约束,速度更快,相应的精度会降低 Returns: 依存句法树结果 """ dep_arc, dep_label, word_length = self.model.dep_decoder( hidden['word_cls_input'], hidden['word_length']) if fast: dep_arc_fix = dep_arc.argmax( dim=-1).unsqueeze_(-1).expand_as(dep_arc) else: dep_arc_fix = eisner( dep_arc, hidden['word_cls_mask']).unsqueeze_(-1).expand_as(dep_arc) dep_arc = torch.zeros_like(dep_arc, dtype=torch.bool).scatter_( -1, dep_arc_fix, True) dep_label = torch.argmax(dep_label, dim=-1) word_cls_mask = hidden['word_cls_mask'] word_cls_mask = word_cls_mask.unsqueeze(-1).expand( -1, -1, word_cls_mask.size(1)) dep_arc = dep_arc & word_cls_mask dep_label = get_graph_entities(dep_arc, dep_label, self.dep_vocab) return dep_label
def step(self, y_pred: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], y: Tuple[torch.Tensor, torch.Tensor]): arc_pred, label_pred, seq_len = y_pred mask = length_to_mask(seq_len + 1) mask[:, 0] = False if self._eisner: from ltp.utils import eisner arc_pred = eisner(arc_pred, mask) else: arc_pred = torch.argmax(arc_pred, dim=-1) label_pred = torch.argmax(label_pred, dim=-1) arc_real, label_real = y label_pred = label_pred.gather(-1, arc_pred.unsqueeze(-1)).squeeze(-1) mask = mask.narrow(-1, 1, mask.size(1) - 1) arc_pred = arc_pred.narrow(-1, 1, arc_pred.size(1) - 1) label_pred = label_pred.narrow(-1, 1, label_pred.size(1) - 1) head_true = (arc_pred == arc_real)[mask] label_true = (label_pred == label_real)[mask] self._head_true += torch.sum(head_true).item() self._label_true += torch.sum(label_true).item() self._union_true += torch.sum(label_true[head_true]).item() self._all += torch.sum(mask).item()
def sdp(self, hidden: dict, graph=True): """ 语义依存图(树) Args: hidden: 分词时所得到的中间表示 graph: 选择是语义依存图还是语义依存树结果 Returns: 语义依存图(树)结果 """ sdp_arc, sdp_label, _ = self.model.sdp_decoder( hidden['word_cls_input'], hidden['word_length']) sdp_arc = torch.sigmoid_(sdp_arc) # 避免自指 eye = torch.arange(0, sdp_arc.size(1), device=sdp_arc.device).view( 1, 1, -1).expand(sdp_arc.size(0), -1, -1) sdp_arc.scatter_(1, eye, 0) if graph: # 语义依存图 sdp_arc.transpose_(-1, -2) sdp_root_mask = sdp_arc[:, 0].argmax(dim=-1).unsqueeze_(-1).expand_as( sdp_arc[:, 0]) sdp_arc[:, 0] = 0 sdp_arc[:, 0].scatter_(dim=-1, index=sdp_root_mask, value=1) sdp_arc_T = sdp_arc.transpose(-1, -2) sdp_arc_fix = sdp_arc_T.argmax( dim=-1).unsqueeze_(-1).expand_as(sdp_arc) sdp_arc = ((sdp_arc_T > 0.5) & (sdp_arc_T > sdp_arc)).scatter_( -1, sdp_arc_fix, True) else: # 语义依存树 sdp_arc_fix = eisner( sdp_arc, hidden['word_cls_mask']).unsqueeze_(-1).expand_as(sdp_arc) sdp_arc = torch.zeros_like(sdp_arc, dtype=torch.bool).scatter_( -1, sdp_arc_fix, True) sdp_label = torch.argmax(sdp_label, dim=-1) word_cls_mask = hidden['word_cls_mask'] word_cls_mask = word_cls_mask.unsqueeze(-1).expand( -1, -1, word_cls_mask.size(1)) word_cls_mask = word_cls_mask & word_cls_mask.transpose(-1, -2) for si, length in enumerate(hidden['word_length']): word_cls_mask[si, :length, 0] = True sdp_arc = sdp_arc & word_cls_mask sdp_label = get_graph_entities(sdp_arc, sdp_label, self.sdp_vocab) return sdp_label