Esempio n. 1
0
    def testGroupAlignedInputOutputSlices_NoGroups(self):
        self.op_slice_dict = {
            self.batch_norm_op: [self.batch_norm_op_slice],
            self.conv_op: [self.conv_op_slice],
            self.relu_op: [self.relu_op_slice]
        }
        self.op_group_dict = {
            self.batch_norm_op_slice: self.batch_norm_op_group,
            self.conv_op_slice: self.conv_op_group,
            self.relu_op_slice: self.relu_op_group
        }
        input_op_slices = [[self.conv_op_slice]]
        output_op_slices = [[self.relu_op_slice]]
        aligned_op_slice_sizes = [5]

        op_handler_util.group_aligned_input_output_slices(
            self.batch_norm_op, [self.conv_op], [self.relu_op],
            input_op_slices, output_op_slices, aligned_op_slice_sizes,
            self.mock_op_reg_manager)

        self.mock_op_reg_manager.group_op_slices.assert_not_called()
        self.mock_op_reg_manager.process_ops.assert_called_once_with(
            [self.relu_op, self.conv_op])
        self.mock_op_reg_manager.process_ops_last.assert_called_once_with(
            [self.batch_norm_op])
Esempio n. 2
0
    def testGroupAlignedInputOutputSlices_InputsOutputsGrouped(self):
        self.op_slice_dict = {
            self.batch_norm_op: [self.batch_norm_op_slice],
            self.conv_op: [self.conv_op_slice],
            self.relu_op: [self.relu_op_slice]
        }
        self.op_group_dict = {
            self.batch_norm_op_slice: self.batch_norm_op_group,
            self.conv_op_slice: self.conv_op_group,
            self.relu_op_slice: self.relu_op_group
        }
        input_op_slices = [[self.conv_op_slice]]
        output_op_slices = [[self.relu_op_slice]]
        aligned_op_slice_sizes = [5]

        op_handler_util.group_aligned_input_output_slices(
            self.batch_norm_op, [], [], input_op_slices, output_op_slices,
            aligned_op_slice_sizes, self.mock_op_reg_manager)

        self.mock_op_reg_manager.group_op_slices.assert_has_calls([
            mock.call([self.batch_norm_op_slice, self.relu_op_slice]),
            mock.call([self.batch_norm_op_slice, self.conv_op_slice],
                      omit_source_op_slices=[])
        ])
        self.mock_op_reg_manager.process_ops.assert_not_called()
  def assign_grouping(self, op, op_reg_manager):
    """Assign grouping to the given op and updates the manager.

    Args:
      op: tf.Operation to assign grouping to.
      op_reg_manager: OpRegularizerManager to keep track of the grouping.
    """
    # Check if all input ops have groups, or tell the manager to process them.
    input_ops = op_handler_util.get_input_ops(op, op_reg_manager)
    input_ops_without_group = op_handler_util.get_ops_without_groups(
        input_ops, op_reg_manager)

    # Check if all output ops have groups, or tell the manager to process them.
    output_ops = op_handler_util.get_output_ops(op, op_reg_manager)
    output_ops_without_group = op_handler_util.get_ops_without_groups(
        output_ops, op_reg_manager)

    # Remove non-passthrough ops from outputs ops to group with.
    output_ops = op_handler_util.remove_non_passthrough_ops(
        output_ops, op_reg_manager)

    # Only group with ops that have the same size.  Process the ops that have
    # mismatched size.
    input_ops_to_group, input_ops_to_process = (
        op_handler_util.separate_same_size_ops(op, input_ops))
    output_ops_to_group, output_ops_to_process = (
        op_handler_util.separate_same_size_ops(op, output_ops))

    # Remove broadcast ops.
    input_ops_to_process = [input_op for input_op in input_ops_to_process
                            if not self._is_broadcast(input_op, op_reg_manager)]

    # Also process ungrouped ops.
    for input_op_without_group in input_ops_without_group:
      if input_op_without_group not in input_ops_to_process:
        input_ops_to_process.append(input_op_without_group)
    for output_op_without_group in output_ops_without_group:
      if output_op_without_group not in output_ops_to_process:
        output_ops_to_process.append(output_op_without_group)

    # Align op slice sizes if needed.
    op_slices = op_reg_manager.get_op_slices(op)
    input_op_slices = op_handler_util.get_op_slices(
        input_ops_to_group, op_reg_manager)
    output_op_slices = op_handler_util.get_op_slices(
        output_ops_to_group, op_reg_manager)
    aligned_op_slice_sizes = op_handler_util.get_aligned_op_slice_sizes(
        op_slices, input_op_slices, output_op_slices)
    op_handler_util.reslice_ops(input_ops_to_group + [op] + output_ops_to_group,
                                aligned_op_slice_sizes, op_reg_manager)

    # Repopulate OpSlice data, as ops may have been resliced.
    input_op_slices, output_op_slices = self._get_input_output_op_slices(
        input_ops_to_group, output_ops_to_group, op_reg_manager)

    # Group with inputs and outputs.
    op_handler_util.group_aligned_input_output_slices(
        op, input_ops_to_process, output_ops_to_process, input_op_slices,
        output_op_slices, aligned_op_slice_sizes, op_reg_manager)
