def assign_grouping(self, op, op_reg_manager): """Assign grouping to the given op and updates the manager. Args: op: tf.Operation to assign grouping to. op_reg_manager: OpRegularizerManager to keep track of the grouping. """ # Check if all input ops have groups, or tell the manager to process them. input_ops = op_handler_util.get_input_ops(op, op_reg_manager) input_ops_without_group = op_handler_util.get_ops_without_groups( input_ops, op_reg_manager) # Check if all output ops have groups, or tell the manager to process them. output_ops = op_handler_util.get_output_ops(op, op_reg_manager) output_ops_without_group = op_handler_util.get_ops_without_groups( output_ops, op_reg_manager) # Remove non-passthrough ops from outputs ops to group with. output_ops = op_handler_util.remove_non_passthrough_ops( output_ops, op_reg_manager) # Only group with ops that have the same size. Process the ops that have # mismatched size. input_ops_to_group, input_ops_to_process = ( op_handler_util.separate_same_size_ops(op, input_ops)) output_ops_to_group, output_ops_to_process = ( op_handler_util.separate_same_size_ops(op, output_ops)) # Remove broadcast ops. input_ops_to_process = [input_op for input_op in input_ops_to_process if not self._is_broadcast(input_op, op_reg_manager)] # Also process ungrouped ops. for input_op_without_group in input_ops_without_group: if input_op_without_group not in input_ops_to_process: input_ops_to_process.append(input_op_without_group) for output_op_without_group in output_ops_without_group: if output_op_without_group not in output_ops_to_process: output_ops_to_process.append(output_op_without_group) # Align op slice sizes if needed. op_slices = op_reg_manager.get_op_slices(op) input_op_slices = op_handler_util.get_op_slices( input_ops_to_group, op_reg_manager) output_op_slices = op_handler_util.get_op_slices( output_ops_to_group, op_reg_manager) aligned_op_slice_sizes = op_handler_util.get_aligned_op_slice_sizes( op_slices, input_op_slices, output_op_slices) op_handler_util.reslice_ops(input_ops_to_group + [op] + output_ops_to_group, aligned_op_slice_sizes, op_reg_manager) # Repopulate OpSlice data, as ops may have been resliced. input_op_slices, output_op_slices = self._get_input_output_op_slices( input_ops_to_group, output_ops_to_group, op_reg_manager) # Group with inputs and outputs. op_handler_util.group_aligned_input_output_slices( op, input_ops_to_process, output_ops_to_process, input_op_slices, output_op_slices, aligned_op_slice_sizes, op_reg_manager)
def assign_grouping(self, op, op_reg_manager): """Assign grouping to the given op and updates the manager. Args: op: tf.Operation to assign grouping to. op_reg_manager: OpRegularizerManager to keep track of the grouping. """ # This is a source op so begin by getting the OpGroup or creating one. op_slices = op_reg_manager.get_op_slices(op) for op_slice in op_slices: op_group = op_reg_manager.get_op_group(op_slice) if op_group is None: op_reg_manager.create_op_group_for_op_slice(op_slice) # Check if all input ops have groups, or tell the manager to process them. input_ops = op_handler_util.get_input_ops(op, op_reg_manager) input_ops_without_group = op_handler_util.get_ops_without_groups( input_ops, op_reg_manager) # Check if all output ops have groups, or tell the manager to process them. output_ops = op_handler_util.get_output_ops(op, op_reg_manager) output_ops_without_group = op_handler_util.get_ops_without_groups( output_ops, op_reg_manager) # Remove non-passthrough ops from outputs ops to group with. output_ops = op_handler_util.remove_non_passthrough_ops( output_ops, op_reg_manager) # Only group with ops that have the same size. Process the ops that have # mismatched size. output_ops_to_group, output_ops_to_process = ( op_handler_util.separate_same_size_ops(op, output_ops)) # Also process ungrouped ops. input_ops_to_process = input_ops_without_group output_ops_to_process.extend(output_ops_without_group) # Align op slice sizes if needed. output_op_slices = op_handler_util.get_op_slices( output_ops_to_group, op_reg_manager) aligned_op_slice_sizes = op_handler_util.get_aligned_op_slice_sizes( op_slices, [], output_op_slices) op_handler_util.reslice_ops([op] + output_ops_to_group, aligned_op_slice_sizes, op_reg_manager) # Repopulate OpSlice data, as ops may have been resliced. output_op_slices = op_handler_util.get_op_slices( output_ops_to_group, op_reg_manager) # Group with outputs. op_handler_util.group_op_with_inputs_and_outputs( op, [], output_op_slices, aligned_op_slice_sizes, op_reg_manager) # Reprocess ops. op_reg_manager.process_ops(output_ops_to_process + input_ops_to_process)
def assign_grouping(self, op, op_reg_manager): """Assign grouping to the given op and updates the manager. Args: op: tf.Operation to assign grouping to. op_reg_manager: OpRegularizerManager to keep track of the grouping. """ # TODO(a1): Consider refactoring this duplicated logic. # Check if all input ops have groups, or tell the manager to process them. input_ops = op_handler_util.get_input_ops(op, op_reg_manager) input_ops_without_group = op_handler_util.get_ops_without_groups( input_ops, op_reg_manager) # Check if all output ops have groups, or tell the manager to process them. output_ops = op_handler_util.get_output_ops(op, op_reg_manager) output_ops_without_group = op_handler_util.get_ops_without_groups( output_ops, op_reg_manager) # Remove non-passthrough ops from outputs ops to group with. output_ops = op_handler_util.remove_non_passthrough_ops( output_ops, op_reg_manager) # Only group with output ops that have the same size. Process the ops that # have mismatched size. input_ops_to_group = input_ops input_ops_to_process = input_ops_without_group output_ops_to_group, output_ops_to_process = ( op_handler_util.separate_same_size_ops(op, output_ops)) # Also process ungrouped ops. output_ops_to_process.extend(output_ops_without_group) # Populate OpSlice data for all relevant ops. concat_op_slices = op_reg_manager.get_op_slices(op) input_op_slices, output_op_slices = self._get_input_output_op_slices( input_ops_to_group, output_ops_to_group, op_reg_manager) # Align op slices sizes if needed. aligned_op_slice_sizes = op_handler_util.get_aligned_op_slice_sizes( concat_op_slices, input_op_slices, output_op_slices) op_handler_util.reslice_concat_ops(input_ops_to_group, aligned_op_slice_sizes, op_reg_manager) op_handler_util.reslice_ops(output_ops_to_group + [op], aligned_op_slice_sizes, op_reg_manager) # Repopulate OpSlice data, as ops may have been resliced. input_op_slices, output_op_slices = self._get_input_output_op_slices( input_ops_to_group, output_ops_to_group, op_reg_manager) # Group aligned OpSlice. op_handler_util.group_aligned_input_output_slices( op, input_ops_to_process, output_ops_to_process, input_op_slices, output_op_slices, aligned_op_slice_sizes, op_reg_manager)
def testResliceOps(self): # Map ops to slices self.op_slice_dict = { self.concat_op: [self.concat_op_slice], self.unfused_batch_norm_op: [self.unfused_batch_norm_op_slice], } op_handler_util.reslice_ops( [self.concat_op, self.unfused_batch_norm_op], [5, 4, 2, 2, 5], self.mock_op_reg_manager) # Verify manager slices input ops. self.mock_op_reg_manager.slice_op.assert_has_calls([ mock.call(self.concat_op, [5, 4, 2, 2, 5]), mock.call(self.unfused_batch_norm_op, [5, 4, 2, 2, 5]) ])
def assign_grouping(self, op, op_reg_manager): """Assign grouping to the given op and updates the manager. Args: op: tf.Operation to assign grouping to. op_reg_manager: OpRegularizerManager to keep track of the grouping. """ concat_axis = _get_concat_op_axis(op) # Need to figure out the rank to know if axis is last. rank = len(op.inputs[0].shape) # Rank of the first input. if concat_axis != -1 and concat_axis != rank - 1: # Concat is actually grouping inputs! handler = grouping_op_handler.GroupingOpHandler() handler.assign_grouping(op, op_reg_manager) return # If concat is of the last dimension, this is a `standard` concat. # TODO(a1): Consider refactoring this duplicated logic. # Check if all input ops have groups, or tell the manager to process them. input_ops = op_handler_util.get_input_ops(op, op_reg_manager) input_ops_without_group = op_handler_util.get_ops_without_groups( input_ops, op_reg_manager) # Check if all output ops have groups, or tell the manager to process them. output_ops = op_handler_util.get_output_ops(op, op_reg_manager) output_ops_without_group = op_handler_util.get_ops_without_groups( output_ops, op_reg_manager) # Remove non-passthrough ops from outputs ops to group with. output_ops = op_handler_util.remove_non_passthrough_ops( output_ops, op_reg_manager) # Only group with output ops that have the same size. Process the ops that # have mismatched size. input_ops_to_group = input_ops input_ops_to_process = input_ops_without_group output_ops_to_group, output_ops_to_process = ( op_handler_util.separate_same_size_ops(op, output_ops)) # Also process ungrouped ops. output_ops_to_process.extend(output_ops_without_group) # Populate OpSlice data for all relevant ops. concat_op_slices = op_reg_manager.get_op_slices(op) input_op_slices, output_op_slices = self._get_input_output_op_slices( input_ops_to_group, output_ops_to_group, op_reg_manager) # Align op slices sizes if needed. aligned_op_slice_sizes = op_handler_util.get_aligned_op_slice_sizes( concat_op_slices, input_op_slices, output_op_slices) op_handler_util.reslice_concat_ops(input_ops_to_group, aligned_op_slice_sizes, op_reg_manager) op_handler_util.reslice_ops(output_ops_to_group + [op], aligned_op_slice_sizes, op_reg_manager) # Repopulate OpSlice data, as ops may have been resliced. input_op_slices, output_op_slices = self._get_input_output_op_slices( input_ops_to_group, output_ops_to_group, op_reg_manager) # Group aligned OpSlice. op_handler_util.group_aligned_input_output_slices( op, input_ops_to_process, output_ops_to_process, input_op_slices, output_op_slices, aligned_op_slice_sizes, op_reg_manager)
def assign_grouping(self, op, op_reg_manager): """Assign grouping to the given op and updates the manager. Args: op: tf.Operation to assign grouping to. op_reg_manager: OpRegularizerManager to keep track of the grouping. """ # Check if all input ops have groups, or tell the manager to process them. input_ops = op_handler_util.get_input_ops(op, op_reg_manager) input_ops_without_group = op_handler_util.get_ops_without_groups( input_ops, op_reg_manager) # Check if all output ops have groups, or tell the manager to process them. output_ops = op_handler_util.get_output_ops(op, op_reg_manager) output_ops_without_group = op_handler_util.get_ops_without_groups( output_ops, op_reg_manager) # Remove non-passthrough ops from outputs ops to group with. output_ops = op_handler_util.remove_non_passthrough_ops( output_ops, op_reg_manager) # Only group with ops that have the same size. Defer the ops that have # mismatched size. input_ops_to_group = input_ops output_ops_to_group, output_ops_to_defer = ( op_handler_util.separate_same_size_ops(op, output_ops)) # Also defer ungrouped ops. input_ops_to_defer = input_ops_without_group for output_op_without_group in output_ops_without_group: if output_op_without_group not in output_ops_to_defer: output_ops_to_defer.append(output_op_without_group) # Only slice and merge if all inputs are grouped. if input_ops_to_defer: op_reg_manager.process_ops(input_ops_to_defer) return block_size = op.get_attr('block_size') block_group = block_size * block_size # For DepthToSpace, slice ops into individual channels before mapping. For # example, this op might reshape a tensor [N, H, W, 4] -> [N, 2H, 2W, 1] # where the 4 input channels are mapped to 1 output channel. Thus, slice # the input into individual OpSlice in order to group. assert len(input_ops_to_group) == 1 input_op = input_ops_to_group[0] op_handler_util.reslice_ops(input_ops, [1] * op_handler_util.get_op_size(input_op), op_reg_manager) op_handler_util.reslice_ops([op] + output_ops_to_group, [1] * op_handler_util.get_op_size(op), op_reg_manager) # Repopulate OpSlice data. op_slices = op_reg_manager.get_op_slices(op) input_op_slices = op_handler_util.get_op_slices( input_ops, op_reg_manager) # Group blocks of input channels with output channels based on block group. # For block_size B, the block group is B * B. For example, if the input # tensor is [N, H, W, 18] with block_size 3, the output tensor is # [N, 3H, 3W, 2] where block_size * block_size number of channels are mapped # to space values (i.e. 3H and 3W). See Tensorflow documentation for # additional details. for i, op_slice in enumerate(op_slices): for input_op_slice in input_op_slices: op_reg_manager.group_op_slices( input_op_slice[i * block_group:(i + 1) * block_group] + [op_slice]) # Process deferred ops. if input_ops_to_defer or output_ops_to_defer: op_reg_manager.process_ops(output_ops_to_defer + input_ops_to_defer)
def assign_grouping(self, op, op_reg_manager): """Assign grouping to the given op and updates the manager. Args: op: tf.Operation to assign grouping to. op_reg_manager: OpRegularizerManager to keep track of the grouping. """ assert op.type == 'DepthwiseConv2dNative' # Get output size. output_size = op_handler_util.get_op_size(op) # Get input size. input_size = op_handler_util.get_op_size(op.inputs[0].op) # Take depth_multiplier from size of weight tensor. depth_multiplier = op.inputs[1].shape.as_list()[-1] if depth_multiplier == 1: super(DepthwiseConvolutionOpHandler, self).assign_grouping( op, op_reg_manager) return # Check if all input ops have groups, or tell the manager to process them. input_ops = op_handler_util.get_input_ops(op, op_reg_manager) input_ops_without_group = op_handler_util.get_ops_without_groups( input_ops, op_reg_manager) # Check if all output ops have groups, or tell the manager to process them. output_ops = op_handler_util.get_output_ops(op, op_reg_manager) output_ops_without_group = op_handler_util.get_ops_without_groups( output_ops, op_reg_manager) # Remove non-passthrough ops from outputs ops to group with. output_ops = op_handler_util.remove_non_passthrough_ops( output_ops, op_reg_manager) # Only group with ops that have the same size. Process the ops that have # mismatched size. For the input, we hardcode that inputs[0] is a normal # input while inputs[1] is the depthwise filter. input_ops_to_group = [input_ops[0]] input_ops_to_process = input_ops_without_group output_ops_to_group, output_ops_to_process = ( op_handler_util.separate_same_size_ops(op, output_ops)) # Also process ungrouped ops. for output_op_without_group in output_ops_without_group: if output_op_without_group not in output_ops_to_process: output_ops_to_process.append(output_op_without_group) # Slice ops into individual channels. For example, consider 3 input # channels and depth_multiplier = 2. Let the input channels be [0, 1, 2] # and the output channels be [3, 4, 5, 6, 7, 8]. The channels should be # individually sliced and grouped with consecutive groups of size # depth_multiplier. Thus, this would end up grouping [0, 0, 1, 1, 2, 2] and # [3, 4, 5, 6, 7, 8] into groups (0, 3, 4), (1, 5, 6), and (2, 7, 8). aligned_op_slice_sizes = [1] * output_size op_handler_util.reslice_ops( input_ops_to_group, [1] * input_size, op_reg_manager) op_handler_util.reslice_ops( [op] + output_ops_to_group, aligned_op_slice_sizes, op_reg_manager) # Rearrange OpSlice to align input and output. input_op_slices, output_op_slices = ( self._get_depth_multiplier_input_output_op_slices( input_ops_to_group, input_size, output_ops_to_group, op_reg_manager, depth_multiplier)) # Group with inputs and outputs. op_handler_util.group_aligned_input_output_slices( op, input_ops_to_process, output_ops_to_process, input_op_slices, output_op_slices, aligned_op_slice_sizes, op_reg_manager)