Пример #1
0
def load_seealso(seealso_file,
                 subprop_file,
                 use_split=False,
                 train_pids=None,
                 dev_pids=None,
                 test_pids=None):
    # load subprop
    subprops = read_subprop_file(subprop_file)
    subtrees, isolate = get_all_subtree(subprops)

    pid2nei_split: Dict[str, set] = defaultdict(set)
    pid2newpid: Dict[str, str] = {}
    if use_split:
        pid2split = {}
        for split, pids in enumerate([train_pids, dev_pids, test_pids]):
            for pid in pids:
                if pid in pid2split:
                    raise Exception('duplicate pids')
                pid2split[pid] = split
        # special cases for split tier
        for subtree in subtrees:
            for pid, childs in subtree.traverse_each_tier():
                if not PropertySubtree.is_split(pid, childs):
                    continue
                rawpid = pid.split('_')[0]
                if pid != rawpid:
                    pid2newpid[rawpid] = pid
                for c1, c2 in combinations(childs, 2):
                    if c1 in pid2split and c2 in pid2split and pid2split[
                            c1] == pid2split[c2]:
                        # in the same split
                        pid2nei_split[c1].add(c2)
                        pid2nei_split[c2].add(c1)

    # load see also
    pid2nei: Dict[str, set] = defaultdict(set)
    with open(seealso_file, 'r') as fin:
        for l in fin:
            l = l.strip()
            if l == '':
                continue
            h, _, t = l.split('\t')
            if h in pid2newpid:
                h = pid2newpid[h]
            if t in pid2newpid:
                t = pid2newpid[t]
            pid2nei[h].add(t)

    if use_split:
        pid2nei.update(pid2nei_split)

    return dict(pid2nei)
