コード例 #1
0
 def test_apply_for_tensor(self, index, expected):
     tensor = tf.constant([[1.2, 3.4], [5.6, 7.8]])
     arg_values = [value.ConstantValue(tensor), value.ConstantValue(index)]
     result_value = self.operation.apply(arg_values,
                                         settings_module.default_settings())
     self.assertEqual(result_value,
                      value.ConstantValue(tf.constant(expected)))
コード例 #2
0
 def setUp(self):
     super(SlicingAxis1BothOperationTest, self).setUp()
     self.operation = python_operations.SlicingAxis1BothOperation()
     self.arg_values = [
         value.InputValue([[12, 34, 56, 78], [-1, -2, -3, -4]], 'my_input'),
         value.ConstantValue(1),
         value.ConstantValue(-1)
     ]
コード例 #3
0
 def setUp(self):
     super(TripleCreationOperationTest, self).setUp()
     self.operation = python_operations.TripleCreationOperation()
     self.arg_values = [
         value.ConstantValue(12),
         value.ConstantValue(34),
         value.ConstantValue(56)
     ]
コード例 #4
0
 def test_apply_for_list(self, index, expected):
     arg_values = [
         value.ConstantValue([1.2, 3.4, 5.6]),
         value.ConstantValue(index)
     ]
     result_value = self.operation.apply(arg_values,
                                         settings_module.default_settings())
     self.assertEqual(result_value, value.ConstantValue(expected))
コード例 #5
0
 def test_apply_returns_none_if_bad_value(self):
     operation = function_operation.FunctionOperation(
         tf_functions.FunctionInfo(
             name='tf.one_hot(indices, depth)',
             filter_group=filter_group.FilterGroup.NONE,
             weight=2))
     indices = value.ConstantValue(tf.ones([limits.MAX_DIMENSION_LENGTH]))
     depth = value.ConstantValue(limits.MAX_TENSOR_ELEMENTS)
     self.assertIsNone(operation.apply([indices, depth], self.settings))
コード例 #6
0
 def test_apply_returns_none_if_exception(self):
     operation = function_operation.FunctionOperation(
         tf_functions.FunctionInfo(
             name='tf.reduce_sum(input_tensor, axis)',
             filter_group=filter_group.FilterGroup.NONE,
             weight=2))
     input_tensor = value.ConstantValue(
         tf.constant([[1, 3, 4], [50, 20, 80]]))
     axis_2 = value.ConstantValue(2)
     self.assertIsNone(
         operation.apply([input_tensor, axis_2], self.settings))
コード例 #7
0
    def test_reconstruct_all_expressions_with_input_names(self):
        input_0 = value.InputValue(0, 'in_0')
        constant_1 = value.ConstantValue(1)
        constant_2 = value.ConstantValue(2)

        my_input = value.InputValue([[1, 3, 2], [-3, 0, 4]], 'my_input')
        final_value = self.operation.apply([my_input, constant_1],
                                           self.settings)

        add_operation = function_operation.FunctionOperation(
            tf_functions.FunctionInfo(
                name='tf.add(x, y)',
                filter_group=filter_group.FilterGroup.NONE,
                weight=1))

        input_1321 = value.InputValue([[1, 3], [2, 1]], 'in_1321')
        value_2432 = add_operation.apply([input_1321, constant_1],
                                         self.settings)

        input_0210 = value.InputValue([[0, 2], [1, 0]], 'in_0210')
        value_2432_duplicate = add_operation.apply([input_0210, constant_2],
                                                   self.settings)
        self.assertEqual(value_2432, value_2432_duplicate)
        value_2432.merge_reconstructions(value_2432_duplicate)

        final_value_duplicate = self.operation.apply([value_2432, input_0],
                                                     self.settings)
        self.assertEqual(final_value, final_value_duplicate)
        final_value.merge_reconstructions(final_value_duplicate)

        input_1430 = value.InputValue([[1, 4], [3, 0]], 'in_1430')
        final_value_duplicate = self.operation.apply([input_1430, input_0],
                                                     self.settings)

        self.assertEqual(final_value, final_value_duplicate)
        final_value.merge_reconstructions(final_value_duplicate)

        expected = [
            ('tf.reduce_max(my_input, axis=1)', {'my_input'}),
            ('tf.reduce_max(tf.add(in_1321, 1), axis=in_0)',
             {'in_1321', 'in_0'}),
            ('tf.reduce_max(tf.add(in_0210, 2), axis=in_0)',
             {'in_0210', 'in_0'}),
            ('tf.reduce_max(in_1430, axis=in_0)', {'in_1430', 'in_0'}),
        ]

        self.assertEqual(
            final_value.reconstruct_all_expressions_with_input_names(),
            expected)
