def compute_search_space_size(benchmark, settings, description_handler):
    """Computes and prints the size of the search space.

  This counts the total number of expressions with weight at most max_weight.
  The weights come from the benchmark (for constants and inputs) and the
  description handler (for determining the op weights). Distinct expressions
  will be counted separately even if they evaluate to the same value, unlike in
  TF-Coder's value_search algorithm which does value-based pruning.

  Args:
    benchmark: The Benchmark object defining the problem to analyze.
    settings: A Settings object containing settings for value search.
    description_handler: The DescriptionHandler used, which can modify weights
      of operations.

  Returns:
    Nothing. All output is printed to stdout.
  """

    max_weight = settings.max_weight
    print('Computing search space.\n'
          'Benchmark name: {}\n'
          'Description handler: {}\n'
          'Max weight: {}'.format(benchmark.name, description_handler,
                                  max_weight))

    # TODO(kshi): Update to load the tensor features model/config.
    operations = value_search.get_reweighted_operations(benchmark,
                                                        settings,
                                                        description_handler,
                                                        tensor_model=None,
                                                        tensor_config=None)

    # These loops are not the most efficient, but it doesn't really matter.
    print('\nFound {} operations.'.format(len(operations)))
    print()
    for weight in range(1, max(op.weight for op in operations) + 1):
        print('# operations with weight {}: {}'.format(
            weight, sum(1 for op in operations if op.weight == weight)))
    print()
    for arity in range(1, max(op.num_args for op in operations) + 1):
        print('# operations with arity {}: {}'.format(
            arity, sum(1 for op in operations if op.num_args == arity)))

    output_value = value_module.OutputValue(benchmark.examples[0].output)
    values_by_weight = [
        collections.OrderedDict() for _ in range(max_weight + 1)
    ]

    constant_operation = None
    for operation in operations:
        if operation.name == tf_functions.CONSTANT_OPERATION_NAME:
            constant_operation = operation
            break
    with SuppressPrint():
        value_search._add_constants_and_inputs_and_print(  # pylint: disable=protected-access
            values_by_weight, benchmark, output_value, constant_operation,
            settings)

    num_expressions_with_weight = [
        len(values_with_weight) for values_with_weight in values_by_weight
    ]
    print()
    max_weight_with_initial_value = max(w for w in range(max_weight + 1)
                                        if num_expressions_with_weight[w])
    for weight in range(1, max_weight_with_initial_value + 1):
        print('# initial values with weight {}: {}'.format(
            weight, num_expressions_with_weight[weight]))

    for total_weight in range(2, max_weight + 1):
        for operation in operations:
            # All operations should have strictly positive weight and num_args.
            op_weight = operation.weight
            op_arity = operation.num_args

            if total_weight - op_weight < op_arity:
                continue

            # Partition `total_weight - op_weight` into `op_arity` positive pieces.
            # Equivalently, partition `total_weight - op_weight - op_arity` into
            # `op_arity` nonnegative pieces.
            for partition in tf_coder_utils.generate_partitions(
                    total_weight - op_weight - op_arity, op_arity):
                arg_weights = [part + 1 for part in partition]
                num_expressions_with_weight[total_weight] += functools.reduce(
                    operator.mul,
                    (num_expressions_with_weight[w] for w in arg_weights))

    print()
    for weight in range(1, max_weight + 1):
        print('# expressions with weight exactly {}: {}'.format(
            weight, num_expressions_with_weight[weight]))

    print()
    for weight in range(1, max_weight + 1):
        print('# expressions with weight up to {}: {}'.format(
            weight, sum(num_expressions_with_weight[:weight + 1])))
 def test_generate_partitions_raises_on_invalid_input(
         self, num_elements, num_parts):
     with self.assertRaises(ValueError):
         list(tf_coder_utils.generate_partitions(num_elements, num_parts))
