class CRFArgumentExtractor(Component): model_name = "crf_arg_extract" used_features = ['ptree'] def __init__(self): self.arg_pos_clf = ArgumentPositionClassifier() self.ss_model = CRF(algorithm='lbfgs', c1=0.1, c2=0.1, max_iterations=100, all_possible_transitions=True) self.ps_model = CRF(algorithm='lbfgs', c1=0.1, c2=0.1, max_iterations=100, all_possible_transitions=True) def load(self, path): self.arg_pos_clf.load(path) self.ss_model = pickle.load( open(os.path.join(path, "{}.ss.p".format(self.model_name)), 'rb')) self.ps_model = pickle.load( open(os.path.join(path, "{}.ps.p".format(self.model_name)), 'rb')) def save(self, path): self.arg_pos_clf.save(path) pickle.dump( self.ss_model, open(os.path.join(path, "{}.ss.p".format(self.model_name)), 'wb')) pickle.dump( self.ps_model, open(os.path.join(path, "{}.ps.p".format(self.model_name)), 'wb')) def fit(self, docs_train: List[Document], docs_val: List[Document] = None): self.arg_pos_clf.fit(docs_train) (x_ss, y_ss), (x_ps, y_ps) = generate_pdtb_features(docs_train) self.ss_model.fit(x_ss, y_ss) self.ps_model.fit(x_ps, y_ps) def score_on_features(self, x_ss, y_ss, x_ps, y_ps): y_pred = np.concatenate(self.ss_model.predict(x_ss)) y_ss = np.concatenate(y_ss) logger.info("Evaluation: SS Model") logger.info(" Acc : {:<06.4}".format(accuracy_score(y_ss, y_pred))) prec, recall, f1, support = precision_recall_fscore_support( y_ss, y_pred, average='macro') logger.info(" Macro: P {:<06.4} R {:<06.4} F1 {:<06.4}".format( prec, recall, f1)) logger.info(" Kappa: {:<06.4}".format( cohen_kappa_score(y_ss, y_pred))) y_pred = np.concatenate(self.ps_model.predict(x_ps)) y_ps = np.concatenate(y_ps) logger.info("Evaluation: PS Model") logger.info(" Acc : {:<06.4}".format(accuracy_score(y_ps, y_pred))) prec, recall, f1, support = precision_recall_fscore_support( y_ps, y_pred, average='macro') logger.info(" Macro: P {:<06.4} R {:<06.4} F1 {:<06.4}".format( prec, recall, f1)) logger.info(" Kappa: {:<06.4}".format( cohen_kappa_score(y_ps, y_pred))) def score(self, docs: List[Document]): self.arg_pos_clf.score(docs) (x_ss, y_ss), (x_ps, y_ps) = generate_pdtb_features(docs) self.score_on_features(x_ss, y_ss, x_ps, y_ps) def extract_arguments(self, ptree: nltk.ParentedTree, relation: Relation, arg_pos: str): indices = [token.local_idx for token in relation.conn.tokens] ptree._label = 'S' x = get_features(ptree, indices) if arg_pos == 'SS': probs = np.array( [[p[c] for c in self.ss_model.classes_] for p in self.ss_model.predict_marginals_single(x)]) probs_max = probs.max(1) labels = np.array(self.ss_model.classes_)[probs.argmax(axis=1)] arg1 = np.where(labels == 'Arg1')[0] arg2 = np.where(labels == 'Arg2')[0] arg1_prob = probs_max[arg1].mean() if len(arg1) else 0.0 arg2_prob = probs_max[arg2].mean() if len(arg2) else 0.0 arg1, arg2 = arg1.tolist(), arg2.tolist() if not arg1: logger.warning("Empty Arg1") if not arg2: logger.warning("Empty Arg2") elif arg_pos == 'PS': probs = np.array( [[p[c] for c in self.ps_model.classes_] for p in self.ps_model.predict_marginals_single(x)]) probs_max = probs.max(1) labels = np.array(self.ps_model.classes_)[probs.argmax(axis=1)] arg1 = [] arg1_prob = 1.0 arg2 = np.where(labels == 'Arg2')[0] arg2_prob = probs_max[arg2].mean() if len(arg2) else 0.0 arg2 = arg2.tolist() if not arg2: logger.warning("Empty Arg2") else: raise NotImplementedError('Unknown argument position') return arg1, arg2, arg1_prob, arg2_prob def parse(self, doc: Document, relations: List[Relation] = None, **kwargs): if relations is None: raise ValueError('Component needs connectives already classified.') for relation in filter(lambda r: r.type == "Explicit", relations): sent_id = relation.conn.get_sentence_idxs()[0] sent = doc.sentences[sent_id] ptree = sent.get_ptree() if ptree is None or len(relation.conn.tokens) == 0: continue # ARGUMENT POSITION conn_raw = ' '.join(t.surface for t in relation.conn.tokens) conn_idxs = [t.local_idx for t in relation.conn.tokens] arg_pos, arg_pos_confidence = self.arg_pos_clf.get_argument_position( ptree, conn_raw, conn_idxs) # If position poorly classified as PS, go to the next relation if arg_pos == 'PS' and sent_id == 0: continue # ARGUMENT EXTRACTION arg1, arg2, arg1_c, arg2_c = self.extract_arguments( ptree, relation, arg_pos) if arg_pos == 'PS': prev_sent = doc.sentences[sent_id] relation.arg1.tokens = prev_sent.tokens relation.arg2.tokens = [sent.tokens[i] for i in arg2] elif arg_pos == 'SS': relation.arg1.tokens = [sent.tokens[i] for i in arg1] relation.arg2.tokens = [sent.tokens[i] for i in arg2] else: logger.error('Unknown Argument Position: ' + arg_pos) return relations
class GoshArgumentExtractor(Component): model_name = 'gosh_arg_extract' used_features = ['ptree', 'dtree'] def __init__(self, window_side_size=2): self.window_side_size = window_side_size self.arg1_model = CRF(algorithm='l2sgd', verbose=True, min_freq=5) self.arg2_model = CRF(algorithm='l2sgd', verbose=True, min_freq=5) def load(self, path): self.arg1_model = pickle.load(open(os.path.join(path, "{}.arg1.p".format(self.model_name)), 'rb')) self.arg2_model = pickle.load(open(os.path.join(path, "{}.arg2.p".format(self.model_name)), 'rb')) def save(self, path): pickle.dump(self.arg1_model, open(os.path.join(path, "{}.arg1.p".format(self.model_name)), 'wb')) pickle.dump(self.arg2_model, open(os.path.join(path, "{}.arg2.p".format(self.model_name)), 'wb')) def fit(self, docs_train: List[Document], docs_val: List[Document] = None): (x_arg1, y_arg1), (x_arg2, y_arg2) = generate_pdtb_features(docs_train, self.window_side_size) self.arg1_model.fit(x_arg1, y_arg1) self.arg2_model.fit(x_arg2, y_arg2) def score_on_features(self, x_arg1, y_arg1, x_arg2, y_arg2): y_pred = self.arg1_model.predict(x_arg1) y_pred = [decode_iob(s) for s in y_pred] y_arg1 = [decode_iob(s) for s in y_arg1] y_pred = np.concatenate(y_pred) y_arg1 = np.concatenate(y_arg1) logger.info("Evaluation: Arg1 Model") logger.info(" Acc : {:<06.4}".format(accuracy_score(y_arg1, y_pred))) prec, recall, f1, support = precision_recall_fscore_support(y_arg1, y_pred, average='macro') logger.info(" Macro: P {:<06.4} R {:<06.4} F1 {:<06.4}".format(prec, recall, f1)) logger.info(" Kappa: {:<06.4}".format(cohen_kappa_score(y_arg1, y_pred))) y_pred = self.arg2_model.predict(x_arg2) y_pred = [decode_iob(s) for s in y_pred] y_arg2 = [decode_iob(s) for s in y_arg2] y_pred = np.concatenate(y_pred) y_arg2 = np.concatenate(y_arg2) logger.info("Evaluation: Arg2 Model") logger.info(" Acc : {:<06.4}".format(accuracy_score(y_arg2, y_pred))) prec, recall, f1, support = precision_recall_fscore_support(y_arg2, y_pred, average='macro') logger.info(" Macro: P {:<06.4} R {:<06.4} F1 {:<06.4}".format(prec, recall, f1)) logger.info(" Kappa: {:<06.4}".format(cohen_kappa_score(y_arg2, y_pred))) def score(self, docs: List[Document]): (x_arg1, y_arg1), (x_arg2, y_arg2) = generate_pdtb_features(docs, self.window_side_size) self.score_on_features(x_arg1, y_arg1, x_arg2, y_arg2) def extract_arguments(self, doc: Document, relation: Relation): conn = [t.local_idx for t in relation.conn.tokens] arg2_sentence_id = relation.arg2.get_sentence_idxs()[0] sent_features = [] for i in range(-self.window_side_size, self.window_side_size + 1): sent_idx = arg2_sentence_id + i if sent_idx < 0 or sent_idx >= len(doc.sentences): continue sent_i = doc.sentences[sent_idx] ptree_i = sent_i.get_ptree() if not ptree_i: continue dtree_i = sent_i.dependencies sent_features.extend( get_features(ptree_i, dtree_i, conn, relation.senses[0], sent_i.tokens[0].local_idx)) indices = [] for i in sent_features: indices.append(i['idx']) del i['idx'] indices = np.array(indices) arg2_probs = np.array( [[p[c] for c in self.arg2_model.classes_] for p in self.arg2_model.predict_marginals_single(sent_features)]) arg2_probs_max = arg2_probs.max(1) arg2_labels = np.array(self.arg2_model.classes_)[arg2_probs.argmax(axis=1)] arg2_labels = np.array(decode_iob(arg2_labels)) for i, arg2_label in zip(sent_features, arg2_labels): i['is_arg2'] = (arg2_label == 'Arg2') arg1_probs = np.array( [[p[c] for c in self.arg1_model.classes_] for p in self.arg1_model.predict_marginals_single(sent_features)]) arg1_probs_max = arg1_probs.max(1) arg1_labels = np.array(self.arg1_model.classes_)[arg1_probs.argmax(axis=1)] arg1_labels = np.array(decode_iob(arg1_labels)) arg1 = indices[np.where(arg1_labels == 'Arg1')[0]] arg2 = indices[np.where(arg2_labels == 'Arg2')[0]] arg1_prob = arg1_probs_max[np.where(arg1_labels == 'Arg1')[0]].mean() if len(arg1) else 0.0 arg2_prob = arg2_probs_max[np.where(arg2_labels == 'Arg2')[0]].mean() if len(arg2) else 0.0 arg1, arg2 = arg1.tolist(), arg2.tolist() return arg1, arg2, arg1_prob, arg2_prob def parse(self, doc: Document, relations: List[Relation] = None, **kwargs): if relations is None: raise ValueError('Component needs connectives already classified.') for relation in filter(lambda r: r.type == "Explicit", relations): sent_id = relation.conn.get_sentence_idxs()[0] sent = doc.sentences[sent_id] ptree = sent.get_ptree() if ptree is None or len(relation.conn.tokens) == 0: continue arg1, arg2, arg1_c, arg2_c = self.extract_arguments(ptree, relation) relation.arg1.tokens = [sent.tokens[i] for i in arg1] relation.arg2.tokens = [sent.tokens[i] for i in arg2] return relations
class CrfEntityExtractor: __DIRNAME = Path(os.path.dirname(__file__)) __FEATURES_SET = [ ['low', 'title', 'upper', 'pos', 'pos2'], ['low', 'word3', 'word2', 'upper', 'title', 'digit', 'pos', 'pos2'], ['low', 'title', 'upper', 'pos', 'pos2'], ] __HALF_FEATURES_SPAN = len(__FEATURES_SET) // 2 __CONFIG = { 'max_iterations': 40, 'L1_c': 1e-3, 'L2_c': 1e-2, } __FEATURES_RANGE = range(-__HALF_FEATURES_SPAN, __HALF_FEATURES_SPAN + 1) __PREFIXES = [str(i) for i in __FEATURES_RANGE] __FUNCTION_DICT = { 'low': lambda doc: doc[0].lower(), 'title': lambda doc: doc[0].istitle(), 'word3': lambda doc: doc[0][-3:], 'word2': lambda doc: doc[0][-2:], 'word1': lambda doc: doc[0][-1:], 'pos': lambda doc: doc[1], 'pos2': lambda doc: doc[1][:2], 'bias': lambda doc: 'bias', 'upper': lambda doc: doc[0].isupper(), 'digit': lambda doc: doc[0].isdigit(), } def __init__(self): self.__crf_model = None def fit(self, train_data: Iterable[str], labels: Iterable[Iterable[str]]): """ :param train_data: :param labels: labels in BIO or BILOU notation :return: """ crf_dataset = self.__create_dataset(train_data, labels) features = [ self.__convert_idata_to_features(message_data) for message_data in crf_dataset ] labels = [ self.__extract_labels_from_data(message_data) for message_data in crf_dataset ] self.__crf_model = CRF( algorithm='lbfgs', c1=self.__CONFIG['L1_c'], c2=self.__CONFIG['L2_c'], max_iterations=self.__CONFIG['max_iterations'], all_possible_transitions=True, ) self.__crf_model.fit(features, labels) return self def predict(self, text: str) -> List['Entity']: """Predicts entities in text. :param text: :return: """ tokens = self.__preprocess(text) intermediate_data = self.__convert_to_idata_format(tokens) features = self.__convert_idata_to_features(intermediate_data) predicts = self.__crf_model.predict_marginals_single(features) entities = [] for pred in predicts: sorted_pred = sorted(pred.items(), key=lambda x: x[1], reverse=True) entities.append(sorted_pred[0][0]) # entities = self.__get_entities_from_predict( # tokens, # predicts # ) return entities def evaluate(self, test_data: Iterable[str], labels: Iterable[Iterable[str]], metric: str = 'accuracy'): """Evaluates accuracy on test data. :param test_data: :param labels: :param metric: :return: """ # if self.__crf_model is None: # self.load_model() labels = self.__process_test_labels(labels) predicted_entities = [self.predict(sentence) for sentence in test_data] processed_predicted_entities = [ self.__postprocess(sent_entities, self.__preprocess(sentence)) for (sent_entities, sentence) in zip(predicted_entities, test_data) ] all_predicted = self.__get_flatten_values(processed_predicted_entities) all_labels = self.__get_flatten_values(labels) return accuracy_score(all_predicted, all_labels) def load_model(self, path: Path) -> 'CrfEntityExtractor': """Loads saved model. :param path: path where model was saved :return: """ self.__crf_model = joblib.load(path) return self def save_model(self, path: Path) -> None: joblib.dump(self.__crf_model, path) def __create_dataset(self, sentences, labels): dataset_message_format = [ self.__convert_to_idata_format(self.__preprocess(sentence), sentence_labels) for sentence, sentence_labels in zip(sentences, labels) ] return dataset_message_format def __convert_to_idata_format(self, tokens, entities=None): message_data_intermediate_format = [] for i, token in enumerate(tokens): entity = entities[i] if (entities and len(entities) > i) else "N/A" tag = self.__get_tag_of_token(token.value) message_data_intermediate_format.append((token.value, tag, entity)) return message_data_intermediate_format def __get_entities_from_predict(self, tokens, predicts): entities = [] cur_token_ind: int = 0 while cur_token_ind < len(tokens): end_ind, confidence, entity_label = self.__handle_bilou_label( cur_token_ind, predicts) if end_ind is not None: current_tokens = tokens[cur_token_ind:end_ind + 1] entity_value: str = ' '.join( [token.value for token in current_tokens]) entity = Entity(name=entity_label, value=entity_value, start_token=cur_token_ind, end_token=end_ind, start=current_tokens[0].start, end=current_tokens[-1].end) entities.append(entity) cur_token_ind = end_ind + 1 else: cur_token_ind += 1 return entities def __handle_bilou_label(self, token_index, predicts): label, confidence = self.__get_most_likely_entity( token_index, predicts) entity_label = self.__convert_to_ent_name(label) extracted = self.__extract_bilou_prefix(label) # if extracted == "U": # return token_index, confidence, entity_label if extracted == "B": end_token_index, confidence = self.__find_bilou_end( token_index, predicts) return end_token_index, confidence, entity_label else: return None, None, None def __find_bilou_end(self, token_index: int, predicts): end_token_ind: int = token_index + 1 finished: bool = False label, confidence = self.__get_most_likely_entity( token_index, predicts) entity_label: str = self.__convert_to_ent_name(label) while not finished: label, label_confidence = self.__get_most_likely_entity( end_token_ind, predicts) confidence = min(confidence, label_confidence) if self.__convert_to_ent_name(label) != entity_label: if self.__extract_bilou_prefix(label) == 'L': finished = True if self.__extract_bilou_prefix(label) == 'I': end_token_ind += 1 else: finished = True end_token_ind -= 1 else: end_token_ind += 1 return end_token_ind, confidence def __mark_positions_by_labels(self, entities_labels, positions, name: str): if len(positions) == 1: entities_labels = self.__set_label(entities_labels, positions[0], 'U', name) else: entities_labels = self.__set_label(entities_labels, positions[0], 'B', name) entities_labels = self.__set_label(entities_labels, positions[-1], 'L', name) for ind in positions[1:-1]: entities_labels = self.__set_label(entities_labels, ind, 'I', name) return entities_labels def __get_example_features(self, data, example_index): """Exctract features from example in intermediate data format :param data: list of examples in task specified format :param example_index: index of central example :return: list of special features extracted from one example and its context """ message_len = len(data) example_features = {} for futures_index in self.__FEATURES_RANGE: if example_index + futures_index >= message_len: example_features['EOS'] = True elif example_index + futures_index < 0: example_features['BOS'] = True else: example = data[example_index + futures_index] shifted_futures_index = futures_index + self.__HALF_FEATURES_SPAN prefix = self.__PREFIXES[shifted_futures_index] features = self.__FEATURES_SET[shifted_futures_index] for feature in features: value = self.__FUNCTION_DICT[feature](example) example_features[f'{prefix}:{feature}'] = value return example_features def __convert_idata_to_features(self, data): """Extract features from examples in intermediate data format :param data: list of examples in special format :return: list of futures extracted form each example """ features = [] for ind, example in enumerate(data): example_features: Dict[str, Any] = self.__get_example_features( data, ind) features.append(example_features) return features def __get_most_likely_entity(self, ind: int, predicts): if len(predicts) > ind: entity_probs = predicts[ind] else: entity_probs = None if entity_probs: label: str = max(entity_probs, key=lambda key: entity_probs[key]) confidence = sum( [v for k, v in entity_probs.items() if k[2:] == label[2:]]) return label, confidence else: return '', 0.0 def __convert_to_ent_name(self, bilou_ent_name: str) -> str: """Get entity name from bilou label representation :param bilou_ent_name: BILOU entity name :return: entity name without BILOU prefix """ return bilou_ent_name[2:] def __extract_bilou_prefix(self, label: str): """Get BILOU prefix from label If label prefix (first label symbol) not in {B, I, U, L} return None :param label: BILOU entity name :return: BILOU prefix """ if len(label) >= 2 and label[1] == "-": return label[0].upper() return None def __process_test_labels(self, test_labels): return [ self.__to_dict([ label[2:] if (label.startswith('B-') or label.startswith('I-')) else label for label in sent_labels ]) for sent_labels in test_labels ] @staticmethod def __preprocess(text: str) -> List[str]: """Deletes EOS token; splits texts into token. :param texts: :return: tokens """ if text.endswith('EOS'): return NLTKSplitter().process(text)[:-1] else: return NLTKSplitter().process(text) @staticmethod def __get_tag_of_token(token: str) -> str: """Gets part-of-speech tag for token. :param token: :return: POS tag """ tag = pos_tag([token])[0][1] return tag @staticmethod def __extract_labels_from_data(data: Iterable) -> List[str]: return [label for _, _, label in data] @staticmethod def __set_label(entities_labels, ind: int, prefix: str, name: str): entities_labels[ind] = f'{prefix}-{name}' return entities_labels @staticmethod def __to_dict(sent_labels): return dict(enumerate(sent_labels)) @staticmethod def __postprocess(entities, sentence): entities_dict = {k: 'O' for k in range(len(sentence))} for entity in entities: for key in range(entity.start_token, entity.end_token + 1): entities_dict[key] = entity.name return entities_dict @staticmethod def __get_flatten_values(dicts): return [word for sentence in dicts for word in sentence.values()]