def _get_base_op_hander_dicts(): """Returns the base op_hander_dict for all regularizers.""" base_dict = collections.defaultdict( grouping_op_handler.GroupingOpHandler, { 'ConcatV2': concat_op_handler.ConcatOpHandler(), 'DepthToSpace': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), 'DepthwiseConv2dNative': depthwise_convolution_op_handler.DepthwiseConvolutionOpHandler(), 'ExpandDims': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), 'RandomUniform': leaf_op_handler.LeafOpHandler(), 'Reshape': leaf_op_handler.LeafOpHandler(), 'Shape': leaf_op_handler.LeafOpHandler(), 'SpaceToDepth': leaf_op_handler.LeafOpHandler(), 'StridedSlice': leaf_op_handler.LeafOpHandler(), 'TensorArrayGatherV3': leaf_op_handler.LeafOpHandler(), 'Transpose': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), }) for resize_method in RESIZE_OP_NAMES: # Resize* ops, second input might be a tensor which will result in an error. base_dict[resize_method] = grouping_op_handler.GroupingOpHandler([0]) return base_dict
def _get_base_op_hander_dicts(): return collections.defaultdict( grouping_op_handler.GroupingOpHandler, { 'ConcatV2': concat_op_handler.ConcatOpHandler(), 'DepthToSpace': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), 'DepthwiseConv2dNative': depthwise_convolution_op_handler.DepthwiseConvolutionOpHandler(), 'ExpandDims': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), 'RandomUniform': leaf_op_handler.LeafOpHandler(), 'Reshape': leaf_op_handler.LeafOpHandler(), 'Shape': leaf_op_handler.LeafOpHandler(), 'SpaceToDepth': leaf_op_handler.LeafOpHandler(), 'StridedSlice': leaf_op_handler.LeafOpHandler(), 'TensorArrayGatherV3': leaf_op_handler.LeafOpHandler(), 'Transpose': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), })
def test_AssignGroupingOfGroupingConcatNoSlicing(self): # In this test, the output op (batch norm) has size 6 and is not sliced. # and that input Conv2Ds are all of size 6, and are grouped. # Map ops to slices. Batch norm op is composed of multiple slices. self.op_slice_dict = { self.relu1_op: [self.relu1_op_slice], self.relu2_op: [self.relu2_op_slice], self.relu3_op: [self.relu3_op_slice], self.concat_op: [self.concat_op_slice], self.batch_norm_op: [self.batch_norm_op_slice], } # Map each slice to a group. self.op_group_dict = { self.relu1_op_slice: self.relu1_op_group, self.relu2_op_slice: self.relu2_op_group, self.relu3_op_slice: self.relu3_op_group, self.batch_norm_op_slice: self.batch_norm_op_group } # Call handler to assign grouping. handler = concat_op_handler.ConcatOpHandler() handler.assign_grouping(self.concat_op, self.mock_op_reg_manager) # Verify manager looks up OpSlice for ops of interest. self.mock_op_reg_manager.get_op_slices.assert_has_calls( # Checking for ops to process. [mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Initial slice data. mock.call(self.concat_op), mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Reslicing. mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.concat_op), mock.call(self.batch_norm_op), # Refreshing slice data. mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Group concat op. mock.call(self.concat_op)]) # Verify manager does not slices the concat op. self.mock_op_reg_manager.slice_op.assert_not_called() # Verify manager groups the new slices. self.mock_op_reg_manager.group_op_slices.assert_called_once_with([ self.concat_op_slice, self.relu1_op_slice, self.relu2_op_slice, self.relu3_op_slice, self.batch_norm_op_slice ])
def testGetInputOutputOpSlices(self): # Map ops to slices. self.op_slice_dict = { self.relu1_op: [self.relu1_op_slice], self.relu2_op: [self.relu2_op_slice], self.relu3_op: [self.relu3_op_slice], self.concat_op: [self.concat_op_slice], self.batch_norm_op: [self.batch_norm_op_slice], } input_ops = [self.relu1_op, self.relu2_op, self.relu3_op, self.axis_op] output_ops = [self.batch_norm_op] expected_input_op_slices = [[ self.relu1_op_slice, self.relu2_op_slice, self.relu3_op_slice ]] expected_output_op_slices = [[self.batch_norm_op_slice]] # Instantiate handler. handler = concat_op_handler.ConcatOpHandler() self.assertEqual( (expected_input_op_slices, expected_output_op_slices), handler._get_input_output_op_slices(input_ops, output_ops, self.mock_op_reg_manager))
def testImageIsNotZerothOutputOfOp(self): # Throughout the framework, we assume that the 0th output of each op is the # only one of interest. One exception that often happens is when the input # image comes from a queue or from a staging op. Then the image is one of # the outputs of the dequeue (or staging) op, not necessarily the 0th one. # Here we test that the BilinearNetworkRegularizer deals correctly with this # case. # Create an input op where the image is output number 1, not 0. # TODO(g1) Move this mechanism to add_concat_model_stub, possibly using # tf.split to produce an op where the image is not the 0th output image # (instead of FIFOQueue). image = add_concat_model_stub.image_stub() non_image_tensor = tf.zeros(shape=(41,)) queue = tf.FIFOQueue( capacity=1, dtypes=(tf.float32,) * 2, shapes=(non_image_tensor.shape, image.shape)) # Pass the image (output[1]) to the network. with arg_scope(self._batch_norm_scope()): output_op = add_concat_model_stub.build_model(queue.dequeue()[1]) # Create OpHandler dict for test. op_handler_dict = collections.defaultdict( grouping_op_handler.GroupingOpHandler) op_handler_dict.update({ 'FusedBatchNorm': batch_norm_source_op_handler.BatchNormSourceOpHandler(0.1), 'Conv2D': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), 'ConcatV2': concat_op_handler.ConcatOpHandler(), }) # Create OpRegularizerManager and NetworkRegularizer for test. manager = orm.OpRegularizerManager([output_op], op_handler_dict) calculator = cost_calculator.CostCalculator( manager, resource_function.flop_function) # Calculate expected FLOP cost. expected_alive_conv1 = sum(add_concat_model_stub.expected_alive()['conv1']) conv1_op = tf.get_default_graph().get_operation_by_name('conv1/Conv2D') conv1_coeff = resource_function.flop_coeff(conv1_op) num_channels = 3 expected_cost = conv1_coeff * num_channels * expected_alive_conv1 with self.session(): tf.global_variables_initializer().run() # Set gamma values to replicate aliveness in add_concat_model_stub. name_to_var = {v.op.name: v for v in tf.global_variables()} gamma1 = name_to_var['conv1/BatchNorm/gamma'] gamma1.assign([0, 1, 1, 0, 1, 0, 1]).eval() gamma4 = name_to_var['conv4/BatchNorm/gamma'] gamma4.assign([0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0]).eval() queue.enqueue((non_image_tensor, image)).run() self.assertEqual(expected_cost, calculator.get_cost([conv1_op]).eval())
def __init__(self, ops, gamma_threshold, regularizer_decorator: Type[ generic_regularizers.OpRegularizer] = None, decorator_parameters=None, force_group=None, regularizer_blacklist=None): """Creates a GammaFlopsRegularizer object. Args: ops: A list of tf.Operation. An OpRegularizer will be created for all the ops in `ops`, and recursively for all ops they depend on via data dependency. Typically `ops` would contain a single tf.Operation, which is the output of the network. gamma_threshold: A float scalar, will be used as a 'gamma_threshold' for all instances GammaL1Regularizer created by this class. regularizer_decorator: A class of OpRegularizer decorator to use. decorator_parameters: A dictionary of parameters to pass to the decorator factory. To be used only with decorators that requires parameters, otherwise use None. force_group: List of regex for ops that should be force-grouped. Each regex corresponds to a separate group. Use '|' operator to specify multiple patterns in a single regex. See op_regularizer_manager for more detail. regularizer_blacklist: List of regex for ops that should not be regularized. See op_regularizer_manager for more detail. """ source_op_handler = batch_norm_source_op_handler.BatchNormSourceOpHandler( gamma_threshold) if regularizer_decorator: source_op_handler = op_handler_decorator.OpHandlerDecorator( source_op_handler, regularizer_decorator, decorator_parameters) op_handler_dict = collections.defaultdict( grouping_op_handler.GroupingOpHandler) op_handler_dict.update({ 'FusedBatchNorm': source_op_handler, 'FusedBatchNormV2': source_op_handler, 'Conv2D': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), 'ConcatV2': concat_op_handler.ConcatOpHandler(), 'DepthToSpace': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), 'DepthwiseConv2dNative': depthwise_convolution_op_handler.DepthwiseConvolutionOpHandler(), 'MatMul': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), 'TensorArrayGatherV3': leaf_op_handler.LeafOpHandler(), 'RandomUniform': leaf_op_handler.LeafOpHandler(), 'Reshape': leaf_op_handler.LeafOpHandler(), 'Transpose': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), 'ExpandDims': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), }) self._manager = orm.OpRegularizerManager( ops, op_handler_dict, force_group=force_group, regularizer_blacklist=regularizer_blacklist) self._calculator = cost_calculator.CostCalculator( self._manager, resource_function.flop_function)
def __init__(self, ops, threshold, l1_fraction=0, regularizer_decorator: Type[ generic_regularizers.OpRegularizer] = None, decorator_parameters=None, force_group=None, regularizer_blacklist=None, convert_to_variable=True): """Creates a GroupLassoFlopsRegularizer object. Args: ops: A list of tf.Operation. An OpRegularizer will be created for all the ops in `ops`, and recursively for all ops they depend on via data dependency. Typically `ops` would contain a single tf.Operation, which is the output of the network. threshold: A float scalar, will be used as a 'threshold' for all regularizer instances created by this class. l1_fraction: Relative weight of L1 in L1 + L2 regularization. regularizer_decorator: A class of OpRegularizer decorator to use. decorator_parameters: A dictionary of parameters to pass to the decorator factory. To be used only with decorators that requires parameters, otherwise use None. force_group: List of regex for ops that should be force-grouped. Each regex corresponds to a separate group. Use '|' operator to specify multiple patterns in a single regex. See op_regularizer_manager for more detail. regularizer_blacklist: List of regex for ops that should not be regularized. See op_regularizer_manager for more detail. convert_to_variable: If `True` convert to variable in the `GroupLassoBaseOpHandler`. If you're graph creates variables outside of `tf.get_variable`, set to `False`. """ conv2d_handler = conv2d_source_op_handler.Conv2DSourceOpHandler( threshold, l1_fraction, convert_to_variable) conv2d_transpose_handler = ( conv2d_transpose_source_op_handler.Conv2DTransposeSourceOpHandler( threshold, l1_fraction, convert_to_variable)) matmul_handler = matmul_source_op_handler.MatMulSourceOpHandler( threshold, l1_fraction, convert_to_variable) if regularizer_decorator: conv2d_handler = op_handler_decorator.OpHandlerDecorator( conv2d_handler, regularizer_decorator, decorator_parameters) conv2d_transpose_handler = op_handler_decorator.OpHandlerDecorator( conv2d_transpose_handler, regularizer_decorator, decorator_parameters) matmul_handler = op_handler_decorator.OpHandlerDecorator( matmul_handler, regularizer_decorator, decorator_parameters) op_handler_dict = collections.defaultdict( grouping_op_handler.GroupingOpHandler) op_handler_dict.update({ 'Conv2D': conv2d_handler, 'Conv2DBackpropInput': conv2d_transpose_handler, 'ConcatV2': concat_op_handler.ConcatOpHandler(), 'DepthToSpace': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), 'DepthwiseConv2dNative': depthwise_convolution_op_handler.DepthwiseConvolutionOpHandler(), 'MatMul': matmul_handler, 'RandomUniform': leaf_op_handler.LeafOpHandler(), 'Reshape': leaf_op_handler.LeafOpHandler(), 'Shape': leaf_op_handler.LeafOpHandler(), 'TensorArrayGatherV3': leaf_op_handler.LeafOpHandler(), 'Transpose': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), 'StridedSlice': leaf_op_handler.LeafOpHandler(), }) self._manager = orm.OpRegularizerManager( ops, op_handler_dict, force_group=force_group, regularizer_blacklist=regularizer_blacklist) self._calculator = cost_calculator.CostCalculator( self._manager, resource_function.flop_function)
def __init__(self, ops, gamma_threshold, hardware, batch_size=1, regularizer_decorator: Type[ generic_regularizers.OpRegularizer] = None, decorator_parameters=None, force_group=None, regularizer_blacklist=None) -> None: """Creates a GammaLatencyRegularizer object. Latency cost and regularization loss is calculated for a specified hardware platform. Args: ops: A list of tf.Operation. An OpRegularizer will be created for all the ops in `ops`, and recursively for all ops they depend on via data dependency. Typically `ops` would contain a single tf.Operation, which is the output of the network. gamma_threshold: A float scalar, will be used as a 'gamma_threshold' for all instances GammaL1Regularizer created by this class. hardware: String name of hardware platform to target. Must be a key from resource_function.PEAK_COMPUTE. batch_size: Integer batch size to calculate cost/loss for. regularizer_decorator: A string, the name of the regularizer decorators to use. Supported decorators are listed in op_regularizer_decorator.SUPPORTED_DECORATORS. decorator_parameters: A dictionary of parameters to pass to the decorator factory. To be used only with decorators that requires parameters, otherwise use None. force_group: List of regex for ops that should be force-grouped. Each regex corresponds to a separate group. Use '|' operator to specify multiple patterns in a single regex. See op_regularizer_manager for more detail. regularizer_blacklist: List of regex for ops that should not be regularized. See op_regularizer_manager for more detail. """ source_op_handler = batch_norm_source_op_handler.BatchNormSourceOpHandler( gamma_threshold) if regularizer_decorator: source_op_handler = op_handler_decorator.OpHandlerDecorator( source_op_handler, regularizer_decorator, decorator_parameters) op_handler_dict = collections.defaultdict( grouping_op_handler.GroupingOpHandler) op_handler_dict.update({ 'FusedBatchNorm': source_op_handler, 'FusedBatchNormV2': source_op_handler, 'Conv2D': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), 'ConcatV2': concat_op_handler.ConcatOpHandler(), 'DepthToSpace': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), 'DepthwiseConv2dNative': depthwise_convolution_op_handler.DepthwiseConvolutionOpHandler(), 'MatMul': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), 'TensorArrayGatherV3': leaf_op_handler.LeafOpHandler(), 'Transpose': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), }) self._manager = orm.OpRegularizerManager( ops, op_handler_dict, force_group=force_group, regularizer_blacklist=regularizer_blacklist) self._calculator = cost_calculator.CostCalculator( self._manager, resource_function.latency_function_factory(hardware, batch_size)) self._hardware = hardware
def testAssignGrouping_NoNeighborGroups(self): # In this test, both the inputs and outputs are missing groups. The concat # and batch norm are sliced, but grouping does not happen until the inputs # and outputs are grouped. # Map ops to slices. self.op_slice_dict = { self.relu1_op: [self.relu1_op_slice], self.relu2_op: [self.relu2_op_slice], self.relu3_op: [self.relu3_op_slice], self.concat_op: [self.concat_op_slice], self.batch_norm_op: [self.batch_norm_op_slice], } # No neighbor slices are grouped. self.op_group_dict = { self.concat_op_slice_0_5: self.concat_op_group1, self.concat_op_slice_5_11: self.concat_op_group2, self.concat_op_slice_11_18: self.concat_op_group3, } # Call handler to assign grouping. handler = concat_op_handler.ConcatOpHandler() handler.assign_grouping(self.concat_op, self.mock_op_reg_manager) # Verify manager looks up OpSlice for ops of interest. self.mock_op_reg_manager.get_op_slices.assert_has_calls( # Checking for ops to process. [ mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Initial slice data. mock.call(self.concat_op), mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Reslicing. mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), mock.call(self.concat_op), # Refreshing slice data. mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op) ]) # Verify manager slices ops that do not have aligned OpSlice sizes. self.mock_op_reg_manager.slice_op.assert_has_calls([ mock.call(self.batch_norm_op, [5, 6, 7]), mock.call(self.concat_op, [5, 6, 7]) ]) # Verify manager doesn't group anything. self.mock_op_reg_manager.group_op_slices.assert_not_called() # Verify manager adds ops to processing queue. self.mock_op_reg_manager.process_ops.assert_called_once_with( [self.batch_norm_op, self.relu1_op, self.relu2_op, self.relu3_op]) self.mock_op_reg_manager.process_ops_last.assert_called_once_with( [self.concat_op])
def testAssignGrouping_OutputsGrouped(self): # In this test, only the output ops are grouped. The concat and batch norm # ops will be sliced according to the input sizes. # Map ops to slices. self.op_slice_dict = { self.relu1_op: [self.relu1_op_slice], self.relu2_op: [self.relu2_op_slice], self.relu3_op: [self.relu3_op_slice], self.concat_op: [self.concat_op_slice], self.batch_norm_op: [self.batch_norm_op_slice], } # Map each slice to a group. Input ops (ReLU) are not grouped. self.op_group_dict = { self.concat_op_slice_0_5: self.concat_op_group1, self.concat_op_slice_5_11: self.concat_op_group2, self.concat_op_slice_11_18: self.concat_op_group3, self.batch_norm_op_slice: self.batch_norm_op_group, self.batch_norm_op_slice_0_5: self.batch_norm_op_group1, self.batch_norm_op_slice_5_11: self.batch_norm_op_group2, self.batch_norm_op_slice_11_18: self.batch_norm_op_group3, } # Call handler to assign grouping. handler = concat_op_handler.ConcatOpHandler() handler.assign_grouping(self.concat_op, self.mock_op_reg_manager) # Verify manager looks up OpSlice for ops of interest. self.mock_op_reg_manager.get_op_slices.assert_has_calls( # Checking for ops to process. [ mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Initial slice data. mock.call(self.concat_op), mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Reslicing. mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), mock.call(self.concat_op), # Refreshing slice data. mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op) ]) # Verify manager slices ops that do not have aligned OpSlice sizes. self.mock_op_reg_manager.slice_op.assert_has_calls([ mock.call(self.batch_norm_op, [5, 6, 7]), mock.call(self.concat_op, [5, 6, 7]) ]) # Verify manager does not group ops. self.mock_op_reg_manager.group_op_slices.assert_not_called() # Verify manager adds ops to processing queue. self.mock_op_reg_manager.process_ops.assert_called_once_with( [self.relu1_op, self.relu2_op, self.relu3_op]) self.mock_op_reg_manager.process_ops_last.assert_called_once_with( [self.concat_op])
def testAssignGrouping_AllNeighborsGrouped_InputSlicesNotAligned(self): # In this test, the op c2 has size 6 but is split into 2 slices of size 3. # The concat op (and its output, the batch norm) both have size 18. This # test verifies that the concat and batch norm are sliced according to the # sizes of c1, c2, and c3, and takes into account that c2 is also composed # of multiple slices. concat_op_slice_0_5 = orm.OpSlice(self.concat_op, orm.Slice(0, 5)) concat_op_slice_5_8 = orm.OpSlice(self.concat_op, orm.Slice(5, 3)) concat_op_slice_8_11 = orm.OpSlice(self.concat_op, orm.Slice(8, 3)) concat_op_slice_11_18 = orm.OpSlice(self.concat_op, orm.Slice(11, 7)) relu2_op_slice_0_3 = orm.OpSlice(self.relu2_op, orm.Slice(0, 3)) relu2_op_slice_3_6 = orm.OpSlice(self.relu2_op, orm.Slice(3, 3)) relu2_op_group1 = orm.OpGroup( relu2_op_slice_0_3, omit_source_op_slices=[relu2_op_slice_0_3]) relu2_op_group2 = orm.OpGroup( relu2_op_slice_3_6, omit_source_op_slices=[relu2_op_slice_3_6]) batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 18)) batch_norm_op_group = orm.OpGroup( batch_norm_op_slice, omit_source_op_slices=[batch_norm_op_slice]) batch_norm_op_slice_0_5 = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 5)) batch_norm_op_group1 = orm.OpGroup( batch_norm_op_slice_0_5, omit_source_op_slices=[batch_norm_op_slice_0_5]) batch_norm_op_slice_5_8 = orm.OpSlice(self.batch_norm_op, orm.Slice(5, 3)) batch_norm_op_group2 = orm.OpGroup( batch_norm_op_slice_5_8, omit_source_op_slices=[batch_norm_op_slice_5_8]) batch_norm_op_slice_8_11 = orm.OpSlice(self.batch_norm_op, orm.Slice(8, 3)) batch_norm_op_group3 = orm.OpGroup( batch_norm_op_slice_8_11, omit_source_op_slices=[batch_norm_op_slice_8_11]) batch_norm_op_slice_11_18 = orm.OpSlice(self.batch_norm_op, orm.Slice(11, 7)) batch_norm_op_group4 = orm.OpGroup( batch_norm_op_slice_11_18, omit_source_op_slices=[batch_norm_op_slice_11_18]) # Map ops to slices. The op c2 is composed of multiple slices. self.op_slice_dict = { self.relu1_op: [self.relu1_op_slice], self.relu2_op: [relu2_op_slice_0_3, relu2_op_slice_3_6], self.relu3_op: [self.relu3_op_slice], self.concat_op: [self.concat_op_slice], self.batch_norm_op: [batch_norm_op_slice], } # Map each slice to a group. self.op_group_dict = { self.relu1_op_slice: self.relu1_op_group, relu2_op_slice_0_3: relu2_op_group1, relu2_op_slice_3_6: relu2_op_group2, self.relu3_op_slice: self.relu3_op_group, batch_norm_op_slice: batch_norm_op_group, batch_norm_op_slice_0_5: batch_norm_op_group1, batch_norm_op_slice_5_8: batch_norm_op_group2, batch_norm_op_slice_8_11: batch_norm_op_group3, batch_norm_op_slice_11_18: batch_norm_op_group4, } # Update op_slice_dict when an op is sliced. def slice_op(op, _): if op == self.batch_norm_op: self.op_slice_dict[self.batch_norm_op] = [ batch_norm_op_slice_0_5, batch_norm_op_slice_5_8, batch_norm_op_slice_8_11, batch_norm_op_slice_11_18 ] if op == self.concat_op: self.op_slice_dict[self.concat_op] = [ concat_op_slice_0_5, concat_op_slice_5_8, concat_op_slice_8_11, concat_op_slice_11_18 ] self.mock_op_reg_manager.slice_op.side_effect = slice_op # Call handler to assign grouping. handler = concat_op_handler.ConcatOpHandler() handler.assign_grouping(self.concat_op, self.mock_op_reg_manager) # Verify manager looks up OpSlice for ops of interest. self.mock_op_reg_manager.get_op_slices.assert_has_calls( # Checking for ops to process. [ mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Initial slice data. mock.call(self.concat_op), mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Reslicing. mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), mock.call(self.concat_op), # Refreshing slice data. mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Group concat op. mock.call(self.concat_op) ]) # Verify manager slices ops that do not have aligned OpSlice sizes. self.mock_op_reg_manager.slice_op.assert_has_calls([ mock.call(self.batch_norm_op, [5, 3, 3, 7]), mock.call(self.concat_op, [5, 3, 3, 7]) ]) # Verify manager groups the new slices. self.mock_op_reg_manager.group_op_slices.assert_has_calls([ mock.call([ concat_op_slice_0_5, self.relu1_op_slice, batch_norm_op_slice_0_5 ]), mock.call([ concat_op_slice_5_8, relu2_op_slice_0_3, batch_norm_op_slice_5_8 ]), mock.call([ concat_op_slice_8_11, relu2_op_slice_3_6, batch_norm_op_slice_8_11 ]), mock.call([ concat_op_slice_11_18, self.relu3_op_slice, batch_norm_op_slice_11_18 ]) ])
def testAssignGrouping_AllNeighborsGrouped_OutputSlicesNotAligned(self): # The output (batch norm) has sizes [9, 4, 5] which are not aligned. This # test verifies that the concat, batch norm, and Conv2D ops are sliced in # alignment. concat_op_slice_0_5 = orm.OpSlice(self.concat_op, orm.Slice(0, 5)) concat_op_slice_5_9 = orm.OpSlice(self.concat_op, orm.Slice(5, 4)) concat_op_slice_9_11 = orm.OpSlice(self.concat_op, orm.Slice(9, 2)) concat_op_slice_11_13 = orm.OpSlice(self.concat_op, orm.Slice(11, 2)) concat_op_slice_13_18 = orm.OpSlice(self.concat_op, orm.Slice(13, 5)) relu2_op_slice_0_4 = orm.OpSlice(self.relu2_op, orm.Slice(0, 4)) relu2_op_slice_4_6 = orm.OpSlice(self.relu2_op, orm.Slice(4, 2)) relu3_op_slice_0_2 = orm.OpSlice(self.relu3_op, orm.Slice(0, 2)) relu3_op_slice_2_7 = orm.OpSlice(self.relu3_op, orm.Slice(2, 5)) batch_norm_op_slice_0_9 = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 9)) batch_norm_op_group1 = orm.OpGroup( batch_norm_op_slice_0_9, omit_source_op_slices=[batch_norm_op_slice_0_9]) batch_norm_op_slice_9_13 = orm.OpSlice(self.batch_norm_op, orm.Slice(9, 4)) batch_norm_op_group2 = orm.OpGroup( batch_norm_op_slice_9_13, omit_source_op_slices=[batch_norm_op_slice_9_13]) batch_norm_op_slice_13_18 = orm.OpSlice(self.batch_norm_op, orm.Slice(13, 5)) batch_norm_op_group3 = orm.OpGroup( batch_norm_op_slice_13_18, omit_source_op_slices=[batch_norm_op_slice_13_18]) batch_norm_op_slice_0_5 = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 5)) batch_norm_op_group4 = orm.OpGroup( batch_norm_op_slice_0_5, omit_source_op_slices=[batch_norm_op_slice_0_5]) batch_norm_op_slice_5_9 = orm.OpSlice(self.batch_norm_op, orm.Slice(5, 4)) batch_norm_op_group5 = orm.OpGroup( batch_norm_op_slice_5_9, omit_source_op_slices=[batch_norm_op_slice_5_9]) batch_norm_op_slice_9_11 = orm.OpSlice(self.batch_norm_op, orm.Slice(9, 2)) batch_norm_op_group6 = orm.OpGroup( batch_norm_op_slice_9_11, omit_source_op_slices=[batch_norm_op_slice_9_11]) batch_norm_op_slice_11_13 = orm.OpSlice(self.batch_norm_op, orm.Slice(11, 2)) batch_norm_op_group7 = orm.OpGroup( batch_norm_op_slice_11_13, omit_source_op_slices=[batch_norm_op_slice_11_13]) batch_norm_op_slice_13_18 = orm.OpSlice(self.batch_norm_op, orm.Slice(11, 5)) batch_norm_op_group8 = orm.OpGroup( batch_norm_op_slice_13_18, omit_source_op_slices=[batch_norm_op_slice_13_18]) # Map ops to slices. Batch norm op is composed of multiple slices. self.op_slice_dict = { self.relu1_op: [self.relu1_op_slice], self.relu2_op: [self.relu2_op_slice], self.relu3_op: [self.relu3_op_slice], self.concat_op: [self.concat_op_slice], self.batch_norm_op: [ batch_norm_op_slice_0_9, batch_norm_op_slice_9_13, batch_norm_op_slice_13_18 ], } # Map each slice to a group. self.op_group_dict = { self.relu1_op_slice: self.relu1_op_group, self.relu2_op_slice: self.relu2_op_group, self.relu3_op_slice: self.relu3_op_group, batch_norm_op_slice_0_9: batch_norm_op_group1, batch_norm_op_slice_9_13: batch_norm_op_group2, batch_norm_op_slice_13_18: batch_norm_op_group3, batch_norm_op_slice_0_5: batch_norm_op_group4, batch_norm_op_slice_5_9: batch_norm_op_group5, batch_norm_op_slice_9_11: batch_norm_op_group6, batch_norm_op_slice_11_13: batch_norm_op_group7, batch_norm_op_slice_13_18: batch_norm_op_group8, } # Update op_slice_dict when an op is sliced. def slice_op(op, _): if op == self.batch_norm_op: self.op_slice_dict[self.batch_norm_op] = [ batch_norm_op_slice_0_5, batch_norm_op_slice_5_9, batch_norm_op_slice_9_11, batch_norm_op_slice_11_13, batch_norm_op_slice_13_18 ] if op == self.concat_op: self.op_slice_dict[self.concat_op] = [ concat_op_slice_0_5, concat_op_slice_5_9, concat_op_slice_9_11, concat_op_slice_11_13, concat_op_slice_13_18 ] if op == self.relu2_op: self.op_slice_dict[self.relu2_op] = [ relu2_op_slice_0_4, relu2_op_slice_4_6 ] if op == self.relu3_op: self.op_slice_dict[self.relu3_op] = [ relu3_op_slice_0_2, relu3_op_slice_2_7 ] self.mock_op_reg_manager.slice_op.side_effect = slice_op # Call handler to assign grouping. handler = concat_op_handler.ConcatOpHandler() handler.assign_grouping(self.concat_op, self.mock_op_reg_manager) # Verify manager looks up OpSlice for ops of interest. self.mock_op_reg_manager.get_op_slices.assert_has_calls( # Checking for ops to process. [ mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Initial slice data. mock.call(self.concat_op), mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Reslicing. mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), mock.call(self.concat_op), # Refreshing slice data. mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Group concat op. mock.call(self.concat_op) ]) # Verify manager slices ops that do not have aligned OpSlice sizes. self.mock_op_reg_manager.slice_op.assert_has_calls([ mock.call(self.relu2_op, [4, 2]), mock.call(self.relu3_op, [2, 5]), mock.call(self.batch_norm_op, [5, 4, 2, 2, 5]), mock.call(self.concat_op, [5, 4, 2, 2, 5]) ]) # Verify manager groups the new slices. self.mock_op_reg_manager.group_op_slices.assert_has_calls([ mock.call([ concat_op_slice_0_5, self.relu1_op_slice, batch_norm_op_slice_0_5 ]), mock.call([ concat_op_slice_5_9, relu2_op_slice_0_4, batch_norm_op_slice_5_9 ]), mock.call([ concat_op_slice_9_11, relu2_op_slice_4_6, batch_norm_op_slice_9_11 ]), mock.call([ concat_op_slice_11_13, relu3_op_slice_0_2, batch_norm_op_slice_11_13 ]), mock.call([ concat_op_slice_13_18, relu3_op_slice_2_7, batch_norm_op_slice_13_18 ]) ])
def testAssignGrouping_AllNeighborsGrouped_SlicesAligned_SameGroup(self): # This test verifies that no slicing or grouping occurs. # Map ops to slices. Batch norm op is composed of multiple slices. self.op_slice_dict = { self.relu1_op: [self.relu1_op_slice], self.relu2_op: [self.relu2_op_slice], self.relu3_op: [self.relu3_op_slice], self.concat_op: [ self.concat_op_slice_0_5, self.concat_op_slice_5_11, self.concat_op_slice_11_18 ], self.batch_norm_op: [ self.batch_norm_op_slice_0_5, self.batch_norm_op_slice_5_11, self.batch_norm_op_slice_11_18 ], } # Map each slice to a group. Corresponding op slices have the same group. self.op_group_dict = { self.relu1_op_slice: self.batch_norm_op_group1, self.relu2_op_slice: self.batch_norm_op_group2, self.relu3_op_slice: self.batch_norm_op_group3, self.concat_op_slice_0_5: self.batch_norm_op_group1, self.concat_op_slice_5_11: self.batch_norm_op_group2, self.concat_op_slice_11_18: self.batch_norm_op_group3, self.batch_norm_op_slice_0_5: self.batch_norm_op_group1, self.batch_norm_op_slice_5_11: self.batch_norm_op_group2, self.batch_norm_op_slice_11_18: self.batch_norm_op_group3, } # Call handler to assign grouping. handler = concat_op_handler.ConcatOpHandler() handler.assign_grouping(self.concat_op, self.mock_op_reg_manager) # Verify manager looks up OpSlice for ops of interest. self.mock_op_reg_manager.get_op_slices.assert_has_calls( # Checking for ops to process. [ mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Initial slice data. mock.call(self.concat_op), mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Reslicing. mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), mock.call(self.concat_op), # Refreshing slice data. mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Group concat op. mock.call(self.concat_op) ]) # Verify manager does not slice any ops. self.mock_op_reg_manager.slice_op.assert_not_called() # Verify manager does not group any ops. self.mock_op_reg_manager.group_op_slices.assert_not_called()
def testAssignGrouping_AllNeighborsGrouped_SlicesAligned(self): # In this test, the output op (batch norm) has size 18 and is sliced into # sizes [5, 6, 7] which matches the Conv2D sizes which are [5, 6, 7]. # Map ops to slices. Batch norm op is composed of multiple slices. self.op_slice_dict = { self.relu1_op: [self.relu1_op_slice], self.relu2_op: [self.relu2_op_slice], self.relu3_op: [self.relu3_op_slice], self.concat_op: [self.concat_op_slice], self.batch_norm_op: [ self.batch_norm_op_slice_0_5, self.batch_norm_op_slice_5_11, self.batch_norm_op_slice_11_18 ], } # Map each slice to a group. self.op_group_dict = { self.relu1_op_slice: self.relu1_op_group, self.relu2_op_slice: self.relu2_op_group, self.relu3_op_slice: self.relu3_op_group, self.batch_norm_op_slice_0_5: self.batch_norm_op_group1, self.batch_norm_op_slice_5_11: self.batch_norm_op_group2, self.batch_norm_op_slice_11_18: self.batch_norm_op_group3, } # Call handler to assign grouping. handler = concat_op_handler.ConcatOpHandler() handler.assign_grouping(self.concat_op, self.mock_op_reg_manager) # Verify manager looks up OpSlice for ops of interest. self.mock_op_reg_manager.get_op_slices.assert_has_calls( # Checking for ops to process. [ mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Initial slice data. mock.call(self.concat_op), mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Reslicing. mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), mock.call(self.concat_op), # Refreshing slice data. mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Group concat op. mock.call(self.concat_op) ]) # Verify manager only slices the concat op. self.mock_op_reg_manager.slice_op.assert_called_once_with( self.concat_op, [5, 6, 7]) # Verify manager groups the new slices. self.mock_op_reg_manager.group_op_slices.assert_has_calls([ mock.call([ self.concat_op_slice_0_5, self.relu1_op_slice, self.batch_norm_op_slice_0_5 ]), mock.call([ self.concat_op_slice_5_11, self.relu2_op_slice, self.batch_norm_op_slice_5_11 ]), mock.call([ self.concat_op_slice_11_18, self.relu3_op_slice, self.batch_norm_op_slice_11_18 ]) ])