示例#1
0
def main():
    """
    Main predict function for the wikikg90m
    """
    args = ArgParser().parse_args()
    config = load_model_config(
        os.path.join(args.model_path, 'model_config.json'))
    args = use_config_replace_args(args, config)
    dataset = get_dataset(args, args.data_path, args.dataset, args.format,
                          args.delimiter, args.data_files,
                          args.has_edge_importance)
    print("Load the dataset done.")
    eval_dataset = EvalDataset(dataset, args)

    model = BaseKEModel(
        args=args,
        n_entities=dataset.n_entities,
        n_relations=dataset.n_relations,
        model_name=args.model_name,
        hidden_size=args.hidden_dim,
        entity_feat_dim=dataset.entity_feat.shape[1],
        relation_feat_dim=dataset.relation_feat.shape[1],
        gamma=args.gamma,
        double_entity_emb=args.double_ent,
        cpu_emb=args.cpu_emb,
        relation_times=args.ote_size,
        scale_type=args.scale_type)

    print("Create the model done.")
    model.entity_feat = dataset.entity_feat
    model.relation_feat = dataset.relation_feat
    load_model_from_checkpoint(model, args.model_path)
    print("The model load the checkpoint done.")

    if args.infer_valid:
        valid_sampler_tail = eval_dataset.create_sampler(
            'valid',
            args.batch_size_eval,
            mode='tail',
            num_workers=args.num_workers,
            rank=0,
            ranks=1)
        infer(args, model, config, 0, [valid_sampler_tail], "valid")

    if args.infer_test:
        test_sampler_tail = eval_dataset.create_sampler(
            'test',
            args.batch_size_eval,
            mode='tail',
            num_workers=args.num_workers,
            rank=i,
            ranks=args.num_proc)
        infer(args, model, config, 0, [test_sampler_tail], "test")
示例#2
0
文件: train.py 项目: ztyskyearth/dgl
def run(args, logger):
    init_time_start = time.time()
    # load dataset and samplers
    dataset = get_dataset(args.data_path, args.dataset, args.format,
                          args.data_files)

    if args.neg_sample_size_eval < 0:
        args.neg_sample_size_eval = dataset.n_entities
    args.batch_size = get_compatible_batch_size(args.batch_size,
                                                args.neg_sample_size)
    args.batch_size_eval = get_compatible_batch_size(args.batch_size_eval,
                                                     args.neg_sample_size_eval)

    args.eval_filter = not args.no_eval_filter
    if args.neg_deg_sample_eval:
        assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."

    train_data = TrainDataset(dataset, args, ranks=args.num_proc)
    # if there is no cross partition relaiton, we fall back to strict_rel_part
    args.strict_rel_part = args.mix_cpu_gpu and (train_data.cross_part
                                                 == False)
    args.soft_rel_part = args.mix_cpu_gpu and args.soft_rel_part and train_data.cross_part
    args.num_workers = 8  # fix num_worker to 8

    if args.num_proc > 1:
        train_samplers = []
        for i in range(args.num_proc):
            train_sampler_head = train_data.create_sampler(
                args.batch_size,
                args.neg_sample_size,
                args.neg_sample_size,
                mode='head',
                num_workers=args.num_workers,
                shuffle=True,
                exclude_positive=False,
                rank=i)
            train_sampler_tail = train_data.create_sampler(
                args.batch_size,
                args.neg_sample_size,
                args.neg_sample_size,
                mode='tail',
                num_workers=args.num_workers,
                shuffle=True,
                exclude_positive=False,
                rank=i)
            train_samplers.append(
                NewBidirectionalOneShotIterator(train_sampler_head,
                                                train_sampler_tail,
                                                args.neg_sample_size,
                                                args.neg_sample_size, True,
                                                dataset.n_entities))

        train_sampler = NewBidirectionalOneShotIterator(
            train_sampler_head, train_sampler_tail, args.neg_sample_size,
            args.neg_sample_size, True, dataset.n_entities)
    else:  # This is used for debug
        train_sampler_head = train_data.create_sampler(
            args.batch_size,
            args.neg_sample_size,
            args.neg_sample_size,
            mode='head',
            num_workers=args.num_workers,
            shuffle=True,
            exclude_positive=False)
        train_sampler_tail = train_data.create_sampler(
            args.batch_size,
            args.neg_sample_size,
            args.neg_sample_size,
            mode='tail',
            num_workers=args.num_workers,
            shuffle=True,
            exclude_positive=False)
        train_sampler = NewBidirectionalOneShotIterator(
            train_sampler_head, train_sampler_tail, args.neg_sample_size,
            args.neg_sample_size, True, dataset.n_entities)

    if args.valid or args.test:
        if len(args.gpu) > 1:
            args.num_test_proc = args.num_proc if args.num_proc < len(
                args.gpu) else len(args.gpu)
        else:
            args.num_test_proc = args.num_proc
        eval_dataset = EvalDataset(dataset, args)

    if args.valid:
        if args.num_proc > 1:
            valid_sampler_heads = []
            valid_sampler_tails = []
            for i in range(args.num_proc):
                valid_sampler_head = eval_dataset.create_sampler(
                    'valid',
                    args.batch_size_eval,
                    args.neg_sample_size_eval,
                    args.neg_sample_size_eval,
                    args.eval_filter,
                    mode='chunk-head',
                    num_workers=args.num_workers,
                    rank=i,
                    ranks=args.num_proc)
                valid_sampler_tail = eval_dataset.create_sampler(
                    'valid',
                    args.batch_size_eval,
                    args.neg_sample_size_eval,
                    args.neg_sample_size_eval,
                    args.eval_filter,
                    mode='chunk-tail',
                    num_workers=args.num_workers,
                    rank=i,
                    ranks=args.num_proc)
                valid_sampler_heads.append(valid_sampler_head)
                valid_sampler_tails.append(valid_sampler_tail)
        else:  # This is used for debug
            valid_sampler_head = eval_dataset.create_sampler(
                'valid',
                args.batch_size_eval,
                args.neg_sample_size_eval,
                args.neg_sample_size_eval,
                args.eval_filter,
                mode='chunk-head',
                num_workers=args.num_workers,
                rank=0,
                ranks=1)
            valid_sampler_tail = eval_dataset.create_sampler(
                'valid',
                args.batch_size_eval,
                args.neg_sample_size_eval,
                args.neg_sample_size_eval,
                args.eval_filter,
                mode='chunk-tail',
                num_workers=args.num_workers,
                rank=0,
                ranks=1)
    if args.test:
        if args.num_test_proc > 1:
            test_sampler_tails = []
            test_sampler_heads = []
            for i in range(args.num_test_proc):
                test_sampler_head = eval_dataset.create_sampler(
                    'test',
                    args.batch_size_eval,
                    args.neg_sample_size_eval,
                    args.neg_sample_size_eval,
                    args.eval_filter,
                    mode='chunk-head',
                    num_workers=args.num_workers,
                    rank=i,
                    ranks=args.num_test_proc)
                test_sampler_tail = eval_dataset.create_sampler(
                    'test',
                    args.batch_size_eval,
                    args.neg_sample_size_eval,
                    args.neg_sample_size_eval,
                    args.eval_filter,
                    mode='chunk-tail',
                    num_workers=args.num_workers,
                    rank=i,
                    ranks=args.num_test_proc)
                test_sampler_heads.append(test_sampler_head)
                test_sampler_tails.append(test_sampler_tail)
        else:
            test_sampler_head = eval_dataset.create_sampler(
                'test',
                args.batch_size_eval,
                args.neg_sample_size_eval,
                args.neg_sample_size_eval,
                args.eval_filter,
                mode='chunk-head',
                num_workers=args.num_workers,
                rank=0,
                ranks=1)
            test_sampler_tail = eval_dataset.create_sampler(
                'test',
                args.batch_size_eval,
                args.neg_sample_size_eval,
                args.neg_sample_size_eval,
                args.eval_filter,
                mode='chunk-tail',
                num_workers=args.num_workers,
                rank=0,
                ranks=1)

    # load model
    model = load_model(logger, args, dataset.n_entities, dataset.n_relations)
    if args.num_proc > 1 or args.async_update:
        model.share_memory()

    # We need to free all memory referenced by dataset.
    eval_dataset = None
    dataset = None

    print('Total initialize time {:.3f} seconds'.format(time.time() -
                                                        init_time_start))

    # train
    start = time.time()
    rel_parts = train_data.rel_parts if args.strict_rel_part or args.soft_rel_part else None
    cross_rels = train_data.cross_rels if args.soft_rel_part else None
    if args.num_proc > 1:
        procs = []
        barrier = mp.Barrier(args.num_proc)
        for i in range(args.num_proc):
            valid_sampler = [valid_sampler_heads[i], valid_sampler_tails[i]
                             ] if args.valid else None
            proc = mp.Process(target=train_mp,
                              args=(args, model, train_samplers[i],
                                    valid_sampler, i, rel_parts, cross_rels,
                                    barrier))
            procs.append(proc)
            proc.start()
        for proc in procs:
            proc.join()
    else:
        valid_samplers = [valid_sampler_head, valid_sampler_tail
                          ] if args.valid else None
        train(args, model, train_sampler, valid_samplers, rel_parts=rel_parts)

    print('training takes {} seconds'.format(time.time() - start))

    if args.save_emb is not None:
        if not os.path.exists(args.save_emb):
            os.mkdir(args.save_emb)
        model.save_emb(args.save_emb, args.dataset)

        # We need to save the model configurations as well.
        conf_file = os.path.join(args.save_emb, 'config.json')
        with open(conf_file, 'w') as outfile:
            json.dump(
                {
                    'dataset': args.dataset,
                    'model': args.model_name,
                    'emb_size': args.hidden_dim,
                    'max_train_step': args.max_step,
                    'batch_size': args.batch_size,
                    'neg_sample_size': args.neg_sample_size,
                    'lr': args.lr,
                    'gamma': args.gamma,
                    'double_ent': args.double_ent,
                    'double_rel': args.double_rel,
                    'neg_adversarial_sampling': args.neg_adversarial_sampling,
                    'adversarial_temperature': args.adversarial_temperature,
                    'regularization_coef': args.regularization_coef,
                    'regularization_norm': args.regularization_norm
                },
                outfile,
                indent=4)

    # test
    if args.test:
        start = time.time()
        if args.num_test_proc > 1:
            queue = mp.Queue(args.num_test_proc)
            procs = []
            for i in range(args.num_test_proc):
                proc = mp.Process(target=test_mp,
                                  args=(args, model, [
                                      test_sampler_heads[i],
                                      test_sampler_tails[i]
                                  ], i, 'Test', queue))
                procs.append(proc)
                proc.start()

            total_metrics = {}
            metrics = {}
            logs = []
            for i in range(args.num_test_proc):
                log = queue.get()
                logs = logs + log

            for metric in logs[0].keys():
                metrics[metric] = sum([log[metric]
                                       for log in logs]) / len(logs)
            for k, v in metrics.items():
                print('Test average {} : {}'.format(k, v))

            for proc in procs:
                proc.join()
        else:
            test(args, model, [test_sampler_head, test_sampler_tail])
        print('testing takes {:.3f} seconds'.format(time.time() - start))
