コード例 #1
0
ファイル: run.py プロジェクト: pengyanhui/LineaRE
def run():
    # load data
    ent2id = read_elements(os.path.join(config.data_path, "entities.dict"))
    rel2id = read_elements(os.path.join(config.data_path, "relations.dict"))
    ent_num = len(ent2id)
    rel_num = len(rel2id)
    train_triples = read_triples(os.path.join(config.data_path, "train.txt"),
                                 ent2id, rel2id)
    valid_triples = read_triples(os.path.join(config.data_path, "valid.txt"),
                                 ent2id, rel2id)
    test_triples = read_triples(os.path.join(config.data_path, "test.txt"),
                                ent2id, rel2id)
    symmetry_test = read_triples(
        os.path.join(config.data_path, "symmetry_test.txt"), ent2id, rel2id)
    inversion_test = read_triples(
        os.path.join(config.data_path, "inversion_test.txt"), ent2id, rel2id)
    composition_test = read_triples(
        os.path.join(config.data_path, "composition_test.txt"), ent2id, rel2id)
    others_test = read_triples(
        os.path.join(config.data_path, "other_test.txt"), ent2id, rel2id)
    logging.info("#ent_num: %d" % ent_num)
    logging.info("#rel_num: %d" % rel_num)
    logging.info("#train triple num: %d" % len(train_triples))
    logging.info("#valid triple num: %d" % len(valid_triples))
    logging.info("#test triple num: %d" % len(test_triples))
    logging.info("#Model: %s" % config.model)

    # 创建模型
    kge_model = TransE(ent_num, rel_num)
    if config.model == "TransH":
        kge_model = TransH(ent_num, rel_num)
    elif config.model == "TransR":
        kge_model = SimpleTransR(ent_num, rel_num)
    elif config.model == "TransD":
        kge_model = TransD(ent_num, rel_num)
    elif config.model == "STransE":
        kge_model = STransE(ent_num, rel_num)
    elif config.model == "LineaRE":
        kge_model = LineaRE(ent_num, rel_num)
    elif config.model == "DistMult":
        kge_model = DistMult(ent_num, rel_num)
    elif config.model == "ComplEx":
        kge_model = ComplEx(ent_num, rel_num)
    elif config.model == "RotatE":
        kge_model = RotatE(ent_num, rel_num)
    elif config.model == "TransIJ":
        kge_model = TransIJ(ent_num, rel_num)

    kge_model = kge_model.cuda(torch.device("cuda:0"))
    logging.info("Model Parameter Configuration:")
    for name, param in kge_model.named_parameters():
        logging.info("Parameter %s: %s, require_grad = %s" %
                     (name, str(param.size()), str(param.requires_grad)))

    # 训练
    train(model=kge_model,
          triples=(train_triples, valid_triples, test_triples, symmetry_test,
                   inversion_test, composition_test, others_test),
          ent_num=ent_num)
コード例 #2
0
def run():
    set_logger()

    # load data
    ent_path = os.path.join(config.data_path, "entities.dict")
    rel_path = os.path.join(config.data_path, "relations.dict")
    ent2id = read_elements(ent_path)
    rel2id = read_elements(rel_path)
    ent_num = len(ent2id)
    rel_num = len(rel2id)
    train_triples = read_triples(os.path.join(config.data_path, "train.txt"),
                                 ent2id, rel2id)
    valid_triples = read_triples(os.path.join(config.data_path, "valid.txt"),
                                 ent2id, rel2id)
    test_triples = read_triples(os.path.join(config.data_path, "test.txt"),
                                ent2id, rel2id)
    logging.info("#ent_num: %d" % ent_num)
    logging.info("#rel_num: %d" % rel_num)
    logging.info("#train triple num: %d" % len(train_triples))
    logging.info("#valid triple num: %d" % len(valid_triples))
    logging.info("#test triple num: %d" % len(test_triples))
    logging.info("#Model: %s" % config.model)

    # 创建模型
    kge_model = TransE(ent_num, rel_num)
    if config.model == "TransH":
        kge_model = TransH(ent_num, rel_num)
    elif config.model == "TransR":
        kge_model = TransR(ent_num, rel_num)
    elif config.model == "TransD":
        kge_model = TransD(ent_num, rel_num)
    elif config.model == "STransE":
        kge_model = STransE(ent_num, rel_num)
    elif config.model == "LineaRE":
        kge_model = LineaRE(ent_num, rel_num)
    elif config.model == "DistMult":
        kge_model = DistMult(ent_num, rel_num)
    elif config.model == "ComplEx":
        kge_model = ComplEx(ent_num, rel_num)
    elif config.model == "RotatE":
        kge_model = RotatE(ent_num, rel_num)

    if config.cuda:
        kge_model = kge_model.cuda()
    logging.info("Model Parameter Configuration:")
    for name, param in kge_model.named_parameters():
        logging.info("Parameter %s: %s, require_grad = %s" %
                     (name, str(param.size()), str(param.requires_grad)))

    # 训练
    train(model=kge_model,
          triples=(train_triples, valid_triples, test_triples),
          ent_num=ent_num)
