def test_map_values(self): expr = create_expression.create_expression_from_prensor( prensor_test_util.create_simple_prensor()) new_root = map_values.map_values(expr, path.Path(["foo"]), lambda x: x * 2, tf.int64, "foo_doubled") leaf_node = expression_test_util.calculate_value_slowly( new_root.get_descendant_or_error(path.Path(["foo_doubled"]))) self.assertAllEqual(leaf_node.parent_index, [0, 1, 2]) self.assertAllEqual(leaf_node.values, [18, 16, 14])
def has(root, source_path, new_field_name): """Get the has of a field as a new sibling field. Args: root: the original expression. source_path: the source path to measure. Cannot be root. new_field_name: the name of the sibling field. Returns: The new expression. """ new_root, size_p = size_anonymous(root, source_path) # TODO(martinz): consider using copy_over to "remove" the size field # from the result. return map_values.map_values( new_root, size_p, lambda x: tf.greater(x, tf.constant(0, dtype=tf.int64)), tf.bool, new_field_name)