コード例 #1
0
ファイル: main.py プロジェクト: sysu2019/Rule-based-GNN
def dump_current_scores_of_devtest(args, m, xp):
    for mode in ['dev', 'test']:
        if mode == 'dev': current_data = dev_data
        if mode == 'test': current_data = test_data

        scores, accuracy = list(), list()
        for batch in chunked(current_data, args.test_batch_size):
            with chainer.using_config('train',
                                      False), chainer.no_backprop_mode():
                current_score = m.get_scores(batch, glinks, grelations, gedges,
                                             xp, mode)
            for v, (h, r, t, l) in zip(current_score.data, batch):
                values = (h, r, t, l, v)
                values = map(str, values)
                values = ','.join(values)
                scores.append(values)
                if v < args.threshold:  # 模型判断为正三元组
                    if l == 1: accuracy.append(1.0)
                    else: accuracy.append(0.0)
                else:  # 模型判断为负三元组
                    if l == 1: accuracy.append(0.0)
                    else: accuracy.append(1.0)
            del current_score
        tool.trace('\t ', mode, sum(accuracy) / len(accuracy))
        if args.margin_file != '':
            with open(args.margin_file, 'a') as wf:
                wf.write(mode + ':' + ' '.join(scores) + '\n')
コード例 #2
0
ファイル: main.py プロジェクト: sysu2019/Rule-based-GNN
def rules_addition():
    tool.trace('load rules')
    global premise
    premise = defaultdict(set)
    for line in tool.read(args.rules_file):
        r, r_, flag = list(map(int, line.strip().split('\t')))
        premise[r].add((r_, flag))
    print(len(premise[0]))
    print(len(premise))
コード例 #3
0
ファイル: main_CNN.py プロジェクト: ZMLight/PAKDD-code
def main(args):
    loadflag = False
    initilize_dataset()
    args.rel_size, args.entity_size = get_sizes(args)
    print('relation size:', args.rel_size, 'entity size:', args.entity_size)

    RelCovVal = [1.0] * args.rel_size
    RelCovValTmp = [0.0] * args.rel_size
    EntCovVal = [1.0] * args.entity_size
    EntCovValTmp = [0.0] * args.entity_size

    xp = Backend(args)
    m = get_model(args)
    opt = get_opt(args)
    if (loadflag == True):
        serializers.load_npz(
            './savedWN/ModelA3CNN_OOKBtest_' + args.pooling_method + '.model',
            m)
    opt.setup(m)

    for epoch in range(args.epoch_size):
        '''
		fileTmp = open('Ent100Eval_v4.txt','a')
		for i in range(100):
			fileTmp.write(str(EntCovVal[i])+'\t')
		fileTmp.write('\n')
		fileTmp.close()
		'''

        opt.alpha = args.beta0 / (1.0 + args.beta1 * epoch)
        trLoss, Ntr = train(args, m, xp, opt, RelCovVal, RelCovValTmp,
                            EntCovVal, EntCovValTmp)
        '''
		for i in range(args.rel_size):
			RelCovVal[i] = RelCovValTmp[i]
			RelCovValTmp[i] = 0.0
		for i in range(args.entity_size):
			EntCovVal[i] = EntCovValTmp[i]
			EntCovValTmp[i] = 0.0'''
        tool.trace('epoch:', epoch, 'tr Loss:', tool.dress(trLoss), Ntr)
        if (epoch % 5 == 0):
            serializers.save_npz(
                './savedWN/ModelA3CNN_OOKBtest_' + args.pooling_method +
                '.model', m)
        dump_current_scores_of_devtest(args, m, xp, RelCovVal, EntCovVal)
コード例 #4
0
ファイル: main_ATT2.py プロジェクト: ZMLight/PAKDD-code
def main(args):
	loadflag = False
	initilize_dataset()
	args.rel_size,args.entity_size = get_sizes(args)
	print('relation size:',args.rel_size,'entity size:',args.entity_size)

	W1lists = []
	file1 = open(args.pretrain_file + '/relation2vec.vec','r')
	for line in file1.readlines():
		line = line.strip().split('\t')
		W1lists.append(line)
	W1 = np.array(W1lists).astype(np.float32)

	W2lists = []
	file2 = open(args.pretrain_file + '/entity2vec.vec','r')
	for line in file2.readlines():
		line = line.strip().split('\t')
		W2lists.append(line)
	W2 = np.array(W2lists).astype(np.float32)

	#print W1.shape, W2.shape

	RelCovVal = [1.0] * args.rel_size
	RelCovValTmp = [0.0] * args.rel_size
	EntCovVal = [1.0] * args.entity_size
	EntCovValTmp = [0.0] * args.entity_size

	xp = Backend(args)
	m = get_model(args,W1,W2)
	opt = get_opt(args)
	if(loadflag == True):
		serializers.load_npz('./savedFB/ModelA3CNN_ATT_0817_head.model',m)
	opt.setup(m)
	normalizeR = m.Normal(xp)

	#random.shuffle(train_data)

	for epoch in range(args.epoch_size):
		opt.alpha = args.beta0/(1.0+args.beta1*epoch)
		trLoss,Ntr = train(args,m,xp,opt,RelCovVal,RelCovValTmp,EntCovVal,EntCovValTmp)
		normalizeR = m.Normal(xp)
		tool.trace('epoch:',epoch,'tr Loss:',tool.dress(trLoss),Ntr)
		if(epoch % 5 == 0):
			serializers.save_npz('./savedFB/ModelA3CNN_ATT_0817_head.model',m)
		dump_current_scores_of_devtest(args,m,xp,RelCovVal,EntCovVal)
