def testIsBroadcast(self): handler = grouping_op_handler.GroupingOpHandler() self.op_group_dict = {} # Size is not 1. self.assertFalse(handler._is_broadcast(self.batch_norm_op, self.mock_op_reg_manager)) # Size is 1 but op is not grouped. ungrouped_broadcast_input = tf.zeros([2, 4, 4, 1]) ungrouped_broadcast_input_slice = orm.OpSlice(ungrouped_broadcast_input, orm.Slice(0, 1)) self.op_slice_dict[ungrouped_broadcast_input.op] = [ ungrouped_broadcast_input_slice] self.assertFalse(handler._is_broadcast(ungrouped_broadcast_input.op, self.mock_op_reg_manager)) # Size is 1 and op is grouped. broadcast_input = tf.zeros([2, 4, 4, 1]) broadcast_input_slice = orm.OpSlice(broadcast_input.op, orm.Slice(0, 1)) self.op_slice_dict[broadcast_input.op] = [broadcast_input_slice] broadcast_input_group = orm.OpGroup(broadcast_input_slice) self.op_group_dict[broadcast_input_slice] = broadcast_input_group self.assertTrue(handler._is_broadcast(broadcast_input.op, self.mock_op_reg_manager))
def testAssignGrouping_AllInputsGrouped(self): # All inputs have groups. Some output ops (mean_op and std_op) do not have # groups. self.op_group_dict = { self.batch_norm_op_slice: self.batch_norm_op_group, self.conv_op_slice: self.conv_op_group, self.relu_op_slice: self.relu_op_group, self.gamma_op_slice: self.conv_op_group, self.beta_op_slice: self.conv_op_group, } # Call handler to assign grouping. handler = grouping_op_handler.GroupingOpHandler() handler.assign_grouping(self.batch_norm_op, self.mock_op_reg_manager) # Verify manager looks up OpSlice for ops of interest. self.mock_op_reg_manager.get_op_slices.assert_has_calls( # Checking for ops to process. [mock.call(self.conv_op), mock.call(self.gamma_op), mock.call(self.beta_op), mock.call(self.relu_op), mock.call(self.mean_op), mock.call(self.std_op), # Initial slice data. mock.call(self.batch_norm_op), mock.call(self.conv_op), mock.call(self.gamma_op), mock.call(self.beta_op), mock.call(self.relu_op), mock.call(self.mean_op), mock.call(self.std_op), # Reslicing. mock.call(self.conv_op), mock.call(self.gamma_op), mock.call(self.beta_op), mock.call(self.batch_norm_op), mock.call(self.relu_op), mock.call(self.mean_op), mock.call(self.std_op), # Refreshing slice data. mock.call(self.conv_op), mock.call(self.gamma_op), mock.call(self.beta_op), mock.call(self.relu_op), mock.call(self.mean_op), mock.call(self.std_op), # Group batch norm op. mock.call(self.batch_norm_op)]) # Verify manager groups batch norm with input ops. self.mock_op_reg_manager.group_op_slices.assert_called_once_with( [self.batch_norm_op_slice, self.conv_op_slice, self.gamma_op_slice, self.beta_op_slice]) # Verify manager processes grouping for mean_op and std_op which do not have # groups. self.mock_op_reg_manager.process_ops.assert_called_once_with( [self.mean_op, self.std_op]) self.mock_op_reg_manager.process_ops_last.assert_not_called()
def _get_base_op_hander_dicts(): """Returns the base op_hander_dict for all regularizers.""" base_dict = collections.defaultdict( grouping_op_handler.GroupingOpHandler, { 'ConcatV2': concat_op_handler.ConcatOpHandler(), 'DepthToSpace': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), 'DepthwiseConv2dNative': depthwise_convolution_op_handler.DepthwiseConvolutionOpHandler(), 'ExpandDims': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), 'RandomUniform': leaf_op_handler.LeafOpHandler(), 'Reshape': leaf_op_handler.LeafOpHandler(), 'Shape': leaf_op_handler.LeafOpHandler(), 'SpaceToDepth': leaf_op_handler.LeafOpHandler(), 'StridedSlice': leaf_op_handler.LeafOpHandler(), 'TensorArrayGatherV3': leaf_op_handler.LeafOpHandler(), 'Transpose': output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), }) for resize_method in RESIZE_OP_NAMES: # Resize* ops, second input might be a tensor which will result in an error. base_dict[resize_method] = grouping_op_handler.GroupingOpHandler([0]) return base_dict
def testAssignGrouping_AllNeighborsGroupedSameGroup(self): # All neighbor ops have same group as batch norm. self.op_group_dict = { self.batch_norm_op_slice: self.batch_norm_op_group, self.conv_op_slice: self.batch_norm_op_group, self.relu_op_slice: self.batch_norm_op_group, self.gamma_op_slice: self.batch_norm_op_group, self.beta_op_slice: self.batch_norm_op_group, self.mean_op_slice: self.batch_norm_op_group, self.std_op_slice: self.batch_norm_op_group, } # Call handler to assign grouping. handler = grouping_op_handler.GroupingOpHandler() handler.assign_grouping(self.batch_norm_op, self.mock_op_reg_manager) # Verify manager looks up OpSlice for ops of interest. self.mock_op_reg_manager.get_op_slices.assert_has_calls( # Checking for ops to process. [ mock.call(self.conv_op), mock.call(self.gamma_op), mock.call(self.beta_op), mock.call(self.relu_op), mock.call(self.mean_op), mock.call(self.std_op), # Initial slice data. mock.call(self.batch_norm_op), mock.call(self.conv_op), mock.call(self.gamma_op), mock.call(self.beta_op), mock.call(self.relu_op), mock.call(self.mean_op), mock.call(self.std_op), # Reslicing. mock.call(self.conv_op), mock.call(self.gamma_op), mock.call(self.beta_op), mock.call(self.batch_norm_op), mock.call(self.relu_op), mock.call(self.mean_op), mock.call(self.std_op), # Refreshing slice data. mock.call(self.conv_op), mock.call(self.gamma_op), mock.call(self.beta_op), mock.call(self.relu_op), mock.call(self.mean_op), mock.call(self.std_op), # Group batch norm op. mock.call(self.batch_norm_op) ]) # Verify manager doesn't perform any additional grouping. self.mock_op_reg_manager.group_op_slices.assert_not_called() # Verify manager does not process any additional ops. self.mock_op_reg_manager.process_ops.assert_not_called() self.mock_op_reg_manager.process_ops_last.assert_not_called()
def testAssignGrouping_AllOutputsGrouped(self): # All outputs have groups. Input beta_op does not have a group. self.op_group_dict = { self.batch_norm_op_slice: self.batch_norm_op_group, self.conv_op_slice: self.conv_op_group, self.relu_op_slice: self.relu_op_group, self.gamma_op_slice: self.conv_op_group, self.mean_op_slice: self.relu_op_group, self.std_op_slice: self.relu_op_group, } # Call handler to assign grouping. handler = grouping_op_handler.GroupingOpHandler() handler.assign_grouping(self.batch_norm_op, self.mock_op_reg_manager) # Verify manager looks up OpSlice for ops of interest. self.mock_op_reg_manager.get_op_slices.assert_has_calls( # Checking for ops to process. [ mock.call(self.conv_op), mock.call(self.gamma_op), mock.call(self.beta_op), mock.call(self.relu_op), mock.call(self.mean_op), mock.call(self.std_op), # Initial slice data. mock.call(self.batch_norm_op), mock.call(self.conv_op), mock.call(self.gamma_op), mock.call(self.beta_op), mock.call(self.relu_op), mock.call(self.mean_op), mock.call(self.std_op), # Reslicing. mock.call(self.conv_op), mock.call(self.gamma_op), mock.call(self.beta_op), mock.call(self.batch_norm_op), mock.call(self.relu_op), mock.call(self.mean_op), mock.call(self.std_op), # Refreshing slice data. mock.call(self.conv_op), mock.call(self.gamma_op), mock.call(self.beta_op), mock.call(self.relu_op), mock.call(self.mean_op), mock.call(self.std_op) ]) # Verify manager does not group. self.mock_op_reg_manager.group_op_slices.assert_not_called() # Verify manager processes all neighbors. self.mock_op_reg_manager.process_ops.assert_called_once_with( [self.beta_op]) self.mock_op_reg_manager.process_ops_last.assert_called_once_with( [self.batch_norm_op])
def testGetInputOutputOpSlices(self): input_ops = [self.conv_op, self.gamma_op, self.beta_op] output_ops = [self.mean_op, self.std_op, self.relu_op] expected_input_op_slices = [ [self.conv_op_slice], [self.gamma_op_slice], [self.beta_op_slice]] expected_output_op_slices = [ [self.mean_op_slice], [self.std_op_slice], [self.relu_op_slice]] # Instantiate handler. handler = grouping_op_handler.GroupingOpHandler() self.assertEqual( (expected_input_op_slices, expected_output_op_slices), handler._get_input_output_op_slices(input_ops, output_ops, self.mock_op_reg_manager))
def testAssignGrouping_NoNeighborGroups(self): # No ops have groups. self.op_group_dict = {} # Call handler to assign grouping. handler = grouping_op_handler.GroupingOpHandler() handler.assign_grouping(self.batch_norm_op, self.mock_op_reg_manager) # Verify manager looks up OpSlice for ops of interest. self.mock_op_reg_manager.get_op_slices.assert_has_calls( # Checking for ops to process. [ mock.call(self.conv_op), mock.call(self.gamma_op), mock.call(self.beta_op), mock.call(self.relu_op), # Initial slice data. mock.call(self.batch_norm_op), mock.call(self.conv_op), mock.call(self.gamma_op), mock.call(self.beta_op), mock.call(self.relu_op), # Reslicing. mock.call(self.conv_op), mock.call(self.gamma_op), mock.call(self.beta_op), mock.call(self.batch_norm_op), mock.call(self.relu_op), # Refreshing slice data. mock.call(self.conv_op), mock.call(self.gamma_op), mock.call(self.beta_op), mock.call(self.relu_op) ]) # Verify manager does not group. self.mock_op_reg_manager.group_op_slices.assert_not_called() # Verify manager processes grouping for Conv2D, ReLU, and batch norm ops. self.mock_op_reg_manager.process_ops.assert_called_once_with( [self.relu_op, self.conv_op, self.gamma_op, self.beta_op]) self.mock_op_reg_manager.process_ops_last.assert_called_once_with( [self.batch_norm_op])
def assign_grouping(self, op, op_reg_manager): """Assign grouping to the given op and updates the manager. Args: op: tf.Operation to assign grouping to. op_reg_manager: OpRegularizerManager to keep track of the grouping. """ concat_axis = _get_concat_op_axis(op) # Need to figure out the rank to know if axis is last. rank = len(op.inputs[0].shape) # Rank of the first input. if concat_axis != -1 and concat_axis != rank - 1: # Concat is actually grouping inputs! handler = grouping_op_handler.GroupingOpHandler() handler.assign_grouping(op, op_reg_manager) return # If concat is of the last dimension, this is a `standard` concat. # TODO(a1): Consider refactoring this duplicated logic. # Check if all input ops have groups, or tell the manager to process them. input_ops = op_handler_util.get_input_ops(op, op_reg_manager) input_ops_without_group = op_handler_util.get_ops_without_groups( input_ops, op_reg_manager) # Check if all output ops have groups, or tell the manager to process them. output_ops = op_handler_util.get_output_ops(op, op_reg_manager) output_ops_without_group = op_handler_util.get_ops_without_groups( output_ops, op_reg_manager) # Remove non-passthrough ops from outputs ops to group with. output_ops = op_handler_util.remove_non_passthrough_ops( output_ops, op_reg_manager) # Only group with output ops that have the same size. Process the ops that # have mismatched size. input_ops_to_group = input_ops input_ops_to_process = input_ops_without_group output_ops_to_group, output_ops_to_process = ( op_handler_util.separate_same_size_ops(op, output_ops)) # Also process ungrouped ops. output_ops_to_process.extend(output_ops_without_group) # Populate OpSlice data for all relevant ops. concat_op_slices = op_reg_manager.get_op_slices(op) input_op_slices, output_op_slices = self._get_input_output_op_slices( input_ops_to_group, output_ops_to_group, op_reg_manager) # Align op slices sizes if needed. aligned_op_slice_sizes = op_handler_util.get_aligned_op_slice_sizes( concat_op_slices, input_op_slices, output_op_slices) op_handler_util.reslice_concat_ops(input_ops_to_group, aligned_op_slice_sizes, op_reg_manager) op_handler_util.reslice_ops(output_ops_to_group + [op], aligned_op_slice_sizes, op_reg_manager) # Repopulate OpSlice data, as ops may have been resliced. input_op_slices, output_op_slices = self._get_input_output_op_slices( input_ops_to_group, output_ops_to_group, op_reg_manager) # Group aligned OpSlice. op_handler_util.group_aligned_input_output_slices( op, input_ops_to_process, output_ops_to_process, input_op_slices, output_op_slices, aligned_op_slice_sizes, op_reg_manager)
def testAssignGrouping_NonPassthroughOutputsSkipped(self): # Designate ReLU as non-passthrough for this test to demonstrate that batch # norm op does not group with ReLU. def is_passthrough(op): if op == self.relu_op: return False return True self.mock_op_reg_manager.is_passthrough.side_effect = is_passthrough # All neighbor ops have groups. self.op_group_dict = { self.batch_norm_op_slice: self.batch_norm_op_group, self.conv_op_slice: self.conv_op_group, self.relu_op_slice: self.relu_op_group, self.gamma_op_slice: self.conv_op_group, self.beta_op_slice: self.conv_op_group, self.mean_op_slice: self.relu_op_group, self.std_op_slice: self.relu_op_group, } # Call handler to assign grouping. handler = grouping_op_handler.GroupingOpHandler() handler.assign_grouping(self.batch_norm_op, self.mock_op_reg_manager) # Verify manager looks up OpSlice for ops of interest. self.mock_op_reg_manager.get_op_slices.assert_has_calls( # Checking for ops to process. [mock.call(self.conv_op), mock.call(self.gamma_op), mock.call(self.beta_op), mock.call(self.relu_op), mock.call(self.mean_op), mock.call(self.std_op), # Initial slice data. mock.call(self.batch_norm_op), mock.call(self.conv_op), mock.call(self.gamma_op), mock.call(self.beta_op), mock.call(self.mean_op), mock.call(self.std_op), # Reslicing. mock.call(self.conv_op), mock.call(self.gamma_op), mock.call(self.beta_op), mock.call(self.batch_norm_op), mock.call(self.mean_op), mock.call(self.std_op), # Refreshing slice data. mock.call(self.conv_op), mock.call(self.gamma_op), mock.call(self.beta_op), mock.call(self.mean_op), mock.call(self.std_op), # Group batch norm op. mock.call(self.batch_norm_op)]) # Verify manager groups batch norm with inputs and outputs. ReLU is not # part of the grouping. self.mock_op_reg_manager.group_op_slices.assert_has_calls( [mock.call([self.batch_norm_op_slice, self.mean_op_slice, self.std_op_slice]), mock.call([self.batch_norm_op_slice, self.conv_op_slice, self.gamma_op_slice, self.beta_op_slice])]) # Verify manager does not process any additional ops. self.mock_op_reg_manager.process_ops.assert_not_called() self.mock_op_reg_manager.process_ops_last.assert_not_called()