Exemplo n.º 1
0
def train_data_iterator(entities, triples):
    entities_1, entities_2 = entities
    triples_1, triples_2 = triples
    loader_head_1 = DataLoader(TrainDataset(triples_1, entities_1,
                                            config.neg_size, "head-batch"),
                               batch_size=config.batch_size,
                               shuffle=True,
                               num_workers=max(0, config.cpu_num // 3),
                               collate_fn=TrainDataset.collate_fn)
    loader_tail_1 = DataLoader(TrainDataset(triples_1, entities_1,
                                            config.neg_size, "tail-batch"),
                               batch_size=config.batch_size,
                               shuffle=True,
                               num_workers=max(0, config.cpu_num // 3),
                               collate_fn=TrainDataset.collate_fn)
    loader_head_2 = DataLoader(TrainDataset(triples_2, entities_2,
                                            config.neg_size, "head-batch"),
                               batch_size=config.batch_size,
                               shuffle=True,
                               num_workers=max(0, config.cpu_num // 3),
                               collate_fn=TrainDataset.collate_fn)
    loader_tail_2 = DataLoader(TrainDataset(triples_2, entities_2,
                                            config.neg_size, "tail-batch"),
                               batch_size=config.batch_size,
                               shuffle=True,
                               num_workers=max(0, config.cpu_num // 3),
                               collate_fn=TrainDataset.collate_fn)
    return BidirectionalOneShotIterator(loader_head_1, loader_tail_1,
                                        loader_head_2, loader_tail_2)
Exemplo n.º 2
0
    def train_NoiGAN(trainer):
        trainer.embedding_model.eval()

        st = time.time()
        trainer.positive_triples = trainer.find_positive_triples()
        et = time.time()
        print("take %d s to find positive triples" % (et - st))

        trainer.train_dataset_head = TrainDataset(
            trainer.train_triples, trainer.args.nentity,
            trainer.args.nrelation, trainer.args.negative_sample_size,
            "head-batch")
        trainer.train_dataset_head.triples = trainer.positive_triples
        trainer.train_dataset_tail = TrainDataset(
            trainer.train_triples, trainer.args.nentity,
            trainer.args.nrelation, trainer.args.negative_sample_size,
            "tail-batch")
        trainer.train_dataset_tail.triples = trainer.positive_triples
        trainer.train_dataloader_head = DataLoader(
            trainer.train_dataset_head,
            batch_size=128,
            shuffle=True,
            num_workers=5,
            collate_fn=TrainDataset.collate_fn)
        trainer.train_dataloader_tail = DataLoader(
            trainer.train_dataset_tail,
            batch_size=128,
            shuffle=True,
            num_workers=5,
            collate_fn=TrainDataset.collate_fn)
        trainer.train_iterator = BidirectionalOneShotIterator(
            trainer.train_dataloader_head, trainer.train_dataloader_tail)
        epochs = 1500
        epoch_reward, epoch_loss, avg_reward = 0, 0, 0
        for i in range(epochs):
            trainer.generator.train()
            positive_sample, negative_sample, subsampling_weight, mode = next(
                trainer.train_iterator)
            if trainer.args.cuda:
                positive_sample = positive_sample.cuda()  # [batch_size, 3]
                negative_sample = negative_sample.cuda(
                )  # [batch_size, negative_sample_size]
            #$ embed()
            pos, neg, scores, sample_idx, row_idx = trainer.generate(
                positive_sample, negative_sample, mode)
            loss, rewards = trainer.discriminate(pos, neg, mode)
            epoch_reward += torch.sum(rewards)
            epoch_loss += loss
            rewards = rewards - avg_reward

            trainer.generator.zero_grad()
            log_probs = F.log_softmax(scores, dim=1)
            reinforce_loss = torch.sum(
                Variable(rewards) * log_probs[row_idx.cuda(), sample_idx.data])
            reinforce_loss.backward()
            trainer.gen_optimizer.step()
            trainer.generator.eval()
Exemplo n.º 3
0
def train_data_iterator(train_triples, ent_num):
    modes = ["head-batch", "tail-batch"]
    datasets = [
        DataLoader(TrainDataset(train_triples, ent_num, config.neg_size, mode),
                   batch_size=config.batch_size,
                   shuffle=True,
                   num_workers=4,
                   collate_fn=TrainDataset.collate_fn) for mode in modes
    ]
    return BidirectionalOneShotIterator(datasets[0], datasets[1])
Exemplo n.º 4
0
def train_data_iterator(train_triples, ent_num):
    dataloader_head = DataLoader(TrainDataset(train_triples, ent_num,
                                              config.neg_size, "head-batch"),
                                 batch_size=config.batch_size,
                                 shuffle=True,
                                 num_workers=max(0, config.cpu_num // 3),
                                 collate_fn=TrainDataset.collate_fn)
    dataloader_tail = DataLoader(TrainDataset(train_triples, ent_num,
                                              config.neg_size, "tail-batch"),
                                 batch_size=config.batch_size,
                                 shuffle=True,
                                 num_workers=max(0, config.cpu_num // 3),
                                 collate_fn=TrainDataset.collate_fn)
    return BidirectionalOneShotIterator(dataloader_head, dataloader_tail)
Exemplo n.º 5
0
 def init_dataset(self):
     train_dataloader_head = DataLoader(TrainDataset(
         self.train_triples, self.entity_count, self.attr_count,
         self.value_count, 512, 'head-batch'),
                                        batch_size=1024,
                                        shuffle=False,
                                        num_workers=4,
                                        collate_fn=TrainDataset.collate_fn)
     train_dataloader_tail = DataLoader(TrainDataset(
         self.train_triples, self.entity_count, self.attr_count,
         self.value_count, 512, 'tail-batch'),
                                        batch_size=1024,
                                        shuffle=False,
                                        num_workers=4,
                                        collate_fn=TrainDataset.collate_fn)
     self.train_iterator = BidirectionalOneShotIterator(
         train_dataloader_head, train_dataloader_tail)
Exemplo n.º 6
0
def construct_dataloader(args, train_triples, nentity, nrelation):
    train_dataloader_head = DataLoader(TrainDataset(train_triples, nentity,
                                                    nrelation,
                                                    args.negative_sample_size,
                                                    'head-batch'),
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       num_workers=max(1, args.cpu_num // 2),
                                       collate_fn=TrainDataset.collate_fn)

    train_dataloader_tail = DataLoader(TrainDataset(train_triples, nentity,
                                                    nrelation,
                                                    args.negative_sample_size,
                                                    'tail-batch'),
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       num_workers=max(1, args.cpu_num // 2),
                                       collate_fn=TrainDataset.collate_fn)

    train_iterator = BidirectionalOneShotIterator(train_dataloader_head,
                                                  train_dataloader_tail)

    return train_iterator
Exemplo n.º 7
0
def main(args):
    # if (not args.do_train) and (not args.do_valid) and (not args.do_test) and (not args.do_case) and (not args.fire_test) and (not args.rel_do_test) :
    #     raise ValueError('one of train/val/test mode must be choosed.')

    if args.init_checkpoint:
        override_config(args)
    elif args.data_path is None:
        raise ValueError('one of init_checkpoint/data_path must be choosed.')

    if args.do_train and args.save_path is None:
        raise ValueError('Where do you want to save your trained model?')

    if args.save_path and not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    # Write logs to checkpoint and console
    set_logger(args)

    with open(os.path.join(args.data_path, 'entities.dict')) as fin:
        entity2id = dict()
        for line in fin:
            eid, entity = line.strip().split('\t')
            entity2id[entity] = int(eid)

    with open(os.path.join(args.data_path, 'relations.dict')) as fin:
        relation2id = dict()
        for line in fin:
            rid, relation = line.strip().split('\t')
            relation2id[relation] = int(rid)

    # Read regions for Countries S* datasets
    if args.countries:
        regions = list()
        with open(os.path.join(args.data_path, 'regions.list')) as fin:
            for line in fin:
                region = line.strip()
                regions.append(entity2id[region])
        args.regions = regions

    nentity = len(entity2id)
    nrelation = len(relation2id)

    args.nentity = nentity
    args.nrelation = nrelation

    logging.info('Model: %s' % args.model)
    logging.info('Data Path: %s' % args.data_path)
    logging.info('#entity: %d' % nentity)
    logging.info('#relation: %d' % nrelation)

    train_triples = read_triple(os.path.join(args.data_path, 'train.txt'),
                                entity2id, relation2id)
    # train_triples = read_triple(os.path.join(args.data_path, 'train_1900.txt'), entity2id, relation2id)
    logging.info('#train: %d' % len(train_triples))
    valid_triples = read_triple(os.path.join(args.data_path, 'valid.txt'),
                                entity2id, relation2id)
    logging.info('#valid: %d' % len(valid_triples))
    test_triples = read_triple(os.path.join(args.data_path, 'test.txt'),
                               entity2id, relation2id)
    # seen_test_triples = read_triple(os.path.join(args.data_path, 'seen_test.txt'), entity2id, relation2id)
    # test_triples = read_triple(os.path.join(args.data_path, 'test_alone_triples.txt'), entity2id, relation2id)

    # def file_name(file_dir):
    #     for root, dirs, files in os.walk(file_dir):
    #         return files
    # rel_dataset = file_name("/scratch/mengyali/workspace/rotate/data/wn18rr/rel_dataset_txt/")
    # for rel in rel_dataset:
    #     test_triples = read_triple(os.path.join(args.data_path, "rel_dataset_txt/"+str(rel)), entity2id, relation2id)
    #     logging.info('#test: %d' % len(test_triples))

    #All true triples
    all_true_triples = train_triples + valid_triples + test_triples

    kge_model = KGEModel(
        model_name=args.model,
        nentity=nentity,
        nrelation=nrelation,
        hidden_dim=args.hidden_dim,
        gamma=args.gamma,
        double_entity_embedding=args.double_entity_embedding,
        double_relation_embedding=args.double_relation_embedding)

    logging.info('Model Parameter Configuration:')
    for name, param in kge_model.named_parameters():
        logging.info('Parameter %s: %s, require_grad = %s' %
                     (name, str(param.size()), str(param.requires_grad)))

    if args.cuda:
        kge_model = kge_model.cuda()

    if args.do_train:
        # Set training dataloader iterator
        train_dataloader_head = DataLoader(
            TrainDataset(train_triples, nentity, nrelation,
                         args.negative_sample_size, 'head-batch', entity2id,
                         relation2id, args.data_path, args.typecons),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TrainDataset.collate_fn)

        train_dataloader_tail = DataLoader(
            TrainDataset(train_triples, nentity, nrelation,
                         args.negative_sample_size, 'tail-batch', entity2id,
                         relation2id, args.data_path, args.typecons),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TrainDataset.collate_fn)

        train_iterator = BidirectionalOneShotIterator(train_dataloader_head,
                                                      train_dataloader_tail)

        # Set training configuration
        current_learning_rate = args.learning_rate
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                            kge_model.parameters()),
                                     lr=current_learning_rate)
        if args.warm_up_steps:
            warm_up_steps = args.warm_up_steps
        else:
            warm_up_steps = args.max_steps // 2

    if args.init_checkpoint:
        # Restore model from checkpoint directory
        logging.info('Loading checkpoint %s...' % args.init_checkpoint)
        checkpoint = torch.load(
            os.path.join(args.init_checkpoint, 'checkpoint'))
        init_step = checkpoint['step']
        kge_model.load_state_dict(checkpoint['model_state_dict'])
        if args.do_train:
            current_learning_rate = checkpoint['current_learning_rate']
            warm_up_steps = checkpoint['warm_up_steps']
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    else:
        logging.info('Ramdomly Initializing %s Model...' % args.model)
        init_step = 0

    step = init_step

    logging.info('Start Training...')
    logging.info('init_step = %d' % init_step)
    logging.info('batch_size = %d' % args.batch_size)
    logging.info('negative_adversarial_sampling = %d' %
                 args.negative_adversarial_sampling)
    logging.info('hidden_dim = %d' % args.hidden_dim)
    logging.info('gamma = %f' % args.gamma)
    logging.info('negative_adversarial_sampling = %s' %
                 str(args.negative_adversarial_sampling))
    if args.negative_adversarial_sampling:
        logging.info('adversarial_temperature = %f' %
                     args.adversarial_temperature)

    # Set valid dataloader as it would be evaluated during training

    if args.do_train:
        logging.info('learning_rate = %d' % current_learning_rate)

        training_logs = []

        #Training Loop
        for step in range(init_step, args.max_steps):

            log = kge_model.train_step(kge_model, optimizer, train_iterator,
                                       args)

            training_logs.append(log)

            if step >= warm_up_steps:
                current_learning_rate = current_learning_rate / 10
                logging.info('Change learning_rate to %f at step %d' %
                             (current_learning_rate, step))
                optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                                    kge_model.parameters()),
                                             lr=current_learning_rate)
                warm_up_steps = warm_up_steps * 3

            if step % args.save_checkpoint_steps == 0:
                save_variable_list = {
                    'step': step,
                    'current_learning_rate': current_learning_rate,
                    'warm_up_steps': warm_up_steps
                }
                save_model(kge_model, optimizer, save_variable_list, args)

            if step % args.log_steps == 0:
                metrics = {}
                for metric in training_logs[0].keys():
                    metrics[metric] = sum(
                        [log[metric]
                         for log in training_logs]) / len(training_logs)
                log_metrics('Training average', step, metrics)
                training_logs = []

            if args.do_valid and step % args.valid_steps == 0:
                logging.info('Evaluating on Valid Dataset...')
                metrics = kge_model.test_step(kge_model, valid_triples,
                                              all_true_triples, args)
                log_metrics('Valid', step, metrics)

        save_variable_list = {
            'step': step,
            'current_learning_rate': current_learning_rate,
            'warm_up_steps': warm_up_steps
        }
        save_model(kge_model, optimizer, save_variable_list, args)

    if args.do_valid:
        logging.info('Evaluating on Valid Dataset...')
        metrics = kge_model.test_step(kge_model, valid_triples,
                                      all_true_triples, args)
        log_metrics('Valid', step, metrics)

    if args.do_test:
        logging.info('Evaluating on Test Dataset...')
        metrics = kge_model.test_step(kge_model, test_triples,
                                      all_true_triples, args)
        # logging.info("----------------------"+str(rel)+"---------------------\n")
        log_metrics('Test', step, metrics)

    if args.get_metric:
        logging.info(
            'Evaluating on Test Dataset and Show the metric in two sides...')
        head_metrics, tail_metrics = kge_model.get_metric(
            kge_model, test_triples, all_true_triples, args)
        logging.info("--------------- Head ------------\n")
        log_metrics('Test-Head', step, head_metrics)
        logging.info("--------------- Tail ------------\n")
        log_metrics('Test-Tail', step, tail_metrics)

    if args.evaluate_train:
        logging.info('Evaluating on Training Dataset...')
        metrics = kge_model.test_step(kge_model, train_triples,
                                      all_true_triples, args)
        log_metrics('Test', step, metrics)

    # Codes about StAR
    if args.get_scores:
        for type in ['dev', 'test']:
            kge_model.get_scores(kge_model, type, valid_triples,
                                 all_true_triples, args)

    if args.get_model_dataset:
        kge_model.get_model_dataset(kge_model, 'train', valid_triples,
                                    all_true_triples, args)

    if args.get_cases:
        kge_model.get_cases(kge_model, test_triples, all_true_triples, args)

    if args.rel_do_test:
        train_rel_dict = collections.Counter([ex[1] for ex in train_triples])
        rel_dict = dict()
        logging.info('Evaluating on Each Test Dataset Devided by Relation...')
        test_ex_dict_departby_rel = dict.fromkeys(relation2id.keys(), [])
        for _rel in relation2id:
            test_ex_dict_departby_rel[_rel] = [
                _ex for _ex in test_triples if _ex[1] == relation2id[_rel]
            ]

        for _rel in test_ex_dict_departby_rel.keys():
            _rel_test_triples = test_ex_dict_departby_rel[_rel]
            _rel_data = [
                train_rel_dict[relation2id[_rel]],
                len(_rel_test_triples)
            ]
            if len(_rel_test_triples) != 0:
                metrics = kge_model.test_step(kge_model, _rel_test_triples,
                                              all_true_triples, args)
                _rel_data.extend([
                    round(metrics['HITS@1'], 3),
                    round(metrics['HITS@3'], 3),
                    round(metrics['HITS@10'], 3),
                    round(metrics['MR'], 1),
                    round(metrics['MRR'], 3)
                ])
            else:
                _rel_data.extend([0, 0, 0, 0, 0])
            rel_dict[_rel] = _rel_data

        sorted_rel = sorted(rel_dict.items(),
                            key=lambda x: x[1][0],
                            reverse=True)

        save_dir = args.init_checkpoint
        with open(join(save_dir, "rel_unbalanced.txt"), "w",
                  encoding="utf-8") as fp:
            fp.write(str(sorted_rel))
        torch.save(sorted_rel, join(save_dir, "rel_unbalanced"))

        # SaveInExcle(sorted_rel, save_dir)
        print("explore unbalanced finished")
def main(args):
    if args.seed != -1:
        torch.manual_seed(args.seed)
        if args.cuda:
            torch.cuda.manual_seed(args.seed)

    if (not args.do_train) and (not args.do_valid) and (not args.do_test):
        raise ValueError('one of train/val/test mode must be choosed.')

    if args.init_checkpoint:
        override_config(args)
    elif args.data_path is None:
        raise ValueError('one of init_checkpoint/data_path must be choosed.')

    if args.do_train and args.save_path is None:
        raise ValueError('Where do you want to save your trained model?')

    if args.save_path and not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
        if args.do_train and args.do_valid:
            if not os.path.exists("%s/best/" % args.save_path):
                os.makedirs("%s/best/" % args.save_path)

    # Write logs to checkpoint and console
    set_logger(args)

    with open(os.path.join(args.data_path, 'entities.dict')) as fin:
        entity2id = dict()
        for line in fin:
            eid, entity = line.strip().split('\t')
            entity2id[entity] = int(eid)

    with open(os.path.join(args.data_path, 'relations.dict')) as fin:
        relation2id = dict()
        for line in fin:
            rid, relation = line.strip().split('\t')
            relation2id[relation] = int(rid)

    # Read regions for Countries S* datasets
    if args.countries:
        regions = list()
        with open(os.path.join(args.data_path, 'regions.list')) as fin:
            for line in fin:
                region = line.strip()
                regions.append(entity2id[region])
        args.regions = regions

    nentity = len(entity2id)
    nrelation = len(relation2id)

    args.nentity = nentity
    args.nrelation = nrelation

    logging.info('Model: %s' % args.model)
    logging.info('Data Path: %s' % args.data_path)
    logging.info('#entity: %d' % nentity)
    logging.info('#relation: %d' % nrelation)

    train_triples = read_triple(os.path.join(args.data_path, 'train.txt'),
                                entity2id, relation2id)
    logging.info('#train: %d' % len(train_triples))
    valid_triples = read_triple(os.path.join(args.data_path, 'valid.txt'),
                                entity2id, relation2id)
    logging.info('#valid: %d' % len(valid_triples))
    test_triples = read_triple(os.path.join(args.data_path, 'test.txt'),
                               entity2id, relation2id)
    logging.info('#test: %d' % len(test_triples))

    train_triples_tsr = torch.LongTensor(train_triples).transpose(
        0, 1)  #idx X batch
    #All true triples
    all_true_triples = train_triples + valid_triples + test_triples
    #if args.use_gnn:
    #    assert False
    #    #kge_model = GNN_KGEModel(
    #    #    model_name=args.model,
    #    #    nentity=nentity,
    #    #    nrelation=nrelation,
    #    #    hidden_dim=args.hidden_dim,
    #    #    gamma=args.gamma,
    #    #    num_layers=args.gnn_layers,
    #    #    args = args,
    #    #    dropout=args.dropout,
    #    #    double_entity_embedding=args.double_entity_embedding,
    #    #    double_relation_embedding=args.double_relation_embedding,
    #    #)
    #else:
    kge_model = KGEModel(
        model_name=args.model,
        nentity=nentity,
        nrelation=nrelation,
        hidden_dim=args.hidden_dim,
        gamma=args.gamma,
        args=args,
        double_entity_embedding=args.double_entity_embedding,
        double_relation_embedding=args.double_relation_embedding,
    )

    logging.info('Model Configuration:')
    logging.info(str(kge_model))
    logging.info('Model Parameter Configuration:')
    for name, param in kge_model.named_parameters():
        logging.info('Parameter %s: %s, require_grad = %s' %
                     (name, str(param.size()), str(param.requires_grad)))

    if args.cuda:
        kge_model = kge_model.cuda()
        train_triples_tsr = train_triples_tsr.cuda()
    #kge_model.build_cxt_triple_map(train_triples)
    if args.do_train:
        # Set training dataloader iterator
        if args.same_head_tail:
            #shuffle train_triples first and no shuffle within dataloaders. So both head and tail will share the same idx
            shuffle(train_triples)
            train_dataloader_head = DataLoader(
                TrainDataset(train_triples, nentity, nrelation,
                             args.negative_sample_size, 'head-batch'),
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=max(1, args.cpu_num // 2),
                collate_fn=TrainDataset.collate_fn)

            train_dataloader_tail = DataLoader(
                TrainDataset(train_triples, nentity, nrelation,
                             args.negative_sample_size, 'tail-batch'),
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=max(1, args.cpu_num // 2),
                collate_fn=TrainDataset.collate_fn)
        else:
            train_dataloader_head = DataLoader(
                TrainDataset(train_triples, nentity, nrelation,
                             args.negative_sample_size, 'head-batch'),
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=max(1, args.cpu_num // 2),
                collate_fn=TrainDataset.collate_fn)

            train_dataloader_tail = DataLoader(
                TrainDataset(train_triples, nentity, nrelation,
                             args.negative_sample_size, 'tail-batch'),
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=max(1, args.cpu_num // 2),
                collate_fn=TrainDataset.collate_fn)
        train_iterator = BidirectionalOneShotIterator(train_dataloader_head,
                                                      train_dataloader_tail)
        #else:
        #    train_dataloader_rel = DataLoader(
        #        TrainDataset(train_triples, nentity, nrelation,
        #            args.negative_sample_head_size*args.negative_sample_tail_size,
        #            'rel-batch',
        #            negative_sample_head_size =args.negative_sample_head_size,
        #            negative_sample_tail_size =args.negative_sample_tail_size,
        #            half_correct=args.negative_sample_half_correct),
        #        batch_size=args.batch_size,
        #        shuffle=True,
        #        num_workers=max(1, args.cpu_num//2),
        #        collate_fn=TrainDataset.collate_fn
        #    )
        #    train_iterator = BidirectionalOneShotIterator.one_shot_iterator(train_dataloader_rel)
        #    tail_only = True

        # Set training configuration
        current_learning_rate = args.learning_rate
        optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, kge_model.parameters()),
            lr=current_learning_rate,
            weight_decay=args.weight_decay,
        )

        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=1,
                                                    gamma=0.5,
                                                    last_epoch=-1)
        #if args.warm_up_steps:
        #    warm_up_steps = args.warm_up_steps
        #else:
        #    warm_up_steps = args.max_steps // 2

    if args.init_checkpoint:
        # Restore model from checkpoint directory
        logging.info('Loading checkpoint %s...' % args.init_checkpoint)
        checkpoint = torch.load(
            os.path.join(args.init_checkpoint, 'checkpoint'))
        init_step = checkpoint['step']
        if 'score_weight' in kge_model.state_dict(
        ) and 'score_weight' not in checkpoint['model_state_dict']:
            checkpoint['model_state_dict'][
                'score_weights'] = kge_model.state_dict()['score_weights']
        kge_model.load_state_dict(checkpoint['model_state_dict'])
        if args.do_train:
            current_learning_rate = checkpoint['current_learning_rate']
            #warm_up_steps = checkpoint['warm_up_steps']
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        else:
            current_learning_rate = 0
    elif args.init_embedding:
        logging.info('Loading pretrained embedding %s ...' %
                     args.init_embedding)
        if kge_model.entity_embedding is not None:
            entity_embedding = np.load(
                os.path.join(args.init_embedding, 'entity_embedding.npy'))
            relation_embedding = np.load(
                os.path.join(args.init_embedding, 'relation_embedding.npy'))
            entity_embedding = torch.from_numpy(entity_embedding).to(
                kge_model.entity_embedding.device)
            relation_embedding = torch.from_numpy(relation_embedding).to(
                kge_model.relation_embedding.device)
            kge_model.entity_embedding.data[:entity_embedding.
                                            size(0)] = entity_embedding
            kge_model.relation_embedding.data[:relation_embedding.
                                              size(0)] = relation_embedding
        init_step = 1
        current_learning_rate = 0
    else:
        logging.info('Ramdomly Initializing %s Model...' % args.model)
        init_step = 1

    step = init_step

    logging.info('Start Training...')
    logging.info('init_step = %d' % init_step)
    logging.info('learning_rate = %.5f' % current_learning_rate)
    logging.info('batch_size = %d' % args.batch_size)
    logging.info('negative_adversarial_sampling = %d' %
                 args.negative_adversarial_sampling)
    logging.info('hidden_dim = %d' % args.hidden_dim)
    logging.info('gamma = %f' % args.gamma)
    logging.info('negative_adversarial_sampling = %s' %
                 str(args.negative_adversarial_sampling))
    if args.negative_adversarial_sampling:
        logging.info('adversarial_temperature = %f' %
                     args.adversarial_temperature)

    # Set valid dataloader as it would be evaluated during training

    #loss_func = nn.BCEWithLogitsLoss(reduction="none") if args.use_bceloss else nn.LogSigmoid()
    if args.use_bceloss:
        loss_func = nn.BCELoss(reduction="none")
    elif args.use_softmarginloss:
        loss_func = nn.SoftMarginLoss(reduction="none")
    else:
        loss_func = nn.LogSigmoid()
    #kge_model.cluster_relation_entity_embedding(args.context_cluster_num, args.context_cluster_scale)
    if args.do_train:
        training_logs = []
        best_metrics = None
        #Training Loop
        optimizer.zero_grad()
        for step in range(init_step, args.max_steps + 1):
            if step % args.update_freq == 1 or args.update_freq == 1:
                optimizer.zero_grad()
            log = kge_model.train_step(kge_model, train_iterator,
                                       train_triples_tsr, loss_func, args)
            if step % args.update_freq == 0:
                optimizer.step()

            training_logs.append(log)

            #if step >= warm_up_steps:
            #    current_learning_rate = current_learning_rate / 10
            #    logging.info('Change learning_rate to %f at step %d' % (current_learning_rate, step))
            #    optimizer = torch.optim.Adam(
            #        filter(lambda p: p.requires_grad, kge_model.parameters()),
            #        lr=current_learning_rate
            #    )
            #    warm_up_steps = warm_up_steps * 3
            if step % args.schedule_steps == 0:
                scheduler.step()

            if step % args.save_checkpoint_steps == 0:
                save_variable_list = {
                    'step': step,
                    'current_learning_rate': current_learning_rate,
                    #'warm_up_steps': warm_up_steps
                }
                save_model(kge_model, optimizer, save_variable_list, args)

            if step % args.log_steps == 0:
                metrics = {}
                for metric in training_logs[0].keys():
                    metrics[metric] = sum(
                        [log[metric]
                         for log in training_logs]) / len(training_logs)
                log_metrics('Training average', step, [metrics])
                training_logs = []

            if args.do_valid and step % args.valid_steps == 0:
                logging.info('Evaluating on Valid Dataset...')
                metrics = kge_model.test_step(kge_model, valid_triples,
                                              all_true_triples,
                                              train_triples_tsr, args)
                log_metrics('Valid', step, metrics)
                if is_better_metric(best_metrics, metrics):
                    save_variable_list = {
                        'step': step,
                        'current_learning_rate': current_learning_rate,
                        #'warm_up_steps': warm_up_steps
                    }
                    save_model(kge_model, optimizer, save_variable_list, args,
                               True)
                    best_metrics = metrics
                #kge_model.cluster_relation_entity_embedding(args.context_cluster_num, args.context_cluster_scale)

        save_variable_list = {
            'step': step,
            'current_learning_rate': current_learning_rate,
            #'warm_up_steps': warm_up_steps
        }
        save_model(kge_model, optimizer, save_variable_list, args)
    if args.do_valid and args.do_train:
        #load the best model
        best_checkpoint = torch.load("%s/best/checkpoint" % args.save_path)
        kge_model.load_state_dict(best_checkpoint['model_state_dict'])
        logging.info("Loading best model from step %d" %
                     best_checkpoint['step'])
        step = best_checkpoint['step']

    if args.do_valid:
        logging.info('Evaluating on Valid Dataset...')
        metrics = kge_model.test_step(kge_model, valid_triples,
                                      all_true_triples, train_triples_tsr,
                                      args)
        log_metrics('Valid', step, metrics)

    if args.do_test:
        logging.info('Evaluating on Test Dataset...')
        metrics = kge_model.test_step(kge_model, test_triples,
                                      all_true_triples, train_triples_tsr,
                                      args)
        log_metrics('Test', step, metrics)

    if args.evaluate_train:
        logging.info('Evaluating on Training Dataset...')
        metrics = kge_model.test_step(kge_model, train_triples,
                                      all_true_triples, train_triples_tsr,
                                      args)
        log_metrics('Test', step, metrics)
Exemplo n.º 9
0
def main(args):
    if (not args.do_train) and (not args.do_valid) and (not args.do_test):
        raise ValueError('one of train/val/test mode must be choosed.')

    if args.init_checkpoint:
        override_config(args)
    elif args.data_path is None:
        raise ValueError('one of init_checkpoint/data_path must be choosed.')

    if args.do_train and args.save_path is None:
        raise ValueError('Where do you want to save your trained model?')
    
    if args.save_path and not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    
    # Write logs to checkpoint and console
    set_logger(args)
    
    with open(os.path.join(args.data_path, 'entities.dict')) as fin:
        entity2id = dict()
        for line in fin:
            eid, entity = line.strip().split('\t')
            entity2id[entity] = int(eid)

    with open(os.path.join(args.data_path, 'relations.dict')) as fin:
        relation2id = dict()
        for line in fin:
            rid, relation = line.strip().split('\t')
            relation2id[relation] = int(rid)
    
    # Read regions for Countries S* datasets
    if args.countries:
        regions = list()
        with open(os.path.join(args.data_path, 'regions.list')) as fin:
            for line in fin:
                region = line.strip()
                regions.append(entity2id[region])
        args.regions = regions

    nentity = len(entity2id)
    nrelation = len(relation2id)
    
    args.nentity = nentity
    args.nrelation = nrelation
    
    logging.info('Model: %s' % args.model)
    logging.info('Data Path: %s' % args.data_path)
    logging.info('#entity: %d' % nentity)
    logging.info('#relation: %d' % nrelation)
    
    train_triples = read_triple(os.path.join(args.data_path, 'train.txt'), entity2id, relation2id)
    logging.info('#train: %d' % len(train_triples))
    valid_triples = read_triple(os.path.join(args.data_path, 'valid.txt'), entity2id, relation2id)
    logging.info('#valid: %d' % len(valid_triples))
    test_triples = read_triple(os.path.join(args.data_path, 'test.txt'), entity2id, relation2id)
    logging.info('#test: %d' % len(test_triples))
    
    #All true triples
    all_true_triples = train_triples + valid_triples + test_triples
    
    kge_model = KGEModel(
        model_name=args.model,
        nentity=nentity,
        nrelation=nrelation,
        hidden_dim=args.hidden_dim,
        type_dim = args.type_dim,
        gamma=args.gamma,
        gamma_type=args.gamma_type,
        gamma_pair=args.gamma_pair,
        double_entity_embedding=args.double_entity_embedding,
        double_relation_embedding=args.double_relation_embedding
    )
    
    logging.info('Model Parameter Configuration:')
    for name, param in kge_model.named_parameters():
        logging.info('Parameter %s: %s, require_grad = %s' % (name, str(param.size()), str(param.requires_grad)))

    if args.cuda:
        kge_model = kge_model.cuda()
    
    if args.do_train:
        # Set training dataloader iterator
        train_dataloader_head = DataLoader(
            TrainDataset(train_triples, nentity, nrelation, args.negative_sample_size, args.pair_sample_size, 'head-batch'), 
            batch_size=args.batch_size,
            shuffle=True, 
            num_workers=max(1, args.cpu_num//2),
            collate_fn=TrainDataset.collate_fn
        )
        
        train_dataloader_tail = DataLoader(
            TrainDataset(train_triples, nentity, nrelation, args.negative_sample_size, args.pair_sample_size, 'tail-batch'), 
            batch_size=args.batch_size,
            shuffle=True, 
            num_workers=max(1, args.cpu_num//2),
            collate_fn=TrainDataset.collate_fn
        )
        
        train_iterator = BidirectionalOneShotIterator(train_dataloader_head, train_dataloader_tail)
        
        # Set training configuration
        current_learning_rate = args.learning_rate
        optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, kge_model.parameters()), 
            lr=current_learning_rate
        )
        if args.warm_up_steps:
            warm_up_steps = args.warm_up_steps
        else:
            warm_up_steps = args.max_steps // 2

    if args.init_checkpoint:
        # Restore model from checkpoint directory
        logging.info('Loading checkpoint %s...' % args.init_checkpoint)
        checkpoint = torch.load(os.path.join(args.init_checkpoint, 'checkpoint'))
        init_step = checkpoint['step']
        kge_model.load_state_dict(checkpoint['model_state_dict'])
        if args.do_train:
            current_learning_rate = checkpoint['current_learning_rate']
            warm_up_steps = checkpoint['warm_up_steps']
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    else:
        logging.info('Ramdomly Initializing %s Model...' % args.model)
        init_step = 0
    
    step = init_step
    
    logging.info('Start Training...')
    logging.info('init_step = %d' % init_step)
    logging.info('batch_size = %d' % args.batch_size)
    logging.info('negative_adversarial_sampling = %d' % args.negative_adversarial_sampling)
    logging.info('hidden_dim = %d' % args.hidden_dim)
    logging.info('gamma = %f' % args.gamma)
    logging.info('type_dim = %d' % args.type_dim)
    logging.info('gamma_type = %f' % args.gamma_type)
    logging.info('alpha_1 = %f' % args.alpha_1)
    logging.info('gamma_pair = %f' % args.gamma_pair)
    logging.info('alpha_2 = %f' % args.alpha_2)
    logging.info('negative_adversarial_sampling = %s' % str(args.negative_adversarial_sampling))
    logging.info('pair_sample_size = %d' % args.pair_sample_size)
    if args.negative_adversarial_sampling:
        logging.info('adversarial_temperature = %f' % args.adversarial_temperature)
    
    # Set valid dataloader as it would be evaluated during training
    
    if args.do_train:
        logging.info('learning_rate = %d' % current_learning_rate)

        training_logs = []
        
        #Training Loop
        for step in range(init_step, args.max_steps):
            log = kge_model.train_step(kge_model, optimizer, train_iterator, args)
            
            training_logs.append(log)
            
            if step >= warm_up_steps:
                current_learning_rate = current_learning_rate / 10
                logging.info('Change learning_rate to %f at step %d' % (current_learning_rate, step))
                optimizer = torch.optim.Adam(
                    filter(lambda p: p.requires_grad, kge_model.parameters()), 
                    lr=current_learning_rate
                )
                warm_up_steps = warm_up_steps * 3
            
            if step % args.save_checkpoint_steps == 0:
                save_variable_list = {
                    'step': step, 
                    'current_learning_rate': current_learning_rate,
                    'warm_up_steps': warm_up_steps
                }
                save_model(kge_model, optimizer, save_variable_list, args)
                
            if step % args.log_steps == 0:
                metrics = {}
                for metric in training_logs[0].keys():
                    metrics[metric] = sum([log[metric] for log in training_logs])/len(training_logs)
                log_metrics('Training average', step, metrics)
                training_logs = []
                
            if args.do_valid and step % args.valid_steps == 0:
                logging.info('Evaluating on Valid Dataset...')
                metrics = kge_model.test_step(kge_model, valid_triples, all_true_triples, args)
                log_metrics('Valid', step, metrics)
        
        save_variable_list = {
            'step': step, 
            'current_learning_rate': current_learning_rate,
            'warm_up_steps': warm_up_steps
        }
        save_model(kge_model, optimizer, save_variable_list, args)
        
    if args.do_valid:
        logging.info('Evaluating on Valid Dataset...')
        metrics = kge_model.test_step(kge_model, valid_triples, all_true_triples, args)
        log_metrics('Valid', step, metrics)
    
    if args.do_test:
        logging.info('Evaluating on Test Dataset...')
        metrics = kge_model.test_step(kge_model, test_triples, all_true_triples, args)
        log_metrics('Test', step, metrics)
    
    if args.evaluate_train:
        logging.info('Evaluating on Training Dataset...')
        metrics = kge_model.test_step(kge_model, train_triples, all_true_triples, args)
        log_metrics('Test', step, metrics)
Exemplo n.º 10
0
def main(args):
    # if (not args.do_train) and (not args.do_valid) and (not args.do_test):
    #     raise ValueError('one of train/val/test mode must be choosed.')
    
    if args.init_checkpoint:
        override_config(args)
    elif args.data_path is None:
        raise ValueError('one of init_checkpoint/data_path must be choosed.')

    if args.do_train and args.save_path is None:
        raise ValueError('Where do you want to save your trained model?')
    
    if args.save_path and not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    
    # Write logs to checkpoint and console
    set_logger(args)
    
    # with open(os.path.join(args.data_path, 'entities.dict')) as fin:
    #     entity2id = dict()
    #     id2entity = dict()
    #     for line in fin:
    #         eid, entity = line.strip().split('\t')
    #         entity2id[entity] = int(eid)
    #         id2entity[int(eid)] = entity

    # with open(os.path.join(args.data_path, 'relations.dict')) as fin:
    #     relation2id = dict()
    #     id2relation = dict()
    #     for line in fin:
    #         rid, relation = line.strip().split('\t')
    #         relation2id[relation] = int(rid)
    #         id2relation[int(rid)] = relation
    
    # # Read regions for Countries S* datasets
    # if args.countries:
    #     regions = list()
    #     with open(os.path.join(args.data_path, 'regions.list')) as fin:
    #         for line in fin:
    #             region = line.strip()
    #             regions.append(entity2id[region])
    #     args.regions = regions

    '''amazon dataset'''
    with open(os.path.join(args.data_path, 'entity2id.txt')) as fin:
        entity2id = dict()
        id2entity = dict()
        for line in fin:
            if len(line.strip().split('\t')) < 2:
                continue
            entity, eid = line.strip().split('\t')
            entity2id[entity] = int(eid)
            id2entity[int(eid)] = entity

    with open(os.path.join(args.data_path, 'relation2id.txt')) as fin:
        relation2id = dict()
        id2relation = dict()
        for line in fin:
            if len(line.strip().split('\t')) < 2:
                continue
            relation, rid = line.strip().split('\t')
            relation2id[relation] = int(rid)
            id2relation[int(rid)] = relation

    nentity = len(entity2id)
    nrelation = len(relation2id)
    
    args.nentity = nentity
    args.nrelation = nrelation
    
    logging.info('Model: %s' % args.model)
    logging.info('Data Path: %s' % args.data_path)
    logging.info('#entity: %d' % nentity)
    logging.info('#relation: %d' % nrelation)
    
    # --------------------------------------------------
    # Comments by Meng:
    # During training, pLogicNet will augment the training triplets,
    # so here we load both the augmented triplets (train.txt) for training and
    # the original triplets (train_kge.txt) for evaluation.
    # Also, the hidden triplets (hidden.txt) are also loaded for annotation.
    # --------------------------------------------------
    # train_triples = read_triple(os.path.join(args.workspace_path, 'train_kge.txt'), entity2id, relation2id)
    # logging.info('#train: %d' % len(train_triples))
    # train_original_triples = read_triple(os.path.join(args.data_path, 'train.txt'), entity2id, relation2id)
    # logging.info('#train original: %d' % len(train_original_triples))
    # valid_triples = read_triple(os.path.join(args.data_path, 'valid.txt'), entity2id, relation2id)
    # logging.info('#valid: %d' % len(valid_triples))
    # test_triples = read_triple(os.path.join(args.data_path, 'test.txt'), entity2id, relation2id)
    # logging.info('#test: %d' % len(test_triples))
    # hidden_triples = read_triple(os.path.join(args.workspace_path, 'hidden.txt'), entity2id, relation2id)
    # logging.info('#hidden: %d' % len(hidden_triples))

    train_triples = read_triple(os.path.join(args.workspace_path, 'train_kge.txt'), entity2id, relation2id)
    logging.info('#train: %d' % len(train_triples))
    train_original_triples = read_triple(os.path.join(args.data_path, 'train.txt'), entity2id, relation2id)
    logging.info('#train original: %d' % len(train_original_triples))
    valid_triples = read_triple(os.path.join(args.data_path, 'kg_val_triples_Cell_Phones_and_Accessories.txt'), entity2id, relation2id)
    logging.info('#valid: %d' % len(valid_triples))
    test_triples = read_triple(os.path.join(args.data_path, 'kg_test_triples_Cell_Phones_and_Accessories.txt'), entity2id, relation2id)
    logging.info('#test: %d' % len(test_triples))
    test_candidates = np.load(os.path.join(args.data_path, 'rec_test_candidate100.npz'))['candidates'][:, 1:]
    # test_candidates = np.load('/common/users/yz956/kg/code/OpenDialKG/cand.npy')
    # hidden_triples = read_triple(os.path.join(args.workspace_path, 'hidden.txt'), entity2id, relation2id)
    hidden_triples = read_triple("/common/users/yz956/kg/code/KBRD/data/cpa/cpa/hidden_50.txt", entity2id, relation2id)
    logging.info('#hidden: %d' % len(hidden_triples))
    
    #All true triples
    all_true_triples = train_original_triples + valid_triples + test_triples
    
    kge_model = KGEModel(
        model_name=args.model,
        nentity=nentity,
        nrelation=nrelation,
        hidden_dim=args.hidden_dim,
        gamma=args.gamma,
        double_entity_embedding=args.double_entity_embedding,
        double_relation_embedding=args.double_relation_embedding
    )
    
    logging.info('Model Parameter Configuration:')
    for name, param in kge_model.named_parameters():
        logging.info('Parameter %s: %s, require_grad = %s' % (name, str(param.size()), str(param.requires_grad)))

    if args.cuda:
        kge_model = kge_model.cuda()
    
    if args.do_train:
        # Set training dataloader iterator
        train_dataloader_head = DataLoader(
            TrainDataset(train_triples, nentity, nrelation, args.negative_sample_size, 'head-batch'), 
            batch_size=args.batch_size,
            shuffle=True, 
            num_workers=max(1, args.cpu_num//2),
            collate_fn=TrainDataset.collate_fn
        )
        
        train_dataloader_tail = DataLoader(
            TrainDataset(train_triples, nentity, nrelation, args.negative_sample_size, 'tail-batch'), 
            batch_size=args.batch_size,
            shuffle=True, 
            num_workers=max(1, args.cpu_num//2),
            collate_fn=TrainDataset.collate_fn
        )
        
        train_iterator = BidirectionalOneShotIterator(train_dataloader_head, train_dataloader_tail)
        
        # Set training configuration
        current_learning_rate = args.learning_rate
        optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, kge_model.parameters()), 
            lr=current_learning_rate
        )
        if args.warm_up_steps:
            warm_up_steps = args.warm_up_steps
        else:
            warm_up_steps = args.max_steps // 2

    if args.init_checkpoint:
        # Restore model from checkpoint directory
        logging.info('Loading checkpoint %s...' % args.init_checkpoint)
        checkpoint = torch.load(os.path.join(args.init_checkpoint, 'checkpoint'))
        init_step = checkpoint['step']
        kge_model.load_state_dict(checkpoint['model_state_dict'])
        if args.do_train:
            current_learning_rate = checkpoint['current_learning_rate']
            warm_up_steps = checkpoint['warm_up_steps']
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    else:
        logging.info('Ramdomly Initializing %s Model...' % args.model)
        init_step = 0
    
    step = init_step
    
    logging.info('Start Training...')
    logging.info('init_step = %d' % init_step)
    logging.info('learning_rate = %d' % current_learning_rate)
    logging.info('batch_size = %d' % args.batch_size)
    logging.info('negative_adversarial_sampling = %d' % args.negative_adversarial_sampling)
    logging.info('hidden_dim = %d' % args.hidden_dim)
    logging.info('gamma = %f' % args.gamma)
    logging.info('negative_adversarial_sampling = %s' % str(args.negative_adversarial_sampling))
    if args.negative_adversarial_sampling:
        logging.info('adversarial_temperature = %f' % args.adversarial_temperature)

    if args.record:
        local_path = args.workspace_path
        ensure_dir(local_path)

        opt = vars(args)
        with open(local_path + '/opt.txt', 'w') as fo:
            for key, val in opt.items():
                fo.write('{} {}\n'.format(key, val))
    
    # Set valid dataloader as it would be evaluated during training
    
    if args.do_train:
        training_logs = []
        
        #Training Loop
        for step in range(init_step, args.max_steps):
            
            log = kge_model.train_step(kge_model, optimizer, train_iterator, args)
            
            training_logs.append(log)
            
            if step >= warm_up_steps:
                current_learning_rate = current_learning_rate / 10
                logging.info('Change learning_rate to %f at step %d' % (current_learning_rate, step))
                optimizer = torch.optim.Adam(
                    filter(lambda p: p.requires_grad, kge_model.parameters()), 
                    lr=current_learning_rate
                )
                warm_up_steps = warm_up_steps * 3
            
            if step % args.save_checkpoint_steps == 0:
                save_variable_list = {
                    'step': step, 
                    'current_learning_rate': current_learning_rate,
                    'warm_up_steps': warm_up_steps
                }
                save_model(kge_model, optimizer, save_variable_list, args)
                
            if step % args.log_steps == 0:
                metrics = {}
                for metric in training_logs[0].keys():
                    metrics[metric] = sum([log[metric] for log in training_logs])/len(training_logs)
                log_metrics('Training average', step, metrics)
                training_logs = []
                
            if args.do_valid and (step + 1) % args.valid_steps == 0:
                logging.info('Evaluating on Valid Dataset...')
                metrics, preds = kge_model.test_step(kge_model, valid_triples, all_true_triples, args)
                log_metrics('Valid', step, metrics)
        
        save_variable_list = {
            'step': step, 
            'current_learning_rate': current_learning_rate,
            'warm_up_steps': warm_up_steps
        }
        save_model(kge_model, optimizer, save_variable_list, args)
        
    if args.do_valid:
        logging.info('Evaluating on Valid Dataset...')
        metrics, preds = kge_model.test_step(kge_model, valid_triples, all_true_triples, args)
        log_metrics('Valid', step, metrics)
        
        # --------------------------------------------------
        # Comments by Meng:
        # Save the prediction results of KGE on validation set.
        # --------------------------------------------------

        if args.record:
            # Save the final results
            with open(local_path + '/result_kge_valid.txt', 'w') as fo:
                for metric in metrics:
                    fo.write('{} : {}\n'.format(metric, metrics[metric]))

            # Save the predictions on test data
            with open(local_path + '/pred_kge_valid.txt', 'w') as fo:
                for h, r, t, f, rk, l in preds:
                    fo.write('{}\t{}\t{}\t{}\t{}\n'.format(id2entity[h], id2relation[r], id2entity[t], f, rk))
                    for e, val in l:
                        fo.write('{}:{:.4f} '.format(id2entity[e], val))
                    fo.write('\n')
    
    if args.do_test:
        logging.info('Evaluating on Test Dataset...')
        # metrics, preds = kge_model.test_step(kge_model, test_triples, all_true_triples, args)
        metrics, preds = kge_model.test_step(kge_model, test_triples, test_candidates, all_true_triples, args)
        log_metrics('Test', step, metrics)
        
        # --------------------------------------------------
        # Comments by Meng:
        # Save the prediction results of KGE on test set.
        # --------------------------------------------------

        if args.record:
            # Save the final results
            with open(local_path + '/result_kge.txt', 'w') as fo:
                for metric in metrics:
                    fo.write('{} : {}\n'.format(metric, metrics[metric]))

            # Save the predictions on test data
            with open(local_path + '/pred_kge.txt', 'w') as fo:
                for h, r, t, f, rk, l in preds:
                    fo.write('{}\t{}\t{}\t{}\t{}\n'.format(id2entity[h], id2relation[r], id2entity[t], f, rk))
                    for e, val in l:
                        fo.write('{}:{:.4f} '.format(id2entity[e], val))
                    fo.write('\n')

    # --------------------------------------------------
    # Comments by Meng:
    # Save the annotations on hidden triplets.
    # --------------------------------------------------

    if args.record:
        # Annotate hidden triplets
        scores = kge_model.infer_step(kge_model, hidden_triples, args)
        # with open(local_path + '/annotation.txt', 'w') as fo:
        #     for (h, r, t), s in zip(hidden_triples, scores):
        #         fo.write('{}\t{}\t{}\t{}\n'.format(id2entity[h], id2relation[r], id2entity[t], s))

        # Annotate hidden triplets
        print('annotation')
        
        cand = {}
        with gzip.open('/common/users/yz956/kg/code/KBRD/data/cpa/cpa/kg_test_candidates_Cell_Phones_and_Accessories.txt.gz', 'rt') as f:
            for line in f:
                cells = line.split()
                uid = int(cells[0])
                item_ids = [int(i) for i in cells[1:]]
                cand[uid] = item_ids
        ann, train = [], []
        d = {}
        with open('/common/users/yz956/kg/code/KBRD/data/cpa/cpa/sample_pre.txt') as ft:
            for line in ft:
                line = line.strip().split('\t')
                train.append(line[1:])
        for u in range(61254):
            hiddens = []
            for i in cand[u]:
            # for i in range(61254, 108858):
                hiddens.append((u, 0, i))
            scores = kge_model.infer_step(kge_model, hiddens, args)
            score_np = np.array(scores)
            d = dict(zip(cand[u], scores))
            # d = dict(zip(range(61254, 108858), scores))
            d = sorted(d.items(), key=lambda x: x[1], reverse=True)
            
            # d_50 = d[:50]
            # for idx, t in enumerate(train[u]):
            #     for (tt, prob) in d_50:
            #         if int(t) == tt:
            #             d_50.remove((tt, prob))
            #             d_50.append(d[50 + idx])
            # assert len(d_50) == 50
            # d = {}

            d_50 = d
            ann.append(d_50)
        with open(local_path + '/annotation_1000_htr.txt', 'w') as fo:
            for idx, d in enumerate(ann):
                for (t, score) in d:
                    fo.write(str(idx) + '\t' + str(t) + '\t0\t' + str(score) + '\n')

        # with open(local_path + '/hidden_50_p.txt', 'w') as fo:
        #     for idx, d in enumerate(ann):
        #         for (t, score) in d:
        #             fo.write(str(idx) + '\t' + str(t) + '\t0\n')
        
        scores = kge_model.infer_step(kge_model, hidden_triples, args)
        with open(local_path + '/annotation_htr.txt', 'w') as fo:
            for (h, r, t), s in zip(hidden_triples, scores):
                # fo.write('{}\t{}\t{}\t{}\n'.format(id2entity[h], id2relation[r], id2entity[t], s))
                fo.write('{}\t{}\t{}\t{}\n'.format(str(h), str(t), str(r), s))
    
    if args.evaluate_train:
        logging.info('Evaluating on Training Dataset...')
        metrics, preds = kge_model.test_step(kge_model, train_triples, all_true_triples, args)
        log_metrics('Test', step, metrics)
Exemplo n.º 11
0
Arquivo: run.py Projeto: cdhx/RotatE
def main(args):
    #什么模式
    if (not args.do_train) and (not args.do_valid) and (not args.do_test):
        raise ValueError('one of train/val/test mode must be choosed.')
    #是否要用config修改一些命令行参数
    if args.init_checkpoint:  # 如果init_checkpoint有值(是一个路径)就从config里面修改一些参数
        override_config(args)  #这个函数里面要用到init_checkpoint,因此进了这个if就说明它有值就不会报错
    elif args.data_path is None:  #override函数里,如果path是none就从config里读,如果就没进上一个if,就不能从config里读,默认是none,后来自己改成默认值,按说这里不指定的话就会raise这个错误
        raise ValueError('one of init_checkpoint/data_path must be choosed.')
    #训练但是没给保存路径
    if args.do_train and args.save_path is None:
        raise ValueError('Where do you want to save your trained model?')
    #有保存路径但是文件夹不存在,就创建一个
    if args.save_path and not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    # Write logs to checkpoint and console
    set_logger(args)
    #里面是所有实体/关系的代码 1 xx  2 xx  3 xx
    #读出来的entity2id,relation2id是键为实体/关系代码,值为序号的字典
    with open(os.path.join(args.data_path, 'entities.dict')) as fin:
        entity2id = dict()
        for line in fin:
            eid, entity = line.strip().split(
                '\t')  #1 xx   eid=str(1),entity=xx
            entity2id[entity] = int(eid)

    with open(os.path.join(args.data_path, 'relations.dict')) as fin:
        relation2id = dict()
        for line in fin:
            rid, relation = line.strip().split('\t')
            relation2id[relation] = int(rid)

    # Read regions for Countries S* datasets
    if args.countries:
        regions = list()
        #contries数据集里面会有一个文件是regions.list其他的数据集没有
        with open(os.path.join(args.data_path, 'regions.list')) as fin:
            for line in fin:  #只有5行
                region = line.strip()
                regions.append(entity2id[region])  #这里的值应该是区域的序号
        args.regions = regions

    nentity = len(entity2id)
    nrelation = len(relation2id)

    args.nentity = nentity
    args.nrelation = nrelation

    logging.info('Model: %s' % args.model)
    logging.info('Data Path: %s' % args.data_path)
    logging.info('#entity: %d' % nentity)
    logging.info('#relation: %d' % nrelation)
    #开始获取训练验证测试数据集,并打印size
    train_triples = read_triple(os.path.join(args.data_path, 'train.txt'),
                                entity2id, relation2id)
    logging.info('#train: %d' % len(train_triples))
    valid_triples = read_triple(os.path.join(args.data_path, 'valid.txt'),
                                entity2id, relation2id)
    logging.info('#valid: %d' % len(valid_triples))
    test_triples = read_triple(os.path.join(args.data_path, 'test.txt'),
                               entity2id, relation2id)
    logging.info('#test: %d' % len(test_triples))

    # All true triples
    all_true_triples = train_triples + valid_triples + test_triples
    #构造模型
    kge_model = KGEModel(
        model_name=args.model,
        nentity=nentity,
        nrelation=nrelation,
        hidden_dim=args.hidden_dim,
        gamma=args.gamma,
        double_entity_embedding=args.double_entity_embedding,
        double_relation_embedding=args.double_relation_embedding)

    logging.info('Model Parameter Configuration:')
    for name, param in kge_model.named_parameters():
        logging.info('Parameter %s: %s, require_grad = %s' %
                     (name, str(param.size()), str(param.requires_grad)))

    if args.cuda:
        kge_model = kge_model.cuda()

    if args.do_train:
        # Set training dataloader iterator
        train_dataloader_head = DataLoader(
            TrainDataset(train_triples, nentity, nrelation,
                         args.negative_sample_size, 'head-batch'),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TrainDataset.collate_fn)

        train_dataloader_tail = DataLoader(
            TrainDataset(train_triples, nentity, nrelation,
                         args.negative_sample_size, 'tail-batch'),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TrainDataset.collate_fn)

        train_iterator = BidirectionalOneShotIterator(train_dataloader_head,
                                                      train_dataloader_tail)

        # Set training configuration
        current_learning_rate = args.learning_rate
        optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad,
                   kge_model.parameters()),  #fitter操作,只优化requires_grad为true的
            lr=current_learning_rate)
        if args.warm_up_steps:
            warm_up_steps = args.warm_up_steps
        else:
            warm_up_steps = args.max_steps // 2

    if args.init_checkpoint:
        # Restore model from checkpoint directory
        logging.info('Loading checkpoint %s...' % args.init_checkpoint)
        checkpoint = torch.load(
            os.path.join(args.init_checkpoint, 'checkpoint'))
        init_step = checkpoint['step']
        kge_model.load_state_dict(checkpoint['model_state_dict'])
        if args.do_train:
            current_learning_rate = checkpoint['current_learning_rate']
            warm_up_steps = checkpoint['warm_up_steps']
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    else:
        logging.info('Ramdomly Initializing %s Model...' % args.model)
        init_step = 0

    step = init_step

    logging.info('Start Training...')
    logging.info('init_step = %d' % init_step)
    logging.info('learning_rate = %d' % current_learning_rate)
    logging.info('batch_size = %d' % args.batch_size)
    logging.info('negative_adversarial_sampling = %d' %
                 args.negative_adversarial_sampling)
    logging.info('hidden_dim = %d' % args.hidden_dim)
    logging.info('gamma = %f' % args.gamma)
    logging.info('negative_adversarial_sampling = %s' %
                 str(args.negative_adversarial_sampling))
    if args.negative_adversarial_sampling:
        logging.info('adversarial_temperature = %f' %
                     args.adversarial_temperature)

    # Set valid dataloader as it would be evaluated during training

    if args.do_train:
        training_logs = []

        # Training Loop
        for step in range(init_step, args.max_steps):
            #train_iterator = BidirectionalOneShotIterator(train_dataloader_head, train_dataloader_tail)
            log = kge_model.train_step(kge_model, optimizer, train_iterator,
                                       args)

            training_logs.append(log)
            #动态调整学习率
            if step >= warm_up_steps:  #大于warm_up_steps后学习率变为原来的1/10
                current_learning_rate = current_learning_rate / 10
                logging.info('Change learning_rate to %f at step %d' %
                             (current_learning_rate, step))
                optimizer = torch.optim.Adam(
                    filter(lambda p: p.requires_grad, kge_model.parameters()),
                    lr=current_learning_rate  #更新优化器里的学习率
                )
                warm_up_steps = warm_up_steps * 3  #更新warm_up_steps
            #每隔save_checkpoint_steps保存一次模型
            if step % args.save_checkpoint_steps == 0:
                save_variable_list = {
                    'step': step,
                    'current_learning_rate': current_learning_rate,
                    'warm_up_steps': warm_up_steps
                }
                save_model(kge_model, optimizer, save_variable_list, args)

            if step % args.log_steps == 0:
                metrics = {}
                for metric in training_logs[0].keys():
                    metrics[metric] = sum(
                        [log[metric]
                         for log in training_logs]) / len(training_logs)
                log_metrics('Training average', step, metrics)
                training_logs = []

            if args.do_valid and step % args.valid_steps == 0:
                logging.info('Evaluating on Valid Dataset...')
                metrics = kge_model.test_step(kge_model, valid_triples,
                                              all_true_triples, args)
                log_metrics('Valid', step, metrics)

        save_variable_list = {
            'step': step,
            'current_learning_rate': current_learning_rate,
            'warm_up_steps': warm_up_steps
        }
        save_model(kge_model, optimizer, save_variable_list, args)

    if args.do_valid:
        logging.info('Evaluating on Valid Dataset...')
        metrics = kge_model.test_step(kge_model, valid_triples,
                                      all_true_triples, args)
        log_metrics('Valid', step, metrics)

    if args.do_test:
        logging.info('Evaluating on Test Dataset...')
        metrics = kge_model.test_step(kge_model, test_triples,
                                      all_true_triples, args)
        log_metrics('Test', step, metrics)

    if args.evaluate_train:
        logging.info('Evaluating on Training Dataset...')
        metrics = kge_model.test_step(kge_model, train_triples,
                                      all_true_triples, args)
        log_metrics('Test', step, metrics)
Exemplo n.º 12
0
def main(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.CUDA_DEVISE

    args.data_path = os.path.join(args.datadir, args.dataset, 'onto_file')

    # if args.init_checkpoint:
    #     override_config(args)
    if args.data_path is None:
        raise ValueError('data_path and dataset must be choosed.')

    args.save_path = os.path.join(args.data_path, 'save_onto_embeds')

    # if args.do_train and args.save_path is None:
    #     raise ValueError('Where do you want to save your trained model?')

    if args.save_path and not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    with open(os.path.join(args.data_path, 'entities.dict')) as fin:
        entity2id = dict()
        for line in fin:
            eid, entity = line.strip().split('\t')
            entity2id[entity] = int(eid)

    with open(os.path.join(args.data_path, 'relations.dict')) as fin:
        relation2id = dict()
        for line in fin:
            rid, relation = line.strip().split('\t')
            relation2id[relation] = int(rid)

    nentity = len(entity2id)
    nrelation = len(relation2id)

    args.nentity = nentity
    args.nrelation = nrelation

    print('Model: %s' % args.model)
    # print('Data Path: %s' % args.data_path + "/" + args.dataset)
    print('#entity num: %d' % nentity)
    print('#relation num: %d' % nrelation)

    all_triples = read_triple(os.path.join(args.data_path, 'all_triples.txt'),
                              entity2id, relation2id)
    print('#total triples num: %d' % len(all_triples))

    # All true triples
    all_true_triples = all_triples

    kge_model = KGEModel(
        model_name=args.model,
        nentity=nentity,
        nrelation=nrelation,
        hidden_dim=args.hidden_dim,
        gamma=args.gamma,
        double_entity_embedding=args.double_entity_embedding,
        double_relation_embedding=args.double_relation_embedding)

    # logging.info('Model Parameter Configuration:')
    # for name, param in kge_model.named_parameters():
    #     logging.info('Parameter %s: %s, require_grad = %s' % (name, str(param.size()), str(param.requires_grad)))

    if args.cuda:
        kge_model = kge_model.cuda()

    if args.do_train:
        # Set training dataloader iterator
        train_dataloader_head = DataLoader(
            TrainDataset(all_triples, nentity, nrelation,
                         args.negative_sample_size, 'head-batch'),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TrainDataset.collate_fn)

        train_dataloader_tail = DataLoader(
            TrainDataset(all_triples, nentity, nrelation,
                         args.negative_sample_size, 'tail-batch'),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TrainDataset.collate_fn)

        train_iterator = BidirectionalOneShotIterator(train_dataloader_head,
                                                      train_dataloader_tail)

        # Set training configuration
        current_learning_rate = args.learning_rate

        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                            kge_model.parameters()),
                                     lr=current_learning_rate)
        if args.warm_up_steps:
            warm_up_steps = args.warm_up_steps
        else:
            warm_up_steps = args.max_steps // 2

    print('Ramdomly Initializing %s Model...' % args.model)

    # step = init_step

    print('------ Start Training...')
    print('batch_size = %d' % args.batch_size)
    print('negative sample size = %d' % args.negative_sample_size)
    print('hidden_dim = %d' % args.hidden_dim)
    print('gamma = %f' % args.gamma)
    print('negative_adversarial_sampling = %s' %
          str(args.negative_adversarial_sampling))

    if args.negative_adversarial_sampling:
        print('adversarial_temperature = %f' % args.adversarial_temperature)

    print("learning rate = %f" % current_learning_rate)

    # Set valid dataloader as it would be evaluated during training

    if args.do_train:

        train_losses = []

        # Training Loop
        for step in range(1, args.max_steps + 1):

            loss_values = kge_model.train_step(kge_model, optimizer,
                                               train_iterator, args)

            train_losses.append(loss_values)

            if step >= warm_up_steps:
                current_learning_rate = current_learning_rate / 10
                print('Change learning_rate to %f at step %d' %
                      (current_learning_rate, step))
                optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                                    kge_model.parameters()),
                                             lr=current_learning_rate)
                warm_up_steps = warm_up_steps * 3

            if step % args.print_steps == 0:
                pos_sample_loss = sum(
                    [losses['pos_sample_loss']
                     for losses in train_losses]) / len(train_losses)
                neg_sample_loss = sum(
                    [losses['neg_sample_loss']
                     for losses in train_losses]) / len(train_losses)
                loss1 = sum([losses['loss']
                             for losses in train_losses]) / len(train_losses)

                # log_metrics('Training average', step, metrics)
                print(
                    'Training Step: %d; average -> pos_sample_loss: %f; neg_sample_loss: %f; loss: %f'
                    % (step, pos_sample_loss, neg_sample_loss, loss1))
                train_losses = []

            if step % args.save_steps == 0:
                save_embeddings(kge_model, step, args)

            if args.evaluate_train and step % args.valid_steps == 0:
                print('------ Evaluating on Training Dataset...')
                metrics = kge_model.test_step(kge_model, all_triples,
                                              all_true_triples, args)
                log_metrics('Test', step, metrics)
