def get_gamma_op_handler_dict(): return collections.defaultdict( grouping_op_handler.GroupingOpHandler, { 'Conv2D': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), 'ConcatV2': concat_op_handler.ConcatOpHandler(), 'DepthToSpace': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), 'DepthwiseConv2dNative': depthwise_convolution_op_handler.DepthwiseConvolutionOpHandler(), 'MatMul': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), 'TensorArrayGatherV3': leaf_op_handler.LeafOpHandler(), 'RandomUniform': leaf_op_handler.LeafOpHandler(), 'Reshape': leaf_op_handler.LeafOpHandler(), 'SpaceToDepth': leaf_op_handler.LeafOpHandler(), 'Shape': leaf_op_handler.LeafOpHandler(), 'Transpose': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), 'ExpandDims': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), })
def testAssignGrouping_NoNeighborGroups(self): # No ops have groups. self.op_group_dict = {} # Call handler to assign grouping. handler = leaf_op_handler.LeafOpHandler() handler.assign_grouping(self.batch_norm_op, self.mock_op_reg_manager) # Verify manager looks up OpSlice for ops of interest. self.mock_op_reg_manager.get_op_slices.assert_has_calls( # Checking for ops to process. [ mock.call(self.relu_op), # Initial slice data. mock.call(self.batch_norm_op), mock.call(self.relu_op), # Reslicing. mock.call(self.batch_norm_op), mock.call(self.relu_op), # Refreshing slice data. mock.call(self.relu_op) ]) # Verify manager groups leaf op. self.mock_op_reg_manager.group_op_slices.assert_called_once_with( [self.batch_norm_op_slice]) # Verify manager processes grouping for batch norm and output ops. self.mock_op_reg_manager.process_ops.assert_called_once_with( [self.relu_op]) self.mock_op_reg_manager.process_ops_last.assert_not_called()
def testAssignGrouping_NonPassthroughOutputsSkipped(self): # Designate ReLU as non-passthrough for this test to demonstrate that batch # norm op does not group with ReLU. def is_passthrough(op): if op == self.relu_op: return False return True self.mock_op_reg_manager.is_passthrough.side_effect = is_passthrough # All neighbor ops have groups. self.op_group_dict = { self.batch_norm_op_slice: self.batch_norm_op_group, self.conv_op_slice: self.conv_op_group, self.relu_op_slice: self.relu_op_group, self.gamma_op_slice: self.conv_op_group, self.beta_op_slice: self.conv_op_group, self.mean_op_slice: self.relu_op_group, self.std_op_slice: self.relu_op_group, } # Call handler to assign grouping. handler = leaf_op_handler.LeafOpHandler() handler.assign_grouping(self.batch_norm_op, self.mock_op_reg_manager) # Verify manager looks up OpSlice for ops of interest. self.mock_op_reg_manager.get_op_slices.assert_has_calls( # Checking for ops to process. [ mock.call(self.relu_op), mock.call(self.mean_op), mock.call(self.std_op), # Initial slice data. mock.call(self.batch_norm_op), mock.call(self.mean_op), mock.call(self.std_op), # Reslicing. mock.call(self.batch_norm_op), mock.call(self.mean_op), mock.call(self.std_op), # Refreshing slice data. mock.call(self.mean_op), mock.call(self.std_op), # Group batch norm op. mock.call(self.batch_norm_op) ]) # Verify manager groups batch norm and outputs. ReLU is not part of the # grouping. self.mock_op_reg_manager.group_op_slices.assert_called_once_with( [self.batch_norm_op_slice, self.mean_op_slice, self.std_op_slice]) # Verify manager does not process any additional ops. self.mock_op_reg_manager.process_ops.assert_not_called() self.mock_op_reg_manager.process_ops_last.assert_not_called()
def _get_base_op_hander_dicts(): """Returns the base op_hander_dict for all regularizers.""" base_dict = collections.defaultdict( grouping_op_handler.GroupingOpHandler, { 'ConcatV2': concat_op_handler.ConcatOpHandler(), 'DepthToSpace': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), 'DepthwiseConv2dNative': depthwise_convolution_op_handler.DepthwiseConvolutionOpHandler(), 'ExpandDims': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), 'RandomUniform': leaf_op_handler.LeafOpHandler(), 'Reshape': leaf_op_handler.LeafOpHandler(), 'Shape': leaf_op_handler.LeafOpHandler(), 'SpaceToDepth': leaf_op_handler.LeafOpHandler(), 'StridedSlice': leaf_op_handler.LeafOpHandler(), 'TensorArrayGatherV3': leaf_op_handler.LeafOpHandler(), 'Transpose': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), }) for resize_method in RESIZE_OP_NAMES: # Resize* ops, second input might be a tensor which will result in an error. base_dict[resize_method] = grouping_op_handler.GroupingOpHandler([0]) return base_dict
def testAssignGrouping_AllNeighborsGrouped(self): # All neighbor ops have groups. self.op_group_dict = { self.batch_norm_op_slice: self.batch_norm_op_group, self.conv_op_slice: self.conv_op_group, self.relu_op_slice: self.relu_op_group, self.gamma_op_slice: self.conv_op_group, self.beta_op_slice: self.conv_op_group, self.mean_op_slice: self.relu_op_group, self.std_op_slice: self.relu_op_group, } # Call handler to assign grouping. handler = leaf_op_handler.LeafOpHandler() handler.assign_grouping(self.batch_norm_op, self.mock_op_reg_manager) # Verify manager looks up OpSlice for ops of interest. self.mock_op_reg_manager.get_op_slices.assert_has_calls( # Checking for ops to process. [ mock.call(self.relu_op), mock.call(self.mean_op), mock.call(self.std_op), # Initial slice data. mock.call(self.batch_norm_op), mock.call(self.relu_op), mock.call(self.mean_op), mock.call(self.std_op), # Reslicing. mock.call(self.batch_norm_op), mock.call(self.relu_op), mock.call(self.mean_op), mock.call(self.std_op), # Refreshing slice data. mock.call(self.relu_op), mock.call(self.mean_op), mock.call(self.std_op), # Group batch norm op. mock.call(self.batch_norm_op) ]) # Verify manager groups batch norm and outputs. self.mock_op_reg_manager.group_op_slices.assert_called_once_with([ self.batch_norm_op_slice, self.relu_op_slice, self.mean_op_slice, self.std_op_slice ]) # Verify manager does not process any additional ops. self.mock_op_reg_manager.process_ops.assert_not_called() self.mock_op_reg_manager.process_ops_last.assert_not_called()
def __init__(self, ops, gamma_threshold, regularizer_decorator: Type[ generic_regularizers.OpRegularizer] = None, decorator_parameters=None, force_group=None, regularizer_blacklist=None): """Creates a GammaFlopsRegularizer object. Args: ops: A list of tf.Operation. An OpRegularizer will be created for all the ops in `ops`, and recursively for all ops they depend on via data dependency. Typically `ops` would contain a single tf.Operation, which is the output of the network. gamma_threshold: A float scalar, will be used as a 'gamma_threshold' for all instances GammaL1Regularizer created by this class. regularizer_decorator: A class of OpRegularizer decorator to use. decorator_parameters: A dictionary of parameters to pass to the decorator factory. To be used only with decorators that requires parameters, otherwise use None. force_group: List of regex for ops that should be force-grouped. Each regex corresponds to a separate group. Use '|' operator to specify multiple patterns in a single regex. See op_regularizer_manager for more detail. regularizer_blacklist: List of regex for ops that should not be regularized. See op_regularizer_manager for more detail. """ source_op_handler = batch_norm_source_op_handler.BatchNormSourceOpHandler( gamma_threshold) if regularizer_decorator: source_op_handler = op_handler_decorator.OpHandlerDecorator( source_op_handler, regularizer_decorator, decorator_parameters) op_handler_dict = collections.defaultdict( grouping_op_handler.GroupingOpHandler) op_handler_dict.update({ 'FusedBatchNorm': source_op_handler, 'FusedBatchNormV2': source_op_handler, 'Conv2D': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), 'ConcatV2': concat_op_handler.ConcatOpHandler(), 'DepthToSpace': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), 'DepthwiseConv2dNative': depthwise_convolution_op_handler.DepthwiseConvolutionOpHandler(), 'MatMul': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), 'TensorArrayGatherV3': leaf_op_handler.LeafOpHandler(), 'RandomUniform': leaf_op_handler.LeafOpHandler(), 'Reshape': leaf_op_handler.LeafOpHandler(), 'Transpose': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), 'ExpandDims': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), }) self._manager = orm.OpRegularizerManager( ops, op_handler_dict, force_group=force_group, regularizer_blacklist=regularizer_blacklist) self._calculator = cost_calculator.CostCalculator( self._manager, resource_function.flop_function)
def __init__(self, ops, threshold, l1_fraction=0, regularizer_decorator: Type[ generic_regularizers.OpRegularizer] = None, decorator_parameters=None, force_group=None, regularizer_blacklist=None, convert_to_variable=True): """Creates a GroupLassoFlopsRegularizer object. Args: ops: A list of tf.Operation. An OpRegularizer will be created for all the ops in `ops`, and recursively for all ops they depend on via data dependency. Typically `ops` would contain a single tf.Operation, which is the output of the network. threshold: A float scalar, will be used as a 'threshold' for all regularizer instances created by this class. l1_fraction: Relative weight of L1 in L1 + L2 regularization. regularizer_decorator: A class of OpRegularizer decorator to use. decorator_parameters: A dictionary of parameters to pass to the decorator factory. To be used only with decorators that requires parameters, otherwise use None. force_group: List of regex for ops that should be force-grouped. Each regex corresponds to a separate group. Use '|' operator to specify multiple patterns in a single regex. See op_regularizer_manager for more detail. regularizer_blacklist: List of regex for ops that should not be regularized. See op_regularizer_manager for more detail. convert_to_variable: If `True` convert to variable in the `GroupLassoBaseOpHandler`. If you're graph creates variables outside of `tf.get_variable`, set to `False`. """ conv2d_handler = conv2d_source_op_handler.Conv2DSourceOpHandler( threshold, l1_fraction, convert_to_variable) conv2d_transpose_handler = ( conv2d_transpose_source_op_handler.Conv2DTransposeSourceOpHandler( threshold, l1_fraction, convert_to_variable)) matmul_handler = matmul_source_op_handler.MatMulSourceOpHandler( threshold, l1_fraction, convert_to_variable) if regularizer_decorator: conv2d_handler = op_handler_decorator.OpHandlerDecorator( conv2d_handler, regularizer_decorator, decorator_parameters) conv2d_transpose_handler = op_handler_decorator.OpHandlerDecorator( conv2d_transpose_handler, regularizer_decorator, decorator_parameters) matmul_handler = op_handler_decorator.OpHandlerDecorator( matmul_handler, regularizer_decorator, decorator_parameters) op_handler_dict = collections.defaultdict( grouping_op_handler.GroupingOpHandler) op_handler_dict.update({ 'Conv2D': conv2d_handler, 'Conv2DBackpropInput': conv2d_transpose_handler, 'ConcatV2': concat_op_handler.ConcatOpHandler(), 'DepthToSpace': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), 'DepthwiseConv2dNative': depthwise_convolution_op_handler.DepthwiseConvolutionOpHandler(), 'MatMul': matmul_handler, 'RandomUniform': leaf_op_handler.LeafOpHandler(), 'Reshape': leaf_op_handler.LeafOpHandler(), 'Shape': leaf_op_handler.LeafOpHandler(), 'TensorArrayGatherV3': leaf_op_handler.LeafOpHandler(), 'Transpose': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), 'StridedSlice': leaf_op_handler.LeafOpHandler(), }) self._manager = orm.OpRegularizerManager( ops, op_handler_dict, force_group=force_group, regularizer_blacklist=regularizer_blacklist) self._calculator = cost_calculator.CostCalculator( self._manager, resource_function.flop_function)
def __init__(self, ops, gamma_threshold, hardware, batch_size=1, regularizer_decorator: Type[ generic_regularizers.OpRegularizer] = None, decorator_parameters=None, force_group=None, regularizer_blacklist=None) -> None: """Creates a GammaLatencyRegularizer object. Latency cost and regularization loss is calculated for a specified hardware platform. Args: ops: A list of tf.Operation. An OpRegularizer will be created for all the ops in `ops`, and recursively for all ops they depend on via data dependency. Typically `ops` would contain a single tf.Operation, which is the output of the network. gamma_threshold: A float scalar, will be used as a 'gamma_threshold' for all instances GammaL1Regularizer created by this class. hardware: String name of hardware platform to target. Must be a key from resource_function.PEAK_COMPUTE. batch_size: Integer batch size to calculate cost/loss for. regularizer_decorator: A string, the name of the regularizer decorators to use. Supported decorators are listed in op_regularizer_decorator.SUPPORTED_DECORATORS. decorator_parameters: A dictionary of parameters to pass to the decorator factory. To be used only with decorators that requires parameters, otherwise use None. force_group: List of regex for ops that should be force-grouped. Each regex corresponds to a separate group. Use '|' operator to specify multiple patterns in a single regex. See op_regularizer_manager for more detail. regularizer_blacklist: List of regex for ops that should not be regularized. See op_regularizer_manager for more detail. """ source_op_handler = batch_norm_source_op_handler.BatchNormSourceOpHandler( gamma_threshold) if regularizer_decorator: source_op_handler = op_handler_decorator.OpHandlerDecorator( source_op_handler, regularizer_decorator, decorator_parameters) op_handler_dict = collections.defaultdict( grouping_op_handler.GroupingOpHandler) op_handler_dict.update({ 'FusedBatchNorm': source_op_handler, 'FusedBatchNormV2': source_op_handler, 'Conv2D': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), 'ConcatV2': concat_op_handler.ConcatOpHandler(), 'DepthToSpace': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), 'DepthwiseConv2dNative': depthwise_convolution_op_handler.DepthwiseConvolutionOpHandler(), 'MatMul': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), 'TensorArrayGatherV3': leaf_op_handler.LeafOpHandler(), 'Transpose': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), }) self._manager = orm.OpRegularizerManager( ops, op_handler_dict, force_group=force_group, regularizer_blacklist=regularizer_blacklist) self._calculator = cost_calculator.CostCalculator( self._manager, resource_function.latency_function_factory(hardware, batch_size)) self._hardware = hardware