Esempio n. 1
0
 def testGetCorrectMatchedRowIndices(self):
     match_results = tf.constant([3, 1, -1, 0, -1, 5, -2])
     match = matcher.Match(match_results)
     expected_row_indices = [3, 1, 0, 5]
     matched_row_indices = match.matched_row_indices()
     self.assertEquals(matched_row_indices.dtype, tf.int32)
     with self.test_session() as sess:
         matched_row_inds = sess.run(matched_row_indices)
         self.assertAllEqual(matched_row_inds, expected_row_indices)
Esempio n. 2
0
 def test_get_correct_ignored_column_indices(self):
     match_results = tf.constant([3, 1, -1, 0, -1, 5, -2])
     match = matcher.Match(match_results)
     expected_column_indices = [6]
     ignored_column_indices = match.ignored_column_indices()
     self.assertEquals(ignored_column_indices.dtype, tf.int32)
     with self.test_session() as sess:
         ignored_column_indices = sess.run(ignored_column_indices)
         self.assertAllEqual(ignored_column_indices,
                             expected_column_indices)
Esempio n. 3
0
 def testGetCorrectUnmatchedColumnIndices(self):
     match_results = tf.constant([3, 1, -1, 0, -1, 5, -2])
     match = matcher.Match(match_results)
     expected_column_indices = [2, 4]
     unmatched_column_indices = match.unmatched_column_indices()
     self.assertEquals(unmatched_column_indices.dtype, tf.int32)
     with self.test_session() as sess:
         unmatched_column_indices = sess.run(unmatched_column_indices)
         self.assertAllEqual(unmatched_column_indices,
                             expected_column_indices)
Esempio n. 4
0
 def test_get_correct_matched_column_indicator(self):
     match_results = tf.constant([3, 1, -1, 0, -1, 5, -2])
     match = matcher.Match(match_results)
     expected_column_indicator = [
         True, True, False, True, False, True, False
     ]
     matched_column_indicator = match.matched_column_indicator()
     self.assertEquals(matched_column_indicator.dtype, tf.bool)
     with self.test_session() as sess:
         matched_column_indicator = sess.run(matched_column_indicator)
         self.assertAllEqual(matched_column_indicator,
                             expected_column_indicator)
Esempio n. 5
0
 def test_scalar_gather_based_on_match(self):
     match_results = tf.constant([3, 1, -1, 0, -1, 5, -2])
     input_tensor = tf.constant([0, 1, 2, 3, 4, 5, 6, 7], dtype=tf.float32)
     expected_gathered_tensor = [3, 1, 100, 0, 100, 5, 200]
     match = matcher.Match(match_results)
     gathered_tensor = match.gather_based_on_match(input_tensor,
                                                   unmatched_value=100.,
                                                   ignored_value=200.)
     self.assertEquals(gathered_tensor.dtype, tf.float32)
     with self.test_session():
         gathered_tensor_out = gathered_tensor.eval()
     self.assertAllEqual(expected_gathered_tensor, gathered_tensor_out)
Esempio n. 6
0
 def test_multidimensional_gather_based_on_match(self):
     match_results = tf.constant([1, -1, -2])
     input_tensor = tf.constant([[0, 0.5, 0, 0.5], [0, 0, 0.5, 0.5]],
                                dtype=tf.float32)
     expected_gathered_tensor = [[0, 0, 0.5, 0.5], [0, 0, 0, 0],
                                 [0, 0, 0, 0]]
     match = matcher.Match(match_results)
     gathered_tensor = match.gather_based_on_match(
         input_tensor,
         unmatched_value=tf.zeros(4),
         ignored_value=tf.zeros(4))
     self.assertEquals(gathered_tensor.dtype, tf.float32)
     with self.test_session():
         gathered_tensor_out = gathered_tensor.eval()
     self.assertAllEqual(expected_gathered_tensor, gathered_tensor_out)
Esempio n. 7
0
 def test_multidimensional_gather_based_on_match_with_matmul_gather_op(
         self):
     match_results = tf.constant([1, -1, -2])
     input_tensor = tf.constant([[0, 0.5, 0, 0.5], [0, 0, 0.5, 0.5]],
                                dtype=tf.float32)
     expected_gathered_tensor = [[0, 0, 0.5, 0.5], [0, 0, 0, 0],
                                 [0, 0, 0, 0]]
     match = matcher.Match(match_results, use_matmul_gather=True)
     gathered_tensor = match.gather_based_on_match(
         input_tensor,
         unmatched_value=tf.zeros(4),
         ignored_value=tf.zeros(4))
     self.assertEquals(gathered_tensor.dtype, tf.float32)
     with self.test_session() as sess:
         self.assertTrue(
             all([
                 op.name is not 'Gather'
                 for op in sess.graph.get_operations()
             ]))
         gathered_tensor_out = gathered_tensor.eval()
     self.assertAllEqual(expected_gathered_tensor, gathered_tensor_out)
Esempio n. 8
0
 def test_all_columns_accounted_for(self):
     # Note: deliberately setting to small number so not always
     # all possibilities appear (matched, unmatched, ignored)
     num_matches = 10
     match_results = tf.random_uniform([num_matches],
                                       minval=-2,
                                       maxval=5,
                                       dtype=tf.int32)
     match = matcher.Match(match_results)
     matched_column_indices = match.matched_column_indices()
     unmatched_column_indices = match.unmatched_column_indices()
     ignored_column_indices = match.ignored_column_indices()
     with self.test_session() as sess:
         matched, unmatched, ignored = sess.run([
             matched_column_indices, unmatched_column_indices,
             ignored_column_indices
         ])
         all_indices = np.hstack((matched, unmatched, ignored))
         all_indices_sorted = np.sort(all_indices)
         self.assertAllEqual(all_indices_sorted,
                             np.arange(num_matches, dtype=np.int32))
Esempio n. 9
0
 def test_get_correct_counts(self):
     match_results = tf.constant([3, 1, -1, 0, -1, 5, -2])
     match = matcher.Match(match_results)
     exp_num_matched_columns = 4
     exp_num_unmatched_columns = 2
     exp_num_ignored_columns = 1
     num_matched_columns = match.num_matched_columns()
     num_unmatched_columns = match.num_unmatched_columns()
     num_ignored_columns = match.num_ignored_columns()
     self.assertEquals(num_matched_columns.dtype, tf.int32)
     self.assertEquals(num_unmatched_columns.dtype, tf.int32)
     self.assertEquals(num_ignored_columns.dtype, tf.int32)
     with self.test_session() as sess:
         (num_matched_columns_out, num_unmatched_columns_out,
          num_ignored_columns_out) = sess.run([
              num_matched_columns, num_unmatched_columns,
              num_ignored_columns
          ])
         self.assertAllEqual(num_matched_columns_out,
                             exp_num_matched_columns)
         self.assertAllEqual(num_unmatched_columns_out,
                             exp_num_unmatched_columns)
         self.assertAllEqual(num_ignored_columns_out,
                             exp_num_ignored_columns)