Exemplo n.º 13
0
def main(args):
    if (not args.do_train) and (not args.do_valid) and (not args.do_test):
        raise ValueError('one of train/val/test mode must be choosed.')

    if args.init_checkpoint:
        override_config(args)
    elif args.data_path is None:
        raise ValueError('one of init_checkpoint/data_path must be choosed.')
    '''
    #if need to store the outputs
    elif args.storefile is None:
        raise ValueError('where do you want to store you result?')
    '''
    if args.do_train and args.save_path is None:
        raise ValueError('Where do you want to save your trained model?')

    if args.save_path and not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    # Write logs to checkpoint and console
    set_logger(args)

    with open(os.path.join(args.data_path, 'entities.txt')) as fin:
        all_entitys = list()
        for line in fin:
            entity, eid = line.strip().split('\t')
            all_entitys.append(eid)

    #define the entity file and inputdataWWC
    with open(os.path.join(args.data_path, 'entities.txt')) as fin:
        entity2id = dict()
        for line in fin:
            entity, eid = line.strip().split('\t')
            entity2id[entity] = int(eid)

    nentity = len(entity2id)

    args.nentity = nentity

    logging.info('Model: %s' % args.model)
    logging.info('Data Path: %s' % args.data_path)
    logging.info('#entity: %d' % nentity)

    #laod all dataset
    train_triples = read_triple(os.path.join(args.data_path, 'train.txt'),
                                entity2id)
    logging.info('#train: %d' % len(train_triples))
    valid_triples = read_triple(os.path.join(args.data_path, 'valid.txt'),
                                entity2id)
    logging.info('#valid: %d' % len(valid_triples))
    test_triples = read_triple(os.path.join(args.data_path, 'test.txt'),
                               entity2id)
    logging.info('#test: %d' % len(test_triples))
    addtrain_triples = read_triple(os.path.join(args.data_path, 'test.txt'),
                                   entity2id)
    logging.info('#add_train:%d' % len(addtrain_triples))
    #All true triples
    all_true_triples = train_triples + valid_triples + test_triples + addtrain_triples
    train_triples = train_triples  #+ addtrain_triples if need experiment,delete addtrain
    '''
    #load output mode dataset
    train_triples = read_triple(os.path.join(args.data_path, 'train_origin.txt'), entity2id, relation2id)
    logging.info('#train: %d' % len(train_triples))
    valid_triples = read_triple(os.path.join(args.data_path, 'valid_1.txt'), entity2id, relation2id)
    logging.info('#valid: %d' % len(valid_triples))
    test_triples = read_triple(os.path.join(args.data_path, 'test_1.txt'), entity2id, relation2id)
    logging.info('#test: %d' % len(test_triples))
    #All true triples
    all_true_triples = train_triples + valid_triples + test_triples
    train_triples = train_triples + valid_triples
    '''

    bdg_model = Bridge_rules(
        nentity=nentity,
        hidden_dim=args.hidden_dim,
        gamma=args.gamma,
    )

    logging.info('Model Parameter Configuration:')
    for name, param in bdg_model.named_parameters():
        logging.info('Parameter %s: %s, require_grad = %s' %
                     (name, str(param.size()), str(param.requires_grad)))

    if args.cuda:
        bdg_model = bdg_model.cuda()

    if args.do_train:
        # Set training dataloader iterator
        train_dataloader_head = DataLoader(
            TrainDataset(train_triples, nentity, args.negative_sample_size,
                         'head-batch'),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TrainDataset.collate_fn)

        train_dataloader_tail = DataLoader(
            TrainDataset(train_triples, nentity, args.negative_sample_size,
                         'tail-batch'),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TrainDataset.collate_fn)

        train_iterator = BidirectionalOneShotIterator(train_dataloader_head,
                                                      train_dataloader_tail)

        # Set training configuration
        current_learning_rate = args.learning_rate
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                            bdg_model.parameters()),
                                     lr=current_learning_rate)
        if args.warm_up_steps:
            warm_up_steps = args.warm_up_steps
        else:
            warm_up_steps = args.max_steps // 2

    if args.init_checkpoint:
        # Restore model from checkpoint directory
        logging.info('Loading checkpoint %s...' % args.init_checkpoint)
        checkpoint = torch.load(
            os.path.join(args.init_checkpoint, 'checkpoint'))
        init_step = checkpoint['step']
        bdg_model.load_state_dict(checkpoint['model_state_dict'])
        if args.do_train:
            current_learning_rate = checkpoint['current_learning_rate']
            warm_up_steps = checkpoint['warm_up_steps']
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    else:
        logging.info('Ramdomly Initializing %s Model...' % args.model)
        init_step = 0

    step = init_step

    logging.info('Start Training...')
    logging.info('init_step = %d' % init_step)
    logging.info('batch_size = %d' % args.batch_size)
    logging.info('negative_adversarial_sampling = %d' %
                 args.negative_adversarial_sampling)
    logging.info('hidden_dim = %d' % args.hidden_dim)
    logging.info('gamma = %f' % args.gamma)
    logging.info('negative_adversarial_sampling = %s' %
                 str(args.negative_adversarial_sampling))
    if args.negative_adversarial_sampling:
        logging.info('adversarial_temperature = %f' %
                     args.adversarial_temperature)

    # Set valid dataloader as it would be evaluated during training

    if args.do_train:
        logging.info('learning_rate = %d' % current_learning_rate)

        training_logs = []

        #Training Loop
        for step in range(init_step, args.max_steps):

            log = bdg_model.train_step(bdg_model, optimizer, train_iterator,
                                       args)

            training_logs.append(log)

            if step >= warm_up_steps:
                current_learning_rate = current_learning_rate / 10
                logging.info('Change learning_rate to %f at step %d' %
                             (current_learning_rate, step))
                optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                                    bdg_model.parameters()),
                                             lr=current_learning_rate)
                warm_up_steps = warm_up_steps * 3

            if step % args.save_checkpoint_steps == 0:
                save_variable_list = {
                    'step': step,
                    'current_learning_rate': current_learning_rate,
                    'warm_up_steps': warm_up_steps
                }
                save_model(bdg_model, optimizer, save_variable_list, args)

            if step % args.log_steps == 0:
                metrics = {}
                for metric in training_logs[0].keys():
                    metrics[metric] = sum(
                        [log[metric]
                         for log in training_logs]) / len(training_logs)
                log_metrics('Training average', step, metrics)
                training_logs = []

            if args.do_valid and step % args.valid_steps == 0:
                logging.info('Evaluating on Valid Dataset...')
                metrics = bdg_model.test_step(bdg_model, valid_triples,
                                              all_true_triples, args)
                log_metrics('Valid', step, metrics)

        save_variable_list = {
            'step': step,
            'current_learning_rate': current_learning_rate,
            'warm_up_steps': warm_up_steps
        }
        save_model(bdg_model, optimizer, save_variable_list, args)

    if args.do_valid:
        logging.info('Evaluating on Valid Dataset...')
        metrics = bdg_model.test_step(bdg_model, valid_triples,
                                      all_true_triples, args)
        log_metrics('Valid', step, metrics)

    if args.do_test:
        logging.info('Evaluating on Test Dataset...')
        metrics = bdg_model.test_step(bdg_model, test_triples,
                                      all_true_triples, args)
        log_metrics('Test', step, metrics)

    if args.evaluate_train:
        logging.info('Evaluating on Training Dataset...')
        metrics = bdg_model.test_step(bdg_model, train_triples,
                                      all_true_triples, args)
        log_metrics('Test', step, metrics)

    if args.output:
        logging.info('Output the entities rank...')
        storefile = 'score.txt'
        bdg_model.store_score(bdg_model, test_triples, all_true_triples,
                              storefile, args)
