Ejemplo n.º 1
0
    def test_enumerate_values_with_weight(self):
        values_by_weight = [
            [],  # Weight 0.
            [_value(1), _value(4), _value(9),
             _value(15)],  # Weight 1.
            [_value(2), _value(6),
             _value(20), _value(60)],  # Weight 2.
            [_value(10), _value(12)],  # Weight 3.
        ]
        statistics = operation_statistics.OperationStatistics()

        filter_cache = filtered_values_cache.FilteredValuesCache()
        actual_results = self.operation.enumerate_values_with_weight(
            9,
            values_by_weight,
            filter_cache,
            end_time=float('inf'),
            settings=self.settings,
            statistics=statistics)

        # The operation itself has weight 5. We can get a total weight of 9 by
        # adding arguments of weight 1 and 3, or 2 and 2.
        expected_results = [
            # Weight 1 + weight 3.

            # X divisible by 2, Y divisible by 3, and X < Y.
            _value(16),  # 4 + 12.

            # X divisible by 4, Y divisible by 5, and X < Y.
            _value(14),  # 4 + 10.

            ################################################

            # Weight 2 + weight 2.

            # X divisible by 2, Y divisible by 3, and X < Y.
            _value(8),  # 2 + 6.
            _value(62),  # 2 + 60.
            _value(66),  # 6 + 60.
            _value(80),  # 20 + 60.

            # X divisible by 4, Y divisible by 5, and X < Y.
            _value(80),  # 20 + 60.

            ################################################

            # Weight 3 + weight 1.

            # X divisible by 2, Y divisible by 3, and X < Y.
            _value(25),  # 10 + 15.
            _value(27),  # 12 + 15.

            # X divisible by 4, Y divisible by 5, and X < Y.
            _value(27),  # 12 + 15.
        ]
        self.assertCountEqual(actual_results, expected_results)

        self.assertEqual(statistics.total_apply_count, 10)
        self.assertEqual(statistics.operation_apply_successes,
                         {'strange_addition': 10})
 def test_get_total_time(self):
     statistics = operation_statistics.OperationStatistics()
     statistics.update('a', count=10, successes=1, time=1.5)
     statistics.update('b', count=80, successes=60, time=40.0)
     statistics.update('a', count=100, successes=10, time=8.5)
     statistics.update('c', count=0, successes=0, time=0.0)
     self.assertEqual(statistics.get_total_time(), 50.0)
 def test_statistics_as_string_sorted_by_time(self):
     statistics = operation_statistics.OperationStatistics()
     statistics.update('fast_op', count=10, successes=1, time=1.5)
     statistics.update('slow_op', count=10, successes=1, time=2.5)
     result = statistics.statistics_as_string(sort_by_time=False)
     self.assertLess(result.index('fast_op'), result.index('slow_op'))
     result = statistics.statistics_as_string(sort_by_time=True)
     self.assertLess(result.index('slow_op'), result.index('fast_op'))
    def test_statistics_as_string(self):
        statistics = operation_statistics.OperationStatistics()
        statistics.update('a', count=10, successes=1, time=1.5)
        statistics.update('b', count=80, successes=60, time=40.0)
        statistics.update('a', count=100, successes=10, time=8.5)
        statistics.update('c', count=0, successes=0, time=0.0)
        statistics.update('do_not_output', count=100, successes=10, time=1.23)

        result = statistics.statistics_as_string(
            operation_names=['a', 'b', 'c'],
            num_unique_values=77,
            elapsed_time=100.0)

        a_row = operation_statistics._ROW_FORMAT_STR.format(name='a',
                                                            eps=11.0,
                                                            sps=1.1,
                                                            executions=110,
                                                            successes=11,
                                                            rate=0.1,
                                                            time=10.0,
                                                            time_frac=0.2)
        self.assertIn(a_row, result)
        b_row = operation_statistics._ROW_FORMAT_STR.format(name='b',
                                                            eps=2.0,
                                                            sps=1.5,
                                                            executions=80,
                                                            successes=60,
                                                            rate=0.75,
                                                            time=40.0,
                                                            time_frac=0.8)
        self.assertIn(b_row, result)
        nan = float('NaN')
        c_row = operation_statistics._ROW_FORMAT_STR.format(name='c',
                                                            eps=nan,
                                                            sps=nan,
                                                            executions=0,
                                                            successes=0,
                                                            rate=nan,
                                                            time=0.0,
                                                            time_frac=0.0)
        self.assertIn(c_row, result)
        self.assertNotIn('do_not_output', result)

        self.assertIn('Number of evaluations: 290', result)
        self.assertIn('Number of successful evaluations: 81', result)
        self.assertIn('Total time applying operations: 51.23 sec', result)
        self.assertIn('Number of unique values: 77', result)
        self.assertIn('Executions per second: 2.9', result)
 def test_update(self):
     statistics = operation_statistics.OperationStatistics()
     statistics.update('a', count=10, successes=1, time=1.5)
     statistics.update('b', count=80, successes=60, time=40.0)
     statistics.update('a', count=100, successes=10, time=8.5)
     statistics.update('c', count=0, successes=0, time=0.0)
     self.assertEqual(statistics.total_apply_count, 190)
     self.assertEqual(statistics.total_apply_successes, 71)
     self.assertEqual(statistics.operation_apply_time, {
         'a': 10.0,
         'b': 40.0,
         'c': 0.0
     })
     self.assertEqual(statistics.operation_apply_count, {
         'a': 110,
         'b': 80,
         'c': 0
     })
     self.assertEqual(statistics.operation_apply_successes, {
         'a': 11,
         'b': 60,
         'c': 0
     })
     self.assertEqual(statistics.all_operation_names, {'a', 'b', 'c'})