コード例 #5
0
ファイル: main.py プロジェクト: sysu2019/Rule-based-GNN
def main(args):
    # 读入规则
    rules_addition()
    # 初始化所有数据集
    initilize_dataset()
    # 根据文件内容修改关系和实体的默认值
    args.rel_size, args.entity_size = get_sizes(args)
    print('relation size:', args.rel_size, 'entity size:', args.entity_size)

    xp = Backend(args)  #返回一个可调用的对象
    m = get_model(args)  # return A0(args)
    # Setup an optimizer
    # 设置训练时用的优化方法,默认为Adam
    opt = get_opt(args)  # return optimizers.Adam()
    # setup()方法只是为优化器提供一个link
    opt.setup(m)
    for epoch in range(args.epoch_size):
        opt.alpha = args.beta0 / (1.0 + args.beta1 * epoch)
        trLoss, Ntr = train(args, m, xp, opt)  #对应到main.py的train
        tool.trace('epoch:', epoch, 'tr Loss:', tool.dress(trLoss), Ntr)
        dump_current_scores_of_devtest(args, m, xp)
コード例 #6
0
ファイル: main.py プロジェクト: ZMLight/PAKDD-code
def main(args):
    loadflag = False
    initilize_dataset()
    args.rel_size, args.entity_size = get_sizes(args)
    print('relation size:', args.rel_size, 'entity size:', args.entity_size)

    xp = Backend(args)
    m = get_model(args)
    opt = get_opt(args)
    if (loadflag == True):
        serializers.load_npz(
            './savedWN/ModelA0_OOKBtest_' + args.pooling_method, m)
    opt.setup(m)
    for epoch in range(args.epoch_size):
        #dump_current_scores_of_devtest(args,m,xp)
        opt.alpha = args.beta0 / (1.0 + args.beta1 * epoch)
        trLoss, Ntr = train(args, m, xp, opt)
        tool.trace('epoch:', epoch, 'tr Loss:', tool.dress(trLoss), Ntr)
        if (epoch % 10 == 0):
            serializers.save_npz(
                './savedWN/ModelA0_OOKBtest_' + args.pooling_method, m)
        dump_current_scores_of_devtest(args, m, xp)
