Exemplo n.º 1
0
  def __init__(
      self,
      output_boundary: List[tf.Operation],
      gamma_threshold,
      hardware,
      batch_size=1,
      regularizer_decorator: Type[generic_regularizers.OpRegularizer] = None,
      decorator_parameters=None,
      input_boundary: List[tf.Operation] = None,
      force_group=None,
      regularizer_blacklist=None) -> None:
    """Creates a GammaLatencyRegularizer object.

    Latency cost and regularization loss is calculated for a specified hardware
    platform.

    Args:
      output_boundary: An OpRegularizer will be created for all these
        operations, and recursively for all ops they depend on via data
        dependency that does not involve ops from input_boundary.
      gamma_threshold: A float scalar, will be used as a 'gamma_threshold' for
        all instances GammaL1Regularizer created by this class.
      hardware: String name of hardware platform to target.  Must be a key from
        resource_function.PEAK_COMPUTE.
      batch_size: Integer batch size to calculate cost/loss for.
      regularizer_decorator: A string, the name of the regularizer decorators
        to use. Supported decorators are listed in
        op_regularizer_decorator.SUPPORTED_DECORATORS.
      decorator_parameters: A dictionary of parameters to pass to the decorator
        factory. To be used only with decorators that requires parameters,
        otherwise use None.
      input_boundary: A list of ops that represent the input boundary of the
        subgraph being regularized (input boundary is not regularized).
      force_group: List of regex for ops that should be force-grouped.  Each
        regex corresponds to a separate group.  Use '|' operator to specify
        multiple patterns in a single regex. See op_regularizer_manager for
        more detail.
      regularizer_blacklist: List of regex for ops that should not be
        regularized. See op_regularizer_manager for more detail.
    """
    source_op_handler = batch_norm_source_op_handler.BatchNormSourceOpHandler(
        gamma_threshold)
    if regularizer_decorator:
      source_op_handler = op_handler_decorator.OpHandlerDecorator(
          source_op_handler, regularizer_decorator,
          decorator_parameters)
    op_handler_dict = op_handlers.get_gamma_op_handler_dict()
    op_handler_dict.update({
        'FusedBatchNorm': source_op_handler,
        'FusedBatchNormV2': source_op_handler,
        'FusedBatchNormV3': source_op_handler,
    })

    self._manager = orm.OpRegularizerManager(
        output_boundary, op_handler_dict, input_boundary=input_boundary,
        force_group=force_group, regularizer_blacklist=regularizer_blacklist)
    self._calculator = cost_calculator.CostCalculator(
        self._manager,
        resource_function.latency_function_factory(hardware, batch_size))
    self._hardware = hardware