コード例 #3
0
def main(args, model_path):
    if args.preprocess: preprocess(args.data, delete_data=True)
    input_keys = ['e1', 'rel', 'rel_eval', 'e2', 'e2_multi1', 'e2_multi2']
    p = Pipeline(args.data, keys=input_keys)
    p.load_vocabs()
    vocab = p.state['vocab']

    num_entities = vocab['e1'].num_token

    train_batcher = StreamBatcher(args.data, 'train', args.batch_size, randomize=True, keys=input_keys, loader_threads=args.loader_threads)
    dev_rank_batcher = StreamBatcher(args.data, 'dev_ranking', args.test_batch_size, randomize=False, loader_threads=args.loader_threads, keys=input_keys)
    test_rank_batcher = StreamBatcher(args.data, 'test_ranking', args.test_batch_size, randomize=False, loader_threads=args.loader_threads, keys=input_keys)


    if args.model is None:
        model = ConvE(args, vocab['e1'].num_token, vocab['rel'].num_token)
    elif args.model == 'conve':
        model = ConvE(args, vocab['e1'].num_token, vocab['rel'].num_token)
    elif args.model == 'distmult':
        model = DistMult(args, vocab['e1'].num_token, vocab['rel'].num_token)
    elif args.model == 'complex':
        model = Complex(args, vocab['e1'].num_token, vocab['rel'].num_token)
    elif args.model == 'interacte':
        model = InteractE(args, vocab['e1'].num_token, vocab['rel'].num_token)
    else:
        log.info('Unknown model: {0}', args.model)
        raise Exception("Unknown model!")

    train_batcher.at_batch_prepared_observers.insert(1,TargetIdx2MultiTarget(num_entities, 'e2_multi1', 'e2_multi1_binary'))


    eta = ETAHook('train', print_every_x_batches=args.log_interval)
    train_batcher.subscribe_to_events(eta)
    train_batcher.subscribe_to_start_of_epoch_event(eta)
    train_batcher.subscribe_to_events(LossHook('train', print_every_x_batches=args.log_interval))

    model.cuda()
    if args.resume:
        model_params = torch.load(model_path)
        print(model)
        total_param_size = []
        params = [(key, value.size(), value.numel()) for key, value in model_params.items()]
        for key, size, count in params:
            total_param_size.append(count)
            print(key, size, count)
        print(np.array(total_param_size).sum())
        model.load_state_dict(model_params)
        model.eval()
        ranking_and_hits(model, test_rank_batcher, vocab, 'test_evaluation')
        ranking_and_hits(model, dev_rank_batcher, vocab, 'dev_evaluation')
    else:
        model.init()

    total_param_size = []
    params = [value.numel() for value in model.parameters()]
    print(params)
    print(np.sum(params))

    opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2)
    for epoch in range(args.epochs):
        model.train()
        for i, str2var in enumerate(train_batcher):
            opt.zero_grad()
            e1 = str2var['e1']
            rel = str2var['rel']
            e2_multi = str2var['e2_multi1_binary'].float()
            # label smoothing
            e2_multi = ((1.0-args.label_smoothing)*e2_multi) + (1.0/e2_multi.size(1))

            pred = model.forward(e1, rel)
            loss = model.loss(pred, e2_multi)
            loss.backward()
            opt.step()

            train_batcher.state.loss = loss.cpu()


        print('saving to {0}'.format(model_path))
        torch.save(model.state_dict(), model_path)

        model.eval()
        with torch.no_grad():
            if epoch % 5 == 0 and epoch > 0:
                ranking_and_hits(model, dev_rank_batcher, vocab, 'dev_evaluation')
            if epoch % 5 == 0:
                if epoch > 0:
                    ranking_and_hits(model, test_rank_batcher, vocab, 'test_evaluation')
