예제 #1
0
    def test_model_intersection(self):
        params = {
            'cm_width': 10,
            'cm_depth': 10,
            'max_set': 100,
            'entity_emb_size': 160,
            'relation_emb_size': 768,
            'vocab_emb_size': 64,
            'train_entity_emb': True,
            'train_relation_emb': True,
        }
        loader = data_loader.DataLoader(params=params,
                                        name='intersection',
                                        root_dir=ROOT_DIR + 'WikiMovies/',
                                        kb_file='kb.txt')
        my_model = model.EmQL('intersection', params, loader)
        self.sess.run(tf.global_variables_initializer())
        candidate_set1 = np.zeros((1, loader.num_entities), dtype=np.float32)
        candidate_set1[0, 1] = 1.0
        candidate_set1[0, 2] = 1.0
        candidate_set2 = np.zeros((1, loader.num_entities), dtype=np.float32)
        candidate_set2[0, 2] = 1.0
        candidate_set2[0, 3] = 1.0
        labels = candidate_set1 * candidate_set2

        loss, tensors = my_model.model_intersection(
            (candidate_set1, candidate_set2, labels), params)
        self.assertGreater(loss.eval(session=self.sess), 0)
        logits = tensors['logits']
        logits = logits.eval(session=self.sess)
        self.assertGreater(logits[0, 2], logits[0, 1])
        self.assertGreater(logits[0, 2], logits[0, 3])
예제 #2
0
    def test_model_follow(self):
        params = {
            'cm_width': 10,
            'cm_depth': 10,
            'max_set': 100,
            'entity_emb_size': 160,
            'relation_emb_size': 768,
            'vocab_emb_size': 64,
            'train_entity_emb': True,
            'train_relation_emb': True,
        }
        loader = data_loader.DataLoader(params=params,
                                        name='set_follow',
                                        root_dir=ROOT_DIR + 'WikiMovies/',
                                        kb_file='kb.txt')
        my_model = model.EmQL('set_follow', params, loader)
        self.sess.run(tf.global_variables_initializer())

        subject_set = np.zeros((1, loader.num_facts), dtype=np.float32)
        subject_set[0, 1] = 1.0
        subject_set[0, 2] = 1.0
        relation_set = np.zeros((1, loader.num_facts), dtype=np.float32)
        relation_set[0, 2] = 1.0
        relation_set[0, 3] = 1.0
        labels = subject_set * relation_set

        loss, tensors = my_model.model_follow(
            (subject_set, relation_set, labels), params)
        self.assertGreater(loss.eval(session=self.sess), 0)
        logits = tensors['logits']
        logits = logits.eval(session=self.sess)
        self.assertGreater(logits[0, 2], logits[0, 1])
        self.assertGreater(logits[0, 2], logits[0, 3])
예제 #3
0
    def test_model_membership(self):
        params = {
            'cm_width': 10,
            'cm_depth': 10,
            'max_set': 100,
            'entity_emb_size': 160,
            'relation_emb_size': 768,
            'vocab_emb_size': 64,
            'train_entity_emb': True,
            'train_relation_emb': True,
        }
        loader = data_loader.DataLoader(params=params,
                                        name='membership',
                                        root_dir=ROOT_DIR + 'WikiMovies/',
                                        kb_file='kb.txt')
        my_model = model.EmQL('membership', params, loader)
        self.sess.run(tf.global_variables_initializer())

        entity_ids = tf.constant([[0, 1, -1, -1], [2, -1, -1, -1]],
                                 dtype=tf.int32)
        labels = np.zeros((2, loader.num_entities), dtype=np.float32)
        labels[0, 0] = 1.0
        labels[0, 1] = 1.0
        labels[1, 2] = 1.0
        labels = tf.constant(labels, dtype=tf.float32)
        loss, tensors = my_model.model_membership((entity_ids, labels), None)
        self.assertGreater(loss.eval(session=self.sess), 0)
        logits = tensors['logits']
        logits = logits.eval(session=self.sess)
        self.assertGreater(logits[0, 0], 0)
        self.assertGreater(logits[0, 1], 0)
        self.assertGreater(logits[1, 2], 0)
