Esempio n. 1
0
def run(args, logger):
    init_time_start = time.time()
    # load dataset and samplers
    dataset = get_dataset(args.data_path, args.dataset, args.format,
                          args.data_files)

    if args.neg_sample_size_eval < 0:
        args.neg_sample_size_eval = dataset.n_entities
    args.batch_size = get_compatible_batch_size(args.batch_size,
                                                args.neg_sample_size)
    args.batch_size_eval = get_compatible_batch_size(args.batch_size_eval,
                                                     args.neg_sample_size_eval)

    args.eval_filter = not args.no_eval_filter
    if args.neg_deg_sample_eval:
        assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."

    train_data = TrainDataset(dataset, args, ranks=args.num_proc)
    # if there is no cross partition relaiton, we fall back to strict_rel_part
    args.strict_rel_part = args.mix_cpu_gpu and (train_data.cross_part
                                                 == False)
    args.soft_rel_part = args.mix_cpu_gpu and args.soft_rel_part and train_data.cross_part
    args.num_workers = 8  # fix num_worker to 8

    if args.num_proc > 1:
        train_samplers = []
        for i in range(args.num_proc):
            train_sampler_head = train_data.create_sampler(
                args.batch_size,
                args.neg_sample_size,
                args.neg_sample_size,
                mode='head',
                num_workers=args.num_workers,
                shuffle=True,
                exclude_positive=False,
                rank=i)
            train_sampler_tail = train_data.create_sampler(
                args.batch_size,
                args.neg_sample_size,
                args.neg_sample_size,
                mode='tail',
                num_workers=args.num_workers,
                shuffle=True,
                exclude_positive=False,
                rank=i)
            train_samplers.append(
                NewBidirectionalOneShotIterator(train_sampler_head,
                                                train_sampler_tail,
                                                args.neg_sample_size,
                                                args.neg_sample_size, True,
                                                dataset.n_entities))

        train_sampler = NewBidirectionalOneShotIterator(
            train_sampler_head, train_sampler_tail, args.neg_sample_size,
            args.neg_sample_size, True, dataset.n_entities)
    else:  # This is used for debug
        train_sampler_head = train_data.create_sampler(
            args.batch_size,
            args.neg_sample_size,
            args.neg_sample_size,
            mode='head',
            num_workers=args.num_workers,
            shuffle=True,
            exclude_positive=False)
        train_sampler_tail = train_data.create_sampler(
            args.batch_size,
            args.neg_sample_size,
            args.neg_sample_size,
            mode='tail',
            num_workers=args.num_workers,
            shuffle=True,
            exclude_positive=False)
        train_sampler = NewBidirectionalOneShotIterator(
            train_sampler_head, train_sampler_tail, args.neg_sample_size,
            args.neg_sample_size, True, dataset.n_entities)

    if args.valid or args.test:
        if len(args.gpu) > 1:
            args.num_test_proc = args.num_proc if args.num_proc < len(
                args.gpu) else len(args.gpu)
        else:
            args.num_test_proc = args.num_proc
        eval_dataset = EvalDataset(dataset, args)

    if args.valid:
        if args.num_proc > 1:
            valid_sampler_heads = []
            valid_sampler_tails = []
            for i in range(args.num_proc):
                valid_sampler_head = eval_dataset.create_sampler(
                    'valid',
                    args.batch_size_eval,
                    args.neg_sample_size_eval,
                    args.neg_sample_size_eval,
                    args.eval_filter,
                    mode='chunk-head',
                    num_workers=args.num_workers,
                    rank=i,
                    ranks=args.num_proc)
                valid_sampler_tail = eval_dataset.create_sampler(
                    'valid',
                    args.batch_size_eval,
                    args.neg_sample_size_eval,
                    args.neg_sample_size_eval,
                    args.eval_filter,
                    mode='chunk-tail',
                    num_workers=args.num_workers,
                    rank=i,
                    ranks=args.num_proc)
                valid_sampler_heads.append(valid_sampler_head)
                valid_sampler_tails.append(valid_sampler_tail)
        else:  # This is used for debug
            valid_sampler_head = eval_dataset.create_sampler(
                'valid',
                args.batch_size_eval,
                args.neg_sample_size_eval,
                args.neg_sample_size_eval,
                args.eval_filter,
                mode='chunk-head',
                num_workers=args.num_workers,
                rank=0,
                ranks=1)
            valid_sampler_tail = eval_dataset.create_sampler(
                'valid',
                args.batch_size_eval,
                args.neg_sample_size_eval,
                args.neg_sample_size_eval,
                args.eval_filter,
                mode='chunk-tail',
                num_workers=args.num_workers,
                rank=0,
                ranks=1)
    if args.test:
        if args.num_test_proc > 1:
            test_sampler_tails = []
            test_sampler_heads = []
            for i in range(args.num_test_proc):
                test_sampler_head = eval_dataset.create_sampler(
                    'test',
                    args.batch_size_eval,
                    args.neg_sample_size_eval,
                    args.neg_sample_size_eval,
                    args.eval_filter,
                    mode='chunk-head',
                    num_workers=args.num_workers,
                    rank=i,
                    ranks=args.num_test_proc)
                test_sampler_tail = eval_dataset.create_sampler(
                    'test',
                    args.batch_size_eval,
                    args.neg_sample_size_eval,
                    args.neg_sample_size_eval,
                    args.eval_filter,
                    mode='chunk-tail',
                    num_workers=args.num_workers,
                    rank=i,
                    ranks=args.num_test_proc)
                test_sampler_heads.append(test_sampler_head)
                test_sampler_tails.append(test_sampler_tail)
        else:
            test_sampler_head = eval_dataset.create_sampler(
                'test',
                args.batch_size_eval,
                args.neg_sample_size_eval,
                args.neg_sample_size_eval,
                args.eval_filter,
                mode='chunk-head',
                num_workers=args.num_workers,
                rank=0,
                ranks=1)
            test_sampler_tail = eval_dataset.create_sampler(
                'test',
                args.batch_size_eval,
                args.neg_sample_size_eval,
                args.neg_sample_size_eval,
                args.eval_filter,
                mode='chunk-tail',
                num_workers=args.num_workers,
                rank=0,
                ranks=1)

    # load model
    model = load_model(logger, args, dataset.n_entities, dataset.n_relations)
    if args.num_proc > 1 or args.async_update:
        model.share_memory()

    # We need to free all memory referenced by dataset.
    eval_dataset = None
    dataset = None

    print('Total initialize time {:.3f} seconds'.format(time.time() -
                                                        init_time_start))

    # train
    start = time.time()
    rel_parts = train_data.rel_parts if args.strict_rel_part or args.soft_rel_part else None
    cross_rels = train_data.cross_rels if args.soft_rel_part else None
    if args.num_proc > 1:
        procs = []
        barrier = mp.Barrier(args.num_proc)
        for i in range(args.num_proc):
            valid_sampler = [valid_sampler_heads[i], valid_sampler_tails[i]
                             ] if args.valid else None
            proc = mp.Process(target=train_mp,
                              args=(args, model, train_samplers[i],
                                    valid_sampler, i, rel_parts, cross_rels,
                                    barrier))
            procs.append(proc)
            proc.start()
        for proc in procs:
            proc.join()
    else:
        valid_samplers = [valid_sampler_head, valid_sampler_tail
                          ] if args.valid else None
        train(args, model, train_sampler, valid_samplers, rel_parts=rel_parts)

    print('training takes {} seconds'.format(time.time() - start))

    if args.save_emb is not None:
        if not os.path.exists(args.save_emb):
            os.mkdir(args.save_emb)
        model.save_emb(args.save_emb, args.dataset)

        # We need to save the model configurations as well.
        conf_file = os.path.join(args.save_emb, 'config.json')
        with open(conf_file, 'w') as outfile:
            json.dump(
                {
                    'dataset': args.dataset,
                    'model': args.model_name,
                    'emb_size': args.hidden_dim,
                    'max_train_step': args.max_step,
                    'batch_size': args.batch_size,
                    'neg_sample_size': args.neg_sample_size,
                    'lr': args.lr,
                    'gamma': args.gamma,
                    'double_ent': args.double_ent,
                    'double_rel': args.double_rel,
                    'neg_adversarial_sampling': args.neg_adversarial_sampling,
                    'adversarial_temperature': args.adversarial_temperature,
                    'regularization_coef': args.regularization_coef,
                    'regularization_norm': args.regularization_norm
                },
                outfile,
                indent=4)

    # test
    if args.test:
        start = time.time()
        if args.num_test_proc > 1:
            queue = mp.Queue(args.num_test_proc)
            procs = []
            for i in range(args.num_test_proc):
                proc = mp.Process(target=test_mp,
                                  args=(args, model, [
                                      test_sampler_heads[i],
                                      test_sampler_tails[i]
                                  ], i, 'Test', queue))
                procs.append(proc)
                proc.start()

            total_metrics = {}
            metrics = {}
            logs = []
            for i in range(args.num_test_proc):
                log = queue.get()
                logs = logs + log

            for metric in logs[0].keys():
                metrics[metric] = sum([log[metric]
                                       for log in logs]) / len(logs)
            for k, v in metrics.items():
                print('Test average {} : {}'.format(k, v))

            for proc in procs:
                proc.join()
        else:
            test(args, model, [test_sampler_head, test_sampler_tail])
        print('testing takes {:.3f} seconds'.format(time.time() - start))
