def testDeepMatcDocUsrProjection(self): """Tests DeepMatch with doc/usr fields projection.""" hparams = copy.copy(self.hparams) # ftr_ext = cnn with tf.Graph().as_default(): query = tf.constant(self.query, dtype=tf.int32) doc_fields = self._get_constant_doc_fields() usr_fields = self._get_constant_usr_fields() wide_ftrs = tf.constant(self.wide_ftrs, dtype=tf.float32) setattr(hparams, 'use_doc_projection', True) setattr(hparams, 'use_usr_projection', True) setattr(hparams, 'num_usr_fields', len(usr_fields)) setattr(hparams, 'explicit_empty', False) # num_sim_ftrs should be doc projection size* (query + user projection size) expected_num_sim_ftrs = 1 * (1 + 1) # Test no query field when ftr_ext = bert dm = deep_match.DeepMatch(query, wide_ftrs, doc_fields, hparams, tf.estimator.ModeKeys.EVAL, usr_fields=usr_fields) self.assertAllEqual(dm.deep_ftr_model.num_sim_ftrs, expected_num_sim_ftrs) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) self.assertAllEqual(dm.deep_ftr_model.sim_ftrs.eval().shape, self.group_size + [expected_num_sim_ftrs])
def testDeepMatchWithUsrField(self): """Tests DeepMatch with user fields.""" hparams = copy.copy(self.hparams) # ftr_ext = bert with tf.Graph().as_default(): query = tf.constant(self.query, dtype=tf.int32) wide_ftrs = tf.constant(self.wide_ftrs, dtype=tf.float32) doc_fields = self._get_constant_doc_fields() usr_fields = self._get_constant_usr_fields() setattr(hparams, 'emb_sim_func', ['concat']) setattr(hparams, 'num_usr_fields', len(usr_fields)) # Test no query field when ftr_ext = bert dm = deep_match.DeepMatch(query, wide_ftrs, doc_fields, hparams, tf.estimator.ModeKeys.EVAL, usr_fields=usr_fields) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) wide_ftrs = wide_ftrs.eval() # all_ftrs should contain information about wide_ftrs user_fields and doc_fields self.assertAllEqual( dm.all_ftrs.eval().shape[-1], (len(doc_fields) + len(usr_fields) + 1) * (dm.deep_ftr_model.ftr_size + int(hparams.explicit_empty)) + wide_ftrs.shape[-1])
def testDeepMatchNoQuery(self): """Tests DeepMatch with query as None. Two conditions must be met when query is None: 1. query_ftrs of model.text_encoding_model must be None 2. all_ftrs should only contain information about doc_ftrs and wide_ftrs """ hparams = copy.copy(self.hparams) # ftr_ext = cnn with tf.Graph().as_default(): query = None doc_fields = self._get_constant_doc_fields() wide_ftrs = tf.constant(self.wide_ftrs, dtype=tf.float32) setattr(hparams, 'ftr_ext', 'cnn') # Test no query field when ftr_ext = bert dm = deep_match.DeepMatch(query, wide_ftrs, doc_fields, hparams, tf.estimator.ModeKeys.EVAL) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) wide_ftrs = wide_ftrs.eval() # all_ftrs should only contain wide_ftrs and doc_fields doc_ftrs_size_after_cnn = len(doc_fields) * len( hparams.filter_window_sizes) * hparams.num_filters self.assertAllEqual( dm.all_ftrs.eval().shape[-1], doc_ftrs_size_after_cnn + wide_ftrs.shape[-1] + len(doc_fields)) # query_ftrs of model.text_encoding_model must be None self.assertAllEqual( dm.deep_ftr_model.text_encoding_model.query_ftrs, None)
def testDeepMatchNaNFtrs(self): """Tests DeepMatch outputs""" query = tf.constant(self.query, dtype=tf.int32) doc_fields = self._get_constant_doc_fields() nan_wide_ftrs = np.copy(self.wide_ftrs) nan_removed_wide_ftrs = np.copy(self.wide_ftrs) nan_wide_ftrs[0] = np.nan nan_removed_wide_ftrs[0] = 0 wide_ftrs = tf.constant(nan_wide_ftrs, dtype=tf.float32) hparams = self.hparams dm = deep_match.DeepMatch(query, wide_ftrs, doc_fields, hparams, tf.estimator.ModeKeys.EVAL) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) wide_ftrs_dm, scores = sess.run( [dm._wide_ftrs, dm.scores] ) # check sizes and shapes self.assertAllClose(wide_ftrs_dm, nan_removed_wide_ftrs) self.assertAllEqual(scores.shape, [2, 3])
def testDeepMatchClassificationNoQueryNoWide(self): """Tests DeepMatch for classification outputs""" doc_field1 = tf.constant(np.random.rand(2, 1, 4), dtype=tf.int32) doc_field2 = tf.constant(np.random.rand(2, 1, 4), dtype=tf.int32) doc_fields = [doc_field1, doc_field2] hparamscp = copy.copy(self.hparams) hparamscp.num_classes = 7 dm = deep_match.DeepMatch(None, None, doc_fields, hparamscp, tf.estimator.ModeKeys.EVAL) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) scores = sess.run(dm.scores) # Check sizes and shapes self.assertAllEqual(scores.shape, [2, hparamscp.num_classes])
def testDeepMatch(self): """Tests DeepMatch outputs""" hparams = self.hparams query = tf.constant(self.query, dtype=tf.int32) doc_fields = self._get_constant_doc_fields() wide_ftrs = tf.constant(self.wide_ftrs, dtype=tf.float32) dm = deep_match.DeepMatch(query, wide_ftrs, doc_fields, hparams, tf.estimator.ModeKeys.EVAL) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) scores = sess.run(dm.scores) # Check sizes and shapes self.assertAllEqual(scores.shape, [2, 3])
def testDeepMatchMultitaskRanking(self): """Tests DeepMatch with multitask ranking""" query = tf.constant(self.query, dtype=tf.int32) doc_fields = self._get_constant_doc_fields() wide_ftrs = tf.constant(self.wide_ftrs, dtype=tf.float32) task_id_field = tf.constant([1, 0]) hparamscp = copy.copy(self.hparams) hparamscp.task_ids = [0, 1] hparamscp.task_weights = [0.2, 0.8] dm = deep_match.DeepMatch(query=query, wide_ftrs=wide_ftrs, doc_fields=doc_fields, hparams=hparamscp, mode=tf.estimator.ModeKeys.EVAL, task_id_field=task_id_field) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) scores = sess.run(dm.scores) # Check sizes and shapes self.assertAllEqual(scores.shape, [2, 3])