示例#3
0
def main(args):
    # load dataset and samplers
    dataset = get_dataset(args.data_path, args.dataset, args.format)
    args.pickle_graph = False
    args.train = False
    args.valid = False
    args.test = True
    args.batch_size_eval = args.batch_size

    logger = get_logger(args)
    # Here we want to use the regualr negative sampler because we need to ensure that
    # all positive edges are excluded.
    eval_dataset = EvalDataset(dataset, args)
    args.neg_sample_size_test = args.neg_sample_size
    if args.neg_sample_size < 0:
        args.neg_sample_size_test = args.neg_sample_size = eval_dataset.g.number_of_nodes(
        )
    if args.num_proc > 1:
        test_sampler_tails = []
        test_sampler_heads = []
        for i in range(args.num_proc):
            test_sampler_head = eval_dataset.create_sampler(
                'test',
                args.batch_size,
                args.neg_sample_size,
                mode='PBG-head',
                num_workers=args.num_worker,
                rank=i,
                ranks=args.num_proc)
            test_sampler_tail = eval_dataset.create_sampler(
                'test',
                args.batch_size,
                args.neg_sample_size,
                mode='PBG-tail',
                num_workers=args.num_worker,
                rank=i,
                ranks=args.num_proc)
            test_sampler_heads.append(test_sampler_head)
            test_sampler_tails.append(test_sampler_tail)
    else:
        test_sampler_head = eval_dataset.create_sampler(
            'test',
            args.batch_size,
            args.neg_sample_size,
            mode='PBG-head',
            num_workers=args.num_worker,
            rank=0,
            ranks=1)
        test_sampler_tail = eval_dataset.create_sampler(
            'test',
            args.batch_size,
            args.neg_sample_size,
            mode='PBG-tail',
            num_workers=args.num_worker,
            rank=0,
            ranks=1)

    # load model
    n_entities = dataset.n_entities
    n_relations = dataset.n_relations
    ckpt_path = args.model_path
    model = load_model_from_checkpoint(logger, args, n_entities, n_relations,
                                       ckpt_path)

    if args.num_proc > 1:
        model.share_memory()
    # test
    args.step = 0
    args.max_step = 0
    if args.num_proc > 1:
        procs = []
        for i in range(args.num_proc):
            proc = mp.Process(target=test,
                              args=(args, model, [
                                  test_sampler_heads[i], test_sampler_tails[i]
                              ]))
            procs.append(proc)
            proc.start()
        for proc in procs:
            proc.join()
    else:
        test(args, model, [test_sampler_head, test_sampler_tail])