コード例 #4
0
ファイル: main.py プロジェクト: chirag-choudhary/kb-logic
def main():
    if Config.process: preprocess(Config.dataset, delete_data=True)
    input_keys = ['e1', 'rel', 'rel_eval', 'e2', 'e2_multi1', 'e2_multi2']
    p = Pipeline(Config.dataset, keys=input_keys)
    p.load_vocabs()
    vocab = p.state['vocab']

    num_entities = vocab['e1'].num_token

    train_batcher = StreamBatcher(Config.dataset,
                                  'train',
                                  Config.batch_size,
                                  randomize=True,
                                  keys=input_keys)
    dev_rank_batcher = StreamBatcher(Config.dataset,
                                     'dev_ranking',
                                     Config.batch_size,
                                     randomize=False,
                                     loader_threads=4,
                                     keys=input_keys)
    test_rank_batcher = StreamBatcher(Config.dataset,
                                      'test_ranking',
                                      Config.batch_size,
                                      randomize=False,
                                      loader_threads=4,
                                      keys=input_keys)

    if Config.model_name is None:
        model = ConvE(vocab['e1'].num_token, vocab['rel'].num_token)
    elif Config.model_name == 'ConvE':
        model = ConvE(vocab['e1'].num_token, vocab['rel'].num_token)
    elif Config.model_name == 'DistMult':
        model = DistMult(vocab['e1'].num_token, vocab['rel'].num_token)
    elif Config.model_name == 'ComplEx':
        model = Complex(vocab['e1'].num_token, vocab['rel'].num_token)
    else:
        log.info('Unknown model: {0}', Config.model_name)
        raise Exception("Unknown model!")

    train_batcher.at_batch_prepared_observers.insert(
        1, TargetIdx2MultiTarget(num_entities, 'e2_multi1',
                                 'e2_multi1_binary'))

    eta = ETAHook('train', print_every_x_batches=100)
    train_batcher.subscribe_to_events(eta)
    train_batcher.subscribe_to_start_of_epoch_event(eta)
    train_batcher.subscribe_to_events(
        LossHook('train', print_every_x_batches=100))

    if Config.cuda:
        model.cuda()
    if load:
        model_params = torch.load(model_path)
        print(model)
        total_param_size = []
        params = [(key, value.size(), value.numel())
                  for key, value in model_params.items()]
        for key, size, count in params:
            total_param_size.append(count)
            print(key, size, count)
        print(np.array(total_param_size).sum())
        model.load_state_dict(model_params)
        model.eval()
        ranking_and_hits(model, test_rank_batcher, vocab, 'test_evaluation')
        ranking_and_hits(model, dev_rank_batcher, vocab, 'dev_evaluation')
    else:
        model.init()

    total_param_size = []
    params = [value.numel() for value in model.parameters()]
    print(params)
    print(np.sum(params))

    max_mrr = 0
    count = 0
    max_count = 3
    opt = torch.optim.Adam(model.parameters(),
                           lr=Config.learning_rate,
                           weight_decay=Config.L2)
    for epoch in range(1, epochs + 1):
        model.train()
        for i, str2var in enumerate(train_batcher):
            opt.zero_grad()
            e1 = str2var['e1']
            rel = str2var['rel']
            e2_multi = str2var['e2_multi1_binary'].float()
            # label smoothing
            e2_multi = ((1.0 - Config.label_smoothing_epsilon) *
                        e2_multi) + (1.0 / e2_multi.size(1))

            pred = model.forward(e1, rel)
            loss = model.loss(pred, e2_multi)
            loss.backward()
            opt.step()

            train_batcher.state.loss = loss.cpu()

        print('saving to {0}'.format(model_path))
        torch.save(model.state_dict(), model_path)

        model.eval()
        with torch.no_grad():
            # ranking_and_hits(model, dev_rank_batcher, vocab, 'dev_evaluation')
            if epoch % 15 == 0:
                mrr = ranking_and_hits(model, dev_rank_batcher, vocab,
                                       'dev_evaluation')
                if mrr <= max_mrr:
                    count += 1
                    if count > max_count:
                        break
                else:
                    count = 0
                    max_mrr = mrr
    mrr_test = ranking_and_hits(model, test_rank_batcher, vocab,
                                'test_evaluation')
