def add_sidebar_param_split_dir() -> SplitDir: """ Add text input for path to Power Split Dir to sidebar, check Power Split Dir, and return handle """ split_dir_path = st.sidebar.text_input('Path to Power Split Directory', value='data/power/split/cde-50/') split_dir = SplitDir(Path(split_dir_path)) split_dir.check() return split_dir
def create_anyburl_dataset(args): split_dir_path = args.split_dir facts_tsv_path = args.facts_tsv overwrite = args.overwrite # # Check that (input) IRT Split Directory exists # logging.info('Check that (input) IRT Split Directory exists ...') split_dir = SplitDir(Path(split_dir_path)) split_dir.check() # # Check that (output) AnyBURL Facts TSV does not exist # logging.info('Check that (output) AnyBURL Facts TSV does not exist ...') facts_tsv = FactsTsv(Path(facts_tsv_path)) if not overwrite: facts_tsv.check(should_exist=False) facts_tsv.path.parent.mkdir(parents=True, exist_ok=True) # # Create AnyBURL Facts TSV # logging.info('Create AnyBURL Facts TSV ...') ent_to_lbl = split_dir.entities_tsv.load() rel_to_lbl = split_dir.relations_tsv.load() def escape(text): return re.sub('[^0-9a-zA-Z]', '_', text) def stringify_ent(ent): return f'{ent}_{escape(ent_to_lbl[ent])}' def stringify_rel(rel): return f'{rel}_{escape(rel_to_lbl[rel])}' train_facts = split_dir.train_facts_tsv.load() anyburl_facts = [ Fact(stringify_ent(head), stringify_rel(rel), stringify_ent(tail)) for head, _, rel, _, tail, _ in train_facts ] facts_tsv.save(anyburl_facts)
def prepare_ruler(args): rules_tsv_path = args.rules_tsv url = args.url username = args.username password = args.password split_dir_path = args.split_dir ruler_pkl_path = args.ruler_pkl min_conf = args.min_conf min_supp = args.min_supp overwrite = args.overwrite # # Check that (input) POWER Rules TSV exists # logging.info('Check that (input) POWER Rules TSV exists ...') rules_tsv = RulesTsv(Path(rules_tsv_path)) rules_tsv.check() # # Check that (input) POWER Split Directory exists # logging.info('Check that (input) POWERT Split Directory exists ...') split_dir = SplitDir(Path(split_dir_path)) split_dir.check() # # Check that (output) POWER Ruler PKL does not exist # logging.info('Check that (output) POWER Ruler PKL does not exist ...') ruler_pkl = RulerPkl(Path(ruler_pkl_path)) ruler_pkl.check(should_exist=overwrite) # # Read rules # logging.info('Read rules ...') ent_to_lbl = split_dir.entities_tsv.load() rel_to_lbl = split_dir.relations_tsv.load() anyburl_rules = rules_tsv.load() rules = [Rule.from_anyburl(rule, ent_to_lbl, rel_to_lbl) for rule in anyburl_rules] good_rules = [rule for rule in rules if rule.conf >= min_conf and rule.fires >= min_supp] good_rules.sort(key=lambda rule: rule.conf, reverse=True) short_rules = [rule for rule in good_rules if len(rule.body) == 1] log_rules('Rules', short_rules) # # Load train facts # logging.info('Load train facts ...') train_triples = split_dir.train_facts_tsv.load() train_facts = {Fact.from_ints(head, rel, tail, ent_to_lbl, rel_to_lbl) for head, _, rel, _, tail, _ in train_triples} # # Process rules # logging.info('Process rules ...') driver = GraphDatabase.driver(url, auth=(username, password)) unsupported_rules = 0 pred = defaultdict(get_defaultdict) with driver.session() as session: if logging.getLogger().level == logging.DEBUG: iter_short_rules = short_rules else: iter_short_rules = tqdm(short_rules) for rule in iter_short_rules: logging.debug(f'Process rule {rule}') # # Process rule body # body_fact = rule.body[0] if type(body_fact.head) == Var and type(body_fact.tail) == Ent: records = session.write_transaction(query_facts_by_rel_tail, rel=body_fact.rel, tail=body_fact.tail) ents = [Ent(head['id'], ent_to_lbl[head['id']]) for head, _, _ in records] elif type(body_fact.head) == Ent and type(body_fact.tail) == Var: records = session.write_transaction(query_facts_by_head_rel, head=body_fact.head, rel=body_fact.rel) ents = [Ent(tail['id'], ent_to_lbl[tail['id']]) for _, _, tail in records] else: logging.debug(f'Unsupported rule body in rule {rule}. Skipping.') unsupported_rules += 1 continue # # Process rule head # head_fact = rule.head if type(head_fact.head) == Var and type(head_fact.tail) == Ent: pred_facts = [Fact(ent, head_fact.rel, head_fact.tail) for ent in ents] elif type(head_fact.head) == Ent and type(head_fact.tail) == Var: pred_facts = [Fact(head_fact.head, head_fact.rel, ent) for ent in ents] else: logging.debug(f'Unsupported rule head in rule {rule}. Skipping.') unsupported_rules += 1 continue # # Filter out train facts and save predicted valid facts # for fact in pred_facts: # if fact not in train_facts: pred[fact.head][(fact.rel, fact.tail)].append(rule) driver.close() # # Persist ruler # logging.info('Persist ruler ...') ruler = Ruler() ruler.pred = pred ruler_pkl.save(ruler)
def eval_texter(args): texter_pkl_path = args.texter_pkl sent_count = args.sent_count split_dir_path = args.split_dir text_dir_path = args.text_dir filter_known = args.filter_known test = args.test # # Check that (input) POWER Texter PKL exists # logging.info('Check that (input) POWER Texter PKL exists ...') texter_pkl = TexterPkl(Path(texter_pkl_path)) texter_pkl.check() # # Check that (input) POWER Split Directory exists # logging.info('Check that (input) POWER Split Directory exists ...') split_dir = SplitDir(Path(split_dir_path)) split_dir.check() # # Check that (input) IRT Text Directory exists # logging.info('Check that (input) IRT Text Directory exists ...') text_dir = TextDir(Path(text_dir_path)) text_dir.check() # # Load texter # logging.info('Load texter ...') texter = texter_pkl.load().cpu() # # Load facts # logging.info('Load facts ...') ent_to_lbl = split_dir.entities_tsv.load() rel_to_lbl = split_dir.relations_tsv.load() if test: known_test_facts = split_dir.test_facts_known_tsv.load() unknown_test_facts = split_dir.test_facts_unknown_tsv.load() known_eval_facts = known_test_facts all_eval_facts = known_test_facts + unknown_test_facts else: known_valid_facts = split_dir.valid_facts_known_tsv.load() unknown_valid_facts = split_dir.valid_facts_unknown_tsv.load() known_eval_facts = known_valid_facts all_eval_facts = known_valid_facts + unknown_valid_facts known_facts = {Fact.from_ints(head, rel, tail, ent_to_lbl, rel_to_lbl) for head, _, rel, _, tail, _ in known_eval_facts} all_eval_facts = {Fact.from_ints(head, rel, tail, ent_to_lbl, rel_to_lbl) for head, _, rel, _, tail, _ in all_eval_facts} # # Load entities # logging.info('Load entities ...') if test: eval_ents = split_dir.test_entities_tsv.load() else: eval_ents = split_dir.valid_entities_tsv.load() eval_ents = [Ent(ent, lbl) for ent, lbl in eval_ents.items()] # # Load texts # logging.info('Load texts ...') if test: eval_ent_to_sents = text_dir.ow_test_sents_txt.load() else: eval_ent_to_sents = text_dir.ow_valid_sents_txt.load() # # Evaluate # all_gt_bools = [] all_pred_bools = [] all_prfs = [] all_ap = [] for ent in eval_ents: logging.debug(f'Evaluate entity {ent} ...') # # Predict entity facts # sents = list(eval_ent_to_sents[ent.id])[:sent_count] if len(sents) < sent_count: logging.warning(f'Only {len(sents)} sentences for entity "{ent.lbl}" ({ent.id}). Skipping.') continue preds: List[Pred] = texter.predict(ent, sents) if filter_known: preds = [pred for pred in preds if pred.fact not in known_facts] logging.debug('Predictions:') for pred in preds: logging.debug(str(pred)) # # Get entity ground truth facts # gt_facts = [fact for fact in all_eval_facts if fact.head == ent] if filter_known: gt_facts = list(set(gt_facts).difference(known_facts)) logging.debug('Ground truth:') for fact in gt_facts: logging.debug(str(fact)) # # Calc entity PRFS # pred_facts = {pred.fact for pred in preds} pred_and_gt_facts = list(pred_facts | set(gt_facts)) gt_bools = [1 if fact in gt_facts else 0 for fact in pred_and_gt_facts] pred_bools = [1 if fact in pred_facts else 0 for fact in pred_and_gt_facts] prfs = precision_recall_fscore_support(gt_bools, pred_bools, labels=[1], zero_division=1) all_prfs.append(prfs) # # Add ent results to global results for micro metrics # all_gt_bools.extend(gt_bools) all_pred_bools.extend(pred_bools) # # Calc entity AP # pred_fact_conf_tuples = [(pred.fact, pred.conf) for pred in preds] ap = calc_ap(pred_fact_conf_tuples, gt_facts) all_ap.append(ap) # # Log entity metrics # logging.info(f'{str(ent.id):5} {ent.lbl:40}: AP = {ap:.2f}, Prec = {prfs[0][0]:.2f}, Rec = {prfs[1][0]:.2f}, ' f'F1 = {prfs[2][0]:.2f}, Supp = {prfs[3][0]}') m_ap = sum(all_ap) / len(all_ap) logging.info(f'mAP = {m_ap:.4f}') macro_prfs = np.array(all_prfs).mean(axis=0) logging.info(f'Macro Prec = {macro_prfs[0][0]:.4f}') logging.info(f'Macro Rec = {macro_prfs[1][0]:.4f}') logging.info(f'Macro F1 = {macro_prfs[2][0]:.4f}') logging.info(f'Macro Supp = {macro_prfs[3][0]:.2f}') micro_prfs = precision_recall_fscore_support(all_gt_bools, all_pred_bools, labels=[1], zero_division=1) logging.info(f'Micro Prec = {micro_prfs[0][0]:.4f}') logging.info(f'Micro Rec = {micro_prfs[1][0]:.4f}') logging.info(f'Micro F1 = {micro_prfs[2][0]:.4f}') logging.info(f'Micro Supp = {micro_prfs[3][0]}')
def eval_zero_rule(args): samples_dir_path = args.samples_dir class_count = args.class_count sent_count = args.sent_count split_dir_path = args.split_dir # # Check that (input) POWER Samples Directory exists # logging.info('Check that (input) POWER Samples Directory exists ...') samples_dir = SamplesDir(Path(samples_dir_path)) samples_dir.check() # # Check that (input) POWER Split Directory exists # logging.info('Check that (input) POWER Split Directory exists ...') split_dir = SplitDir(Path(split_dir_path)) split_dir.check() # # Load entity/relation labels # logging.info('Load entity/relation labels ...') ent_to_lbl = split_dir.entities_tsv.load() rel_to_lbl = split_dir.relations_tsv.load() # # Load datasets # logging.info('Load test dataset ...') test_set = samples_dir.test_samples_tsv.load(class_count, sent_count) # # Calc class frequencies # logging.info('Calc class frequencies ...') _, _, test_classes_stack, _ = zip(*test_set) test_freqs = np.array(test_classes_stack).mean(axis=0) # # Evaluate # logging.info(f'test_freqs = {test_freqs}') for strategy in ('uniform', 'stratified', 'most_frequent', 'constant'): logging.info(strategy) mean_metrics = [] for i, gt in tqdm(enumerate(np.array(test_classes_stack).T)): if strategy == 'constant': classifier = DummyClassifier(strategy='constant', constant=1) classifier.fit([0, 1], [0, 1]) else: classifier = DummyClassifier(strategy=strategy) classifier.fit(gt, gt) metrics_list = [] for _ in range(10): pred = classifier.predict(gt) acc = accuracy_score(gt, pred) prec, recall, f1, _ = precision_recall_fscore_support( gt, pred, labels=[1], zero_division=1) metrics_list.append((acc, prec[0], recall[0], f1[0])) mean_metrics.append(np.mean(metrics_list, axis=0)) logging.info(mean_metrics[0]) logging.info(mean_metrics[-1]) logging.info(np.mean(mean_metrics, axis=0))
def train_texter(args): samples_dir_path = args.samples_dir class_count = args.class_count sent_count = args.sent_count split_dir_path = args.split_dir texter_pkl_path = args.texter_pkl batch_size = args.batch_size device = args.device epoch_count = args.epoch_count log_dir = args.log_dir log_steps = args.log_steps lr = args.lr overwrite = args.overwrite sent_len = args.sent_len try_batch_size = args.try_batch_size # # Check that (input) POWER Samples Directory exists # logging.info('Check that (input) POWER Samples Directory exists ...') samples_dir = SamplesDir(Path(samples_dir_path)) samples_dir.check() # # Check that (input) POWER Split Directory exists # logging.info('Check that (input) POWER Split Directory exists ...') split_dir = SplitDir(Path(split_dir_path)) split_dir.check() # # Check that (output) POWER Texter PKL does not exist # logging.info('Check that (output) POWER Texter PKL does not exist ...') texter_pkl = TexterPkl(Path(texter_pkl_path)) if not overwrite: texter_pkl.check(should_exist=False) # # Load entity/relation labels # logging.info('Load entity/relation labels ...') ent_to_lbl = split_dir.entities_tsv.load() rel_to_lbl = split_dir.relations_tsv.load() # # Create Texter # rel_tail_freq_lbl_tuples = samples_dir.classes_tsv.load() classes = [(Rel(rel, rel_to_lbl[rel]), Ent(tail, ent_to_lbl[tail])) for rel, tail, _, _ in rel_tail_freq_lbl_tuples] pre_trained = 'distilbert-base-uncased' texter = Texter(pre_trained, classes) # # Load datasets and create dataloaders # logging.info('Load datasets and create dataloaders ...') train_set = samples_dir.train_samples_tsv.load(class_count, sent_count) valid_set = samples_dir.valid_samples_tsv.load(class_count, sent_count) def generate_batch( batch: List[Sample]) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """ :param batch: [Sample(ent, ent_lbl, [class], [sent])] :return: ent_batch: IntTensor[batch_size], tok_lists_batch: IntTensor[batch_size, sent_count, sent_len], masks_batch: IntTensor[batch_size, sent_count, sent_len], classes_batch: IntTensor[batch_size, class_count] """ ent_batch, _, classes_batch, sents_batch = zip(*batch) for sents in sents_batch: shuffle(sents) flat_sents_batch = [sent for sents in sents_batch for sent in sents] encoded = texter.tokenizer(flat_sents_batch, padding=True, truncation=True, max_length=sent_len, return_tensors='pt') b_size = len( ent_batch ) # usually b_size == batch_size, except for last batch in samples tok_lists_batch = encoded.input_ids.reshape(b_size, sent_count, -1) masks_batch = encoded.attention_mask.reshape(b_size, sent_count, -1) return tensor(ent_batch), tok_lists_batch, masks_batch, tensor( classes_batch) train_loader = DataLoader(train_set, batch_size=batch_size, collate_fn=generate_batch, shuffle=True) valid_loader = DataLoader(valid_set, batch_size=batch_size, collate_fn=generate_batch) # # Calc class weights # logging.info('Calc class weights ...') _, _, train_classes_stack, _ = zip(*train_set) train_freqs = np.array(train_classes_stack).mean(axis=0) class_weights = tensor(1 / train_freqs) # # Prepare training # logging.info('Prepare training ...') texter = texter.to(device) criterion = BCEWithLogitsLoss(pos_weight=class_weights.to(device)) no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in texter.named_parameters() if not any(nd in n for nd in no_decay) ], 'weight_decay': 0.01 }, { 'params': [ p for n, p in texter.named_parameters() if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0 }] optimizer = AdamW(optimizer_grouped_parameters, lr=lr) writer = SummaryWriter(log_dir=log_dir) # # Train # logging.info('Train ...') best_valid_f1 = 0 # Global progress for Tensorboard train_steps = 0 valid_steps = 0 for epoch in range(epoch_count): epoch_metrics = { 'train': { 'loss': 0.0, 'pred_classes_stack': [], 'gt_classes_stack': [] }, 'valid': { 'loss': 0.0, 'pred_classes_stack': [], 'gt_classes_stack': [] } } # # Train # texter.train() for _, sents_batch, masks_batch, gt_batch in tqdm( train_loader, desc=f'Epoch {epoch}'): train_steps += len(sents_batch) sents_batch = sents_batch.to(device) masks_batch = masks_batch.to(device) gt_batch = gt_batch.to(device).float() logits_batch = texter(sents_batch, masks_batch)[0] loss = criterion(logits_batch, gt_batch) optimizer.zero_grad() loss.backward() optimizer.step() # # Log metrics # pred_batch = (logits_batch > 0).int() step_loss = loss.item() step_pred_batch = pred_batch.cpu().numpy().tolist() step_gt_batch = gt_batch.cpu().numpy().tolist() epoch_metrics['train']['loss'] += step_loss epoch_metrics['train']['pred_classes_stack'] += step_pred_batch epoch_metrics['train']['gt_classes_stack'] += step_gt_batch if log_steps: writer.add_scalars('loss', {'train': step_loss}, train_steps) step_metrics = { 'train': { 'pred_classes_stack': step_pred_batch, 'gt_classes_stack': step_gt_batch } } log_class_metrics(step_metrics, writer, train_steps, class_count) log_macro_metrics(step_metrics, writer, train_steps) if try_batch_size: break # # Validate # texter.eval() for _, sents_batch, masks_batch, gt_batch in tqdm( valid_loader, desc=f'Epoch {epoch}'): valid_steps += len(sents_batch) sents_batch = sents_batch.to(device) masks_batch = masks_batch.to(device) gt_batch = gt_batch.to(device).float() logits_batch = texter(sents_batch, masks_batch)[0] loss = criterion(logits_batch, gt_batch) # # Log metrics # pred_batch = (logits_batch > 0).int() step_loss = loss.item() step_pred_batch = pred_batch.cpu().numpy().tolist() step_gt_batch = gt_batch.cpu().numpy().tolist() epoch_metrics['valid']['loss'] += step_loss epoch_metrics['valid']['pred_classes_stack'] += step_pred_batch epoch_metrics['valid']['gt_classes_stack'] += step_gt_batch if log_steps: writer.add_scalars('loss', {'valid': step_loss}, valid_steps) step_metrics = { 'valid': { 'pred_classes_stack': step_pred_batch, 'gt_classes_stack': step_gt_batch } } log_class_metrics(step_metrics, writer, valid_steps, class_count) log_macro_metrics(step_metrics, writer, valid_steps) if try_batch_size: break # # Log loss # train_loss = epoch_metrics['train']['loss'] / len(train_loader) valid_loss = epoch_metrics['valid']['loss'] / len(valid_loader) writer.add_scalars('loss', { 'train': train_loss, 'valid': valid_loss }, epoch) # # Log metrics # log_class_metrics(epoch_metrics, writer, epoch, class_count) valid_f1 = log_macro_metrics(epoch_metrics, writer, epoch) # # Persist Texter # if valid_f1 > best_valid_f1: best_valid_f1 = valid_f1 texter_pkl.save(texter) if try_batch_size: break
def train_texter(args): ruler_pkl_path = args.ruler_pkl texter_pkl_path = args.texter_pkl sent_count = args.sent_count split_dir_path = args.split_dir text_dir_path = args.text_dir epoch_count = args.epoch_count log_dir = args.log_dir log_steps = args.log_steps lr = args.lr overwrite = args.overwrite sent_len = args.sent_len # # Check that (input) POWER Ruler PKL exists # logging.info('Check that (input) POWER Ruler PKL exists ...') ruler_pkl = RulerPkl(Path(ruler_pkl_path)) ruler_pkl.check() # # Check that (input) POWER Texter PKL exists # logging.info('Check that (input) POWER Texter PKL exists ...') texter_pkl = TexterPkl(Path(texter_pkl_path)) texter_pkl.check() # # Check that (input) POWER Split Directory exists # logging.info('Check that (input) POWER Split Directory exists ...') split_dir = SplitDir(Path(split_dir_path)) split_dir.check() # # Check that (input) IRT Text Directory exists # logging.info('Check that (input) IRT Text Directory exists ...') text_dir = TextDir(Path(text_dir_path)) text_dir.check() # # Load ruler # logging.info('Load ruler ...') ruler = ruler_pkl.load() # # Load texter # logging.info('Load texter ...') texter = texter_pkl.load().cpu() # # Build POWER # power = Aggregator(texter, ruler) # # Load facts # logging.info('Load facts ...') ent_to_lbl = split_dir.entities_tsv.load() rel_to_lbl = split_dir.relations_tsv.load() train_facts = split_dir.train_facts_tsv.load() train_facts = { Fact.from_ints(head, rel, tail, ent_to_lbl, rel_to_lbl) for head, _, rel, _, tail, _ in train_facts } known_valid_facts = split_dir.valid_facts_known_tsv.load() unknown_valid_facts = split_dir.valid_facts_unknown_tsv.load() known_eval_facts = known_valid_facts all_valid_facts = known_valid_facts + unknown_valid_facts known_facts = { Fact.from_ints(head, rel, tail, ent_to_lbl, rel_to_lbl) for head, _, rel, _, tail, _ in known_eval_facts } all_valid_facts = { Fact.from_ints(head, rel, tail, ent_to_lbl, rel_to_lbl) for head, _, rel, _, tail, _ in all_valid_facts } # # Load entities # logging.info('Load entities ...') train_ents = split_dir.train_entities_tsv.load() valid_ents = split_dir.valid_entities_tsv.load() train_ents = [Ent(ent, lbl) for ent, lbl in train_ents.items()] valid_ents = [Ent(ent, lbl) for ent, lbl in valid_ents.items()] # # Load texts # logging.info('Load texts ...') train_ent_to_sents = text_dir.cw_train_sents_txt.load() valid_ent_to_sents = text_dir.ow_valid_sents_txt.load() # # Prepare training # criterion = MSELoss() writer = SummaryWriter(log_dir=log_dir) # # Train # logging.info('Train ...') texter_optimizer = SGD([power.texter_weight], lr=lr) ruler_optimizer = SGD([power.ruler_weight], lr=lr) for epoch in range(epoch_count): for ent in train_ents: print(power.texter_weight) print(power.ruler_weight) print() # # Get entity ground truth facts # gt_facts = [fact for fact in train_facts if fact.head == ent] logging.debug('Ground truth:') for fact in gt_facts: logging.debug(str(fact)) # # Train Texter Weight # sents = list(train_ent_to_sents[ent.id])[:sent_count] if len(sents) < sent_count: logging.warning( f'Only {len(sents)} sentences for entity "{ent.lbl}" ({ent.id}). Skipping.' ) continue texter_preds = texter.predict(ent, sents) train_confs = [pred.conf for pred in texter_preds] gt_confs = [ 1 if pred.fact in gt_facts else 0 for pred in texter_preds ] for train_conf, gt_conf in zip(train_confs, gt_confs): loss = criterion( torch.tensor(train_conf) * power.texter_weight, torch.tensor(gt_conf).float()) texter_optimizer.zero_grad() loss.backward() texter_optimizer.step() # # Train Ruler Weight # ruler_preds = ruler.predict(ent) train_confs = [pred.conf for pred in ruler_preds] gt_confs = [ 1 if pred.fact in gt_facts else 0 for pred in ruler_preds ] for train_conf, gt_conf in zip(train_confs, gt_confs): loss = criterion( torch.tensor(train_conf) * power.ruler_weight, torch.tensor(gt_conf).float()) ruler_optimizer.zero_grad() loss.backward() ruler_optimizer.step()