def store_predictions(self): predictions = [] for i, doc in enumerate(self._dataset.documents): tokens = doc.tokens pred_entities = self._pred_entities[i] pred_relations = self._pred_relations[i] # convert entities converted_entities = [] for entity in pred_entities: entity_span = entity[:2] span_tokens = util.get_span_tokens(tokens, entity_span) entity_type = entity[2].identifier converted_entity = dict(type=entity_type, start=span_tokens[0].index, end=span_tokens[-1].index + 1) converted_entities.append(converted_entity) converted_entities = sorted(converted_entities, key=lambda e: e['start']) # convert relations converted_relations = [] for relation in pred_relations: head, tail = relation[:2] head_span, head_type = head[:2], head[2].identifier tail_span, tail_type = tail[:2], tail[2].identifier head_span_tokens = util.get_span_tokens(tokens, head_span) tail_span_tokens = util.get_span_tokens(tokens, tail_span) relation_type = relation[2].identifier converted_head = dict(type=head_type, start=head_span_tokens[0].index, end=head_span_tokens[-1].index + 1) converted_tail = dict(type=tail_type, start=tail_span_tokens[0].index, end=tail_span_tokens[-1].index + 1) head_idx = converted_entities.index(converted_head) tail_idx = converted_entities.index(converted_tail) converted_relation = dict(type=relation_type, head=head_idx, tail=tail_idx) converted_relations.append(converted_relation) converted_relations = sorted(converted_relations, key=lambda r: r['head']) doc_predictions = dict(tokens=[t.phrase for t in tokens], entities=converted_entities, relations=converted_relations) predictions.append(doc_predictions) # store as json label, epoch = self._dataset_label, self._epoch with open(self._predictions_path % (label, epoch), 'w') as predictions_file: json.dump(predictions, predictions_file)
def store_predictions(documents, pred_entities, pred_relations, store_path, pred_entities_scores, pred_relations_scores): predictions = [] for i, doc in enumerate(documents): tokens = doc.tokens sample_pred_entities = pred_entities[i] sample_pred_relations = pred_relations[i] sample_pred_relations_scores = pred_relations_scores[i] sample_pred_entities_scores = pred_entities_scores[i] # convert entities """print (i,"****",sample_pred_relations) print ("****",sample_pred_relations_scores) print ("****",len(sample_pred_relations), " ",len(sample_pred_relations_scores)) """ converted_entities = [] converted_entities_scores = [] for idx, entity in enumerate(sample_pred_entities): entity_span = entity[:2] span_tokens = util.get_span_tokens(tokens, entity_span) entity_type = entity[2].identifier converted_entity = dict(type=entity_type, start=span_tokens[0].index, end=span_tokens[-1].index + 1) converted_entity_with_score = dict( type=entity_type, probs=str(sample_pred_entities_scores[idx]), start=span_tokens[0].index, end=span_tokens[-1].index + 1) converted_entities.append(converted_entity) converted_entities_scores.append(converted_entity_with_score) converted_entities = sorted(converted_entities, key=lambda e: e['start']) converted_entities_scores = sorted(converted_entities_scores, key=lambda e: e['start']) # convert relations converted_relations = [] for idx, relation in enumerate(sample_pred_relations): head, tail = relation[:2] head_span, head_type = head[:2], head[2].identifier tail_span, tail_type = tail[:2], tail[2].identifier head_span_tokens = util.get_span_tokens(tokens, head_span) tail_span_tokens = util.get_span_tokens(tokens, tail_span) relation_type = relation[2].identifier converted_head = dict(type=head_type, start=head_span_tokens[0].index, end=head_span_tokens[-1].index + 1) converted_tail = dict(type=tail_type, start=tail_span_tokens[0].index, end=tail_span_tokens[-1].index + 1) head_idx = converted_entities.index(converted_head) tail_idx = converted_entities.index(converted_tail) converted_relation = dict(type=relation_type, head=head_idx, tail=tail_idx) converted_relation = dict(type=relation_type, head=head_idx, tail=tail_idx, probs=str( sample_pred_relations_scores[idx])) converted_relations.append(converted_relation) converted_relations = sorted(converted_relations, key=lambda r: r['head']) doc_predictions = dict(tokens=[t.phrase for t in tokens], entities=converted_entities_scores, relations=converted_relations) predictions.append(doc_predictions) # store as json with open(store_path, 'w') as predictions_file: print("the predict element is stored at:", store_path) json.dump(predictions, predictions_file)