Example #1
0
 def load_model(self, path, batch_id):
     """
     remain for compatible
     """
     config_path = os.path.join(path, 'config,json')
     if os.path.exists(config_path):
         self.config = unserialize(os.path.join(path, 'config.json'))
         self.cogKR = CogKR(graph=self.kg,
                            entity_dict=self.entity_dict,
                            relation_dict=self.relation_dict,
                            device=self.device,
                            **self.config['model']).to(self.device)
         model_state = torch.load(
             os.path.join(path,
                          str(batch_id) + ".model.dict"))
         self.cogKR.load_state_dict(model_state)
Example #2
0
 def build_model(self, model_config):
     self.config['model'] = model_config
     self.cogKR = CogKR(
         graph=self.kg,
         entity_dict=self.entity_dict,
         relation_dict=self.relation_dict,
         max_nodes=model_config['max_nodes'],
         max_neighbors=model_config['max_neighbors'],
         embed_size=model_config['embed_size'],
         hidden_size=model_config['hidden_size'],
         topk=model_config['topk'],
         reward_policy=model_config.get('reward_policy', 'direct'),
         device=self.device,
         sparse_embed=self.sparse_embed,
         id2entity=self.id2entity,
         id2relation=self.id2relation,
         use_summary=self.config['trainer'].get('meta_learn',
                                                True)).to(self.device)
     self.agent = self.cogKR.agent
     self.coggraph = self.cogKR.cog_graph
     self.summary = self.cogKR.summary