예제 #4
0
    def test_model_query2box(self):
        params = {
            'cm_width': 10,
            'cm_depth': 10,
            'max_set': 100,
            'entity_emb_size': 64,
            'relation_emb_size': 64,
            'vocab_emb_size': 64,
            'train_entity_emb': False,
            'train_relation_emb': False,
            'intermediate_top_k': 5,
            'use_cm_sketch': True
        }
        loader = data_loader.DataLoader(params=params,
                                        name='query2box_uc',
                                        root_dir=ROOT_DIR +
                                        'Query2Box/FB15k-237/',
                                        kb_file='kb.txt')
        my_model = model.EmQL('query2box_uc', params, loader)
        self.sess.run(tf.global_variables_initializer())

        ent1 = np.array([1, 2], dtype=np.int32)
        rel1 = np.array([11, 12], dtype=np.int32)
        ent2 = np.array([3, 4], dtype=np.int32)
        rel2 = np.array([13, 14], dtype=np.int32)
        rel3 = np.array([15, 16], dtype=np.int32)
        loss, tensors = my_model.model_query2box(
            'query2box_uc', (ent1, rel1, ent2, rel2, rel3), params)

        # query2box will not be trained so the loss should always be 0.
        self.assertEqual(loss.eval(session=self.sess), 0.0)
        answer_ids = tensors['answer_ids']
        answer_ids = answer_ids.eval(session=self.sess)
        self.assertEqual(answer_ids.shape[1], params['intermediate_top_k'])
예제 #5
0
    def test_sketch(self):
        params = {
            'cm_width': 100,
            'cm_depth': 10,
            'max_set': 100,
            'entity_emb_size': 160,
            'relation_emb_size': 768,
            'vocab_emb_size': 64,
            'train_entity_emb': True,
            'train_relation_emb': True,
            'intermediate_top_k': 10,
            'use_cm_sketch': True
        }
        loader = data_loader.DataLoader(params=params,
                                        name='metaqa2',
                                        root_dir=ROOT_DIR + 'MetaQA/2hop/',
                                        kb_file='kb.txt',
                                        vocab_file='vocab.json')
        my_model = model.EmQL('metaqa2', params, loader)
        self.sess.run(tf.global_variables_initializer())

        entity_weights_np = np.array([[0.1, 0.9, 0.3]], dtype=np.float32)
        entity_ids = tf.constant([[0, 1, 2]], dtype=tf.int32)
        entity_weights = tf.constant(entity_weights_np, dtype=tf.float32)
        sketch = module.create_cm_sketch(entity_ids, entity_weights,
                                         my_model.all_entity_sketches,
                                         params['cm_width'])
        _, entity_weights_from_sketch = module.check_topk_fact_eligible(
            entity_ids, sketch, my_model.all_entity_sketches, params)
        entity_weights_from_sketch_np = entity_weights_from_sketch.eval(
            session=self.sess).astype(np.float32)
        self.assertAllClose(entity_weights_from_sketch_np, entity_weights_np)
예제 #6
0
 def test_loader_membership(self):
     params = {
         'cm_width': 10,
         'cm_depth': 10,
         'max_set': 100,
     }
     loader = data_loader.DataLoader(params=params,
                                     name='membership',
                                     root_dir=ROOT_DIR + 'WikiMovies/',
                                     kb_file='kb.txt')
     for one_data in loader.train_data_membership + loader.test_data_membership:
         self.assertGreater(len(one_data), 1)
예제 #7
0
 def test_loader_webqsp(self):
     params = {
         'cm_width': 10,
         'cm_depth': 10,
         'max_set': 100,
     }
     loader = data_loader.DataLoader(params=params,
                                     name='webqsp',
                                     root_dir=ROOT_DIR + 'WebQSP/',
                                     kb_file='kb_webqsp_constraint2.txt')
     self.assertGreater(len(loader.train_data_webqsp), 2500)
     self.assertGreater(len(loader.test_data_webqsp), 1500)
     self.assertGreater(len(loader.fact2id), 1000000)
