Example #1
0
  def __init__(
      self,
      output_boundary: List[tf.Operation],
      gamma_threshold,
      hardware,
      batch_size=1,
      regularizer_decorator: Type[generic_regularizers.OpRegularizer] = None,
      decorator_parameters=None,
      input_boundary: List[tf.Operation] = None,
      force_group=None,
      regularizer_blacklist=None) -> None:
    """Creates a GammaLatencyRegularizer object.

    Latency cost and regularization loss is calculated for a specified hardware
    platform.

    Args:
      output_boundary: An OpRegularizer will be created for all these
        operations, and recursively for all ops they depend on via data
        dependency that does not involve ops from input_boundary.
      gamma_threshold: A float scalar, will be used as a 'gamma_threshold' for
        all instances GammaL1Regularizer created by this class.
      hardware: String name of hardware platform to target.  Must be a key from
        resource_function.PEAK_COMPUTE.
      batch_size: Integer batch size to calculate cost/loss for.
      regularizer_decorator: A string, the name of the regularizer decorators
        to use. Supported decorators are listed in
        op_regularizer_decorator.SUPPORTED_DECORATORS.
      decorator_parameters: A dictionary of parameters to pass to the decorator
        factory. To be used only with decorators that requires parameters,
        otherwise use None.
      input_boundary: A list of ops that represent the input boundary of the
        subgraph being regularized (input boundary is not regularized).
      force_group: List of regex for ops that should be force-grouped.  Each
        regex corresponds to a separate group.  Use '|' operator to specify
        multiple patterns in a single regex. See op_regularizer_manager for
        more detail.
      regularizer_blacklist: List of regex for ops that should not be
        regularized. See op_regularizer_manager for more detail.
    """
    source_op_handler = batch_norm_source_op_handler.BatchNormSourceOpHandler(
        gamma_threshold)
    if regularizer_decorator:
      source_op_handler = op_handler_decorator.OpHandlerDecorator(
          source_op_handler, regularizer_decorator,
          decorator_parameters)
    op_handler_dict = op_handlers.get_gamma_op_handler_dict()
    op_handler_dict.update({
        'FusedBatchNorm': source_op_handler,
        'FusedBatchNormV2': source_op_handler,
        'FusedBatchNormV3': source_op_handler,
    })

    self._manager = orm.OpRegularizerManager(
        output_boundary, op_handler_dict, input_boundary=input_boundary,
        force_group=force_group, regularizer_blacklist=regularizer_blacklist)
    self._calculator = cost_calculator.CostCalculator(
        self._manager,
        resource_function.latency_function_factory(hardware, batch_size))
    self._hardware = hardware
    def __init__(self,
                 output_boundary: List[tf.Operation],
                 regularize_on_mask=True,
                 alive_threshold=0.1,
                 mask_as_alive_vector=True,
                 regularizer_decorator: Type[
                     generic_regularizers.OpRegularizer] = None,
                 decorator_parameters=None,
                 input_boundary: List[tf.Operation] = None,
                 force_group=None,
                 regularizer_blacklist=None):
        """Creates a LogisticSigmoidFlopsRegularizer object.

    Args:
      output_boundary: An OpRegularizer will be created for all these
        operations, and recursively for all ops they depend on via data
        dependency that does not involve ops from input_boundary.
      regularize_on_mask: Bool. If True uses the binary mask as the
        regularization vector. Else uses the probability vector.
      alive_threshold: Float. Threshold below which values are considered dead.
        This can be used both when mask_as_alive_vector is True and then the
        threshold is used to binarize the sampled values and
        when mask_as_alive_vector is False, and then the threshold is on the
        channel probability.
      mask_as_alive_vector: Bool. If True use the thresholded sampled mask
        as the alive vector. Else, use thresholded probabilities from the
        logits.
      regularizer_decorator: A class of OpRegularizer decorator to use.
      decorator_parameters: A dictionary of parameters to pass to the decorator
        factory. To be used only with decorators that requires parameters,
        otherwise use None.
      input_boundary: A list of ops that represent the input boundary of the
        subgraph being regularized (input boundary is not regularized).
      force_group: List of regex for ops that should be force-grouped.  Each
        regex corresponds to a separate group.  Use '|' operator to specify
        multiple patterns in a single regex. See op_regularizer_manager for more
        detail.
      regularizer_blacklist: List of regex for ops that should not be
        regularized. See op_regularizer_manager for more detail.
    """
        source_op_handler = ls_handler.LogisticSigmoidSourceOpHandler(
            regularize_on_mask, alive_threshold, mask_as_alive_vector)
        if regularizer_decorator:
            source_op_handler = op_handler_decorator.OpHandlerDecorator(
                source_op_handler, regularizer_decorator, decorator_parameters)
        op_handler_dict = op_handlers.get_gamma_op_handler_dict()
        op_handler_dict.update({
            'LogisticSigmoidGating': source_op_handler,
        })

        self._manager = orm.OpRegularizerManager(
            output_boundary,
            op_handler_dict,
            create_grouping_regularizer=pgr.ProbabilisticGroupingRegularizer,
            input_boundary=input_boundary,
            force_group=force_group,
            regularizer_blacklist=regularizer_blacklist)
        self._calculator = self.get_calculator()