Example #3
0
class Main:
    def __init__(self,
                 args,
                 root_directory,
                 device=torch.device("cpu"),
                 comment="",
                 sparse_embed=False,
                 relation_encode=False,
                 tqdm_wrapper=tqdm_notebook):
        self.args = args
        self.root_directory = root_directory
        self.comment = comment
        self.device = device
        self.tqdm_wrapper = tqdm_wrapper
        self.config = {
            'graph': {
                'train_width': 256,
                'test_width': 2000
            },
            'model': {
                "max_nodes": 256,
                "max_neighbors": 32,
                "embed_size": 64,
                "topk": 5,
                'reward_policy': 'direct'
            },
            'optimizer': {
                'name': 'Adam',
                'summary': {
                    'lr': 1e-5
                },
                'embed': {
                    'lr': 1e-5
                },
                'agent': {
                    'lr': 1e-4
                },
                'config': {
                    'weight_decay': 1e-4
                }
            },
            'pretrain_optimizer': {
                'lr': 1e-4
            },
            'trainer': {
                'ignore_relation': True,
                'weighted_sample': True
            },
            'train': {
                'batch_size': 32,
                'log_interval': 1000,
                'evaluate_interval': 5000,
                'validate_metric': 'MAP'
            },
            'pretrain': {
                'batch_size': 64,
                'keep_embed': False
            }
        }
        self.measure_dict = {
            'Hit@1': functools.partial(hitRatio, topn=1),
            'Hit@3': functools.partial(hitRatio, topn=3),
            'Hit@5': functools.partial(hitRatio, topn=5),
            'Hit@10': functools.partial(hitRatio, topn=10),
            'hitRatio': hitRatio,
            'MAP': MAP
        }
        self.sparse_embed = sparse_embed
        self.relation_encode = relation_encode
        self.best_results = {}
        self.data_loaded = False
        self.env_built = False

    def init(self, config=None):
        if config is not None:
            self.config = config
        if not self.data_loaded:
            self.load_data()
        if not self.env_built:
            self.build_env(self.config['graph'])
        self.build_model(self.config['model'])
        self.build_logger()
        self.build_optimizer(self.config['optimizer'])

    def load_data(self):
        self.data_directory = os.path.join(self.root_directory, "data")
        self.entity_dict = load_index(
            os.path.join(self.data_directory, "ent2id.txt"))
        self.relation_dict = load_index(
            os.path.join(self.data_directory, "relation2id.txt"))
        self.facts_data = translate_facts(
            load_facts(os.path.join(self.data_directory, "train.txt")),
            self.entity_dict, self.relation_dict)
        self.test_support = translate_facts(
            load_facts(os.path.join(self.data_directory, "test_support.txt")),
            self.entity_dict, self.relation_dict)
        self.valid_support = translate_facts(
            load_facts(os.path.join(self.data_directory, "valid_support.txt")),
            self.entity_dict, self.relation_dict)
        self.test_eval = translate_facts(
            load_facts(os.path.join(self.data_directory, "test_eval.txt")),
            self.entity_dict, self.relation_dict)
        self.valid_eval = translate_facts(
            load_facts(os.path.join(self.data_directory, "valid_eval.txt")),
            self.entity_dict, self.relation_dict)
        # augment
        with open(os.path.join(self.data_directory, 'pagerank.txt')) as file:
            self.pagerank = list(
                map(lambda x: float(x.strip()), file.readlines()))
        if os.path.exists(os.path.join(self.data_directory, "fact_dist")):
            self.fact_dist = unserialize(
                os.path.join(self.data_directory, "fact_dist"))
        else:
            self.fact_dist = None
        if os.path.exists(os.path.join(self.data_directory, "train_graphs")):
            self.train_graphs = unserialize(
                os.path.join(self.data_directory, "train_graphs"))
        else:
            self.train_graphs = None
        if os.path.exists(os.path.join(self.data_directory,
                                       "evaluate_graphs")):
            print("Use evaluate graphs")
            self.evaluate_graphs = unserialize(
                os.path.join(self.data_directory, "evaluate_graphs"))
        else:
            print("Warning: Can't find evaluate graphs")
            self.evaluate_graphs = None
        if os.path.exists(os.path.join(self.data_directory, "rel2candidates")):
            self.rel2candidate = unserialize(
                os.path.join(self.data_directory, "rel2candidates"))
        else:
            self.rel2candidate = {}
        # self.rel2candidate = {self.relation_dict[key]: value for key, value in self.rel2candidate.items() if
        #                       key in self.relation_dict}
        self.id2entity = sorted(self.entity_dict.keys(),
                                key=self.entity_dict.get)
        self.id2relation = sorted(self.relation_dict.keys(),
                                  key=self.relation_dict.get)
        self.data_loaded = True

    def build_env(self, graph_config, build_matrix=True):
        self.config['graph'] = graph_config
        self.reverse_relation = [
            self.relation_dict[inverse_relation(relation)]
            for relation in self.id2relation
        ]
        self.kg = KG(self.facts_data,
                     entity_num=len(self.entity_dict),
                     relation_num=len(self.relation_dict),
                     node_scores=self.pagerank,
                     build_matrix=build_matrix,
                     **graph_config)
        self.trainer = Trainer(self.kg,
                               self.facts_data,
                               reverse_relation=self.reverse_relation,
                               cutoff=3,
                               train_graphs=self.train_graphs,
                               validate_tasks=(self.valid_support,
                                               self.valid_eval),
                               test_tasks=(self.test_support, self.test_eval),
                               evaluate_graphs=self.evaluate_graphs,
                               id2entity=self.id2entity,
                               id2relation=self.id2relation,
                               rel2candidate=self.rel2candidate,
                               fact_dist=self.fact_dist,
                               **self.config.get('trainer', {}))
        self.env_built = True

    def load_model(self, path, batch_id):
        """
        remain for compatible
        """
        config_path = os.path.join(path, 'config,json')
        if os.path.exists(config_path):
            self.config = unserialize(os.path.join(path, 'config.json'))
            self.cogKR = CogKR(graph=self.kg,
                               entity_dict=self.entity_dict,
                               relation_dict=self.relation_dict,
                               device=self.device,
                               **self.config['model']).to(self.device)
            model_state = torch.load(
                os.path.join(path,
                             str(batch_id) + ".model.dict"))
            self.cogKR.load_state_dict(model_state)

    def load_state(self, path, train=True):
        state = torch.load(path)
        if 'config' in state:
            self.config = state['config']
        self.build_model(self.config['model'])
        self.cogKR.load_state_dict(state['model'])
        if train:
            self.build_optimizer(self.config['optimizer'])
            self.optimizer.load_state_dict(state['optimizer'])
            self.log_directory = os.path.dirname(path)
            if 'batch_id' in state:
                self.batch_id = state['batch_id']
            else:
                self.batch_id = int(os.path.basename(path).split('.')[0])
            self.build_logger(self.log_directory, self.batch_id)
            if 'best_results' in state:
                self.best_results = state['best_results']
            else:
                self.best_results = {}
            self.total_graph_loss = state.get('graph_loss', 0.0)
            self.total_rank_loss = state.get('rank_loss', 0.0)
            self.total_graph_size = state.get('graph_size', 0.0)
            self.total_reward = state.get('reward', 0.0)

    def load_pretrain(self, path):
        state_dict = torch.load(path)
        self.cogKR.summary.load_state_dict(state_dict)

    def build_pretrain_model(self, model_config):
        self.config['model'] = model_config
        entity_embeddings = nn.Embedding(len(self.entity_dict) + 1,
                                         model_config['embed_size'],
                                         padding_idx=len(self.entity_dict))
        relation_embeddings = nn.Embedding(len(self.relation_dict) + 1,
                                           model_config['embed_size'],
                                           padding_idx=len(self.relation_dict))
        self.summary = Summary(
            model_config.get('hidden_size', model_config['embed_size']),
            graph=self.kg,
            entity_embeddings=entity_embeddings,
            relation_embeddings=relation_embeddings).to(self.device)

    def build_model(self, model_config):
        self.config['model'] = model_config
        self.cogKR = CogKR(
            graph=self.kg,
            entity_dict=self.entity_dict,
            relation_dict=self.relation_dict,
            max_nodes=model_config['max_nodes'],
            max_neighbors=model_config['max_neighbors'],
            embed_size=model_config['embed_size'],
            hidden_size=model_config['hidden_size'],
            topk=model_config['topk'],
            reward_policy=model_config.get('reward_policy', 'direct'),
            baseline_lambda=model_config.get('baseline_lambda', 0.0),
            onlyS=model_config.get('onlyS', False),
            device=self.device,
            sparse_embed=self.sparse_embed,
            id2entity=self.id2entity,
            id2relation=self.id2relation,
            use_summary=self.config['trainer'].get('meta_learn',
                                                   True)).to(self.device)
        self.agent = self.cogKR.agent
        self.coggraph = self.cogKR.cog_graph
        self.summary = self.cogKR.summary

    def build_pretrain_optimiaer(self, optimizer_config):
        self.config['pretrain_optimizer'] = optimizer_config
        if self.config['pretrain'].get('keep_embed', False):
            print("Keep embedding")
            self.summary.entity_embeddings.weight.requires_grad_(False)
            self.summary.relation_embeddings.weight.requires_grad_(False)
        self.parameters = list(
            filter(lambda x: x.requires_grad, self.summary.parameters()))
        self.optimizer = torch.optim.__getattribute__(
            optimizer_config['name'])(self.parameters,
                                      **optimizer_config['config'])
        self.loss_function = nn.CrossEntropyLoss()
        self.predict_loss = 0.0

    def build_optimizer(self, optimizer_config):
        self.config['optimizer'] = optimizer_config
        parameter_ids = set()
        self.parameters = []
        self.optim_params = []
        sparse_ids = set()
        self.embed_parameters = list(
            self.cogKR.relation_embeddings.parameters()) + list(
                self.cogKR.entity_embeddings.parameters())
        self.parameters.extend(self.embed_parameters)
        parameter_ids.update(map(id, self.embed_parameters))
        print('Embedding parameters:',
              list(map(lambda x: x.size(), self.embed_parameters)))
        if self.sparse_embed:
            self.embed_optimizer = torch.optim.SparseAdam(
                self.embed_parameters, **optimizer_config['embed'])
        else:
            self.optim_params.append({
                'params': self.embed_parameters,
                **optimizer_config['embed']
            })
        self.summary_parameters = list(
            filter(lambda x: id(x) not in parameter_ids,
                   self.summary.parameters()))
        self.parameters.extend(self.summary_parameters)
        parameter_ids.update(map(id, self.summary_parameters))
        self.agent_parameters = list(
            filter(lambda x: id(x) not in parameter_ids,
                   self.cogKR.parameters()))
        self.parameters.extend(self.agent_parameters)
        if self.sparse_embed:
            sparse_ids.update(
                map(id, self.cogKR.entity_embeddings.parameters()))
            sparse_ids.update(
                map(id, self.cogKR.relation_embeddings.parameters()))
        self.dense_parameters = list(
            filter(lambda x: id(x) not in sparse_ids, self.parameters))
        print(list(map(lambda x: x.size(), self.dense_parameters)))
        self.optim_params.extend([{
            'params': self.summary_parameters,
            **optimizer_config['summary'],
        }, {
            'params': self.agent_parameters,
            **optimizer_config['agent']
        }])
        self.optimizer = torch.optim.__getattribute__(
            optimizer_config['name'])(self.optim_params,
                                      **optimizer_config['config'])
        self.total_graph_loss, self.total_rank_loss = 0.0, 0.0
        self.total_graph_size, self.total_reward = 0, 0.0

    def save_state(self, is_best=False):
        if is_best:
            filename = os.path.join(self.log_directory, "best.state")
        else:
            filename = os.path.join(self.log_directory,
                                    str(self.batch_id + 1) + ".state")
        torch.save(
            {
                'config': self.config,
                'model': self.cogKR.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'batch_id': self.batch_id + 1,
                'graph_loss': self.total_graph_loss,
                'rank_loss': self.total_rank_loss,
                'reward': self.total_reward,
                'graph_size': self.total_graph_size,
                'log_file': self.log_file,
                'best_results': self.best_results
            }, filename)

    def evaluate_model(self,
                       mode='test',
                       output=None,
                       save_graph=None,
                       **kwargs):
        with torch.no_grad():
            self.cogKR.eval()
            current = time.time()
            results = multi_mean_measure(
                self.trainer.evaluate_generator(
                    self.cogKR,
                    self.trainer.evaluate(mode=mode, **kwargs),
                    save_result=output,
                    save_graph=save_graph), self.measure_dict)
            if self.args.inference_time:
                print(time.time() - current)
            self.cogKR.train()
        return results

    def build_logger(self, log_directory=None, batch_id=None):
        self.log_directory = log_directory
        if self.log_directory is None:
            self.log_directory = os.path.join(
                self.root_directory, "log", "-".join(
                    (time.strftime("%m-%d-%H"), self.comment)))
            if not os.path.exists(self.log_directory):
                os.makedirs(self.log_directory)
            serialize(self.config,
                      os.path.join(self.log_directory, 'config.json'),
                      in_json=True)
        self.log_file = os.path.join(self.log_directory, "log")
        if batch_id is None:
            self.writer = SummaryWriter(self.log_file)
            self.batch_sampler = itertools.count()
        else:
            self.writer = SummaryWriter(self.log_file, purge_step=batch_id)
            self.batch_sampler = itertools.count(start=batch_id)

    def pretrain(self, single_step=False):
        for batch_id in tqdm(self.batch_sampler):
            support_pairs, labels = self.trainer.predict_sample(
                self.config['pretrain']['batch_size'])
            labels = torch.tensor(labels, dtype=torch.long, device=self.device)
            scores = self.summary(support_pairs, predict=True)
            loss = self.loss_function(scores, labels)
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.parameters,
                                           0.25,
                                           norm_type='inf')
            self.optimizer.step()

            self.predict_loss += loss.item()
            if (batch_id + 1) % self.config['pretrain'].get(
                    'log_interval', 1000) == 0:
                print(self.predict_loss)
                self.writer.add_scalar('predict_loss',
                                       self.predict_loss / 1000, batch_id)
                self.predict_loss = 0.0
            if (batch_id + 1) % self.config['pretrain'].get(
                    'evaluate_interval', 10000) == 0:
                torch.save(
                    self.summary.state_dict(),
                    os.path.join(self.log_directory,
                                 'summary.' + str(batch_id + 1) + ".dict"))
            if single_step:
                break

    def log(self):
        interval = self.config['train']['log_interval']
        self.writer.add_scalar('graph_loss', self.total_graph_loss / interval,
                               self.batch_id)
        self.writer.add_scalar('rank_loss', self.total_rank_loss / interval,
                               self.batch_id)
        self.writer.add_scalar('reward', self.total_reward / interval,
                               self.batch_id)
        self.writer.add_scalar('graph_size', self.total_graph_size / interval,
                               self.batch_id)
        print(self.total_graph_loss, self.total_rank_loss)
        self.total_graph_loss, self.total_rank_loss = 0.0, 0.0
        self.total_graph_size, self.total_reward = 0, 0.0

    def train(self, single_step=False):
        meta_learn = self.config.get('trainer', {}).get('meta_learn', True)
        validate_metric = self.config.get('train',
                                          {}).get('validate_metric', 'MAP')
        print('Graph loss weight:',
              self.config['train'].get('graph_weight', 1.0))
        try:
            for self.batch_id in self.tqdm_wrapper(self.batch_sampler):
                support_pairs, query_heads, query_tails, relations, graphs = self.trainer.sample(
                    self.config['train']['batch_size'])
                if meta_learn:
                    graph_loss, rank_loss = self.cogKR(
                        query_heads,
                        end_entities=query_tails,
                        support_pairs=support_pairs,
                        evaluate=False,
                        stochastic=True)
                else:
                    graph_loss, rank_loss = self.cogKR(
                        query_heads,
                        end_entities=query_tails,
                        relations=relations,
                        evaluate=False,
                        stochastic=True)
                self.optimizer.zero_grad()
                if self.sparse_embed:
                    self.embed_optimizer.zero_grad()
                (self.config['train'].get('graph_weight', 1.0) * graph_loss +
                 rank_loss).backward()
                torch.nn.utils.clip_grad_norm_(self.dense_parameters,
                                               0.25,
                                               norm_type='inf')
                self.optimizer.step()
                if self.sparse_embed:
                    self.embed_optimizer.step()
                if torch.isnan(graph_loss) or torch.isnan(rank_loss):
                    break
                else:
                    self.total_graph_loss += graph_loss.item()
                    self.total_rank_loss += rank_loss.item()
                    self.total_reward += self.cogKR.reward
                    self.total_graph_size += self.cogKR.graph_size
                if (self.batch_id +
                        1) % self.config['train']['log_interval'] == 0:
                    self.log()
                if (self.batch_id +
                        1) % self.config['train']['evaluate_interval'] == 0:
                    with torch.no_grad():
                        test_results = self.evaluate_model(mode='test')
                        validate_results = self.evaluate_model(mode='valid')
                        print("Validate results:", validate_results)
                        update = False
                        for key, value in test_results.items():
                            self.writer.add_scalar(key, value, self.batch_id)
                        if validate_metric not in self.best_results or validate_results[
                                validate_metric] >= self.best_results[
                                    validate_metric]:
                            print("Test results:", test_results)
                            self.save_state(is_best=True)
                        for key, value in validate_results.items():
                            if key not in self.best_results or value > self.best_results[
                                    key]:
                                self.best_results[key] = value
                self.local = locals()
                if single_step:
                    break
        except:
            self.local = locals()
            raise

    def get_fact_dist(self, ignore_relation=True):
        graph = self.kg.to_networkx(multi=True, neighbor_limit=256)
        fact_dist = {}
        for relation, pairs in tqdm(self.trainer.train_query.items()):
            deleted_edges = []
            if ignore_relation:
                reverse_relation = self.reverse_relation[relation]
                for head, tail in itertools.chain(
                        pairs, self.trainer.train_support[relation],
                        self.trainer.train_query[reverse_relation],
                        self.trainer.train_support[reverse_relation]):
                    try:
                        graph.remove_edge(head, tail, relation)
                        deleted_edges.append((head, tail, relation))
                    except NetworkXError:
                        pass
                    try:
                        graph.remove_edge(head, tail, reverse_relation)
                        deleted_edges.append((head, tail, reverse_relation))
                    except NetworkXError:
                        pass
            for head, tail in itertools.chain(
                    self.trainer.train_query[relation],
                    self.trainer.train_support[relation]):
                delete_edge = False
                try:
                    graph.remove_edge(head, tail, relation)
                    delete_edge = True
                except NetworkXError:
                    pass
                try:
                    dist = shortest_path_length(graph, head, tail)
                except NetworkXNoPath or KeyError:
                    dist = -1
                fact_dist[(head, relation, tail)] = dist
                if delete_edge:
                    graph.add_edge(head, tail, relation)
            graph.add_edges_from(deleted_edges)
        return fact_dist

    def get_dist_dict(self, mode='test', by_relation=True):
        self.graph = self.kg.to_networkx(multi=False)
        global_dist_count = defaultdict(int)
        fact_dist = {}
        if mode == 'test':
            relations = self.trainer.test_relations
        elif mode == 'valid':
            relations = self.trainer.validate_relations
        else:
            raise NotImplemented
        for relation in relations:
            dist_count = defaultdict(int)
            for head, tail in self.trainer.task_ground[relation]:
                try:
                    dist = shortest_path_length(self.graph, head, tail)
                except networkx.NetworkXNoPath:
                    dist = -1
                dist_count[dist] += 1
                global_dist_count[dist] += 1
                fact_dist[(head, relation, tail)] = dist
            if by_relation:
                print(relation, sorted(dist_count.items(), key=lambda x: x[0]))
        print(sorted(global_dist_count.items(), key=lambda x: x[0]))
        return fact_dist, global_dist_count

    def get_onehop_ratio(self):
        e1e2_rel = {}
        for relation, pairs in self.trainer.train_support.items():
            for pair in pairs:
                e1e2_rel.setdefault(pair, set())
                e1e2_rel[pair].add(relation)

        sums, num = 0, 0
        for relation, pairs in self.trainer.task_ground.items():
            print(relation)
            for head, tail in pairs:
                num += 1
                if (head, tail) in e1e2_rel:
                    sums += 1
                    print(e1e2_rel[(head, tail)])
        return sums / num

    def get_test_fact_num(self):
        sums = 0
        for task in self.trainer.test_relations:
            sums += len(self.trainer.task_ground[task])
        return sums

    def save_to_hyper(self, data_dir):
        if not os.path.exists(data_dir):
            os.mkdir(data_dir)

        def save_to_file(data, path):
            with open(path, "w") as output:
                for head, relation, tail in data:
                    output.write("{}\t{}\t{}\n".format(head, relation, tail))

        def save_dict(data, path):
            with open(path, "w") as output:
                for entry, idx in data.items():
                    output.write("{}\t{}\n".format(idx, entry))

        facts_data = list(
            filter(lambda x: not self.id2relation[x[1]].endswith("_inv"),
                   self.facts_data))
        facts_data = list(
            map(
                lambda x: (self.id2entity[x[0]], self.id2relation[x[1]], self.
                           id2entity[x[2]]), facts_data))
        supports = [(self.id2entity[head], self.id2relation[relation],
                     self.id2entity[tail])
                    for relation, (head,
                                   tail) in self.trainer.task_support.items()]
        facts_data = list(itertools.chain(facts_data, supports))
        valid_evaluate = [(self.id2entity[head], self.id2relation[relation],
                           self.id2entity[tail])
                          for relation in self.trainer.validate_relations
                          for head, tail in self.trainer.task_ground[relation]]
        test_evaluate = [(self.id2entity[head], self.id2relation[relation],
                          self.id2entity[tail])
                         for relation in self.trainer.test_relations
                         for head, tail in self.trainer.task_ground[relation]]
        save_to_file(
            itertools.chain(facts_data, *itertools.repeat(supports, 1)),
            os.path.join(data_dir, 'train.txt'))
        save_to_file(valid_evaluate, os.path.join(data_dir, "dev.txt"))
        save_to_file(test_evaluate, os.path.join(data_dir, "test.txt"))
        save_to_file(add_reverse_relations(facts_data),
                     os.path.join(data_dir, "graph.txt"))
        save_dict(self.entity_dict, os.path.join(data_dir, 'entities.dict'))
        save_dict(self.relation_dict, os.path.join(data_dir, 'relations.dict'))

    def save_to_multihop(self, data_dir):
        """
        generate the file in the format of https://github.com/salesforce/MultiHopKG
        """
        if not os.path.exists(data_dir):
            os.mkdir(data_dir)

        def save_to_file(data, path):
            with open(path, "w") as output:
                for head, relation, tail in data:
                    output.write("{}\t{}\t{}\n".format(head, tail, relation))

        facts_data = list(
            filter(lambda x: not self.id2relation[x[1]].endswith("_inv"),
                   self.facts_data))
        facts_data = list(
            map(
                lambda x: (self.id2entity[x[0]], self.id2relation[x[1]], self.
                           id2entity[x[2]]), facts_data))
        supports = [(self.id2entity[head], relation, self.id2entity[tail])
                    for relation, (head,
                                   tail) in self.trainer.task_support.items()]
        valid_evaluate = [(self.id2entity[head], relation,
                           self.id2entity[tail])
                          for relation in self.trainer.validate_relations
                          for head, tail in self.trainer.task_ground[relation]]
        test_evaluate = [(self.id2entity[head], relation, self.id2entity[tail])
                         for relation in self.trainer.test_relations
                         for head, tail in self.trainer.task_ground[relation]]
        save_to_file(itertools.chain(facts_data, supports),
                     os.path.join(data_dir, 'raw.kb'))
        save_to_file(itertools.chain(facts_data, supports),
                     os.path.join(data_dir, "train.triples"))
        save_to_file(valid_evaluate, os.path.join(data_dir, "dev.triples"))
        save_to_file(test_evaluate, os.path.join(data_dir, "test.triples"))
        with open(os.path.join(data_dir, 'raw.pgrk'), "w") as output:
            for key, value in enumerate(self.pagerank):
                output.write("{}\t:{}\n".format(self.id2entity[key], value))
        with open(os.path.join(data_dir, "rel2candidates"), "w") as output:
            for relation, candidates in self.rel2candidate.items():
                for candidate in candidates:
                    output.write("{}\t{}\n".format(relation,
                                                   self.id2entity[candidate]))