Ejemplo n.º 6
0
def _find_solutions(
    benchmark: benchmark_module.Benchmark,
    operations: List[operation_base.Operation],
    start_time: float,
    settings: settings_module.Settings
) -> Tuple[List[Solution], Set[value_module.Value], ValuesByWeight,
           Optional[operation_statistics.OperationStatistics]]:
  """Helper, returning (solutions, value_set, values_by_weight, statistics)."""
  timeout_reached = False
  end_time = start_time + settings.timeout

  only_minimal_solutions = settings.only_minimal_solutions
  if settings.max_solutions == 1:
    # If we only want one solution, it will be minimal.
    only_minimal_solutions = True

  # An object to track statistics, if requested.
  statistics = (operation_statistics.OperationStatistics()
                if settings.printing.statistics
                else None)

  # A list of Solution namedtuples.
  solutions = []

  # A set of string solution expressions (don't return duplicate solutions).
  solution_expression_set = set()

  # The output value to search for.
  output_value = value_module.OutputValue(benchmark.examples[0].output)

  # A list of OrderedDicts mapping Value objects to themselves. The i-th
  # OrderedDict contains all Value objects of weight i.
  values_by_weight = [collections.OrderedDict()
                      for _ in range(settings.max_weight + 1)]

  # Find and cache the cast and constant operations for use later.
  cast_operation = None
  constant_operation = None
  for operation in operations:
    if operation.name == tf_functions.CAST_OPERATION_NAME:
      cast_operation = operation
    elif operation.name == tf_functions.CONSTANT_OPERATION_NAME:
      constant_operation = operation
  # Create the output dtype value for use later.
  dtype_value = value_module.ConstantValue(output_value.dtype)

  # Populate values_by_weight with inputs and constants. This also prints
  # inputs/output/constants to stdout.
  _add_constants_and_inputs_and_print(
      values_by_weight, benchmark, output_value, constant_operation, settings)

  # A set storing all values found so far.
  value_set = set().union(*values_by_weight)

  filter_cache = filtered_values_cache.FilteredValuesCache()

  # Value search by weight.
  for weight in range(1, settings.max_weight + 1):
    if settings.printing.progress:
      print('Searching weight {}...'.format(weight))

    # Values with the current weight. This might already include leaf values.
    new_values = values_by_weight[weight]

    for operation in operations:
      for value in operation.enumerate_values_with_weight(
          target_weight=weight,
          values_by_weight=values_by_weight,
          filter_cache=filter_cache,
          end_time=end_time,
          settings=settings,
          statistics=statistics):

        if value not in value_set:
          # This value has never been seen before, or it's the desired output.
          if settings.printing.verbose:
            expression = value.reconstruct_expression()
            print('{} produces:\n{}'.format(expression, value))

          if value == output_value:
            possible_first_solution = not solutions
            # Found solution(s), but some may be bad.
            _record_solutions(value, weight, start_time, solutions,
                              solution_expression_set, benchmark, settings)
            if possible_first_solution and solutions:
              end_time = min(
                  end_time,
                  timeit.default_timer() + settings.max_extra_solutions_time)
            if len(solutions) >= settings.max_solutions:
              return solutions, value_set, values_by_weight, statistics
          else:
            # Only store the value if it isn't a solution. Otherwise, we'll get
            # lots of "almost duplicate" solutions, e.g., by adding 0.
            new_values[value] = value
            # We should never add output_value (or anything equal) to value_set
            # so that we can continue finding other solutions.
            value_set.add(value)
        else:  # This value has been seen before.
          if value in new_values:
            # The value was already computed differently with this weight.
            original_value = new_values[value]
            if isinstance(original_value, value_module.OperationValue):
              # Only merge reconstructions if this was originally an
              # OperationValue. (It could be a ConstantValue instead.)
              operation_value = original_value   # type: value_module.OperationValue
              operation_value.merge_reconstructions(value)
          elif not only_minimal_solutions:
            # If we want non-minimal solutions, we need to store the value even
            # if we have already seen that value with a smaller weight.
            new_values[value] = value

      if timeit.default_timer() > end_time:
        timeout_reached = True
        # Don't return immediately; still try to cast new values because this is
        # relatively quick.
        break

    # Try casting new values to the output dtype if this has a chance of being
    # a correct solution.
    for new_value in new_values:
      if (cast_operation is not None and
          new_value.shape == output_value.shape and
          new_value.dtype != output_value.dtype and
          operation_filtering.is_castable(new_value, dtype_value)):
        casted_value = cast_operation.apply([new_value, dtype_value], settings)
        if casted_value == output_value:
          possible_first_solution = not solutions
          # Found solution(s), but some may be bad.
          _record_solutions(casted_value, weight, start_time, solutions,
                            solution_expression_set, benchmark, settings)
          if possible_first_solution and solutions:
            end_time = min(
                end_time,
                timeit.default_timer() + settings.max_extra_solutions_time)
          if len(solutions) >= settings.max_solutions:
            return solutions, value_set, values_by_weight, statistics

    if settings.printing.progress:
      print('Found {} distinct values of weight {}, or {} total.'.format(
          len(new_values), weight, len(value_set)))
    if only_minimal_solutions and solutions:
      return solutions, value_set, values_by_weight, statistics
    if timeout_reached:
      break

  return solutions, value_set, values_by_weight, statistics