コード例 #5
0
ファイル: main.py プロジェクト: yangji9181/HNE
def main(args):

    if args.preprocess:
        print('start preprocessing', flush=True)
        preprocess(args, delete_data=True)
        print('finish preprocessing', flush=True)

    else:
        input_keys = ['e1', 'rel', 'rel_eval', 'e2', 'e2_multi1', 'e2_multi2']
        p = Pipeline(args.data, keys=input_keys)
        print(time.strftime("%a, %d %b %Y %H:%M:%S +0000", time.localtime()) +
              ': start loading vocabs',
              flush=True)
        p.load_vocabs()
        print(time.strftime("%a, %d %b %Y %H:%M:%S +0000", time.localtime()) +
              ': finish loading vocabs',
              flush=True)
        vocab = p.state['vocab']
        num_entities = vocab['e1'].num_token

        train_batcher = StreamBatcher(args.data,
                                      'train',
                                      args.batch_size,
                                      randomize=True,
                                      keys=input_keys,
                                      loader_threads=args.loader_threads)
        model = DistMult(args, vocab['e1'].num_token, vocab['rel'].num_token)
        train_batcher.at_batch_prepared_observers.insert(
            1,
            TargetIdx2MultiTarget(num_entities, 'e2_multi1',
                                  'e2_multi1_binary'))

        #         eta = ETAHook('train', print_every_x_batches=args.log_interval)
        #         train_batcher.subscribe_to_events(eta)
        #         train_batcher.subscribe_to_start_of_epoch_event(eta)
        #         train_batcher.subscribe_to_events(LossHook('train', print_every_x_batches=args.log_interval))

        model.cuda()
        model.init()

        total_param_size = []
        params = [value.numel() for value in model.parameters()]
        print(params, flush=True)
        print(np.sum(params), flush=True)

        opt = torch.optim.Adam(model.parameters(),
                               lr=args.lr,
                               weight_decay=args.l2)
        print(time.strftime("%a, %d %b %Y %H:%M:%S +0000", time.localtime()) +
              f': start training with epochs = {args.epochs}',
              flush=True)
        for epoch in range(args.epochs):
            model.train()
            #             sampled_batches = set(np.random.choice(train_batcher.num_batches, args.num_batches, replace=False))
            #             print(time.strftime("%a, %d %b %Y %H:%M:%S +0000", time.localtime()) + f': start epoch {epoch} with batches = {len(sampled_batches)} out of {train_batcher.num_batches}', flush=True)
            #             processed_count = 0
            for i, str2var in enumerate(train_batcher):
                #                 if i not in sampled_batches: continue
                #                 if processed_count%int(args.num_batches/1000)==0: print(time.strftime("%a, %d %b %Y %H:%M:%S +0000", time.localtime()) + f': start epoch {epoch} batch {i} = {processed_count}', flush=True)
                #                 processed_count += 1
                opt.zero_grad()
                e1 = str2var['e1']
                rel = str2var['rel']
                e2_multi = str2var['e2_multi1_binary'].float()
                e2_multi = ((1.0 - args.label_smoothing) *
                            e2_multi) + (1.0 / e2_multi.size(1))

                pred = model.forward(e1, rel)
                loss = model.loss(pred, e2_multi)
                loss.backward()
                opt.step()


#                 train_batcher.state.loss = loss.cpu()

            print(time.strftime("%a, %d %b %Y %H:%M:%S +0000",
                                time.localtime()) +
                  f': finish training epoch {epoch}',
                  flush=True)

        model.eval()
        output(args, vocab['e1'], model.emb_e.weight.detach().cpu().numpy())
