class CRF(nn.Module): def __init__(self, config): super().__init__() self.config = config self.crf = ConditionalRandomField( num_tags=config.tag_vocab_size, include_start_end_transitions=False, ) def forward(self, inputs: torch.Tensor, tags: torch.Tensor, mask: torch.ByteTensor = None, reduction: str = 'sum'): if mask is None: mask = torch.ones(*inputs.size()[:2], dtype=torch.long).to(inputs.device) log_denominator = self.crf._input_likelihood(inputs, mask) log_numerator = self.crf._joint_likelihood(inputs, tags, mask) loglik = log_numerator - log_denominator if reduction == 'sum': loglik = loglik.sum() elif reduction == 'mean': loglik = loglik.mean() elif reduction == 'none': pass return loglik def decode(self, inputs, mask=None): if mask is None: mask = torch.ones(*inputs.shape[:2], dtype=torch.long).to(inputs.device) # preds = self.crf.viterbi_tags(inputs, mask) # preds, scores = zip(*preds) preds, scores = viterbi_decode_torch(inputs, self.crf.transitions) return list(preds)
class DTCRF(nn.Module): def __init__(self, config): super().__init__() self.config = config self.tag_form = config.tag_form self.crf = ConditionalRandomField( num_tags=config.tag_vocab_size, include_start_end_transitions=False, ) del self.crf.transitions # must del parameter before assigning a tensor self.crf.transitions = None del self.crf._constraint_mask num_tags = config.tag_vocab_size constraint_mask = torch.Tensor(num_tags + 2, num_tags + 2).fill_(1.).to(config.device) self.crf._constraint_mask = constraint_mask #torch.nn.Parameter(constraint_mask, requires_grad=False) if self.tag_form == 'iobes': M = 4 elif self.tag_form == 'iob2': M = 2 else: raise Exception(f'unsupported tag form: {self.tag_form}') N = config.tag_vocab_size E = (config.tag_vocab_size - 1) // M self.N, self.M, self.E = N, M, E self.p_in = nn.Parameter(torch.randn([M, M], dtype=torch.float32)) self.p_cross = nn.Parameter(torch.randn([M, M], dtype=torch.float32)) self.p_out = nn.Parameter(torch.randn(1, dtype=torch.float32)) self.p_to_out = nn.Parameter(torch.randn(M, dtype=torch.float32)) self.p_from_out = nn.Parameter(torch.randn(M, dtype=torch.float32)) self.need_update = True def p_to_cpu(self): if self.p_in.device.type != 'cpu': self.p_in.data = self.p_in.data.cpu() self.p_cross.data = self.p_cross.data.cpu() self.p_out.data = self.p_out.data.cpu() self.p_to_out.data = self.p_to_out.data.cpu() self.p_from_out.data = self.p_from_out.data.cpu() def update_transitions(self): ### build transition matrix (operation on cpu) M, N, E = self.M, self.N, self.E extended = torch.zeros( [N, N]) #.to(self.config.device) # extended transition matrix extended[0, 0] = self.p_out # O to O for e in range(E): extended[0, e * M + 1:e * M + 1 + M] = self.p_from_out extended[e * M + 1:e * M + 1 + M, 0] = self.p_to_out for e0 in range(E): extended[e0 * M + 1:e0 * M + 1 + M, e0 * M + 1:e0 * M + 1 + M] = self.p_in for e1 in range(e0 + 1, E): extended[e0 * M + 1:e0 * M + 1 + M, e1 * M + 1:e1 * M + 1 + M] = self.p_cross extended[e1 * M + 1:e1 * M + 1 + M, e0 * M + 1:e0 * M + 1 + M] = self.p_cross self.crf.transitions = extended.to(self.config.device) ### finish building transition matrix def forward(self, inputs: torch.Tensor, tags: torch.Tensor, mask: torch.ByteTensor = None, reduction: str = 'sum'): self.p_to_cpu() self.update_transitions() self.need_update = True if mask is None: mask = torch.ones(*inputs.size()[:2], dtype=torch.long).to(inputs.device) log_denominator = self.crf._input_likelihood(inputs, mask) log_numerator = self.crf._joint_likelihood(inputs, tags, mask) loglik = log_numerator - log_denominator if reduction == 'sum': loglik = loglik.sum() elif reduction == 'mean': loglik = loglik.mean() elif reduction == 'token_mean': loglik = loglik.mean() elif reduction == 'none': pass return loglik def decode(self, inputs, mask=None): if self.need_update: self.update_transitions() self.need_update = False if mask is None: mask = torch.ones(*inputs.shape[:2], dtype=torch.long).to(inputs.device) # preds = self.crf.viterbi_tags(inputs, mask) # preds, scores = zip(*preds) preds, scores = viterbi_decode_torch(inputs, self.crf.transitions) return list(preds)
class DCCRF(nn.Module): def __init__(self, config, input_dim=None): super().__init__() self.config = config self.tag_form = config.tag_form tag_form = config.tag_form self.crf = ConditionalRandomField( num_tags=config.tag_vocab_size, include_start_end_transitions=False, ) del self.crf.transitions # must del parameter before assigning a tensor self.crf.transitions = None del self.crf._constraint_mask num_tags = config.tag_vocab_size constraint_mask = torch.Tensor(num_tags + 2, num_tags + 2).fill_(1.).to(config.device) self.crf._constraint_mask = constraint_mask if tag_form == 'iobes': M = 4 elif tag_form == 'iob2': M = 2 else: raise Exception(f'unsupported tag form: {tag_form}') N = config.tag_vocab_size E = (config.tag_vocab_size - 1) // M A = 4 self.N, self.M, self.E, self.A = N, M, E, A self.p_in = nn.Parameter(torch.randn([A, M, M], dtype=torch.float32)) self.p_cross = nn.Parameter(torch.randn([M, M], dtype=torch.float32)) self.p_out = nn.Parameter(torch.randn(1, dtype=torch.float32)) self.p_to_out = nn.Parameter(torch.randn([M], dtype=torch.float32)) self.p_from_out = nn.Parameter(torch.randn([M], dtype=torch.float32)) if input_dim is None: input_dim = config.hidden_dim self.block_attn = nn.Linear(input_dim, A) init_linear(self.block_attn) self.dropout = nn.Dropout(0.5) def p_to_cpu(self): if self.p_out.device.type != 'cpu': # self.p_in.data = self.p_in.data.cpu() self.p_cross.data = self.p_cross.data.cpu() self.p_out.data = self.p_out.data.cpu() self.p_to_out.data = self.p_to_out.data.cpu() self.p_from_out.data = self.p_from_out.data.cpu() def update_transitions(self, hiddens, entity_mask=None): ### build transition matrix (operation on cpu) M, N, K, A = self.M, self.N, self.K, self.A ### predict block block_atten = self.block_attn(hiddens) if entity_mask is not None: block_atten -= 999 * (1. - entity_mask)[:, None] block_atten = F.softmax(block_atten, 0) block_atten = F.softmax(block_atten * 10, -1) p_in = (self.p_in[None] * block_atten[:, :, None, None] ).mean(1).cpu() # (K, A, N, N) => (K, N, N) ### build extended extended = torch.zeros( [N, N]) #.to(self.config.device) # extended transition matrix extended[0, 0] = self.p_out # O to O for e in range(E): extended[0, e * M + 1:e * M + 1 + M] = self.p_from_out extended[e * M + 1:e * M + 1 + M, 0] = self.p_to_out for e0 in range(E): extended[e0 * M + 1:e0 * M + 1 + M, e0 * M + 1:e0 * M + 1 + M] = p_in[k0] for e1 in range(e0 + 1, E): extended[e0 * M + 1:e0 * M + 1 + M, e1 * M + 1:e1 * M + 1 + M] = self.p_cross extended[e1 * M + 1:e1 * M + 1 + M, e0 * M + 1:e0 * M + 1 + M] = self.p_cross self.crf.transitions = extended.to(self.config.device) ### finish building transition matrix def forward(self, inputs: torch.Tensor, tags: torch.Tensor, hiddens: torch.Tensor, mask: torch.ByteTensor = None, entity_mask=None, reduction: str = 'sum'): self.p_to_cpu() self.update_transitions(hiddens, entity_mask) if mask is None: mask = torch.ones(*inputs.size()[:2], dtype=torch.long).to(inputs.device) log_denominator = self.crf._input_likelihood(inputs, mask) log_numerator = self.crf._joint_likelihood(inputs, tags, mask) loglik = log_numerator - log_denominator if reduction == 'sum': loglik = loglik.sum() elif reduction == 'mean': loglik = loglik.mean() elif reduction == 'token_mean': loglik = loglik.mean() elif reduction == 'none': pass return loglik def decode(self, inputs, hiddens, mask=None, entity_mask=None): self.update_transitions(hiddens, entity_mask) if mask is None: mask = torch.ones(*inputs.shape[:2], dtype=torch.long).to(inputs.device) # preds = self.crf.viterbi_tags(inputs, mask) # preds, scores = zip(*preds) preds, scores = viterbi_decode_torch(inputs, self.crf.transitions) return list(preds)
class PyramidCRF(torch.nn.Module): def __init__(self, config): super().__init__() self.config = config self.crf_T = ConditionalRandomField( num_tags=config.tag_vocab_size, include_start_end_transitions=False, ) self.crf_L = ConditionalRandomField( num_tags=config.tag_vocab_size, include_start_end_transitions=False, ) self.crf_R = ConditionalRandomField( num_tags=config.tag_vocab_size, include_start_end_transitions=False, ) def forward(self, logits_list, tags_list, mask_list): ''' O -> O -> O -> O -> O -> O \ / \ / \ / \ / \ / O -> O -> O -> O -> O \ / \ / \ / \ / O -> O -> O -> O ''' P = len(logits_list) B, T, N = logits_list[0].shape # -> direction logits_T = torch.zeros([P, B, T, N], dtype=torch.float).fill_(-1000.).to( logits_list[0].device) tags_T = torch.zeros([P, B, T], dtype=torch.long).to(tags_list[0].device) mask_T = torch.zeros([P, B, T], dtype=torch.bool).to(mask_list[0].device) for i in range(len(logits_list)): logits = logits_list[i] tags = tags_list[i] mask = mask_list[i] logits_T[i, :, :logits.shape[1], :] = logits tags_T[i, :, :tags.shape[1]] = tags mask_T[i, :, :mask.shape[1]] = mask _logits = logits_T.view(P * B, T, N) _tags = tags_T.view(P * B, T) _mask = mask_T.view(P * B, T) log_denominator = self.crf_T._input_likelihood(_logits, _mask) log_numerator = self.crf_T._joint_likelihood(_logits, _tags, _mask) loglik = log_numerator - log_denominator loss_T = -loglik.sum() # \ direction logits_L = logits_T.transpose(0, 2) # (T, B, P, N) tags_L = tags_T.transpose(0, 2) # (T, B, P) mask_L = mask_T.transpose(0, 2) # (T, B, P) _logits = logits_L.reshape(T * B, P, N) _tags = tags_L.reshape(T * B, P) _mask = mask_L.reshape(T * B, P) _mask[:, 0] = 1 # avoid all False log_denominator = self.crf_L._input_likelihood(_logits, _mask) log_numerator = self.crf_L._joint_likelihood(_logits, _tags, _mask) loglik = log_numerator - log_denominator loss_L = -loglik.sum() # / direction logits_R = torch.zeros_like(logits_T).fill_(-1000.) tags_R = torch.zeros_like(tags_T) mask_R = torch.zeros_like(mask_T) for i in range(len(logits_list)): logits = logits_list[i] tags = tags_list[i] mask = mask_list[i] logits_R[i, :, -logits.shape[1]:, :] = logits tags_R[i, :, -tags.shape[1]:] = tags mask_R[i, :, -mask.shape[1]:] = mask logits_R = logits_R.transpose(0, 2) # (T, B, P, N) tags_R = tags_R.transpose(0, 2) # (T, B, P) mask_R = mask_R.transpose(0, 2) # (T, B, P) _logits = logits_R.reshape(T * B, P, N) _tags = tags_R.reshape(T * B, P) _mask = mask_R.reshape(T * B, P) _mask[:, 0] = 1 # avoid all False log_denominator = self.crf_R._input_likelihood(_logits, _mask) log_numerator = self.crf_R._joint_likelihood(_logits, _tags, _mask) loglik = log_numerator - log_denominator loss_R = -loglik.sum() return loss_T, loss_L, loss_R def decode(self, logits_list, mask_list): P = len(logits_list) B, T, N = logits_list[0].shape # -> direction logits_T = torch.zeros([P, B, T, N], dtype=torch.float).fill_(-1000.).to( logits_list[0].device) mask_T = torch.zeros([P, B, T], dtype=torch.bool).to(mask_list[0].device) for i in range(len(logits_list)): logits = logits_list[i] mask = mask_list[i] logits_T[i, :, :logits.shape[1], :] = logits mask_T[i, :, :mask.shape[1]] = mask logits_T[:, :, :, 0] += (1. - mask_T.float()) * 1000. _logits = logits_T.view(P * B, T, N) preds_T, _ = viterbi_decode_torch(_logits, self.crf_T.transitions) preds_T = np.array(preds_T).reshape(P, B, T) # (P*B, T) # \ direction logits_L = logits_T.transpose(0, 2) # (T, B, P, N) _logits = logits_L.reshape(T * B, P, N) preds_L, _ = viterbi_decode_torch(_logits, self.crf_L.transitions) preds_L = np.array(preds_L).reshape(T, B, P).transpose(2, 1, 0) # / direction logits_R = torch.zeros_like(logits_T).fill_(-1000.) mask_R = torch.zeros_like(mask_T) for i in range(len(logits_list)): logits = logits_list[i] mask = mask_list[i] logits_R[i, :, -logits.shape[1]:, :] = logits mask_R[i, :, -mask.shape[1]:] = mask logits_R[:, :, :, 0] += (1. - mask_R.float()) * 1000. _logits = logits_R.transpose(0, 2).reshape(T * B, P, N) preds_R, _ = viterbi_decode_torch(_logits, self.crf_R.transitions) preds_R = np.array(preds_R).reshape(T, B, P) for i in range(1, P): preds_R[:-i, :, i] = preds_R[i:, :, i] preds_R[-i:, :, i] = 0 preds_R = preds_R.transpose(2, 1, 0) return preds_T, preds_L, preds_R