Exemplo n.º 14
0
    def train(self,train_triples,time,entity2id,relation2id,valid_triples):
        self.nentity = len(entity2id)
        self.nrelation = len(relation2id)

        current_learning_rate = 0.0001
        if time == -1:
            # self.kge_model = KGEModel(
            #     ent_tot=self.nentity,
            #     rel_tot=self.nrelation,
            #     dim_e=50,
            #     dim_r=50
            # )
            # self.kge_model = self.kge_model.cuda()
            # self.optimizer = torch.optim.Adam(
            #     filter(lambda p: p.requires_grad, self.kge_model.parameters()),
            #     lr=current_learning_rate
            # )
            # init_step = 0
        #else:
        #    temp = time - 1
            self.kge_model, self.optimizer, step_loaded = self.load_model(time)
            init_step = step_loaded
        else:
            init_step = 0

        # Set training dataloader iterator
        train_dataloader_head = DataLoader(
            TrainDataset(train_triples, self.nentity, self.nrelation, round(len(train_triples)*0.1), 'head-batch'),
            batch_size=50,
            shuffle=False,
            num_workers=2,
            collate_fn=TrainDataset.collate_fn
        )
        train_dataloader_tail = DataLoader(
            TrainDataset(train_triples, self.nentity, self.nrelation, round(len(train_triples)*0.1), 'tail-batch'),
            batch_size=50,
            shuffle=False,
            num_workers=2,
            collate_fn=TrainDataset.collate_fn
        )
        warm_up_steps = 5000 // 2
        train_iterator = BidirectionalOneShotIterator(train_dataloader_head, train_dataloader_tail)


        if self.isCUDA == 1:
            self.kge_model = self.kge_model.cuda()

        #start training
        print("start training:%d"%time)
        # Training Loop
        starttime = Time.time()
        if time==-1:
            steps = 3000*4
            printnum = 100
        else:
            steps = 150
            printnum = 50
        for step in range(init_step, steps):
            loss = self.kge_model.train_step(self.kge_model, self.optimizer, train_iterator,self.isCUDA)
            '''
            if step >= warm_up_steps:
                current_learning_rate = current_learning_rate / 10
                logging.info('Change learning_rate to %f at step %d' % (current_learning_rate, step))
                self.optimizer = torch.optim.Adam(
                   filter(lambda p: p.requires_grad, self.kge_model.parameters()),
                   lr=current_learning_rate
                )
                warm_up_steps = warm_up_steps * 3
            '''
            if step%printnum==0:
                endtime = Time.time()
                print("step:%d, cost time: %s, loss is %.4f" % (step,round((endtime - starttime), 3),loss))
                '''self.save_model(self.kge_model, self.optimizer, time)
                result_head, result_tail = self.evaluate(valid_triples, train_triples + valid_triples, relation2id, entity2id, time)
                print(result_head)
                print(result_tail)
                self.kge_model, self.optimizer = self.load_model(time)'''



            self.save_model(self.kge_model, self.optimizer, time, step)
