def test_equi_join_indices_for_broadcast(self): """Breaking down the broadcast.""" a = tf.constant([0, 1, 1], dtype=tf.int64) b = tf.constant([0, 1, 2], dtype=tf.int64) [index_a, index_b] = struct2tensor_ops.equi_join_indices(a, b) self.assertAllEqual(index_a, [0, 1, 2]) self.assertAllEqual(index_b, [0, 1, 1])
def benchmark_fn(session): a = tf.compat.v1.placeholder(dtype=tf.int64, shape=(None, )) b = tf.compat.v1.placeholder(dtype=tf.int64, shape=(None, )) result = struct2tensor_ops.equi_join_indices(a, b) with tf.control_dependencies(result): x = tf.constant(1) return session.make_callable(x, feed_list=[a, b])
def test_equi_join_indices(self, a, b, expected_index_a, expected_index_b): a = tf.constant(a, dtype=tf.int64) b = tf.constant(b, dtype=tf.int64) # Test equi_join_indices [index_a, index_b] = struct2tensor_ops.equi_join_indices(a, b) self.assertAllEqual(index_a, expected_index_a) self.assertAllEqual(index_b, expected_index_b) # Test equi_join_any_indices [index_a, index_b] = struct2tensor_ops.equi_join_any_indices(a, b) self.assertAllEqual(index_a, expected_index_a) self.assertAllEqual(index_b, expected_index_b)
def _filter_by_parent_indices_to_keep( node_value, parent_indices_to_keep): """Filter by parent indices to keep.""" [new_parent_index, self_indices_to_keep ] = struct2tensor_ops.equi_join_indices(parent_indices_to_keep, node_value.parent_index) if isinstance(node_value, prensor.ChildNodeTensor): return _FilterChildNodeTensor(new_parent_index, node_value.is_repeated, self_indices_to_keep) if isinstance(node_value, prensor.LeafNodeTensor): return prensor.LeafNodeTensor( new_parent_index, tf.gather(node_value.values, self_indices_to_keep), node_value.is_repeated) raise ValueError("Unknown NodeValue type")
def _get_any_parsed_field(value_field, type_url_field, field_name): """Helper function for _get_any_parsed_fields.""" full_name = get_full_name_from_any_step(field_name) indices_with_type = _any_indices_with_type(type_url_field, full_name) [index_to_solution_index, index_to_values ] = struct2tensor_ops.equi_join_indices(indices_with_type, value_field.index) solution_index = tf.gather(indices_with_type, index_to_solution_index) solution_value = tf.gather(value_field.value, index_to_values) # TODO(martinz): make _ParsedField public. return struct2tensor_ops._ParsedField( # pylint: disable=protected-access field_name=field_name, field_descriptor=None, index=solution_index, value=solution_value)
def calculate(self, sources, destinations, options): [origin_value, sibling_value] = sources if not isinstance(origin_value, prensor.LeafNodeTensor): raise ValueError("origin not a LeafNodeTensor") if not isinstance(sibling_value, prensor.ChildNodeTensor): raise ValueError("sibling value is not a ChildNodeTensor") sibling_to_parent_index = sibling_value.parent_index # For each i, for each v, if there exist exactly n values j such that: # sibling_to_parent_index[i]==origin_value.parent_index[j] # then there exists exactly n values k such that: # new_parent_index[k] = i # new_values[k] = origin_value.values[j] # (Ordering is also preserved). [broadcasted_to_sibling_index, index_to_values ] = struct2tensor_ops.equi_join_indices(sibling_to_parent_index, origin_value.parent_index) new_values = tf.gather(origin_value.values, index_to_values) return prensor.LeafNodeTensor(broadcasted_to_sibling_index, new_values, self.is_repeated)
def test_equi_join_indices_no_overlap(self): a = tf.constant([0, 1, 1, 2], dtype=tf.int64) b = tf.constant([3, 4, 5], dtype=tf.int64) [index_a, index_b] = struct2tensor_ops.equi_join_indices(a, b) self.assertAllEqual(index_a, []) self.assertAllEqual(index_b, [])
def test_equi_join_indices_both_empty(self): a = tf.constant([], dtype=tf.int64) b = tf.constant([], dtype=tf.int64) [index_a, index_b] = struct2tensor_ops.equi_join_indices(a, b) self.assertAllEqual(index_a, []) self.assertAllEqual(index_b, [])
def test_equi_join_indices(self): a = tf.constant([0, 0, 1, 1, 2, 3, 4], dtype=tf.int64) b = tf.constant([0, 0, 2, 2, 3], dtype=tf.int64) [index_a, index_b] = struct2tensor_ops.equi_join_indices(a, b) self.assertAllEqual(index_a, [0, 0, 1, 1, 4, 4, 5]) self.assertAllEqual(index_b, [0, 1, 0, 1, 2, 3, 4])