コード例 #8
0
  def test_merge_reconstructions(self):
    tensor_value = value.InputValue([[1, 3, 2], [-3, 0, 4]], 'my_input')
    axis_value = value.ConstantValue(1)
    operation_value = self.operation.apply([tensor_value, axis_value],
                                           self.settings)
    self.assertLen(operation_value.operation_applications, 1)

    tensor_value_2 = value.InputValue([[1, 4], [3, 0]], 'my_input_2')
    axis_value_2 = value.ConstantValue(0)
    operation_value_2 = self.operation.apply([tensor_value_2, axis_value_2],
                                             self.settings)
    self.assertEqual(operation_value, operation_value_2)

    operation_value.merge_reconstructions(operation_value_2)
    self.assertLen(operation_value.operation_applications, 2)
コード例 #9
0
 def test_apply_succeeds(self):
     operation = function_operation.FunctionOperation(
         tf_functions.FunctionInfo(
             name='tf.reduce_sum(input_tensor, axis)',
             filter_group=filter_group.FilterGroup.NONE,
             weight=2))
     input_tensor = value.ConstantValue(
         tf.constant([[1, 3, 4], [50, 20, 80]]))
     axis_0 = value.ConstantValue(0)
     axis_1 = value.ConstantValue(1)
     self.assertEqual(
         operation.apply([input_tensor, axis_0], self.settings),
         value.ConstantValue(tf.constant([51, 23, 84])))
     self.assertEqual(
         operation.apply([input_tensor, axis_1], self.settings),
         value.ConstantValue(tf.constant([8, 150])))
コード例 #10
0
 def setUp(self):
     super(IndexingAxis1OperationTest, self).setUp()
     self.operation = python_operations.IndexingAxis1Operation()
     self.arg_values = [
         value.InputValue([[12, 34], [56, 78]], 'my_input'),
         value.ConstantValue(1)
     ]
コード例 #11
0
    def test_extract_examples_from_value_without_inputs(self):
        constant_1 = value_module.ConstantValue(1)
        constant_2 = value_module.ConstantValue(2)
        constant_3 = value_module.ConstantValue(3)

        subtree = self.add_operation.apply([constant_1, constant_2],
                                           self.settings)
        without_inputs = self.add_operation.apply([subtree, constant_3],
                                                  self.settings)
        actual = collect_tensor_data.extract_examples_from_value(
            without_inputs)

        self.assertCountEqual(
            [example.expression for example in actual],
            # `tf.add(tf.add(1, 2), 3)` has no inputs and is not included.
            ['tf.add(in1, 3)'])
コード例 #12
0
 def setUp(self):
     super(SlicingAxis1RightOperationTest, self).setUp()
     self.operation = python_operations.SlicingAxis1RightOperation()
     self.arg_values = [
         value.InputValue([[12, 34, 56]], 'my_input'),
         value.ConstantValue(-1)
     ]
コード例 #13
0
  def test_init_with_primitive(self):
    constant_value = value.ConstantValue(12)
    self.assertEqual(constant_value.type, int)
    self.assertTrue(constant_value.is_primitive)

    self.assertIsNone(constant_value.elem_type)
    self.assertFalse(constant_value.elem_type_is_tensor)
    self.assertIsNone(constant_value.dtype)
    self.assertIsNone(constant_value.shape)
