示例#1
0
def add_sidebar_param_split_dir() -> SplitDir:
    """ Add text input for path to Power Split Dir to sidebar, check Power Split Dir, and return handle """

    split_dir_path = st.sidebar.text_input('Path to Power Split Directory',
                                           value='data/power/split/cde-50/')

    split_dir = SplitDir(Path(split_dir_path))
    split_dir.check()

    return split_dir
示例#2
0
def create_anyburl_dataset(args):
    split_dir_path = args.split_dir
    facts_tsv_path = args.facts_tsv

    overwrite = args.overwrite

    #
    # Check that (input) IRT Split Directory exists
    #

    logging.info('Check that (input) IRT Split Directory exists ...')

    split_dir = SplitDir(Path(split_dir_path))
    split_dir.check()

    #
    # Check that (output) AnyBURL Facts TSV does not exist
    #

    logging.info('Check that (output) AnyBURL Facts TSV does not exist ...')

    facts_tsv = FactsTsv(Path(facts_tsv_path))
    if not overwrite:
        facts_tsv.check(should_exist=False)

    facts_tsv.path.parent.mkdir(parents=True, exist_ok=True)

    #
    # Create AnyBURL Facts TSV
    #

    logging.info('Create AnyBURL Facts TSV ...')

    ent_to_lbl = split_dir.entities_tsv.load()
    rel_to_lbl = split_dir.relations_tsv.load()

    def escape(text):
        return re.sub('[^0-9a-zA-Z]', '_', text)

    def stringify_ent(ent):
        return f'{ent}_{escape(ent_to_lbl[ent])}'

    def stringify_rel(rel):
        return f'{rel}_{escape(rel_to_lbl[rel])}'

    train_facts = split_dir.train_facts_tsv.load()

    anyburl_facts = [
        Fact(stringify_ent(head), stringify_rel(rel), stringify_ent(tail))
        for head, _, rel, _, tail, _ in train_facts
    ]

    facts_tsv.save(anyburl_facts)