示例#4
0
def main(args):
    args.eval_filter = not args.no_eval_filter
    if args.neg_deg_sample:
        assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."

    # load dataset and samplers
    dataset = get_dataset(args.data_path, args.dataset, args.format)
    args.pickle_graph = False
    args.train = False
    args.valid = False
    args.test = True
    args.batch_size_eval = args.batch_size

    logger = get_logger(args)
    # Here we want to use the regualr negative sampler because we need to ensure that
    # all positive edges are excluded.
    eval_dataset = EvalDataset(dataset, args)

    args.neg_sample_size_test = args.neg_sample_size
    args.neg_deg_sample_eval = args.neg_deg_sample
    if args.neg_sample_size < 0:
        args.neg_sample_size_test = args.neg_sample_size = eval_dataset.g.number_of_nodes(
        )
    if args.neg_chunk_size < 0:
        args.neg_chunk_size = args.neg_sample_size

    num_workers = args.num_worker
    # for multiprocessing evaluation, we don't need to sample multiple batches at a time
    # in each process.
    if args.num_proc > 1:
        num_workers = 1
    if args.num_proc > 1:
        test_sampler_tails = []
        test_sampler_heads = []
        for i in range(args.num_proc):
            test_sampler_head = eval_dataset.create_sampler(
                'test',
                args.batch_size,
                args.neg_sample_size,
                args.neg_chunk_size,
                args.eval_filter,
                mode='chunk-head',
                num_workers=num_workers,
                rank=i,
                ranks=args.num_proc)
            test_sampler_tail = eval_dataset.create_sampler(
                'test',
                args.batch_size,
                args.neg_sample_size,
                args.neg_chunk_size,
                args.eval_filter,
                mode='chunk-tail',
                num_workers=num_workers,
                rank=i,
                ranks=args.num_proc)
            test_sampler_heads.append(test_sampler_head)
            test_sampler_tails.append(test_sampler_tail)
    else:
        test_sampler_head = eval_dataset.create_sampler(
            'test',
            args.batch_size,
            args.neg_sample_size,
            args.neg_chunk_size,
            args.eval_filter,
            mode='chunk-head',
            num_workers=num_workers,
            rank=0,
            ranks=1)
        test_sampler_tail = eval_dataset.create_sampler(
            'test',
            args.batch_size,
            args.neg_sample_size,
            args.neg_chunk_size,
            args.eval_filter,
            mode='chunk-tail',
            num_workers=num_workers,
            rank=0,
            ranks=1)

    # load model
    n_entities = dataset.n_entities
    n_relations = dataset.n_relations
    ckpt_path = args.model_path
    model = load_model_from_checkpoint(logger, args, n_entities, n_relations,
                                       ckpt_path)

    if args.num_proc > 1:
        model.share_memory()
    # test
    args.step = 0
    args.max_step = 0
    start = time.time()
    if args.num_proc > 1:
        queue = mp.Queue(args.num_proc)
        procs = []
        for i in range(args.num_proc):
            proc = mp.Process(target=test,
                              args=(args, model, [
                                  test_sampler_heads[i], test_sampler_tails[i]
                              ], 'Test', queue))
            procs.append(proc)
            proc.start()
        for proc in procs:
            proc.join()

        total_metrics = {}
        for i in range(args.num_proc):
            metrics = queue.get()
            for k, v in metrics.items():
                if i == 0:
                    total_metrics[k] = v / args.num_proc
                else:
                    total_metrics[k] += v / args.num_proc
        for k, v in metrics.items():
            print('Test average {} at [{}/{}]: {}'.format(
                k, args.step, args.max_step, v))
    else:
        test(args, model, [test_sampler_head, test_sampler_tail])
    print('Test takes {:.3f} seconds'.format(time.time() - start))
示例#5
0
def run(args, logger):
    # load dataset and samplers
    dataset = get_dataset(args.data_path, args.dataset, args.format)
    n_entities = dataset.n_entities
    n_relations = dataset.n_relations
    if args.neg_sample_size_test < 0:
        args.neg_sample_size_test = n_entities

    train_data = TrainDataset(dataset, args, ranks=args.num_proc)
    if args.num_proc > 1:
        train_samplers = []
        for i in range(args.num_proc):
            train_sampler_head = train_data.create_sampler(
                args.batch_size,
                args.neg_sample_size,
                mode='PBG-head',
                num_workers=args.num_worker,
                shuffle=True,
                exclude_positive=True,
                rank=i)
            train_sampler_tail = train_data.create_sampler(
                args.batch_size,
                args.neg_sample_size,
                mode='PBG-tail',
                num_workers=args.num_worker,
                shuffle=True,
                exclude_positive=True,
                rank=i)
            train_samplers.append(
                NewBidirectionalOneShotIterator(train_sampler_head,
                                                train_sampler_tail, True,
                                                n_entities))
    else:
        train_sampler_head = train_data.create_sampler(
            args.batch_size,
            args.neg_sample_size,
            mode='PBG-head',
            num_workers=args.num_worker,
            shuffle=True,
            exclude_positive=True)
        train_sampler_tail = train_data.create_sampler(
            args.batch_size,
            args.neg_sample_size,
            mode='PBG-tail',
            num_workers=args.num_worker,
            shuffle=True,
            exclude_positive=True)
        train_sampler = NewBidirectionalOneShotIterator(
            train_sampler_head, train_sampler_tail, True, n_entities)

    if args.valid or args.test:
        eval_dataset = EvalDataset(dataset, args)
    if args.valid:
        # Here we want to use the regualr negative sampler because we need to ensure that
        # all positive edges are excluded.
        if args.num_proc > 1:
            valid_sampler_heads = []
            valid_sampler_tails = []
            for i in range(args.num_proc):
                valid_sampler_head = eval_dataset.create_sampler(
                    'valid',
                    args.batch_size_eval,
                    args.neg_sample_size_valid,
                    mode='PBG-head',
                    num_workers=args.num_worker,
                    rank=i,
                    ranks=args.num_proc)
                valid_sampler_tail = eval_dataset.create_sampler(
                    'valid',
                    args.batch_size_eval,
                    args.neg_sample_size_valid,
                    mode='PBG-tail',
                    num_workers=args.num_worker,
                    rank=i,
                    ranks=args.num_proc)
                valid_sampler_heads.append(valid_sampler_head)
                valid_sampler_tails.append(valid_sampler_tail)
        else:
            valid_sampler_head = eval_dataset.create_sampler(
                'valid',
                args.batch_size_eval,
                args.neg_sample_size_valid,
                mode='PBG-head',
                num_workers=args.num_worker,
                rank=0,
                ranks=1)
            valid_sampler_tail = eval_dataset.create_sampler(
                'valid',
                args.batch_size_eval,
                args.neg_sample_size_valid,
                mode='PBG-tail',
                num_workers=args.num_worker,
                rank=0,
                ranks=1)
    if args.test:
        # Here we want to use the regualr negative sampler because we need to ensure that
        # all positive edges are excluded.
        if args.num_proc > 1:
            test_sampler_tails = []
            test_sampler_heads = []
            for i in range(args.num_proc):
                test_sampler_head = eval_dataset.create_sampler(
                    'test',
                    args.batch_size_eval,
                    args.neg_sample_size_test,
                    mode='PBG-head',
                    num_workers=args.num_worker,
                    rank=i,
                    ranks=args.num_proc)
                test_sampler_tail = eval_dataset.create_sampler(
                    'test',
                    args.batch_size_eval,
                    args.neg_sample_size_test,
                    mode='PBG-tail',
                    num_workers=args.num_worker,
                    rank=i,
                    ranks=args.num_proc)
                test_sampler_heads.append(test_sampler_head)
                test_sampler_tails.append(test_sampler_tail)
        else:
            test_sampler_head = eval_dataset.create_sampler(
                'test',
                args.batch_size_eval,
                args.neg_sample_size_test,
                mode='PBG-head',
                num_workers=args.num_worker,
                rank=0,
                ranks=1)
            test_sampler_tail = eval_dataset.create_sampler(
                'test',
                args.batch_size_eval,
                args.neg_sample_size_test,
                mode='PBG-tail',
                num_workers=args.num_worker,
                rank=0,
                ranks=1)

    # We need to free all memory referenced by dataset.
    eval_dataset = None
    dataset = None
    # load model
    model = load_model(logger, args, n_entities, n_relations)

    if args.num_proc > 1:
        model.share_memory()

    # train
    start = time.time()
    if args.num_proc > 1:
        procs = []
        for i in range(args.num_proc):
            valid_samplers = [valid_sampler_heads[i], valid_sampler_tails[i]
                              ] if args.valid else None
            proc = mp.Process(target=train,
                              args=(args, model, train_samplers[i],
                                    valid_samplers))
            procs.append(proc)
            proc.start()
        for proc in procs:
            proc.join()
    else:
        valid_samplers = [valid_sampler_head, valid_sampler_tail
                          ] if args.valid else None
        train(args, model, train_sampler, valid_samplers)
    print('training takes {} seconds'.format(time.time() - start))

    if args.save_emb is not None:
        if not os.path.exists(args.save_emb):
            os.mkdir(args.save_emb)
        model.save_emb(args.save_emb, args.dataset)

    # test
    if args.test:
        if args.num_proc > 1:
            procs = []
            for i in range(args.num_proc):
                proc = mp.Process(target=test,
                                  args=(args, model, [
                                      test_sampler_heads[i],
                                      test_sampler_tails[i]
                                  ]))
                procs.append(proc)
                proc.start()
            for proc in procs:
                proc.join()
        else:
            test(args, model, [test_sampler_head, test_sampler_tail])