Esempio n. 2
0
def run(args, logger):
    # load dataset and samplers
    dataset = get_dataset(args.data_path, args.dataset, args.format)
    n_entities = dataset.n_entities
    n_relations = dataset.n_relations
    if args.neg_sample_size_test < 0:
        args.neg_sample_size_test = n_entities
    args.eval_filter = not args.no_eval_filter
    if args.neg_deg_sample_eval:
        assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."

    # When we generate a batch of negative edges from a set of positive edges,
    # we first divide the positive edges into chunks and corrupt the edges in a chunk
    # together. By default, the chunk size is equal to the negative sample size.
    # Usually, this works well. But we also allow users to specify the chunk size themselves.
    if args.neg_chunk_size < 0:
        args.neg_chunk_size = args.neg_sample_size
    if args.neg_chunk_size_valid < 0:
        args.neg_chunk_size_valid = args.neg_sample_size_valid
    if args.neg_chunk_size_test < 0:
        args.neg_chunk_size_test = args.neg_sample_size_test

    train_data = TrainDataset(dataset, args, ranks=args.num_proc)
    if args.num_proc > 1:
        train_samplers = []
        for i in range(args.num_proc):
            train_sampler_head = train_data.create_sampler(
                args.batch_size,
                args.neg_sample_size,
                args.neg_chunk_size,
                mode='chunk-head',
                num_workers=args.num_worker,
                shuffle=True,
                exclude_positive=True,
                rank=i)
            train_sampler_tail = train_data.create_sampler(
                args.batch_size,
                args.neg_sample_size,
                args.neg_chunk_size,
                mode='chunk-tail',
                num_workers=args.num_worker,
                shuffle=True,
                exclude_positive=True,
                rank=i)
            train_samplers.append(
                NewBidirectionalOneShotIterator(train_sampler_head,
                                                train_sampler_tail,
                                                args.neg_chunk_size, True,
                                                n_entities))
    else:
        train_sampler_head = train_data.create_sampler(
            args.batch_size,
            args.neg_sample_size,
            args.neg_chunk_size,
            mode='chunk-head',
            num_workers=args.num_worker,
            shuffle=True,
            exclude_positive=True)
        train_sampler_tail = train_data.create_sampler(
            args.batch_size,
            args.neg_sample_size,
            args.neg_chunk_size,
            mode='chunk-tail',
            num_workers=args.num_worker,
            shuffle=True,
            exclude_positive=True)
        train_sampler = NewBidirectionalOneShotIterator(
            train_sampler_head, train_sampler_tail, args.neg_chunk_size, True,
            n_entities)

    # for multiprocessing evaluation, we don't need to sample multiple batches at a time
    # in each process.
    num_workers = args.num_worker
    if args.num_proc > 1:
        num_workers = 1
    if args.valid or args.test:
        eval_dataset = EvalDataset(dataset, args)
    if args.valid:
        # Here we want to use the regualr negative sampler because we need to ensure that
        # all positive edges are excluded.
        if args.num_proc > 1:
            valid_sampler_heads = []
            valid_sampler_tails = []
            for i in range(args.num_proc):
                valid_sampler_head = eval_dataset.create_sampler(
                    'valid',
                    args.batch_size_eval,
                    args.neg_sample_size_valid,
                    args.neg_chunk_size_valid,
                    args.eval_filter,
                    mode='chunk-head',
                    num_workers=num_workers,
                    rank=i,
                    ranks=args.num_proc)
                valid_sampler_tail = eval_dataset.create_sampler(
                    'valid',
                    args.batch_size_eval,
                    args.neg_sample_size_valid,
                    args.neg_chunk_size_valid,
                    args.eval_filter,
                    mode='chunk-tail',
                    num_workers=num_workers,
                    rank=i,
                    ranks=args.num_proc)
                valid_sampler_heads.append(valid_sampler_head)
                valid_sampler_tails.append(valid_sampler_tail)
        else:
            valid_sampler_head = eval_dataset.create_sampler(
                'valid',
                args.batch_size_eval,
                args.neg_sample_size_valid,
                args.neg_chunk_size_valid,
                args.eval_filter,
                mode='chunk-head',
                num_workers=num_workers,
                rank=0,
                ranks=1)
            valid_sampler_tail = eval_dataset.create_sampler(
                'valid',
                args.batch_size_eval,
                args.neg_sample_size_valid,
                args.neg_chunk_size_valid,
                args.eval_filter,
                mode='chunk-tail',
                num_workers=num_workers,
                rank=0,
                ranks=1)
    if args.test:
        # Here we want to use the regualr negative sampler because we need to ensure that
        # all positive edges are excluded.
        if args.num_proc > 1:
            test_sampler_tails = []
            test_sampler_heads = []
            for i in range(args.num_proc):
                test_sampler_head = eval_dataset.create_sampler(
                    'test',
                    args.batch_size_eval,
                    args.neg_sample_size_test,
                    args.neg_chunk_size_test,
                    args.eval_filter,
                    mode='chunk-head',
                    num_workers=num_workers,
                    rank=i,
                    ranks=args.num_proc)
                test_sampler_tail = eval_dataset.create_sampler(
                    'test',
                    args.batch_size_eval,
                    args.neg_sample_size_test,
                    args.neg_chunk_size_test,
                    args.eval_filter,
                    mode='chunk-tail',
                    num_workers=num_workers,
                    rank=i,
                    ranks=args.num_proc)
                test_sampler_heads.append(test_sampler_head)
                test_sampler_tails.append(test_sampler_tail)
        else:
            test_sampler_head = eval_dataset.create_sampler(
                'test',
                args.batch_size_eval,
                args.neg_sample_size_test,
                args.neg_chunk_size_test,
                args.eval_filter,
                mode='chunk-head',
                num_workers=num_workers,
                rank=0,
                ranks=1)
            test_sampler_tail = eval_dataset.create_sampler(
                'test',
                args.batch_size_eval,
                args.neg_sample_size_test,
                args.neg_chunk_size_test,
                args.eval_filter,
                mode='chunk-tail',
                num_workers=num_workers,
                rank=0,
                ranks=1)

    # We need to free all memory referenced by dataset.
    eval_dataset = None
    dataset = None
    # load model
    model = load_model(logger, args, n_entities, n_relations)

    if args.num_proc > 1:
        model.share_memory()

    # train
    start = time.time()
    if args.num_proc > 1:
        procs = []
        for i in range(args.num_proc):
            rel_parts = train_data.rel_parts if args.rel_part else None
            valid_samplers = [valid_sampler_heads[i], valid_sampler_tails[i]
                              ] if args.valid else None
            proc = mp.Process(target=train,
                              args=(args, model, train_samplers[i], i,
                                    rel_parts, valid_samplers))
            procs.append(proc)
            proc.start()
        for proc in procs:
            proc.join()
    else:
        valid_samplers = [valid_sampler_head, valid_sampler_tail
                          ] if args.valid else None
        train(args, model, train_sampler, valid_samplers)
    print('training takes {} seconds'.format(time.time() - start))

    if args.save_emb is not None:
        if not os.path.exists(args.save_emb):
            os.mkdir(args.save_emb)
        model.save_emb(args.save_emb, args.dataset)

    # test
    if args.test:
        start = time.time()
        if args.num_proc > 1:
            queue = mp.Queue(args.num_proc)
            procs = []
            for i in range(args.num_proc):
                proc = mp.Process(target=test,
                                  args=(args, model, [
                                      test_sampler_heads[i],
                                      test_sampler_tails[i]
                                  ], i, 'Test', queue))
                procs.append(proc)
                proc.start()

            total_metrics = {}
            for i in range(args.num_proc):
                metrics = queue.get()
                for k, v in metrics.items():
                    if i == 0:
                        total_metrics[k] = v / args.num_proc
                    else:
                        total_metrics[k] += v / args.num_proc
            for k, v in metrics.items():
                print('Test average {} at [{}/{}]: {}'.format(
                    k, args.step, args.max_step, v))

            for proc in procs:
                proc.join()
        else:
            test(args, model, [test_sampler_head, test_sampler_tail])
        print('test:', time.time() - start)