示例#3
0
def prepare_ruler(args):
    rules_tsv_path = args.rules_tsv
    url = args.url
    username = args.username
    password = args.password
    split_dir_path = args.split_dir
    ruler_pkl_path = args.ruler_pkl

    min_conf = args.min_conf
    min_supp = args.min_supp
    overwrite = args.overwrite

    #
    # Check that (input) POWER Rules TSV exists
    #

    logging.info('Check that (input) POWER Rules TSV exists ...')

    rules_tsv = RulesTsv(Path(rules_tsv_path))
    rules_tsv.check()

    #
    # Check that (input) POWER Split Directory exists
    #

    logging.info('Check that (input) POWERT Split Directory exists ...')

    split_dir = SplitDir(Path(split_dir_path))
    split_dir.check()

    #
    # Check that (output) POWER Ruler PKL does not exist
    #

    logging.info('Check that (output) POWER Ruler PKL does not exist ...')

    ruler_pkl = RulerPkl(Path(ruler_pkl_path))
    ruler_pkl.check(should_exist=overwrite)

    #
    # Read rules
    #

    logging.info('Read rules ...')

    ent_to_lbl = split_dir.entities_tsv.load()
    rel_to_lbl = split_dir.relations_tsv.load()

    anyburl_rules = rules_tsv.load()
    rules = [Rule.from_anyburl(rule, ent_to_lbl, rel_to_lbl) for rule in anyburl_rules]

    good_rules = [rule for rule in rules if rule.conf >= min_conf and rule.fires >= min_supp]
    good_rules.sort(key=lambda rule: rule.conf, reverse=True)

    short_rules = [rule for rule in good_rules if len(rule.body) == 1]
    log_rules('Rules', short_rules)

    #
    # Load train facts
    #

    logging.info('Load train facts ...')

    train_triples = split_dir.train_facts_tsv.load()
    train_facts = {Fact.from_ints(head, rel, tail, ent_to_lbl, rel_to_lbl)
                   for head, _, rel, _, tail, _ in train_triples}

    #
    # Process rules
    #

    logging.info('Process rules ...')

    driver = GraphDatabase.driver(url, auth=(username, password))
    unsupported_rules = 0

    pred = defaultdict(get_defaultdict)

    with driver.session() as session:
        if logging.getLogger().level == logging.DEBUG:
            iter_short_rules = short_rules
        else:
            iter_short_rules = tqdm(short_rules)

        for rule in iter_short_rules:
            logging.debug(f'Process rule {rule}')

            #
            # Process rule body
            #

            body_fact = rule.body[0]

            if type(body_fact.head) == Var and type(body_fact.tail) == Ent:
                records = session.write_transaction(query_facts_by_rel_tail, rel=body_fact.rel, tail=body_fact.tail)
                ents = [Ent(head['id'], ent_to_lbl[head['id']]) for head, _, _ in records]

            elif type(body_fact.head) == Ent and type(body_fact.tail) == Var:
                records = session.write_transaction(query_facts_by_head_rel, head=body_fact.head, rel=body_fact.rel)
                ents = [Ent(tail['id'], ent_to_lbl[tail['id']]) for _, _, tail in records]

            else:
                logging.debug(f'Unsupported rule body in rule {rule}. Skipping.')
                unsupported_rules += 1
                continue

            #
            # Process rule head
            #

            head_fact = rule.head

            if type(head_fact.head) == Var and type(head_fact.tail) == Ent:
                pred_facts = [Fact(ent, head_fact.rel, head_fact.tail) for ent in ents]

            elif type(head_fact.head) == Ent and type(head_fact.tail) == Var:
                pred_facts = [Fact(head_fact.head, head_fact.rel, ent) for ent in ents]

            else:
                logging.debug(f'Unsupported rule head in rule {rule}. Skipping.')
                unsupported_rules += 1
                continue

            #
            # Filter out train facts and save predicted valid facts
            #

            for fact in pred_facts:
                # if fact not in train_facts:
                pred[fact.head][(fact.rel, fact.tail)].append(rule)

    driver.close()

    #
    # Persist ruler
    #

    logging.info('Persist ruler ...')

    ruler = Ruler()
    ruler.pred = pred

    ruler_pkl.save(ruler)
