def _pre_process_input(self, utterances): lengths = [len(s) for s in utterances] max_len = max(lengths) pieces = iterative_support(self._lexical_vocab.tokenize, utterances) units, positions = [], [] for tokens in pieces: units.append(flat_list(tokens)) cum_list = np.cumsum([len(p) for p in tokens]).tolist() positions.append([0] + cum_list[:-1]) sizes = [len(u) for u in units] max_size = max(sizes) cls_sign = self._lexical_vocab.CLS_SIGN sep_sign = self._lexical_vocab.SEP_SIGN pad_sign = self._lexical_vocab.PAD_SIGN pad_unit = [[cls_sign] + s + [sep_sign] + [pad_sign] * (max_size - len(s)) for s in units] starts = [[ln + 1 for ln in u] + [max_size + 1] * (max_len - len(u)) for u in positions] var_unit = torch.LongTensor( [self._lexical_vocab.index(u) for u in pad_unit]) attn_mask = torch.LongTensor([[1] * (lg + 2) + [0] * (max_size - lg) for lg in sizes]) var_start = torch.LongTensor(starts) if torch.cuda.is_available(): var_unit = var_unit.cuda() attn_mask = attn_mask.cuda() var_start = var_start.cuda() return var_unit, attn_mask, var_start, lengths
def _pre_process_output(self, entities, lengths): positions, labels = [], [] batch_size = len(entities) for utt_i in range(0, batch_size): for segment in entities[utt_i]: positions.append((utt_i, segment[0], segment[1])) labels.append(segment[2]) for utt_i in range(0, batch_size): reject_set = [(e[0], e[1]) for e in entities[utt_i]] s_len = lengths[utt_i] neg_num = int(s_len * self._neg_rate) + 1 candies = flat_list([[(i, j) for j in range(i, s_len) if (i, j) not in reject_set] for i in range(s_len)]) if len(candies) > 0: sample_num = min(neg_num, len(candies)) assert sample_num > 0 np.random.shuffle(candies) for i, j in candies[:sample_num]: positions.append((utt_i, i, j)) labels.append("O") var_lbl = torch.LongTensor( iterative_support(self._label_vocab.index, labels)) if torch.cuda.is_available(): var_lbl = var_lbl.cuda() return positions, var_lbl
def inference(self, sentences): var_sent, attn_mask, starts, lengths = self._pre_process_input( sentences) log_items = self(var_sent, mask_mat=attn_mask, starts=starts) score_t = torch.log_softmax(log_items, dim=-1) val_table, idx_table = torch.max(score_t, dim=-1) listing_it = idx_table.cpu().numpy().tolist() listing_vt = val_table.cpu().numpy().tolist() label_table = iterative_support(self._label_vocab.get, listing_it) candidates = [] for l_mat, v_mat, sent_l in zip(label_table, listing_vt, lengths): candidates.append([]) for i in range(0, sent_l): for j in range(i, sent_l): if l_mat[i][j] != "O": candidates[-1].append((i, j, l_mat[i][j], v_mat[i][j])) entities = [] for segments in candidates: ordered_seg = sorted(segments, key=lambda e: -e[-1]) filter_list = [] for elem in ordered_seg: flag = False current = (elem[0], elem[1]) for prior in filter_list: flag = conflict_judge(current, (prior[0], prior[1])) if flag: break if not flag: filter_list.append((elem[0], elem[1], elem[2])) entities.append(sorted(filter_list, key=lambda e: e[0])) return entities