Example #3
0
 def test_dict_logic(self):
     gamma_dict = op_handlers.get_gamma_op_handler_dict()
     self.assertIn('Conv2D', gamma_dict)
     self.assertIn('MatMul', gamma_dict)
     group_lasso_dict = op_handlers.get_group_lasso_op_handler_dict()
     self.assertNotIn('Conv2D', group_lasso_dict)
     self.assertNotIn('MatMul', group_lasso_dict)
     for op in group_lasso_dict:
         self.assertIn(op, gamma_dict)
    def __init__(self,
                 ops,
                 gamma_threshold,
                 regularizer_decorator: Type[
                     generic_regularizers.OpRegularizer] = None,
                 decorator_parameters=None,
                 input_boundary=None,
                 force_group=None,
                 regularizer_blacklist=None):
        """Creates a GammaModelSizeRegularizer object.

    Args:
      ops: A list of tf.Operation. An OpRegularizer will be created for all
        the ops in `ops`, and recursively for all ops they depend on via data
        dependency. Typically `ops` would contain a single tf.Operation, which
        is the output of the network.
      gamma_threshold: A float scalar, will be used as a 'gamma_threshold' for
        all instances GammaL1Regularizer created by this class.
      regularizer_decorator: A string, the name of the regularizer decorators
        to use. Supported decorators are listed in
        op_regularizer_decorator.SUPPORTED_DECORATORS.
      decorator_parameters: A dictionary of parameters to pass to the decorator
        factory. To be used only with decorators that requires parameters,
        otherwise use None.
      input_boundary: A list of ops that represent the input boundary of the
        subgraph being regularized (input boundary is not regularized).
      force_group: List of regex for ops that should be force-grouped.  Each
        regex corresponds to a separate group.  Use '|' operator to specify
        multiple patterns in a single regex. See op_regularizer_manager for
        more detail.
      regularizer_blacklist: List of regex for ops that should not be
        regularized. See op_regularizer_manager for more detail.
    """
        source_op_handler = batch_norm_source_op_handler.BatchNormSourceOpHandler(
            gamma_threshold)
        if regularizer_decorator:
            source_op_handler = op_handler_decorator.OpHandlerDecorator(
                source_op_handler, regularizer_decorator, decorator_parameters)
        op_handler_dict = op_handlers.get_gamma_op_handler_dict()
        op_handler_dict.update({
            'FusedBatchNorm': source_op_handler,
            'FusedBatchNormV2': source_op_handler,
        })

        self._manager = orm.OpRegularizerManager(
            ops,
            op_handler_dict,
            input_boundary=input_boundary,
            force_group=force_group,
            regularizer_blacklist=regularizer_blacklist)
        self._calculator = cost_calculator.CostCalculator(
            self._manager, resource_function.model_size_function)
Example #5
0
    def __init__(self,
                 output_boundary: List[tf.Operation],
                 gamma_threshold,
                 regularizer_decorator: Type[
                     generic_regularizers.OpRegularizer] = None,
                 decorator_parameters=None,
                 input_boundary: List[tf.Operation] = None,
                 force_group=None,
                 regularizer_blacklist=None):
        """Creates a GammaActivationRegularizer object.

    Args:
      output_boundary: An OpRegularizer will be created for all these
        operations, and recursively for all ops they depend on via data
        dependency that does not involve ops from input_boundary.
      gamma_threshold: A float scalar, will be used as a 'gamma_threshold' for
        all instances GammaL1Regularizer created by this class.
      regularizer_decorator: A class of OpRegularizer decorator to use.
      decorator_parameters: A dictionary of parameters to pass to the decorator
        factory. To be used only with decorators that requires parameters,
        otherwise use None.
      input_boundary: A list of ops that represent the input boundary of the
        subgraph being regularized (input boundary is not regularized).
      force_group: List of regex for ops that should be force-grouped.  Each
        regex corresponds to a separate group.  Use '|' operator to specify
        multiple patterns in a single regex. See op_regularizer_manager for more
        detail.
      regularizer_blacklist: List of regex for ops that should not be
        regularized. See op_regularizer_manager for more detail.
    """
        source_op_handler = batch_norm_source_op_handler.BatchNormSourceOpHandler(
            gamma_threshold)
        if regularizer_decorator:
            source_op_handler = op_handler_decorator.OpHandlerDecorator(
                source_op_handler, regularizer_decorator, decorator_parameters)
        op_handler_dict = op_handlers.get_gamma_op_handler_dict()
        op_handler_dict.update({
            'FusedBatchNorm': source_op_handler,
            'FusedBatchNormV2': source_op_handler,
            'FusedBatchNormV3': source_op_handler,
        })

        self._manager = orm.OpRegularizerManager(
            output_boundary,
            op_handler_dict,
            input_boundary=input_boundary,
            force_group=force_group,
            regularizer_blacklist=regularizer_blacklist)
        self._calculator = cost_calculator.CostCalculator(
            self._manager, resource_function.activation_count_function)