コード例 #6
0
def main():
    if Config.process: preprocess(Config.dataset, delete_data=True)
    input_keys = ['e1', 'rel', 'rel_eval', 'e2', 'e2_multi1', 'e2_multi2']
    p = Pipeline(Config.dataset, keys=input_keys)
    p.load_vocabs()
    vocab = p.state['vocab']

    num_entities = vocab['e1'].num_token

    train_batcher = StreamBatcher(Config.dataset,
                                  'train',
                                  Config.batch_size,
                                  randomize=True,
                                  keys=input_keys)
    dev_rank_batcher = StreamBatcher(Config.dataset,
                                     'dev_ranking',
                                     Config.batch_size,
                                     randomize=False,
                                     loader_threads=4,
                                     keys=input_keys)
    test_rank_batcher = StreamBatcher(Config.dataset,
                                      'test_ranking',
                                      Config.batch_size,
                                      randomize=False,
                                      loader_threads=4,
                                      keys=input_keys)

    if Config.model_name is None:
        model = ConvE(vocab['e1'].num_token, vocab['rel'].num_token)
    elif Config.model_name == 'ConvE':
        model = ConvE(vocab['e1'].num_token, vocab['rel'].num_token)
    elif Config.model_name == 'DistMult':
        model = DistMult(vocab['e1'].num_token, vocab['rel'].num_token)
    elif Config.model_name == 'ComplEx':
        model = Complex(vocab['e1'].num_token, vocab['rel'].num_token)
    elif Config.model_name == 'RNNDist':
        model = RNNDist(vocab['e1'].num_token, vocab['rel'].num_token)
    else:
        log.info('Unknown model: {0}', Config.model_name)
        raise Exception("Unknown model!")

    train_batcher.at_batch_prepared_observers.insert(
        1, TargetIdx2MultiTarget(num_entities, 'e2_multi1',
                                 'e2_multi1_binary'))

    eta = ETAHook('train', print_every_x_batches=100)
    train_batcher.subscribe_to_events(eta)
    train_batcher.subscribe_to_start_of_epoch_event(eta)
    train_batcher.subscribe_to_events(
        LossHook('train', print_every_x_batches=100))
    if Config.dataset == 'ICEWS18':
        lengths = [
            1618, 956, 815, 1461, 1634, 1596, 1754, 1494, 800, 979, 1588, 1779,
            1831, 1762, 1566, 812, 820, 1707, 1988, 1845, 1670, 1695, 956, 930,
            1641, 1813, 1759, 1664, 1616, 1021, 998, 1668, 1589, 1720
        ]
    else:
        lengths = [
            1090, 730, 646, 939, 681, 783, 546, 526, 524, 586, 656, 741, 562,
            474, 493, 487, 474, 477, 460, 532, 348, 530, 402, 493, 503, 452,
            668, 512, 406, 467, 524, 563, 524, 418, 441, 487, 515, 475, 478,
            532, 387, 479, 485, 417, 542, 496, 487, 445, 504, 350, 432, 445,
            401, 570, 554, 504, 505, 483, 587, 441, 489, 501, 487, 513, 513,
            524, 655, 545, 599, 702, 734, 519, 603, 579, 537, 635, 437, 422,
            695, 575, 553, 485, 429, 663, 475, 673, 527, 559, 540, 591, 558,
            698, 422, 1145, 969, 1074, 888, 683, 677, 910, 902, 644, 777, 695,
            571, 656, 797, 576, 468, 676, 687, 549, 482, 1007, 778, 567, 813,
            788, 879, 557, 724, 850, 809, 685, 714, 554, 799, 727, 208, 946,
            979, 892, 859, 1092, 1038, 999, 1477, 1126, 1096, 1145, 955, 100,
            1264, 1287, 962, 1031, 1603, 1662, 1179, 1064, 1179, 1105, 1465,
            1176, 1219, 1137, 1112, 791, 829, 2347, 917, 913, 1107, 960, 850,
            1005, 1045, 871, 972, 921, 1019, 984, 1033, 848, 918, 699, 1627,
            1580, 1354, 1119, 1065, 1208, 1037, 1134, 980, 1249, 1031, 908,
            787, 819, 804, 764, 959, 1057, 770, 691, 816, 620, 788, 829, 895,
            1128, 1023, 1038, 1030, 1016, 991, 866, 878, 1013, 977, 914, 976,
            717, 740, 904, 912, 1043, 1117, 930, 1116, 1028, 946, 922, 1151,
            1092, 967, 1189, 1081, 1158, 943, 981, 1212, 1104, 941, 912, 1347,
            1241, 1479, 1188, 1152, 1164, 1167, 1173, 1280, 979, 142, 1458,
            910, 1126, 1053, 1083, 897, 1021, 1075, 881, 1054, 941, 927, 860,
            1081, 876, 1952, 1576, 1560, 1599, 1226, 1083, 964, 1059, 1179,
            982, 1032, 933, 877, 1032, 957, 884, 909, 846, 850, 798, 843, 1183,
            1108, 1185, 797, 915, 952, 1181, 744, 86, 889, 1151, 925, 1119,
            1115, 1036, 772, 1052, 837, 897, 1095, 926, 1034, 1031, 995, 907,
            969, 981, 1135, 915, 1161, 100, 1269, 1244, 1331, 1124, 1074, 1162,
            1159, 1078, 1311, 1210, 1308, 945, 1183, 1580, 1406, 1417, 1173,
            1348, 1274, 1179, 893, 1107, 950, 1028, 1055, 1059, 1244, 1082,
            1179, 1011, 955, 886, 865, 857
        ]
    if Config.cuda:
        model.cuda()
    if load:
        # if True:
        model_params = torch.load(model_path)
        print(model)
        total_param_size = []
        params = [(key, value.size(), value.numel())
                  for key, value in model_params.items()]
        for key, size, count in params:
            total_param_size.append(count)
            print(key, size, count)
        print(np.array(total_param_size).sum())
        model.load_state_dict(model_params)
        model.eval()
        ranking_and_hits(model, test_rank_batcher, vocab, 'test_evaluation')
        # ranks = ranking_and_hits2(model, test_rank_batcher, vocab, 'test_evaluation')
        print(len(ranks))

        mrr = []
        curr_step = 0
        for i in range(len(lengths)):
            rr = np.array(ranks[curr_step:curr_step + 2 * lengths[i]])
            mrr.append(np.mean(1 / rr))

            curr_step += 2 * lengths[i]
        with open(Config.dataset + 'mrr.txt', 'w') as f:
            for i, mr in enumerate(mrr):
                print("MRR (filtered) @ {}th day: {:.6f}".format(i, mr))
                f.write(str(mr) + '\n')
        h10 = []
        curr_step = 0
        for i in range(len(lengths)):
            rr = np.array(ranks[curr_step:curr_step + 2 * lengths[i]])
            h10.append(np.mean(rr <= 10))
        with open(Config.dataset + 'h10.txt', 'w') as f:
            for i, mr in enumerate(h10):
                print("h10 (filtered) @ {}th day: {:.6f}".format(i, mr))
                f.write(str(mr) + '\n')
        h10 = []
        for i in range(len(lengths)):
            rr = np.array(ranks[curr_step:curr_step + 2 * lengths[i]])
            h10.append(np.mean(rr <= 3))
        with open(Config.dataset + 'h3.txt', 'w') as f:
            for i, mr in enumerate(h10):
                print("h10 (filtered) @ {}th day: {:.6f}".format(i, mr))
                f.write(str(mr) + '\n')

        h10 = []

        for i in range(len(lengths)):
            rr = np.array(ranks[curr_step:curr_step + 2 * lengths[i]])
            h10.append(np.mean(rr <= 1))
        with open(Config.dataset + 'h1.txt', 'w') as f:
            for i, mr in enumerate(h10):
                print("h10 (filtered) @ {}th day: {:.6f}".format(i, mr))
                f.write(str(mr) + '\n')
        print("length", len(ranks))
        print("length_2", 2 * sum(lengths))

        # ranking_and_hits(model, dev_rank_batcher, vocab, 'dev_evaluation')
    else:
        model.init()

    total_param_size = []
    params = [value.numel() for value in model.parameters()]
    print(params)
    print(np.sum(params))

    opt = torch.optim.Adam(model.parameters(),
                           lr=Config.learning_rate,
                           weight_decay=Config.L2)
    for epoch in range(epochs):
        # break
        model.train()
        for i, str2var in enumerate(train_batcher):
            opt.zero_grad()
            e1 = str2var['e1']
            rel = str2var['rel']
            e2_multi = str2var['e2_multi1_binary'].float()

            # label smoothing
            # e2_multi = ((1.0-Config.label_smoothing_epsilon)*e2_multi) + (1.0/e2_multi.size(1))
            # print("this",Config.label_smoothing_epsilon, e2_multi.size(1))

            pred = model.forward(e1, rel)
            # loss = model.loss(pred, e2_multi)
            # #
            loss = torch.zeros(1).cuda()
            for j in range(128):
                position = torch.nonzero(e2_multi[j])[0].cuda()
                label = torch.cat(
                    [torch.ones(len(position)),
                     torch.zeros(len(position))]).cuda()
                neg_position = torch.randint(e2_multi.shape[1],
                                             (len(position), )).long().cuda()
                position = torch.cat([position, neg_position])
                loss += model.loss(pred[j, position], label)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           1.0)  # clip gradients
            opt.step()

            train_batcher.state.loss = loss.cpu()

        print('saving to {0}'.format(model_path))
        torch.save(model.state_dict(), model_path)

        model.eval()
        with torch.no_grad():
            # ranking_and_hits(model, dev_rank_batcher, vocab, 'dev_evaluation')
            if epoch == 50:
                ranks = ranking_and_hits(model, test_rank_batcher, vocab,
                                         'test_evaluation')