Esempio n. 3
0
def run(args, logger):
    # load dataset and samplers
    dataset = get_dataset(args.data_path, args.dataset, args.format)
    n_entities = dataset.n_entities
    n_relations = dataset.n_relations
    if args.neg_sample_size_test < 0:
        args.neg_sample_size_test = n_entities

    train_data = TrainDataset(dataset, args, ranks=args.num_proc)
    if args.num_proc > 1:
        train_samplers = []
        for i in range(args.num_proc):
            train_sampler_head = train_data.create_sampler(
                args.batch_size,
                args.neg_sample_size,
                mode='PBG-head',
                num_workers=args.num_worker,
                shuffle=True,
                exclude_positive=True,
                rank=i)
            train_sampler_tail = train_data.create_sampler(
                args.batch_size,
                args.neg_sample_size,
                mode='PBG-tail',
                num_workers=args.num_worker,
                shuffle=True,
                exclude_positive=True,
                rank=i)
            train_samplers.append(
                NewBidirectionalOneShotIterator(train_sampler_head,
                                                train_sampler_tail, True,
                                                n_entities))
    else:
        train_sampler_head = train_data.create_sampler(
            args.batch_size,
            args.neg_sample_size,
            mode='PBG-head',
            num_workers=args.num_worker,
            shuffle=True,
            exclude_positive=True)
        train_sampler_tail = train_data.create_sampler(
            args.batch_size,
            args.neg_sample_size,
            mode='PBG-tail',
            num_workers=args.num_worker,
            shuffle=True,
            exclude_positive=True)
        train_sampler = NewBidirectionalOneShotIterator(
            train_sampler_head, train_sampler_tail, True, n_entities)

    if args.valid or args.test:
        eval_dataset = EvalDataset(dataset, args)
    if args.valid:
        # Here we want to use the regualr negative sampler because we need to ensure that
        # all positive edges are excluded.
        if args.num_proc > 1:
            valid_sampler_heads = []
            valid_sampler_tails = []
            for i in range(args.num_proc):
                valid_sampler_head = eval_dataset.create_sampler(
                    'valid',
                    args.batch_size_eval,
                    args.neg_sample_size_valid,
                    mode='PBG-head',
                    num_workers=args.num_worker,
                    rank=i,
                    ranks=args.num_proc)
                valid_sampler_tail = eval_dataset.create_sampler(
                    'valid',
                    args.batch_size_eval,
                    args.neg_sample_size_valid,
                    mode='PBG-tail',
                    num_workers=args.num_worker,
                    rank=i,
                    ranks=args.num_proc)
                valid_sampler_heads.append(valid_sampler_head)
                valid_sampler_tails.append(valid_sampler_tail)
        else:
            valid_sampler_head = eval_dataset.create_sampler(
                'valid',
                args.batch_size_eval,
                args.neg_sample_size_valid,
                mode='PBG-head',
                num_workers=args.num_worker,
                rank=0,
                ranks=1)
            valid_sampler_tail = eval_dataset.create_sampler(
                'valid',
                args.batch_size_eval,
                args.neg_sample_size_valid,
                mode='PBG-tail',
                num_workers=args.num_worker,
                rank=0,
                ranks=1)
    if args.test:
        # Here we want to use the regualr negative sampler because we need to ensure that
        # all positive edges are excluded.
        if args.num_proc > 1:
            test_sampler_tails = []
            test_sampler_heads = []
            for i in range(args.num_proc):
                test_sampler_head = eval_dataset.create_sampler(
                    'test',
                    args.batch_size_eval,
                    args.neg_sample_size_test,
                    mode='PBG-head',
                    num_workers=args.num_worker,
                    rank=i,
                    ranks=args.num_proc)
                test_sampler_tail = eval_dataset.create_sampler(
                    'test',
                    args.batch_size_eval,
                    args.neg_sample_size_test,
                    mode='PBG-tail',
                    num_workers=args.num_worker,
                    rank=i,
                    ranks=args.num_proc)
                test_sampler_heads.append(test_sampler_head)
                test_sampler_tails.append(test_sampler_tail)
        else:
            test_sampler_head = eval_dataset.create_sampler(
                'test',
                args.batch_size_eval,
                args.neg_sample_size_test,
                mode='PBG-head',
                num_workers=args.num_worker,
                rank=0,
                ranks=1)
            test_sampler_tail = eval_dataset.create_sampler(
                'test',
                args.batch_size_eval,
                args.neg_sample_size_test,
                mode='PBG-tail',
                num_workers=args.num_worker,
                rank=0,
                ranks=1)

    # We need to free all memory referenced by dataset.
    eval_dataset = None
    dataset = None
    # load model
    model = load_model(logger, args, n_entities, n_relations)

    if args.num_proc > 1:
        model.share_memory()

    # train
    start = time.time()
    if args.num_proc > 1:
        procs = []
        for i in range(args.num_proc):
            valid_samplers = [valid_sampler_heads[i], valid_sampler_tails[i]
                              ] if args.valid else None
            proc = mp.Process(target=train,
                              args=(args, model, train_samplers[i],
                                    valid_samplers))
            procs.append(proc)
            proc.start()
        for proc in procs:
            proc.join()
    else:
        valid_samplers = [valid_sampler_head, valid_sampler_tail
                          ] if args.valid else None
        train(args, model, train_sampler, valid_samplers)
    print('training takes {} seconds'.format(time.time() - start))

    if args.save_emb is not None:
        if not os.path.exists(args.save_emb):
            os.mkdir(args.save_emb)
        model.save_emb(args.save_emb, args.dataset)

    # test
    if args.test:
        if args.num_proc > 1:
            procs = []
            for i in range(args.num_proc):
                proc = mp.Process(target=test,
                                  args=(args, model, [
                                      test_sampler_heads[i],
                                      test_sampler_tails[i]
                                  ]))
                procs.append(proc)
                proc.start()
            for proc in procs:
                proc.join()
        else:
            test(args, model, [test_sampler_head, test_sampler_tail])