示例#4
0
def eval_texter(args):
    texter_pkl_path = args.texter_pkl
    sent_count = args.sent_count
    split_dir_path = args.split_dir
    text_dir_path = args.text_dir

    filter_known = args.filter_known
    test = args.test

    #
    # Check that (input) POWER Texter PKL exists
    #

    logging.info('Check that (input) POWER Texter PKL exists ...')

    texter_pkl = TexterPkl(Path(texter_pkl_path))
    texter_pkl.check()

    #
    # Check that (input) POWER Split Directory exists
    #

    logging.info('Check that (input) POWER Split Directory exists ...')

    split_dir = SplitDir(Path(split_dir_path))
    split_dir.check()

    #
    # Check that (input) IRT Text Directory exists
    #

    logging.info('Check that (input) IRT Text Directory exists ...')

    text_dir = TextDir(Path(text_dir_path))
    text_dir.check()

    #
    # Load texter
    #

    logging.info('Load texter ...')

    texter = texter_pkl.load().cpu()

    #
    # Load facts
    #

    logging.info('Load facts ...')

    ent_to_lbl = split_dir.entities_tsv.load()
    rel_to_lbl = split_dir.relations_tsv.load()

    if test:
        known_test_facts = split_dir.test_facts_known_tsv.load()
        unknown_test_facts = split_dir.test_facts_unknown_tsv.load()

        known_eval_facts = known_test_facts
        all_eval_facts = known_test_facts + unknown_test_facts

    else:
        known_valid_facts = split_dir.valid_facts_known_tsv.load()
        unknown_valid_facts = split_dir.valid_facts_unknown_tsv.load()

        known_eval_facts = known_valid_facts
        all_eval_facts = known_valid_facts + unknown_valid_facts

    known_facts = {Fact.from_ints(head, rel, tail, ent_to_lbl, rel_to_lbl)
                   for head, _, rel, _, tail, _ in known_eval_facts}

    all_eval_facts = {Fact.from_ints(head, rel, tail, ent_to_lbl, rel_to_lbl)
                      for head, _, rel, _, tail, _ in all_eval_facts}

    #
    # Load entities
    #

    logging.info('Load entities ...')

    if test:
        eval_ents = split_dir.test_entities_tsv.load()
    else:
        eval_ents = split_dir.valid_entities_tsv.load()

    eval_ents = [Ent(ent, lbl) for ent, lbl in eval_ents.items()]

    #
    # Load texts
    #

    logging.info('Load texts ...')

    if test:
        eval_ent_to_sents = text_dir.ow_test_sents_txt.load()
    else:
        eval_ent_to_sents = text_dir.ow_valid_sents_txt.load()

    #
    # Evaluate
    #

    all_gt_bools = []
    all_pred_bools = []

    all_prfs = []
    all_ap = []

    for ent in eval_ents:
        logging.debug(f'Evaluate entity {ent} ...')

        #
        # Predict entity facts
        #

        sents = list(eval_ent_to_sents[ent.id])[:sent_count]
        if len(sents) < sent_count:
            logging.warning(f'Only {len(sents)} sentences for entity "{ent.lbl}" ({ent.id}). Skipping.')
            continue

        preds: List[Pred] = texter.predict(ent, sents)

        if filter_known:
            preds = [pred for pred in preds if pred.fact not in known_facts]

        logging.debug('Predictions:')
        for pred in preds:
            logging.debug(str(pred))

        #
        # Get entity ground truth facts
        #

        gt_facts = [fact for fact in all_eval_facts if fact.head == ent]

        if filter_known:
            gt_facts = list(set(gt_facts).difference(known_facts))

        logging.debug('Ground truth:')
        for fact in gt_facts:
            logging.debug(str(fact))

        #
        # Calc entity PRFS
        #

        pred_facts = {pred.fact for pred in preds}
        pred_and_gt_facts = list(pred_facts | set(gt_facts))

        gt_bools = [1 if fact in gt_facts else 0 for fact in pred_and_gt_facts]
        pred_bools = [1 if fact in pred_facts else 0 for fact in pred_and_gt_facts]

        prfs = precision_recall_fscore_support(gt_bools, pred_bools, labels=[1], zero_division=1)
        all_prfs.append(prfs)

        #
        # Add ent results to global results for micro metrics
        #

        all_gt_bools.extend(gt_bools)
        all_pred_bools.extend(pred_bools)

        #
        # Calc entity AP
        #

        pred_fact_conf_tuples = [(pred.fact, pred.conf) for pred in preds]

        ap = calc_ap(pred_fact_conf_tuples, gt_facts)
        all_ap.append(ap)

        #
        # Log entity metrics
        #

        logging.info(f'{str(ent.id):5} {ent.lbl:40}: AP = {ap:.2f}, Prec = {prfs[0][0]:.2f}, Rec = {prfs[1][0]:.2f}, '
                     f'F1 = {prfs[2][0]:.2f}, Supp = {prfs[3][0]}')

    m_ap = sum(all_ap) / len(all_ap)
    logging.info(f'mAP = {m_ap:.4f}')

    macro_prfs = np.array(all_prfs).mean(axis=0)
    logging.info(f'Macro Prec = {macro_prfs[0][0]:.4f}')
    logging.info(f'Macro Rec = {macro_prfs[1][0]:.4f}')
    logging.info(f'Macro F1 = {macro_prfs[2][0]:.4f}')
    logging.info(f'Macro Supp = {macro_prfs[3][0]:.2f}')

    micro_prfs = precision_recall_fscore_support(all_gt_bools, all_pred_bools, labels=[1], zero_division=1)
    logging.info(f'Micro Prec = {micro_prfs[0][0]:.4f}')
    logging.info(f'Micro Rec = {micro_prfs[1][0]:.4f}')
    logging.info(f'Micro F1 = {micro_prfs[2][0]:.4f}')
    logging.info(f'Micro Supp = {micro_prfs[3][0]}')