コード例 #14
0
def create_examples(io_example: IOExample,
                    max_num_inputs: int = 3,
                    permute_inputs: bool = True) -> List[Dict[Text, Any]]:
    """Creates example dicts for the I/O example."""
    examples = []
    operation_list = all_operations.get_operations(
        include_sparse_operations=True)
    operation_counter = collections.Counter(
        [op.name for op in io_example.operations])
    operation_counts = [operation_counter[op.name] for op in operation_list]

    num_inputs = len(io_example.input_values)

    try:
        num_inputs_feature = min(num_inputs, max_num_inputs)
        output_features = featurize_value(io_example.output_value)
        input_features = []
        for input_value in io_example.input_values[:max_num_inputs]:
            combined_features = featurize_value(input_value)
            combined_features.update(
                featurize_input_and_output(input_value,
                                           io_example.output_value))
            input_features.append(combined_features)

        dummy_value = value_module.ConstantValue(0)
        dummy_input_features = featurize_value(dummy_value)
        dummy_input_features.update(
            featurize_input_and_output(dummy_value, dummy_value))

    except ValueError as e:
        logging.warning('%s: could not featurize IOExample %s', e, io_example)
        return []

    permutations = (itertools.permutations(range(num_inputs))
                    if permute_inputs else [list(range(num_inputs))])
    for permutation in permutations:

        feature_dict = collections.defaultdict(list)
        feature_dict['num_inputs'] = [num_inputs_feature]
        feature_dict.update(copy.deepcopy(output_features))

        padded_input_features = [
            input_features[index] for index in permutation
        ]
        for _ in range(max_num_inputs - num_inputs):
            padded_input_features.append(dummy_input_features)

        for input_features_dict in padded_input_features:
            for key, value in six.iteritems(input_features_dict):
                feature_dict[key].extend(value)

        feature_dict['operations'] = operation_counts
        feature_dict['expression'] = [io_example.expression]

        examples.append(feature_dict)

    return examples
コード例 #15
0
 def test_reconstruct_expression(self):
     operation = function_operation.FunctionOperation(
         tf_functions.FunctionInfo(
             name='tf.reduce_sum(input_tensor, axis)',
             filter_group=filter_group.FilterGroup.NONE,
             weight=2))
     arg_1 = value.InputValue([[1, 3], [50, 20]], 'my_input')
     arg_2 = value.ConstantValue('tf-coder')
     self.assertEqual(operation.reconstruct_expression([arg_1, arg_2]),
                      "tf.reduce_sum(my_input, axis='tf-coder')")
コード例 #16
0
    def test_reconstruct_all_expressions_with_input_names_using_addition(self):
        constants = [value.ConstantValue(i) for i in range(10)]
        add_operation = function_operation.FunctionOperation(
            tf_functions.FunctionInfo(
                name='tf.add(x, y)',
                filter_group=filter_group.FilterGroup.NONE,
                weight=1))

        # The i-th element contains all unique Value objects of weight i, mapped to
        # themselves to allow retrieving the stored Value equal to some query Value.
        values_by_weight = [collections.OrderedDict()]  # Nothing of weight 0.
        # Add constants with weight 1.
        values_by_weight.append(collections.OrderedDict())
        for constant in constants:
            values_by_weight[1][constant] = constant

        for weight in range(2, 6):
            new_values = collections.OrderedDict()
            for arg_1_weight in range(1, weight):
                arg_2_weight = weight - arg_1_weight - 1
                for arg1, arg2 in itertools.product(
                        values_by_weight[arg_1_weight],
                        values_by_weight[arg_2_weight]):
                    result = add_operation.apply([arg1, arg2], self.settings)
                    if result not in new_values:
                        new_values[result] = result
                    else:
                        new_values[result].merge_reconstructions(result)
            values_by_weight.append(new_values)

        query = value.OutputValue(9)

        # The form must be (a + b), where there are 10 choices for a, which then
        # determines b.
        reconstructions = (values_by_weight[3][query].
                           reconstruct_all_expressions_with_input_names())
        self.assertLen(reconstructions, 10)
        # No expression uses input values.
        self.assertTrue(
            all(not bool(used_names) for _, used_names in reconstructions))

        # No AST with only binary operators has weight 4.
        self.assertEmpty(values_by_weight[4])

        # The form is either (a + (b + c)) or ((a + b) + c). Each of the two forms
        # has 1 + 2 + ... + 9 = 45 options. Note that "a" in (a + (b + c)) cannot be
        # 0, or else (b + c) would have the same value as the entire expression.
        # Similarly, "c" in ((a + b) + c) cannot be 0.
        self.assertLen(
            values_by_weight[5]
            [query].reconstruct_all_expressions_with_input_names(), 90)
