示例#1
0
def main(args):
    # if (not args.do_train) and (not args.do_valid) and (not args.do_test):
    #     raise ValueError('one of train/val/test mode must be choosed.')
    
    if args.init_checkpoint:
        override_config(args)
    elif args.data_path is None:
        raise ValueError('one of init_checkpoint/data_path must be choosed.')

    if args.do_train and args.save_path is None:
        raise ValueError('Where do you want to save your trained model?')
    
    if args.save_path and not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    
    # Write logs to checkpoint and console
    set_logger(args)
    
    # with open(os.path.join(args.data_path, 'entities.dict')) as fin:
    #     entity2id = dict()
    #     id2entity = dict()
    #     for line in fin:
    #         eid, entity = line.strip().split('\t')
    #         entity2id[entity] = int(eid)
    #         id2entity[int(eid)] = entity

    # with open(os.path.join(args.data_path, 'relations.dict')) as fin:
    #     relation2id = dict()
    #     id2relation = dict()
    #     for line in fin:
    #         rid, relation = line.strip().split('\t')
    #         relation2id[relation] = int(rid)
    #         id2relation[int(rid)] = relation
    
    # # Read regions for Countries S* datasets
    # if args.countries:
    #     regions = list()
    #     with open(os.path.join(args.data_path, 'regions.list')) as fin:
    #         for line in fin:
    #             region = line.strip()
    #             regions.append(entity2id[region])
    #     args.regions = regions

    '''amazon dataset'''
    with open(os.path.join(args.data_path, 'entity2id.txt')) as fin:
        entity2id = dict()
        id2entity = dict()
        for line in fin:
            if len(line.strip().split('\t')) < 2:
                continue
            entity, eid = line.strip().split('\t')
            entity2id[entity] = int(eid)
            id2entity[int(eid)] = entity

    with open(os.path.join(args.data_path, 'relation2id.txt')) as fin:
        relation2id = dict()
        id2relation = dict()
        for line in fin:
            if len(line.strip().split('\t')) < 2:
                continue
            relation, rid = line.strip().split('\t')
            relation2id[relation] = int(rid)
            id2relation[int(rid)] = relation

    nentity = len(entity2id)
    nrelation = len(relation2id)
    
    args.nentity = nentity
    args.nrelation = nrelation
    
    logging.info('Model: %s' % args.model)
    logging.info('Data Path: %s' % args.data_path)
    logging.info('#entity: %d' % nentity)
    logging.info('#relation: %d' % nrelation)
    
    # --------------------------------------------------
    # Comments by Meng:
    # During training, pLogicNet will augment the training triplets,
    # so here we load both the augmented triplets (train.txt) for training and
    # the original triplets (train_kge.txt) for evaluation.
    # Also, the hidden triplets (hidden.txt) are also loaded for annotation.
    # --------------------------------------------------
    # train_triples = read_triple(os.path.join(args.workspace_path, 'train_kge.txt'), entity2id, relation2id)
    # logging.info('#train: %d' % len(train_triples))
    # train_original_triples = read_triple(os.path.join(args.data_path, 'train.txt'), entity2id, relation2id)
    # logging.info('#train original: %d' % len(train_original_triples))
    # valid_triples = read_triple(os.path.join(args.data_path, 'valid.txt'), entity2id, relation2id)
    # logging.info('#valid: %d' % len(valid_triples))
    # test_triples = read_triple(os.path.join(args.data_path, 'test.txt'), entity2id, relation2id)
    # logging.info('#test: %d' % len(test_triples))
    # hidden_triples = read_triple(os.path.join(args.workspace_path, 'hidden.txt'), entity2id, relation2id)
    # logging.info('#hidden: %d' % len(hidden_triples))

    train_triples = read_triple(os.path.join(args.workspace_path, 'train_kge.txt'), entity2id, relation2id)
    logging.info('#train: %d' % len(train_triples))
    train_original_triples = read_triple(os.path.join(args.data_path, 'train.txt'), entity2id, relation2id)
    logging.info('#train original: %d' % len(train_original_triples))
    valid_triples = read_triple(os.path.join(args.data_path, 'kg_val_triples_Cell_Phones_and_Accessories.txt'), entity2id, relation2id)
    logging.info('#valid: %d' % len(valid_triples))
    test_triples = read_triple(os.path.join(args.data_path, 'kg_test_triples_Cell_Phones_and_Accessories.txt'), entity2id, relation2id)
    logging.info('#test: %d' % len(test_triples))
    test_candidates = np.load(os.path.join(args.data_path, 'rec_test_candidate100.npz'))['candidates'][:, 1:]
    # test_candidates = np.load('/common/users/yz956/kg/code/OpenDialKG/cand.npy')
    # hidden_triples = read_triple(os.path.join(args.workspace_path, 'hidden.txt'), entity2id, relation2id)
    hidden_triples = read_triple("/common/users/yz956/kg/code/KBRD/data/cpa/cpa/hidden_50.txt", entity2id, relation2id)
    logging.info('#hidden: %d' % len(hidden_triples))
    
    #All true triples
    all_true_triples = train_original_triples + valid_triples + test_triples
    
    kge_model = KGEModel(
        model_name=args.model,
        nentity=nentity,
        nrelation=nrelation,
        hidden_dim=args.hidden_dim,
        gamma=args.gamma,
        double_entity_embedding=args.double_entity_embedding,
        double_relation_embedding=args.double_relation_embedding
    )
    
    logging.info('Model Parameter Configuration:')
    for name, param in kge_model.named_parameters():
        logging.info('Parameter %s: %s, require_grad = %s' % (name, str(param.size()), str(param.requires_grad)))

    if args.cuda:
        kge_model = kge_model.cuda()
    
    if args.do_train:
        # Set training dataloader iterator
        train_dataloader_head = DataLoader(
            TrainDataset(train_triples, nentity, nrelation, args.negative_sample_size, 'head-batch'), 
            batch_size=args.batch_size,
            shuffle=True, 
            num_workers=max(1, args.cpu_num//2),
            collate_fn=TrainDataset.collate_fn
        )
        
        train_dataloader_tail = DataLoader(
            TrainDataset(train_triples, nentity, nrelation, args.negative_sample_size, 'tail-batch'), 
            batch_size=args.batch_size,
            shuffle=True, 
            num_workers=max(1, args.cpu_num//2),
            collate_fn=TrainDataset.collate_fn
        )
        
        train_iterator = BidirectionalOneShotIterator(train_dataloader_head, train_dataloader_tail)
        
        # Set training configuration
        current_learning_rate = args.learning_rate
        optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, kge_model.parameters()), 
            lr=current_learning_rate
        )
        if args.warm_up_steps:
            warm_up_steps = args.warm_up_steps
        else:
            warm_up_steps = args.max_steps // 2

    if args.init_checkpoint:
        # Restore model from checkpoint directory
        logging.info('Loading checkpoint %s...' % args.init_checkpoint)
        checkpoint = torch.load(os.path.join(args.init_checkpoint, 'checkpoint'))
        init_step = checkpoint['step']
        kge_model.load_state_dict(checkpoint['model_state_dict'])
        if args.do_train:
            current_learning_rate = checkpoint['current_learning_rate']
            warm_up_steps = checkpoint['warm_up_steps']
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    else:
        logging.info('Ramdomly Initializing %s Model...' % args.model)
        init_step = 0
    
    step = init_step
    
    logging.info('Start Training...')
    logging.info('init_step = %d' % init_step)
    logging.info('learning_rate = %d' % current_learning_rate)
    logging.info('batch_size = %d' % args.batch_size)
    logging.info('negative_adversarial_sampling = %d' % args.negative_adversarial_sampling)
    logging.info('hidden_dim = %d' % args.hidden_dim)
    logging.info('gamma = %f' % args.gamma)
    logging.info('negative_adversarial_sampling = %s' % str(args.negative_adversarial_sampling))
    if args.negative_adversarial_sampling:
        logging.info('adversarial_temperature = %f' % args.adversarial_temperature)

    if args.record:
        local_path = args.workspace_path
        ensure_dir(local_path)

        opt = vars(args)
        with open(local_path + '/opt.txt', 'w') as fo:
            for key, val in opt.items():
                fo.write('{} {}\n'.format(key, val))
    
    # Set valid dataloader as it would be evaluated during training
    
    if args.do_train:
        training_logs = []
        
        #Training Loop
        for step in range(init_step, args.max_steps):
            
            log = kge_model.train_step(kge_model, optimizer, train_iterator, args)
            
            training_logs.append(log)
            
            if step >= warm_up_steps:
                current_learning_rate = current_learning_rate / 10
                logging.info('Change learning_rate to %f at step %d' % (current_learning_rate, step))
                optimizer = torch.optim.Adam(
                    filter(lambda p: p.requires_grad, kge_model.parameters()), 
                    lr=current_learning_rate
                )
                warm_up_steps = warm_up_steps * 3
            
            if step % args.save_checkpoint_steps == 0:
                save_variable_list = {
                    'step': step, 
                    'current_learning_rate': current_learning_rate,
                    'warm_up_steps': warm_up_steps
                }
                save_model(kge_model, optimizer, save_variable_list, args)
                
            if step % args.log_steps == 0:
                metrics = {}
                for metric in training_logs[0].keys():
                    metrics[metric] = sum([log[metric] for log in training_logs])/len(training_logs)
                log_metrics('Training average', step, metrics)
                training_logs = []
                
            if args.do_valid and (step + 1) % args.valid_steps == 0:
                logging.info('Evaluating on Valid Dataset...')
                metrics, preds = kge_model.test_step(kge_model, valid_triples, all_true_triples, args)
                log_metrics('Valid', step, metrics)
        
        save_variable_list = {
            'step': step, 
            'current_learning_rate': current_learning_rate,
            'warm_up_steps': warm_up_steps
        }
        save_model(kge_model, optimizer, save_variable_list, args)
        
    if args.do_valid:
        logging.info('Evaluating on Valid Dataset...')
        metrics, preds = kge_model.test_step(kge_model, valid_triples, all_true_triples, args)
        log_metrics('Valid', step, metrics)
        
        # --------------------------------------------------
        # Comments by Meng:
        # Save the prediction results of KGE on validation set.
        # --------------------------------------------------

        if args.record:
            # Save the final results
            with open(local_path + '/result_kge_valid.txt', 'w') as fo:
                for metric in metrics:
                    fo.write('{} : {}\n'.format(metric, metrics[metric]))

            # Save the predictions on test data
            with open(local_path + '/pred_kge_valid.txt', 'w') as fo:
                for h, r, t, f, rk, l in preds:
                    fo.write('{}\t{}\t{}\t{}\t{}\n'.format(id2entity[h], id2relation[r], id2entity[t], f, rk))
                    for e, val in l:
                        fo.write('{}:{:.4f} '.format(id2entity[e], val))
                    fo.write('\n')
    
    if args.do_test:
        logging.info('Evaluating on Test Dataset...')
        # metrics, preds = kge_model.test_step(kge_model, test_triples, all_true_triples, args)
        metrics, preds = kge_model.test_step(kge_model, test_triples, test_candidates, all_true_triples, args)
        log_metrics('Test', step, metrics)
        
        # --------------------------------------------------
        # Comments by Meng:
        # Save the prediction results of KGE on test set.
        # --------------------------------------------------

        if args.record:
            # Save the final results
            with open(local_path + '/result_kge.txt', 'w') as fo:
                for metric in metrics:
                    fo.write('{} : {}\n'.format(metric, metrics[metric]))

            # Save the predictions on test data
            with open(local_path + '/pred_kge.txt', 'w') as fo:
                for h, r, t, f, rk, l in preds:
                    fo.write('{}\t{}\t{}\t{}\t{}\n'.format(id2entity[h], id2relation[r], id2entity[t], f, rk))
                    for e, val in l:
                        fo.write('{}:{:.4f} '.format(id2entity[e], val))
                    fo.write('\n')

    # --------------------------------------------------
    # Comments by Meng:
    # Save the annotations on hidden triplets.
    # --------------------------------------------------

    if args.record:
        # Annotate hidden triplets
        scores = kge_model.infer_step(kge_model, hidden_triples, args)
        # with open(local_path + '/annotation.txt', 'w') as fo:
        #     for (h, r, t), s in zip(hidden_triples, scores):
        #         fo.write('{}\t{}\t{}\t{}\n'.format(id2entity[h], id2relation[r], id2entity[t], s))

        # Annotate hidden triplets
        print('annotation')
        
        cand = {}
        with gzip.open('/common/users/yz956/kg/code/KBRD/data/cpa/cpa/kg_test_candidates_Cell_Phones_and_Accessories.txt.gz', 'rt') as f:
            for line in f:
                cells = line.split()
                uid = int(cells[0])
                item_ids = [int(i) for i in cells[1:]]
                cand[uid] = item_ids
        ann, train = [], []
        d = {}
        with open('/common/users/yz956/kg/code/KBRD/data/cpa/cpa/sample_pre.txt') as ft:
            for line in ft:
                line = line.strip().split('\t')
                train.append(line[1:])
        for u in range(61254):
            hiddens = []
            for i in cand[u]:
            # for i in range(61254, 108858):
                hiddens.append((u, 0, i))
            scores = kge_model.infer_step(kge_model, hiddens, args)
            score_np = np.array(scores)
            d = dict(zip(cand[u], scores))
            # d = dict(zip(range(61254, 108858), scores))
            d = sorted(d.items(), key=lambda x: x[1], reverse=True)
            
            # d_50 = d[:50]
            # for idx, t in enumerate(train[u]):
            #     for (tt, prob) in d_50:
            #         if int(t) == tt:
            #             d_50.remove((tt, prob))
            #             d_50.append(d[50 + idx])
            # assert len(d_50) == 50
            # d = {}

            d_50 = d
            ann.append(d_50)
        with open(local_path + '/annotation_1000_htr.txt', 'w') as fo:
            for idx, d in enumerate(ann):
                for (t, score) in d:
                    fo.write(str(idx) + '\t' + str(t) + '\t0\t' + str(score) + '\n')

        # with open(local_path + '/hidden_50_p.txt', 'w') as fo:
        #     for idx, d in enumerate(ann):
        #         for (t, score) in d:
        #             fo.write(str(idx) + '\t' + str(t) + '\t0\n')
        
        scores = kge_model.infer_step(kge_model, hidden_triples, args)
        with open(local_path + '/annotation_htr.txt', 'w') as fo:
            for (h, r, t), s in zip(hidden_triples, scores):
                # fo.write('{}\t{}\t{}\t{}\n'.format(id2entity[h], id2relation[r], id2entity[t], s))
                fo.write('{}\t{}\t{}\t{}\n'.format(str(h), str(t), str(r), s))
    
    if args.evaluate_train:
        logging.info('Evaluating on Training Dataset...')
        metrics, preds = kge_model.test_step(kge_model, train_triples, all_true_triples, args)
        log_metrics('Test', step, metrics)
