def __init__(self, output_boundary: List[tf.Operation], threshold, l1_fraction=0.0, regularizer_decorator: Optional[Type[ generic_regularizers.OpRegularizer]] = None, decorator_parameters=None, input_boundary: Optional[List[tf.Operation]] = None, force_group: Optional[List[Text]] = None, regularizer_blacklist: Optional[List[Text]] = None): """Creates a GroupLassoModelSizeRegularizer object. Args: output_boundary: An OpRegularizer will be created for all these operations, and recursively for all ops they depend on via data dependency that does not involve ops from input_boundary. threshold: A float scalar, will be used as a 'threshold' for all regularizer instances created by this class. l1_fraction: A float scalar. The relative weight of L1 in L1 + L2 regularization. regularizer_decorator: A class of OpRegularizer decorator to use. decorator_parameters: A dictionary of parameters to pass to the decorator factory. To be used only with decorators that requires parameters, otherwise use None. input_boundary: A list of ops that represent the input boundary of the subgraph being regularized (input boundary is not regularized). force_group: List of regex for ops that should be force-grouped. Each regex corresponds to a separate group. Use '|' operator to specify multiple patterns in a single regex. See op_regularizer_manager for more detail. regularizer_blacklist: List of regex for ops that should not be regularized. See op_regularizer_manager for more detail. """ custom_handlers = { 'Conv2D': conv_handler.ConvSourceOpHandler(threshold, l1_fraction), 'Conv3D': conv_handler.ConvSourceOpHandler(threshold, l1_fraction), 'Conv2DBackpropInput': conv2d_transpose_handler.Conv2DTransposeSourceOpHandler( threshold, l1_fraction), 'MatMul': matmul_handler.MatMulSourceOpHandler(threshold, l1_fraction) } if regularizer_decorator: for key in custom_handlers: custom_handlers[key] = op_handler_decorator.OpHandlerDecorator( custom_handlers[key], regularizer_decorator, decorator_parameters) op_handler_dict = op_handlers.get_group_lasso_op_handler_dict() op_handler_dict.update(custom_handlers) self._manager = orm.OpRegularizerManager( output_boundary, op_handler_dict, input_boundary=input_boundary, force_group=force_group, regularizer_blacklist=regularizer_blacklist) self._calculator = cost_calculator.CostCalculator( self._manager, resource_function.model_size_function)
def testAssignGrouping_GroupWithOutputOnly(self, conv_type): self._build(conv_type) # Map ops to slices. self.op_slice_dict = { self.conv1_op: [self.conv1_op_slice], self.relu1_op: [self.relu1_op_slice], self.conv2_op: [self.conv2_op_slice], self.relu2_op: [self.relu2_op_slice], } # Map each slice to a group. Corresponding op slices have the same group. self.op_group_dict = { self.conv2_op_slice: self.conv2_op_group, } # Call handler to assign grouping. handler = conv_source_op_handler.ConvSourceOpHandler(_DEFAULT_THRESHOLD) handler.assign_grouping(self.conv2_op, self.mock_op_reg_manager) # Verify manager looks up op slice for ops of interest. self.mock_op_reg_manager.get_op_slices.assert_any_call(self.conv2_op) # Verify manager does not slice any ops. self.mock_op_reg_manager.slice_op.assert_not_called() # Verify manager adds inputs to process queue. self.mock_op_reg_manager.process_ops.assert_called_once_with( [self.relu1_op])
def is_passthrough(op): if op in [self.conv1_op, self.conv2_op]: h = conv_source_op_handler.ConvSourceOpHandler( _DEFAULT_THRESHOLD) return h.is_passthrough else: return False
def testCreateRegularizer(self, conv_type): self._build(conv_type) # Call handler to create regularizer. handler = conv_source_op_handler.ConvSourceOpHandler(_DEFAULT_THRESHOLD) regularizer = handler.create_regularizer(self.conv2_op_slice) # Verify regularizer produces correctly shaped tensors. # Most of the regularizer testing is in group_lasso_regularizer_test.py expected_norm_dim = self.conv2_op.inputs[1].shape.as_list()[-1] self.assertEqual(expected_norm_dim, regularizer.regularization_vector.shape.as_list()[0])
def testOpHandlerDecorator(self): image = tf.constant(0.0, shape=[1, 17, 19, 3]) kernel = tf.ones([5, 5, 3, 3]) output = tf.nn.conv2d(image, kernel, strides=[1, 1, 1, 1], padding='SAME') decorated_op_handler = op_handler_decorator.OpHandlerDecorator( conv_source_op_handler.ConvSourceOpHandler(1e-3, 0), DummyDecorator) op_slice = orm.OpSlice(output.op, orm.Slice(0, 3)) regularizer = decorated_op_handler.create_regularizer(op_slice) self.assertAllClose(0.5 * np.ones(3), regularizer.regularization_vector) self.assertAllClose(np.ones(3), regularizer.alive_vector)