Example #1
0
  def _group_with_output_slices(
      self, op, output_op_slices, op_slices, op_reg_manager):
    """Groups OpSlice of current op with output ops.

    Assuming OpSlice of op have been aligned with output, groups the
    corresponding OpSlice.

    Args:
      op: tf.Operation to determine grouping for.
      output_op_slices: List of list of OpSlice, with a list per output op.
      op_slices: List of OpSlice for current op.
      op_reg_manager: OpRegularizerManager to keep track of grouping.

    Raises:
      ValueError: If sizes for current and output op slices are not the same.
    """
    # Assert that op slices for output and current op are aligned.
    output_op_slices_sizes = op_handler_util.get_op_slice_sizes(
        output_op_slices)
    op_slice_sizes = op_handler_util.get_op_slice_sizes([op_slices])

    if op_slice_sizes != output_op_slices_sizes:
      raise ValueError('Current op and output op have differing slice '
                       'sizes: {}, {}'.format(
                           op_slice_sizes, output_op_slices_sizes))

    op_handler_util.group_op_with_inputs_and_outputs(
        op, [], output_op_slices, op_slice_sizes, op_reg_manager)
    def _slice_op_slice(self, op_slice, sizes, size_index, size_count,
                        new_op_slice_group):
        """Slices an OpSlice according to new sizes.

    During reslicing, any OpSlice of an op could be resliced.  Given the new
    sizes, this method finds the index where the old OpSlice matches, and
    reslices the OpSlice according to the new sizes.  The new OpSlice are added
    to new_op_slice_group by index, so that matching OpSlice can be grouped
    together later.

    Args:
      op_slice: OpSlice that should be sliced.
      sizes: List of integers specifying the new slice sizes.
      size_index: Integer specifying which index in sizes corresponds to
        op_slice.
      size_count: Integer specifying how many slices op_slice will be sliced
        into.
      new_op_slice_group: List of list of new OpSlice that should be grouped
        together.
    """
        op = op_slice.op
        op_slices = self.get_op_slices(op)

        # Get slice sizes for op.
        op_slice_sizes = op_handler_util.get_op_slice_sizes([op_slices])[0]

        # Find the slice index that needs to be resliced.
        op_slice_index = op_slices.index(op_slice)

        # Clear old OpSlice to OpGroup mapping.
        if op_slice in self._op_group_dict:
            del self._op_group_dict[op_slice]

        # Calculate the new op slice sizes for the op.
        op_slice_sizes.pop(op_slice_index)
        # Keep track of which OpSlice were resliced.
        is_resliced = [False] * len(op_slice_sizes)
        for i in range(size_count):
            op_slice_sizes.insert(op_slice_index + i, sizes[size_index + i])
            is_resliced.insert(op_slice_index + i, True)

        # Find source slices and slice the op.
        is_source = self._get_source_slices(op_slice_sizes, op_slices)
        slices = self._slice_op_with_sizes(op, op_slice_sizes, is_source,
                                           is_resliced)

        # Accumulate new OpSlice at the corresonding index.
        for i in range(size_count):
            new_op_slice_group[i].append(slices[op_slice_index + i])