Exemplo n.º 2
0
  def __init__(self,
               output_boundary: List[tf.Operation],
               threshold,
               l1_fraction=0.0,
               regularizer_decorator: Optional[Type[
                   generic_regularizers.OpRegularizer]] = None,
               decorator_parameters=None,
               input_boundary: Optional[List[tf.Operation]] = None,
               force_group: Optional[List[Text]] = None,
               regularizer_blacklist: Optional[List[Text]] = None):
    """Creates a GroupLassoModelSizeRegularizer object.

    Args:
      output_boundary: An OpRegularizer will be created for all these
        operations, and recursively for all ops they depend on via data
        dependency that does not involve ops from input_boundary.
      threshold: A float scalar, will be used as a 'threshold' for all
        regularizer instances created by this class.
      l1_fraction: A float scalar.  The relative weight of L1 in L1 + L2
        regularization.
      regularizer_decorator: A class of OpRegularizer decorator to use.
      decorator_parameters: A dictionary of parameters to pass to the decorator
        factory. To be used only with decorators that requires parameters,
        otherwise use None.
      input_boundary: A list of ops that represent the input boundary of the
        subgraph being regularized (input boundary is not regularized).
      force_group: List of regex for ops that should be force-grouped.  Each
        regex corresponds to a separate group.  Use '|' operator to specify
        multiple patterns in a single regex. See op_regularizer_manager for more
        detail.
      regularizer_blacklist: List of regex for ops that should not be
        regularized. See op_regularizer_manager for more detail.
    """
    custom_handlers = {
        'Conv2D':
            conv_handler.ConvSourceOpHandler(threshold, l1_fraction),
        'Conv3D':
            conv_handler.ConvSourceOpHandler(threshold, l1_fraction),
        'Conv2DBackpropInput':
            conv2d_transpose_handler.Conv2DTransposeSourceOpHandler(
                threshold, l1_fraction),
        'MatMul':
            matmul_handler.MatMulSourceOpHandler(threshold, l1_fraction)
    }
    if regularizer_decorator:
      for key in custom_handlers:
        custom_handlers[key] = op_handler_decorator.OpHandlerDecorator(
            custom_handlers[key], regularizer_decorator, decorator_parameters)

    op_handler_dict = op_handlers.get_group_lasso_op_handler_dict()
    op_handler_dict.update(custom_handlers)

    self._manager = orm.OpRegularizerManager(
        output_boundary,
        op_handler_dict,
        input_boundary=input_boundary,
        force_group=force_group,
        regularizer_blacklist=regularizer_blacklist)
    self._calculator = cost_calculator.CostCalculator(
        self._manager, resource_function.model_size_function)
Exemplo n.º 3
0
  def testImageIsNotZerothOutputOfOp(self):
    # Throughout the framework, we assume that the 0th output of each op is the
    # only one of interest. One exception that often happens is when the input
    # image comes from a queue or from a staging op. Then the image is one of
    # the outputs of the dequeue (or staging) op, not necessarily the 0th one.
    # Here we test that the BilinearNetworkRegularizer deals correctly with this
    # case.

    # Create an input op where the image is output number 1, not 0.
    # TODO(g1) Move this mechanism to add_concat_model_stub, possibly using
    # tf.split to produce an op where the image is not the 0th output image
    # (instead of FIFOQueue).
    image = add_concat_model_stub.image_stub()
    non_image_tensor = tf.zeros(shape=(41,))
    queue = tf.FIFOQueue(
        capacity=1,
        dtypes=(tf.float32,) * 2,
        shapes=(non_image_tensor.shape, image.shape))

    # Pass the image (output[1]) to the network.
    with arg_scope(self._batch_norm_scope()):
      output_op = add_concat_model_stub.build_model(queue.dequeue()[1])

    # Create OpHandler dict for test.
    op_handler_dict = collections.defaultdict(
        grouping_op_handler.GroupingOpHandler)
    op_handler_dict.update({
        'FusedBatchNorm':
        batch_norm_source_op_handler.BatchNormSourceOpHandler(0.1),
        'Conv2D':
        output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
        'ConcatV2':
        concat_op_handler.ConcatOpHandler(),
    })

    # Create OpRegularizerManager and NetworkRegularizer for test.
    manager = orm.OpRegularizerManager([output_op], op_handler_dict)
    calculator = cost_calculator.CostCalculator(
        manager, resource_function.flop_function)

    # Calculate expected FLOP cost.
    expected_alive_conv1 = sum(add_concat_model_stub.expected_alive()['conv1'])
    conv1_op = tf.get_default_graph().get_operation_by_name('conv1/Conv2D')
    conv1_coeff = resource_function.flop_coeff(conv1_op)
    num_channels = 3
    expected_cost = conv1_coeff * num_channels * expected_alive_conv1

    with self.session():
      tf.global_variables_initializer().run()
      # Set gamma values to replicate aliveness in add_concat_model_stub.
      name_to_var = {v.op.name: v for v in tf.global_variables()}
      gamma1 = name_to_var['conv1/BatchNorm/gamma']
      gamma1.assign([0, 1, 1, 0, 1, 0, 1]).eval()
      gamma4 = name_to_var['conv4/BatchNorm/gamma']
      gamma4.assign([0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0]).eval()

      queue.enqueue((non_image_tensor, image)).run()
      self.assertEqual(expected_cost,
                       calculator.get_cost([conv1_op]).eval())