示例#6
0
文件: train.py 项目: prempv/dgl_kge
def run(args, logger):
    # load dataset and samplers
    dataset = get_dataset(args.data_path, args.dataset, args.format)
    n_entities = dataset.n_entities
    n_relations = dataset.n_relations
    if args.neg_sample_size_test < 0:
        args.neg_sample_size_test = n_entities
    args.eval_filter = not args.no_eval_filter
    if args.neg_deg_sample_eval:
        assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."

    # When we generate a batch of negative edges from a set of positive edges,
    # we first divide the positive edges into chunks and corrupt the edges in a chunk
    # together. By default, the chunk size is equal to the negative sample size.
    # Usually, this works well. But we also allow users to specify the chunk size themselves.
    if args.neg_chunk_size < 0:
        args.neg_chunk_size = args.neg_sample_size
    if args.neg_chunk_size_valid < 0:
        args.neg_chunk_size_valid = args.neg_sample_size_valid
    if args.neg_chunk_size_test < 0:
        args.neg_chunk_size_test = args.neg_sample_size_test

    train_data = TrainDataset(dataset, args, ranks=args.num_proc)
    if args.num_proc > 1:
        train_samplers = []
        for i in range(args.num_proc):
            train_sampler_head = train_data.create_sampler(
                args.batch_size,
                args.neg_sample_size,
                args.neg_chunk_size,
                mode='chunk-head',
                num_workers=args.num_worker,
                shuffle=True,
                exclude_positive=True,
                rank=i)
            train_sampler_tail = train_data.create_sampler(
                args.batch_size,
                args.neg_sample_size,
                args.neg_chunk_size,
                mode='chunk-tail',
                num_workers=args.num_worker,
                shuffle=True,
                exclude_positive=True,
                rank=i)
            train_samplers.append(
                NewBidirectionalOneShotIterator(train_sampler_head,
                                                train_sampler_tail,
                                                args.neg_chunk_size, True,
                                                n_entities))
    else:
        train_sampler_head = train_data.create_sampler(
            args.batch_size,
            args.neg_sample_size,
            args.neg_chunk_size,
            mode='chunk-head',
            num_workers=args.num_worker,
            shuffle=True,
            exclude_positive=True)
        train_sampler_tail = train_data.create_sampler(
            args.batch_size,
            args.neg_sample_size,
            args.neg_chunk_size,
            mode='chunk-tail',
            num_workers=args.num_worker,
            shuffle=True,
            exclude_positive=True)
        train_sampler = NewBidirectionalOneShotIterator(
            train_sampler_head, train_sampler_tail, args.neg_chunk_size, True,
            n_entities)

    # for multiprocessing evaluation, we don't need to sample multiple batches at a time
    # in each process.
    num_workers = args.num_worker
    if args.num_proc > 1:
        num_workers = 1
    if args.valid or args.test:
        eval_dataset = EvalDataset(dataset, args)
    if args.valid:
        # Here we want to use the regualr negative sampler because we need to ensure that
        # all positive edges are excluded.
        if args.num_proc > 1:
            valid_sampler_heads = []
            valid_sampler_tails = []
            for i in range(args.num_proc):
                valid_sampler_head = eval_dataset.create_sampler(
                    'valid',
                    args.batch_size_eval,
                    args.neg_sample_size_valid,
                    args.neg_chunk_size_valid,
                    args.eval_filter,
                    mode='chunk-head',
                    num_workers=num_workers,
                    rank=i,
                    ranks=args.num_proc)
                valid_sampler_tail = eval_dataset.create_sampler(
                    'valid',
                    args.batch_size_eval,
                    args.neg_sample_size_valid,
                    args.neg_chunk_size_valid,
                    args.eval_filter,
                    mode='chunk-tail',
                    num_workers=num_workers,
                    rank=i,
                    ranks=args.num_proc)
                valid_sampler_heads.append(valid_sampler_head)
                valid_sampler_tails.append(valid_sampler_tail)
        else:
            valid_sampler_head = eval_dataset.create_sampler(
                'valid',
                args.batch_size_eval,
                args.neg_sample_size_valid,
                args.neg_chunk_size_valid,
                args.eval_filter,
                mode='chunk-head',
                num_workers=num_workers,
                rank=0,
                ranks=1)
            valid_sampler_tail = eval_dataset.create_sampler(
                'valid',
                args.batch_size_eval,
                args.neg_sample_size_valid,
                args.neg_chunk_size_valid,
                args.eval_filter,
                mode='chunk-tail',
                num_workers=num_workers,
                rank=0,
                ranks=1)
    if args.test:
        # Here we want to use the regualr negative sampler because we need to ensure that
        # all positive edges are excluded.
        if args.num_proc > 1:
            test_sampler_tails = []
            test_sampler_heads = []
            for i in range(args.num_proc):
                test_sampler_head = eval_dataset.create_sampler(
                    'test',
                    args.batch_size_eval,
                    args.neg_sample_size_test,
                    args.neg_chunk_size_test,
                    args.eval_filter,
                    mode='chunk-head',
                    num_workers=num_workers,
                    rank=i,
                    ranks=args.num_proc)
                test_sampler_tail = eval_dataset.create_sampler(
                    'test',
                    args.batch_size_eval,
                    args.neg_sample_size_test,
                    args.neg_chunk_size_test,
                    args.eval_filter,
                    mode='chunk-tail',
                    num_workers=num_workers,
                    rank=i,
                    ranks=args.num_proc)
                test_sampler_heads.append(test_sampler_head)
                test_sampler_tails.append(test_sampler_tail)
        else:
            test_sampler_head = eval_dataset.create_sampler(
                'test',
                args.batch_size_eval,
                args.neg_sample_size_test,
                args.neg_chunk_size_test,
                args.eval_filter,
                mode='chunk-head',
                num_workers=num_workers,
                rank=0,
                ranks=1)
            test_sampler_tail = eval_dataset.create_sampler(
                'test',
                args.batch_size_eval,
                args.neg_sample_size_test,
                args.neg_chunk_size_test,
                args.eval_filter,
                mode='chunk-tail',
                num_workers=num_workers,
                rank=0,
                ranks=1)

    # We need to free all memory referenced by dataset.
    eval_dataset = None
    dataset = None
    # load model
    model = load_model(logger, args, n_entities, n_relations)

    if args.num_proc > 1:
        model.share_memory()

    # train
    start = time.time()
    if args.num_proc > 1:
        procs = []
        for i in range(args.num_proc):
            rel_parts = train_data.rel_parts if args.rel_part else None
            valid_samplers = [valid_sampler_heads[i], valid_sampler_tails[i]
                              ] if args.valid else None
            proc = mp.Process(target=train,
                              args=(args, model, train_samplers[i], i,
                                    rel_parts, valid_samplers))
            procs.append(proc)
            proc.start()
        for proc in procs:
            proc.join()
    else:
        valid_samplers = [valid_sampler_head, valid_sampler_tail
                          ] if args.valid else None
        train(args, model, train_sampler, valid_samplers)
    print('training takes {} seconds'.format(time.time() - start))

    if args.save_emb is not None:
        if not os.path.exists(args.save_emb):
            os.mkdir(args.save_emb)
        model.save_emb(args.save_emb, args.dataset)

    # test
    if args.test:
        start = time.time()
        if args.num_proc > 1:
            queue = mp.Queue(args.num_proc)
            procs = []
            for i in range(args.num_proc):
                proc = mp.Process(target=test,
                                  args=(args, model, [
                                      test_sampler_heads[i],
                                      test_sampler_tails[i]
                                  ], i, 'Test', queue))
                procs.append(proc)
                proc.start()

            total_metrics = {}
            for i in range(args.num_proc):
                metrics = queue.get()
                for k, v in metrics.items():
                    if i == 0:
                        total_metrics[k] = v / args.num_proc
                    else:
                        total_metrics[k] += v / args.num_proc
            for k, v in metrics.items():
                print('Test average {} at [{}/{}]: {}'.format(
                    k, args.step, args.max_step, v))

            for proc in procs:
                proc.join()
        else:
            test(args, model, [test_sampler_head, test_sampler_tail])
        print('test:', time.time() - start)