Exemplo n.º 15
0
def main(args):
    if not torch.cuda.is_available():
        args.cuda = False

    if args.ruge:
        args.loss = 'ruge'

    if (not args.do_train) and (not args.do_valid) and (not args.do_test) and (
            not args.do_experiment) and (not args.do_grid):
        raise ValueError('one of train/val/test mode must be choosed.')

    if args.init_checkpoint:
        override_config(args)

    elif args.data_path is None:
        raise ValueError('one of init_checkpoint/data_path must be choosed.')

    if args.do_train and args.save_path is None:
        raise ValueError('Where do you want to save your trained model?')

    if args.save_path and not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    # Write logs to checkpoint and console

    set_logger(args)
    if args.regularization != 0:
        print('L3 regularization with coeff - ', args.regularization)
    if args.l2_r != 0:
        print('L2 regularization with coeff - ', args.l2_r)
    if args.project != 0:
        print('projecting before training')
    #logging.info('Inverse loss = premise - concl (reverse)')
    if OPT_STOPPING:
        logging.info('Opt stopping is ON')
        print('Opt stopping is on')

    with open(os.path.join(args.data_path, 'entities.dict')) as fin:
        entity2id = dict()
        for line in fin:
            eid, entity = line.strip().split('\t')
            entity2id[entity] = int(eid)

    with open(os.path.join(args.data_path, 'relations.dict')) as fin:
        relation2id = dict()
        for line in fin:
            rid, relation = line.strip().split('\t')
            relation2id[relation] = int(rid)

    # Read regions for Countries S* datasets
    if args.countries:
        regions = list()
        with open(os.path.join(args.data_path, 'regions.list')) as fin:
            for line in fin:
                region = line.strip()
                regions.append(entity2id[region])
        args.regions = regions

    nentity = len(entity2id)
    nrelation = len(relation2id)

    args.nentity = nentity
    args.nrelation = nrelation

    if args.inject:
        logging.info('With rule injection')
    else:
        logging.info('NO INJECTION')

    logging.info('Model: %s' % args.model)
    logging.info('Data Path: %s' % args.data_path)
    logging.info('#entity: %d' % nentity)
    logging.info('#relation: %d' % nrelation)

    train_triples = read_triple(os.path.join(args.data_path, 'train.txt'),
                                entity2id, relation2id)
    logging.info('#train: %d' % len(train_triples))
    valid_triples = read_triple(os.path.join(args.data_path, 'valid.txt'),
                                entity2id, relation2id)
    logging.info('#valid: %d' % len(valid_triples))
    test_triples = read_triple(
        os.path.join(args.data_path, 'test.txt'), entity2id,
        relation2id)  # For testing on Symmetric in WordNet: Symmetric_testWN18
    logging.info('#test: %d' % len(test_triples))

    #All true triples
    all_true_triples = train_triples + valid_triples + test_triples
    train_args = {}

    # set up rule iterators
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    n_batches = len(train_triples) // args.batch_size
    if n_batches < len(train_triples) / args.batch_size: n_batches += 1
    rule_iterators = {}
    rules_info = ''
    if args.inv:
        n_inverse, inverse_batchsize, rule_iterators[
            'inverse'] = setup_rule_loader(n_batches, args.batch_size,
                                           args.data_path,
                                           'groundings_inverse.txt', device,
                                           RULE_BATCH_SIZE_INV)
        rules_info += 'Inverse: batch size %d out of %d rules' % (
            inverse_batchsize, n_inverse) + '\n'
    if args.eq:
        n_eq, eq_batchsize, rule_iterators['equality'] = setup_rule_loader(
            n_batches, args.batch_size, args.data_path,
            'groundings_equality.txt', device, RULE_BATCH_SIZE_EQ)
        rules_info += 'Equality: batch size %d out of %d rules' % (
            eq_batchsize, n_eq) + '\n'
    if args.impl:
        n_impl, impl_batchsize, rule_iterators[
            'implication'] = setup_rule_loader(n_batches, args.batch_size,
                                               args.data_path,
                                               'groundings_implication.txt',
                                               device, RULE_BATCH_SIZE_IMPL)
        rules_info += 'implication: batch size %d out of %d rules\n' % (
            impl_batchsize, n_impl)
    if args.sym:
        n_symmetry, sym_batchsize, rule_iterators[
            'symmetry'] = setup_rule_loader(n_batches, args.batch_size,
                                            args.data_path,
                                            'groundings_symmetric.txt', device,
                                            RULE_BATCH_SIZE_SYM)
        rules_info += 'symmetry: batch size %d out of %d rules\n' % (
            sym_batchsize, n_symmetry)
    if args.ruge or args.ruge_inject:
        n_rules, rule_iterators['ruge'] = construct_ruge_loader(
            n_batches, args)
        rules_info += 'RUGE: Total %d rules\n' % n_rules

    if rules_info:
        logging.info(rules_info)

    # ----------- adversarial ------------------
    if args.adversarial:
        clauses_filename = os.path.join(args.data_path, 'clauses_0.9.pl')
        adv_clauses, clentity2id = dt.read_clauses(clauses_filename,
                                                   relation2id)
        n_clause_entities = len(clentity2id)
        mult = 2
        if args.model in ['TransE', 'pRotatE']: mult = 1
        if 'QuatE' in args.model: mult = 4
        adv_model = ADVModel(clauses=adv_clauses,
                             n_entities=len(clentity2id),
                             dim=mult * args.hidden_dim,
                             use_cuda=args.cuda)
        if args.cuda:
            adv_model = adv_model.cuda()
    else:
        adv_model = None

    if args.do_grid:
        if rules_info:
            print(rules_info)
        run_grid(nentity, nrelation, train_triples, valid_triples,
                 test_triples, all_true_triples, args, rule_iterators,
                 adv_model)
        exit()
    ntriples = len(train_triples)
    kge_model = KGEModel(
        model_name=args.model,
        nentity=nentity,
        nrelation=nrelation,
        ntriples=ntriples,
        hidden_dim=args.hidden_dim,
        gamma=args.gamma,
    )
    kge_model.set_loss(args.loss)

    logging.info('Model Parameter Configuration:')
    for name, param in kge_model.named_parameters():
        logging.info('Parameter %s: %s, require_grad = %s' %
                     (name, str(param.size()), str(param.requires_grad)))
    logging.info('Loss function %s' % args.loss)
    if args.cuda and args.parallel:
        gpus = [0, 1]
        os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(str(x) for x in gpus)
        kge_model.cuda()
        kge_model = torch.nn.DataParallel(kge_model, device_ids=[0, 1])

    elif args.cuda:
        kge_model = kge_model.cuda()

    if args.do_train or args.do_experiment:
        # Set training dataloader iterator
        train_dataloader_head = DataLoader(
            TrainDataset(train_triples, nentity, nrelation,
                         args.negative_sample_size, 'head-batch'),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TrainDataset.collate_fn)

        train_dataloader_tail = DataLoader(
            TrainDataset(train_triples, nentity, nrelation,
                         args.negative_sample_size, 'tail-batch'),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TrainDataset.collate_fn)

        train_iterator = BidirectionalOneShotIterator(train_dataloader_head,
                                                      train_dataloader_tail)

        # Set training configuration
        current_learning_rate = args.learning_rate
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                            kge_model.parameters()),
                                     lr=current_learning_rate)

    if args.init_checkpoint:
        # Restore model from checkpoint directory
        logging.info('Loading checkpoint %s...' % args.init_checkpoint)
        checkpoint = torch.load(
            os.path.join(args.init_checkpoint, 'checkpoint'))
        init_step = checkpoint['step']
        kge_model.load_state_dict(checkpoint['model_state_dict'])
        if args.do_train:
            current_learning_rate = checkpoint['current_learning_rate']
            warm_up_steps = checkpoint['warm_up_steps']
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        else:
            logging.info('Ramdomly Initializing %s Model...' % args.model)
            init_step = 0

    step = init_step

    logging.info('Start Training...')
    logging.info('init_step = %d' % init_step)
    logging.info('learning_rate = %d' % current_learning_rate)
    logging.info('batch_size = %d' % args.batch_size)
    logging.info('negative_adversarial_sampling = %d' %
                 args.negative_adversarial_sampling)
    logging.info('hidden_dim = %d' % args.hidden_dim)
    logging.info('gamma = %f' % args.gamma)
    logging.info('negative_adversarial_sampling = %s' %
                 str(args.negative_adversarial_sampling))
    if args.negative_adversarial_sampling:
        logging.info('adversarial_temperature = %f' %
                     args.adversarial_temperature)

    # Set valid dataloader as it would be evaluated during training

    if args.do_train:
        train_model(init_step, valid_triples, all_true_triples, kge_model,
                    train_iterator, len(train_triples), args)

    if args.evaluate_train:
        logging.info('Evaluating on Training Dataset...')
        model_module = kge_model.module if args.parallel else kge_model
        metrics = model_module.test_step(kge_model, train_triples,
                                         all_true_triples, args)
        #metrics1 = model_module.getScore(kge_model, train_triples, all_true_triples, args)
        log_metrics('Test', step, metrics)

    # experiment on the updated function
    if args.do_experiment:
        logging.info('\n\nSTARTING EXPERIMENT\n')

    train_model(init_step, valid_triples, all_true_triples, kge_model,
                train_iterator, rule_iterators, args)

    if args.do_valid:
        logging.info('Evaluating on Valid Dataset...')
        model_module = kge_model.module if args.parallel else kge_model
        metrics = model_module.test_step(kge_model, valid_triples,
                                         all_true_triples, args)
        #metrics1 = model_module.getScore(kge_model, train_triples, all_true_triples, args)
        log_metrics('Valid', step, metrics)

    if args.do_test:
        logging.info('Evaluating on Test Dataset...')
        model_module = kge_model.module if args.parallel else kge_model
        metrics = model_module.test_step(kge_model, test_triples,
                                         all_true_triples, args)
        log_metrics('Test', step, metrics)

    if args.evaluate_train:
        logging.info('Evaluating on Training Dataset...')
        model_module = kge_model.module if args.parallel else kge_model
        metrics = model_module.test_step(kge_model, train_triples,
                                         all_true_triples, args)
        log_metrics('Test', step, metrics)