예제 #8
0
 def test_loader_follow(self):
     params = {
         'cm_width': 10,
         'cm_depth': 10,
         'max_set': 100,
     }
     loader = data_loader.DataLoader(params=params,
                                     name='set_follow',
                                     root_dir=ROOT_DIR + 'WikiMovies/',
                                     kb_file='kb.txt')
     for subj_factids, rel_factids in (loader.train_data_follow +
                                       loader.test_data_follow):
         self.assertGreaterEqual(len(subj_factids & rel_factids), 1)
예제 #9
0
 def test_loader_intersection_union(self):
     params = {
         'cm_width': 10,
         'cm_depth': 10,
         'max_set': 100,
     }
     loader = data_loader.DataLoader(params=params,
                                     name='intersection',
                                     root_dir=ROOT_DIR + 'WikiMovies/',
                                     kb_file='kb.txt')
     for set1, set2 in (loader.train_data_set_pair +
                        loader.test_data_set_pair):
         self.assertGreaterEqual(len(set1 & set2), 1)
예제 #10
0
 def test_loader_query2box(self):
     params = {
         'cm_width': 10,
         'cm_depth': 10,
         'max_set': 100,
     }
     loader = data_loader.DataLoader(params=params,
                                     name='query2box_uc',
                                     root_dir=ROOT_DIR +
                                     'Query2Box/FB15k-237/',
                                     kb_file='kb.txt')
     self.assertIsNone(loader.train_data_query2box)
     self.assertEqual(len(loader.test_data_query2box), 5000)
     self.assertEqual(len(loader.test_data_query2box[0]), 3)
예제 #11
0
 def test_loader_metaqa(self):
     params = {
         'cm_width': 10,
         'cm_depth': 10,
         'max_set': 100,
     }
     loader = data_loader.DataLoader(params=params,
                                     name='metaqa2',
                                     root_dir=ROOT_DIR + 'MetaQA/2hop/',
                                     kb_file='kb.txt',
                                     vocab_file='vocab.json')
     for question, _, answer_fact_ids in (loader.train_data_metaqa +
                                          loader.test_data_metaqa):
         self.assertGreater(len(question), 0)
         self.assertGreaterEqual(len(answer_fact_ids), 1)
예제 #12
0
 def test_model_webqsp(self):
     kb_index_str = KB_INDEX_DIR
     bert_handle_str = 'https://tfhub.dev/google/bert_uncased_L-12_H-768_A-12/1'
     params = {
         'cm_width': 10,
         'cm_depth': 10,
         'max_set': 100,
         'entity_emb_size': 160,
         'relation_emb_size': 768,
         'vocab_emb_size': 64,
         'train_entity_emb': True,
         'train_relation_emb': True,
         'use_cm_sketch': True,
         'kb_index': kb_index_str,
         'bert_handle': bert_handle_str,
         'train_bert': False,
         'intermediate_top_k': 1000
     }
     loader = data_loader.DataLoader(params=params,
                                     name='webqsp',
                                     root_dir=ROOT_DIR + 'WebQSP/',
                                     kb_file='kb_webqsp_constraint2.txt')
     my_model = model.EmQL('webqsp', params, loader)
     question = tf.constant([[0, 1, 3, 0, 0]], dtype=tf.int32)
     segment_ids = tf.constant([[0, 0, 0, 0, 0]], dtype=tf.int32)
     question_mask = tf.constant([[1, 1, 1, 0, 0]], dtype=tf.int32)
     q_entity_id = 0
     question_entity_id = tf.constant([q_entity_id], dtype=tf.int32)
     constraint_entity_id = tf.constant([q_entity_id], dtype=tf.int32)
     question_entity_sketch = loader.cm_context.get_sketch(xs=[0])
     question_entity_sketch = tf.constant(question_entity_sketch,
                                          dtype=tf.float32)
     question_entity_sketch = tf.expand_dims(question_entity_sketch, axis=0)
     constraint_entity_sketch = question_entity_sketch
     answers = tf.constant([[0, 1, 3, -1, -1]], dtype=tf.int32)
     loss, _ = my_model.model_webqsp(
         (question, segment_ids, question_mask, question_entity_id,
          question_entity_sketch, constraint_entity_id,
          constraint_entity_sketch, answers),
         params,
         top_k=params['intermediate_top_k'])
     self.sess.run(tf.global_variables_initializer())
     self.sess.run(tf.local_variables_initializer())
     self.assertGreater(loss.eval(session=self.sess), 0)