示例#7
0
def dist_train_test(args,
                    model,
                    train_sampler,
                    entity_pb,
                    relation_pb,
                    l2g,
                    rank=0,
                    rel_parts=None,
                    cross_rels=None,
                    barrier=None):
    if args.num_proc > 1:
        th.set_num_threads(args.num_thread)

    client = connect_to_kvstore(args, entity_pb, relation_pb, l2g)
    client.barrier()
    train_time_start = time.time()
    train(args, model, train_sampler, None, rank, rel_parts, cross_rels,
          barrier, client)
    client.barrier()
    print('Total train time {:.3f} seconds'.format(time.time() -
                                                   train_time_start))

    model = None

    if client.get_id() % args.num_client == 0:  # pull full model from kvstore

        args.num_test_proc = args.num_client
        dataset_full = get_dataset(args.data_path, args.dataset, args.format)

        print('Full data n_entities: ' + str(dataset_full.n_entities))
        print("Full data n_relations: " + str(dataset_full.n_relations))

        model_test = load_model(None, args, dataset_full.n_entities,
                                dataset_full.n_relations)
        eval_dataset = EvalDataset(dataset_full, args)

        if args.test:
            model_test.share_memory()

        if args.neg_sample_size_test < 0:
            args.neg_sample_size_test = dataset_full.n_entities
        args.eval_filter = not args.no_eval_filter
        if args.neg_deg_sample_eval:
            assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."

        if args.neg_chunk_size_valid < 0:
            args.neg_chunk_size_valid = args.neg_sample_size_valid
        if args.neg_chunk_size_test < 0:
            args.neg_chunk_size_test = args.neg_sample_size_test

        print("Pull relation_emb ...")
        relation_id = F.arange(0, model_test.n_relations)
        relation_data = client.pull(name='relation_emb', id_tensor=relation_id)
        model_test.relation_emb.emb[relation_id] = relation_data

        print("Pull entity_emb ... ")
        # split model into 100 small parts
        start = 0
        percent = 0
        entity_id = F.arange(0, model_test.n_entities)
        count = int(model_test.n_entities / 100)
        end = start + count
        while True:
            print("Pull %d / 100 ..." % percent)
            if end >= model_test.n_entities:
                end = -1
            tmp_id = entity_id[start:end]
            entity_data = client.pull(name='entity_emb', id_tensor=tmp_id)
            model_test.entity_emb.emb[tmp_id] = entity_data
            if end == -1:
                break
            start = end
            end += count
            percent += 1

        if args.save_emb is not None:
            if not os.path.exists(args.save_emb):
                os.mkdir(args.save_emb)
            model_test.save_emb(args.save_emb, args.dataset)

        if args.test:
            args.num_thread = 1
            test_sampler_tails = []
            test_sampler_heads = []
            for i in range(args.num_test_proc):
                test_sampler_head = eval_dataset.create_sampler(
                    'test',
                    args.batch_size_eval,
                    args.neg_sample_size_test,
                    args.neg_chunk_size_test,
                    args.eval_filter,
                    mode='chunk-head',
                    num_workers=args.num_thread,
                    rank=i,
                    ranks=args.num_test_proc)
                test_sampler_tail = eval_dataset.create_sampler(
                    'test',
                    args.batch_size_eval,
                    args.neg_sample_size_test,
                    args.neg_chunk_size_test,
                    args.eval_filter,
                    mode='chunk-tail',
                    num_workers=args.num_thread,
                    rank=i,
                    ranks=args.num_test_proc)
                test_sampler_heads.append(test_sampler_head)
                test_sampler_tails.append(test_sampler_tail)

            eval_dataset = None
            dataset_full = None

            print("Run test, test processes: %d" % args.num_test_proc)

            queue = mp.Queue(args.num_test_proc)
            procs = []
            for i in range(args.num_test_proc):
                proc = mp.Process(target=test_mp,
                                  args=(args, model_test, [
                                      test_sampler_heads[i],
                                      test_sampler_tails[i]
                                  ], i, 'Test', queue))
                procs.append(proc)
                proc.start()

            total_metrics = {}
            metrics = {}
            logs = []
            for i in range(args.num_test_proc):
                log = queue.get()
                logs = logs + log

            for metric in logs[0].keys():
                metrics[metric] = sum([log[metric]
                                       for log in logs]) / len(logs)
            for k, v in metrics.items():
                print('Test average {} at [{}/{}]: {}'.format(
                    k, args.step, args.max_step, v))

            for proc in procs:
                proc.join()

        if client.get_id() == 0:
            client.shut_down()
