def assign_grouping(self, op, op_reg_manager): """Assign grouping to the given op and updates the manager. Args: op: tf.Operation to assign grouping to. op_reg_manager: OpRegularizerManager to keep track of the grouping. """ # TODO(a1): Consider refactoring this duplicated logic. # Check if all input ops have groups, or tell the manager to process them. input_ops = op_handler_util.get_input_ops(op, op_reg_manager) input_ops_without_group = op_handler_util.get_ops_without_groups( input_ops, op_reg_manager) # Check if all output ops have groups, or tell the manager to process them. output_ops = op_handler_util.get_output_ops(op, op_reg_manager) output_ops_without_group = op_handler_util.get_ops_without_groups( output_ops, op_reg_manager) # Remove non-passthrough ops from outputs ops to group with. output_ops = op_handler_util.remove_non_passthrough_ops( output_ops, op_reg_manager) # Only group with output ops that have the same size. Process the ops that # have mismatched size. input_ops_to_group = input_ops input_ops_to_process = input_ops_without_group output_ops_to_group, output_ops_to_process = ( op_handler_util.separate_same_size_ops(op, output_ops)) # Also process ungrouped ops. output_ops_to_process.extend(output_ops_without_group) # Populate OpSlice data for all relevant ops. concat_op_slices = op_reg_manager.get_op_slices(op) input_op_slices, output_op_slices = self._get_input_output_op_slices( input_ops_to_group, output_ops_to_group, op_reg_manager) # Align op slices sizes if needed. aligned_op_slice_sizes = op_handler_util.get_aligned_op_slice_sizes( concat_op_slices, input_op_slices, output_op_slices) op_handler_util.reslice_concat_ops(input_ops_to_group, aligned_op_slice_sizes, op_reg_manager) op_handler_util.reslice_ops(output_ops_to_group + [op], aligned_op_slice_sizes, op_reg_manager) # Repopulate OpSlice data, as ops may have been resliced. input_op_slices, output_op_slices = self._get_input_output_op_slices( input_ops_to_group, output_ops_to_group, op_reg_manager) # Group aligned OpSlice. op_handler_util.group_aligned_input_output_slices( op, input_ops_to_process, output_ops_to_process, input_op_slices, output_op_slices, aligned_op_slice_sizes, op_reg_manager)
def testResliceConcatOps_Aligned(self): # Map ops to slices. self.op_slice_dict = { self.relu2_op: [self.relu2_op_slice], self.relu3_op: [self.relu3_op_slice], self.relu4_op: [self.relu4_op_slice], } op_handler_util.reslice_concat_ops( [self.relu2_op, self.relu3_op, self.relu4_op], [5, 6, 7], self.mock_op_reg_manager) # Verify manager does not slice any ops. self.mock_op_reg_manager.slice_op.assert_not_called()
def testResliceConcatOps_NotAligned(self): relu3_op_slice_0_3 = orm.OpSlice(self.relu3_op, orm.Slice(0, 3)) relu3_op_slice_3_6 = orm.OpSlice(self.relu3_op, orm.Slice(3, 3)) # Map ops to slices. The op c3 is composed of multiple slices. self.op_slice_dict = { self.relu2_op: [self.relu2_op_slice], self.relu3_op: [relu3_op_slice_0_3, relu3_op_slice_3_6], self.relu4_op: [self.relu4_op_slice], } op_handler_util.reslice_concat_ops( [self.relu2_op, self.relu3_op, self.relu4_op], [5, 4, 2, 2, 5], self.mock_op_reg_manager) # Verify manager slices input ops. self.mock_op_reg_manager.slice_op.assert_has_calls([ mock.call(self.relu3_op, [4, 2]), mock.call(self.relu4_op, [2, 5]) ])
def assign_grouping(self, op, op_reg_manager): """Assign grouping to the given op and updates the manager. Args: op: tf.Operation to assign grouping to. op_reg_manager: OpRegularizerManager to keep track of the grouping. """ concat_axis = _get_concat_op_axis(op) # Need to figure out the rank to know if axis is last. rank = len(op.inputs[0].shape) # Rank of the first input. if concat_axis != -1 and concat_axis != rank - 1: # Concat is actually grouping inputs! handler = grouping_op_handler.GroupingOpHandler() handler.assign_grouping(op, op_reg_manager) return # If concat is of the last dimension, this is a `standard` concat. # TODO(a1): Consider refactoring this duplicated logic. # Check if all input ops have groups, or tell the manager to process them. input_ops = op_handler_util.get_input_ops(op, op_reg_manager) input_ops_without_group = op_handler_util.get_ops_without_groups( input_ops, op_reg_manager) # Check if all output ops have groups, or tell the manager to process them. output_ops = op_handler_util.get_output_ops(op, op_reg_manager) output_ops_without_group = op_handler_util.get_ops_without_groups( output_ops, op_reg_manager) # Remove non-passthrough ops from outputs ops to group with. output_ops = op_handler_util.remove_non_passthrough_ops( output_ops, op_reg_manager) # Only group with output ops that have the same size. Process the ops that # have mismatched size. input_ops_to_group = input_ops input_ops_to_process = input_ops_without_group output_ops_to_group, output_ops_to_process = ( op_handler_util.separate_same_size_ops(op, output_ops)) # Also process ungrouped ops. output_ops_to_process.extend(output_ops_without_group) # Populate OpSlice data for all relevant ops. concat_op_slices = op_reg_manager.get_op_slices(op) input_op_slices, output_op_slices = self._get_input_output_op_slices( input_ops_to_group, output_ops_to_group, op_reg_manager) # Align op slices sizes if needed. aligned_op_slice_sizes = op_handler_util.get_aligned_op_slice_sizes( concat_op_slices, input_op_slices, output_op_slices) op_handler_util.reslice_concat_ops(input_ops_to_group, aligned_op_slice_sizes, op_reg_manager) op_handler_util.reslice_ops(output_ops_to_group + [op], aligned_op_slice_sizes, op_reg_manager) # Repopulate OpSlice data, as ops may have been resliced. input_op_slices, output_op_slices = self._get_input_output_op_slices( input_ops_to_group, output_ops_to_group, op_reg_manager) # Group aligned OpSlice. op_handler_util.group_aligned_input_output_slices( op, input_ops_to_process, output_ops_to_process, input_op_slices, output_op_slices, aligned_op_slice_sizes, op_reg_manager)