コード例 #17
0
 def test_run_value_handles_large_weight_constants(self):
     benchmark = benchmark_module.Benchmark(examples=[
         benchmark_module.Example(inputs=[[1], [2]], output=[[3]])
     ])
     results = value_search.run_value_search(benchmark=benchmark,
                                             settings=self.settings)
     self.assertNotEmpty(results.solutions)
     self.assertEqual(results.solutions[0].expression,
                      'tf.add(in1, tf.expand_dims(in2, 0))')
     output_shape_constant = value_module.ConstantValue((1, 1))
     self.assertIn(output_shape_constant, results.value_set)
     # Find the element in value_set equal to output_shape_constant and assert
     # that it's actually a ConstantValue, as opposed to an OperationValue.
     for value in results.value_set:
         if value == output_shape_constant:
             self.assertIsInstance(value, value_module.ConstantValue)
コード例 #18
0
    def setUp(self):
        super(CollectTensorDataTest, self).setUp()
        self.settings = settings_module.default_settings()

        operations = all_operations.get_operations()
        self.unique_with_counts_operation = all_operations.find_operation_with_name(
            'tf.unique_with_counts(x)', operation_list=operations)
        self.indexing_operation = all_operations.find_operation_with_name(
            'IndexingOperation', operation_list=operations)
        self.gather_operation = all_operations.find_operation_with_name(
            'tf.gather(params, indices)', operation_list=operations)
        self.add_operation = all_operations.find_operation_with_name(
            'tf.add(x, y)', operation_list=operations)

        # Example with many operations.
        in1 = value_module.InputValue([1, 1, 2, 5, 6, 5], 'in1')
        in2 = value_module.InputValue([0, 10, 20, 30, 40, 50, 60, 70], 'in2')
        constant_1 = value_module.ConstantValue(1)

        unique = self.unique_with_counts_operation.apply([in1], self.settings)
        indexed = self.indexing_operation.apply([unique, constant_1],
                                                self.settings)
        gathered = self.gather_operation.apply([in2, in1], self.settings)
        self.example_value_1 = self.add_operation.apply([indexed, gathered],
                                                        self.settings)

        self.assertEqual(
            self.example_value_1.reconstruct_expression(),
            'tf.add(tf.unique_with_counts(in1)[1], tf.gather(in2, in1))')
        self.assertEqual(self.example_value_1,
                         value_module.OutputValue([10, 10, 21, 52, 63, 52]))

        # Example with many variables and new inputs.
        in3 = value_module.InputValue([1], 'in3')
        in4 = value_module.InputValue([2], 'in4')

        a = self.add_operation.apply([in3, new_input([10])], self.settings)
        b = self.add_operation.apply([in4, in3], self.settings)
        c = self.add_operation.apply([new_input([20]), in3], self.settings)
        d = self.add_operation.apply([a, b], self.settings)
        self.example_value_2 = self.add_operation.apply([c, d], self.settings)

        self.assertEqual(
            self.example_value_2.reconstruct_expression(),
            'tf.add(tf.add(NEW_INPUT, in3), '
            'tf.add(tf.add(in3, NEW_INPUT), tf.add(in4, in3)))')
        self.assertEqual(self.example_value_2, value_module.OutputValue([35]))
コード例 #19
0
  def test_copy(self):
    tensor_value = value.InputValue([[1, 3, 2], [-3, 0, 4]], 'my_input')
    axis_value = value.ConstantValue(1)
    operation_value = self.operation.apply([tensor_value, axis_value],
                                           self.settings)

    copy_value = operation_value.copy()
    self.assertIsNot(operation_value, copy_value)
    self.assertTrue(operation_value.reconstruct_expression(use_cache=False)  # pylint: disable=g-generic-assert
                    == copy_value.reconstruct_expression(use_cache=False)
                    == 'tf.reduce_max(my_input, axis=1)')

    copy_value.operation_applications[0].arg_values[0].name = 'new_name'
    self.assertEqual(operation_value.reconstruct_expression(use_cache=False),
                     'tf.reduce_max(my_input, axis=1)')
    self.assertEqual(copy_value.reconstruct_expression(use_cache=False),
                     'tf.reduce_max(new_name, axis=1)')