Exemplo n.º 16
0
Arquivo: run.py Projeto: zyksir/NoiGAN
def main(args):
    if (not args.do_train) and (not args.do_valid) and (not args.do_test):
        raise ValueError('one of train/val/test mode must be choosed.')

    if not args.do_train and args.init_checkpoint:
        override_config(args)
    elif args.data_path is None:
        raise ValueError('one of init_checkpoint/data_path must be choosed.')

    if args.do_train and args.save_path is None:
        raise ValueError('Where do you want to save your trained model?')

    if args.save_path and not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    # Write logs to checkpoint and console
    set_logger(args)

    with open(os.path.join(args.data_path, 'entities.dict')) as fin:
        entity2id = dict()
        for line in fin:
            eid, entity = line.strip().split('\t')
            entity2id[entity] = int(eid)

    with open(os.path.join(args.data_path, 'relations.dict')) as fin:
        relation2id = dict()
        for line in fin:
            rid, relation = line.strip().split('\t')
            relation2id[relation] = int(rid)

    nentity = len(entity2id)
    nrelation = len(relation2id)

    args.nentity = nentity
    args.nrelation = nrelation

    logging.info('Model: %s' % args.model)
    logging.info('Data Path: %s' % args.data_path)
    logging.info('#entity: %d' % nentity)
    logging.info('#relation: %d' % nrelation)

    train_triples = read_triple(os.path.join(args.data_path, "train.txt"),
                                entity2id, relation2id)
    if args.self_test:
        train_triples = train_triples[len(train_triples) // 5:]
    if args.fake:
        fake_triples = pickle.load(
            open(os.path.join(args.data_path, "fake%s.pkl" % args.fake), "rb"))
        fake = torch.LongTensor(fake_triples)
        train_triples += fake_triples
    else:
        fake_triples = [(0, 0, 0)]
        fake = torch.LongTensor(fake_triples)
    if args.cuda:
        fake = fake.cuda()
    logging.info('#train: %d' % len(train_triples))
    valid_triples = read_triple(os.path.join(args.data_path, 'valid.txt'),
                                entity2id, relation2id)
    logging.info('#valid: %d' % len(valid_triples))
    test_triples = read_triple(os.path.join(args.data_path, 'test.txt'),
                               entity2id, relation2id)
    logging.info('#test: %d' % len(test_triples))

    all_true_triples = train_triples + valid_triples + test_triples

    kge_model = KGEModel(
        model_name=args.model,
        nentity=nentity,
        nrelation=nrelation,
        hidden_dim=args.hidden_dim,
        gamma=args.gamma,
        double_entity_embedding=args.double_entity_embedding,
        double_relation_embedding=args.double_relation_embedding)
    trainer = None
    if args.method == "CLF":
        trainer = ClassifierTrainer(train_triples, fake_triples, args,
                                    kge_model, args.hard)
    elif args.method == "LT":
        trainer = LTTrainer(train_triples, fake_triples, args, kge_model)
    elif args.method == "NoiGAN":
        trainer = NoiGANTrainer(train_triples, fake_triples, args, kge_model,
                                args.hard)

    logging.info('Model Parameter Configuration:')
    for name, param in kge_model.named_parameters():
        logging.info('Parameter %s: %s, require_grad = %s' %
                     (name, str(param.size()), str(param.requires_grad)))

    if args.cuda:
        kge_model = kge_model.cuda()

    if args.do_train:
        # Set training dataloader iterator
        train_dataloader_head = DataLoader(
            TrainDataset(train_triples, nentity, nrelation,
                         args.negative_sample_size, 'head-batch'),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TrainDataset.collate_fn)

        train_dataloader_tail = DataLoader(
            TrainDataset(train_triples, nentity, nrelation,
                         args.negative_sample_size, 'tail-batch'),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TrainDataset.collate_fn)

        train_iterator = BidirectionalOneShotIterator(train_dataloader_head,
                                                      train_dataloader_tail)

        # Set training configuration
        current_learning_rate = args.learning_rate
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                            kge_model.parameters()),
                                     lr=current_learning_rate)
        if args.warm_up_steps:
            warm_up_steps = args.warm_up_steps
        else:
            warm_up_steps = args.max_steps

    if args.init_checkpoint:
        # Restore model from checkpoint directory
        logging.info('Loading checkpoint %s...' % args.init_checkpoint)
        checkpoint = torch.load(
            os.path.join(args.init_checkpoint, 'checkpoint'))
        init_step = 0  #checkpoint['step']
        kge_model.load_state_dict(checkpoint['model_state_dict'])
        if args.do_train:
            # current_learning_rate = checkpoint['current_learning_rate']
            # warm_up_steps = checkpoint['warm_up_steps']
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    else:
        logging.info('Ramdomly Initializing %s Model...' % args.model)
        init_step = 0

    step = init_step

    logging.info('Start Training...')
    logging.info('init_step = %d' % init_step)
    logging.info('batch_size = %d' % args.batch_size)
    logging.info('negative_adversarial_sampling = %d' %
                 args.negative_adversarial_sampling)
    logging.info('hidden_dim = %d' % args.hidden_dim)
    logging.info('gamma = %f' % args.gamma)
    logging.info('negative_adversarial_sampling = %s' %
                 str(args.negative_adversarial_sampling))
    if args.negative_adversarial_sampling:
        logging.info('adversarial_temperature = %f' %
                     args.adversarial_temperature)

    # Set valid dataloader as it would be evaluated during training

    if args.do_train:
        logging.info('learning_rate = %f' % current_learning_rate)

        training_logs = []

        #Training Loop
        triple2confidence_weights = None
        for step in range(init_step, args.max_steps):
            if args.method == "CLF" and step % args.classify_steps == 0:
                logging.info('Train Classifier')
                metrics = trainer.train_classifier(trainer)
                log_metrics('Classifier', step, metrics)
                metrics = trainer.test_ave_score(trainer)
                log_metrics('Classifier', step, metrics)
                trainer.cal_confidence_weight()
            elif args.method == "NoiGAN" and step % args.classify_steps == 0:
                logging.info('Train NoiGAN')
                trainer.train_NoiGAN(trainer)
                metrics = trainer.test_ave_score(trainer)
                log_metrics('Classifier', step, metrics)
                trainer.cal_confidence_weight()

            log = kge_model.train_step(kge_model,
                                       optimizer,
                                       train_iterator,
                                       args,
                                       trainer=trainer)

            training_logs.append(log)

            if step >= warm_up_steps:
                current_learning_rate = current_learning_rate / 10
                logging.info('Change learning_rate to %f at step %d' %
                             (current_learning_rate, step))
                optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                                    kge_model.parameters()),
                                             lr=current_learning_rate)
                warm_up_steps = warm_up_steps * 3

            if step % args.save_checkpoint_steps == 0:
                save_variable_list = {
                    'step': step,
                    'current_learning_rate': current_learning_rate,
                    'warm_up_steps': warm_up_steps
                }
                save_model(kge_model, optimizer, save_variable_list, args,
                           trainer)

            if step % args.log_steps == 0:
                metrics = {}
                for metric in training_logs[0].keys():
                    metrics[metric] = sum(
                        [log[metric]
                         for log in training_logs]) / len(training_logs)
                log_metrics('Training average', step, metrics)
                training_logs = []

            if args.do_valid and step % args.valid_steps == 0:
                logging.info('Evaluating on Valid Dataset...')
                metrics = kge_model.test_step(kge_model, valid_triples,
                                              all_true_triples, args)
                log_metrics('Valid', step, metrics)

        save_variable_list = {
            'step': step,
            'current_learning_rate': current_learning_rate,
            'warm_up_steps': warm_up_steps
        }
        save_model(kge_model, optimizer, save_variable_list, args, trainer)

    if trainer is not None:
        logging.info("Evaluating Classifier on Train Dataset")
        metrics = trainer.test_ave_score(trainer)
        log_metrics('Train', step, metrics)

    if args.do_valid:
        logging.info('Evaluating on Valid Dataset...')
        metrics = kge_model.test_step(kge_model, valid_triples,
                                      all_true_triples, args)
        log_metrics('Valid', step, metrics)

    if args.do_test:
        logging.info('Evaluating on Test Dataset...')
        metrics = kge_model.test_step(kge_model, test_triples,
                                      all_true_triples, args)
        log_metrics('Test', step, metrics)
        # logging.info("\t".join([metric for metric in metrics.values()]))

    if args.evaluate_train:
        logging.info('Evaluating on Training Dataset...')
        metrics = kge_model.test_step(kge_model, train_triples,
                                      all_true_triples, args)
        log_metrics('Test', step, metrics)
Exemplo n.º 17
0
def main(args):
    if (
        (not args.do_train)
        and (not args.do_valid)
        and (not args.do_test)
        and (not args.evaluate_train)
    ):
        raise ValueError("one of train/val/test mode must be choosed.")

    if args.init_checkpoint:
        override_config(args)

    args.save_path = (
        "log/%s/%s/%s-%s/%s"
        % (args.dataset, args.model, args.hidden_dim, args.gamma, time.time())
        if args.save_path == None
        else args.save_path
    )
    writer = SummaryWriter(args.save_path)

    # Write logs to checkpoint and console
    set_logger(args)

    dataset = LinkPropPredDataset(name=args.dataset)
    split_dict = dataset.get_edge_split()
    nentity = dataset.graph["num_nodes"]
    nrelation = int(max(dataset.graph["edge_reltype"])[0]) + 1

    evaluator = Evaluator(name=args.dataset)

    args.nentity = nentity
    args.nrelation = nrelation

    logging.info("Model: %s" % args.model)
    logging.info("Dataset: %s" % args.dataset)
    logging.info("#entity: %d" % nentity)
    logging.info("#relation: %d" % nrelation)

    train_triples = split_dict["train"]
    logging.info("#train: %d" % len(train_triples["head"]))
    valid_triples = split_dict["valid"]
    logging.info("#valid: %d" % len(valid_triples["head"]))
    test_triples = split_dict["test"]
    logging.info("#test: %d" % len(test_triples["head"]))

    train_count, train_true_head, train_true_tail = (
        defaultdict(lambda: 4),
        defaultdict(list),
        defaultdict(list),
    )
    for i in tqdm(range(len(train_triples["head"]))):
        head, relation, tail = (
            train_triples["head"][i],
            train_triples["relation"][i],
            train_triples["tail"][i],
        )
        train_count[(head, relation)] += 1
        train_count[(tail, -relation - 1)] += 1
        train_true_head[(relation, tail)].append(head)
        train_true_tail[(head, relation)].append(tail)

    kge_model = KGEModel(
        model_name=args.model,
        nentity=nentity,
        nrelation=nrelation,
        hidden_dim=args.hidden_dim,
        gamma=args.gamma,
        double_entity_embedding=args.double_entity_embedding,
        double_relation_embedding=args.double_relation_embedding,
        evaluator=evaluator,
    )

    logging.info("Model Parameter Configuration:")
    for name, param in kge_model.named_parameters():
        logging.info(
            "Parameter %s: %s, require_grad = %s"
            % (name, str(param.size()), str(param.requires_grad))
        )

    if args.cuda:
        kge_model = kge_model.cuda()

    if args.do_train:
        # Set training dataloader iterator
        train_dataloader_head = DataLoader(
            TrainDataset(
                train_triples,
                nentity,
                nrelation,
                args.negative_sample_size,
                "head-batch",
                train_count,
                train_true_head,
                train_true_tail,
            ),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TrainDataset.collate_fn,
        )

        train_dataloader_tail = DataLoader(
            TrainDataset(
                train_triples,
                nentity,
                nrelation,
                args.negative_sample_size,
                "tail-batch",
                train_count,
                train_true_head,
                train_true_tail,
            ),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TrainDataset.collate_fn,
        )

        train_iterator = BidirectionalOneShotIterator(
            train_dataloader_head, train_dataloader_tail
        )

        # Set training configuration
        current_learning_rate = args.learning_rate
        optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, kge_model.parameters()),
            lr=current_learning_rate,
        )
        if args.warm_up_steps:
            warm_up_steps = args.warm_up_steps
        else:
            warm_up_steps = args.max_steps // 2

    if args.init_checkpoint:
        # Restore model from checkpoint directory
        logging.info("Loading checkpoint %s..." % args.init_checkpoint)
        checkpoint = torch.load(os.path.join(args.init_checkpoint, "checkpoint"))
        init_step = checkpoint["step"]
        kge_model.load_state_dict(checkpoint["model_state_dict"])
        if args.do_train:
            current_learning_rate = checkpoint["current_learning_rate"]
            warm_up_steps = checkpoint["warm_up_steps"]
            optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    else:
        logging.info("Ramdomly Initializing %s Model..." % args.model)
        init_step = 0

    step = init_step

    logging.info("Start Training...")
    logging.info("init_step = %d" % init_step)
    logging.info("batch_size = %d" % args.batch_size)
    logging.info(
        "negative_adversarial_sampling = %d" % args.negative_adversarial_sampling
    )
    logging.info("hidden_dim = %d" % args.hidden_dim)
    logging.info("gamma = %f" % args.gamma)
    logging.info(
        "negative_adversarial_sampling = %s" % str(args.negative_adversarial_sampling)
    )
    if args.negative_adversarial_sampling:
        logging.info("adversarial_temperature = %f" % args.adversarial_temperature)

    # Set valid dataloader as it would be evaluated during training

    if args.do_train:
        logging.info("learning_rate = %d" % current_learning_rate)

        training_logs = []

        # Training Loop
        for step in range(init_step, args.max_steps):

            log = kge_model.train_step(kge_model, optimizer, train_iterator, args)
            training_logs.append(log)

            if step >= warm_up_steps:
                current_learning_rate = current_learning_rate / 10
                logging.info(
                    "Change learning_rate to %f at step %d"
                    % (current_learning_rate, step)
                )
                optimizer = torch.optim.Adam(
                    filter(lambda p: p.requires_grad, kge_model.parameters()),
                    lr=current_learning_rate,
                )
                warm_up_steps = warm_up_steps * 3

            if (
                step % args.save_checkpoint_steps == 0 and step > 0
            ):  # ~ 41 seconds/saving
                save_variable_list = {
                    "step": step,
                    "current_learning_rate": current_learning_rate,
                    "warm_up_steps": warm_up_steps,
                }
                save_model(kge_model, optimizer, save_variable_list, args)

            if step % args.log_steps == 0:
                metrics = {}
                for metric in training_logs[0].keys():
                    metrics[metric] = sum([log[metric] for log in training_logs]) / len(
                        training_logs
                    )
                log_metrics("Train", step, metrics, writer)
                training_logs = []

            if args.do_valid and step % args.valid_steps == 0 and step > 0:
                logging.info("Evaluating on Valid Dataset...")
                metrics = kge_model.test_step(kge_model, valid_triples, args)
                log_metrics("Valid", step, metrics, writer)

        save_variable_list = {
            "step": step,
            "current_learning_rate": current_learning_rate,
            "warm_up_steps": warm_up_steps,
        }
        save_model(kge_model, optimizer, save_variable_list, args)

    if args.do_valid:
        logging.info("Evaluating on Valid Dataset...")
        metrics = kge_model.test_step(kge_model, valid_triples, args)
        log_metrics("Valid", step, metrics, writer)

    if args.do_test:
        logging.info("Evaluating on Test Dataset...")
        metrics = kge_model.test_step(kge_model, test_triples, args)
        log_metrics("Test", step, metrics, writer)

    if args.evaluate_train:
        logging.info("Evaluating on Training Dataset...")
        small_train_triples = {}
        indices = np.random.choice(
            len(train_triples["head"]), args.ntriples_eval_train, replace=False
        )
        for i in train_triples:
            small_train_triples[i] = train_triples[i][indices]
        metrics = kge_model.test_step(
            kge_model, small_train_triples, args, random_sampling=True
        )
        log_metrics("Train", step, metrics, writer)
Exemplo n.º 18
0
train_triples = read_triple("data/fr_en/att_triple_all")
train_dataloader_head = data.DataLoader(
    TrainDataset(train_triples, entity_count, attr_count, value_count, 256, 'head-batch'),
    batch_size=1024,
    shuffle=False,
    num_workers=4,
    collate_fn=TrainDataset.collate_fn
)
train_dataloader_tail = data.DataLoader(
    TrainDataset(train_triples, entity_count, attr_count, value_count, 256, 'tail-batch'),
    batch_size=1024,
    shuffle=False,
    num_workers=4,
    collate_fn=TrainDataset.collate_fn
)
train_iterator = BidirectionalOneShotIterator(train_dataloader_head, train_dataloader_tail)

model = TransE(entity_count, attr_count, value_count, device).to(device)
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

start_epoch_id = 1
step = 0
best_score = 0.0
epochs = 8000

# if checkpoint_path:
#     start_epoch_id, step, best_score = load_checkpoint(checkpoint_path, model, optimizer)

print(model)