Exemplo n.º 4
0
  def __init__(
      self,
      ops,
      threshold,
      l1_fraction=0,
      regularizer_decorator: Type[generic_regularizers.OpRegularizer] = None,
      decorator_parameters=None,
      force_group=None,
      regularizer_blacklist=None):
    """Creates a GroupLassoActivationRegularizer object.

    Args:
      ops: A list of tf.Operation. An OpRegularizer will be created for all the
        ops in `ops`, and recursively for all ops they depend on via data
        dependency. Typically `ops` would contain a single tf.Operation, which
        is the output of the network.
      threshold: A float scalar, will be used as a 'threshold' for all
        regularizer instances created by this class.
      l1_fraction: Relative weight of L1 in L1 + L2 regularization.
      regularizer_decorator: A class of OpRegularizer decorator to use.
      decorator_parameters: A dictionary of parameters to pass to the decorator
        factory. To be used only with decorators that requires parameters,
        otherwise use None.
      force_group: List of regex for ops that should be force-grouped.  Each
        regex corresponds to a separate group.  Use '|' operator to specify
        multiple patterns in a single regex. See op_regularizer_manager for more
        detail.
      regularizer_blacklist: List of regex for ops that should not be
        regularized. See op_regularizer_manager for more detail.
    """
    conv2d_handler = conv2d_source_op_handler.Conv2DSourceOpHandler(
        threshold, l1_fraction)
    conv2d_transpose_handler = (
        conv2d_transpose_source_op_handler.Conv2DTransposeSourceOpHandler(
            threshold, l1_fraction))
    matmul_handler = matmul_source_op_handler.MatMulSourceOpHandler(
        threshold, l1_fraction)
    if regularizer_decorator:
      conv2d_handler = op_handler_decorator.OpHandlerDecorator(
          conv2d_handler, regularizer_decorator, decorator_parameters)
      conv2d_transpose_handler = op_handler_decorator.OpHandlerDecorator(
          conv2d_transpose_handler, regularizer_decorator, decorator_parameters)
      matmul_handler = op_handler_decorator.OpHandlerDecorator(
          matmul_handler, regularizer_decorator, decorator_parameters)

    op_handler_dict = op_handlers.get_group_lasso_op_handler_dict()
    op_handler_dict.update({
        'Conv2D': conv2d_handler,
        'Conv2DBackpropInput': conv2d_transpose_handler,
        'MatMul': matmul_handler,
    })

    self._manager = orm.OpRegularizerManager(
        ops,
        op_handler_dict,
        force_group=force_group,
        regularizer_blacklist=regularizer_blacklist)
    self._calculator = cost_calculator.CostCalculator(
        self._manager, resource_function.activation_count_function)
    def __init__(self,
                 output_boundary: List[tf.Operation],
                 regularize_on_mask=True,
                 alive_threshold=0.1,
                 mask_as_alive_vector=True,
                 regularizer_decorator: Type[
                     generic_regularizers.OpRegularizer] = None,
                 decorator_parameters=None,
                 input_boundary: List[tf.Operation] = None,
                 force_group=None,
                 regularizer_blacklist=None):
        """Creates a LogisticSigmoidFlopsRegularizer object.

    Args:
      output_boundary: An OpRegularizer will be created for all these
        operations, and recursively for all ops they depend on via data
        dependency that does not involve ops from input_boundary.
      regularize_on_mask: Bool. If True uses the binary mask as the
        regularization vector. Else uses the probability vector.
      alive_threshold: Float. Threshold below which values are considered dead.
        This can be used both when mask_as_alive_vector is True and then the
        threshold is used to binarize the sampled values and
        when mask_as_alive_vector is False, and then the threshold is on the
        channel probability.
      mask_as_alive_vector: Bool. If True use the thresholded sampled mask
        as the alive vector. Else, use thresholded probabilities from the
        logits.
      regularizer_decorator: A class of OpRegularizer decorator to use.
      decorator_parameters: A dictionary of parameters to pass to the decorator
        factory. To be used only with decorators that requires parameters,
        otherwise use None.
      input_boundary: A list of ops that represent the input boundary of the
        subgraph being regularized (input boundary is not regularized).
      force_group: List of regex for ops that should be force-grouped.  Each
        regex corresponds to a separate group.  Use '|' operator to specify
        multiple patterns in a single regex. See op_regularizer_manager for more
        detail.
      regularizer_blacklist: List of regex for ops that should not be
        regularized. See op_regularizer_manager for more detail.
    """
        source_op_handler = ls_handler.LogisticSigmoidSourceOpHandler(
            regularize_on_mask, alive_threshold, mask_as_alive_vector)
        if regularizer_decorator:
            source_op_handler = op_handler_decorator.OpHandlerDecorator(
                source_op_handler, regularizer_decorator, decorator_parameters)
        op_handler_dict = op_handlers.get_gamma_op_handler_dict()
        op_handler_dict.update({
            'LogisticSigmoidGating': source_op_handler,
        })

        self._manager = orm.OpRegularizerManager(
            output_boundary,
            op_handler_dict,
            create_grouping_regularizer=pgr.ProbabilisticGroupingRegularizer,
            input_boundary=input_boundary,
            force_group=force_group,
            regularizer_blacklist=regularizer_blacklist)
        self._calculator = self.get_calculator()