コード例 #7
0
ファイル: main.py プロジェクト: sysu2019/Rule-based-GNN
def initilize_dataset():
    global candidate_heads, gold_heads, candidate_tails, gold_tails
    global glinks, grelations, gedges
    global train_data, dev_data, test_data, trfreq

    # get properties of knowledge graph
    # 读取训练集
    tool.trace('load train')
    grelations = defaultdict(set)
    glinks = defaultdict(set)
    train_data = set()
    for line in tool.read(args.train_file):
        h, r, t = list(map(int, line.strip().split('\t')))
        # train_data包含了所有三元组
        train_data.add((
            h,
            r,
            t,
        ))
        # trfreq键值为关系ID,值为有该关系的三元组数量
        trfreq[r] += 1
        # grelations存储了所有关系,键值为(头实体,尾实体),值为关系ID
        grelations[(h, t)].add(r)
        # 存储邻居
        glinks[t].add(h)
        glinks[h].add(t)
        # gold_可以用于查找每个三元组的头实体和尾实体
        gold_heads[(r, t)].add(h)
        gold_tails[(h, r)].add(t)
        # candidate_以关系作为键值,每个键对应的值为与该关系相连的头实体或尾实体
        candidate_heads[r].add(h)
        candidate_tails[r].add(t)
        # 与glinks的存储一样,但glinks的键值包含了所有实体,而这个的键值只有头实体或尾实体
        tail_per_head[h].add(t)
        head_per_tail[t].add(h)
        # 引入规则
        if r in premise:
            conclusion = premise[r]
            for r_, flag in conclusion:
                if flag == 1:
                    grelations[(h, t)].add(r_)
                    train_data.add((
                        h,
                        r_,
                        t,
                    ))
                    trfreq[r_] += 1
                    gold_heads[(r_, t)].add(h)
                    gold_tails[(h, r_)].add(t)
                    candidate_heads[r_].add(h)
                    candidate_tails[r_].add(t)
                else:
                    grelations[(t, h)].add(r_)
                    train_data.add((
                        t,
                        r_,
                        h,
                    ))
                    trfreq[r_] += 1
                    gold_heads[(r_, h)].add(t)
                    gold_tails[(t, r_)].add(h)
                    candidate_heads[r_].add(t)
                    candidate_tails[r_].add(h)
                    tail_per_head[t].add(h)
                    head_per_tail[h].add(t)

    for e in glinks:
        glinks[e] = list(glinks[e])
    for r in candidate_heads:
        candidate_heads[r] = list(candidate_heads[r])
    for r in candidate_tails:
        candidate_tails[r] = list(candidate_tails[r])
    # 改写值,每个键值对应的为与它有关系的实体数量
    for h in tail_per_head:
        tail_per_head[h] = len(tail_per_head[h]) + 0.0
    for t in head_per_tail:
        head_per_tail[t] = len(head_per_tail[t]) + 0.0

    # 构造辅助集
    tool.trace('set axiaulity')
    # switch standard setting or OOKB setting
    # 2个实验,标准三元组分类实验和OOKB实验
    if args.train_file == args.auxiliary_file:
        tool.trace('standard setting, use: edges=links')
        gedges = glinks
    else:
        # ookb的实验只修改grelation和gedges?
        tool.trace('OOKB esetting, use: different edges')
        gedges = defaultdict(set)
        for line in tool.read(args.auxiliary_file):
            h, r, t = list(map(int, line.strip().split('\t')))
            grelations[(h, t)].add(r)
            # 引入规则
            if r in premise:
                conclusion = premise[r]
                for r_, flag in conclusion:
                    if flag == 1:
                        grelations[(h, t)].add(r_)
                    else:
                        grelations[(t, h)].add(r_)
            gedges[t].add(h)
            gedges[h].add(t)
        for e in gedges:
            gedges[e] = list(gedges[e])

    for (h, t) in grelations:
        grelations[(h, t)] = list(grelations[(h, t)])
        # if len(grelations[(h,t)]) != 1:
        # 	print("len(grelations[(h,t)])",len(grelations[(h,t)]),h,t,grelations[(h,t)])
    print("grelations", len(grelations))
    train_data = list(train_data)
    print("train_data", len(train_data))
    for r in trfreq:
        trfreq[r] = args.train_size / (float(trfreq[r]) * len(trfreq))

    # load dev
    # 读取验证集,验证集包含正三元组和负三元组
    tool.trace('load dev')
    dev_data = list()
    for line in open(args.dev_file):
        h, r, t, l = list(map(int, line.strip().split('\t')))
        # 过滤掉验证集中含OOKB实体的三元组(只在做OOKB的实验中会成立?)
        if h not in glinks or t not in glinks: continue
        dev_data.append((
            h,
            r,
            t,
            l,
        ))
    print('dev size:', len(dev_data))

    # load test
    # 读取测试集,测试集包含正三元组和负三元组
    tool.trace('load test')
    test_data = list()
    for line in open(args.test_file):
        h, r, t, l = list(map(int, line.strip().split('\t')))
        # 过滤掉测试集中含OOKB实体的三元组(做OOKB的实验不是过滤掉不含的吗?)
        # 标准三元组分类的实验中没有OOKB的实体?
        if h not in glinks or t not in glinks: continue
        test_data.append((
            h,
            r,
            t,
            l,
        ))
    print('test size:', len(test_data))