示例#8
0
文件: main.py 项目: WenjinW/PGL
def main():
    args = ArgParser().parse_args()
    prepare_save_path(args)
    args.neg_sample_size_eval = 1000
    set_global_seed(args.seed)

    init_time_start = time.time()
    dataset = get_dataset(args, args.data_path, args.dataset, args.format,
                          args.delimiter, args.data_files,
                          args.has_edge_importance)
    args.batch_size = get_compatible_batch_size(args.batch_size,
                                                args.neg_sample_size)
    args.batch_size_eval = get_compatible_batch_size(args.batch_size_eval,
                                                     args.neg_sample_size_eval)

    #print(args)
    set_logger(args)

    print("To build training dataset")
    t1 = time.time()
    train_data = TrainDataset(dataset,
                              args,
                              has_importance=args.has_edge_importance)
    print("Training dataset built, it takes %s seconds" % (time.time() - t1))
    args.num_workers = 8  # fix num_worker to 8
    print("Building training sampler")
    t1 = time.time()
    train_sampler_head = train_data.create_sampler(
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        neg_sample_size=args.neg_sample_size,
        neg_mode='head')
    train_sampler_tail = train_data.create_sampler(
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        neg_sample_size=args.neg_sample_size,
        neg_mode='tail')
    train_sampler = NewBidirectionalOneShotIterator(train_sampler_head,
                                                    train_sampler_tail)
    print("Training sampler created, it takes %s seconds" % (time.time() - t1))

    if args.valid or args.test:
        if len(args.gpu) > 1:
            args.num_test_proc = args.num_proc if args.num_proc < len(
                args.gpu) else len(args.gpu)
        else:
            args.num_test_proc = args.num_proc
        print("To create eval_dataset")
        t1 = time.time()
        eval_dataset = EvalDataset(dataset, args)
        print("eval_dataset created, it takes %d seconds" % (time.time() - t1))

    if args.valid:
        if args.num_proc > 1:
            valid_samplers = []
            for i in range(args.num_proc):
                print("creating valid sampler for proc %d" % i)
                t1 = time.time()
                valid_sampler_tail = eval_dataset.create_sampler(
                    'valid',
                    args.batch_size_eval,
                    mode='tail',
                    num_workers=args.num_workers,
                    rank=i,
                    ranks=args.num_proc)
                valid_samplers.append(valid_sampler_tail)
                print(
                    "Valid sampler for proc %d created, it takes %s seconds" %
                    (i, time.time() - t1))
        else:
            valid_sampler_tail = eval_dataset.create_sampler(
                'valid',
                args.batch_size_eval,
                mode='tail',
                num_workers=args.num_workers,
                rank=0,
                ranks=1)
            valid_samplers = [valid_sampler_tail]

    for arg in vars(args):
        logging.info('{:20}:{}'.format(arg, getattr(args, arg)))

    print("To create model")
    t1 = time.time()
    model = BaseKEModel(args=args,
                        n_entities=dataset.n_entities,
                        n_relations=dataset.n_relations,
                        model_name=args.model_name,
                        hidden_size=args.hidden_dim,
                        entity_feat_dim=dataset.entity_feat.shape[1],
                        relation_feat_dim=dataset.relation_feat.shape[1],
                        gamma=args.gamma,
                        double_entity_emb=args.double_ent,
                        relation_times=args.ote_size,
                        scale_type=args.scale_type)

    model.entity_feat = dataset.entity_feat
    model.relation_feat = dataset.relation_feat
    print(len(model.parameters()))

    if args.cpu_emb:
        print("using cpu emb\n" * 5)
    else:
        print("using gpu emb\n" * 5)
    optimizer = paddle.optimizer.Adam(learning_rate=args.mlp_lr,
                                      parameters=model.parameters())
    lr_tensor = paddle.to_tensor(args.lr)

    global_step = 0
    tic_train = time.time()
    log = {}
    log["loss"] = 0.0
    log["regularization"] = 0.0
    for step in range(0, args.max_step):
        pos_triples, neg_triples, ids, neg_head = next(train_sampler)
        loss = model.forward(pos_triples, neg_triples, ids, neg_head)

        log["loss"] = loss.numpy()[0]
        if args.regularization_coef > 0.0 and args.regularization_norm > 0:
            coef, nm = args.regularization_coef, args.regularization_norm
            reg = coef * norm(model.entity_embedding.curr_emb, nm)
            log['regularization'] = reg.numpy()[0]
            loss = loss + reg

        loss.backward()
        optimizer.step()
        if args.cpu_emb:
            model.entity_embedding.step(lr_tensor)
        optimizer.clear_grad()
        if (step + 1) % args.log_interval == 0:
            speed = args.log_interval / (time.time() - tic_train)
            logging.info(
                "step: %d, train loss: %.5f, regularization: %.4e, speed: %.2f steps/s"
                % (step, log["loss"], log["regularization"], speed))
            log["loss"] = 0.0
            tic_train = time.time()

        if args.valid and (
                step + 1
        ) % args.eval_interval == 0 and step > 1 and valid_samplers is not None:
            print("Valid begin")
            valid_start = time.time()
            valid_input_dict = test(args,
                                    model,
                                    valid_samplers,
                                    step,
                                    rank=0,
                                    mode='Valid')
            paddle.save(
                valid_input_dict,
                os.path.join(args.save_path, "valid_{}.pkl".format(step)))
            # Save the model for the inference
        if (step + 1) % args.save_step == 0:
            print("The step:{}, save model path:{}".format(
                step + 1, args.save_path))
            model.save_model()
            print("Save model done.")