Esempio n. 4
0
    def assign_grouping(self, op, op_reg_manager):
        """Assign grouping to the given op and updates the manager.

    Args:
      op: tf.Operation to assign grouping to.
      op_reg_manager: OpRegularizerManager to keep track of the grouping.
    """
        # TODO(a1): Consider refactoring this duplicated logic.
        # Check if all input ops have groups, or tell the manager to process them.
        input_ops = op_handler_util.get_input_ops(op, op_reg_manager)
        input_ops_without_group = op_handler_util.get_ops_without_groups(
            input_ops, op_reg_manager)

        # Check if all output ops have groups, or tell the manager to process them.
        output_ops = op_handler_util.get_output_ops(op, op_reg_manager)
        output_ops_without_group = op_handler_util.get_ops_without_groups(
            output_ops, op_reg_manager)

        # Remove non-passthrough ops from outputs ops to group with.
        output_ops = op_handler_util.remove_non_passthrough_ops(
            output_ops, op_reg_manager)

        # Only group with output ops that have the same size.  Process the ops that
        # have mismatched size.
        input_ops_to_group = input_ops
        input_ops_to_process = input_ops_without_group
        output_ops_to_group, output_ops_to_process = (
            op_handler_util.separate_same_size_ops(op, output_ops))

        # Also process ungrouped ops.
        output_ops_to_process.extend(output_ops_without_group)

        # Populate OpSlice data for all relevant ops.
        concat_op_slices = op_reg_manager.get_op_slices(op)
        input_op_slices, output_op_slices = self._get_input_output_op_slices(
            input_ops_to_group, output_ops_to_group, op_reg_manager)

        # Align op slices sizes if needed.
        aligned_op_slice_sizes = op_handler_util.get_aligned_op_slice_sizes(
            concat_op_slices, input_op_slices, output_op_slices)
        op_handler_util.reslice_concat_ops(input_ops_to_group,
                                           aligned_op_slice_sizes,
                                           op_reg_manager)
        op_handler_util.reslice_ops(output_ops_to_group + [op],
                                    aligned_op_slice_sizes, op_reg_manager)

        # Repopulate OpSlice data, as ops may have been resliced.
        input_op_slices, output_op_slices = self._get_input_output_op_slices(
            input_ops_to_group, output_ops_to_group, op_reg_manager)

        # Group aligned OpSlice.
        op_handler_util.group_aligned_input_output_slices(
            op, input_ops_to_process, output_ops_to_process, input_op_slices,
            output_op_slices, aligned_op_slice_sizes, op_reg_manager)
    def assign_grouping(self, op, op_reg_manager):
        """Assign grouping to the given op and updates the manager.

    Args:
      op: tf.Operation to assign grouping to.
      op_reg_manager: OpRegularizerManager to keep track of the grouping.
    """
        concat_axis = _get_concat_op_axis(op)
        # Need to figure out the rank to know if axis is last.
        rank = len(op.inputs[0].shape)  # Rank of the first input.

        if concat_axis != -1 and concat_axis != rank - 1:
            # Concat is actually grouping inputs!
            handler = grouping_op_handler.GroupingOpHandler()
            handler.assign_grouping(op, op_reg_manager)
            return

        # If concat is of the last dimension, this is a `standard` concat.
        # TODO(a1): Consider refactoring this duplicated logic.
        # Check if all input ops have groups, or tell the manager to process them.
        input_ops = op_handler_util.get_input_ops(op, op_reg_manager)
        input_ops_without_group = op_handler_util.get_ops_without_groups(
            input_ops, op_reg_manager)

        # Check if all output ops have groups, or tell the manager to process them.
        output_ops = op_handler_util.get_output_ops(op, op_reg_manager)
        output_ops_without_group = op_handler_util.get_ops_without_groups(
            output_ops, op_reg_manager)

        # Remove non-passthrough ops from outputs ops to group with.
        output_ops = op_handler_util.remove_non_passthrough_ops(
            output_ops, op_reg_manager)

        # Only group with output ops that have the same size.  Process the ops that
        # have mismatched size.
        input_ops_to_group = input_ops
        input_ops_to_process = input_ops_without_group
        output_ops_to_group, output_ops_to_process = (
            op_handler_util.separate_same_size_ops(op, output_ops))

        # Also process ungrouped ops.
        output_ops_to_process.extend(output_ops_without_group)

        # Populate OpSlice data for all relevant ops.
        concat_op_slices = op_reg_manager.get_op_slices(op)
        input_op_slices, output_op_slices = self._get_input_output_op_slices(
            input_ops_to_group, output_ops_to_group, op_reg_manager)

        # Align op slices sizes if needed.
        aligned_op_slice_sizes = op_handler_util.get_aligned_op_slice_sizes(
            concat_op_slices, input_op_slices, output_op_slices)
        op_handler_util.reslice_concat_ops(input_ops_to_group,
                                           aligned_op_slice_sizes,
                                           op_reg_manager)
        op_handler_util.reslice_ops(output_ops_to_group + [op],
                                    aligned_op_slice_sizes, op_reg_manager)

        # Repopulate OpSlice data, as ops may have been resliced.
        input_op_slices, output_op_slices = self._get_input_output_op_slices(
            input_ops_to_group, output_ops_to_group, op_reg_manager)

        # Group aligned OpSlice.
        op_handler_util.group_aligned_input_output_slices(
            op, input_ops_to_process, output_ops_to_process, input_op_slices,
            output_op_slices, aligned_op_slice_sizes, op_reg_manager)