コード例 #8
0
def initilize_dataset():
    global candidate_heads, gold_heads, candidate_tails, gold_tails
    global glinks, grelations, gedges, grelationsT, grelationsEdges

    # get properties of knowledge graph
    tool.trace('load train')
    grelations = dict()
    grelationsT = defaultdict(set)
    glinks = defaultdict(set)
    for line in tool.read(args.train_file):
        h, r, t = list(map(int, line.strip().split('\t')))
        grelations[(h, t)] = r
        grelationsT[r].add((h, t))
        glinks[t].add(h)
        glinks[h].add(t)
        gold_heads[(r, t)].add(h)
        gold_tails[(h, r)].add(t)
        candidate_heads[r].add(h)
        candidate_tails[r].add(t)
        tail_per_head[h].add(t)
        head_per_tail[t].add(h)
    for e in glinks:
        glinks[e] = list(glinks[e])
    for r in grelationsT:
        grelationsT[r] = list(grelationsT[r])

    for r in candidate_heads:
        candidate_heads[r] = list(candidate_heads[r])
    for r in candidate_tails:
        candidate_tails[r] = list(candidate_tails[r])
    for h in tail_per_head:
        tail_per_head[h] = len(tail_per_head[h]) + 0.0
    for t in head_per_tail:
        head_per_tail[t] = len(head_per_tail[t]) + 0.0

    tool.trace('set axiaulity')
    # switch standard setting or OOKB setting
    if args.train_file == args.auxiliary_file:
        tool.trace('standard setting, use: edges=links')
        gedges = glinks
        grelationsEdges = grelationsT
    else:
        tool.trace('OOKB esetting, use: different edges')
        gedges = defaultdict(set)
        grelationsEdges = defaultdict(set)
        for line in tool.read(args.auxiliary_file):
            h, r, t = list(map(int, line.strip().split('\t')))
            grelationsEdges[r].add((h, t))
            #grelations[(h,t)]=r
            gedges[t].add(h)
            gedges[h].add(t)
        for e in gedges:
            gedges[e] = list(gedges[e])
        for r in grelationsEdges:
            grelationsEdges[r] = list(grelationsEdges[r])

    global train_data, dev_data, test_data, trfreq
    # load train
    train_data = set()
    for line in open(args.train_file):
        h, r, t = list(map(int, line.strip().split('\t')))
        train_data.add((
            h,
            r,
            t,
        ))
        trfreq[r] += 1
    train_data = list(train_data)
    for r in trfreq:
        trfreq[r] = args.train_size / (float(trfreq[r]) * len(trfreq))

    # load dev
    tool.trace('load dev')
    dev_data = list()
    for line in open(args.dev_file):
        h, r, t, l = list(map(int, line.strip().split('\t')))
        if h not in glinks or t not in glinks: continue
        dev_data.append((
            h,
            r,
            t,
            l,
        ))
    print('dev size:', len(dev_data))

    # load test
    tool.trace('load test')
    test_data = list()
    for line in open(args.test_file):
        h, r, t, l = list(map(int, line.strip().split('\t')))
        #if h not in glinks and h < args.entity_size: continue
        #if t not in glinks and t < args.entity_size: continue
        if h not in glinks or t not in glinks: continue
        test_data.append((
            h,
            r,
            t,
            l,
        ))
    print('test size:', len(test_data))
コード例 #9
0
ファイル: main.py プロジェクト: majingbit/GNN-for-OOKB
def main(args):
    global candidate_heads, gold_heads, candidate_tails, gold_tails, black_set
    xp = XP(args)
    args.rel_size, args.entity_size = get_sizes(args)
    print('relation size:', args.rel_size, 'entity size:', args.entity_size)
    m = get_model(args)
    opt = get_opt(args)
    opt.setup(m)

    relations = dict()
    links = defaultdict(set)
    for line in tool.read(args.train_file):
        items = list(map(int, line.strip().split('\t')))
        if len(items) == 4:
            h, r, t, l = items
            if l == 0: continue
        else:
            h, r, t = items
        relations[(h, t)] = r
        links[t].add(h)
        links[h].add(t)
        gold_heads[(r, t)].add(h)
        gold_tails[(h, r)].add(t)
        candidate_heads[r].add(h)
        candidate_tails[r].add(t)
        tail_per_head[h].add(t)
        head_per_tail[t].add(h)
    for e in links:
        links[e] = list(links[e])

    for p in gold_heads:
        if len(candidate_heads[p[0]] - gold_heads[p]) == 0:
            p = (-p[0], p[1])
            black_set.add(p)
    for p in gold_tails:
        if len(candidate_tails[p[1]] - gold_tails[p]) == 0:
            black_set.add(p)
    print('black list size:', len(black_set))
    for r in candidate_heads:
        candidate_heads[r] = list(candidate_heads[r])
    for r in candidate_tails:
        candidate_tails[r] = list(candidate_tails[r])
    for h in tail_per_head:
        tail_per_head[h] = len(tail_per_head[h]) + 0.0
    for t in head_per_tail:
        head_per_tail[t] = len(head_per_tail[t]) + 0.0

    if args.train_file == args.auxiliary_file:
        tool.trace('use: edges=links')
        edges = links
    else:
        tool.trace('use: different edges')
        edges = defaultdict(set)
        for line in tool.read(args.auxiliary_file):
            items = list(map(int, line.strip().split('\t')))
            if len(items) == 4:
                h, r, t, l = items
                if l == 0: continue
            else:
                h, r, t = items
            relations[(h, t)] = r
            edges[t].add(h)
            edges[h].add(t)
        for e in edges:
            edges[e] = list(edges[e])

    for epoch in range(args.epoch_size):
        opt.alpha = args.beta0 / (1.0 + args.beta1 * epoch)
        trLoss, Ntr = train(args, m, xp, opt, links, relations, edges)
        evaluate(args, m, xp, links, relations, edges)
        tool.trace('epoch:', epoch, 'tr Loss:', tool.dress(trLoss), Ntr)