Exemplo n.º 6
0
 def __init__(self, ops, gamma_threshold):
     gamma_l1_reg_factory = gamma_l1_regularizer.GammaL1RegularizerFactory(
         gamma_threshold)
     opreg_manager = op_regularizer_manager.OpRegularizerManager(
         ops, {
             'Conv2D': gamma_l1_reg_factory.create_regularizer,
             'DepthwiseConv2dNative':
             gamma_l1_reg_factory.create_regularizer
         })
     super(GammaFlopsRegularizer,
           self).__init__(opreg_manager, bilinear_cost_utils.flop_coeff)
    def __init__(self,
                 ops,
                 gamma_threshold,
                 regularizer_decorator: Type[
                     generic_regularizers.OpRegularizer] = None,
                 decorator_parameters=None,
                 input_boundary=None,
                 force_group=None,
                 regularizer_blacklist=None):
        """Creates a GammaModelSizeRegularizer object.

    Args:
      ops: A list of tf.Operation. An OpRegularizer will be created for all
        the ops in `ops`, and recursively for all ops they depend on via data
        dependency. Typically `ops` would contain a single tf.Operation, which
        is the output of the network.
      gamma_threshold: A float scalar, will be used as a 'gamma_threshold' for
        all instances GammaL1Regularizer created by this class.
      regularizer_decorator: A string, the name of the regularizer decorators
        to use. Supported decorators are listed in
        op_regularizer_decorator.SUPPORTED_DECORATORS.
      decorator_parameters: A dictionary of parameters to pass to the decorator
        factory. To be used only with decorators that requires parameters,
        otherwise use None.
      input_boundary: A list of ops that represent the input boundary of the
        subgraph being regularized (input boundary is not regularized).
      force_group: List of regex for ops that should be force-grouped.  Each
        regex corresponds to a separate group.  Use '|' operator to specify
        multiple patterns in a single regex. See op_regularizer_manager for
        more detail.
      regularizer_blacklist: List of regex for ops that should not be
        regularized. See op_regularizer_manager for more detail.
    """
        source_op_handler = batch_norm_source_op_handler.BatchNormSourceOpHandler(
            gamma_threshold)
        if regularizer_decorator:
            source_op_handler = op_handler_decorator.OpHandlerDecorator(
                source_op_handler, regularizer_decorator, decorator_parameters)
        op_handler_dict = op_handlers.get_gamma_op_handler_dict()
        op_handler_dict.update({
            'FusedBatchNorm': source_op_handler,
            'FusedBatchNormV2': source_op_handler,
        })

        self._manager = orm.OpRegularizerManager(
            ops,
            op_handler_dict,
            input_boundary=input_boundary,
            force_group=force_group,
            regularizer_blacklist=regularizer_blacklist)
        self._calculator = cost_calculator.CostCalculator(
            self._manager, resource_function.model_size_function)