예제 #13
0
    def test_eval(self):
        root_dir = self.tmp_dir
        params = {
            'cm_width': 10,
            'cm_depth': 10,
            'max_set': 100,
            'entity_emb_size': 64,
            'relation_emb_size': 64,
            'vocab_emb_size': 64,
            'train_entity_emb': False,
            'train_relation_emb': False,
            'intermediate_top_k': 10,
            'use_cm_sketch': True
        }

        loader = data_loader.DataLoader(params=params,
                                        name='query2box_ic',
                                        root_dir=root_dir,
                                        kb_file='kb.txt')

        q2b_metrics = emql_eval.Query2BoxMetrics(self.task, root_dir, loader)
        ent1, rel1, ent2, rel2, rel3 = 0, 0, 1, 2, 4
        features = np.array([ent1, rel1, ent2, rel2, rel3])
        # Predictions that are not in hard_answers will be skipped for
        # evaluation. We refer answers not in hard_answers as easy answers.
        # For clarity, we copy over the all_answers and
        # hard_answers of the query below:
        #     all_answers = {3, 4, 5}
        #     easy_answers = {3, 4}
        #     hard_answers = {5}
        answer_ids = np.array([4, 2, 5])
        tf_prediction = {'query': features, 'answer_ids': answer_ids}
        q2b_metrics.eval(tf_prediction)

        # hits@1 is 0 because the first non easy_answers is 2, but it's
        # not in correct hard_answers. hits@3 and hits@10 is 1.0 because
        # because the second non easy_answer 5 is in correct hard_answers.
        self.assertEqual(q2b_metrics.metrics['hits@1'], 0.0)
        self.assertEqual(q2b_metrics.metrics['hits@3'], 1.0)
        self.assertEqual(q2b_metrics.metrics['hits@10'], 1.0)
        self.assertEqual(q2b_metrics.metrics['mrr'], 0.5)
예제 #14
0
 def test_model_metaqa(self):
     params = {
         'cm_width': 10,
         'cm_depth': 10,
         'max_set': 100,
         'entity_emb_size': 160,
         'relation_emb_size': 768,
         'vocab_emb_size': 64,
         'train_entity_emb': True,
         'train_relation_emb': True,
         'intermediate_top_k': 10,
         'use_cm_sketch': True
     }
     loader = data_loader.DataLoader(params=params,
                                     name='metaqa2',
                                     root_dir=ROOT_DIR + 'MetaQA/2hop/',
                                     kb_file='kb.txt',
                                     vocab_file='vocab.json')
     my_model = model.EmQL('metaqa2', params, loader)
     question = tf.constant([[0, 1, 3, -1, -1]], dtype=tf.int32)
     q_entity_id = 0
     question_entity_id = tf.constant([q_entity_id], dtype=tf.int32)
     question_entity_sketch = loader.cm_context.get_sketch(xs=[0])
     question_entity_sketch = tf.constant(question_entity_sketch,
                                          dtype=tf.float32)
     question_entity_sketch = tf.expand_dims(question_entity_sketch, axis=0)
     answers = np.zeros((1, loader.num_facts), dtype=np.float32)
     answers[0, 100] = 1
     answers = tf.constant(answers, dtype=tf.float32)
     loss, _ = my_model.model_metaqa(
         (question, question_entity_id, question_entity_sketch, answers),
         params,
         hop=2,
         top_k=params['intermediate_top_k'])
     self.sess.run(tf.global_variables_initializer())
     self.assertGreater(loss.eval(session=self.sess), 0)