示例#2
0
def main(args):
    if (not args.do_train) and (not args.do_valid) and (not args.do_test):
        raise ValueError('one of train/val/test mode must be choosed.')
    
    if args.init_checkpoint:
        override_config(args)
    elif args.data_path is None:
        raise ValueError('one of init_checkpoint/data_path must be choosed.')

    if args.do_train and args.save_path is None:
        raise ValueError('Where do you want to save your trained model?')
    
    if args.save_path and not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    
    # Write logs to checkpoint and console
    set_logger(args)
    
    with open(os.path.join(args.data_path, 'entities.dict')) as fin:
        entity2id = dict()
        id2entity = dict()
        for line in fin:
            eid, entity = line.strip().split('\t')
            entity2id[entity] = int(eid)
            id2entity[int(eid)] = entity

    with open(os.path.join(args.data_path, 'relations.dict')) as fin:
        relation2id = dict()
        id2relation = dict()
        for line in fin:
            rid, relation = line.strip().split('\t')
            relation2id[relation] = int(rid)
            id2relation[int(rid)] = relation
    
    # Read regions for Countries S* datasets
    if args.countries:
        regions = list()
        with open(os.path.join(args.data_path, 'regions.list')) as fin:
            for line in fin:
                region = line.strip()
                regions.append(entity2id[region])
        args.regions = regions

    nentity = len(entity2id)
    nrelation = len(relation2id)
    
    args.nentity = nentity
    args.nrelation = nrelation
    
    logging.info('Model: %s' % args.model)
    logging.info('Data Path: %s' % args.data_path)
    logging.info('#entity: %d' % nentity)
    logging.info('#relation: %d' % nrelation)
    
    # --------------------------------------------------
    # Comments by Meng:
    # During training, pLogicNet will augment the training triplets,
    # so here we load both the augmented triplets (train.txt) for training and
    # the original triplets (train_kge.txt) for evaluation.
    # Also, the hidden triplets (hidden.txt) are also loaded for annotation.
    # --------------------------------------------------
    train_triples = read_triple(os.path.join(args.workspace_path, 'train_kge.txt'), entity2id, relation2id)
    logging.info('#train: %d' % len(train_triples))
    train_original_triples = read_triple(os.path.join(args.data_path, 'train.txt'), entity2id, relation2id)
    logging.info('#train original: %d' % len(train_original_triples))
    #valid_triples = read_triple(os.path.join(args.data_path, 'valid.txt'), entity2id, relation2id)
    #logging.info('#valid: %d' % len(valid_triples))
    test_triples = read_triple(os.path.join(args.data_path, 'test.txt'), entity2id, relation2id)
    logging.info('#test: %d' % len(test_triples))
    hidden_triples = read_triple(os.path.join(args.workspace_path, 'hidden.txt'), entity2id, relation2id)
    logging.info('#hidden: %d' % len(hidden_triples))
    
    #All true triples
    all_true_triples = train_original_triples + test_triples
    
    kge_model = KGEModel(
        model_name=args.model,
        nentity=nentity,
        nrelation=nrelation,
        hidden_dim=args.hidden_dim,
        gamma=args.gamma,
        double_entity_embedding=args.double_entity_embedding,
        double_relation_embedding=args.double_relation_embedding
    )
    
    logging.info('Model Parameter Configuration:')
    for name, param in kge_model.named_parameters():
        logging.info('Parameter %s: %s, require_grad = %s' % (name, str(param.size()), str(param.requires_grad)))

    if args.cuda:
        kge_model = kge_model.cuda()
    
    if args.do_train:
        # Set training dataloader iterator
        train_dataloader_head = DataLoader(
            TrainDataset(train_triples, nentity, nrelation, args.negative_sample_size, 'head-batch'), 
            batch_size=args.batch_size,
            shuffle=True, 
            num_workers=max(1, args.cpu_num//2),
            collate_fn=TrainDataset.collate_fn
        )
        
        train_dataloader_tail = DataLoader(
            TrainDataset(train_triples, nentity, nrelation, args.negative_sample_size, 'tail-batch'), 
            batch_size=args.batch_size,
            shuffle=True, 
            num_workers=max(1, args.cpu_num//2),
            collate_fn=TrainDataset.collate_fn
        )
        
        train_iterator = BidirectionalOneShotIterator(train_dataloader_head, train_dataloader_tail)
        
        # Set training configuration
        current_learning_rate = args.learning_rate
        optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, kge_model.parameters()), 
            lr=current_learning_rate
        )
        if args.warm_up_steps:
            warm_up_steps = args.warm_up_steps
        else:
            warm_up_steps = args.max_steps // 2

    if args.init_checkpoint:
        # Restore model from checkpoint directory
        logging.info('Loading checkpoint %s...' % args.init_checkpoint)
        checkpoint = torch.load(os.path.join(args.init_checkpoint, 'checkpoint'))
        init_step = checkpoint['step']
        kge_model.load_state_dict(checkpoint['model_state_dict'])
        if args.do_train:
            current_learning_rate = checkpoint['current_learning_rate']
            warm_up_steps = checkpoint['warm_up_steps']
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    else:
        logging.info('Ramdomly Initializing %s Model...' % args.model)
        init_step = 0
    
    step = init_step
    
    logging.info('Start Training...')
    logging.info('init_step = %d' % init_step)
    logging.info('learning_rate = %d' % current_learning_rate)
    logging.info('batch_size = %d' % args.batch_size)
    logging.info('negative_adversarial_sampling = %d' % args.negative_adversarial_sampling)
    logging.info('hidden_dim = %d' % args.hidden_dim)
    logging.info('gamma = %f' % args.gamma)
    logging.info('negative_adversarial_sampling = %s' % str(args.negative_adversarial_sampling))
    if args.negative_adversarial_sampling:
        logging.info('adversarial_temperature = %f' % args.adversarial_temperature)

    if args.record:
        local_path = args.workspace_path
        ensure_dir(local_path)

        opt = vars(args)
        with open(local_path + '/opt.txt', 'w') as fo:
            for key, val in opt.items():
                fo.write('{} {}\n'.format(key, val))
    
    # Set valid dataloader as it would be evaluated during training
    
    if args.do_train:
        training_logs = []
        
        #Training Loop
        for step in range(init_step, args.max_steps):
            
            log = kge_model.train_step(kge_model, optimizer, train_iterator, args)
            
            training_logs.append(log)
            
            if step >= warm_up_steps:
                current_learning_rate = current_learning_rate / 10
                logging.info('Change learning_rate to %f at step %d' % (current_learning_rate, step))
                optimizer = torch.optim.Adam(
                    filter(lambda p: p.requires_grad, kge_model.parameters()), 
                    lr=current_learning_rate
                )
                warm_up_steps = warm_up_steps * 3
            
            if step % args.save_checkpoint_steps == 0:
                save_variable_list = {
                    'step': step, 
                    'current_learning_rate': current_learning_rate,
                    'warm_up_steps': warm_up_steps
                }
                save_model(kge_model, optimizer, save_variable_list, args)
                
            if step % args.log_steps == 0:
                metrics = {}
                for metric in training_logs[0].keys():
                    metrics[metric] = sum([log[metric] for log in training_logs])/len(training_logs)
                log_metrics('Training average', step, metrics)
                training_logs = []
                
            if args.do_valid and (step + 1) % args.valid_steps == 0:
                logging.info('Evaluating on Valid Dataset...')
                metrics, preds = kge_model.test_step(kge_model, valid_triples, all_true_triples, args)
                log_metrics('Valid', step, metrics)
        
        save_variable_list = {
            'step': step, 
            'current_learning_rate': current_learning_rate,
            'warm_up_steps': warm_up_steps
        }
        save_model(kge_model, optimizer, save_variable_list, args)
        
    if args.do_valid:
        logging.info('Evaluating on Valid Dataset...')
        metrics, preds = kge_model.test_step(kge_model, valid_triples, all_true_triples, args)
        log_metrics('Valid', step, metrics)
        
        # --------------------------------------------------
        # Comments by Meng:
        # Save the prediction results of KGE on validation set.
        # --------------------------------------------------

        if args.record:
            # Save the final results
            with open(local_path + '/result_kge_valid.txt', 'w') as fo:
                for metric in metrics:
                    fo.write('{} : {}\n'.format(metric, metrics[metric]))

            # Save the predictions on test data
            with open(local_path + '/pred_kge_valid.txt', 'w') as fo:
                for h, r, t, f, rk, l in preds:
                    fo.write('{}\t{}\t{}\t{}\t{}\n'.format(id2entity[h], id2relation[r], id2entity[t], f, rk))
                    for e, val in l:
                        fo.write('{}:{:.4f} '.format(id2entity[e], val))
                    fo.write('\n')
    
    if args.do_test:
        logging.info('Evaluating on Test Dataset...')
        metrics, preds = kge_model.test_step(kge_model, test_triples, all_true_triples, args)
        log_metrics('Test', step, metrics)
        
        # --------------------------------------------------
        # Comments by Meng:
        # Save the prediction results of KGE on test set.
        # --------------------------------------------------

        if args.record:
            # Save the final results
            with open(local_path + '/result_kge.txt', 'w') as fo:
                for metric in metrics:
                    fo.write('{} : {}\n'.format(metric, metrics[metric]))

            # Save the predictions on test data
            with open(local_path + '/pred_kge.txt', 'w') as fo:
                for h, r, t, f, rk, l in preds:
                    fo.write('{}\t{}\t{}\t{}\t{}\n'.format(id2entity[h], id2relation[r], id2entity[t], f, rk))
                    for e, val in l:
                        fo.write('{}:{:.4f} '.format(id2entity[e], val))
                    fo.write('\n')

    # --------------------------------------------------
    # Comments by Meng:
    # Save the annotations on hidden triplets.
    # --------------------------------------------------

    if args.record:
        # Annotate hidden triplets
        scores = kge_model.infer_step(kge_model, hidden_triples, args)
        with open(local_path + '/annotation.txt', 'w') as fo:
            for (h, r, t), s in zip(hidden_triples, scores):
                fo.write('{}\t{}\t{}\t{}\n'.format(id2entity[h], id2relation[r], id2entity[t], s))
    
    if args.evaluate_train:
        logging.info('Evaluating on Training Dataset...')
        metrics, preds = kge_model.test_step(kge_model, train_triples, all_true_triples, args)
        log_metrics('Test', step, metrics)