示例#5
0
def eval_zero_rule(args):
    samples_dir_path = args.samples_dir
    class_count = args.class_count
    sent_count = args.sent_count
    split_dir_path = args.split_dir

    #
    # Check that (input) POWER Samples Directory exists
    #

    logging.info('Check that (input) POWER Samples Directory exists ...')

    samples_dir = SamplesDir(Path(samples_dir_path))
    samples_dir.check()

    #
    # Check that (input) POWER Split Directory exists
    #

    logging.info('Check that (input) POWER Split Directory exists ...')

    split_dir = SplitDir(Path(split_dir_path))
    split_dir.check()

    #
    # Load entity/relation labels
    #

    logging.info('Load entity/relation labels ...')

    ent_to_lbl = split_dir.entities_tsv.load()
    rel_to_lbl = split_dir.relations_tsv.load()

    #
    # Load datasets
    #

    logging.info('Load test dataset ...')

    test_set = samples_dir.test_samples_tsv.load(class_count, sent_count)

    #
    # Calc class frequencies
    #

    logging.info('Calc class frequencies ...')

    _, _, test_classes_stack, _ = zip(*test_set)
    test_freqs = np.array(test_classes_stack).mean(axis=0)

    #
    # Evaluate
    #

    logging.info(f'test_freqs = {test_freqs}')

    for strategy in ('uniform', 'stratified', 'most_frequent', 'constant'):
        logging.info(strategy)

        mean_metrics = []
        for i, gt in tqdm(enumerate(np.array(test_classes_stack).T)):

            if strategy == 'constant':
                classifier = DummyClassifier(strategy='constant', constant=1)
                classifier.fit([0, 1], [0, 1])
            else:
                classifier = DummyClassifier(strategy=strategy)
                classifier.fit(gt, gt)

            metrics_list = []
            for _ in range(10):
                pred = classifier.predict(gt)

                acc = accuracy_score(gt, pred)
                prec, recall, f1, _ = precision_recall_fscore_support(
                    gt, pred, labels=[1], zero_division=1)

                metrics_list.append((acc, prec[0], recall[0], f1[0]))

            mean_metrics.append(np.mean(metrics_list, axis=0))

        logging.info(mean_metrics[0])
        logging.info(mean_metrics[-1])
        logging.info(np.mean(mean_metrics, axis=0))
