示例#1
0
    def assign_grouping(self, op, op_reg_manager):
        """Assign grouping to the given op and updates the manager.

    Args:
      op: tf.Operation to assign grouping to.
      op_reg_manager: OpRegularizerManager to keep track of the grouping.
    """
        # TODO(a1): Consider refactoring this duplicated logic.
        # Check if all input ops have groups, or tell the manager to process them.
        input_ops = op_handler_util.get_input_ops(op, op_reg_manager)
        input_ops_without_group = op_handler_util.get_ops_without_groups(
            input_ops, op_reg_manager)

        # Check if all output ops have groups, or tell the manager to process them.
        output_ops = op_handler_util.get_output_ops(op, op_reg_manager)
        output_ops_without_group = op_handler_util.get_ops_without_groups(
            output_ops, op_reg_manager)

        # Remove non-passthrough ops from outputs ops to group with.
        output_ops = op_handler_util.remove_non_passthrough_ops(
            output_ops, op_reg_manager)

        # Only group with output ops that have the same size.  Process the ops that
        # have mismatched size.
        input_ops_to_group = input_ops
        input_ops_to_process = input_ops_without_group
        output_ops_to_group, output_ops_to_process = (
            op_handler_util.separate_same_size_ops(op, output_ops))

        # Also process ungrouped ops.
        output_ops_to_process.extend(output_ops_without_group)

        # Populate OpSlice data for all relevant ops.
        concat_op_slices = op_reg_manager.get_op_slices(op)
        input_op_slices, output_op_slices = self._get_input_output_op_slices(
            input_ops_to_group, output_ops_to_group, op_reg_manager)

        # Align op slices sizes if needed.
        aligned_op_slice_sizes = op_handler_util.get_aligned_op_slice_sizes(
            concat_op_slices, input_op_slices, output_op_slices)
        op_handler_util.reslice_concat_ops(input_ops_to_group,
                                           aligned_op_slice_sizes,
                                           op_reg_manager)
        op_handler_util.reslice_ops(output_ops_to_group + [op],
                                    aligned_op_slice_sizes, op_reg_manager)

        # Repopulate OpSlice data, as ops may have been resliced.
        input_op_slices, output_op_slices = self._get_input_output_op_slices(
            input_ops_to_group, output_ops_to_group, op_reg_manager)

        # Group aligned OpSlice.
        op_handler_util.group_aligned_input_output_slices(
            op, input_ops_to_process, output_ops_to_process, input_op_slices,
            output_op_slices, aligned_op_slice_sizes, op_reg_manager)
示例#2
0
    def testResliceConcatOps_Aligned(self):
        # Map ops to slices.
        self.op_slice_dict = {
            self.relu2_op: [self.relu2_op_slice],
            self.relu3_op: [self.relu3_op_slice],
            self.relu4_op: [self.relu4_op_slice],
        }

        op_handler_util.reslice_concat_ops(
            [self.relu2_op, self.relu3_op, self.relu4_op], [5, 6, 7],
            self.mock_op_reg_manager)

        # Verify manager does not slice any ops.
        self.mock_op_reg_manager.slice_op.assert_not_called()
示例#3
0
    def testResliceConcatOps_NotAligned(self):
        relu3_op_slice_0_3 = orm.OpSlice(self.relu3_op, orm.Slice(0, 3))
        relu3_op_slice_3_6 = orm.OpSlice(self.relu3_op, orm.Slice(3, 3))

        # Map ops to slices.  The op c3 is composed of multiple slices.
        self.op_slice_dict = {
            self.relu2_op: [self.relu2_op_slice],
            self.relu3_op: [relu3_op_slice_0_3, relu3_op_slice_3_6],
            self.relu4_op: [self.relu4_op_slice],
        }

        op_handler_util.reslice_concat_ops(
            [self.relu2_op, self.relu3_op, self.relu4_op], [5, 4, 2, 2, 5],
            self.mock_op_reg_manager)

        # Verify manager slices input ops.
        self.mock_op_reg_manager.slice_op.assert_has_calls([
            mock.call(self.relu3_op, [4, 2]),
            mock.call(self.relu4_op, [2, 5])
        ])
    def assign_grouping(self, op, op_reg_manager):
        """Assign grouping to the given op and updates the manager.

    Args:
      op: tf.Operation to assign grouping to.
      op_reg_manager: OpRegularizerManager to keep track of the grouping.
    """
        concat_axis = _get_concat_op_axis(op)
        # Need to figure out the rank to know if axis is last.
        rank = len(op.inputs[0].shape)  # Rank of the first input.

        if concat_axis != -1 and concat_axis != rank - 1:
            # Concat is actually grouping inputs!
            handler = grouping_op_handler.GroupingOpHandler()
            handler.assign_grouping(op, op_reg_manager)
            return

        # If concat is of the last dimension, this is a `standard` concat.
        # TODO(a1): Consider refactoring this duplicated logic.
        # Check if all input ops have groups, or tell the manager to process them.
        input_ops = op_handler_util.get_input_ops(op, op_reg_manager)
        input_ops_without_group = op_handler_util.get_ops_without_groups(
            input_ops, op_reg_manager)

        # Check if all output ops have groups, or tell the manager to process them.
        output_ops = op_handler_util.get_output_ops(op, op_reg_manager)
        output_ops_without_group = op_handler_util.get_ops_without_groups(
            output_ops, op_reg_manager)

        # Remove non-passthrough ops from outputs ops to group with.
        output_ops = op_handler_util.remove_non_passthrough_ops(
            output_ops, op_reg_manager)

        # Only group with output ops that have the same size.  Process the ops that
        # have mismatched size.
        input_ops_to_group = input_ops
        input_ops_to_process = input_ops_without_group
        output_ops_to_group, output_ops_to_process = (
            op_handler_util.separate_same_size_ops(op, output_ops))

        # Also process ungrouped ops.
        output_ops_to_process.extend(output_ops_without_group)

        # Populate OpSlice data for all relevant ops.
        concat_op_slices = op_reg_manager.get_op_slices(op)
        input_op_slices, output_op_slices = self._get_input_output_op_slices(
            input_ops_to_group, output_ops_to_group, op_reg_manager)

        # Align op slices sizes if needed.
        aligned_op_slice_sizes = op_handler_util.get_aligned_op_slice_sizes(
            concat_op_slices, input_op_slices, output_op_slices)
        op_handler_util.reslice_concat_ops(input_ops_to_group,
                                           aligned_op_slice_sizes,
                                           op_reg_manager)
        op_handler_util.reslice_ops(output_ops_to_group + [op],
                                    aligned_op_slice_sizes, op_reg_manager)

        # Repopulate OpSlice data, as ops may have been resliced.
        input_op_slices, output_op_slices = self._get_input_output_op_slices(
            input_ops_to_group, output_ops_to_group, op_reg_manager)

        # Group aligned OpSlice.
        op_handler_util.group_aligned_input_output_slices(
            op, input_ops_to_process, output_ops_to_process, input_op_slices,
            output_op_slices, aligned_op_slice_sizes, op_reg_manager)