コード例 #20
0
  def test_reconstruct_expression(self):
    tensor_value = value.InputValue([[1, 3, 2], [-3, 0, 4]], 'my_input')
    axis_value = value.ConstantValue(1)

    operation_value = value.OperationValue(tf.constant([3, 4]),
                                           self.operation,
                                           [tensor_value, axis_value])
    expected_expression = 'tf.reduce_max(my_input, axis=1)'
    self.assertEqual(operation_value.reconstruct_expression(),
                     expected_expression)
    self.assertEqual(operation_value.reconstruct_expression(),
                     expected_expression)  # Cached.

    operation_value_apply = self.operation.apply([tensor_value, axis_value],
                                                 self.settings)
    self.assertEqual(operation_value, operation_value_apply)
    self.assertEqual(operation_value_apply.reconstruct_expression(),
                     expected_expression)
コード例 #21
0
 def test_get_type_filter(self):
     int_type_filter = operation_filtering.get_type_filter(int)
     self.assertTrue(int_type_filter(value.ConstantValue(1)))
     self.assertFalse(int_type_filter(value.ConstantValue(1.0)))
コード例 #22
0
 def test_apply(self):
     self.assertEqual(
         self.operation.apply(self.arg_values,
                              settings_module.default_settings()),
         value.ConstantValue((12, )))
コード例 #23
0
def _value(wrapped_value):
  """A simple utility to create Value objects."""
  return value.ConstantValue(wrapped_value)
コード例 #24
0
def _find_solutions(
    benchmark: benchmark_module.Benchmark,
    operations: List[operation_base.Operation],
    start_time: float,
    settings: settings_module.Settings
) -> Tuple[List[Solution], Set[value_module.Value], ValuesByWeight,
           Optional[operation_statistics.OperationStatistics]]:
  """Helper, returning (solutions, value_set, values_by_weight, statistics)."""
  timeout_reached = False
  end_time = start_time + settings.timeout

  only_minimal_solutions = settings.only_minimal_solutions
  if settings.max_solutions == 1:
    # If we only want one solution, it will be minimal.
    only_minimal_solutions = True

  # An object to track statistics, if requested.
  statistics = (operation_statistics.OperationStatistics()
                if settings.printing.statistics
                else None)

  # A list of Solution namedtuples.
  solutions = []

  # A set of string solution expressions (don't return duplicate solutions).
  solution_expression_set = set()

  # The output value to search for.
  output_value = value_module.OutputValue(benchmark.examples[0].output)

  # A list of OrderedDicts mapping Value objects to themselves. The i-th
  # OrderedDict contains all Value objects of weight i.
  values_by_weight = [collections.OrderedDict()
                      for _ in range(settings.max_weight + 1)]

  # Find and cache the cast and constant operations for use later.
  cast_operation = None
  constant_operation = None
  for operation in operations:
    if operation.name == tf_functions.CAST_OPERATION_NAME:
      cast_operation = operation
    elif operation.name == tf_functions.CONSTANT_OPERATION_NAME:
      constant_operation = operation
  # Create the output dtype value for use later.
  dtype_value = value_module.ConstantValue(output_value.dtype)

  # Populate values_by_weight with inputs and constants. This also prints
  # inputs/output/constants to stdout.
  _add_constants_and_inputs_and_print(
      values_by_weight, benchmark, output_value, constant_operation, settings)

  # A set storing all values found so far.
  value_set = set().union(*values_by_weight)

  filter_cache = filtered_values_cache.FilteredValuesCache()

  # Value search by weight.
  for weight in range(1, settings.max_weight + 1):
    if settings.printing.progress:
      print('Searching weight {}...'.format(weight))

    # Values with the current weight. This might already include leaf values.
    new_values = values_by_weight[weight]

    for operation in operations:
      for value in operation.enumerate_values_with_weight(
          target_weight=weight,
          values_by_weight=values_by_weight,
          filter_cache=filter_cache,
          end_time=end_time,
          settings=settings,
          statistics=statistics):

        if value not in value_set:
          # This value has never been seen before, or it's the desired output.
          if settings.printing.verbose:
            expression = value.reconstruct_expression()
            print('{} produces:\n{}'.format(expression, value))

          if value == output_value:
            possible_first_solution = not solutions
            # Found solution(s), but some may be bad.
            _record_solutions(value, weight, start_time, solutions,
                              solution_expression_set, benchmark, settings)
            if possible_first_solution and solutions:
              end_time = min(
                  end_time,
                  timeit.default_timer() + settings.max_extra_solutions_time)
            if len(solutions) >= settings.max_solutions:
              return solutions, value_set, values_by_weight, statistics
          else:
            # Only store the value if it isn't a solution. Otherwise, we'll get
            # lots of "almost duplicate" solutions, e.g., by adding 0.
            new_values[value] = value
            # We should never add output_value (or anything equal) to value_set
            # so that we can continue finding other solutions.
            value_set.add(value)
        else:  # This value has been seen before.
          if value in new_values:
            # The value was already computed differently with this weight.
            original_value = new_values[value]
            if isinstance(original_value, value_module.OperationValue):
              # Only merge reconstructions if this was originally an
              # OperationValue. (It could be a ConstantValue instead.)
              operation_value = original_value   # type: value_module.OperationValue
              operation_value.merge_reconstructions(value)
          elif not only_minimal_solutions:
            # If we want non-minimal solutions, we need to store the value even
            # if we have already seen that value with a smaller weight.
            new_values[value] = value

      if timeit.default_timer() > end_time:
        timeout_reached = True
        # Don't return immediately; still try to cast new values because this is
        # relatively quick.
        break

    # Try casting new values to the output dtype if this has a chance of being
    # a correct solution.
    for new_value in new_values:
      if (cast_operation is not None and
          new_value.shape == output_value.shape and
          new_value.dtype != output_value.dtype and
          operation_filtering.is_castable(new_value, dtype_value)):
        casted_value = cast_operation.apply([new_value, dtype_value], settings)
        if casted_value == output_value:
          possible_first_solution = not solutions
          # Found solution(s), but some may be bad.
          _record_solutions(casted_value, weight, start_time, solutions,
                            solution_expression_set, benchmark, settings)
          if possible_first_solution and solutions:
            end_time = min(
                end_time,
                timeit.default_timer() + settings.max_extra_solutions_time)
          if len(solutions) >= settings.max_solutions:
            return solutions, value_set, values_by_weight, statistics

    if settings.printing.progress:
      print('Found {} distinct values of weight {}, or {} total.'.format(
          len(new_values), weight, len(value_set)))
    if only_minimal_solutions and solutions:
      return solutions, value_set, values_by_weight, statistics
    if timeout_reached:
      break

  return solutions, value_set, values_by_weight, statistics