コード例 #7
0
ファイル: main.py プロジェクト: zzw-x/CPL
def main():
    if Config.process: preprocess(Config.dataset, delete_data=True)
    input_keys = ['e1', 'rel', 'rel_eval', 'e2', 'e2_multi1', 'e2_multi2']
    p = Pipeline(Config.dataset, keys=input_keys)
    p.load_vocabs()
    vocab = p.state['vocab']

    num_entities = vocab['e1'].num_token

    train_batcher = StreamBatcher(Config.dataset, 'train', Config.batch_size, randomize=True, keys=input_keys)
    dev_rank_batcher = StreamBatcher(Config.dataset, 'dev_ranking', Config.batch_size, randomize=False, loader_threads=4, keys=input_keys)
    test_rank_batcher = StreamBatcher(Config.dataset, 'test_ranking', Config.batch_size, randomize=False, loader_threads=4, keys=input_keys)


    if Config.model_name is None:
        model = ConvE(vocab['e1'].num_token, vocab['rel'].num_token)
    elif Config.model_name == 'ConvE':
        model = ConvE(vocab['e1'].num_token, vocab['rel'].num_token)
    elif Config.model_name == 'DistMult':
        model = DistMult(vocab['e1'].num_token, vocab['rel'].num_token)
    elif Config.model_name == 'ComplEx':
        model = Complex(vocab['e1'].num_token, vocab['rel'].num_token)
    else:
        log.info('Unknown model: {0}', Config.model_name)
        raise Exception("Unknown model!")

    train_batcher.at_batch_prepared_observers.insert(1,TargetIdx2MultiTarget(num_entities, 'e2_multi1', 'e2_multi1_binary'))


    eta = ETAHook('train', print_every_x_batches=100)
    train_batcher.subscribe_to_events(eta)
    train_batcher.subscribe_to_start_of_epoch_event(eta)
    train_batcher.subscribe_to_events(LossHook('train', print_every_x_batches=100))

    if Config.cuda:
        model.cuda()
    if load:
        model_params = torch.load(model_path)
        print(model)
        total_param_size = []
        params = [(key, value.size(), value.numel()) for key, value in model_params.items()]
        for key, size, count in params:
            total_param_size.append(count)
            print(key, size, count)
        print(np.array(total_param_size).sum())
        model.load_state_dict(model_params)
        model.eval()
        ranking_and_hits(model, test_rank_batcher, vocab, 'test_evaluation',epochs,True)
        ranking_and_hits(model, dev_rank_batcher, vocab, 'dev_evaluation',epochs,False)
    else:
        model.init()

    total_param_size = []
    params = [value.numel() for value in model.parameters()]
    print(params)
    print(np.sum(params))

    opt = torch.optim.Adam(model.parameters(), lr=Config.learning_rate, weight_decay=Config.L2)
    for epoch in range(epochs):
        model.train()
        for i, str2var in tqdm(enumerate(train_batcher)):
            opt.zero_grad()
            e1 = str2var['e1']
            rel = str2var['rel']
            e2_multi = str2var['e2_multi1_binary'].float()
            # label smoothing
            pred = model.forward(e1, rel)
            loss = torch.zeros(1).cuda()
            for j in range(128):
                position = torch.nonzero(e2_multi[j])[0].cuda()
                label = torch.cat([torch.ones(len(position)), torch.zeros(len(position))]).cuda()
                neg_position = torch.randint(e2_multi.shape[1], (len(position),)).long().cuda()
                position = torch.cat([position, neg_position])
                loss += model.loss(pred[j, position], label)
            loss.backward()
            opt.step()

            train_batcher.state.loss = loss.cpu()


        print('saving to {0}'.format(model_path))
        torch.save(model.state_dict(), model_path)

        model.eval()
        with torch.no_grad():
            if epoch % 100 == 0:
                if epoch > 0:
                    ranking_and_hits(model, test_rank_batcher, vocab, Config.dataset + "-" + Config.model_name,epoch,False)
            if epoch + 1 == epochs:
                ranking_and_hits(model, test_rank_batcher, vocab, Config.dataset,epoch,True)