示例#6
0
def train_texter(args):
    samples_dir_path = args.samples_dir
    class_count = args.class_count
    sent_count = args.sent_count
    split_dir_path = args.split_dir
    texter_pkl_path = args.texter_pkl

    batch_size = args.batch_size
    device = args.device
    epoch_count = args.epoch_count
    log_dir = args.log_dir
    log_steps = args.log_steps
    lr = args.lr
    overwrite = args.overwrite
    sent_len = args.sent_len
    try_batch_size = args.try_batch_size

    #
    # Check that (input) POWER Samples Directory exists
    #

    logging.info('Check that (input) POWER Samples Directory exists ...')

    samples_dir = SamplesDir(Path(samples_dir_path))
    samples_dir.check()

    #
    # Check that (input) POWER Split Directory exists
    #

    logging.info('Check that (input) POWER Split Directory exists ...')

    split_dir = SplitDir(Path(split_dir_path))
    split_dir.check()

    #
    # Check that (output) POWER Texter PKL does not exist
    #

    logging.info('Check that (output) POWER Texter PKL does not exist ...')

    texter_pkl = TexterPkl(Path(texter_pkl_path))

    if not overwrite:
        texter_pkl.check(should_exist=False)

    #
    # Load entity/relation labels
    #

    logging.info('Load entity/relation labels ...')

    ent_to_lbl = split_dir.entities_tsv.load()
    rel_to_lbl = split_dir.relations_tsv.load()

    #
    # Create Texter
    #

    rel_tail_freq_lbl_tuples = samples_dir.classes_tsv.load()

    classes = [(Rel(rel, rel_to_lbl[rel]), Ent(tail, ent_to_lbl[tail]))
               for rel, tail, _, _ in rel_tail_freq_lbl_tuples]

    pre_trained = 'distilbert-base-uncased'
    texter = Texter(pre_trained, classes)

    #
    # Load datasets and create dataloaders
    #

    logging.info('Load datasets and create dataloaders ...')

    train_set = samples_dir.train_samples_tsv.load(class_count, sent_count)
    valid_set = samples_dir.valid_samples_tsv.load(class_count, sent_count)

    def generate_batch(
            batch: List[Sample]) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
        """
        :param    batch:            [Sample(ent, ent_lbl, [class], [sent])]

        :return:  ent_batch:        IntTensor[batch_size],
                  tok_lists_batch:  IntTensor[batch_size, sent_count, sent_len],
                  masks_batch:      IntTensor[batch_size, sent_count, sent_len],
                  classes_batch:    IntTensor[batch_size, class_count]
        """

        ent_batch, _, classes_batch, sents_batch = zip(*batch)

        for sents in sents_batch:
            shuffle(sents)

        flat_sents_batch = [sent for sents in sents_batch for sent in sents]

        encoded = texter.tokenizer(flat_sents_batch,
                                   padding=True,
                                   truncation=True,
                                   max_length=sent_len,
                                   return_tensors='pt')

        b_size = len(
            ent_batch
        )  # usually b_size == batch_size, except for last batch in samples
        tok_lists_batch = encoded.input_ids.reshape(b_size, sent_count, -1)
        masks_batch = encoded.attention_mask.reshape(b_size, sent_count, -1)

        return tensor(ent_batch), tok_lists_batch, masks_batch, tensor(
            classes_batch)

    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              collate_fn=generate_batch,
                              shuffle=True)
    valid_loader = DataLoader(valid_set,
                              batch_size=batch_size,
                              collate_fn=generate_batch)

    #
    # Calc class weights
    #

    logging.info('Calc class weights ...')

    _, _, train_classes_stack, _ = zip(*train_set)
    train_freqs = np.array(train_classes_stack).mean(axis=0)

    class_weights = tensor(1 / train_freqs)

    #
    # Prepare training
    #

    logging.info('Prepare training ...')

    texter = texter.to(device)

    criterion = BCEWithLogitsLoss(pos_weight=class_weights.to(device))

    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in texter.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.01
    }, {
        'params': [
            p for n, p in texter.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters, lr=lr)

    writer = SummaryWriter(log_dir=log_dir)

    #
    # Train
    #

    logging.info('Train ...')

    best_valid_f1 = 0

    # Global progress for Tensorboard
    train_steps = 0
    valid_steps = 0

    for epoch in range(epoch_count):

        epoch_metrics = {
            'train': {
                'loss': 0.0,
                'pred_classes_stack': [],
                'gt_classes_stack': []
            },
            'valid': {
                'loss': 0.0,
                'pred_classes_stack': [],
                'gt_classes_stack': []
            }
        }

        #
        # Train
        #

        texter.train()

        for _, sents_batch, masks_batch, gt_batch in tqdm(
                train_loader, desc=f'Epoch {epoch}'):
            train_steps += len(sents_batch)

            sents_batch = sents_batch.to(device)
            masks_batch = masks_batch.to(device)
            gt_batch = gt_batch.to(device).float()

            logits_batch = texter(sents_batch, masks_batch)[0]
            loss = criterion(logits_batch, gt_batch)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            #
            # Log metrics
            #

            pred_batch = (logits_batch > 0).int()

            step_loss = loss.item()
            step_pred_batch = pred_batch.cpu().numpy().tolist()
            step_gt_batch = gt_batch.cpu().numpy().tolist()

            epoch_metrics['train']['loss'] += step_loss
            epoch_metrics['train']['pred_classes_stack'] += step_pred_batch
            epoch_metrics['train']['gt_classes_stack'] += step_gt_batch

            if log_steps:
                writer.add_scalars('loss', {'train': step_loss}, train_steps)

                step_metrics = {
                    'train': {
                        'pred_classes_stack': step_pred_batch,
                        'gt_classes_stack': step_gt_batch
                    }
                }

                log_class_metrics(step_metrics, writer, train_steps,
                                  class_count)
                log_macro_metrics(step_metrics, writer, train_steps)

            if try_batch_size:
                break

        #
        # Validate
        #

        texter.eval()

        for _, sents_batch, masks_batch, gt_batch in tqdm(
                valid_loader, desc=f'Epoch {epoch}'):
            valid_steps += len(sents_batch)

            sents_batch = sents_batch.to(device)
            masks_batch = masks_batch.to(device)
            gt_batch = gt_batch.to(device).float()

            logits_batch = texter(sents_batch, masks_batch)[0]
            loss = criterion(logits_batch, gt_batch)

            #
            # Log metrics
            #

            pred_batch = (logits_batch > 0).int()

            step_loss = loss.item()
            step_pred_batch = pred_batch.cpu().numpy().tolist()
            step_gt_batch = gt_batch.cpu().numpy().tolist()

            epoch_metrics['valid']['loss'] += step_loss
            epoch_metrics['valid']['pred_classes_stack'] += step_pred_batch
            epoch_metrics['valid']['gt_classes_stack'] += step_gt_batch

            if log_steps:
                writer.add_scalars('loss', {'valid': step_loss}, valid_steps)

                step_metrics = {
                    'valid': {
                        'pred_classes_stack': step_pred_batch,
                        'gt_classes_stack': step_gt_batch
                    }
                }

                log_class_metrics(step_metrics, writer, valid_steps,
                                  class_count)
                log_macro_metrics(step_metrics, writer, valid_steps)

            if try_batch_size:
                break

        #
        # Log loss
        #

        train_loss = epoch_metrics['train']['loss'] / len(train_loader)
        valid_loss = epoch_metrics['valid']['loss'] / len(valid_loader)

        writer.add_scalars('loss', {
            'train': train_loss,
            'valid': valid_loss
        }, epoch)

        #
        # Log metrics
        #

        log_class_metrics(epoch_metrics, writer, epoch, class_count)
        valid_f1 = log_macro_metrics(epoch_metrics, writer, epoch)

        #
        # Persist Texter
        #

        if valid_f1 > best_valid_f1:
            best_valid_f1 = valid_f1
            texter_pkl.save(texter)

        if try_batch_size:
            break
示例#7
0
def train_texter(args):
    ruler_pkl_path = args.ruler_pkl
    texter_pkl_path = args.texter_pkl
    sent_count = args.sent_count
    split_dir_path = args.split_dir
    text_dir_path = args.text_dir

    epoch_count = args.epoch_count
    log_dir = args.log_dir
    log_steps = args.log_steps
    lr = args.lr
    overwrite = args.overwrite
    sent_len = args.sent_len

    #
    # Check that (input) POWER Ruler PKL exists
    #

    logging.info('Check that (input) POWER Ruler PKL exists ...')

    ruler_pkl = RulerPkl(Path(ruler_pkl_path))
    ruler_pkl.check()

    #
    # Check that (input) POWER Texter PKL exists
    #

    logging.info('Check that (input) POWER Texter PKL exists ...')

    texter_pkl = TexterPkl(Path(texter_pkl_path))
    texter_pkl.check()

    #
    # Check that (input) POWER Split Directory exists
    #

    logging.info('Check that (input) POWER Split Directory exists ...')

    split_dir = SplitDir(Path(split_dir_path))
    split_dir.check()

    #
    # Check that (input) IRT Text Directory exists
    #

    logging.info('Check that (input) IRT Text Directory exists ...')

    text_dir = TextDir(Path(text_dir_path))
    text_dir.check()

    #
    # Load ruler
    #

    logging.info('Load ruler ...')

    ruler = ruler_pkl.load()

    #
    # Load texter
    #

    logging.info('Load texter ...')

    texter = texter_pkl.load().cpu()

    #
    # Build POWER
    #

    power = Aggregator(texter, ruler)

    #
    # Load facts
    #

    logging.info('Load facts ...')

    ent_to_lbl = split_dir.entities_tsv.load()
    rel_to_lbl = split_dir.relations_tsv.load()

    train_facts = split_dir.train_facts_tsv.load()
    train_facts = {
        Fact.from_ints(head, rel, tail, ent_to_lbl, rel_to_lbl)
        for head, _, rel, _, tail, _ in train_facts
    }

    known_valid_facts = split_dir.valid_facts_known_tsv.load()
    unknown_valid_facts = split_dir.valid_facts_unknown_tsv.load()

    known_eval_facts = known_valid_facts
    all_valid_facts = known_valid_facts + unknown_valid_facts

    known_facts = {
        Fact.from_ints(head, rel, tail, ent_to_lbl, rel_to_lbl)
        for head, _, rel, _, tail, _ in known_eval_facts
    }

    all_valid_facts = {
        Fact.from_ints(head, rel, tail, ent_to_lbl, rel_to_lbl)
        for head, _, rel, _, tail, _ in all_valid_facts
    }

    #
    # Load entities
    #

    logging.info('Load entities ...')

    train_ents = split_dir.train_entities_tsv.load()
    valid_ents = split_dir.valid_entities_tsv.load()

    train_ents = [Ent(ent, lbl) for ent, lbl in train_ents.items()]
    valid_ents = [Ent(ent, lbl) for ent, lbl in valid_ents.items()]

    #
    # Load texts
    #

    logging.info('Load texts ...')

    train_ent_to_sents = text_dir.cw_train_sents_txt.load()
    valid_ent_to_sents = text_dir.ow_valid_sents_txt.load()

    #
    # Prepare training
    #

    criterion = MSELoss()

    writer = SummaryWriter(log_dir=log_dir)

    #
    # Train
    #

    logging.info('Train ...')

    texter_optimizer = SGD([power.texter_weight], lr=lr)
    ruler_optimizer = SGD([power.ruler_weight], lr=lr)

    for epoch in range(epoch_count):

        for ent in train_ents:
            print(power.texter_weight)
            print(power.ruler_weight)
            print()

            #
            # Get entity ground truth facts
            #

            gt_facts = [fact for fact in train_facts if fact.head == ent]

            logging.debug('Ground truth:')
            for fact in gt_facts:
                logging.debug(str(fact))

            #
            # Train Texter Weight
            #

            sents = list(train_ent_to_sents[ent.id])[:sent_count]
            if len(sents) < sent_count:
                logging.warning(
                    f'Only {len(sents)} sentences for entity "{ent.lbl}" ({ent.id}). Skipping.'
                )
                continue

            texter_preds = texter.predict(ent, sents)

            train_confs = [pred.conf for pred in texter_preds]
            gt_confs = [
                1 if pred.fact in gt_facts else 0 for pred in texter_preds
            ]

            for train_conf, gt_conf in zip(train_confs, gt_confs):
                loss = criterion(
                    torch.tensor(train_conf) * power.texter_weight,
                    torch.tensor(gt_conf).float())
                texter_optimizer.zero_grad()
                loss.backward()
                texter_optimizer.step()

            #
            # Train Ruler Weight
            #

            ruler_preds = ruler.predict(ent)

            train_confs = [pred.conf for pred in ruler_preds]
            gt_confs = [
                1 if pred.fact in gt_facts else 0 for pred in ruler_preds
            ]

            for train_conf, gt_conf in zip(train_confs, gt_confs):
                loss = criterion(
                    torch.tensor(train_conf) * power.ruler_weight,
                    torch.tensor(gt_conf).float())
                ruler_optimizer.zero_grad()
                loss.backward()
                ruler_optimizer.step()