Пример #2
0
def run_emb_train(
        data_dir,
        emb_file,
        subprop_file,
        use_label=False,
        filter_leaves=False,
        only_test_on=None,
        epoch=10,
        input_size=400,
        batch_size=128,
        use_cuda=False,
        early_stop=None,
        num_occs=10,
        num_occs_label=10,
        hidden_size=128,
        dropout=0.0,
        lr=0.001,
        only_prop=False,
        use_tbow=0,
        tbow_emb_size=50,
        word_emb_file=None,
        suffix='.tbow',  # tbow 1
        use_tbow2=0,
        tbow_emb_size2=50,
        word_emb_file2=None,
        suffix2='.tbow',  # tbow 2
        use_sent=0,
        sent_emb_size=50,
        sent_emb_file=None,
        sent_suffix='.sent',  # sent
        only_tbow=False,
        only_sent=False,
        renew_word_emb=False,
        output_pred=False,
        use_ancestor=False,
        filter_labels=False,
        acc_topk=1,
        use_weight=False,
        only_one_sample_per_prop=False,
        optimizer='adam',
        use_gnn=None,
        sent_emb_method='cnn_mean',
        sent_hidden_size=16,
        lr_decay=0):
    subprops = read_subprop_file(subprop_file)
    pid2plabel = get_pid2plabel(subprops)
    subtrees, _ = get_all_subtree(subprops)
    is_parent = get_is_parent(subprops)
    is_ancestor = get_is_ancestor(subtrees)
    leaves = get_leaves(subtrees)
    pid2parent = dict([(c, p) for p, c in is_parent])

    emb_id2ind, emb = read_embeddings_from_text_file(emb_file,
                                                     debug=False,
                                                     emb_size=200,
                                                     use_padding=True,
                                                     first_line=False)

    train_samples = read_nway_file(os.path.join(data_dir, 'train.nway'))
    dev_samples = read_nway_file(os.path.join(data_dir, 'dev.nway'))
    test_samples = read_nway_file(os.path.join(data_dir, 'test.nway'))
    print('#train, #dev, #test {}, {}, {}'.format(len(train_samples),
                                                  len(dev_samples),
                                                  len(test_samples)))

    label2ind = load_tsv_as_dict(os.path.join(data_dir, 'label2ind.txt'),
                                 valuefunc=int)
    ind2label = dict((v, k) for k, v in label2ind.items())

    train_labels = set(s[1] for s in train_samples)
    test_labels = set(s[1] for s in test_samples)
    join_labels = train_labels & test_labels
    print('#labels in train & test {}'.format(len(join_labels)))

    anc2ind: Dict[str, int] = defaultdict(lambda: len(anc2ind))
    if use_ancestor:

        def get_anc_label(parent_label):
            parent_label = ind2label[parent_label]
            if parent_label in pid2parent:
                return anc2ind[pid2parent[parent_label]]
            else:
                return anc2ind['NO_ANC']

        train_samples = [((pid, poccs), (plabel, get_anc_label(plabel)))
                         for (pid, poccs), plabel in train_samples]
        dev_samples = [((pid, poccs), (plabel, get_anc_label(plabel)))
                       for (pid, poccs), plabel in dev_samples]
        test_samples = [((pid, poccs), (plabel, get_anc_label(plabel)))
                        for (pid, poccs), plabel in test_samples]
    print('#ancestor {}'.format(len(anc2ind)))

    train_samples_li, dev_samples_li, test_samples_li = [train_samples
                                                         ], [dev_samples
                                                             ], [test_samples]
    for split, sl in [('train', train_samples_li), ('dev', dev_samples_li),
                      ('test', test_samples_li)]:
        for is_use, suff in [(use_tbow, suffix), (use_tbow2, suffix)]:
            if is_use:
                samples_more = read_tbow_file(
                    os.path.join(data_dir, split + suff))
                assert len(samples_more) == len(sl[0])
                sl.append(samples_more)
        for is_use, suff in [(use_sent, sent_suffix)]:
            if is_use:
                samples_more = read_sent_file(
                    os.path.join(data_dir, split + suff))
                assert len(samples_more) == len(sl[0])
                sl.append(samples_more)

    if len(train_samples_li) > 1:
        train_samples = list(zip(*train_samples_li))
        dev_samples = list(zip(*dev_samples_li))
        test_samples = list(zip(*test_samples_li))

    vocab_size, vocab_size2, sent_vocab_size = None, None, None
    word_emb, word_emb2, sent_emb = None, None, None
    if use_tbow:
        if word_emb_file:
            word_emb_id2ind, word_emb = read_embeddings_from_text_file(
                word_emb_file,
                debug=False,
                emb_size=tbow_emb_size,
                first_line=False,
                use_padding=True,
                split_char=' ')
            vocab_size = len(word_emb_id2ind)
            if renew_word_emb:
                word_emb = None
        else:
            vocab_size = len(
                load_tsv_as_dict(
                    os.path.join(data_dir,
                                 '{}.vocab'.format(suffix.lstrip('.')))))

    if use_tbow2:
        if word_emb_file2:
            word_emb_id2ind2, word_emb2 = read_embeddings_from_text_file(
                word_emb_file2,
                debug=False,
                emb_size=tbow_emb_size2,
                first_line=False,
                use_padding=True,
                split_char=' ')
            vocab_size2 = len(word_emb_id2ind2)
            if renew_word_emb:
                word_emb2 = None
        else:
            vocab_size2 = len(
                load_tsv_as_dict(
                    os.path.join(data_dir,
                                 '{}.vocab'.format(suffix2.lstrip('.')))))

    if use_sent:
        if sent_emb_file:
            sent_emb_id2ind, sent_emb = read_embeddings_from_text_file(
                sent_emb_file,
                debug=False,
                emb_size=sent_emb_size,
                first_line=False,
                use_padding=True,
                split_char=' ')
            sent_vocab_size = len(sent_emb_id2ind)
        else:
            sent_vocab_size = len(
                load_tsv_as_dict(
                    os.path.join(data_dir,
                                 '{}.vocab'.format(sent_suffix.lstrip('.')))))

    print('vocab size 1 {}'.format(vocab_size))
    print('vocab size 2 {}'.format(vocab_size2))
    print('sent vocab size {}'.format(sent_vocab_size))

    if filter_leaves:
        print('filter leaves')
        filter_pids = set(leaves)
        if use_tbow or use_sent:
            train_samples = [
                s for s in train_samples if s[0][0][0] not in filter_pids
            ]
            dev_samples = [
                s for s in dev_samples if s[0][0][0] not in filter_pids
            ]
            test_samples = [
                s for s in test_samples if s[0][0][0] not in filter_pids
            ]
        else:
            train_samples = [
                s for s in train_samples if s[0][0] not in filter_pids
            ]
            dev_samples = [
                s for s in dev_samples if s[0][0] not in filter_pids
            ]
            test_samples = [
                s for s in test_samples if s[0][0] not in filter_pids
            ]

    if filter_labels:
        print('filter labels')
        if use_tbow or use_sent:
            train_samples = [
                s for s in train_samples if s[0][1] in join_labels
            ]
            dev_samples = [s for s in dev_samples if s[0][1] in join_labels]
            test_samples = [s for s in test_samples if s[0][1] in join_labels]
        else:
            train_samples = [s for s in train_samples if s[1] in join_labels]
            dev_samples = [s for s in dev_samples if s[1] in join_labels]
            test_samples = [s for s in test_samples if s[1] in join_labels]

    if only_one_sample_per_prop:

        def filter_first(samples, key_func):
            dup = set()
            new_samples = []
            for s in samples:
                k = key_func(s)
                if k in dup:
                    continue
                dup.add(k)
                new_samples.append(s)
            return new_samples

        if use_tbow or use_sent:
            key_func = lambda s: s[0][0][0]
        else:
            key_func = lambda s: s[0][0]
        train_samples = filter_first(train_samples, key_func=key_func)
        dev_samples = filter_first(dev_samples, key_func=key_func)
        test_samples = filter_first(test_samples, key_func=key_func)

    if only_test_on:
        if use_tbow or use_sent:
            test_samples = [
                s for s in test_samples if s[0][0][0] in only_test_on
            ]
        else:
            test_samples = [s for s in test_samples if s[0][0] in only_test_on]

    if use_label:
        label_samples = read_nway_file(
            os.path.join(data_dir, 'label2occs.nway'))
        label_samples_dict: Dict[int, List] = defaultdict(list)
        for (pid, occs), l in label_samples:
            occs = list(occs)
            shuffle(occs)  # TODO: better than shuffle?
            label_samples_dict[l].extend(occs)
        label_samples = [((ind2label[l], label_samples_dict[l]), l)
                         for l in sorted(label_samples_dict.keys())]
        label_samples_li = [label_samples]
        # extend label samples by features
        for is_use, suff in [(use_tbow, suffix), (use_tbow2, suffix)]:
            if is_use:
                samples_more = read_tbow_file(
                    os.path.join(data_dir, 'label' + suff))
                assert len(samples_more) == len(label_samples_li[0])
                label_samples_li.append(samples_more)
        for is_use, suff in [(use_sent, sent_suffix)]:
            if is_use:
                samples_more = read_sent_file(
                    os.path.join(data_dir, 'label' + suff))
                assert len(samples_more) == len(label_samples_li[0])
                label_samples_li.append(samples_more)
        if len(label_samples_li) > 1:
            label_samples = list(zip(*label_samples_li))
    else:
        label_samples = None

    if use_cuda:
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    print('#samples in train/dev/test: {} {} {}'.format(
        len(train_samples), len(dev_samples), len(test_samples)))

    if use_gnn:
        graph_data, meta_data = build_graph(train_samples, dev_samples,
                                            test_samples, 'data/see_also.tsv',
                                            os.path.join(data_dir, 'subprops'),
                                            emb_id2ind, emb)
        graph_data = dict((k, v.to(device)) for k, v in graph_data.items())

        emb_model = EmbGnnModel(feat_size=input_size,
                                hidden_size=hidden_size,
                                num_class=len(label2ind),
                                dropout=dropout,
                                method=use_gnn)
        emb_model.to(device)

        optimizer = torch.optim.RMSprop(emb_model.parameters(), lr=lr)

        last_metric, last_count = None, 0
        metrics = []
        for e in range(epoch):
            # train
            emb_model.train()
            train_logits, _, _, train_loss, _, _ = emb_model(**graph_data)
            train_logits = train_logits.detach().cpu().numpy()
            train_pred = [
                (pid, train_logits[i], l) for i, (pid, l) in enumerate(
                    zip(meta_data['train_pids'], meta_data['train_labels']))
            ]

            if train_loss.requires_grad:
                optimizer.zero_grad()
                train_loss.backward()
                torch.nn.utils.clip_grad_norm_(emb_model.parameters(), 1.0)
                optimizer.step()

            # eval
            emb_model.eval()
            _, dev_logits, test_logits, _, dev_loss, test_loss = emb_model(
                **graph_data)
            dev_logits = dev_logits.detach().cpu().numpy()
            test_logits = test_logits.detach().cpu().numpy()

            dev_pred = [(pid, dev_logits[i], l) for i, (pid, l) in enumerate(
                zip(meta_data['dev_pids'], meta_data['dev_labels']))]
            test_pred = [(pid, test_logits[i], l) for i, (pid, l) in enumerate(
                zip(meta_data['test_pids'], meta_data['test_labels']))]

            # metrics
            train_metric, train_ranks, train_pred_label, _ = accuracy_nway(
                train_pred,
                ind2label=ind2label,
                topk=acc_topk,
                num_classes=len(label2ind))
            dev_metric, dev_ranks, dev_pred_label, _ = accuracy_nway(
                dev_pred,
                ind2label=ind2label,
                topk=acc_topk,
                num_classes=len(label2ind))
            test_metric, test_ranks, test_pred_label, _ = accuracy_nway(
                test_pred,
                ind2label=ind2label,
                topk=acc_topk,
                num_classes=len(label2ind))

            print(
                'train: {:>.3f}, {:>.3f} dev: {:>.3f}, {:>.3f} test: {:>.3f}, {:>.3f}'
                .format(train_loss.item(), train_metric, dev_loss.item(),
                        dev_metric, test_loss.item(), test_metric))

            if early_stop and last_metric and last_metric > dev_metric:
                last_count += 1
                if last_count >= early_stop:
                    print('early stop')
                    break
            last_metric = dev_metric
            metrics.append(test_metric)

        return metrics, test_ranks, dev_ranks, train_ranks

    emb_model = EmbModel(emb,
                         len(label2ind),
                         len(anc2ind),
                         input_size=input_size,
                         hidden_size=hidden_size,
                         padding_idx=0,
                         dropout=dropout,
                         only_prop=only_prop,
                         use_label=use_label,
                         vocab_size=vocab_size,
                         tbow_emb_size=tbow_emb_size,
                         word_emb=word_emb,
                         vocab_size2=vocab_size2,
                         tbow_emb_size2=tbow_emb_size2,
                         word_emb2=word_emb2,
                         sent_vocab_size=sent_vocab_size,
                         sent_emb_size=sent_emb_size,
                         sent_emb=sent_emb,
                         only_tbow=only_tbow,
                         only_sent=only_sent,
                         use_weight=use_weight,
                         sent_emb_method=sent_emb_method,
                         sent_hidden_size=sent_hidden_size)
    emb_model.to(device)
    if optimizer == 'adam':
        optimizer = torch.optim.Adam(emb_model.parameters(), lr=lr)
    elif optimizer == 'sgd':
        optimizer = torch.optim.SGD(emb_model.parameters(), lr=lr)
    elif optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(emb_model.parameters(), lr=lr)
    else:
        raise NotImplementedError
    if lr_decay:
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, 'max', patience=lr_decay)

    last_metric, last_count = None, 0
    metrics = []
    train_tensors = samples2tensors(train_samples,
                                    batch_size,
                                    emb_id2ind,
                                    num_occs,
                                    device,
                                    only_prop,
                                    use_tbow=use_tbow,
                                    use_tbow2=use_tbow2,
                                    use_sent=use_sent,
                                    use_anc=use_ancestor)
    dev_tensors = samples2tensors(dev_samples,
                                  batch_size,
                                  emb_id2ind,
                                  num_occs,
                                  device,
                                  only_prop,
                                  use_tbow=use_tbow,
                                  use_tbow2=use_tbow2,
                                  use_sent=use_sent,
                                  use_anc=use_ancestor)
    test_tensors = samples2tensors(test_samples,
                                   batch_size,
                                   emb_id2ind,
                                   num_occs,
                                   device,
                                   only_prop,
                                   use_tbow=use_tbow,
                                   use_tbow2=use_tbow2,
                                   use_sent=use_sent,
                                   use_anc=use_ancestor)
    if use_label:
        label_tensors = samples2tensors(label_samples,
                                        len(label_samples),
                                        emb_id2ind,
                                        num_occs_label,
                                        device,
                                        only_prop,
                                        use_tbow=use_tbow,
                                        use_tbow2=use_tbow2,
                                        use_sent=use_sent,
                                        use_anc=use_ancestor)[0]
    else:
        label_tensors = None
    for e in range(epoch):
        # train
        train_loss, train_pred, _ = one_epoch(emb_model,
                                              optimizer,
                                              train_samples,
                                              train_tensors,
                                              batch_size,
                                              emb_id2ind,
                                              num_occs,
                                              device,
                                              only_prop,
                                              is_train=True,
                                              label_samples=label_samples,
                                              label_tensors=label_tensors,
                                              use_tbow=use_tbow,
                                              use_sent=use_sent,
                                              use_anc=use_ancestor,
                                              num_occs_label=num_occs_label)
        # dev
        dev_loss, dev_pred, _ = one_epoch(emb_model,
                                          optimizer,
                                          dev_samples,
                                          dev_tensors,
                                          batch_size,
                                          emb_id2ind,
                                          num_occs,
                                          device,
                                          only_prop,
                                          is_train=False,
                                          label_samples=label_samples,
                                          label_tensors=label_tensors,
                                          use_tbow=use_tbow,
                                          use_sent=use_sent,
                                          use_anc=use_ancestor,
                                          num_occs_label=num_occs_label)
        # test
        test_loss, test_pred, _ = one_epoch(emb_model,
                                            optimizer,
                                            test_samples,
                                            test_tensors,
                                            batch_size,
                                            emb_id2ind,
                                            num_occs,
                                            device,
                                            only_prop,
                                            is_train=False,
                                            label_samples=label_samples,
                                            label_tensors=label_tensors,
                                            use_tbow=use_tbow,
                                            use_sent=use_sent,
                                            use_anc=use_ancestor,
                                            num_occs_label=num_occs_label)

        train_metric, train_ranks, train_pred_label, _ = accuracy_nway(
            train_pred,
            ind2label=ind2label,
            topk=acc_topk,
            num_classes=len(label2ind))
        dev_metric, dev_ranks, dev_pred_label, _ = accuracy_nway(
            dev_pred,
            ind2label=ind2label,
            topk=acc_topk,
            num_classes=len(label2ind))
        test_metric, test_ranks, test_pred_label, _ = accuracy_nway(
            test_pred,
            ind2label=ind2label,
            topk=acc_topk,
            num_classes=len(label2ind))

        if lr_decay:
            scheduler.step(dev_metric)

        print(
            'train: {:>.3f}, {:>.3f} dev: {:>.3f}, {:>.3f} test: {:>.3f}, {:>.3f}'
            .format(np.mean(train_loss), train_metric, np.mean(dev_loss),
                    dev_metric, np.mean(test_loss), test_metric))

        metrics.append(test_metric)

        if early_stop and last_metric and last_metric >= dev_metric:
            last_count += 1
            if last_count >= early_stop:
                print('early stop')
                break
        last_metric = dev_metric

    # get rank
    _, train_ranks, _, _ = accuracy_nway(train_pred,
                                         ind2label=ind2label,
                                         topk=None,
                                         num_classes=len(label2ind))
    _, dev_ranks, _, _ = accuracy_nway(dev_pred,
                                       ind2label=ind2label,
                                       topk=None,
                                       num_classes=len(label2ind))
    _, test_ranks, _, _ = accuracy_nway(test_pred,
                                        ind2label=ind2label,
                                        topk=None,
                                        num_classes=len(label2ind))

    test_ranks = get_ranks(test_ranks,
                           is_parent=is_parent,
                           is_ancestor=is_ancestor)
    dev_ranks = get_ranks(dev_ranks,
                          is_parent=is_parent,
                          is_ancestor=is_ancestor)
    train_ranks = get_ranks(train_ranks,
                            is_parent=is_parent,
                            is_ancestor=is_ancestor)

    if output_pred:
        for fn in ['train', 'dev', 'test']:
            with open(os.path.join(data_dir, fn + '.pred'), 'w') as fout:
                for pl in eval('{}_pred_label'.format(fn)):
                    fout.write('{}\n'.format(pl))
    return metrics, test_ranks, dev_ranks, train_ranks, emb_model
