def setUp(self): super(OperationValueTest, self).setUp() function_info = tf_functions.FunctionInfo( name='tf.reduce_max(input_tensor, axis)', filter_group=filter_group.FilterGroup.NONE, weight=1) self.operation = function_operation.FunctionOperation(function_info) self.settings = settings_module.default_settings()
def test_apply_returns_none_if_bad_value(self): operation = function_operation.FunctionOperation( tf_functions.FunctionInfo( name='tf.one_hot(indices, depth)', filter_group=filter_group.FilterGroup.NONE, weight=2)) indices = value.ConstantValue(tf.ones([limits.MAX_DIMENSION_LENGTH])) depth = value.ConstantValue(limits.MAX_TENSOR_ELEMENTS) self.assertIsNone(operation.apply([indices, depth], self.settings))
def test_reconstruct_expression(self): operation = function_operation.FunctionOperation( tf_functions.FunctionInfo( name='tf.reduce_sum(input_tensor, axis)', filter_group=filter_group.FilterGroup.NONE, weight=2)) arg_1 = value.InputValue([[1, 3], [50, 20]], 'my_input') arg_2 = value.ConstantValue('tf-coder') self.assertEqual(operation.reconstruct_expression([arg_1, arg_2]), "tf.reduce_sum(my_input, axis='tf-coder')")
def test_apply_returns_none_if_exception(self): operation = function_operation.FunctionOperation( tf_functions.FunctionInfo( name='tf.reduce_sum(input_tensor, axis)', filter_group=filter_group.FilterGroup.NONE, weight=2)) input_tensor = value.ConstantValue( tf.constant([[1, 3, 4], [50, 20, 80]])) axis_2 = value.ConstantValue(2) self.assertIsNone( operation.apply([input_tensor, axis_2], self.settings))
def test_metadata(self): operation = function_operation.FunctionOperation( tf_functions.FunctionInfo( name='tf.reduce_sum(input_tensor, axis)', filter_group=filter_group.FilterGroup.NONE, weight=2)) docstring = operation.metadata.docstring self.assertIn( 'Computes the sum of elements across dimensions of a tensor.', docstring) self.assertIn('tf.reduce_sum(input_tensor, axis)', docstring) self.assertIn('reduce sum', docstring) self.assertIn('input tensor', docstring)
def test_reconstruct_all_expressions_with_input_names_using_addition(self): constants = [value.ConstantValue(i) for i in range(10)] add_operation = function_operation.FunctionOperation( tf_functions.FunctionInfo( name='tf.add(x, y)', filter_group=filter_group.FilterGroup.NONE, weight=1)) # The i-th element contains all unique Value objects of weight i, mapped to # themselves to allow retrieving the stored Value equal to some query Value. values_by_weight = [collections.OrderedDict()] # Nothing of weight 0. # Add constants with weight 1. values_by_weight.append(collections.OrderedDict()) for constant in constants: values_by_weight[1][constant] = constant for weight in range(2, 6): new_values = collections.OrderedDict() for arg_1_weight in range(1, weight): arg_2_weight = weight - arg_1_weight - 1 for arg1, arg2 in itertools.product( values_by_weight[arg_1_weight], values_by_weight[arg_2_weight]): result = add_operation.apply([arg1, arg2], self.settings) if result not in new_values: new_values[result] = result else: new_values[result].merge_reconstructions(result) values_by_weight.append(new_values) query = value.OutputValue(9) # The form must be (a + b), where there are 10 choices for a, which then # determines b. reconstructions = (values_by_weight[3][query]. reconstruct_all_expressions_with_input_names()) self.assertLen(reconstructions, 10) # No expression uses input values. self.assertTrue( all(not bool(used_names) for _, used_names in reconstructions)) # No AST with only binary operators has weight 4. self.assertEmpty(values_by_weight[4]) # The form is either (a + (b + c)) or ((a + b) + c). Each of the two forms # has 1 + 2 + ... + 9 = 45 options. Note that "a" in (a + (b + c)) cannot be # 0, or else (b + c) would have the same value as the entire expression. # Similarly, "c" in ((a + b) + c) cannot be 0. self.assertLen( values_by_weight[5] [query].reconstruct_all_expressions_with_input_names(), 90)
def test_reconstruct_all_expressions_with_input_names(self): input_0 = value.InputValue(0, 'in_0') constant_1 = value.ConstantValue(1) constant_2 = value.ConstantValue(2) my_input = value.InputValue([[1, 3, 2], [-3, 0, 4]], 'my_input') final_value = self.operation.apply([my_input, constant_1], self.settings) add_operation = function_operation.FunctionOperation( tf_functions.FunctionInfo( name='tf.add(x, y)', filter_group=filter_group.FilterGroup.NONE, weight=1)) input_1321 = value.InputValue([[1, 3], [2, 1]], 'in_1321') value_2432 = add_operation.apply([input_1321, constant_1], self.settings) input_0210 = value.InputValue([[0, 2], [1, 0]], 'in_0210') value_2432_duplicate = add_operation.apply([input_0210, constant_2], self.settings) self.assertEqual(value_2432, value_2432_duplicate) value_2432.merge_reconstructions(value_2432_duplicate) final_value_duplicate = self.operation.apply([value_2432, input_0], self.settings) self.assertEqual(final_value, final_value_duplicate) final_value.merge_reconstructions(final_value_duplicate) input_1430 = value.InputValue([[1, 4], [3, 0]], 'in_1430') final_value_duplicate = self.operation.apply([input_1430, input_0], self.settings) self.assertEqual(final_value, final_value_duplicate) final_value.merge_reconstructions(final_value_duplicate) expected = [ ('tf.reduce_max(my_input, axis=1)', {'my_input'}), ('tf.reduce_max(tf.add(in_1321, 1), axis=in_0)', {'in_1321', 'in_0'}), ('tf.reduce_max(tf.add(in_0210, 2), axis=in_0)', {'in_0210', 'in_0'}), ('tf.reduce_max(in_1430, axis=in_0)', {'in_1430', 'in_0'}), ] self.assertEqual( final_value.reconstruct_all_expressions_with_input_names(), expected)
def test_apply_succeeds(self): operation = function_operation.FunctionOperation( tf_functions.FunctionInfo( name='tf.reduce_sum(input_tensor, axis)', filter_group=filter_group.FilterGroup.NONE, weight=2)) input_tensor = value.ConstantValue( tf.constant([[1, 3, 4], [50, 20, 80]])) axis_0 = value.ConstantValue(0) axis_1 = value.ConstantValue(1) self.assertEqual( operation.apply([input_tensor, axis_0], self.settings), value.ConstantValue(tf.constant([51, 23, 84]))) self.assertEqual( operation.apply([input_tensor, axis_1], self.settings), value.ConstantValue(tf.constant([8, 150])))
def get_sparse_operations() -> List[operation_base.Operation]: """Returns a list of Operation objects for sparse operations.""" return [ function_operation.FunctionOperation(function_info) for function_info in tf_functions.SPARSE_FUNCTIONS ]
def get_tf_operations() -> List[operation_base.Operation]: """Returns a list of Operation objects for dense TensorFlow operations.""" return [ function_operation.FunctionOperation(function_info) for function_info in tf_functions.PY_FUNCTIONS ]