Exemplo n.º 8
0
 def __init__(self, ops, threshold):
     # Regularizer factories for convolution and fully connected layers.
     conv_regularizer_factory = (
         conv_group_lasso_regularizer.ConvGroupLassoRegularizerFactory(
             threshold))
     regularizer_factories = {
         'Conv2D': conv_regularizer_factory.create_regularizer,
         'Conv2DBackpropInput': conv_regularizer_factory.create_regularizer,
     }
     # Create OpRegularizerManager instance.
     opreg_manager = op_regularizer_manager.OpRegularizerManager(
         ops, regularizer_factories)
     super(GroupLassoFlopsRegularizer,
           self).__init__(opreg_manager, bilinear_cost_utils.flop_coeff)
    def testGetRegularizerForConcatWithNone(self, test_concat, depth):
        image = tf.constant(0.0, shape=[1, 17, 19, 3])
        conv2 = layers.conv2d(image, 5, [1, 1], padding='SAME', scope='conv2')
        other_input = tf.add(
            tf.identity(tf.constant(3.0, shape=[1, 17, 19, depth])), 3.0)
        # other_input has None as regularizer.
        concat = tf.concat([other_input, conv2], 3)
        output = tf.add(concat, concat, name='output_out')
        op = concat.op if test_concat else output.op
        op_reg_manager = orm.OpRegularizerManager(
            [output.op], op_regularizer_stub.MOCK_REG_DICT)
        expected_alive = op_regularizer_stub.expected_alive()

        with self.test_session():
            alive = op_reg_manager.get_regularizer(op).alive_vector.eval()
            self.assertAllEqual([True] * depth, alive[:depth])
            self.assertAllEqual(expected_alive['conv2'], alive[depth:])
Exemplo n.º 10
0
    def __init__(self,
                 ops,
                 gamma_threshold,
                 regularizer_decorator: Type[
                     generic_regularizers.OpRegularizer] = None,
                 decorator_parameters=None,
                 force_group=None,
                 regularizer_blacklist=None):
        """Creates a GammaFlopsRegularizer object.

    Args:
      ops: A list of tf.Operation. An OpRegularizer will be created for all the
        ops in `ops`, and recursively for all ops they depend on via data
        dependency. Typically `ops` would contain a single tf.Operation, which
        is the output of the network.
      gamma_threshold: A float scalar, will be used as a 'gamma_threshold' for
        all instances GammaL1Regularizer created by this class.
      regularizer_decorator: A class of OpRegularizer decorator to use.
      decorator_parameters: A dictionary of parameters to pass to the decorator
        factory. To be used only with decorators that requires parameters,
        otherwise use None.
      force_group: List of regex for ops that should be force-grouped.  Each
        regex corresponds to a separate group.  Use '|' operator to specify
        multiple patterns in a single regex. See op_regularizer_manager for more
        detail.
      regularizer_blacklist: List of regex for ops that should not be
        regularized. See op_regularizer_manager for more detail.
    """
        source_op_handler = batch_norm_source_op_handler.BatchNormSourceOpHandler(
            gamma_threshold)
        if regularizer_decorator:
            source_op_handler = op_handler_decorator.OpHandlerDecorator(
                source_op_handler, regularizer_decorator, decorator_parameters)
        op_handler_dict = collections.defaultdict(
            grouping_op_handler.GroupingOpHandler)
        op_handler_dict.update({
            'FusedBatchNorm':
            source_op_handler,
            'FusedBatchNormV2':
            source_op_handler,
            'Conv2D':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
            'ConcatV2':
            concat_op_handler.ConcatOpHandler(),
            'DepthToSpace':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
            'DepthwiseConv2dNative':
            depthwise_convolution_op_handler.DepthwiseConvolutionOpHandler(),
            'MatMul':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
            'TensorArrayGatherV3':
            leaf_op_handler.LeafOpHandler(),
            'RandomUniform':
            leaf_op_handler.LeafOpHandler(),
            'Reshape':
            leaf_op_handler.LeafOpHandler(),
            'Transpose':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
            'ExpandDims':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
        })

        self._manager = orm.OpRegularizerManager(
            ops,
            op_handler_dict,
            force_group=force_group,
            regularizer_blacklist=regularizer_blacklist)
        self._calculator = cost_calculator.CostCalculator(
            self._manager, resource_function.flop_function)