Пример #3
0
def compute_overlap(data_dir,
                    split,
                    poccs,
                    subprops,
                    emb,
                    emb_id2ind,
                    top=100,
                    method='cosine',
                    only_prop_emb=False,
                    detect_cheat=True,
                    use_minus=False,
                    filter_num_poccs=None,
                    filter_pids=None,
                    num_workers=1,
                    skip_split=False,
                    ori_subprops=None,
                    debug=False,
                    use_norm=False,
                    **kwargs):
    emb_dim = emb.shape[1]
    print('#words in emb: {}'.format(len(emb_id2ind)))

    # load poccs
    poccs = os.path.join(data_dir, poccs)
    with open(poccs, 'rb') as fin:
        poccs = pickle.load(fin)

    # sample occurrences for each property
    all_ids = set()
    sampled_poccs = {}
    for k in poccs:
        if top:
            sampled_poccs[k] = random.sample(poccs[k], len(poccs[k]))[:top]
        else:
            sampled_poccs[k] = poccs[k]
        if len(sampled_poccs[k]) > 0:
            heads, tails = zip(*sampled_poccs[k])
            all_ids.update(heads)
            all_ids.update(tails)
    all_ids_emb = all_ids & set(emb_id2ind.keys())
    print('#entities in property occs: {}, {} have embs'.format(
        len(all_ids), len(all_ids_emb)))

    # load properties for test
    if type(split) is str:
        split = [split]
    props = []
    for sp in split:
        props += list(load_tsv_as_dict(os.path.join(data_dir, sp)).keys())

    # load labels/parents
    labels = list(
        load_tsv_as_dict(os.path.join(data_dir, 'label2ind.txt')).keys())

    # load subprops
    subprops = read_subprop_file(subprops)
    is_parent = get_is_parent(subprops)
    subtrees, _ = get_all_subtree(subprops)
    is_ancestor = get_is_ancestor(subtrees)
    pid2plabel = get_pid2plabel(subprops)
    if skip_split:
        ori_subprops = read_subprop_file(ori_subprops)
        ori_parents = set([p for p, c in get_is_parent(ori_subprops)])

    print('#property: {} #label: {}'.format(len(props), len(labels)))
    if debug:
        get_label_dist(props, labels, is_parent, pid2plabel)

    # collect embs for labels
    label_embs = []
    for l in labels:
        if l not in sampled_poccs or len(sampled_poccs[l]) <= 0:
            continue
        if skip_split and l not in ori_parents:
            continue
        ls = sampled_poccs[l]
        if filter_num_poccs and len(poccs[l]) < filter_num_poccs:
            continue
        if only_prop_emb:
            ls_emb = np.array([emb[emb_id2ind[l]]])
        else:
            ls_emb = [
                emb[emb_id2ind[e]] for o in ls for e in o
                if o[0] in emb_id2ind and o[1] in emb_id2ind
            ]
            ls_emb = np.array(ls_emb).reshape(-1, 2 * emb_dim)
        if use_minus:
            ls_emb = ls_emb[:, :emb_dim] - ls_emb[:, emb_dim:]
        if ls_emb.shape[0] <= 0:
            continue
        if use_norm:
            ls_emb = norm(ls_emb)
        if method == 'kde':
            kde = KernelDensity(kernel='gaussian',
                                bandwidth=kwargs['sigma']).fit(ls_emb)
            kde_p = kde.score_samples(ls_emb)
        else:
            kde = None
            kde_p = None
        label_embs.append((l, ls, ls_emb, kde, kde_p))

    # collect embs for properties
    prop_embs = []
    for tp in props:
        if filter_pids and tp not in filter_pids:
            continue
        has_parent = False
        for l, _, _, _, _ in label_embs:
            if (l, tp) in is_parent:
                has_parent = True
        if not has_parent:
            continue
        tps = sampled_poccs[tp]
        if filter_num_poccs and len(poccs[tp]) < filter_num_poccs:
            continue
        if only_prop_emb:
            tps_emb = np.array([emb[emb_id2ind[tp]]])
        else:
            tps_emb = [
                emb[emb_id2ind[e]] for o in tps for e in o
                if o[0] in emb_id2ind and o[1] in emb_id2ind
            ]
            tps_emb = np.array(tps_emb).reshape(-1, 2 * emb_dim)
        if use_minus:
            tps_emb = tps_emb[:, :emb_dim] - tps_emb[:, emb_dim:]
        if len(tps_emb) == 0:
            continue
        if use_norm:
            tps_emb = norm(tps_emb)
        if method == 'kde':
            kde = KernelDensity(kernel='gaussian',
                                bandwidth=kwargs['sigma']).fit(tps_emb)
            kde_p = kde.score_samples(tps_emb)
        else:
            kde = None
            kde_p = None
        prop_embs.append((tp, tps, tps_emb, kde, kde_p))

    # iterate over properties
    start_time = time.time()
    pool = multiprocessing.Pool(num_workers)
    results = pool.map(
        partial(compute_one,
                label_embs=label_embs,
                detect_cheat=detect_cheat,
                is_parent=is_parent,
                is_ancestor=is_ancestor,
                method=method,
                **kwargs), prop_embs)
    print('use {} secs'.format(time.time() - start_time))

    # collect results
    total, correct, mrr = 0, 0, 0
    correct_li = []
    rank_dict: Dict[str, List] = {}
    kde_dict: Dict[str, Any] = {}
    for result in results:
        if result == None:
            continue
        tp, rank = result
        rank_dict[tp] = rank
        if (rank[0][0], tp) in is_parent:
            correct += 1
            correct_li.append((tp))
        rr = 0
        for i in range(len(rank)):
            if (rank[i][0], tp) in is_parent:
                rr = 1 / (i + 1)
                break
        mrr += rr
        total += 1

    print('acc {}, mrr {}, total {} properties and {} labels'.format(
        correct / total, mrr / total, total, len(label_embs)))
    if debug:
        get_label_dist(correct_li, labels, is_parent, pid2plabel)
        print('correct list:')
        print(correct_li)

    return rank_dict
Пример #4
0
        action='store_true',
        help=
        'whether empty split is allowed. used in within_tree and by_entail and by_entail-n_way'
    )
    parser.add_argument('--filter_test',
                        action='store_true',
                        help='whether to remove test pid before population')
    parser.add_argument('--remove_common_child',
                        action='store_true',
                        help='keep the deepest child')
    args = parser.parse_args()

    random.seed(2019)
    np.random.seed(2019)

    subprops = read_subprop_file(args.prop_file)
    if args.entityid2name_file:
        with open(args.entityid2name_file, 'rb') as fin:
            entityid2name = pickle.load(fin)
    else:
        entityid2name = None
    pid2plabel = get_pid2plabel(subprops, entityid2name=entityid2name)
    subtrees, isolate = get_all_subtree(subprops)

    if args.remove_common_child:
        print('remove common child')
        remove_common_child(subtrees)
    save_all_subtree(subtrees + isolate, subprops,
                     os.path.join(args.out_dir, 'subprops'))

    subtree_pids = set()  # only consider properties in subtrees