Esempio n. 4
0
def run(args, logger):
    train_time_start = time.time()
    # load dataset and samplers
    dataset = get_dataset(args.data_path, args.dataset, args.format,
                          args.data_files)
    n_entities = dataset.n_entities
    n_relations = dataset.n_relations
    if args.neg_sample_size_test < 0:
        args.neg_sample_size_test = n_entities
    args.eval_filter = not args.no_eval_filter
    if args.neg_deg_sample_eval:
        assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."

    # When we generate a batch of negative edges from a set of positive edges,
    # we first divide the positive edges into chunks and corrupt the edges in a chunk
    # together. By default, the chunk size is equal to the negative sample size.
    # Usually, this works well. But we also allow users to specify the chunk size themselves.
    if args.neg_chunk_size < 0:
        args.neg_chunk_size = args.neg_sample_size
    if args.neg_chunk_size_valid < 0:
        args.neg_chunk_size_valid = args.neg_sample_size_valid
    if args.neg_chunk_size_test < 0:
        args.neg_chunk_size_test = args.neg_sample_size_test

    num_workers = args.num_worker
    train_data = TrainDataset(dataset, args, ranks=args.num_proc)
    # if there is no cross partition relaiton, we fall back to strict_rel_part
    args.strict_rel_part = args.mix_cpu_gpu and (train_data.cross_part
                                                 == False)
    args.soft_rel_part = args.mix_cpu_gpu and args.soft_rel_part and train_data.cross_part

    # Automatically set number of OMP threads for each process if it is not provided
    # The value for GPU is evaluated in AWS p3.16xlarge
    # The value for CPU is evaluated in AWS x1.32xlarge
    if args.nomp_thread_per_process == -1:
        if len(args.gpu) > 0:
            # GPU training
            args.num_thread = 4
        else:
            # CPU training
            args.num_thread = 1
    else:
        args.num_thread = args.nomp_thread_per_process

    if args.num_proc > 1:
        train_samplers = []
        for i in range(args.num_proc):
            train_sampler_head = train_data.create_sampler(
                args.batch_size,
                args.neg_sample_size,
                args.neg_chunk_size,
                mode='head',
                num_workers=num_workers,
                shuffle=True,
                exclude_positive=False,
                rank=i)
            train_sampler_tail = train_data.create_sampler(
                args.batch_size,
                args.neg_sample_size,
                args.neg_chunk_size,
                mode='tail',
                num_workers=num_workers,
                shuffle=True,
                exclude_positive=False,
                rank=i)
            train_samplers.append(
                NewBidirectionalOneShotIterator(train_sampler_head,
                                                train_sampler_tail,
                                                args.neg_chunk_size,
                                                args.neg_sample_size, True,
                                                n_entities))
    else:
        train_sampler_head = train_data.create_sampler(args.batch_size,
                                                       args.neg_sample_size,
                                                       args.neg_chunk_size,
                                                       mode='head',
                                                       num_workers=num_workers,
                                                       shuffle=True,
                                                       exclude_positive=False)
        train_sampler_tail = train_data.create_sampler(args.batch_size,
                                                       args.neg_sample_size,
                                                       args.neg_chunk_size,
                                                       mode='tail',
                                                       num_workers=num_workers,
                                                       shuffle=True,
                                                       exclude_positive=False)
        train_sampler = NewBidirectionalOneShotIterator(
            train_sampler_head, train_sampler_tail, args.neg_chunk_size,
            args.neg_sample_size, True, n_entities)

    # for multiprocessing evaluation, we don't need to sample multiple batches at a time
    # in each process.
    if args.num_proc > 1:
        num_workers = 1
    if args.valid or args.test:
        if len(args.gpu) > 1:
            args.num_test_proc = args.num_proc if args.num_proc < len(
                args.gpu) else len(args.gpu)
        else:
            args.num_test_proc = args.num_proc
        eval_dataset = EvalDataset(dataset, args)
    if args.valid:
        # Here we want to use the regualr negative sampler because we need to ensure that
        # all positive edges are excluded.
        if args.num_proc > 1:
            valid_sampler_heads = []
            valid_sampler_tails = []
            for i in range(args.num_proc):
                valid_sampler_head = eval_dataset.create_sampler(
                    'valid',
                    args.batch_size_eval,
                    args.neg_sample_size_valid,
                    args.neg_chunk_size_valid,
                    args.eval_filter,
                    mode='chunk-head',
                    num_workers=num_workers,
                    rank=i,
                    ranks=args.num_proc)
                valid_sampler_tail = eval_dataset.create_sampler(
                    'valid',
                    args.batch_size_eval,
                    args.neg_sample_size_valid,
                    args.neg_chunk_size_valid,
                    args.eval_filter,
                    mode='chunk-tail',
                    num_workers=num_workers,
                    rank=i,
                    ranks=args.num_proc)
                valid_sampler_heads.append(valid_sampler_head)
                valid_sampler_tails.append(valid_sampler_tail)
        else:
            valid_sampler_head = eval_dataset.create_sampler(
                'valid',
                args.batch_size_eval,
                args.neg_sample_size_valid,
                args.neg_chunk_size_valid,
                args.eval_filter,
                mode='chunk-head',
                num_workers=num_workers,
                rank=0,
                ranks=1)
            valid_sampler_tail = eval_dataset.create_sampler(
                'valid',
                args.batch_size_eval,
                args.neg_sample_size_valid,
                args.neg_chunk_size_valid,
                args.eval_filter,
                mode='chunk-tail',
                num_workers=num_workers,
                rank=0,
                ranks=1)
    if args.test:
        # Here we want to use the regualr negative sampler because we need to ensure that
        # all positive edges are excluded.
        # We use a maximum of num_gpu in test stage to save GPU memory.
        if args.num_test_proc > 1:
            test_sampler_tails = []
            test_sampler_heads = []
            for i in range(args.num_test_proc):
                test_sampler_head = eval_dataset.create_sampler(
                    'test',
                    args.batch_size_eval,
                    args.neg_sample_size_test,
                    args.neg_chunk_size_test,
                    args.eval_filter,
                    mode='chunk-head',
                    num_workers=num_workers,
                    rank=i,
                    ranks=args.num_test_proc)
                test_sampler_tail = eval_dataset.create_sampler(
                    'test',
                    args.batch_size_eval,
                    args.neg_sample_size_test,
                    args.neg_chunk_size_test,
                    args.eval_filter,
                    mode='chunk-tail',
                    num_workers=num_workers,
                    rank=i,
                    ranks=args.num_test_proc)
                test_sampler_heads.append(test_sampler_head)
                test_sampler_tails.append(test_sampler_tail)
        else:
            test_sampler_head = eval_dataset.create_sampler(
                'test',
                args.batch_size_eval,
                args.neg_sample_size_test,
                args.neg_chunk_size_test,
                args.eval_filter,
                mode='chunk-head',
                num_workers=num_workers,
                rank=0,
                ranks=1)
            test_sampler_tail = eval_dataset.create_sampler(
                'test',
                args.batch_size_eval,
                args.neg_sample_size_test,
                args.neg_chunk_size_test,
                args.eval_filter,
                mode='chunk-tail',
                num_workers=num_workers,
                rank=0,
                ranks=1)

    # We need to free all memory referenced by dataset.
    eval_dataset = None
    dataset = None
    # load model
    model = load_model(logger, args, n_entities, n_relations)

    if args.num_proc > 1 or args.async_update:
        model.share_memory()

    print('Total data loading time {:.3f} seconds'.format(time.time() -
                                                          train_time_start))

    # train
    start = time.time()
    rel_parts = train_data.rel_parts if args.strict_rel_part or args.soft_rel_part else None
    cross_rels = train_data.cross_rels if args.soft_rel_part else None
    if args.num_proc > 1:
        procs = []
        barrier = mp.Barrier(args.num_proc)
        for i in range(args.num_proc):
            valid_sampler = [valid_sampler_heads[i], valid_sampler_tails[i]
                             ] if args.valid else None
            proc = mp.Process(target=train_mp,
                              args=(args, model, train_samplers[i],
                                    valid_sampler, i, rel_parts, cross_rels,
                                    barrier))
            procs.append(proc)
            proc.start()
        for proc in procs:
            proc.join()
    else:
        valid_samplers = [valid_sampler_head, valid_sampler_tail
                          ] if args.valid else None
        train(args, model, train_sampler, valid_samplers, rel_parts=rel_parts)
    print('training takes {} seconds'.format(time.time() - start))

    if args.save_emb is not None:
        if not os.path.exists(args.save_emb):
            os.mkdir(args.save_emb)
        model.save_emb(args.save_emb, args.dataset)

        # We need to save the model configurations as well.
        conf_file = os.path.join(args.save_emb, 'config.json')
        with open(conf_file, 'w') as outfile:
            json.dump(
                {
                    'dataset': args.dataset,
                    'model': args.model_name,
                    'emb_size': args.hidden_dim,
                    'max_train_step': args.max_step,
                    'batch_size': args.batch_size,
                    'neg_sample_size': args.neg_sample_size,
                    'lr': args.lr,
                    'gamma': args.gamma,
                    'double_ent': args.double_ent,
                    'double_rel': args.double_rel,
                    'neg_adversarial_sampling': args.neg_adversarial_sampling,
                    'adversarial_temperature': args.adversarial_temperature,
                    'regularization_coef': args.regularization_coef,
                    'regularization_norm': args.regularization_norm
                },
                outfile,
                indent=4)

    # test
    if args.test:
        start = time.time()
        if args.num_test_proc > 1:
            queue = mp.Queue(args.num_test_proc)
            procs = []
            for i in range(args.num_test_proc):
                proc = mp.Process(target=test_mp,
                                  args=(args, model, [
                                      test_sampler_heads[i],
                                      test_sampler_tails[i]
                                  ], i, 'Test', queue))
                procs.append(proc)
                proc.start()

            total_metrics = {}
            metrics = {}
            logs = []
            for i in range(args.num_test_proc):
                log = queue.get()
                logs = logs + log

            for metric in logs[0].keys():
                metrics[metric] = sum([log[metric]
                                       for log in logs]) / len(logs)
            for k, v in metrics.items():
                print('Test average {} at [{}/{}]: {}'.format(
                    k, args.step, args.max_step, v))

            for proc in procs:
                proc.join()
        else:
            test(args, model, [test_sampler_head, test_sampler_tail])
        print('test:', time.time() - start)