def __init__( self, output_boundary: List[tf.Operation], gamma_threshold, hardware, batch_size=1, regularizer_decorator: Type[generic_regularizers.OpRegularizer] = None, decorator_parameters=None, input_boundary: List[tf.Operation] = None, force_group=None, regularizer_blacklist=None) -> None: """Creates a GammaLatencyRegularizer object. Latency cost and regularization loss is calculated for a specified hardware platform. Args: output_boundary: An OpRegularizer will be created for all these operations, and recursively for all ops they depend on via data dependency that does not involve ops from input_boundary. gamma_threshold: A float scalar, will be used as a 'gamma_threshold' for all instances GammaL1Regularizer created by this class. hardware: String name of hardware platform to target. Must be a key from resource_function.PEAK_COMPUTE. batch_size: Integer batch size to calculate cost/loss for. regularizer_decorator: A string, the name of the regularizer decorators to use. Supported decorators are listed in op_regularizer_decorator.SUPPORTED_DECORATORS. decorator_parameters: A dictionary of parameters to pass to the decorator factory. To be used only with decorators that requires parameters, otherwise use None. input_boundary: A list of ops that represent the input boundary of the subgraph being regularized (input boundary is not regularized). force_group: List of regex for ops that should be force-grouped. Each regex corresponds to a separate group. Use '|' operator to specify multiple patterns in a single regex. See op_regularizer_manager for more detail. regularizer_blacklist: List of regex for ops that should not be regularized. See op_regularizer_manager for more detail. """ source_op_handler = batch_norm_source_op_handler.BatchNormSourceOpHandler( gamma_threshold) if regularizer_decorator: source_op_handler = op_handler_decorator.OpHandlerDecorator( source_op_handler, regularizer_decorator, decorator_parameters) op_handler_dict = op_handlers.get_gamma_op_handler_dict() op_handler_dict.update({ 'FusedBatchNorm': source_op_handler, 'FusedBatchNormV2': source_op_handler, 'FusedBatchNormV3': source_op_handler, }) self._manager = orm.OpRegularizerManager( output_boundary, op_handler_dict, input_boundary=input_boundary, force_group=force_group, regularizer_blacklist=regularizer_blacklist) self._calculator = cost_calculator.CostCalculator( self._manager, resource_function.latency_function_factory(hardware, batch_size)) self._hardware = hardware
def __init__(self, output_boundary: List[tf.Operation], regularize_on_mask=True, alive_threshold=0.1, mask_as_alive_vector=True, regularizer_decorator: Type[ generic_regularizers.OpRegularizer] = None, decorator_parameters=None, input_boundary: List[tf.Operation] = None, force_group=None, regularizer_blacklist=None): """Creates a LogisticSigmoidFlopsRegularizer object. Args: output_boundary: An OpRegularizer will be created for all these operations, and recursively for all ops they depend on via data dependency that does not involve ops from input_boundary. regularize_on_mask: Bool. If True uses the binary mask as the regularization vector. Else uses the probability vector. alive_threshold: Float. Threshold below which values are considered dead. This can be used both when mask_as_alive_vector is True and then the threshold is used to binarize the sampled values and when mask_as_alive_vector is False, and then the threshold is on the channel probability. mask_as_alive_vector: Bool. If True use the thresholded sampled mask as the alive vector. Else, use thresholded probabilities from the logits. regularizer_decorator: A class of OpRegularizer decorator to use. decorator_parameters: A dictionary of parameters to pass to the decorator factory. To be used only with decorators that requires parameters, otherwise use None. input_boundary: A list of ops that represent the input boundary of the subgraph being regularized (input boundary is not regularized). force_group: List of regex for ops that should be force-grouped. Each regex corresponds to a separate group. Use '|' operator to specify multiple patterns in a single regex. See op_regularizer_manager for more detail. regularizer_blacklist: List of regex for ops that should not be regularized. See op_regularizer_manager for more detail. """ source_op_handler = ls_handler.LogisticSigmoidSourceOpHandler( regularize_on_mask, alive_threshold, mask_as_alive_vector) if regularizer_decorator: source_op_handler = op_handler_decorator.OpHandlerDecorator( source_op_handler, regularizer_decorator, decorator_parameters) op_handler_dict = op_handlers.get_gamma_op_handler_dict() op_handler_dict.update({ 'LogisticSigmoidGating': source_op_handler, }) self._manager = orm.OpRegularizerManager( output_boundary, op_handler_dict, create_grouping_regularizer=pgr.ProbabilisticGroupingRegularizer, input_boundary=input_boundary, force_group=force_group, regularizer_blacklist=regularizer_blacklist) self._calculator = self.get_calculator()
def test_dict_logic(self): gamma_dict = op_handlers.get_gamma_op_handler_dict() self.assertIn('Conv2D', gamma_dict) self.assertIn('MatMul', gamma_dict) group_lasso_dict = op_handlers.get_group_lasso_op_handler_dict() self.assertNotIn('Conv2D', group_lasso_dict) self.assertNotIn('MatMul', group_lasso_dict) for op in group_lasso_dict: self.assertIn(op, gamma_dict)
def __init__(self, ops, gamma_threshold, regularizer_decorator: Type[ generic_regularizers.OpRegularizer] = None, decorator_parameters=None, input_boundary=None, force_group=None, regularizer_blacklist=None): """Creates a GammaModelSizeRegularizer object. Args: ops: A list of tf.Operation. An OpRegularizer will be created for all the ops in `ops`, and recursively for all ops they depend on via data dependency. Typically `ops` would contain a single tf.Operation, which is the output of the network. gamma_threshold: A float scalar, will be used as a 'gamma_threshold' for all instances GammaL1Regularizer created by this class. regularizer_decorator: A string, the name of the regularizer decorators to use. Supported decorators are listed in op_regularizer_decorator.SUPPORTED_DECORATORS. decorator_parameters: A dictionary of parameters to pass to the decorator factory. To be used only with decorators that requires parameters, otherwise use None. input_boundary: A list of ops that represent the input boundary of the subgraph being regularized (input boundary is not regularized). force_group: List of regex for ops that should be force-grouped. Each regex corresponds to a separate group. Use '|' operator to specify multiple patterns in a single regex. See op_regularizer_manager for more detail. regularizer_blacklist: List of regex for ops that should not be regularized. See op_regularizer_manager for more detail. """ source_op_handler = batch_norm_source_op_handler.BatchNormSourceOpHandler( gamma_threshold) if regularizer_decorator: source_op_handler = op_handler_decorator.OpHandlerDecorator( source_op_handler, regularizer_decorator, decorator_parameters) op_handler_dict = op_handlers.get_gamma_op_handler_dict() op_handler_dict.update({ 'FusedBatchNorm': source_op_handler, 'FusedBatchNormV2': source_op_handler, }) self._manager = orm.OpRegularizerManager( ops, op_handler_dict, input_boundary=input_boundary, force_group=force_group, regularizer_blacklist=regularizer_blacklist) self._calculator = cost_calculator.CostCalculator( self._manager, resource_function.model_size_function)
def __init__(self, output_boundary: List[tf.Operation], gamma_threshold, regularizer_decorator: Type[ generic_regularizers.OpRegularizer] = None, decorator_parameters=None, input_boundary: List[tf.Operation] = None, force_group=None, regularizer_blacklist=None): """Creates a GammaActivationRegularizer object. Args: output_boundary: An OpRegularizer will be created for all these operations, and recursively for all ops they depend on via data dependency that does not involve ops from input_boundary. gamma_threshold: A float scalar, will be used as a 'gamma_threshold' for all instances GammaL1Regularizer created by this class. regularizer_decorator: A class of OpRegularizer decorator to use. decorator_parameters: A dictionary of parameters to pass to the decorator factory. To be used only with decorators that requires parameters, otherwise use None. input_boundary: A list of ops that represent the input boundary of the subgraph being regularized (input boundary is not regularized). force_group: List of regex for ops that should be force-grouped. Each regex corresponds to a separate group. Use '|' operator to specify multiple patterns in a single regex. See op_regularizer_manager for more detail. regularizer_blacklist: List of regex for ops that should not be regularized. See op_regularizer_manager for more detail. """ source_op_handler = batch_norm_source_op_handler.BatchNormSourceOpHandler( gamma_threshold) if regularizer_decorator: source_op_handler = op_handler_decorator.OpHandlerDecorator( source_op_handler, regularizer_decorator, decorator_parameters) op_handler_dict = op_handlers.get_gamma_op_handler_dict() op_handler_dict.update({ 'FusedBatchNorm': source_op_handler, 'FusedBatchNormV2': source_op_handler, 'FusedBatchNormV3': source_op_handler, }) self._manager = orm.OpRegularizerManager( output_boundary, op_handler_dict, input_boundary=input_boundary, force_group=force_group, regularizer_blacklist=regularizer_blacklist) self._calculator = cost_calculator.CostCalculator( self._manager, resource_function.activation_count_function)