Example #3
0
    def testGetOpSliceSizes(self):
        relu3_op_slice_0_3 = orm.OpSlice(self.relu2_op, orm.Slice(0, 3))
        relu3_op_slice_3_6 = orm.OpSlice(self.relu2_op, orm.Slice(3, 3))

        batch_norm_op_slice_0_5 = orm.OpSlice(self.unfused_batch_norm_op,
                                              orm.Slice(0, 5))
        batch_norm_op_slice_5_8 = orm.OpSlice(self.unfused_batch_norm_op,
                                              orm.Slice(5, 3))
        batch_norm_op_slice_8_11 = orm.OpSlice(self.unfused_batch_norm_op,
                                               orm.Slice(8, 3))
        batch_norm_op_slice_11_18 = orm.OpSlice(self.unfused_batch_norm_op,
                                                orm.Slice(11, 7))

        # Map ops to slices.
        self.op_slice_dict = {
            self.relu2_op: [self.relu2_op_slice],
            self.relu3_op: [relu3_op_slice_0_3, relu3_op_slice_3_6],
            self.relu4_op: [self.relu4_op_slice],
            self.unfused_batch_norm_op: [
                batch_norm_op_slice_0_5, batch_norm_op_slice_5_8,
                batch_norm_op_slice_8_11, batch_norm_op_slice_11_18
            ],
        }

        expected_op_slice_sizes = [
            [5],  # c2 has size 5.
            [3, 3],  # c3 has size 6, but in 2 slices of size 3.
            [7],  # c4 has size 7.
            [5, 3, 3, 7]
        ]  # batch norm has size 18, but slice sizes of c1, c2, c3.

        self.assertEqual(
            expected_op_slice_sizes,
            op_handler_util.get_op_slice_sizes(
                [[self.relu2_op_slice],
                 [relu3_op_slice_0_3,
                  relu3_op_slice_3_6], [self.relu4_op_slice],
                 [
                     batch_norm_op_slice_0_5, batch_norm_op_slice_5_8,
                     batch_norm_op_slice_8_11, batch_norm_op_slice_11_18
                 ]]))
    def slice_op(self, op, sizes):
        """Slice an op into specified sizes.

    Creates OpSlice objects to represent slices of op.  The op is mapped to its
    constituent OpSlice and reformed by concatenating the OpSlice.  For example,
    if op has 10 channels and sizes is [3, 7], then this method returns
    [OpSlice(op, (0, 3)), OpSlice(op, (3, 7))].

    Note that sizes must be able to be aligned with the original op slice sizes.
    An original slice can be partitioned into smaller slices, but the original
    slice boundaries cannot be changed.  For example, if the original sizes are
    [3, 7], the op cannot be sliced into sizes [2, 8].  However, slicing into
    sizes [1, 2, 3, 4] is okay because the original slices are being sliced
    (3 -> [1, 2] and 7 -> [3, 4]).

    Also note that ops that are grouped with op will also be sliced accordingly,
    with respective slices grouped.  For example, if OpA is grouped with OpB and
    OpC, and OpA is sliced into OpA1 and OpA2, then the result will be groups
    (OpA1, OpB1, OpC1) and (OpA2, OpB2, OpC2).

    Args:
      op: A tf.Operation to slice for the purpose of grouping.
      sizes: List of Integer sizes to slice op into.  Sizes must sum up to the
        number of output channels for op.

    Raises:
      ValueError: If sizes cannot be aligned with original op slice sizes.
    """
        old_op_slices = self.get_op_slices(op)
        old_op_slice_sizes = op_handler_util.get_op_slice_sizes(
            [old_op_slices])[0]

        # If sizes already match, then nothing happens.
        if old_op_slice_sizes == sizes:
            return

        # If sizes cannot be aligned with original sizes, raise exception.
        try:
            aligned_op_slice_sizes = op_handler_util.get_aligned_sizes(
                [old_op_slice_sizes, sizes])
        except ValueError as e:
            raise ValueError('Error with op: %s: %s' % (op.name, e.args[0]))

        if sizes != aligned_op_slice_sizes:
            raise ValueError('Cannot slice op %s from sizes %s to %s' %
                             (op.name, old_op_slice_sizes, sizes))

        # Iterate through slices to find old slices that need to be resliced.
        old_slice_index = 0
        new_slice_index = 0
        new_slice_count = 1
        while (new_slice_index + new_slice_count <= len(aligned_op_slice_sizes)
               and old_slice_index < len(old_op_slice_sizes)):
            old_size = old_op_slice_sizes[old_slice_index]
            new_size = op_handler_util.get_total_slice_size(
                sizes, new_slice_index, new_slice_count)
            if old_size == new_size:
                if new_slice_count > 1:
                    # If sizes match then this old slice is sliced into new_slice_count
                    # smaller slices.  Find the group of the old slice because all OpSlice
                    # in the group will need to be sliced similarly.
                    op_group = self.get_op_group(
                        old_op_slices[old_slice_index])
                    if op_group:
                        group_op_slices = op_group.op_slices
                    else:
                        # If OpSlice has no group, just use the OpSlice itself.
                        group_op_slices = [old_op_slices[old_slice_index]]
                    new_op_slice_group = [
                        list() for _ in range(new_slice_count)
                    ]
                    for group_op_slice in group_op_slices:
                        self._slice_op_slice(group_op_slice, sizes,
                                             new_slice_index, new_slice_count,
                                             new_op_slice_group)

                    if op_group:
                        # Group all new OpSlice along each index.
                        for i in range(new_slice_count):
                            self.group_op_slices(new_op_slice_group[i])

                # Update indices for the next slice.
                old_slice_index += 1
                new_slice_index += new_slice_count
                new_slice_count = 1
            else:
                # If sizes do not match, then more new slices are needed to match the
                # old slice.
                new_slice_count += 1