def compute_search_space_size(num_ops, num_nodes, num_leaf_choices):
    """Computes and prints the size of the search space.

  This counts the total number of expressions with exactly the given number of
  operations and nodes in the expression tree. Distinct expressions will be
  counted separately even if they evaluate to the same value, unlike in
  TF-Coder's value_search algorithm which does value-based pruning.

  Args:
    num_ops: The target number of operations.
    num_nodes: A target number of nodes in the expression tree.
    num_leaf_choices: The number of distinct inputs and constants available to
      form the leaves of the expression tree.

  Returns:
    The DP table, where dp[i][j] is the answer for i ops and j nodes.
  """
    operations = all_operations.get_operations(include_sparse_operations=True)
    max_arity = max(op.num_args for op in operations)
    arity_counts = [0] * (max_arity + 1)

    print('Found {} operations.'.format(len(operations)))
    for arity in range(max_arity + 1):
        arity_counts[arity] = sum(1 for op in operations
                                  if op.num_args == arity)
        print('# operations with arity {}: {}'.format(arity,
                                                      arity_counts[arity]))
    print('\nNum leaf nodes: {}'.format(num_leaf_choices))

    # dp[i][j] = the number of expressions using exactly i ops and j nodes.
    dp = np.zeros((num_ops + 1, num_nodes + 1))

    # The only expressions using 0 ops are single-node leaves.
    dp[0][1] = num_leaf_choices

    for ops in range(1, num_ops + 1):
        for nodes in range(1, num_nodes + 1):
            # The running total number of ways to satisfy # ops and # nodes.
            total = 0

            for arity in range(1, max_arity + 1):
                # The running total number of ways to fill the arguments.
                args_total = 0

                # The ways to allocate remaining ops and nodes to the arguments.
                ops_partitions = utils.generate_partitions(num_elements=ops -
                                                           1,
                                                           num_parts=arity)
                nodes_partitions = utils.generate_partitions(
                    num_elements=nodes - 1, num_parts=arity)

                for ops_partition, nodes_partition in itertools.product(
                        ops_partitions, nodes_partitions):
                    # The i-th argument must have ops_partition[i] ops and
                    # nodes_partition[i] nodes. Look up the number of ways in the DP
                    # table.
                    args_total += np.prod([
                        dp[ops_partition[i]][nodes_partition[i]]
                        for i in range(arity)
                    ])

                # There are arity_counts[arity] choices for the outermost operation.
                total += args_total * arity_counts[arity]

            dp[ops][nodes] = total

    print(
        '\nThere are {} expressions using exactly {} ops and {} nodes.'.format(
            dp[num_ops][num_nodes], num_ops, num_nodes))
    print('There are {} expressions using at most {} ops and {} nodes.'.format(
        np.sum(dp), num_ops, num_nodes))
    return dp
 def test_generate_partitions(self, num_elements, num_parts,
                              expected_result):
     actual = list(
         tf_coder_utils.generate_partitions(num_elements, num_parts))
     self.assertCountEqual(actual, expected_result)
    def enumerate_values_with_weight(
        self,
        target_weight: int,
        values_by_weight: ValuesByWeightDict,
        filter_cache: filtered_values_cache.FilteredValuesCache,
        end_time: float,
        settings: settings_module.Settings,
        statistics: Optional[operation_statistics.OperationStatistics] = None
    ) -> List[value.Value]:
        """Enumerates values with a given target weight.

    Args:
      target_weight: The desired weight of resulting values.
      values_by_weight: A collection of Values organized by their weight.
      filter_cache: The FilteredValuesCache object used during this search.
      end_time: A timeit.default_timer() cutoff where this should timeout.
      settings: A Settings object storing settings for this search.
      statistics: An optional OperationStatistics object to track statistics
        during this function's execution.

    Returns:
      A list of Value objects of the specified weight.
    """
        num_args = self.num_args
        if num_args == 0:
            return [
            ]  # An operation with no arguments can't have variable weight.
        if target_weight - self.weight - num_args < 0:
            return []  # Too many arguments for this weight.

        results = []  # type: List[value.Value]

        for value_filters in self._value_filters_list:
            assert len(value_filters) == num_args

            # Enumerate ways of partitioning (target_weight - self.weight) into
            # (num_args) positive pieces.
            # Equivalently, partition (target_weight - self.weight - num_args) into
            # (num_args) nonnegative pieces.
            arg_options_list = []  # type: List[ArgOptionsType]
            for partition in tf_coder_utils.generate_partitions(
                    target_weight - self.weight - num_args,
                    num_args):  # type: Tuple[int, ...]
                if (settings.paper_experiments.skip_filtering
                        and self.name not in tf_functions.REQUIRES_FILTERING):
                    # Only for experiments in the PLDI paper.
                    arg_options = [
                        values_by_weight[weight_minus_1 + 1]
                        for arg, weight_minus_1 in enumerate(partition)
                    ]  # type: ArgOptionsType
                else:
                    arg_options = [
                        filter_cache.filter_values(
                            value_filters[arg], weight_minus_1 + 1,
                            values_by_weight[weight_minus_1 + 1])
                        for arg, weight_minus_1 in enumerate(partition)
                    ]  # type: ArgOptionsType
                arg_options_list.append(arg_options)

            for arg_options in arg_options_list:
                results.extend(
                    self._enumerate_values(arg_options, end_time, settings,
                                           statistics))
        return results