t = Tester()
Exemplo n.º 19
0
def main(args):

    # fl == 0 : entity graph
    if args.fl == 0:

        if (not args.do_train_en) and (not args.do_valid_en) and (
                not args.do_test_en):
            raise ValueError('One of train/valid/test mode must be choosed.')

        if args.init_checkpoint_en:
            override_config(args)
        elif args.data_path is None:
            raise ValueError(
                'One of init_checkpoint/data_path must be choosed.')

        if args.do_train_en and args.save_path_en is None:
            raise ValueError('Where do you want to save your trained model?')

        if args.save_path_en and not os.path.exists(args.save_path_en):
            os.makedirs(args.save_path_en)

        set_logger(args)
        # 获取所有实体 final_entity.txt
        entity2id = dict()
        with open(os.path.join(args.data_path, 'final_entity_order.txt'),
                  encoding='utf8') as fin:
            for line in fin.readlines():
                eid, entity = line.strip().split('\t')
                entity2id[entity] = int(eid)
        nentity = len(entity2id)

        # 获取所有entity_relation 从ffinal_en_relation_order.txt
        entity_relation2id = dict()
        with open(os.path.join(args.data_path,
                               'ffinal_en_relation_order.txt')) as fin:
            for line in fin:
                en_reid, entity_relation = line.strip().split('\t')

                entity_relation2id[entity_relation] = int(en_reid)
        nentity_relation = len(entity_relation2id)

        args.nentity = nentity
        args.nentity_re = nentity_relation

        logging.info('entity_Model: %s' % args.entity_model)
        logging.info('Data Path: %s' % args.data_path)
        logging.info('number of entity: %d' % nentity)
        logging.info('number of entity_re: %d' % nentity_relation)

        #加载训练三元组集合,
        train_entity_triples = read_triple(
            os.path.join(args.data_path, 'train_entity_Graph.txt'), entity2id,
            entity_relation2id)
        logging.info('#train en_en triples: %s' % len(train_entity_triples))
        #加载valid三元组集合,
        val_entity_triples = read_triple(
            os.path.join(args.data_path, 'val_entity_Graph.txt'), entity2id,
            entity_relation2id)
        logging.info('#val en_en triples: %s' % len(val_entity_triples))
        #加载test三元组集合,
        test_entity_triples = read_triple(
            os.path.join(args.data_path, 'test_entity_Graph.txt'), entity2id,
            entity_relation2id)
        logging.info('#test en_en triples: %s' % len(test_entity_triples))
        # all_triples
        all_true_triples_entity = train_entity_triples + val_entity_triples + test_entity_triples

        # entity model
        entity_kge_model = DKGE_Model(
            model_name=args.entity_model,
            nnode=args.nentity,
            nnode_re=args.nentity_re,
            hidden_dim=args.hidden_dim_en,
            gamma=args.gamma_en,
            gamma_intra=args.gamma_intra,
            double_node_embedding=args.double_node_embedding_en,
            double_node_re_embedding=args.double_node_re_embedding_en)
        logging.info('Entity Model Parameter Configuration:')
        for name, param in entity_kge_model.named_parameters():
            logging.info('Parameter %s: %s, require_grad = %s' %
                         (name, str(param.size()), str(param.requires_grad)))
        if args.cuda:
            entity_kge_model = entity_kge_model.cuda()

        if args.do_train_en:
            train_dataloader_head_en = DataLoader(
                TrainDataset_en(train_entity_triples, nentity,
                                nentity_relation, args.negative_sample_size_en,
                                'head-batch'),
                batch_size=args.batch_size_en,
                shuffle=True,
                num_workers=int(max(1, args.cpu_num // 2)),
                collate_fn=TrainDataset_en.collate_fn)
            train_dataloader_tail_en = DataLoader(
                TrainDataset_en(train_entity_triples, nentity,
                                nentity_relation, args.negative_sample_size_en,
                                'tail-batch'),
                batch_size=args.batch_size_en,
                shuffle=True,
                num_workers=int(max(1, args.cpu_num // 2)),
                collate_fn=TrainDataset_en.collate_fn)
            train_iterator_en = BidirectionalOneShotIterator(
                train_dataloader_head_en, train_dataloader_tail_en)
            # Set training configuration
            current_learning_rate_en = args.learning_rate_en
            optimizer_en = torch.optim.Adam(filter(
                lambda p: p.requires_grad, entity_kge_model.parameters()),
                                            lr=current_learning_rate_en,
                                            amsgrad=True)
            if args.warm_up_steps_en:
                warm_up_steps_en = args.warm_up_steps_en
            else:
                warm_up_steps_en = args.max_steps_en // 2

        if args.init_checkpoint_en:
            # Restore model from checkpoint directory
            logging.info('Loading checkpoint %s...' % args.init_checkpoint_en)
            checkpoint_en = torch.load(
                os.path.join(args.init_checkpoint_en, 'checkpoint_en'))
            init_step_en = checkpoint_en['step']
            entity_kge_model.load_state_dict(checkpoint_en['model_state_dict'])
            if args.do_train_en:
                current_learning_rate_en = checkpoint_en[
                    'current_learning_rate']
                warm_up_steps_en = checkpoint_en['warm_up_steps']
                optimizer_en.load_state_dict(
                    checkpoint_en['optimizer_state_dict'])
        else:
            logging.info('Ramdomly Initializing %s Model...' %
                         args.entity_model)
            init_step_en = 0

        step_en = init_step_en
        logging.info('Start Training of entity graph...')
        logging.info('init_step_en = %d' % init_step_en)
        logging.info('learning_rate_en = %f' % current_learning_rate_en)
        logging.info('batch_size_en = %d' % args.batch_size_en)
        logging.info('hidden_dim_en = %d' % args.hidden_dim_en)
        logging.info('gamma_en = %f' % args.gamma_en)

        if args.do_train_en:
            training_logs_en = []
            for step_en in range(init_step_en, args.max_steps_en):
                log_en = entity_kge_model.train_step(entity_kge_model,
                                                     optimizer_en,
                                                     train_iterator_en, args)
                training_logs_en.append(log_en)

                if step_en >= warm_up_steps_en:
                    current_learning_rate_en = current_learning_rate_en / 10
                    logging.info('Changing learning_rate_en to %f at step %d' %
                                 (current_learning_rate_en, step_en))
                    optimizer_en = torch.optim.Adam(
                        filter(lambda p: p.requires_grad,
                               entity_kge_model.parameters()),
                        lr=current_learning_rate_en,
                        amsgrad=True)
                    warm_up_steps_en = warm_up_steps_en * 3

                if step_en % args.save_checkpoint_steps_en == 0:
                    save_variable_list_en = {
                        'step': step_en,
                        'current_learning_rate': current_learning_rate_en,
                        'warm_up_steps': warm_up_steps_en
                    }
                    save_model(entity_kge_model, optimizer_en,
                               save_variable_list_en, args)

                if step_en % args.log_steps_en == 0:
                    metrics_en = {}
                    for metric_en in training_logs_en[0].keys():
                        metrics_en[metric_en] = sum([
                            log_en[metric_en] for log_en in training_logs_en
                        ]) / len(training_logs_en)
                    log_metrics('Training average in entity graph', step_en,
                                metrics_en)
                    training_logs_en = []

                if args.do_valid_en and step_en % args.valid_steps_en == 0 and step_en != 0 and False:
                    logging.info(
                        'Evaluating on Valid Dataset in entity graph ...')
                    metrics_en = entity_kge_model.test_step(
                        entity_kge_model, val_entity_triples,
                        all_true_triples_entity, args)
                    log_metrics('Valid', step_en, metrics_en)

            save_variable_list_en = {
                'step': step_en,
                'current_learning_rate': current_learning_rate_en,
                'warm_up_steps': warm_up_steps_en
            }
            save_model(entity_kge_model, optimizer_en, save_variable_list_en,
                       args)

        if args.do_valid_en and False:
            logging.info('Evaluating on Valid Dataset in entity graph...')
            metrics_en = entity_kge_model.test_step(entity_kge_model,
                                                    val_entity_triples,
                                                    all_true_triples_entity,
                                                    args)
            log_metrics('Valid', step_en, metrics_en)

        if args.do_test_en:
            logging.info('Testing on Test Dataset in entity graph...')
            metrics_en = entity_kge_model.test_step(entity_kge_model,
                                                    test_entity_triples,
                                                    all_true_triples_entity,
                                                    args)
            log_metrics('Test', step_en, metrics_en)

    # type
    else:
        if (not args.do_train_ty) and (not args.do_valid_ty) and (
                not args.do_test_ty):
            raise ValueError('One of train/valid/test mode must be choosed.')

        if args.init_checkpoint_ty:
            override_config(args)
        elif args.data_path is None:
            raise ValueError(
                'One of init_checkpoint/data_path must be choosed.')

        if args.do_train_ty and args.save_path_ty is None:
            raise ValueError('Where do you want to save your trained model?')

        if args.save_path_ty and not os.path.exists(args.save_path_ty):
            os.makedirs(args.save_path_ty)

        set_logger(args)
        # 获取所有type 从final_type.txt
        type2id = dict()
        with open(os.path.join(args.data_path, 'final_type_order.txt')) as fin:
            for line in fin:
                tid, type = line.strip().split('\t')
                type2id[type] = int(tid)
        ntype = len(type2id)

        # 获取所有type_relation 从ffinal_ty_relation_order.txt
        type_relation2id = dict()
        with open(os.path.join(args.data_path,
                               'ffinal_ty_relation_order.txt')) as fin:
            for line in fin:
                ty_reid, type_relation = line.strip().split('\t')
                type_relation2id[type_relation] = int(ty_reid)
        ntype_relation = len(type_relation2id)

        args.ntype = ntype
        args.ntype_re = ntype_relation

        logging.info('type_Model: %s' % args.type_model)
        logging.info('Data Path: %s' % args.data_path)
        logging.info('number of type: %d' % ntype)
        logging.info('number of type_re: %d' % ntype_relation)

        #加载训练三元组集合,
        train_type_triples = read_triple(
            os.path.join(args.data_path, 'train_type_Graph.txt'), type2id,
            type_relation2id)
        logging.info('#train ty_ty triples: %s' % len(train_type_triples))
        #加载valid三元组集合,
        val_type_triples = read_triple(
            os.path.join(args.data_path, 'val_type_Graph.txt'), type2id,
            type_relation2id)
        logging.info('#val ty_ty triples: %s' % len(val_type_triples))
        #加载test三元组集合,
        test_type_triples = read_triple(
            os.path.join(args.data_path, 'test_type_Graph.txt'), type2id,
            type_relation2id)
        logging.info('#test ty_ty triples: %s' % len(test_type_triples))
        # all_triples
        all_true_triples_type = train_type_triples + val_type_triples + test_type_triples

        # type model
        type_kge_model = DKGE_Model(
            model_name=args.type_model,
            nnode=args.ntype,
            nnode_re=args.ntype_re,
            hidden_dim=args.hidden_dim_ty,
            gamma=args.gamma_ty,
            gamma_intra=args.gamma_intra,
            double_node_embedding=args.double_node_embedding_ty,
            double_node_re_embedding=args.double_node_re_embedding_ty)
        logging.info('Type Model Parameter Configuration:')
        for name, param in type_kge_model.named_parameters():
            logging.info('Parameter %s: %s, require_grad = %s' %
                         (name, str(param.size()), str(param.requires_grad)))
        if args.cuda:
            type_kge_model = type_kge_model.cuda()

        if args.do_train_ty:
            train_dataloader_head_ty = DataLoader(
                TrainDataset_ty(train_type_triples, ntype, ntype_relation,
                                args.negative_sample_size_ty, 'head-batch'),
                batch_size=args.batch_size_ty,
                shuffle=True,
                num_workers=max(1, args.cpu_num // 2),
                collate_fn=TrainDataset_ty.collate_fn)
            train_dataloader_tail_ty = DataLoader(
                TrainDataset_ty(train_type_triples, ntype, ntype_relation,
                                args.negative_sample_size_ty, 'tail-batch'),
                batch_size=args.batch_size_ty,
                shuffle=True,
                num_workers=max(1, args.cpu_num // 2),
                collate_fn=TrainDataset_ty.collate_fn)
            train_iterator_ty = BidirectionalOneShotIterator(
                train_dataloader_head_ty, train_dataloader_tail_ty)
            # Set training configuration
            current_learning_rate_ty = args.learning_rate_ty
            optimizer_ty = torch.optim.Adam(filter(
                lambda p: p.requires_grad, type_kge_model.parameters()),
                                            lr=current_learning_rate_ty,
                                            amsgrad=True)
            if args.warm_up_steps_ty:
                warm_up_steps_ty = args.warm_up_steps_ty
            else:
                warm_up_steps_ty = args.max_steps_ty // 2

        if args.init_checkpoint_ty:
            # Restore model from checkpoint directory
            logging.info('Loading checkpoint %s...' % args.init_checkpoint_ty)
            checkpoint_ty = torch.load(
                os.path.join(args.init_checkpoint_ty, 'checkpoint_ty'))
            init_step_ty = checkpoint_ty['step']
            type_kge_model.load_state_dict(checkpoint_ty['model_state_dict'])
            if args.do_train_ty:
                current_learning_rate_ty = checkpoint_ty[
                    'current_learning_rate']
                warm_up_steps = checkpoint_ty['warm_up_steps']
                optimizer_ty.load_state_dict(
                    checkpoint_ty['optimizer_state_dict'])
        else:
            logging.info('Ramdomly Initializing %s Model...' % args.type_model)
            init_step_ty = 0

        step_ty = init_step_ty
        logging.info('Start Training of type graph...')
        logging.info('init_step_ty = %d' % init_step_ty)
        logging.info('learning_rate_ty = %f' % current_learning_rate_ty)
        logging.info('batch_size_ty = %d' % args.batch_size_ty)
        logging.info('hidden_dim_ty = %d' % args.hidden_dim_ty)
        logging.info('gamma_ty = %f' % args.gamma_ty)

        if args.do_train_ty:
            training_logs_ty = []
            for step_ty in range(init_step_ty, args.max_steps_ty):
                log_ty = type_kge_model.train_step(type_kge_model,
                                                   optimizer_ty,
                                                   train_iterator_ty, args)
                training_logs_ty.append(log_ty)

                if step_ty >= warm_up_steps_ty:
                    current_learning_rate_ty = current_learning_rate_ty / 10
                    logging.info('Changing learning_rate_ty to %f at step %d' %
                                 (current_learning_rate_ty, step_ty))
                    optimizer_ty = torch.optim.Adam(
                        filter(lambda p: p.requires_grad,
                               type_kge_model.parameters()),
                        lr=current_learning_rate_ty,
                        amsgrad=True)
                    warm_up_steps_ty = warm_up_steps_ty * 3

                if step_ty % args.save_checkpoint_steps_ty == 0:
                    save_variable_list_ty = {
                        'step': step_ty,
                        'current_learning_rate': current_learning_rate_ty,
                        'warm_up_steps': warm_up_steps_ty
                    }
                    save_model(type_kge_model, optimizer_ty,
                               save_variable_list_ty, args)

                if step_ty % args.log_steps_ty == 0:
                    metrics_ty = {}
                    for metric_ty in training_logs_ty[0].keys():
                        metrics_ty[metric_ty] = sum([
                            log_ty[metric_ty] for log_ty in training_logs_ty
                        ]) / len(training_logs_ty)
                    log_metrics('Training average in type graph', step_ty,
                                metrics_ty)
                    training_logs_ty = []

                if False and args.do_valid_ty and step_ty % args.valid_steps_ty == 0:
                    logging.info(
                        'Evaluating on Valid Dataset in type graph ...')
                    metrics_ty = type_kge_model.test_step(
                        type_kge_model, val_type_triples,
                        all_true_triples_type, args)
                    log_metrics('Valid', step_ty, metrics_ty)

            save_variable_list_ty = {
                'step': step_ty,
                'current_learning_rate': current_learning_rate_ty,
                'warm_up_steps': warm_up_steps_ty
            }
            save_model(type_kge_model, optimizer_ty, save_variable_list_ty,
                       args)

        if args.do_valid_ty and False:
            logging.info('Evaluating on Valid Dataset in type graph...')
            metrics_ty = type_kge_model.test_step(type_kge_model,
                                                  val_type_triples,
                                                  all_true_triples_type, args)
            log_metrics('Valid', step_ty, metrics_ty)

        if args.do_test_ty:
            logging.info('Testing on Test Dataset in type graph...')
            metrics_ty = type_kge_model.test_step(type_kge_model,
                                                  test_type_triples,
                                                  all_true_triples_type, args)
            log_metrics('Test', step_ty, metrics_ty)
Exemplo n.º 20
0
def main(args):
    if (not args.do_train) and (not args.do_valid) and (not args.do_test):
        raise ValueError('one of train/val/test mode must be choosed.')

    if args.init_checkpoint:
        override_config(args)
    elif args.data_path is None:
        raise ValueError('one of init_checkpoint/data_path must be choosed.')

    if args.do_train and args.save_path is None:
        # create default save directory
        dt = datetime.datetime.now()
        args.save_path = os.path.join(
            os.environ['LOG_DIR'],
            args.data_path.split('/')[-1], args.model,
            datetime.datetime.now().strftime('%m%d%H%M%S'))
        # raise ValueError('Where do you want to save your trained model?')

    if args.save_path and not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    # Write logs to checkpoint and console
    set_logger(args)
    writer = SummaryWriter(log_dir=args.save_path)

    with open(os.path.join(args.data_path, 'entities.dict')) as fin:
        entity2id = dict()
        for line in fin:
            eid, entity = line.strip().split('\t')
            entity2id[entity] = int(eid)

    with open(os.path.join(args.data_path, 'relations.dict')) as fin:
        relation2id = dict()
        for line in fin:
            rid, relation = line.strip().split('\t')
            relation2id[relation] = int(rid)

    # Read regions for Countries S* datasets
    if args.countries:
        regions = list()
        with open(os.path.join(args.data_path, 'regions.list')) as fin:
            for line in fin:
                region = line.strip()
                regions.append(entity2id[region])
        args.regions = regions

    nentity = len(entity2id)
    nrelation = len(relation2id)

    args.nentity = nentity
    args.nrelation = nrelation

    logging.info('Model: %s' % args.model)
    logging.info('Data Path: %s' % args.data_path)
    logging.info('Save Path: {}'.format(args.save_path))
    logging.info('#entity: %d' % nentity)
    logging.info('#relation: %d' % nrelation)

    train_triples = read_triple(os.path.join(args.data_path, 'train.txt'),
                                entity2id, relation2id)
    logging.info('#train: %d' % len(train_triples))
    valid_triples = read_triple(os.path.join(args.data_path, 'valid.txt'),
                                entity2id, relation2id)
    logging.info('#valid: %d' % len(valid_triples))
    test_triples = read_triple(os.path.join(args.data_path, 'test.txt'),
                               entity2id, relation2id)
    logging.info('#test: %d' % len(test_triples))

    # All true triples
    all_true_triples = train_triples + valid_triples + test_triples

    if args.model in EUC_MODELS:
        ModelClass = EKGEModel
    elif args.model in HYP_MODELS:
        ModelClass = HKGEModel
    elif args.model in ONE_2_MANY_E_MODELS:
        ModelClass = O2MEKGEModel
    else:
        raise ValueError('model %s not supported' % args.model)

    if ModelClass != O2MEKGEModel:
        kge_model = ModelClass(
            model_name=args.model,
            nentity=nentity,
            nrelation=nrelation,
            hidden_dim=args.hidden_dim,
            gamma=args.gamma,
            p_norm=args.p_norm,
            dropout=args.dropout,
            entity_embedding_multiple=args.entity_embedding_multiple,
            relation_embedding_multiple=args.relation_embedding_multiple)
    else:
        kge_model = ModelClass(
            model_name=args.model,
            nentity=nentity,
            nrelation=nrelation,
            hidden_dim=args.hidden_dim,
            gamma=args.gamma,
            p_norm=args.p_norm,
            dropout=args.dropout,
            entity_embedding_multiple=args.entity_embedding_multiple,
            relation_embedding_multiple=args.relation_embedding_multiple,
            nsiblings=args.nsib,
            rho=args.rho)

    logging.info('Model Parameter Configuration:')
    for name, param in kge_model.named_parameters():
        logging.info('Parameter %s: %s, require_grad = %s' %
                     (name, str(param.size()), str(param.requires_grad)))

    if args.cuda:
        kge_model = kge_model.cuda()

    if args.do_train:
        # Set training dataloader iterator
        train_dataloader_head = DataLoader(
            TrainDataset(train_triples, nentity, nrelation,
                         args.negative_sample_size, 'head-batch'),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TrainDataset.collate_fn)

        train_dataloader_tail = DataLoader(
            TrainDataset(train_triples, nentity, nrelation,
                         args.negative_sample_size, 'tail-batch'),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TrainDataset.collate_fn)

        train_iterator = BidirectionalOneShotIterator(train_dataloader_head,
                                                      train_dataloader_tail)

        # Set training configuration
        current_learning_rate = args.learning_rate
        optimizer = init_optimizer(kge_model, current_learning_rate)
        if args.warm_up_steps:
            warm_up_steps = args.warm_up_steps
        else:
            warm_up_steps = args.max_steps // 2

    if args.init_checkpoint:
        # Restore model from checkpoint directory
        logging.info('Loading checkpoint %s...' % args.init_checkpoint)
        checkpoint = torch.load(
            os.path.join(args.init_checkpoint, 'checkpoint'))
        init_step = checkpoint['step']
        kge_model.load_state_dict(checkpoint['model_state_dict'])
        if args.do_train:
            current_learning_rate = checkpoint['current_learning_rate']
            warm_up_steps = checkpoint['warm_up_steps']
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    else:
        logging.info('Ramdomly Initializing %s Model...' % args.model)
        init_step = 0

    step = init_step

    if args.do_train:
        logging.info('Start Training...')
        logging.info('init_step = %d' % init_step)
        logging.info('hidden_dim = %d' % args.hidden_dim)
        logging.info('learning_rate = %d' % current_learning_rate)
        logging.info('batch_size = %d' % args.batch_size)
        logging.info('negative_adversarial_sampling = %d' %
                     args.negative_adversarial_sampling)

        logging.info('gamma = %f' % args.gamma)
        logging.info('dropout = %f' % args.dropout)
        if args.negative_adversarial_sampling:
            logging.info('adversarial_temperature = %f' %
                         args.adversarial_temperature)

        # Set valid dataloader as it would be evaluated during training
        training_logs = []

        # Training Loop
        for step in range(init_step, args.max_steps):

            log = kge_model.train_step(kge_model, optimizer, train_iterator,
                                       args)
            training_logs.append(log)
            write_metrics(writer, step, log, split='train')
            write_metrics(writer, step,
                          {'current_learning_rate': current_learning_rate})

            if step >= warm_up_steps:
                current_learning_rate = current_learning_rate / 10
                logging.info('Change learning_rate to %f at step %d' %
                             (current_learning_rate, step))
                optimizer = init_optimizer(kge_model, current_learning_rate)

                warm_up_steps = warm_up_steps * 3

            if step % args.save_checkpoint_steps == 0:
                save_variable_list = {
                    'step': step,
                    'current_learning_rate': current_learning_rate,
                    'warm_up_steps': warm_up_steps
                }
                save_model(kge_model, optimizer, save_variable_list, args)

            if step % args.log_steps == 0:
                metrics = {}
                for metric in training_logs[0].keys():
                    metrics[metric] = sum(
                        [log[metric]
                         for log in training_logs]) / len(training_logs)
                log_metrics('Training average', step, metrics)
                write_metrics(writer, step, metrics, split='train')
                training_logs = []

            if args.do_valid and step % args.valid_steps == 0:
                logging.info('Evaluating on Valid Dataset...')
                metrics = kge_model.test_step(kge_model, valid_triples,
                                              all_true_triples, args)
                log_metrics('Valid', step, metrics)
                write_metrics(writer, step, metrics, split='valid')

        save_variable_list = {
            'step': step,
            'current_learning_rate': current_learning_rate,
            'warm_up_steps': warm_up_steps
        }

        save_model(kge_model, optimizer, save_variable_list, args)

    if args.do_valid:
        logging.info('Evaluating on Valid Dataset...')
        metrics = kge_model.test_step(kge_model, valid_triples,
                                      all_true_triples, args)
        log_metrics('Valid', step, metrics)

    if args.do_test:
        logging.info('Evaluating on Test Dataset...')
        metrics = kge_model.test_step(kge_model, test_triples,
                                      all_true_triples, args)
        log_metrics('Test', step, metrics)

    if args.evaluate_train:
        logging.info('Evaluating on Training Dataset...')
        metrics = kge_model.test_step(kge_model, train_triples,
                                      all_true_triples, args)
        log_metrics('Test', step, metrics)
Exemplo n.º 21
0
def main(arg):

    with open(
            r'C:\Users\pc\Desktop\编程\KnowledgeGraphEmbedding-master\data\FB15k\entities.dict'
    ) as fin:
        entity2id = dict()
        for line in fin:
            eid, entity = line.strip().split('\t')
            entity2id[entity] = int(eid)
    with open(
            r'C:\Users\pc\Desktop\编程\KnowledgeGraphEmbedding-master\data\FB15k\relations.dict'
    ) as fin:
        relation2id = dict()
        for line in fin:
            rid, relation = line.strip().split('\t')
            relation2id[relation] = int(rid)
    nentity = len(entity2id)
    nrelation = len(relation2id)

    arg.nentity = nentity
    arg.nrelation = nrelation

    logging.info('Model:%s' % arg.MODEL)
    logging.info('Data Path:%s' % arg.DATA_PATH)
    logging.info('entity:%d' % arg.nentity)
    logging.info('relation:%d' % arg.nrelation)

    #extract data from file
    train_triples = read_triple(os.path.join(arg.DATA_PATH, 'train.txt'),
                                entity2id, relation2id)
    logging.info('#train:%d' % len(train_triples))
    valid_triples = read_triple(os.path.join(arg.DATA_PATH, 'valid.txt'),
                                entity2id, relation2id)
    logging.info('#valid:%d' % len(valid_triples))
    test_triples = read_triple(os.path.join(arg.DATA_PATH, 'test.txt'),
                               entity2id, relation2id)
    logging.info('#test:%d' % len(test_triples))
    #all true triples
    all_true_triples = train_triples + valid_triples + test_triples

    #construct model
    kge_model = KGEModel(
        model_name=arg.MODEL,
        nentity=arg.nentity,
        nrelation=arg.nrelation,
        hidden_dim=arg.HIDDEN_DIM,
        gamma=arg.gamma,
        double_entity_embedding=arg.double_entity_embedding,
        double_relation_embedding=arg.double_relation_embedding)

    #print model para configuration
    logging.info('Model Parameter Configuration')
    for name, para in kge_model.named_parameters():
        #print(name,para.size(),para.requires_grad
        logging.info('Parameter %s:%s,require_grad=%s' %
                     (name, str(para.size()), str(para.requires_grad)))

    #do train
    train_dataloader_head = DataLoader(TrainDataset(train_triples, nentity,
                                                    nrelation,
                                                    arg.negative_sample_size,
                                                    'head-bath'),
                                       batch_size=arg.BATCH_SIZE,
                                       shuffle=True,
                                       num_workers=max(1, arg.cpu_num // 2),
                                       collate_fn=TrainDataset.collate_fn)
    train_dataloader_tail = DataLoader(TrainDataset(train_triples, nentity,
                                                    nrelation,
                                                    arg.negative_sample_size,
                                                    'tail-batch'),
                                       batch_size=arg.BATCH_SIZE,
                                       shuffle=True,
                                       num_workers=max(1, arg.cpu_num // 2),
                                       collate_fn=TrainDataset.collate_fn)

    train_iterator = BidirectionalOneShotIterator(train_dataloader_head,
                                                  train_dataloader_tail)

    #set train configuration
    current_learning_rate = arg.LR
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        kge_model.parameters()),
                                 lr=current_learning_rate)

    warm_up_steps = arg.warm_up_steps if arg.warm_up_steps else arg.max_steps // 2
    init_step = 0
    step = init_step

    logging.info('Start Training...')
    logging.info('init_step = %d' % init_step)
    logging.info('learning_rate = %d' % current_learning_rate)
    logging.info('batch_size = %d' % arg.BATCH_SIZE)
    #logging.info('negative_adversarial_sampling = %d' % arg.negative_adversarial_sampling'])
    logging.info('hidden_dim = %d' % arg.HIDDEN_DIM)
    logging.info('gamma = %f' % arg.gamma)
    #logging.info('negative_adversarial_sampling = %s' % str(arg.negative_adversarial_sampling']))

    #start training
    training_logs = []

    for step in range(init_step, arg.max_steps):
        log = kge_model.train_step(kge_model, optimizer, train_iterator, arg)
        training_logs.append(log)
        #update warm-up-step
        if step >= warm_up_steps:  #大于warm_up_steps后学习率变为原来的1/10
            current_learning_rate = current_learning_rate / 10
            logging.info('Change learning_rate to %f at step %d' %
                         (current_learning_rate, step))
            optimizer = torch.optim.Adam(
                filter(lambda p: p.requires_grad, kge_model.parameters()),
                lr=current_learning_rate  #更新优化器里的学习率
            )
            warm_up_steps = warm_up_steps * 3  #更新warm_up_steps
        #save model
        if step % arg.save_checkpoint_steps == 0:
            save_variable_list = {
                'step': step,
                'current_learning_rate': current_learning_rate,
                'warm_up_steps': warm_up_steps
            }
            save_model(kge_model, optimizer, save_variable_list, arg)
    #save after last time
    save_variable_list = {
        'step': step,
        'current_learning_rate': current_learning_rate,
        'warm_up_steps': warm_up_steps
    }
    save_model(kge_model, optimizer, save_variable_list, args)
Exemplo n.º 22
0
Arquivo: run.py Projeto: rpatil524/ogb
def main(args):
    if (not args.do_train) and (not args.do_valid) and (not args.do_test) and (
            not args.evaluate_train):
        raise ValueError('one of train/val/test mode must be choosed.')

    if args.init_checkpoint:
        override_config(args)

    args.save_path = 'log/%s/%s/%s-%s/%s' % (
        args.dataset, args.model, args.hidden_dim, args.gamma,
        time.time()) if args.save_path == None else args.save_path
    writer = SummaryWriter(args.save_path)

    # Write logs to checkpoint and console
    set_logger(args)

    dataset = LinkPropPredDataset(name='ogbl-biokg')
    split_edge = dataset.get_edge_split()
    train_triples, valid_triples, test_triples = split_edge[
        "train"], split_edge["valid"], split_edge["test"]
    nrelation = int(max(train_triples['relation'])) + 1
    entity_dict = dict()
    cur_idx = 0
    for key in dataset[0]['num_nodes_dict']:
        entity_dict[key] = (cur_idx,
                            cur_idx + dataset[0]['num_nodes_dict'][key])
        cur_idx += dataset[0]['num_nodes_dict'][key]
    nentity = sum(dataset[0]['num_nodes_dict'].values())

    evaluator = Evaluator(name=args.dataset)

    args.nentity = nentity
    args.nrelation = nrelation

    logging.info('Model: %s' % args.model)
    logging.info('Dataset: %s' % args.dataset)
    logging.info('#entity: %d' % nentity)
    logging.info('#relation: %d' % nrelation)

    # train_triples = split_dict['train']
    logging.info('#train: %d' % len(train_triples['head']))
    # valid_triples = split_dict['valid']
    logging.info('#valid: %d' % len(valid_triples['head']))
    # test_triples = split_dict['test']
    logging.info('#test: %d' % len(test_triples['head']))

    train_count, train_true_head, train_true_tail = defaultdict(
        lambda: 4), defaultdict(list), defaultdict(list)
    for i in tqdm(range(len(train_triples['head']))):
        head, relation, tail = train_triples['head'][i], train_triples[
            'relation'][i], train_triples['tail'][i]
        head_type, tail_type = train_triples['head_type'][i], train_triples[
            'tail_type'][i]
        train_count[(head, relation, head_type)] += 1
        train_count[(tail, -relation - 1, tail_type)] += 1
        train_true_head[(relation, tail)].append(head)
        train_true_tail[(head, relation)].append(tail)

    kge_model = KGEModel(
        model_name=args.model,
        nentity=nentity,
        nrelation=nrelation,
        hidden_dim=args.hidden_dim,
        gamma=args.gamma,
        double_entity_embedding=args.double_entity_embedding,
        double_relation_embedding=args.double_relation_embedding,
        evaluator=evaluator)

    logging.info('Model Parameter Configuration:')
    for name, param in kge_model.named_parameters():
        logging.info('Parameter %s: %s, require_grad = %s' %
                     (name, str(param.size()), str(param.requires_grad)))

    if args.cuda:
        kge_model = kge_model.cuda()

    if args.init_checkpoint:
        # Restore model from checkpoint directory
        logging.info('Loading checkpoint %s...' % args.init_checkpoint)
        checkpoint = torch.load(
            os.path.join(args.init_checkpoint, 'checkpoint'))
        entity_dict = checkpoint['entity_dict']

    if args.do_train:
        # Set training dataloader iterator
        train_dataloader_head = DataLoader(
            TrainDataset(train_triples, nentity, nrelation,
                         args.negative_sample_size, 'head-batch', train_count,
                         train_true_head, train_true_tail, entity_dict),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TrainDataset.collate_fn)

        train_dataloader_tail = DataLoader(
            TrainDataset(train_triples, nentity, nrelation,
                         args.negative_sample_size, 'tail-batch', train_count,
                         train_true_head, train_true_tail, entity_dict),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TrainDataset.collate_fn)

        train_iterator = BidirectionalOneShotIterator(train_dataloader_head,
                                                      train_dataloader_tail)

        # Set training configuration
        current_learning_rate = args.learning_rate
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                            kge_model.parameters()),
                                     lr=current_learning_rate)
        if args.warm_up_steps:
            warm_up_steps = args.warm_up_steps
        else:
            warm_up_steps = args.max_steps // 2

    if args.init_checkpoint:
        # Restore model from checkpoint directory
        # logging.info('Loading checkpoint %s...' % args.init_checkpoint)
        # checkpoint = torch.load(os.path.join(args.init_checkpoint, 'checkpoint'))
        init_step = checkpoint['step']
        kge_model.load_state_dict(checkpoint['model_state_dict'])
        # entity_dict = checkpoint['entity_dict']
        if args.do_train:
            current_learning_rate = checkpoint['current_learning_rate']
            warm_up_steps = checkpoint['warm_up_steps']
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    else:
        logging.info('Ramdomly Initializing %s Model...' % args.model)
        init_step = 0

    step = init_step

    logging.info('Start Training...')
    logging.info('init_step = %d' % init_step)
    logging.info('batch_size = %d' % args.batch_size)
    logging.info('negative_adversarial_sampling = %d' %
                 args.negative_adversarial_sampling)
    logging.info('hidden_dim = %d' % args.hidden_dim)
    logging.info('gamma = %f' % args.gamma)
    logging.info('negative_adversarial_sampling = %s' %
                 str(args.negative_adversarial_sampling))
    if args.negative_adversarial_sampling:
        logging.info('adversarial_temperature = %f' %
                     args.adversarial_temperature)

    # Set valid dataloader as it would be evaluated during training

    if args.do_train:
        logging.info('learning_rate = %d' % current_learning_rate)

        training_logs = []

        #Training Loop
        for step in range(init_step, args.max_steps):

            log = kge_model.train_step(kge_model, optimizer, train_iterator,
                                       args)
            training_logs.append(log)

            if step >= warm_up_steps:
                current_learning_rate = current_learning_rate / 10
                logging.info('Change learning_rate to %f at step %d' %
                             (current_learning_rate, step))
                optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                                    kge_model.parameters()),
                                             lr=current_learning_rate)
                warm_up_steps = warm_up_steps * 3

            if step % args.save_checkpoint_steps == 0 and step > 0:  # ~ 41 seconds/saving
                save_variable_list = {
                    'step': step,
                    'current_learning_rate': current_learning_rate,
                    'warm_up_steps': warm_up_steps,
                    'entity_dict': entity_dict
                }
                save_model(kge_model, optimizer, save_variable_list, args)

            if step % args.log_steps == 0:
                metrics = {}
                for metric in training_logs[0].keys():
                    metrics[metric] = sum(
                        [log[metric]
                         for log in training_logs]) / len(training_logs)
                log_metrics('Train', step, metrics, writer)
                training_logs = []

            if args.do_valid and step % args.valid_steps == 0 and step > 0:
                logging.info('Evaluating on Valid Dataset...')
                metrics = kge_model.test_step(kge_model, valid_triples, args,
                                              entity_dict)
                log_metrics('Valid', step, metrics, writer)

        save_variable_list = {
            'step': step,
            'current_learning_rate': current_learning_rate,
            'warm_up_steps': warm_up_steps
        }
        save_model(kge_model, optimizer, save_variable_list, args)

    if args.do_valid:
        logging.info('Evaluating on Valid Dataset...')
        metrics = kge_model.test_step(kge_model, valid_triples, args,
                                      entity_dict)
        log_metrics('Valid', step, metrics, writer)

    if args.do_test:
        logging.info('Evaluating on Test Dataset...')
        metrics = kge_model.test_step(kge_model, test_triples, args,
                                      entity_dict)
        log_metrics('Test', step, metrics, writer)

    if args.evaluate_train:
        logging.info('Evaluating on Training Dataset...')
        small_train_triples = {}
        indices = np.random.choice(len(train_triples['head']),
                                   args.ntriples_eval_train,
                                   replace=False)
        for i in train_triples:
            if 'type' in i:
                small_train_triples[i] = [train_triples[i][x] for x in indices]
            else:
                small_train_triples[i] = train_triples[i][indices]
        metrics = kge_model.test_step(kge_model,
                                      small_train_triples,
                                      args,
                                      entity_dict,
                                      random_sampling=True)
        log_metrics('Train', step, metrics, writer)
Exemplo n.º 23
0
def main(args):
    if (not args.do_train) and (not args.do_valid) and (not args.do_test):
        raise ValueError('one of train/val/test mode must be choosed.')
    
    if args.init_checkpoint:
        override_config(args)
    elif args.data_path is None:
        raise ValueError('one of init_checkpoint/data_path must be choosed.')

    if args.do_train and args.save_path is None:
        raise ValueError('Where do you want to save your trained model?')
    
    if args.save_path and not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    
    # Write logs to checkpoint and console
    set_logger(args)
    
    with open(os.path.join(args.data_path, 'entities.dict')) as fin:
        entity2id = dict()
        id2entity = dict()
        for line in fin:
            eid, entity = line.strip().split('\t')
            entity2id[entity] = int(eid)
            id2entity[int(eid)] = entity

    with open(os.path.join(args.data_path, 'relations.dict')) as fin:
        relation2id = dict()
        id2relation = dict()
        for line in fin:
            rid, relation = line.strip().split('\t')
            relation2id[relation] = int(rid)
            id2relation[int(rid)] = relation
    
    # Read regions for Countries S* datasets
    if args.countries:
        regions = list()
        with open(os.path.join(args.data_path, 'regions.list')) as fin:
            for line in fin:
                region = line.strip()
                regions.append(entity2id[region])
        args.regions = regions

    nentity = len(entity2id)
    nrelation = len(relation2id)
    
    args.nentity = nentity
    args.nrelation = nrelation
    
    logging.info('Model: %s' % args.model)
    logging.info('Data Path: %s' % args.data_path)
    logging.info('#entity: %d' % nentity)
    logging.info('#relation: %d' % nrelation)
    
    # --------------------------------------------------
    # Comments by Meng:
    # During training, pLogicNet will augment the training triplets,
    # so here we load both the augmented triplets (train.txt) for training and
    # the original triplets (train_kge.txt) for evaluation.
    # Also, the hidden triplets (hidden.txt) are also loaded for annotation.
    # --------------------------------------------------
    train_triples = read_triple(os.path.join(args.workspace_path, 'train_kge.txt'), entity2id, relation2id)
    logging.info('#train: %d' % len(train_triples))
    train_original_triples = read_triple(os.path.join(args.data_path, 'train.txt'), entity2id, relation2id)
    logging.info('#train original: %d' % len(train_original_triples))
    #valid_triples = read_triple(os.path.join(args.data_path, 'valid.txt'), entity2id, relation2id)
    #logging.info('#valid: %d' % len(valid_triples))
    test_triples = read_triple(os.path.join(args.data_path, 'test.txt'), entity2id, relation2id)
    logging.info('#test: %d' % len(test_triples))
    hidden_triples = read_triple(os.path.join(args.workspace_path, 'hidden.txt'), entity2id, relation2id)
    logging.info('#hidden: %d' % len(hidden_triples))
    
    #All true triples
    all_true_triples = train_original_triples + test_triples
    
    kge_model = KGEModel(
        model_name=args.model,
        nentity=nentity,
        nrelation=nrelation,
        hidden_dim=args.hidden_dim,
        gamma=args.gamma,
        double_entity_embedding=args.double_entity_embedding,
        double_relation_embedding=args.double_relation_embedding
    )
    
    logging.info('Model Parameter Configuration:')
    for name, param in kge_model.named_parameters():
        logging.info('Parameter %s: %s, require_grad = %s' % (name, str(param.size()), str(param.requires_grad)))

    if args.cuda:
        kge_model = kge_model.cuda()
    
    if args.do_train:
        # Set training dataloader iterator
        train_dataloader_head = DataLoader(
            TrainDataset(train_triples, nentity, nrelation, args.negative_sample_size, 'head-batch'), 
            batch_size=args.batch_size,
            shuffle=True, 
            num_workers=max(1, args.cpu_num//2),
            collate_fn=TrainDataset.collate_fn
        )
        
        train_dataloader_tail = DataLoader(
            TrainDataset(train_triples, nentity, nrelation, args.negative_sample_size, 'tail-batch'), 
            batch_size=args.batch_size,
            shuffle=True, 
            num_workers=max(1, args.cpu_num//2),
            collate_fn=TrainDataset.collate_fn
        )
        
        train_iterator = BidirectionalOneShotIterator(train_dataloader_head, train_dataloader_tail)
        
        # Set training configuration
        current_learning_rate = args.learning_rate
        optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, kge_model.parameters()), 
            lr=current_learning_rate
        )
        if args.warm_up_steps:
            warm_up_steps = args.warm_up_steps
        else:
            warm_up_steps = args.max_steps // 2

    if args.init_checkpoint:
        # Restore model from checkpoint directory
        logging.info('Loading checkpoint %s...' % args.init_checkpoint)
        checkpoint = torch.load(os.path.join(args.init_checkpoint, 'checkpoint'))
        init_step = checkpoint['step']
        kge_model.load_state_dict(checkpoint['model_state_dict'])
        if args.do_train:
            current_learning_rate = checkpoint['current_learning_rate']
            warm_up_steps = checkpoint['warm_up_steps']
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    else:
        logging.info('Ramdomly Initializing %s Model...' % args.model)
        init_step = 0
    
    step = init_step
    
    logging.info('Start Training...')
    logging.info('init_step = %d' % init_step)
    logging.info('learning_rate = %d' % current_learning_rate)
    logging.info('batch_size = %d' % args.batch_size)
    logging.info('negative_adversarial_sampling = %d' % args.negative_adversarial_sampling)
    logging.info('hidden_dim = %d' % args.hidden_dim)
    logging.info('gamma = %f' % args.gamma)
    logging.info('negative_adversarial_sampling = %s' % str(args.negative_adversarial_sampling))
    if args.negative_adversarial_sampling:
        logging.info('adversarial_temperature = %f' % args.adversarial_temperature)

    if args.record:
        local_path = args.workspace_path
        ensure_dir(local_path)

        opt = vars(args)
        with open(local_path + '/opt.txt', 'w') as fo:
            for key, val in opt.items():
                fo.write('{} {}\n'.format(key, val))
    
    # Set valid dataloader as it would be evaluated during training
    
    if args.do_train:
        training_logs = []
        
        #Training Loop
        for step in range(init_step, args.max_steps):
            
            log = kge_model.train_step(kge_model, optimizer, train_iterator, args)
            
            training_logs.append(log)
            
            if step >= warm_up_steps:
                current_learning_rate = current_learning_rate / 10
                logging.info('Change learning_rate to %f at step %d' % (current_learning_rate, step))
                optimizer = torch.optim.Adam(
                    filter(lambda p: p.requires_grad, kge_model.parameters()), 
                    lr=current_learning_rate
                )
                warm_up_steps = warm_up_steps * 3
            
            if step % args.save_checkpoint_steps == 0:
                save_variable_list = {
                    'step': step, 
                    'current_learning_rate': current_learning_rate,
                    'warm_up_steps': warm_up_steps
                }
                save_model(kge_model, optimizer, save_variable_list, args)
                
            if step % args.log_steps == 0:
                metrics = {}
                for metric in training_logs[0].keys():
                    metrics[metric] = sum([log[metric] for log in training_logs])/len(training_logs)
                log_metrics('Training average', step, metrics)
                training_logs = []
                
            if args.do_valid and (step + 1) % args.valid_steps == 0:
                logging.info('Evaluating on Valid Dataset...')
                metrics, preds = kge_model.test_step(kge_model, valid_triples, all_true_triples, args)
                log_metrics('Valid', step, metrics)
        
        save_variable_list = {
            'step': step, 
            'current_learning_rate': current_learning_rate,
            'warm_up_steps': warm_up_steps
        }
        save_model(kge_model, optimizer, save_variable_list, args)
        
    if args.do_valid:
        logging.info('Evaluating on Valid Dataset...')
        metrics, preds = kge_model.test_step(kge_model, valid_triples, all_true_triples, args)
        log_metrics('Valid', step, metrics)
        
        # --------------------------------------------------
        # Comments by Meng:
        # Save the prediction results of KGE on validation set.
        # --------------------------------------------------

        if args.record:
            # Save the final results
            with open(local_path + '/result_kge_valid.txt', 'w') as fo:
                for metric in metrics:
                    fo.write('{} : {}\n'.format(metric, metrics[metric]))

            # Save the predictions on test data
            with open(local_path + '/pred_kge_valid.txt', 'w') as fo:
                for h, r, t, f, rk, l in preds:
                    fo.write('{}\t{}\t{}\t{}\t{}\n'.format(id2entity[h], id2relation[r], id2entity[t], f, rk))
                    for e, val in l:
                        fo.write('{}:{:.4f} '.format(id2entity[e], val))
                    fo.write('\n')
    
    if args.do_test:
        logging.info('Evaluating on Test Dataset...')
        metrics, preds = kge_model.test_step(kge_model, test_triples, all_true_triples, args)
        log_metrics('Test', step, metrics)
        
        # --------------------------------------------------
        # Comments by Meng:
        # Save the prediction results of KGE on test set.
        # --------------------------------------------------

        if args.record:
            # Save the final results
            with open(local_path + '/result_kge.txt', 'w') as fo:
                for metric in metrics:
                    fo.write('{} : {}\n'.format(metric, metrics[metric]))

            # Save the predictions on test data
            with open(local_path + '/pred_kge.txt', 'w') as fo:
                for h, r, t, f, rk, l in preds:
                    fo.write('{}\t{}\t{}\t{}\t{}\n'.format(id2entity[h], id2relation[r], id2entity[t], f, rk))
                    for e, val in l:
                        fo.write('{}:{:.4f} '.format(id2entity[e], val))
                    fo.write('\n')

    # --------------------------------------------------
    # Comments by Meng:
    # Save the annotations on hidden triplets.
    # --------------------------------------------------

    if args.record:
        # Annotate hidden triplets
        scores = kge_model.infer_step(kge_model, hidden_triples, args)
        with open(local_path + '/annotation.txt', 'w') as fo:
            for (h, r, t), s in zip(hidden_triples, scores):
                fo.write('{}\t{}\t{}\t{}\n'.format(id2entity[h], id2relation[r], id2entity[t], s))
    
    if args.evaluate_train:
        logging.info('Evaluating on Training Dataset...')
        metrics, preds = kge_model.test_step(kge_model, train_triples, all_true_triples, args)
        log_metrics('Test', step, metrics)
Exemplo n.º 24
0
Arquivo: run.py Projeto: zyksir/NoiGAN
def main(args):
    if (not args.do_train) and (not args.do_valid) and (not args.do_test):
        raise ValueError('one of train/val/test mode must be choosed.')

    if args.init_checkpoint:
        override_config(args)
    elif args.data_path is None:
        raise ValueError('one of init_checkpoint/data_path must be choosed.')

    if args.do_train and args.save_path is None:
        raise ValueError('Where do you want to save your trained model?')

    if args.save_path and not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    # Write logs to checkpoint and console
    set_logger(args)

    with open(os.path.join(args.data_path, 'entities.dict')) as fin:
        entity2id = dict()
        for line in fin:
            eid, entity = line.strip().split('\t')
            entity2id[entity] = int(eid)

    with open(os.path.join(args.data_path, 'relations.dict')) as fin:
        relation2id = dict()
        for line in fin:
            rid, relation = line.strip().split('\t')
            relation2id[relation] = int(rid)

    nentity = len(entity2id)
    nrelation = len(relation2id)

    args.nentity = nentity
    args.nrelation = nrelation

    logging.info('Model: %s' % args.model)
    logging.info('Data Path: %s' % args.data_path)
    logging.info('#entity: %d' % nentity)
    logging.info('#relation: %d' % nrelation)

    train_triples = read_triple(os.path.join(args.data_path, args.train_set), entity2id, relation2id)
    if args.fake:
        fake_triples = pickle.load(open(os.path.join(args.data_path, "fake%s.pkl" % args.fake), "rb"))
        fake = torch.LongTensor(fake_triples)
        train_triples += fake_triples
    else:
        fake_triples = [(0, 0, 0)]
        fake = torch.LongTensor(fake_triples)
    if args.cuda:
        fake = fake.cuda()
    logging.info('#train: %d' % len(train_triples))
    valid_triples = read_triple(os.path.join(args.data_path, 'valid.txt'), entity2id, relation2id)
    logging.info('#valid: %d' % len(valid_triples))
    test_triples = read_triple(os.path.join(args.data_path, 'test.txt'), entity2id, relation2id)
    logging.info('#test: %d' % len(test_triples))

    all_true_triples = train_triples + valid_triples + test_triples

    kge_model = KGEModel(
        model_name=args.model,
        nentity=nentity,
        nrelation=nrelation,
        hidden_dim=args.hidden_dim,
        gamma=args.gamma,
        double_entity_embedding=args.double_entity_embedding,
        double_relation_embedding=args.double_relation_embedding
    )

    logging.info('Model Parameter Configuration:')
    for name, param in kge_model.named_parameters():
        logging.info('Parameter %s: %s, require_grad = %s' % (name, str(param.size()), str(param.requires_grad)))
    if args.cuda:
        kge_model = kge_model.cuda()

    # Set training dataloader iterator
    train_dataset_head = TrainDataset(train_triples, nentity, nrelation, args.negative_sample_size, 'head-batch')
    train_dataset_tail = TrainDataset(train_triples, nentity, nrelation, args.negative_sample_size, 'tail-batch')
    for triple in tqdm(train_dataset_head.triples, total=len(train_dataset_head.triples)):
        train_dataset_head.subsampling_weights[triple] = torch.FloatTensor([1.0])
    train_dataset_tail.subsampling_weights = train_dataset_head.subsampling_weights

    train_dataloader_head = DataLoader(
        train_dataset_head,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=max(1, args.cpu_num // 2),
        collate_fn=TrainDataset.collate_fn
    )

    train_dataloader_tail = DataLoader(
        train_dataset_tail,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=max(1, args.cpu_num // 2),
        collate_fn=TrainDataset.collate_fn
    )

    train_iterator = BidirectionalOneShotIterator(train_dataloader_head, train_dataloader_tail)
    classifier, generator = None, None
    if args.method == "clf" or args.method is None:
        args.gen_dim = args.hidden_dim
        clf_triples = random.sample(train_triples, len(train_triples)//10)
        clf_dataset_head = TrainDataset(clf_triples, nentity, nrelation,
                                        args.negative_sample_size, 'head-batch')
        clf_dataset_tail = TrainDataset(clf_triples, nentity, nrelation,
                                        args.negative_sample_size, 'tail-batch')
        clf_dataset_head.true_head, clf_dataset_head.true_tail = train_dataset_head.true_head, train_dataset_head.true_tail
        clf_dataset_tail.true_head, clf_dataset_tail.true_tail = train_dataset_tail.true_head, train_dataset_tail.true_tail
        clf_dataset_head.subsampling_weights = train_dataset_head.subsampling_weights
        clf_dataset_tail.subsampling_weights = train_dataset_head.subsampling_weights
        clf_dataloader_head = DataLoader(
            clf_dataset_head,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TrainDataset.collate_fn
        )

        clf_dataloader_tail = DataLoader(
            clf_dataset_tail,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TrainDataset.collate_fn
        )
        clf_iterator = BidirectionalOneShotIterator(clf_dataloader_head, clf_dataloader_tail)

        gen_dataset_head = TrainDataset(clf_triples, nentity, nrelation,
                                        args.negative_sample_size, 'head-batch')
        gen_dataset_tail = TrainDataset(clf_triples, nentity, nrelation,
                                        args.negative_sample_size, 'tail-batch')
        gen_dataset_head.true_head, gen_dataset_head.true_tail = train_dataset_head.true_head, train_dataset_head.true_tail
        gen_dataset_tail.true_head, gen_dataset_tail.true_tail = train_dataset_tail.true_head, train_dataset_tail.true_tail
        gen_dataset_head.subsampling_weights = train_dataset_head.subsampling_weights
        gen_dataset_tail.subsampling_weights = train_dataset_head.subsampling_weights
        gen_dataloader_head = DataLoader(
            gen_dataset_head,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TrainDataset.collate_fn
        )

        gen_dataloader_tail = DataLoader(
            gen_dataset_tail,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TrainDataset.collate_fn
        )
        gen_iterator = BidirectionalOneShotIterator(gen_dataloader_head, gen_dataloader_tail)

        # if args.double_entity_embedding:
        #     classifier = SimpleNN(input_dim=args.hidden_dim, hidden_dim=5)
        #     generator = SimpleNN(input_dim=args.hidden_dim, hidden_dim=5)
        # else:
        classifier = SimpleNN(input_dim=args.hidden_dim, hidden_dim=5)
        generator = SimpleNN(input_dim=args.hidden_dim, hidden_dim=5)

        if args.cuda:
            classifier = classifier.cuda()
            generator = generator.cuda()
        clf_lr = 0.005 # if "FB15k" in args.data_path else 0.01
        clf_opt = torch.optim.Adam(classifier.parameters(), lr=clf_lr)
        gen_opt = torch.optim.SGD(generator.parameters(), lr=0.0001)
    elif args.method == "KBGAN":
        generator = KGEModel(
            model_name=args.model,
            nentity=nentity,
            nrelation=nrelation,
            hidden_dim=args.gen_dim,
            gamma=args.gamma,
            double_entity_embedding=args.double_entity_embedding,
            double_relation_embedding=args.double_relation_embedding
        )
        if args.cuda:
            generator = generator.cuda()
        # if args.gen_init is not None:
        #     checkpoint = torch.load(os.path.join(args.gen_init, 'checkpoint'))
        #     generator.load_state_dict(checkpoint['model_state_dict'])
        gen_opt = torch.optim.Adam(generator.parameters(), lr=args.learning_rate)

    # Set training configuration
    current_learning_rate = args.learning_rate
    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, kge_model.parameters()),
        lr=current_learning_rate
    )
    if args.warm_up_steps:
        warm_up_steps = args.warm_up_steps
    else:
        warm_up_steps = args.max_steps # // 2

    if args.init_checkpoint:
        # Restore model from checkpoint directory
        logging.info('Loading checkpoint %s...' % args.init_checkpoint)
        checkpoint = torch.load(os.path.join(args.init_checkpoint, 'checkpoint'))
        init_step = 0
        kge_model.load_state_dict(checkpoint['model_state_dict'])
        if args.do_train:
            warm_up_steps = checkpoint['warm_up_steps']
            logging.info("warm_up_steps = %d" % warm_up_steps)
        else:
            current_learning_rate = args.learning_rate
    else:
        logging.info('Ramdomly Initializing %s Model...' % args.model)
        init_step = 0

    step = init_step

    logging.info('Start Training...')
    logging.info('init_step = %d' % init_step)
    logging.info('learning_rate = %d' % current_learning_rate)
    logging.info('batch_size = %d' % args.batch_size)
    logging.info('negative_adversarial_sampling = %d' % args.negative_adversarial_sampling)
    logging.info('hidden_dim = %d' % args.hidden_dim)
    logging.info('gamma = %f' % args.gamma)
    logging.info('negative_adversarial_sampling = %s' % str(args.negative_adversarial_sampling))
    if args.negative_adversarial_sampling:
        logging.info('adversarial_temperature = %f' % args.adversarial_temperature)

    # Set valid  as it would be evaluated during training
    if args.do_train:
        if args.method == "clf" and args.init_checkpoint:
            # classifier.find_topK_triples(kge_model, classifier, train_iterator, clf_iterator, GAN_iterator)
            # logging.info("fake triples in classifier training %d / %d" % (
            #     len(set(fake_triples).intersection(set(clf_iterator.dataloader_head.dataset.triples))),
            #     len(clf_iterator.dataloader_head.dataset.triples)))
            for epoch in range(1200):
                log = classifier.train_classifier_step(kge_model, classifier, clf_opt, clf_iterator, args, generator=None, model_name=args.model)
                if (epoch+1) % 200 == 0:
                    logging.info(log)
                if epoch == 4000:
                    clf_opt = torch.optim.Adam(classifier.parameters(), lr=clf_lr/10)
            clf_opt = torch.optim.Adam(classifier.parameters(), lr=clf_lr)


        training_logs = []

        # Training Loop
        logging.info(optimizer)
        soft = False
        epoch_reward, epoch_loss, avg_reward, log = 0, 0, 0, {}
        for step in range(init_step, args.max_steps):
            if args.method == "clf" and step % 10001 == 0:
                if args.num == 1:
                    soft = True
                elif args.num == 1000:
                    soft = False
                else:
                    soft = not soft
                head, relation, tail = classifier.get_embedding(kge_model, fake)
                if args.model == "RotatE":
                    fake_score = classifier.forward(RotatE(head, relation, tail, "single", kge_model))
                elif args.model == "DistMult":
                    fake_score = classifier.forward(head * relation * tail)
                elif args.model == "TransE":
                    fake_score = classifier.forward(head + relation - tail)
                all_weight = classifier.find_topK_triples(kge_model, classifier, train_iterator, clf_iterator,
                                                           gen_iterator, soft=soft, model_name=args.model)
                logging.info("fake percent %f in %d" % (fake_score.sum().item() / all_weight, all_weight))
                logging.info("fake triples in classifier training %d / %d" % (
                    len(set(fake_triples).intersection(set(clf_iterator.dataloader_head.dataset.triples))),
                    len(clf_iterator.dataloader_head.dataset.triples)))

                epoch_reward, epoch_loss, avg_reward = 0, 0, 0
                for epoch in tqdm(range(200)):
                    classifier.train_GAN_step(kge_model, generator, classifier, gen_opt, clf_opt, gen_iterator, epoch_reward, epoch_loss, avg_reward, args, model_name=args.model)

                clf_train_num = 200
                for epoch in range(clf_train_num):
                    log = classifier.train_classifier_step(kge_model, classifier, clf_opt, clf_iterator, args, generator=None, model_name=args.model)
                    if epoch % 100 == 0:
                        logging.info(log)

            if step % 300 == 0 and step > 0 and args.method == "KBGAN":
                avg_reward = epoch_reward / batch_num
                epoch_reward, epoch_loss = 0, 0
                logging.info('Training average reward at step %d: %f' % (step, avg_reward))
                logging.info('Training average loss at step %d: %f' % (step, epoch_loss / batch_num))

            if args.method == "KBGAN":
                epoch_reward, epoch_loss, batch_num = kge_model.train_GAN_step(generator, kge_model, gen_opt, optimizer, train_iterator, epoch_reward, epoch_loss, avg_reward, args)
            else:
                log = kge_model.train_step(kge_model, optimizer, train_iterator, args, generator=generator)

            training_logs.append(log)

            if step >= warm_up_steps:
                current_learning_rate = current_learning_rate / 10
                logging.info('Change learning_rate to %f at step %d' % (current_learning_rate, step))
                optimizer = torch.optim.Adam(
                    filter(lambda p: p.requires_grad, kge_model.parameters()),
                    lr=current_learning_rate
                )
                warm_up_steps = warm_up_steps * 3

            if step % args.save_checkpoint_steps == 0:
                save_variable_list = {
                    'step': step,
                    'current_learning_rate': current_learning_rate,
                    'warm_up_steps': warm_up_steps
                }
                if args.method is not None:
                    save_variable_list["confidence"] = train_iterator.dataloader_head.dataset.subsampling_weights
                save_model(kge_model, optimizer, save_variable_list, args, classifier=classifier, generator=generator)

            if step % args.log_steps == 0:
                metrics = {}
                for metric in training_logs[0].keys():
                    metrics[metric] = sum([log[metric] for log in training_logs]) / len(training_logs)
                log_metrics('Training average', step, metrics)
                training_logs = []

            if args.do_valid and step % args.valid_steps == 0:
                logging.info('Evaluating on Valid Dataset...')
                metrics = kge_model.test_step(kge_model, valid_triples, all_true_triples, args)
                log_metrics('Valid', step, metrics)
        save_variable_list = {
            'step': step,
            'current_learning_rate': current_learning_rate,
            'warm_up_steps': warm_up_steps
        }
        if args.method is not None:
            save_variable_list["confidence"] = train_iterator.dataloader_head.dataset.subsampling_weights
        save_model(kge_model, optimizer, save_variable_list, args, classifier=classifier, generator=generator)

    if args.do_valid:
        logging.info('Evaluating on Valid Dataset...')
        metrics = kge_model.test_step(kge_model, valid_triples, all_true_triples, args)
        log_metrics('Valid', step, metrics)

    if args.do_test:
        logging.info('Evaluating on Test Dataset...')
        metrics = kge_model.test_step(kge_model, test_triples, all_true_triples, args)
        log_metrics('Test', step, metrics)
        if args.method is not None:
            classifier.find_topK_triples(kge_model, classifier, train_iterator, clf_iterator,
                                         gen_iterator, soft=True, model_name=args.model)
            # torch.save(train_iterator.dataloader_head.dataset.subsampling_weights,
            #            os.path.join(args.save_path, 'weight'))
            true_triples = set(train_triples) - set(fake_triples)
            scores, label = [], []
            for triple in true_triples:
                if not (triple == (0, 0, 0)):
                    scores.append(train_iterator.dataloader_head.dataset.subsampling_weights[triple].item())
                    label.append(1)
            for triple in fake_triples:
                if not (triple == (0, 0, 0)):
                    scores.append(train_iterator.dataloader_head.dataset.subsampling_weights[triple].item())
                    label.append(0)
        else:
            print("start to use sigmoid to translate distance to probability")
            scores, label = [], []
            true_triples = set(train_triples) - set(fake_triples)
            i = 0
            import sys
            while i < len(train_iterator.dataloader_head.dataset.triples):
                sys.stdout.write("%d in %d\r" % (i, len(train_iterator.dataloader_head.dataset.triples)))
                sys.stdout.flush()
                j = min(i + 1024, len(train_iterator.dataloader_head.dataset.triples))
                sample = torch.LongTensor(train_iterator.dataloader_head.dataset.triples[i: j]).cuda()
                score = kge_model(sample).detach().cpu().view(-1)
                for x, triple in enumerate(train_iterator.dataloader_head.dataset.triples[i: j]):
                    if triple in true_triples:
                        label.append(1)
                        scores.append(torch.sigmoid(score[x]))
                    elif triple in fake_triples:
                        label.append(0)
                        scores.append(torch.sigmoid(score[x]))
                i = j
                del sample
                del score
        scores, label = np.array(scores), np.array(label)
        from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score
        p = precision_score(label, scores > 0.5)
        r = recall_score(label, scores > 0.5)
        f1 = f1_score(label, scores > 0.5)
        auc = roc_auc_score(label, scores > 0.5)
        logging.info(f"""
        precision = {p}
        recall = {r}
        f1 score = {f1}
        auc score = {auc}
        """)
        p = precision_score(1 - label, scores < 0.5)
        r = recall_score(1 - label, scores < 0.5)
        f1 = f1_score(1 - label, scores < 0.5)
        auc = roc_auc_score(1 - label, scores < 0.5)
        logging.info(f"""
                precision = {p}
                recall = {r}
                f1 score = {f1}
                auc score = {auc}
                """)

    if args.evaluate_train:
        logging.info('Evaluating on Training Dataset...')
        metrics = kge_model.test_step(kge_model, train_triples, all_true_triples, args)
        log_metrics('Test', step, metrics)
Exemplo n.º 25
0
def main(args):
    if args.init_checkpoint:
        override_config(args)
    elif args.data_path is None:
        raise ValueError('one of init_checkpoint/data_path must be choosed.')

    if args.save_path is None:
        raise ValueError('Where do you want to save your trained model?')
    
    if args.save_path and not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    
    # Write logs to checkpoint and console
    set_logger(args)

    with open(args.data_path) as fin:
        entity2id = bidict()
        relation2id = bidict()
        train_triples = []
        for line in fin:
            _tmp = [x.strip() for x in re.split("[,\t]", line) if x.strip()][:3]
            if len(_tmp) < 3:
                continue
            e1, relation, e2 = tuple(_tmp)
            if not e1 in entity2id:
                entity2id[e1] = len(entity2id)
            if not e2 in entity2id:
                entity2id[e2] = len(entity2id)
            if not relation in relation2id:
                relation2id[relation] = len(relation2id)
            train_triples.append((entity2id[e1], relation2id[relation], entity2id[e2]))

    nentity = len(entity2id)
    nrelation = len(relation2id)
    
    args.nentity = nentity
    args.nrelation = nrelation
    
    logging.info('Model: %s' % args.model)
    logging.info('Data Path: %s' % args.data_path)
    logging.info('#entity: %d' % nentity)
    logging.info('#relation: %d' % nrelation)
    
    logging.info('#train: %d' % len(train_triples))
    
    #All true triples
    all_true_triples = train_triples
    
    kge_model = KGEModel(
        model_name=args.model,
        nentity=nentity,
        nrelation=nrelation,
        hidden_dim=args.hidden_dim,
        gamma=args.gamma,
        double_entity_embedding=args.double_entity_embedding,
        double_relation_embedding=args.double_relation_embedding
    )
    
    logging.info('Model Parameter Configuration:')
    for name, param in kge_model.named_parameters():
        logging.info('Parameter %s: %s, require_grad = %s' % (name, str(param.size()), str(param.requires_grad)))

    if args.cuda:
        kge_model = kge_model.cuda()
    
    # Set training dataloader iterator
    train_dataloader_head = DataLoader(
        TrainDataset(train_triples, nentity, nrelation, args.negative_sample_size, 'head-batch'), 
        batch_size=args.batch_size,
        shuffle=True, 
        num_workers=max(1, args.cpu_num//2),
        collate_fn=TrainDataset.collate_fn
    )
    
    train_dataloader_tail = DataLoader(
        TrainDataset(train_triples, nentity, nrelation, args.negative_sample_size, 'tail-batch'), 
        batch_size=args.batch_size,
        shuffle=True, 
        num_workers=max(1, args.cpu_num//2),
        collate_fn=TrainDataset.collate_fn
    )
    
    train_iterator = BidirectionalOneShotIterator(train_dataloader_head, train_dataloader_tail)
    
    # Set training configuration
    current_learning_rate = args.learning_rate
    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, kge_model.parameters()), 
        lr=current_learning_rate
    )
    if args.warm_up_steps:
        warm_up_steps = args.warm_up_steps
    else:
            warm_up_steps = args.max_steps // 2

    if args.init_checkpoint:
        # Restore model from checkpoint directory
        logging.info('Loading checkpoint %s...' % args.init_checkpoint)
        checkpoint = torch.load(os.path.join(args.init_checkpoint, 'checkpoint'))
        init_step = checkpoint['step']
        kge_model.load_state_dict(checkpoint['model_state_dict'])
        
        current_learning_rate = checkpoint['current_learning_rate']
        warm_up_steps = checkpoint['warm_up_steps']
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    else:
        logging.info('Ramdomly Initializing %s Model...' % args.model)
        init_step = 0
    
    step = init_step
    
    logging.info('Start Training...')
    logging.info('init_step = %d' % init_step)
    logging.info('learning_rate = %d' % current_learning_rate)
    logging.info('batch_size = %d' % args.batch_size)
    logging.info('negative_adversarial_sampling = %d' % args.negative_adversarial_sampling)
    logging.info('hidden_dim = %d' % args.hidden_dim)
    logging.info('gamma = %f' % args.gamma)
    logging.info('negative_adversarial_sampling = %s' % str(args.negative_adversarial_sampling))
    if args.negative_adversarial_sampling:
        logging.info('adversarial_temperature = %f' % args.adversarial_temperature)
    
    # Set valid dataloader as it would be evaluated during training
    
    training_logs = []
    
    #Training Loop
    for step in range(init_step, args.max_steps):
        
        log = kge_model.train_step(kge_model, optimizer, train_iterator, args)
        
        training_logs.append(log)
        
        if step >= warm_up_steps:
            current_learning_rate = current_learning_rate / 10
            logging.info('Change learning_rate to %f at step %d' % (current_learning_rate, step))
            optimizer = torch.optim.Adam(
                filter(lambda p: p.requires_grad, kge_model.parameters()), 
                lr=current_learning_rate
            )
            warm_up_steps = warm_up_steps * 3
        
        if step % args.save_checkpoint_steps == 0:
            save_variable_list = {
                'step': step, 
                'current_learning_rate': current_learning_rate,
                'warm_up_steps': warm_up_steps
            }
            save_model(kge_model, optimizer, save_variable_list, args, entity2id, relation2id)
            
        if step % args.log_steps == 0:
            metrics = {}
            for metric in training_logs[0].keys():
                metrics[metric] = sum([log[metric] for log in training_logs])/len(training_logs)
            log_metrics('Training average', step, metrics)
            training_logs = []
            
    save_variable_list = {
        'step': step, 
        'current_learning_rate': current_learning_rate,
        'warm_up_steps': warm_up_steps
    }
    save_model(kge_model, optimizer, save_variable_list, args, entity2id, relation2id)
        
    if args.evaluate_train:
        logging.info('Evaluating on Training Dataset...')
        metrics = kge_model.test_step(kge_model, train_triples, all_true_triples, args)
        log_metrics('Test', step, metrics)