Exemplo n.º 1
0
    def testDeepMatcDocUsrProjection(self):
        """Tests DeepMatch with doc/usr fields projection."""
        hparams = copy.copy(self.hparams)
        # ftr_ext = cnn
        with tf.Graph().as_default():
            query = tf.constant(self.query, dtype=tf.int32)
            doc_fields = self._get_constant_doc_fields()
            usr_fields = self._get_constant_usr_fields()
            wide_ftrs = tf.constant(self.wide_ftrs, dtype=tf.float32)
            setattr(hparams, 'use_doc_projection', True)
            setattr(hparams, 'use_usr_projection', True)
            setattr(hparams, 'num_usr_fields', len(usr_fields))
            setattr(hparams, 'explicit_empty', False)
            # num_sim_ftrs should be doc projection size* (query + user projection size)
            expected_num_sim_ftrs = 1 * (1 + 1)
            # Test no query field when ftr_ext = bert
            dm = deep_match.DeepMatch(query,
                                      wide_ftrs,
                                      doc_fields,
                                      hparams,
                                      tf.estimator.ModeKeys.EVAL,
                                      usr_fields=usr_fields)

            self.assertAllEqual(dm.deep_ftr_model.num_sim_ftrs,
                                expected_num_sim_ftrs)
            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                self.assertAllEqual(dm.deep_ftr_model.sim_ftrs.eval().shape,
                                    self.group_size + [expected_num_sim_ftrs])
Exemplo n.º 2
0
    def testDeepMatchWithUsrField(self):
        """Tests DeepMatch with user fields."""

        hparams = copy.copy(self.hparams)
        # ftr_ext = bert
        with tf.Graph().as_default():
            query = tf.constant(self.query, dtype=tf.int32)
            wide_ftrs = tf.constant(self.wide_ftrs, dtype=tf.float32)
            doc_fields = self._get_constant_doc_fields()
            usr_fields = self._get_constant_usr_fields()
            setattr(hparams, 'emb_sim_func', ['concat'])
            setattr(hparams, 'num_usr_fields', len(usr_fields))

            # Test no query field when ftr_ext = bert
            dm = deep_match.DeepMatch(query,
                                      wide_ftrs,
                                      doc_fields,
                                      hparams,
                                      tf.estimator.ModeKeys.EVAL,
                                      usr_fields=usr_fields)

            with self.test_session() as sess:
                sess.run(tf.global_variables_initializer())
                wide_ftrs = wide_ftrs.eval()

                # all_ftrs should contain information about wide_ftrs user_fields and doc_fields
                self.assertAllEqual(
                    dm.all_ftrs.eval().shape[-1],
                    (len(doc_fields) + len(usr_fields) + 1) *
                    (dm.deep_ftr_model.ftr_size + int(hparams.explicit_empty))
                    + wide_ftrs.shape[-1])
Exemplo n.º 3
0
    def testDeepMatchNoQuery(self):
        """Tests DeepMatch with query as None.

        Two conditions must be met when query is None:
        1. query_ftrs of model.text_encoding_model must be None
        2. all_ftrs should only contain information about doc_ftrs and wide_ftrs
        """
        hparams = copy.copy(self.hparams)
        # ftr_ext = cnn
        with tf.Graph().as_default():
            query = None
            doc_fields = self._get_constant_doc_fields()
            wide_ftrs = tf.constant(self.wide_ftrs, dtype=tf.float32)
            setattr(hparams, 'ftr_ext', 'cnn')
            # Test no query field when ftr_ext = bert
            dm = deep_match.DeepMatch(query, wide_ftrs, doc_fields, hparams,
                                      tf.estimator.ModeKeys.EVAL)

            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                wide_ftrs = wide_ftrs.eval()

                # all_ftrs should only contain wide_ftrs and doc_fields
                doc_ftrs_size_after_cnn = len(doc_fields) * len(
                    hparams.filter_window_sizes) * hparams.num_filters
                self.assertAllEqual(
                    dm.all_ftrs.eval().shape[-1], doc_ftrs_size_after_cnn +
                    wide_ftrs.shape[-1] + len(doc_fields))

                # query_ftrs of model.text_encoding_model must be None
                self.assertAllEqual(
                    dm.deep_ftr_model.text_encoding_model.query_ftrs, None)