示例#9
0
def run(args, logger):
    train_time_start = time.time()
    # load dataset and samplers
    dataset = get_dataset(args.data_path, args.dataset, args.format,
                          args.data_files)
    n_entities = dataset.n_entities
    n_relations = dataset.n_relations
    if args.neg_sample_size_test < 0:
        args.neg_sample_size_test = n_entities
    args.eval_filter = not args.no_eval_filter
    if args.neg_deg_sample_eval:
        assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."

    # When we generate a batch of negative edges from a set of positive edges,
    # we first divide the positive edges into chunks and corrupt the edges in a chunk
    # together. By default, the chunk size is equal to the negative sample size.
    # Usually, this works well. But we also allow users to specify the chunk size themselves.
    if args.neg_chunk_size < 0:
        args.neg_chunk_size = args.neg_sample_size
    if args.neg_chunk_size_valid < 0:
        args.neg_chunk_size_valid = args.neg_sample_size_valid
    if args.neg_chunk_size_test < 0:
        args.neg_chunk_size_test = args.neg_sample_size_test

    num_workers = args.num_worker
    train_data = TrainDataset(dataset, args, ranks=args.num_proc)
    # if there is no cross partition relaiton, we fall back to strict_rel_part
    args.strict_rel_part = args.mix_cpu_gpu and (train_data.cross_part
                                                 == False)
    args.soft_rel_part = args.mix_cpu_gpu and args.soft_rel_part and train_data.cross_part

    # Automatically set number of OMP threads for each process if it is not provided
    # The value for GPU is evaluated in AWS p3.16xlarge
    # The value for CPU is evaluated in AWS x1.32xlarge
    if args.nomp_thread_per_process == -1:
        if len(args.gpu) > 0:
            # GPU training
            args.num_thread = 4
        else:
            # CPU training
            args.num_thread = 1
    else:
        args.num_thread = args.nomp_thread_per_process

    if args.num_proc > 1:
        train_samplers = []
        for i in range(args.num_proc):
            train_sampler_head = train_data.create_sampler(
                args.batch_size,
                args.neg_sample_size,
                args.neg_chunk_size,
                mode='head',
                num_workers=num_workers,
                shuffle=True,
                exclude_positive=False,
                rank=i)
            train_sampler_tail = train_data.create_sampler(
                args.batch_size,
                args.neg_sample_size,
                args.neg_chunk_size,
                mode='tail',
                num_workers=num_workers,
                shuffle=True,
                exclude_positive=False,
                rank=i)
            train_samplers.append(
                NewBidirectionalOneShotIterator(train_sampler_head,
                                                train_sampler_tail,
                                                args.neg_chunk_size,
                                                args.neg_sample_size, True,
                                                n_entities))
    else:
        train_sampler_head = train_data.create_sampler(args.batch_size,
                                                       args.neg_sample_size,
                                                       args.neg_chunk_size,
                                                       mode='head',
                                                       num_workers=num_workers,
                                                       shuffle=True,
                                                       exclude_positive=False)
        train_sampler_tail = train_data.create_sampler(args.batch_size,
                                                       args.neg_sample_size,
                                                       args.neg_chunk_size,
                                                       mode='tail',
                                                       num_workers=num_workers,
                                                       shuffle=True,
                                                       exclude_positive=False)
        train_sampler = NewBidirectionalOneShotIterator(
            train_sampler_head, train_sampler_tail, args.neg_chunk_size,
            args.neg_sample_size, True, n_entities)

    # for multiprocessing evaluation, we don't need to sample multiple batches at a time
    # in each process.
    if args.num_proc > 1:
        num_workers = 1
    if args.valid or args.test:
        if len(args.gpu) > 1:
            args.num_test_proc = args.num_proc if args.num_proc < len(
                args.gpu) else len(args.gpu)
        else:
            args.num_test_proc = args.num_proc
        eval_dataset = EvalDataset(dataset, args)
    if args.valid:
        # Here we want to use the regualr negative sampler because we need to ensure that
        # all positive edges are excluded.
        if args.num_proc > 1:
            valid_sampler_heads = []
            valid_sampler_tails = []
            for i in range(args.num_proc):
                valid_sampler_head = eval_dataset.create_sampler(
                    'valid',
                    args.batch_size_eval,
                    args.neg_sample_size_valid,
                    args.neg_chunk_size_valid,
                    args.eval_filter,
                    mode='chunk-head',
                    num_workers=num_workers,
                    rank=i,
                    ranks=args.num_proc)
                valid_sampler_tail = eval_dataset.create_sampler(
                    'valid',
                    args.batch_size_eval,
                    args.neg_sample_size_valid,
                    args.neg_chunk_size_valid,
                    args.eval_filter,
                    mode='chunk-tail',
                    num_workers=num_workers,
                    rank=i,
                    ranks=args.num_proc)
                valid_sampler_heads.append(valid_sampler_head)
                valid_sampler_tails.append(valid_sampler_tail)
        else:
            valid_sampler_head = eval_dataset.create_sampler(
                'valid',
                args.batch_size_eval,
                args.neg_sample_size_valid,
                args.neg_chunk_size_valid,
                args.eval_filter,
                mode='chunk-head',
                num_workers=num_workers,
                rank=0,
                ranks=1)
            valid_sampler_tail = eval_dataset.create_sampler(
                'valid',
                args.batch_size_eval,
                args.neg_sample_size_valid,
                args.neg_chunk_size_valid,
                args.eval_filter,
                mode='chunk-tail',
                num_workers=num_workers,
                rank=0,
                ranks=1)
    if args.test:
        # Here we want to use the regualr negative sampler because we need to ensure that
        # all positive edges are excluded.
        # We use a maximum of num_gpu in test stage to save GPU memory.
        if args.num_test_proc > 1:
            test_sampler_tails = []
            test_sampler_heads = []
            for i in range(args.num_test_proc):
                test_sampler_head = eval_dataset.create_sampler(
                    'test',
                    args.batch_size_eval,
                    args.neg_sample_size_test,
                    args.neg_chunk_size_test,
                    args.eval_filter,
                    mode='chunk-head',
                    num_workers=num_workers,
                    rank=i,
                    ranks=args.num_test_proc)
                test_sampler_tail = eval_dataset.create_sampler(
                    'test',
                    args.batch_size_eval,
                    args.neg_sample_size_test,
                    args.neg_chunk_size_test,
                    args.eval_filter,
                    mode='chunk-tail',
                    num_workers=num_workers,
                    rank=i,
                    ranks=args.num_test_proc)
                test_sampler_heads.append(test_sampler_head)
                test_sampler_tails.append(test_sampler_tail)
        else:
            test_sampler_head = eval_dataset.create_sampler(
                'test',
                args.batch_size_eval,
                args.neg_sample_size_test,
                args.neg_chunk_size_test,
                args.eval_filter,
                mode='chunk-head',
                num_workers=num_workers,
                rank=0,
                ranks=1)
            test_sampler_tail = eval_dataset.create_sampler(
                'test',
                args.batch_size_eval,
                args.neg_sample_size_test,
                args.neg_chunk_size_test,
                args.eval_filter,
                mode='chunk-tail',
                num_workers=num_workers,
                rank=0,
                ranks=1)

    # We need to free all memory referenced by dataset.
    eval_dataset = None
    dataset = None
    # load model
    model = load_model(logger, args, n_entities, n_relations)

    if args.num_proc > 1 or args.async_update:
        model.share_memory()

    print('Total data loading time {:.3f} seconds'.format(time.time() -
                                                          train_time_start))

    # train
    start = time.time()
    rel_parts = train_data.rel_parts if args.strict_rel_part or args.soft_rel_part else None
    cross_rels = train_data.cross_rels if args.soft_rel_part else None
    if args.num_proc > 1:
        procs = []
        barrier = mp.Barrier(args.num_proc)
        for i in range(args.num_proc):
            valid_sampler = [valid_sampler_heads[i], valid_sampler_tails[i]
                             ] if args.valid else None
            proc = mp.Process(target=train_mp,
                              args=(args, model, train_samplers[i],
                                    valid_sampler, i, rel_parts, cross_rels,
                                    barrier))
            procs.append(proc)
            proc.start()
        for proc in procs:
            proc.join()
    else:
        valid_samplers = [valid_sampler_head, valid_sampler_tail
                          ] if args.valid else None
        train(args, model, train_sampler, valid_samplers, rel_parts=rel_parts)
    print('training takes {} seconds'.format(time.time() - start))

    if args.save_emb is not None:
        if not os.path.exists(args.save_emb):
            os.mkdir(args.save_emb)
        model.save_emb(args.save_emb, args.dataset)

        # We need to save the model configurations as well.
        conf_file = os.path.join(args.save_emb, 'config.json')
        with open(conf_file, 'w') as outfile:
            json.dump(
                {
                    'dataset': args.dataset,
                    'model': args.model_name,
                    'emb_size': args.hidden_dim,
                    'max_train_step': args.max_step,
                    'batch_size': args.batch_size,
                    'neg_sample_size': args.neg_sample_size,
                    'lr': args.lr,
                    'gamma': args.gamma,
                    'double_ent': args.double_ent,
                    'double_rel': args.double_rel,
                    'neg_adversarial_sampling': args.neg_adversarial_sampling,
                    'adversarial_temperature': args.adversarial_temperature,
                    'regularization_coef': args.regularization_coef,
                    'regularization_norm': args.regularization_norm
                },
                outfile,
                indent=4)

    # test
    if args.test:
        start = time.time()
        if args.num_test_proc > 1:
            queue = mp.Queue(args.num_test_proc)
            procs = []
            for i in range(args.num_test_proc):
                proc = mp.Process(target=test_mp,
                                  args=(args, model, [
                                      test_sampler_heads[i],
                                      test_sampler_tails[i]
                                  ], i, 'Test', queue))
                procs.append(proc)
                proc.start()

            total_metrics = {}
            metrics = {}
            logs = []
            for i in range(args.num_test_proc):
                log = queue.get()
                logs = logs + log

            for metric in logs[0].keys():
                metrics[metric] = sum([log[metric]
                                       for log in logs]) / len(logs)
            for k, v in metrics.items():
                print('Test average {} at [{}/{}]: {}'.format(
                    k, args.step, args.max_step, v))

            for proc in procs:
                proc.join()
        else:
            test(args, model, [test_sampler_head, test_sampler_tail])
        print('test:', time.time() - start)
