def test_apply_for_tensor(self, index, expected): tensor = tf.constant([[1.2, 3.4], [5.6, 7.8]]) arg_values = [value.ConstantValue(tensor), value.ConstantValue(index)] result_value = self.operation.apply(arg_values, settings_module.default_settings()) self.assertEqual(result_value, value.ConstantValue(tf.constant(expected)))
def setUp(self): super(SlicingAxis1BothOperationTest, self).setUp() self.operation = python_operations.SlicingAxis1BothOperation() self.arg_values = [ value.InputValue([[12, 34, 56, 78], [-1, -2, -3, -4]], 'my_input'), value.ConstantValue(1), value.ConstantValue(-1) ]
def setUp(self): super(TripleCreationOperationTest, self).setUp() self.operation = python_operations.TripleCreationOperation() self.arg_values = [ value.ConstantValue(12), value.ConstantValue(34), value.ConstantValue(56) ]
def test_apply_for_list(self, index, expected): arg_values = [ value.ConstantValue([1.2, 3.4, 5.6]), value.ConstantValue(index) ] result_value = self.operation.apply(arg_values, settings_module.default_settings()) self.assertEqual(result_value, value.ConstantValue(expected))
def test_apply_returns_none_if_bad_value(self): operation = function_operation.FunctionOperation( tf_functions.FunctionInfo( name='tf.one_hot(indices, depth)', filter_group=filter_group.FilterGroup.NONE, weight=2)) indices = value.ConstantValue(tf.ones([limits.MAX_DIMENSION_LENGTH])) depth = value.ConstantValue(limits.MAX_TENSOR_ELEMENTS) self.assertIsNone(operation.apply([indices, depth], self.settings))
def test_apply_returns_none_if_exception(self): operation = function_operation.FunctionOperation( tf_functions.FunctionInfo( name='tf.reduce_sum(input_tensor, axis)', filter_group=filter_group.FilterGroup.NONE, weight=2)) input_tensor = value.ConstantValue( tf.constant([[1, 3, 4], [50, 20, 80]])) axis_2 = value.ConstantValue(2) self.assertIsNone( operation.apply([input_tensor, axis_2], self.settings))
def test_reconstruct_all_expressions_with_input_names(self): input_0 = value.InputValue(0, 'in_0') constant_1 = value.ConstantValue(1) constant_2 = value.ConstantValue(2) my_input = value.InputValue([[1, 3, 2], [-3, 0, 4]], 'my_input') final_value = self.operation.apply([my_input, constant_1], self.settings) add_operation = function_operation.FunctionOperation( tf_functions.FunctionInfo( name='tf.add(x, y)', filter_group=filter_group.FilterGroup.NONE, weight=1)) input_1321 = value.InputValue([[1, 3], [2, 1]], 'in_1321') value_2432 = add_operation.apply([input_1321, constant_1], self.settings) input_0210 = value.InputValue([[0, 2], [1, 0]], 'in_0210') value_2432_duplicate = add_operation.apply([input_0210, constant_2], self.settings) self.assertEqual(value_2432, value_2432_duplicate) value_2432.merge_reconstructions(value_2432_duplicate) final_value_duplicate = self.operation.apply([value_2432, input_0], self.settings) self.assertEqual(final_value, final_value_duplicate) final_value.merge_reconstructions(final_value_duplicate) input_1430 = value.InputValue([[1, 4], [3, 0]], 'in_1430') final_value_duplicate = self.operation.apply([input_1430, input_0], self.settings) self.assertEqual(final_value, final_value_duplicate) final_value.merge_reconstructions(final_value_duplicate) expected = [ ('tf.reduce_max(my_input, axis=1)', {'my_input'}), ('tf.reduce_max(tf.add(in_1321, 1), axis=in_0)', {'in_1321', 'in_0'}), ('tf.reduce_max(tf.add(in_0210, 2), axis=in_0)', {'in_0210', 'in_0'}), ('tf.reduce_max(in_1430, axis=in_0)', {'in_1430', 'in_0'}), ] self.assertEqual( final_value.reconstruct_all_expressions_with_input_names(), expected)
def test_merge_reconstructions(self): tensor_value = value.InputValue([[1, 3, 2], [-3, 0, 4]], 'my_input') axis_value = value.ConstantValue(1) operation_value = self.operation.apply([tensor_value, axis_value], self.settings) self.assertLen(operation_value.operation_applications, 1) tensor_value_2 = value.InputValue([[1, 4], [3, 0]], 'my_input_2') axis_value_2 = value.ConstantValue(0) operation_value_2 = self.operation.apply([tensor_value_2, axis_value_2], self.settings) self.assertEqual(operation_value, operation_value_2) operation_value.merge_reconstructions(operation_value_2) self.assertLen(operation_value.operation_applications, 2)
def test_apply_succeeds(self): operation = function_operation.FunctionOperation( tf_functions.FunctionInfo( name='tf.reduce_sum(input_tensor, axis)', filter_group=filter_group.FilterGroup.NONE, weight=2)) input_tensor = value.ConstantValue( tf.constant([[1, 3, 4], [50, 20, 80]])) axis_0 = value.ConstantValue(0) axis_1 = value.ConstantValue(1) self.assertEqual( operation.apply([input_tensor, axis_0], self.settings), value.ConstantValue(tf.constant([51, 23, 84]))) self.assertEqual( operation.apply([input_tensor, axis_1], self.settings), value.ConstantValue(tf.constant([8, 150])))
def setUp(self): super(IndexingAxis1OperationTest, self).setUp() self.operation = python_operations.IndexingAxis1Operation() self.arg_values = [ value.InputValue([[12, 34], [56, 78]], 'my_input'), value.ConstantValue(1) ]
def test_extract_examples_from_value_without_inputs(self): constant_1 = value_module.ConstantValue(1) constant_2 = value_module.ConstantValue(2) constant_3 = value_module.ConstantValue(3) subtree = self.add_operation.apply([constant_1, constant_2], self.settings) without_inputs = self.add_operation.apply([subtree, constant_3], self.settings) actual = collect_tensor_data.extract_examples_from_value( without_inputs) self.assertCountEqual( [example.expression for example in actual], # `tf.add(tf.add(1, 2), 3)` has no inputs and is not included. ['tf.add(in1, 3)'])
def setUp(self): super(SlicingAxis1RightOperationTest, self).setUp() self.operation = python_operations.SlicingAxis1RightOperation() self.arg_values = [ value.InputValue([[12, 34, 56]], 'my_input'), value.ConstantValue(-1) ]
def test_init_with_primitive(self): constant_value = value.ConstantValue(12) self.assertEqual(constant_value.type, int) self.assertTrue(constant_value.is_primitive) self.assertIsNone(constant_value.elem_type) self.assertFalse(constant_value.elem_type_is_tensor) self.assertIsNone(constant_value.dtype) self.assertIsNone(constant_value.shape)
def create_examples(io_example: IOExample, max_num_inputs: int = 3, permute_inputs: bool = True) -> List[Dict[Text, Any]]: """Creates example dicts for the I/O example.""" examples = [] operation_list = all_operations.get_operations( include_sparse_operations=True) operation_counter = collections.Counter( [op.name for op in io_example.operations]) operation_counts = [operation_counter[op.name] for op in operation_list] num_inputs = len(io_example.input_values) try: num_inputs_feature = min(num_inputs, max_num_inputs) output_features = featurize_value(io_example.output_value) input_features = [] for input_value in io_example.input_values[:max_num_inputs]: combined_features = featurize_value(input_value) combined_features.update( featurize_input_and_output(input_value, io_example.output_value)) input_features.append(combined_features) dummy_value = value_module.ConstantValue(0) dummy_input_features = featurize_value(dummy_value) dummy_input_features.update( featurize_input_and_output(dummy_value, dummy_value)) except ValueError as e: logging.warning('%s: could not featurize IOExample %s', e, io_example) return [] permutations = (itertools.permutations(range(num_inputs)) if permute_inputs else [list(range(num_inputs))]) for permutation in permutations: feature_dict = collections.defaultdict(list) feature_dict['num_inputs'] = [num_inputs_feature] feature_dict.update(copy.deepcopy(output_features)) padded_input_features = [ input_features[index] for index in permutation ] for _ in range(max_num_inputs - num_inputs): padded_input_features.append(dummy_input_features) for input_features_dict in padded_input_features: for key, value in six.iteritems(input_features_dict): feature_dict[key].extend(value) feature_dict['operations'] = operation_counts feature_dict['expression'] = [io_example.expression] examples.append(feature_dict) return examples
def test_reconstruct_expression(self): operation = function_operation.FunctionOperation( tf_functions.FunctionInfo( name='tf.reduce_sum(input_tensor, axis)', filter_group=filter_group.FilterGroup.NONE, weight=2)) arg_1 = value.InputValue([[1, 3], [50, 20]], 'my_input') arg_2 = value.ConstantValue('tf-coder') self.assertEqual(operation.reconstruct_expression([arg_1, arg_2]), "tf.reduce_sum(my_input, axis='tf-coder')")
def test_reconstruct_all_expressions_with_input_names_using_addition(self): constants = [value.ConstantValue(i) for i in range(10)] add_operation = function_operation.FunctionOperation( tf_functions.FunctionInfo( name='tf.add(x, y)', filter_group=filter_group.FilterGroup.NONE, weight=1)) # The i-th element contains all unique Value objects of weight i, mapped to # themselves to allow retrieving the stored Value equal to some query Value. values_by_weight = [collections.OrderedDict()] # Nothing of weight 0. # Add constants with weight 1. values_by_weight.append(collections.OrderedDict()) for constant in constants: values_by_weight[1][constant] = constant for weight in range(2, 6): new_values = collections.OrderedDict() for arg_1_weight in range(1, weight): arg_2_weight = weight - arg_1_weight - 1 for arg1, arg2 in itertools.product( values_by_weight[arg_1_weight], values_by_weight[arg_2_weight]): result = add_operation.apply([arg1, arg2], self.settings) if result not in new_values: new_values[result] = result else: new_values[result].merge_reconstructions(result) values_by_weight.append(new_values) query = value.OutputValue(9) # The form must be (a + b), where there are 10 choices for a, which then # determines b. reconstructions = (values_by_weight[3][query]. reconstruct_all_expressions_with_input_names()) self.assertLen(reconstructions, 10) # No expression uses input values. self.assertTrue( all(not bool(used_names) for _, used_names in reconstructions)) # No AST with only binary operators has weight 4. self.assertEmpty(values_by_weight[4]) # The form is either (a + (b + c)) or ((a + b) + c). Each of the two forms # has 1 + 2 + ... + 9 = 45 options. Note that "a" in (a + (b + c)) cannot be # 0, or else (b + c) would have the same value as the entire expression. # Similarly, "c" in ((a + b) + c) cannot be 0. self.assertLen( values_by_weight[5] [query].reconstruct_all_expressions_with_input_names(), 90)
def test_run_value_handles_large_weight_constants(self): benchmark = benchmark_module.Benchmark(examples=[ benchmark_module.Example(inputs=[[1], [2]], output=[[3]]) ]) results = value_search.run_value_search(benchmark=benchmark, settings=self.settings) self.assertNotEmpty(results.solutions) self.assertEqual(results.solutions[0].expression, 'tf.add(in1, tf.expand_dims(in2, 0))') output_shape_constant = value_module.ConstantValue((1, 1)) self.assertIn(output_shape_constant, results.value_set) # Find the element in value_set equal to output_shape_constant and assert # that it's actually a ConstantValue, as opposed to an OperationValue. for value in results.value_set: if value == output_shape_constant: self.assertIsInstance(value, value_module.ConstantValue)
def setUp(self): super(CollectTensorDataTest, self).setUp() self.settings = settings_module.default_settings() operations = all_operations.get_operations() self.unique_with_counts_operation = all_operations.find_operation_with_name( 'tf.unique_with_counts(x)', operation_list=operations) self.indexing_operation = all_operations.find_operation_with_name( 'IndexingOperation', operation_list=operations) self.gather_operation = all_operations.find_operation_with_name( 'tf.gather(params, indices)', operation_list=operations) self.add_operation = all_operations.find_operation_with_name( 'tf.add(x, y)', operation_list=operations) # Example with many operations. in1 = value_module.InputValue([1, 1, 2, 5, 6, 5], 'in1') in2 = value_module.InputValue([0, 10, 20, 30, 40, 50, 60, 70], 'in2') constant_1 = value_module.ConstantValue(1) unique = self.unique_with_counts_operation.apply([in1], self.settings) indexed = self.indexing_operation.apply([unique, constant_1], self.settings) gathered = self.gather_operation.apply([in2, in1], self.settings) self.example_value_1 = self.add_operation.apply([indexed, gathered], self.settings) self.assertEqual( self.example_value_1.reconstruct_expression(), 'tf.add(tf.unique_with_counts(in1)[1], tf.gather(in2, in1))') self.assertEqual(self.example_value_1, value_module.OutputValue([10, 10, 21, 52, 63, 52])) # Example with many variables and new inputs. in3 = value_module.InputValue([1], 'in3') in4 = value_module.InputValue([2], 'in4') a = self.add_operation.apply([in3, new_input([10])], self.settings) b = self.add_operation.apply([in4, in3], self.settings) c = self.add_operation.apply([new_input([20]), in3], self.settings) d = self.add_operation.apply([a, b], self.settings) self.example_value_2 = self.add_operation.apply([c, d], self.settings) self.assertEqual( self.example_value_2.reconstruct_expression(), 'tf.add(tf.add(NEW_INPUT, in3), ' 'tf.add(tf.add(in3, NEW_INPUT), tf.add(in4, in3)))') self.assertEqual(self.example_value_2, value_module.OutputValue([35]))
def test_copy(self): tensor_value = value.InputValue([[1, 3, 2], [-3, 0, 4]], 'my_input') axis_value = value.ConstantValue(1) operation_value = self.operation.apply([tensor_value, axis_value], self.settings) copy_value = operation_value.copy() self.assertIsNot(operation_value, copy_value) self.assertTrue(operation_value.reconstruct_expression(use_cache=False) # pylint: disable=g-generic-assert == copy_value.reconstruct_expression(use_cache=False) == 'tf.reduce_max(my_input, axis=1)') copy_value.operation_applications[0].arg_values[0].name = 'new_name' self.assertEqual(operation_value.reconstruct_expression(use_cache=False), 'tf.reduce_max(my_input, axis=1)') self.assertEqual(copy_value.reconstruct_expression(use_cache=False), 'tf.reduce_max(new_name, axis=1)')
def test_reconstruct_expression(self): tensor_value = value.InputValue([[1, 3, 2], [-3, 0, 4]], 'my_input') axis_value = value.ConstantValue(1) operation_value = value.OperationValue(tf.constant([3, 4]), self.operation, [tensor_value, axis_value]) expected_expression = 'tf.reduce_max(my_input, axis=1)' self.assertEqual(operation_value.reconstruct_expression(), expected_expression) self.assertEqual(operation_value.reconstruct_expression(), expected_expression) # Cached. operation_value_apply = self.operation.apply([tensor_value, axis_value], self.settings) self.assertEqual(operation_value, operation_value_apply) self.assertEqual(operation_value_apply.reconstruct_expression(), expected_expression)
def test_get_type_filter(self): int_type_filter = operation_filtering.get_type_filter(int) self.assertTrue(int_type_filter(value.ConstantValue(1))) self.assertFalse(int_type_filter(value.ConstantValue(1.0)))
def test_apply(self): self.assertEqual( self.operation.apply(self.arg_values, settings_module.default_settings()), value.ConstantValue((12, )))
def _value(wrapped_value): """A simple utility to create Value objects.""" return value.ConstantValue(wrapped_value)
def _find_solutions( benchmark: benchmark_module.Benchmark, operations: List[operation_base.Operation], start_time: float, settings: settings_module.Settings ) -> Tuple[List[Solution], Set[value_module.Value], ValuesByWeight, Optional[operation_statistics.OperationStatistics]]: """Helper, returning (solutions, value_set, values_by_weight, statistics).""" timeout_reached = False end_time = start_time + settings.timeout only_minimal_solutions = settings.only_minimal_solutions if settings.max_solutions == 1: # If we only want one solution, it will be minimal. only_minimal_solutions = True # An object to track statistics, if requested. statistics = (operation_statistics.OperationStatistics() if settings.printing.statistics else None) # A list of Solution namedtuples. solutions = [] # A set of string solution expressions (don't return duplicate solutions). solution_expression_set = set() # The output value to search for. output_value = value_module.OutputValue(benchmark.examples[0].output) # A list of OrderedDicts mapping Value objects to themselves. The i-th # OrderedDict contains all Value objects of weight i. values_by_weight = [collections.OrderedDict() for _ in range(settings.max_weight + 1)] # Find and cache the cast and constant operations for use later. cast_operation = None constant_operation = None for operation in operations: if operation.name == tf_functions.CAST_OPERATION_NAME: cast_operation = operation elif operation.name == tf_functions.CONSTANT_OPERATION_NAME: constant_operation = operation # Create the output dtype value for use later. dtype_value = value_module.ConstantValue(output_value.dtype) # Populate values_by_weight with inputs and constants. This also prints # inputs/output/constants to stdout. _add_constants_and_inputs_and_print( values_by_weight, benchmark, output_value, constant_operation, settings) # A set storing all values found so far. value_set = set().union(*values_by_weight) filter_cache = filtered_values_cache.FilteredValuesCache() # Value search by weight. for weight in range(1, settings.max_weight + 1): if settings.printing.progress: print('Searching weight {}...'.format(weight)) # Values with the current weight. This might already include leaf values. new_values = values_by_weight[weight] for operation in operations: for value in operation.enumerate_values_with_weight( target_weight=weight, values_by_weight=values_by_weight, filter_cache=filter_cache, end_time=end_time, settings=settings, statistics=statistics): if value not in value_set: # This value has never been seen before, or it's the desired output. if settings.printing.verbose: expression = value.reconstruct_expression() print('{} produces:\n{}'.format(expression, value)) if value == output_value: possible_first_solution = not solutions # Found solution(s), but some may be bad. _record_solutions(value, weight, start_time, solutions, solution_expression_set, benchmark, settings) if possible_first_solution and solutions: end_time = min( end_time, timeit.default_timer() + settings.max_extra_solutions_time) if len(solutions) >= settings.max_solutions: return solutions, value_set, values_by_weight, statistics else: # Only store the value if it isn't a solution. Otherwise, we'll get # lots of "almost duplicate" solutions, e.g., by adding 0. new_values[value] = value # We should never add output_value (or anything equal) to value_set # so that we can continue finding other solutions. value_set.add(value) else: # This value has been seen before. if value in new_values: # The value was already computed differently with this weight. original_value = new_values[value] if isinstance(original_value, value_module.OperationValue): # Only merge reconstructions if this was originally an # OperationValue. (It could be a ConstantValue instead.) operation_value = original_value # type: value_module.OperationValue operation_value.merge_reconstructions(value) elif not only_minimal_solutions: # If we want non-minimal solutions, we need to store the value even # if we have already seen that value with a smaller weight. new_values[value] = value if timeit.default_timer() > end_time: timeout_reached = True # Don't return immediately; still try to cast new values because this is # relatively quick. break # Try casting new values to the output dtype if this has a chance of being # a correct solution. for new_value in new_values: if (cast_operation is not None and new_value.shape == output_value.shape and new_value.dtype != output_value.dtype and operation_filtering.is_castable(new_value, dtype_value)): casted_value = cast_operation.apply([new_value, dtype_value], settings) if casted_value == output_value: possible_first_solution = not solutions # Found solution(s), but some may be bad. _record_solutions(casted_value, weight, start_time, solutions, solution_expression_set, benchmark, settings) if possible_first_solution and solutions: end_time = min( end_time, timeit.default_timer() + settings.max_extra_solutions_time) if len(solutions) >= settings.max_solutions: return solutions, value_set, values_by_weight, statistics if settings.printing.progress: print('Found {} distinct values of weight {}, or {} total.'.format( len(new_values), weight, len(value_set))) if only_minimal_solutions and solutions: return solutions, value_set, values_by_weight, statistics if timeout_reached: break return solutions, value_set, values_by_weight, statistics
def _add_constants_and_inputs_and_print( values_by_weight: ValuesByWeight, benchmark: benchmark_module.Benchmark, output_value: value_module.OutputValue, constant_operation: operation_base.Operation, settings: settings_module.Settings) -> None: """Adds constant/input Values to values_by_weight, and prints to stdout.""" # Conceptually this is a set, but it's actually a list so that constants can # be printed in the same order they are chosen by the heuristics. The reduced # efficiency of membership-checking is not a big deal because we have few # constants. constants_so_far = set() constants_to_print = [] # User-provided constants. for c in benchmark.constants: if not _constant_exists(c, constants_so_far): constant_value = value_module.ConstantValue(c) weight = tf_functions.PROVIDED_CONSTANT_WEIGHT _add_value_by_weight(values_by_weight, constant_value, weight) constants_so_far.add(c) constants_to_print.append(c) # Add inputs, while computing some info for extra constants later. max_input_tensor_rank = 0 dimension_lengths = set() input_names_to_objects = _input_names_to_objects(benchmark.examples[0].inputs) for name, input_object in input_names_to_objects.items(): input_value = value_module.InputValue(input_object, name) if input_value.is_tensor: max_input_tensor_rank = max(max_input_tensor_rank, len(input_value.shape)) dimension_lengths.update(input_value.shape) if input_value.is_primitive and constant_operation is not None: scalar_tensor_value = constant_operation.apply([input_value], settings) _add_value_by_weight(values_by_weight, scalar_tensor_value, tf_functions.PRIMITIVE_INPUT_AS_TENSOR_WEIGHT) _add_value_by_weight(values_by_weight, input_value, tf_functions.INPUT_VARIABLE_WEIGHT) if input_value.is_primitive: constants_so_far.add(input_value.value) print("Input '{}':\n{!s}\n".format(name, input_value.value)) if output_value.shape is not None: dimension_lengths.update(output_value.shape) # Always include these as constants. common_constants = [0, 1, -1, True, False] # Also include 2, 3, ..., max_example_input_tensor_rank - 1 when applicable. axis_constants = list(range(2, max_input_tensor_rank)) # Also include dimension lengths of input and output tensors. shape_constants = sorted(dimension_lengths) constant_weight_pairs = ( [(c, tf_functions.COMMON_CONSTANT_WEIGHT) for c in common_constants] + [(c, tf_functions.AXIS_CONSTANT_WEIGHT) for c in axis_constants] + [(c, tf_functions.SHAPE_CONSTANT_WEIGHT) for c in shape_constants]) for constant, weight in constant_weight_pairs: if not _constant_exists(constant, constants_so_far): constant_value = value_module.ConstantValue(constant) _add_value_by_weight(values_by_weight, constant_value, weight) constants_so_far.add(constant) constants_to_print.append(constant) # DTypes for casting. for dtype, weight in tf_functions.CONSTANT_DTYPES_AND_WEIGHTS.items(): dtype_value = value_module.ConstantValue(dtype) _add_value_by_weight(values_by_weight, dtype_value, weight) if output_value.shape: # Add the output shape as a constant. shape_tuple = tuple(output_value.shape) shape_tuple_value = value_module.ConstantValue(shape_tuple) weight = tf_functions.OUTPUT_SHAPE_TUPLE_WEIGHT _add_value_by_weight(values_by_weight, shape_tuple_value, weight) # Don't add shape_tuple to constants_to_print, because printing it out could # be confusing to users. # Only for experiments in the PLDI paper. if settings.paper_experiments.uniform_weights: # Count the number of values. num_values = sum(len(values_with_weight) for values_with_weight in values_by_weight) # Take all values and put them in the collection for weight 1. for weight in range(2, len(values_by_weight)): for heavy_value in values_by_weight[weight]: values_by_weight[1][heavy_value] = heavy_value values_by_weight[weight].clear() # Make sure we did it right. for weight, values_with_weight in enumerate(values_by_weight): assert len(values_with_weight) == (num_values if weight == 1 else 0) print('Output:\n{!s}\n'.format(output_value.value)) print('Constants: {!r}\n'.format(constants_to_print)) if benchmark.description: print('Description: {}\n'.format(benchmark.description)) print('Searching...\n') sys.stdout.flush() # Flush so the inputs/output appear in Colab immediately.
def test_init_does_not_convert_to_tensor(self): constant_value = value.ConstantValue(1) self.assertFalse(constant_value.is_tensor) self.assertTrue(constant_value.is_primitive)
def test_copy(self): constant_value = value.ConstantValue(1) copy_value = constant_value.copy() self.assertIsNot(constant_value, copy_value) self.assertEqual(constant_value.reconstruct_expression(), copy_value.reconstruct_expression())
def test_reconstruct_expression(self): constant_value = value.ConstantValue('TF-Coder') self.assertEqual(constant_value.reconstruct_expression(), "'TF-Coder'")
def test_init_with_dtype(self): constant_value = value.ConstantValue(tf.int64) self.assertTrue(constant_value.is_dtype)
def test_get_dtype_filter(self): int32_dtype_filter = operation_filtering.get_dtype_filter(tf.int32) self.assertTrue(int32_dtype_filter(value.ConstantValue( tf.constant(1)))) self.assertFalse( int32_dtype_filter(value.ConstantValue(tf.constant(1.0))))