Exemplo n.º 4
0
    def testDeepMatchNaNFtrs(self):
        """Tests DeepMatch outputs"""
        query = tf.constant(self.query, dtype=tf.int32)
        doc_fields = self._get_constant_doc_fields()

        nan_wide_ftrs = np.copy(self.wide_ftrs)
        nan_removed_wide_ftrs = np.copy(self.wide_ftrs)

        nan_wide_ftrs[0] = np.nan
        nan_removed_wide_ftrs[0] = 0

        wide_ftrs = tf.constant(nan_wide_ftrs, dtype=tf.float32)
        hparams = self.hparams

        dm = deep_match.DeepMatch(query,
                                  wide_ftrs,
                                  doc_fields,
                                  hparams,
                                  tf.estimator.ModeKeys.EVAL)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            wide_ftrs_dm, scores = sess.run(
                [dm._wide_ftrs,
                 dm.scores]
            )

            # check sizes and shapes
            self.assertAllClose(wide_ftrs_dm, nan_removed_wide_ftrs)
            self.assertAllEqual(scores.shape, [2, 3])
Exemplo n.º 5
0
    def testDeepMatchClassificationNoQueryNoWide(self):
        """Tests DeepMatch for classification outputs"""
        doc_field1 = tf.constant(np.random.rand(2, 1, 4), dtype=tf.int32)
        doc_field2 = tf.constant(np.random.rand(2, 1, 4), dtype=tf.int32)
        doc_fields = [doc_field1, doc_field2]
        hparamscp = copy.copy(self.hparams)
        hparamscp.num_classes = 7
        dm = deep_match.DeepMatch(None, None, doc_fields, hparamscp,
                                  tf.estimator.ModeKeys.EVAL)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            scores = sess.run(dm.scores)
        # Check sizes and shapes
        self.assertAllEqual(scores.shape, [2, hparamscp.num_classes])
Exemplo n.º 6
0
    def testDeepMatch(self):
        """Tests DeepMatch outputs"""
        hparams = self.hparams
        query = tf.constant(self.query, dtype=tf.int32)
        doc_fields = self._get_constant_doc_fields()

        wide_ftrs = tf.constant(self.wide_ftrs, dtype=tf.float32)

        dm = deep_match.DeepMatch(query, wide_ftrs, doc_fields, hparams,
                                  tf.estimator.ModeKeys.EVAL)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            scores = sess.run(dm.scores)

            # Check sizes and shapes
            self.assertAllEqual(scores.shape, [2, 3])
Exemplo n.º 7
0
    def testDeepMatchMultitaskRanking(self):
        """Tests DeepMatch with multitask ranking"""
        query = tf.constant(self.query, dtype=tf.int32)
        doc_fields = self._get_constant_doc_fields()
        wide_ftrs = tf.constant(self.wide_ftrs, dtype=tf.float32)
        task_id_field = tf.constant([1, 0])

        hparamscp = copy.copy(self.hparams)
        hparamscp.task_ids = [0, 1]
        hparamscp.task_weights = [0.2, 0.8]

        dm = deep_match.DeepMatch(query=query,
                                  wide_ftrs=wide_ftrs,
                                  doc_fields=doc_fields,
                                  hparams=hparamscp,
                                  mode=tf.estimator.ModeKeys.EVAL,
                                  task_id_field=task_id_field)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            scores = sess.run(dm.scores)

            # Check sizes and shapes
            self.assertAllEqual(scores.shape, [2, 3])