示例#10
0
def main(args):
    args.eval_filter = not args.no_eval_filter
    if args.neg_deg_sample_eval:
        assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."

    # load dataset and samplers
    dataset = get_dataset(args.data_path, args.dataset, args.format,
                          args.data_files)
    args.pickle_graph = False
    args.train = False
    args.valid = False
    args.test = True
    args.strict_rel_part = False
    args.soft_rel_part = False
    args.async_update = False

    logger = get_logger(args)
    # Here we want to use the regualr negative sampler because we need to ensure that
    # all positive edges are excluded.
    eval_dataset = EvalDataset(dataset, args)

    if args.neg_sample_size_eval < 0:
        args.neg_sample_size_eval = args.neg_sample_size = eval_dataset.g.number_of_nodes(
        )
    args.batch_size_eval = get_compatible_batch_size(args.batch_size_eval,
                                                     args.neg_sample_size_eval)

    args.num_workers = 8  # fix num_workers to 8
    if args.num_proc > 1:
        test_sampler_tails = []
        test_sampler_heads = []
        for i in range(args.num_proc):
            test_sampler_head = eval_dataset.create_sampler(
                'test',
                args.batch_size_eval,
                args.neg_sample_size_eval,
                args.neg_sample_size_eval,
                args.eval_filter,
                mode='chunk-head',
                num_workers=args.num_workers,
                rank=i,
                ranks=args.num_proc)
            test_sampler_tail = eval_dataset.create_sampler(
                'test',
                args.batch_size_eval,
                args.neg_sample_size_eval,
                args.neg_sample_size_eval,
                args.eval_filter,
                mode='chunk-tail',
                num_workers=args.num_workers,
                rank=i,
                ranks=args.num_proc)
            test_sampler_heads.append(test_sampler_head)
            test_sampler_tails.append(test_sampler_tail)
    else:
        test_sampler_head = eval_dataset.create_sampler(
            'test',
            args.batch_size_eval,
            args.neg_sample_size_eval,
            args.neg_sample_size_eval,
            args.eval_filter,
            mode='chunk-head',
            num_workers=args.num_workers,
            rank=0,
            ranks=1)
        test_sampler_tail = eval_dataset.create_sampler(
            'test',
            args.batch_size_eval,
            args.neg_sample_size_eval,
            args.neg_sample_size_eval,
            args.eval_filter,
            mode='chunk-tail',
            num_workers=args.num_workers,
            rank=0,
            ranks=1)

    # load model
    n_entities = dataset.n_entities
    n_relations = dataset.n_relations
    ckpt_path = args.model_path
    model = load_model_from_checkpoint(logger, args, n_entities, n_relations,
                                       ckpt_path)

    if args.num_proc > 1:
        model.share_memory()
    # test
    args.step = 0
    args.max_step = 0
    start = time.time()
    if args.num_proc > 1:
        queue = mp.Queue(args.num_proc)
        procs = []
        for i in range(args.num_proc):
            proc = mp.Process(target=test_mp,
                              args=(args, model, [
                                  test_sampler_heads[i], test_sampler_tails[i]
                              ], i, 'Test', queue))
            procs.append(proc)
            proc.start()

        total_metrics = {}
        metrics = {}
        logs = []
        for i in range(args.num_proc):
            log = queue.get()
            logs = logs + log

        for metric in logs[0].keys():
            metrics[metric] = sum([log[metric] for log in logs]) / len(logs)
        for k, v in metrics.items():
            print('Test average {} at [{}/{}]: {}'.format(
                k, args.step, args.max_step, v))

        for proc in procs:
            proc.join()
    else:
        test(args, model, [test_sampler_head, test_sampler_tail])
    print('Test takes {:.3f} seconds'.format(time.time() - start))