Exemplo n.º 11
0
    def __init__(self,
                 ops,
                 threshold,
                 l1_fraction=0,
                 regularizer_decorator: Type[
                     generic_regularizers.OpRegularizer] = None,
                 decorator_parameters=None,
                 force_group=None,
                 regularizer_blacklist=None,
                 convert_to_variable=True):
        """Creates a GroupLassoFlopsRegularizer object.

    Args:
      ops: A list of tf.Operation. An OpRegularizer will be created for all the
        ops in `ops`, and recursively for all ops they depend on via data
        dependency. Typically `ops` would contain a single tf.Operation, which
        is the output of the network.
      threshold: A float scalar, will be used as a 'threshold' for all
        regularizer instances created by this class.
      l1_fraction: Relative weight of L1 in L1 + L2 regularization.
      regularizer_decorator: A class of OpRegularizer decorator to use.
      decorator_parameters: A dictionary of parameters to pass to the decorator
        factory. To be used only with decorators that requires parameters,
        otherwise use None.
      force_group: List of regex for ops that should be force-grouped.  Each
        regex corresponds to a separate group.  Use '|' operator to specify
        multiple patterns in a single regex. See op_regularizer_manager for more
        detail.
      regularizer_blacklist: List of regex for ops that should not be
        regularized. See op_regularizer_manager for more detail.
      convert_to_variable: If `True` convert to variable in the
        `GroupLassoBaseOpHandler`. If you're graph creates variables outside of
        `tf.get_variable`, set to `False`.
    """
        conv2d_handler = conv2d_source_op_handler.Conv2DSourceOpHandler(
            threshold, l1_fraction, convert_to_variable)
        conv2d_transpose_handler = (
            conv2d_transpose_source_op_handler.Conv2DTransposeSourceOpHandler(
                threshold, l1_fraction, convert_to_variable))
        matmul_handler = matmul_source_op_handler.MatMulSourceOpHandler(
            threshold, l1_fraction, convert_to_variable)
        if regularizer_decorator:
            conv2d_handler = op_handler_decorator.OpHandlerDecorator(
                conv2d_handler, regularizer_decorator, decorator_parameters)
            conv2d_transpose_handler = op_handler_decorator.OpHandlerDecorator(
                conv2d_transpose_handler, regularizer_decorator,
                decorator_parameters)
            matmul_handler = op_handler_decorator.OpHandlerDecorator(
                matmul_handler, regularizer_decorator, decorator_parameters)

        op_handler_dict = collections.defaultdict(
            grouping_op_handler.GroupingOpHandler)
        op_handler_dict.update({
            'Conv2D':
            conv2d_handler,
            'Conv2DBackpropInput':
            conv2d_transpose_handler,
            'ConcatV2':
            concat_op_handler.ConcatOpHandler(),
            'DepthToSpace':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
            'DepthwiseConv2dNative':
            depthwise_convolution_op_handler.DepthwiseConvolutionOpHandler(),
            'MatMul':
            matmul_handler,
            'RandomUniform':
            leaf_op_handler.LeafOpHandler(),
            'Reshape':
            leaf_op_handler.LeafOpHandler(),
            'Shape':
            leaf_op_handler.LeafOpHandler(),
            'TensorArrayGatherV3':
            leaf_op_handler.LeafOpHandler(),
            'Transpose':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
            'StridedSlice':
            leaf_op_handler.LeafOpHandler(),
        })

        self._manager = orm.OpRegularizerManager(
            ops,
            op_handler_dict,
            force_group=force_group,
            regularizer_blacklist=regularizer_blacklist)
        self._calculator = cost_calculator.CostCalculator(
            self._manager, resource_function.flop_function)
