예제 #1
0
 def test_equi_join_indices_for_broadcast(self):
     """Breaking down the broadcast."""
     a = tf.constant([0, 1, 1], dtype=tf.int64)
     b = tf.constant([0, 1, 2], dtype=tf.int64)
     [index_a, index_b] = struct2tensor_ops.equi_join_indices(a, b)
     self.assertAllEqual(index_a, [0, 1, 2])
     self.assertAllEqual(index_b, [0, 1, 1])
예제 #2
0
 def benchmark_fn(session):
     a = tf.compat.v1.placeholder(dtype=tf.int64, shape=(None, ))
     b = tf.compat.v1.placeholder(dtype=tf.int64, shape=(None, ))
     result = struct2tensor_ops.equi_join_indices(a, b)
     with tf.control_dependencies(result):
         x = tf.constant(1)
     return session.make_callable(x, feed_list=[a, b])
    def test_equi_join_indices(self, a, b, expected_index_a, expected_index_b):
        a = tf.constant(a, dtype=tf.int64)
        b = tf.constant(b, dtype=tf.int64)

        # Test equi_join_indices
        [index_a, index_b] = struct2tensor_ops.equi_join_indices(a, b)
        self.assertAllEqual(index_a, expected_index_a)
        self.assertAllEqual(index_b, expected_index_b)

        # Test equi_join_any_indices
        [index_a, index_b] = struct2tensor_ops.equi_join_any_indices(a, b)
        self.assertAllEqual(index_a, expected_index_a)
        self.assertAllEqual(index_b, expected_index_b)
예제 #4
0
def _filter_by_parent_indices_to_keep(
    node_value,
    parent_indices_to_keep):
  """Filter by parent indices to keep."""
  [new_parent_index, self_indices_to_keep
  ] = struct2tensor_ops.equi_join_indices(parent_indices_to_keep,
                                          node_value.parent_index)
  if isinstance(node_value, prensor.ChildNodeTensor):
    return _FilterChildNodeTensor(new_parent_index, node_value.is_repeated,
                                  self_indices_to_keep)
  if isinstance(node_value, prensor.LeafNodeTensor):
    return prensor.LeafNodeTensor(
        new_parent_index, tf.gather(node_value.values, self_indices_to_keep),
        node_value.is_repeated)
  raise ValueError("Unknown NodeValue type")
예제 #5
0
def _get_any_parsed_field(value_field, type_url_field, field_name):
    """Helper function for _get_any_parsed_fields."""
    full_name = get_full_name_from_any_step(field_name)
    indices_with_type = _any_indices_with_type(type_url_field, full_name)
    [index_to_solution_index, index_to_values
     ] = struct2tensor_ops.equi_join_indices(indices_with_type,
                                             value_field.index)
    solution_index = tf.gather(indices_with_type, index_to_solution_index)
    solution_value = tf.gather(value_field.value, index_to_values)
    # TODO(martinz): make _ParsedField public.
    return struct2tensor_ops._ParsedField(  # pylint: disable=protected-access
        field_name=field_name,
        field_descriptor=None,
        index=solution_index,
        value=solution_value)
예제 #6
0
 def calculate(self, sources, destinations, options):
     [origin_value, sibling_value] = sources
     if not isinstance(origin_value, prensor.LeafNodeTensor):
         raise ValueError("origin not a LeafNodeTensor")
     if not isinstance(sibling_value, prensor.ChildNodeTensor):
         raise ValueError("sibling value is not a ChildNodeTensor")
     sibling_to_parent_index = sibling_value.parent_index
     # For each i, for each v, if there exist exactly n values j such that:
     # sibling_to_parent_index[i]==origin_value.parent_index[j]
     # then there exists exactly n values k such that:
     # new_parent_index[k] = i
     # new_values[k] = origin_value.values[j]
     # (Ordering is also preserved).
     [broadcasted_to_sibling_index, index_to_values
      ] = struct2tensor_ops.equi_join_indices(sibling_to_parent_index,
                                              origin_value.parent_index)
     new_values = tf.gather(origin_value.values, index_to_values)
     return prensor.LeafNodeTensor(broadcasted_to_sibling_index, new_values,
                                   self.is_repeated)
예제 #7
0
 def test_equi_join_indices_no_overlap(self):
     a = tf.constant([0, 1, 1, 2], dtype=tf.int64)
     b = tf.constant([3, 4, 5], dtype=tf.int64)
     [index_a, index_b] = struct2tensor_ops.equi_join_indices(a, b)
     self.assertAllEqual(index_a, [])
     self.assertAllEqual(index_b, [])
예제 #8
0
 def test_equi_join_indices_both_empty(self):
     a = tf.constant([], dtype=tf.int64)
     b = tf.constant([], dtype=tf.int64)
     [index_a, index_b] = struct2tensor_ops.equi_join_indices(a, b)
     self.assertAllEqual(index_a, [])
     self.assertAllEqual(index_b, [])
예제 #9
0
 def test_equi_join_indices(self):
     a = tf.constant([0, 0, 1, 1, 2, 3, 4], dtype=tf.int64)
     b = tf.constant([0, 0, 2, 2, 3], dtype=tf.int64)
     [index_a, index_b] = struct2tensor_ops.equi_join_indices(a, b)
     self.assertAllEqual(index_a, [0, 0, 1, 1, 4, 4, 5])
     self.assertAllEqual(index_b, [0, 1, 0, 1, 2, 3, 4])