コード例 #8
0
ファイル: train.py プロジェクト: kmswin1/TransE
def main():
    opts = get_train_args()
    print("load data ...")
    data = DataSet('data/modified_triples.txt')
    dataloader = DataLoader(data, shuffle=True, batch_size=opts.batch_size)
    print("load model ...")
    if opts.model_type == 'transe':
        model = TransE(opts, data.ent_tot, data.rel_tot)
    elif opts.model_type == "distmult":
        model = DistMult(opts, data.ent_tot, data.rel_tot)
    if opts.optimizer == 'Adam':
        optimizer = optim.Adam(model.parameters(), lr=opts.lr)
    elif opts.optimizer == 'SGD':
        optimizer = optim.SGD(model.parameters(), lr=opts.lr)
    model.cuda()
    model.relation_normalize()
    loss = torch.nn.MarginRankingLoss(margin=opts.margin)

    print("start training")
    for epoch in range(1, opts.epochs + 1):
        print("epoch : " + str(epoch))
        model.train()
        epoch_start = time.time()
        epoch_loss = 0
        tot = 0
        cnt = 0
        for i, batch_data in enumerate(dataloader):
            optimizer.zero_grad()
            batch_h, batch_r, batch_t, batch_n = batch_data
            batch_h = torch.LongTensor(batch_h).cuda()
            batch_r = torch.LongTensor(batch_r).cuda()
            batch_t = torch.LongTensor(batch_t).cuda()
            batch_n = torch.LongTensor(batch_n).cuda()
            pos_score, neg_score, dist = model.forward(batch_h, batch_r,
                                                       batch_t, batch_n)
            pos_score = pos_score.cpu()
            neg_score = neg_score.cpu()
            dist = dist.cpu()
            train_loss = loss(pos_score, neg_score,
                              torch.ones(pos_score.size(-1))) + dist
            train_loss.backward()
            optimizer.step()
            batch_loss = torch.sum(train_loss)
            epoch_loss += batch_loss
            batch_size = batch_h.size(0)
            tot += batch_size
            cnt += 1
            print('\r{:>10} epoch {} progress {} loss: {}\n'.format(
                '', epoch, tot / data.__len__(), train_loss),
                  end='')
        end = time.time()
        time_used = end - epoch_start
        epoch_loss /= cnt
        print('one epoch time: {} minutes'.format(time_used / 60))
        print('{} epochs'.format(epoch))
        print('epoch {} loss: {}'.format(epoch, epoch_loss))

        if epoch % opts.save_step == 0:
            print("save model...")
            model.entity_normalize()
            torch.save(model.state_dict(), 'model.pt')

    print("save model...")
    model.entity_normalize()
    torch.save(model.state_dict(), 'model.pt')
    print("[Saving embeddings of whole entities & relations...]")
    save_embeddings(model, opts, data.id2ent, data.id2rel)
    print("[Embedding results are saved successfully.]")