Esempio n. 6
0
  def assign_grouping(self, op, op_reg_manager):
    """Assign grouping to the given op and updates the manager.

    Args:
      op: tf.Operation to assign grouping to.
      op_reg_manager: OpRegularizerManager to keep track of the grouping.
    """
    assert op.type == 'DepthwiseConv2dNative'

    # Get output size.
    output_size = op_handler_util.get_op_size(op)

    # Get input size.
    input_size = op_handler_util.get_op_size(op.inputs[0].op)

    # Take depth_multiplier from size of weight tensor.
    depth_multiplier = op.inputs[1].shape.as_list()[-1]

    if depth_multiplier == 1:
      super(DepthwiseConvolutionOpHandler, self).assign_grouping(
          op, op_reg_manager)
      return

    # Check if all input ops have groups, or tell the manager to process them.
    input_ops = op_handler_util.get_input_ops(op, op_reg_manager)
    input_ops_without_group = op_handler_util.get_ops_without_groups(
        input_ops, op_reg_manager)

    # Check if all output ops have groups, or tell the manager to process them.
    output_ops = op_handler_util.get_output_ops(op, op_reg_manager)
    output_ops_without_group = op_handler_util.get_ops_without_groups(
        output_ops, op_reg_manager)

    # Remove non-passthrough ops from outputs ops to group with.
    output_ops = op_handler_util.remove_non_passthrough_ops(
        output_ops, op_reg_manager)

    # Only group with ops that have the same size.  Process the ops that have
    # mismatched size.  For the input, we hardcode that inputs[0] is a normal
    # input while inputs[1] is the depthwise filter.
    input_ops_to_group = [input_ops[0]]
    input_ops_to_process = input_ops_without_group
    output_ops_to_group, output_ops_to_process = (
        op_handler_util.separate_same_size_ops(op, output_ops))

    # Also process ungrouped ops.
    for output_op_without_group in output_ops_without_group:
      if output_op_without_group not in output_ops_to_process:
        output_ops_to_process.append(output_op_without_group)

    # Slice ops into individual channels.  For example, consider 3 input
    # channels and depth_multiplier = 2.  Let the input channels be [0, 1, 2]
    # and the output channels be [3, 4, 5, 6, 7, 8].  The channels should be
    # individually sliced and grouped with consecutive groups of size
    # depth_multiplier.  Thus, this would end up grouping [0, 0, 1, 1, 2, 2] and
    # [3, 4, 5, 6, 7, 8] into groups (0, 3, 4), (1, 5, 6), and (2, 7, 8).
    aligned_op_slice_sizes = [1] * output_size
    op_handler_util.reslice_ops(
        input_ops_to_group, [1] * input_size, op_reg_manager)
    op_handler_util.reslice_ops(
        [op] + output_ops_to_group, aligned_op_slice_sizes, op_reg_manager)

    # Rearrange OpSlice to align input and output.
    input_op_slices, output_op_slices = (
        self._get_depth_multiplier_input_output_op_slices(
            input_ops_to_group, input_size, output_ops_to_group,
            op_reg_manager, depth_multiplier))

    # Group with inputs and outputs.
    op_handler_util.group_aligned_input_output_slices(
        op, input_ops_to_process, output_ops_to_process, input_op_slices,
        output_op_slices, aligned_op_slice_sizes, op_reg_manager)