Exemplo n.º 12
0
    def __init__(self,
                 ops,
                 gamma_threshold,
                 hardware,
                 batch_size=1,
                 regularizer_decorator: Type[
                     generic_regularizers.OpRegularizer] = None,
                 decorator_parameters=None,
                 force_group=None,
                 regularizer_blacklist=None) -> None:
        """Creates a GammaLatencyRegularizer object.

    Latency cost and regularization loss is calculated for a specified hardware
    platform.

    Args:
      ops: A list of tf.Operation. An OpRegularizer will be created for all
        the ops in `ops`, and recursively for all ops they depend on via data
        dependency. Typically `ops` would contain a single tf.Operation, which
        is the output of the network.
      gamma_threshold: A float scalar, will be used as a 'gamma_threshold' for
        all instances GammaL1Regularizer created by this class.
      hardware: String name of hardware platform to target.  Must be a key from
        resource_function.PEAK_COMPUTE.
      batch_size: Integer batch size to calculate cost/loss for.
      regularizer_decorator: A string, the name of the regularizer decorators
        to use. Supported decorators are listed in
        op_regularizer_decorator.SUPPORTED_DECORATORS.
      decorator_parameters: A dictionary of parameters to pass to the decorator
        factory. To be used only with decorators that requires parameters,
        otherwise use None.
      force_group: List of regex for ops that should be force-grouped.  Each
        regex corresponds to a separate group.  Use '|' operator to specify
        multiple patterns in a single regex. See op_regularizer_manager for
        more detail.
      regularizer_blacklist: List of regex for ops that should not be
        regularized. See op_regularizer_manager for more detail.
    """
        source_op_handler = batch_norm_source_op_handler.BatchNormSourceOpHandler(
            gamma_threshold)
        if regularizer_decorator:
            source_op_handler = op_handler_decorator.OpHandlerDecorator(
                source_op_handler, regularizer_decorator, decorator_parameters)
        op_handler_dict = collections.defaultdict(
            grouping_op_handler.GroupingOpHandler)
        op_handler_dict.update({
            'FusedBatchNorm':
            source_op_handler,
            'FusedBatchNormV2':
            source_op_handler,
            'Conv2D':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
            'ConcatV2':
            concat_op_handler.ConcatOpHandler(),
            'DepthToSpace':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
            'DepthwiseConv2dNative':
            depthwise_convolution_op_handler.DepthwiseConvolutionOpHandler(),
            'MatMul':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
            'TensorArrayGatherV3':
            leaf_op_handler.LeafOpHandler(),
            'Transpose':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
        })

        self._manager = orm.OpRegularizerManager(
            ops,
            op_handler_dict,
            force_group=force_group,
            regularizer_blacklist=regularizer_blacklist)
        self._calculator = cost_calculator.CostCalculator(
            self._manager,
            resource_function.latency_function_factory(hardware, batch_size))
        self._hardware = hardware