Example #1
0
    def testGetTotalSliceSize(self):
        op_slice_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9]

        self.assertEqual(
            15, op_handler_util.get_total_slice_size(op_slice_sizes, 0, 5))
        self.assertEqual(
            15, op_handler_util.get_total_slice_size(op_slice_sizes, 3, 3))
        self.assertEqual(
            30, op_handler_util.get_total_slice_size(op_slice_sizes, 5, 4))
        self.assertEqual(
            3, op_handler_util.get_total_slice_size(op_slice_sizes, 2, 1))
    def slice_op(self, op, sizes):
        """Slice an op into specified sizes.

    Creates OpSlice objects to represent slices of op.  The op is mapped to its
    constituent OpSlice and reformed by concatenating the OpSlice.  For example,
    if op has 10 channels and sizes is [3, 7], then this method returns
    [OpSlice(op, (0, 3)), OpSlice(op, (3, 7))].

    Note that sizes must be able to be aligned with the original op slice sizes.
    An original slice can be partitioned into smaller slices, but the original
    slice boundaries cannot be changed.  For example, if the original sizes are
    [3, 7], the op cannot be sliced into sizes [2, 8].  However, slicing into
    sizes [1, 2, 3, 4] is okay because the original slices are being sliced
    (3 -> [1, 2] and 7 -> [3, 4]).

    Also note that ops that are grouped with op will also be sliced accordingly,
    with respective slices grouped.  For example, if OpA is grouped with OpB and
    OpC, and OpA is sliced into OpA1 and OpA2, then the result will be groups
    (OpA1, OpB1, OpC1) and (OpA2, OpB2, OpC2).

    Args:
      op: A tf.Operation to slice for the purpose of grouping.
      sizes: List of Integer sizes to slice op into.  Sizes must sum up to the
        number of output channels for op.

    Raises:
      ValueError: If sizes cannot be aligned with original op slice sizes.
    """
        old_op_slices = self.get_op_slices(op)
        old_op_slice_sizes = op_handler_util.get_op_slice_sizes(
            [old_op_slices])[0]

        # If sizes already match, then nothing happens.
        if old_op_slice_sizes == sizes:
            return

        # If sizes cannot be aligned with original sizes, raise exception.
        try:
            aligned_op_slice_sizes = op_handler_util.get_aligned_sizes(
                [old_op_slice_sizes, sizes])
        except ValueError as e:
            raise ValueError('Error with op: %s: %s' % (op.name, e.args[0]))

        if sizes != aligned_op_slice_sizes:
            raise ValueError('Cannot slice op %s from sizes %s to %s' %
                             (op.name, old_op_slice_sizes, sizes))

        # Iterate through slices to find old slices that need to be resliced.
        old_slice_index = 0
        new_slice_index = 0
        new_slice_count = 1
        while (new_slice_index + new_slice_count <= len(aligned_op_slice_sizes)
               and old_slice_index < len(old_op_slice_sizes)):
            old_size = old_op_slice_sizes[old_slice_index]
            new_size = op_handler_util.get_total_slice_size(
                sizes, new_slice_index, new_slice_count)
            if old_size == new_size:
                if new_slice_count > 1:
                    # If sizes match then this old slice is sliced into new_slice_count
                    # smaller slices.  Find the group of the old slice because all OpSlice
                    # in the group will need to be sliced similarly.
                    op_group = self.get_op_group(
                        old_op_slices[old_slice_index])
                    if op_group:
                        group_op_slices = op_group.op_slices
                    else:
                        # If OpSlice has no group, just use the OpSlice itself.
                        group_op_slices = [old_op_slices[old_slice_index]]
                    new_op_slice_group = [
                        list() for _ in range(new_slice_count)
                    ]
                    for group_op_slice in group_op_slices:
                        self._slice_op_slice(group_op_slice, sizes,
                                             new_slice_index, new_slice_count,
                                             new_op_slice_group)

                    if op_group:
                        # Group all new OpSlice along each index.
                        for i in range(new_slice_count):
                            self.group_op_slices(new_op_slice_group[i])

                # Update indices for the next slice.
                old_slice_index += 1
                new_slice_index += new_slice_count
                new_slice_count = 1
            else:
                # If sizes do not match, then more new slices are needed to match the
                # old slice.
                new_slice_count += 1