コード例 #25
0
def _add_constants_and_inputs_and_print(
    values_by_weight: ValuesByWeight,
    benchmark: benchmark_module.Benchmark,
    output_value: value_module.OutputValue,
    constant_operation: operation_base.Operation,
    settings: settings_module.Settings) -> None:
  """Adds constant/input Values to values_by_weight, and prints to stdout."""
  # Conceptually this is a set, but it's actually a list so that constants can
  # be printed in the same order they are chosen by the heuristics. The reduced
  # efficiency of membership-checking is not a big deal because we have few
  # constants.
  constants_so_far = set()
  constants_to_print = []

  # User-provided constants.
  for c in benchmark.constants:
    if not _constant_exists(c, constants_so_far):
      constant_value = value_module.ConstantValue(c)
      weight = tf_functions.PROVIDED_CONSTANT_WEIGHT
      _add_value_by_weight(values_by_weight, constant_value, weight)
      constants_so_far.add(c)
      constants_to_print.append(c)

  # Add inputs, while computing some info for extra constants later.
  max_input_tensor_rank = 0
  dimension_lengths = set()
  input_names_to_objects = _input_names_to_objects(benchmark.examples[0].inputs)
  for name, input_object in input_names_to_objects.items():
    input_value = value_module.InputValue(input_object, name)
    if input_value.is_tensor:
      max_input_tensor_rank = max(max_input_tensor_rank, len(input_value.shape))
      dimension_lengths.update(input_value.shape)
    if input_value.is_primitive and constant_operation is not None:
      scalar_tensor_value = constant_operation.apply([input_value], settings)
      _add_value_by_weight(values_by_weight, scalar_tensor_value,
                           tf_functions.PRIMITIVE_INPUT_AS_TENSOR_WEIGHT)

    _add_value_by_weight(values_by_weight, input_value,
                         tf_functions.INPUT_VARIABLE_WEIGHT)
    if input_value.is_primitive:
      constants_so_far.add(input_value.value)

    print("Input '{}':\n{!s}\n".format(name, input_value.value))

  if output_value.shape is not None:
    dimension_lengths.update(output_value.shape)

  # Always include these as constants.
  common_constants = [0, 1, -1, True, False]
  # Also include 2, 3, ..., max_example_input_tensor_rank - 1 when applicable.
  axis_constants = list(range(2, max_input_tensor_rank))
  # Also include dimension lengths of input and output tensors.
  shape_constants = sorted(dimension_lengths)

  constant_weight_pairs = (
      [(c, tf_functions.COMMON_CONSTANT_WEIGHT) for c in common_constants] +
      [(c, tf_functions.AXIS_CONSTANT_WEIGHT) for c in axis_constants] +
      [(c, tf_functions.SHAPE_CONSTANT_WEIGHT) for c in shape_constants])

  for constant, weight in constant_weight_pairs:
    if not _constant_exists(constant, constants_so_far):
      constant_value = value_module.ConstantValue(constant)
      _add_value_by_weight(values_by_weight, constant_value, weight)
      constants_so_far.add(constant)
      constants_to_print.append(constant)

  # DTypes for casting.
  for dtype, weight in tf_functions.CONSTANT_DTYPES_AND_WEIGHTS.items():
    dtype_value = value_module.ConstantValue(dtype)
    _add_value_by_weight(values_by_weight, dtype_value, weight)

  if output_value.shape:
    # Add the output shape as a constant.
    shape_tuple = tuple(output_value.shape)
    shape_tuple_value = value_module.ConstantValue(shape_tuple)
    weight = tf_functions.OUTPUT_SHAPE_TUPLE_WEIGHT
    _add_value_by_weight(values_by_weight, shape_tuple_value, weight)
    # Don't add shape_tuple to constants_to_print, because printing it out could
    # be confusing to users.

  # Only for experiments in the PLDI paper.
  if settings.paper_experiments.uniform_weights:
    # Count the number of values.
    num_values = sum(len(values_with_weight)
                     for values_with_weight in values_by_weight)
    # Take all values and put them in the collection for weight 1.
    for weight in range(2, len(values_by_weight)):
      for heavy_value in values_by_weight[weight]:
        values_by_weight[1][heavy_value] = heavy_value
      values_by_weight[weight].clear()
    # Make sure we did it right.
    for weight, values_with_weight in enumerate(values_by_weight):
      assert len(values_with_weight) == (num_values if weight == 1 else 0)

  print('Output:\n{!s}\n'.format(output_value.value))
  print('Constants: {!r}\n'.format(constants_to_print))
  if benchmark.description:
    print('Description: {}\n'.format(benchmark.description))
  print('Searching...\n')
  sys.stdout.flush()  # Flush so the inputs/output appear in Colab immediately.
コード例 #26
0
 def test_init_does_not_convert_to_tensor(self):
   constant_value = value.ConstantValue(1)
   self.assertFalse(constant_value.is_tensor)
   self.assertTrue(constant_value.is_primitive)
コード例 #27
0
 def test_copy(self):
   constant_value = value.ConstantValue(1)
   copy_value = constant_value.copy()
   self.assertIsNot(constant_value, copy_value)
   self.assertEqual(constant_value.reconstruct_expression(),
                    copy_value.reconstruct_expression())
コード例 #28
0
 def test_reconstruct_expression(self):
   constant_value = value.ConstantValue('TF-Coder')
   self.assertEqual(constant_value.reconstruct_expression(), "'TF-Coder'")
コード例 #29
0
 def test_init_with_dtype(self):
   constant_value = value.ConstantValue(tf.int64)
   self.assertTrue(constant_value.is_dtype)
コード例 #30
0
 def test_get_dtype_filter(self):
     int32_dtype_filter = operation_filtering.get_dtype_filter(tf.int32)
     self.assertTrue(int32_dtype_filter(value.ConstantValue(
         tf.constant(1))))
     self.assertFalse(
         int32_dtype_filter(value.ConstantValue(tf.constant(1.0))))