def load_seealso(seealso_file, subprop_file, use_split=False, train_pids=None, dev_pids=None, test_pids=None): # load subprop subprops = read_subprop_file(subprop_file) subtrees, isolate = get_all_subtree(subprops) pid2nei_split: Dict[str, set] = defaultdict(set) pid2newpid: Dict[str, str] = {} if use_split: pid2split = {} for split, pids in enumerate([train_pids, dev_pids, test_pids]): for pid in pids: if pid in pid2split: raise Exception('duplicate pids') pid2split[pid] = split # special cases for split tier for subtree in subtrees: for pid, childs in subtree.traverse_each_tier(): if not PropertySubtree.is_split(pid, childs): continue rawpid = pid.split('_')[0] if pid != rawpid: pid2newpid[rawpid] = pid for c1, c2 in combinations(childs, 2): if c1 in pid2split and c2 in pid2split and pid2split[ c1] == pid2split[c2]: # in the same split pid2nei_split[c1].add(c2) pid2nei_split[c2].add(c1) # load see also pid2nei: Dict[str, set] = defaultdict(set) with open(seealso_file, 'r') as fin: for l in fin: l = l.strip() if l == '': continue h, _, t = l.split('\t') if h in pid2newpid: h = pid2newpid[h] if t in pid2newpid: t = pid2newpid[t] pid2nei[h].add(t) if use_split: pid2nei.update(pid2nei_split) return dict(pid2nei)
def run_emb_train( data_dir, emb_file, subprop_file, use_label=False, filter_leaves=False, only_test_on=None, epoch=10, input_size=400, batch_size=128, use_cuda=False, early_stop=None, num_occs=10, num_occs_label=10, hidden_size=128, dropout=0.0, lr=0.001, only_prop=False, use_tbow=0, tbow_emb_size=50, word_emb_file=None, suffix='.tbow', # tbow 1 use_tbow2=0, tbow_emb_size2=50, word_emb_file2=None, suffix2='.tbow', # tbow 2 use_sent=0, sent_emb_size=50, sent_emb_file=None, sent_suffix='.sent', # sent only_tbow=False, only_sent=False, renew_word_emb=False, output_pred=False, use_ancestor=False, filter_labels=False, acc_topk=1, use_weight=False, only_one_sample_per_prop=False, optimizer='adam', use_gnn=None, sent_emb_method='cnn_mean', sent_hidden_size=16, lr_decay=0): subprops = read_subprop_file(subprop_file) pid2plabel = get_pid2plabel(subprops) subtrees, _ = get_all_subtree(subprops) is_parent = get_is_parent(subprops) is_ancestor = get_is_ancestor(subtrees) leaves = get_leaves(subtrees) pid2parent = dict([(c, p) for p, c in is_parent]) emb_id2ind, emb = read_embeddings_from_text_file(emb_file, debug=False, emb_size=200, use_padding=True, first_line=False) train_samples = read_nway_file(os.path.join(data_dir, 'train.nway')) dev_samples = read_nway_file(os.path.join(data_dir, 'dev.nway')) test_samples = read_nway_file(os.path.join(data_dir, 'test.nway')) print('#train, #dev, #test {}, {}, {}'.format(len(train_samples), len(dev_samples), len(test_samples))) label2ind = load_tsv_as_dict(os.path.join(data_dir, 'label2ind.txt'), valuefunc=int) ind2label = dict((v, k) for k, v in label2ind.items()) train_labels = set(s[1] for s in train_samples) test_labels = set(s[1] for s in test_samples) join_labels = train_labels & test_labels print('#labels in train & test {}'.format(len(join_labels))) anc2ind: Dict[str, int] = defaultdict(lambda: len(anc2ind)) if use_ancestor: def get_anc_label(parent_label): parent_label = ind2label[parent_label] if parent_label in pid2parent: return anc2ind[pid2parent[parent_label]] else: return anc2ind['NO_ANC'] train_samples = [((pid, poccs), (plabel, get_anc_label(plabel))) for (pid, poccs), plabel in train_samples] dev_samples = [((pid, poccs), (plabel, get_anc_label(plabel))) for (pid, poccs), plabel in dev_samples] test_samples = [((pid, poccs), (plabel, get_anc_label(plabel))) for (pid, poccs), plabel in test_samples] print('#ancestor {}'.format(len(anc2ind))) train_samples_li, dev_samples_li, test_samples_li = [train_samples ], [dev_samples ], [test_samples] for split, sl in [('train', train_samples_li), ('dev', dev_samples_li), ('test', test_samples_li)]: for is_use, suff in [(use_tbow, suffix), (use_tbow2, suffix)]: if is_use: samples_more = read_tbow_file( os.path.join(data_dir, split + suff)) assert len(samples_more) == len(sl[0]) sl.append(samples_more) for is_use, suff in [(use_sent, sent_suffix)]: if is_use: samples_more = read_sent_file( os.path.join(data_dir, split + suff)) assert len(samples_more) == len(sl[0]) sl.append(samples_more) if len(train_samples_li) > 1: train_samples = list(zip(*train_samples_li)) dev_samples = list(zip(*dev_samples_li)) test_samples = list(zip(*test_samples_li)) vocab_size, vocab_size2, sent_vocab_size = None, None, None word_emb, word_emb2, sent_emb = None, None, None if use_tbow: if word_emb_file: word_emb_id2ind, word_emb = read_embeddings_from_text_file( word_emb_file, debug=False, emb_size=tbow_emb_size, first_line=False, use_padding=True, split_char=' ') vocab_size = len(word_emb_id2ind) if renew_word_emb: word_emb = None else: vocab_size = len( load_tsv_as_dict( os.path.join(data_dir, '{}.vocab'.format(suffix.lstrip('.'))))) if use_tbow2: if word_emb_file2: word_emb_id2ind2, word_emb2 = read_embeddings_from_text_file( word_emb_file2, debug=False, emb_size=tbow_emb_size2, first_line=False, use_padding=True, split_char=' ') vocab_size2 = len(word_emb_id2ind2) if renew_word_emb: word_emb2 = None else: vocab_size2 = len( load_tsv_as_dict( os.path.join(data_dir, '{}.vocab'.format(suffix2.lstrip('.'))))) if use_sent: if sent_emb_file: sent_emb_id2ind, sent_emb = read_embeddings_from_text_file( sent_emb_file, debug=False, emb_size=sent_emb_size, first_line=False, use_padding=True, split_char=' ') sent_vocab_size = len(sent_emb_id2ind) else: sent_vocab_size = len( load_tsv_as_dict( os.path.join(data_dir, '{}.vocab'.format(sent_suffix.lstrip('.'))))) print('vocab size 1 {}'.format(vocab_size)) print('vocab size 2 {}'.format(vocab_size2)) print('sent vocab size {}'.format(sent_vocab_size)) if filter_leaves: print('filter leaves') filter_pids = set(leaves) if use_tbow or use_sent: train_samples = [ s for s in train_samples if s[0][0][0] not in filter_pids ] dev_samples = [ s for s in dev_samples if s[0][0][0] not in filter_pids ] test_samples = [ s for s in test_samples if s[0][0][0] not in filter_pids ] else: train_samples = [ s for s in train_samples if s[0][0] not in filter_pids ] dev_samples = [ s for s in dev_samples if s[0][0] not in filter_pids ] test_samples = [ s for s in test_samples if s[0][0] not in filter_pids ] if filter_labels: print('filter labels') if use_tbow or use_sent: train_samples = [ s for s in train_samples if s[0][1] in join_labels ] dev_samples = [s for s in dev_samples if s[0][1] in join_labels] test_samples = [s for s in test_samples if s[0][1] in join_labels] else: train_samples = [s for s in train_samples if s[1] in join_labels] dev_samples = [s for s in dev_samples if s[1] in join_labels] test_samples = [s for s in test_samples if s[1] in join_labels] if only_one_sample_per_prop: def filter_first(samples, key_func): dup = set() new_samples = [] for s in samples: k = key_func(s) if k in dup: continue dup.add(k) new_samples.append(s) return new_samples if use_tbow or use_sent: key_func = lambda s: s[0][0][0] else: key_func = lambda s: s[0][0] train_samples = filter_first(train_samples, key_func=key_func) dev_samples = filter_first(dev_samples, key_func=key_func) test_samples = filter_first(test_samples, key_func=key_func) if only_test_on: if use_tbow or use_sent: test_samples = [ s for s in test_samples if s[0][0][0] in only_test_on ] else: test_samples = [s for s in test_samples if s[0][0] in only_test_on] if use_label: label_samples = read_nway_file( os.path.join(data_dir, 'label2occs.nway')) label_samples_dict: Dict[int, List] = defaultdict(list) for (pid, occs), l in label_samples: occs = list(occs) shuffle(occs) # TODO: better than shuffle? label_samples_dict[l].extend(occs) label_samples = [((ind2label[l], label_samples_dict[l]), l) for l in sorted(label_samples_dict.keys())] label_samples_li = [label_samples] # extend label samples by features for is_use, suff in [(use_tbow, suffix), (use_tbow2, suffix)]: if is_use: samples_more = read_tbow_file( os.path.join(data_dir, 'label' + suff)) assert len(samples_more) == len(label_samples_li[0]) label_samples_li.append(samples_more) for is_use, suff in [(use_sent, sent_suffix)]: if is_use: samples_more = read_sent_file( os.path.join(data_dir, 'label' + suff)) assert len(samples_more) == len(label_samples_li[0]) label_samples_li.append(samples_more) if len(label_samples_li) > 1: label_samples = list(zip(*label_samples_li)) else: label_samples = None if use_cuda: device = torch.device('cuda') else: device = torch.device('cpu') print('#samples in train/dev/test: {} {} {}'.format( len(train_samples), len(dev_samples), len(test_samples))) if use_gnn: graph_data, meta_data = build_graph(train_samples, dev_samples, test_samples, 'data/see_also.tsv', os.path.join(data_dir, 'subprops'), emb_id2ind, emb) graph_data = dict((k, v.to(device)) for k, v in graph_data.items()) emb_model = EmbGnnModel(feat_size=input_size, hidden_size=hidden_size, num_class=len(label2ind), dropout=dropout, method=use_gnn) emb_model.to(device) optimizer = torch.optim.RMSprop(emb_model.parameters(), lr=lr) last_metric, last_count = None, 0 metrics = [] for e in range(epoch): # train emb_model.train() train_logits, _, _, train_loss, _, _ = emb_model(**graph_data) train_logits = train_logits.detach().cpu().numpy() train_pred = [ (pid, train_logits[i], l) for i, (pid, l) in enumerate( zip(meta_data['train_pids'], meta_data['train_labels'])) ] if train_loss.requires_grad: optimizer.zero_grad() train_loss.backward() torch.nn.utils.clip_grad_norm_(emb_model.parameters(), 1.0) optimizer.step() # eval emb_model.eval() _, dev_logits, test_logits, _, dev_loss, test_loss = emb_model( **graph_data) dev_logits = dev_logits.detach().cpu().numpy() test_logits = test_logits.detach().cpu().numpy() dev_pred = [(pid, dev_logits[i], l) for i, (pid, l) in enumerate( zip(meta_data['dev_pids'], meta_data['dev_labels']))] test_pred = [(pid, test_logits[i], l) for i, (pid, l) in enumerate( zip(meta_data['test_pids'], meta_data['test_labels']))] # metrics train_metric, train_ranks, train_pred_label, _ = accuracy_nway( train_pred, ind2label=ind2label, topk=acc_topk, num_classes=len(label2ind)) dev_metric, dev_ranks, dev_pred_label, _ = accuracy_nway( dev_pred, ind2label=ind2label, topk=acc_topk, num_classes=len(label2ind)) test_metric, test_ranks, test_pred_label, _ = accuracy_nway( test_pred, ind2label=ind2label, topk=acc_topk, num_classes=len(label2ind)) print( 'train: {:>.3f}, {:>.3f} dev: {:>.3f}, {:>.3f} test: {:>.3f}, {:>.3f}' .format(train_loss.item(), train_metric, dev_loss.item(), dev_metric, test_loss.item(), test_metric)) if early_stop and last_metric and last_metric > dev_metric: last_count += 1 if last_count >= early_stop: print('early stop') break last_metric = dev_metric metrics.append(test_metric) return metrics, test_ranks, dev_ranks, train_ranks emb_model = EmbModel(emb, len(label2ind), len(anc2ind), input_size=input_size, hidden_size=hidden_size, padding_idx=0, dropout=dropout, only_prop=only_prop, use_label=use_label, vocab_size=vocab_size, tbow_emb_size=tbow_emb_size, word_emb=word_emb, vocab_size2=vocab_size2, tbow_emb_size2=tbow_emb_size2, word_emb2=word_emb2, sent_vocab_size=sent_vocab_size, sent_emb_size=sent_emb_size, sent_emb=sent_emb, only_tbow=only_tbow, only_sent=only_sent, use_weight=use_weight, sent_emb_method=sent_emb_method, sent_hidden_size=sent_hidden_size) emb_model.to(device) if optimizer == 'adam': optimizer = torch.optim.Adam(emb_model.parameters(), lr=lr) elif optimizer == 'sgd': optimizer = torch.optim.SGD(emb_model.parameters(), lr=lr) elif optimizer == 'rmsprop': optimizer = torch.optim.RMSprop(emb_model.parameters(), lr=lr) else: raise NotImplementedError if lr_decay: scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, 'max', patience=lr_decay) last_metric, last_count = None, 0 metrics = [] train_tensors = samples2tensors(train_samples, batch_size, emb_id2ind, num_occs, device, only_prop, use_tbow=use_tbow, use_tbow2=use_tbow2, use_sent=use_sent, use_anc=use_ancestor) dev_tensors = samples2tensors(dev_samples, batch_size, emb_id2ind, num_occs, device, only_prop, use_tbow=use_tbow, use_tbow2=use_tbow2, use_sent=use_sent, use_anc=use_ancestor) test_tensors = samples2tensors(test_samples, batch_size, emb_id2ind, num_occs, device, only_prop, use_tbow=use_tbow, use_tbow2=use_tbow2, use_sent=use_sent, use_anc=use_ancestor) if use_label: label_tensors = samples2tensors(label_samples, len(label_samples), emb_id2ind, num_occs_label, device, only_prop, use_tbow=use_tbow, use_tbow2=use_tbow2, use_sent=use_sent, use_anc=use_ancestor)[0] else: label_tensors = None for e in range(epoch): # train train_loss, train_pred, _ = one_epoch(emb_model, optimizer, train_samples, train_tensors, batch_size, emb_id2ind, num_occs, device, only_prop, is_train=True, label_samples=label_samples, label_tensors=label_tensors, use_tbow=use_tbow, use_sent=use_sent, use_anc=use_ancestor, num_occs_label=num_occs_label) # dev dev_loss, dev_pred, _ = one_epoch(emb_model, optimizer, dev_samples, dev_tensors, batch_size, emb_id2ind, num_occs, device, only_prop, is_train=False, label_samples=label_samples, label_tensors=label_tensors, use_tbow=use_tbow, use_sent=use_sent, use_anc=use_ancestor, num_occs_label=num_occs_label) # test test_loss, test_pred, _ = one_epoch(emb_model, optimizer, test_samples, test_tensors, batch_size, emb_id2ind, num_occs, device, only_prop, is_train=False, label_samples=label_samples, label_tensors=label_tensors, use_tbow=use_tbow, use_sent=use_sent, use_anc=use_ancestor, num_occs_label=num_occs_label) train_metric, train_ranks, train_pred_label, _ = accuracy_nway( train_pred, ind2label=ind2label, topk=acc_topk, num_classes=len(label2ind)) dev_metric, dev_ranks, dev_pred_label, _ = accuracy_nway( dev_pred, ind2label=ind2label, topk=acc_topk, num_classes=len(label2ind)) test_metric, test_ranks, test_pred_label, _ = accuracy_nway( test_pred, ind2label=ind2label, topk=acc_topk, num_classes=len(label2ind)) if lr_decay: scheduler.step(dev_metric) print( 'train: {:>.3f}, {:>.3f} dev: {:>.3f}, {:>.3f} test: {:>.3f}, {:>.3f}' .format(np.mean(train_loss), train_metric, np.mean(dev_loss), dev_metric, np.mean(test_loss), test_metric)) metrics.append(test_metric) if early_stop and last_metric and last_metric >= dev_metric: last_count += 1 if last_count >= early_stop: print('early stop') break last_metric = dev_metric # get rank _, train_ranks, _, _ = accuracy_nway(train_pred, ind2label=ind2label, topk=None, num_classes=len(label2ind)) _, dev_ranks, _, _ = accuracy_nway(dev_pred, ind2label=ind2label, topk=None, num_classes=len(label2ind)) _, test_ranks, _, _ = accuracy_nway(test_pred, ind2label=ind2label, topk=None, num_classes=len(label2ind)) test_ranks = get_ranks(test_ranks, is_parent=is_parent, is_ancestor=is_ancestor) dev_ranks = get_ranks(dev_ranks, is_parent=is_parent, is_ancestor=is_ancestor) train_ranks = get_ranks(train_ranks, is_parent=is_parent, is_ancestor=is_ancestor) if output_pred: for fn in ['train', 'dev', 'test']: with open(os.path.join(data_dir, fn + '.pred'), 'w') as fout: for pl in eval('{}_pred_label'.format(fn)): fout.write('{}\n'.format(pl)) return metrics, test_ranks, dev_ranks, train_ranks, emb_model
def compute_overlap(data_dir, split, poccs, subprops, emb, emb_id2ind, top=100, method='cosine', only_prop_emb=False, detect_cheat=True, use_minus=False, filter_num_poccs=None, filter_pids=None, num_workers=1, skip_split=False, ori_subprops=None, debug=False, use_norm=False, **kwargs): emb_dim = emb.shape[1] print('#words in emb: {}'.format(len(emb_id2ind))) # load poccs poccs = os.path.join(data_dir, poccs) with open(poccs, 'rb') as fin: poccs = pickle.load(fin) # sample occurrences for each property all_ids = set() sampled_poccs = {} for k in poccs: if top: sampled_poccs[k] = random.sample(poccs[k], len(poccs[k]))[:top] else: sampled_poccs[k] = poccs[k] if len(sampled_poccs[k]) > 0: heads, tails = zip(*sampled_poccs[k]) all_ids.update(heads) all_ids.update(tails) all_ids_emb = all_ids & set(emb_id2ind.keys()) print('#entities in property occs: {}, {} have embs'.format( len(all_ids), len(all_ids_emb))) # load properties for test if type(split) is str: split = [split] props = [] for sp in split: props += list(load_tsv_as_dict(os.path.join(data_dir, sp)).keys()) # load labels/parents labels = list( load_tsv_as_dict(os.path.join(data_dir, 'label2ind.txt')).keys()) # load subprops subprops = read_subprop_file(subprops) is_parent = get_is_parent(subprops) subtrees, _ = get_all_subtree(subprops) is_ancestor = get_is_ancestor(subtrees) pid2plabel = get_pid2plabel(subprops) if skip_split: ori_subprops = read_subprop_file(ori_subprops) ori_parents = set([p for p, c in get_is_parent(ori_subprops)]) print('#property: {} #label: {}'.format(len(props), len(labels))) if debug: get_label_dist(props, labels, is_parent, pid2plabel) # collect embs for labels label_embs = [] for l in labels: if l not in sampled_poccs or len(sampled_poccs[l]) <= 0: continue if skip_split and l not in ori_parents: continue ls = sampled_poccs[l] if filter_num_poccs and len(poccs[l]) < filter_num_poccs: continue if only_prop_emb: ls_emb = np.array([emb[emb_id2ind[l]]]) else: ls_emb = [ emb[emb_id2ind[e]] for o in ls for e in o if o[0] in emb_id2ind and o[1] in emb_id2ind ] ls_emb = np.array(ls_emb).reshape(-1, 2 * emb_dim) if use_minus: ls_emb = ls_emb[:, :emb_dim] - ls_emb[:, emb_dim:] if ls_emb.shape[0] <= 0: continue if use_norm: ls_emb = norm(ls_emb) if method == 'kde': kde = KernelDensity(kernel='gaussian', bandwidth=kwargs['sigma']).fit(ls_emb) kde_p = kde.score_samples(ls_emb) else: kde = None kde_p = None label_embs.append((l, ls, ls_emb, kde, kde_p)) # collect embs for properties prop_embs = [] for tp in props: if filter_pids and tp not in filter_pids: continue has_parent = False for l, _, _, _, _ in label_embs: if (l, tp) in is_parent: has_parent = True if not has_parent: continue tps = sampled_poccs[tp] if filter_num_poccs and len(poccs[tp]) < filter_num_poccs: continue if only_prop_emb: tps_emb = np.array([emb[emb_id2ind[tp]]]) else: tps_emb = [ emb[emb_id2ind[e]] for o in tps for e in o if o[0] in emb_id2ind and o[1] in emb_id2ind ] tps_emb = np.array(tps_emb).reshape(-1, 2 * emb_dim) if use_minus: tps_emb = tps_emb[:, :emb_dim] - tps_emb[:, emb_dim:] if len(tps_emb) == 0: continue if use_norm: tps_emb = norm(tps_emb) if method == 'kde': kde = KernelDensity(kernel='gaussian', bandwidth=kwargs['sigma']).fit(tps_emb) kde_p = kde.score_samples(tps_emb) else: kde = None kde_p = None prop_embs.append((tp, tps, tps_emb, kde, kde_p)) # iterate over properties start_time = time.time() pool = multiprocessing.Pool(num_workers) results = pool.map( partial(compute_one, label_embs=label_embs, detect_cheat=detect_cheat, is_parent=is_parent, is_ancestor=is_ancestor, method=method, **kwargs), prop_embs) print('use {} secs'.format(time.time() - start_time)) # collect results total, correct, mrr = 0, 0, 0 correct_li = [] rank_dict: Dict[str, List] = {} kde_dict: Dict[str, Any] = {} for result in results: if result == None: continue tp, rank = result rank_dict[tp] = rank if (rank[0][0], tp) in is_parent: correct += 1 correct_li.append((tp)) rr = 0 for i in range(len(rank)): if (rank[i][0], tp) in is_parent: rr = 1 / (i + 1) break mrr += rr total += 1 print('acc {}, mrr {}, total {} properties and {} labels'.format( correct / total, mrr / total, total, len(label_embs))) if debug: get_label_dist(correct_li, labels, is_parent, pid2plabel) print('correct list:') print(correct_li) return rank_dict
action='store_true', help= 'whether empty split is allowed. used in within_tree and by_entail and by_entail-n_way' ) parser.add_argument('--filter_test', action='store_true', help='whether to remove test pid before population') parser.add_argument('--remove_common_child', action='store_true', help='keep the deepest child') args = parser.parse_args() random.seed(2019) np.random.seed(2019) subprops = read_subprop_file(args.prop_file) if args.entityid2name_file: with open(args.entityid2name_file, 'rb') as fin: entityid2name = pickle.load(fin) else: entityid2name = None pid2plabel = get_pid2plabel(subprops, entityid2name=entityid2name) subtrees, isolate = get_all_subtree(subprops) if args.remove_common_child: print('remove common child') remove_common_child(subtrees) save_all_subtree(subtrees + isolate, subprops, os.path.join(args.out_dir, 'subprops')) subtree_pids = set() # only consider properties in subtrees