Esempio n. 1
0
def _get_base_op_hander_dicts():
    """Returns the base op_hander_dict for all regularizers."""
    base_dict = collections.defaultdict(
        grouping_op_handler.GroupingOpHandler, {
            'ConcatV2':
            concat_op_handler.ConcatOpHandler(),
            'DepthToSpace':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
            'DepthwiseConv2dNative':
            depthwise_convolution_op_handler.DepthwiseConvolutionOpHandler(),
            'ExpandDims':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
            'RandomUniform':
            leaf_op_handler.LeafOpHandler(),
            'Reshape':
            leaf_op_handler.LeafOpHandler(),
            'Shape':
            leaf_op_handler.LeafOpHandler(),
            'SpaceToDepth':
            leaf_op_handler.LeafOpHandler(),
            'StridedSlice':
            leaf_op_handler.LeafOpHandler(),
            'TensorArrayGatherV3':
            leaf_op_handler.LeafOpHandler(),
            'Transpose':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
        })
    for resize_method in RESIZE_OP_NAMES:
        # Resize* ops, second input might be a tensor which will result in an error.
        base_dict[resize_method] = grouping_op_handler.GroupingOpHandler([0])
    return base_dict
Esempio n. 2
0
def _get_base_op_hander_dicts():
    return collections.defaultdict(
        grouping_op_handler.GroupingOpHandler, {
            'ConcatV2':
            concat_op_handler.ConcatOpHandler(),
            'DepthToSpace':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
            'DepthwiseConv2dNative':
            depthwise_convolution_op_handler.DepthwiseConvolutionOpHandler(),
            'ExpandDims':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
            'RandomUniform':
            leaf_op_handler.LeafOpHandler(),
            'Reshape':
            leaf_op_handler.LeafOpHandler(),
            'Shape':
            leaf_op_handler.LeafOpHandler(),
            'SpaceToDepth':
            leaf_op_handler.LeafOpHandler(),
            'StridedSlice':
            leaf_op_handler.LeafOpHandler(),
            'TensorArrayGatherV3':
            leaf_op_handler.LeafOpHandler(),
            'Transpose':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
        })
  def test_AssignGroupingOfGroupingConcatNoSlicing(self):
    # In this test, the output op (batch norm) has size 6 and is not sliced.
    # and that input Conv2Ds are all of size 6, and are grouped.

    # Map ops to slices.  Batch norm op is composed of multiple slices.
    self.op_slice_dict = {
        self.relu1_op: [self.relu1_op_slice],
        self.relu2_op: [self.relu2_op_slice],
        self.relu3_op: [self.relu3_op_slice],
        self.concat_op: [self.concat_op_slice],
        self.batch_norm_op: [self.batch_norm_op_slice],
    }

    # Map each slice to a group.
    self.op_group_dict = {
        self.relu1_op_slice: self.relu1_op_group,
        self.relu2_op_slice: self.relu2_op_group,
        self.relu3_op_slice: self.relu3_op_group,
        self.batch_norm_op_slice: self.batch_norm_op_group
    }

    # Call handler to assign grouping.
    handler = concat_op_handler.ConcatOpHandler()
    handler.assign_grouping(self.concat_op, self.mock_op_reg_manager)

    # Verify manager looks up OpSlice for ops of interest.
    self.mock_op_reg_manager.get_op_slices.assert_has_calls(
        # Checking for ops to process.
        [mock.call(self.relu1_op),
         mock.call(self.relu2_op),
         mock.call(self.relu3_op),
         mock.call(self.batch_norm_op),
         # Initial slice data.
         mock.call(self.concat_op),
         mock.call(self.relu1_op),
         mock.call(self.relu2_op),
         mock.call(self.relu3_op),
         mock.call(self.batch_norm_op),
         # Reslicing.
         mock.call(self.relu1_op),
         mock.call(self.relu2_op),
         mock.call(self.relu3_op),
         mock.call(self.concat_op),
         mock.call(self.batch_norm_op),
         # Refreshing slice data.
         mock.call(self.relu1_op),
         mock.call(self.relu2_op),
         mock.call(self.relu3_op),
         mock.call(self.batch_norm_op),
         # Group concat op.
         mock.call(self.concat_op)])

    # Verify manager does not slices the concat op.
    self.mock_op_reg_manager.slice_op.assert_not_called()

    # Verify manager groups the new slices.
    self.mock_op_reg_manager.group_op_slices.assert_called_once_with([
        self.concat_op_slice, self.relu1_op_slice, self.relu2_op_slice,
        self.relu3_op_slice, self.batch_norm_op_slice
    ])
    def testGetInputOutputOpSlices(self):
        # Map ops to slices.
        self.op_slice_dict = {
            self.relu1_op: [self.relu1_op_slice],
            self.relu2_op: [self.relu2_op_slice],
            self.relu3_op: [self.relu3_op_slice],
            self.concat_op: [self.concat_op_slice],
            self.batch_norm_op: [self.batch_norm_op_slice],
        }

        input_ops = [self.relu1_op, self.relu2_op, self.relu3_op, self.axis_op]
        output_ops = [self.batch_norm_op]

        expected_input_op_slices = [[
            self.relu1_op_slice, self.relu2_op_slice, self.relu3_op_slice
        ]]
        expected_output_op_slices = [[self.batch_norm_op_slice]]

        # Instantiate handler.
        handler = concat_op_handler.ConcatOpHandler()

        self.assertEqual(
            (expected_input_op_slices, expected_output_op_slices),
            handler._get_input_output_op_slices(input_ops, output_ops,
                                                self.mock_op_reg_manager))
  def testImageIsNotZerothOutputOfOp(self):
    # Throughout the framework, we assume that the 0th output of each op is the
    # only one of interest. One exception that often happens is when the input
    # image comes from a queue or from a staging op. Then the image is one of
    # the outputs of the dequeue (or staging) op, not necessarily the 0th one.
    # Here we test that the BilinearNetworkRegularizer deals correctly with this
    # case.

    # Create an input op where the image is output number 1, not 0.
    # TODO(g1) Move this mechanism to add_concat_model_stub, possibly using
    # tf.split to produce an op where the image is not the 0th output image
    # (instead of FIFOQueue).
    image = add_concat_model_stub.image_stub()
    non_image_tensor = tf.zeros(shape=(41,))
    queue = tf.FIFOQueue(
        capacity=1,
        dtypes=(tf.float32,) * 2,
        shapes=(non_image_tensor.shape, image.shape))

    # Pass the image (output[1]) to the network.
    with arg_scope(self._batch_norm_scope()):
      output_op = add_concat_model_stub.build_model(queue.dequeue()[1])

    # Create OpHandler dict for test.
    op_handler_dict = collections.defaultdict(
        grouping_op_handler.GroupingOpHandler)
    op_handler_dict.update({
        'FusedBatchNorm':
        batch_norm_source_op_handler.BatchNormSourceOpHandler(0.1),
        'Conv2D':
        output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
        'ConcatV2':
        concat_op_handler.ConcatOpHandler(),
    })

    # Create OpRegularizerManager and NetworkRegularizer for test.
    manager = orm.OpRegularizerManager([output_op], op_handler_dict)
    calculator = cost_calculator.CostCalculator(
        manager, resource_function.flop_function)

    # Calculate expected FLOP cost.
    expected_alive_conv1 = sum(add_concat_model_stub.expected_alive()['conv1'])
    conv1_op = tf.get_default_graph().get_operation_by_name('conv1/Conv2D')
    conv1_coeff = resource_function.flop_coeff(conv1_op)
    num_channels = 3
    expected_cost = conv1_coeff * num_channels * expected_alive_conv1

    with self.session():
      tf.global_variables_initializer().run()
      # Set gamma values to replicate aliveness in add_concat_model_stub.
      name_to_var = {v.op.name: v for v in tf.global_variables()}
      gamma1 = name_to_var['conv1/BatchNorm/gamma']
      gamma1.assign([0, 1, 1, 0, 1, 0, 1]).eval()
      gamma4 = name_to_var['conv4/BatchNorm/gamma']
      gamma4.assign([0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0]).eval()

      queue.enqueue((non_image_tensor, image)).run()
      self.assertEqual(expected_cost,
                       calculator.get_cost([conv1_op]).eval())
Esempio n. 6
0
    def __init__(self,
                 ops,
                 gamma_threshold,
                 regularizer_decorator: Type[
                     generic_regularizers.OpRegularizer] = None,
                 decorator_parameters=None,
                 force_group=None,
                 regularizer_blacklist=None):
        """Creates a GammaFlopsRegularizer object.

    Args:
      ops: A list of tf.Operation. An OpRegularizer will be created for all the
        ops in `ops`, and recursively for all ops they depend on via data
        dependency. Typically `ops` would contain a single tf.Operation, which
        is the output of the network.
      gamma_threshold: A float scalar, will be used as a 'gamma_threshold' for
        all instances GammaL1Regularizer created by this class.
      regularizer_decorator: A class of OpRegularizer decorator to use.
      decorator_parameters: A dictionary of parameters to pass to the decorator
        factory. To be used only with decorators that requires parameters,
        otherwise use None.
      force_group: List of regex for ops that should be force-grouped.  Each
        regex corresponds to a separate group.  Use '|' operator to specify
        multiple patterns in a single regex. See op_regularizer_manager for more
        detail.
      regularizer_blacklist: List of regex for ops that should not be
        regularized. See op_regularizer_manager for more detail.
    """
        source_op_handler = batch_norm_source_op_handler.BatchNormSourceOpHandler(
            gamma_threshold)
        if regularizer_decorator:
            source_op_handler = op_handler_decorator.OpHandlerDecorator(
                source_op_handler, regularizer_decorator, decorator_parameters)
        op_handler_dict = collections.defaultdict(
            grouping_op_handler.GroupingOpHandler)
        op_handler_dict.update({
            'FusedBatchNorm':
            source_op_handler,
            'FusedBatchNormV2':
            source_op_handler,
            'Conv2D':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
            'ConcatV2':
            concat_op_handler.ConcatOpHandler(),
            'DepthToSpace':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
            'DepthwiseConv2dNative':
            depthwise_convolution_op_handler.DepthwiseConvolutionOpHandler(),
            'MatMul':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
            'TensorArrayGatherV3':
            leaf_op_handler.LeafOpHandler(),
            'RandomUniform':
            leaf_op_handler.LeafOpHandler(),
            'Reshape':
            leaf_op_handler.LeafOpHandler(),
            'Transpose':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
            'ExpandDims':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
        })

        self._manager = orm.OpRegularizerManager(
            ops,
            op_handler_dict,
            force_group=force_group,
            regularizer_blacklist=regularizer_blacklist)
        self._calculator = cost_calculator.CostCalculator(
            self._manager, resource_function.flop_function)
Esempio n. 7
0
    def __init__(self,
                 ops,
                 threshold,
                 l1_fraction=0,
                 regularizer_decorator: Type[
                     generic_regularizers.OpRegularizer] = None,
                 decorator_parameters=None,
                 force_group=None,
                 regularizer_blacklist=None,
                 convert_to_variable=True):
        """Creates a GroupLassoFlopsRegularizer object.

    Args:
      ops: A list of tf.Operation. An OpRegularizer will be created for all the
        ops in `ops`, and recursively for all ops they depend on via data
        dependency. Typically `ops` would contain a single tf.Operation, which
        is the output of the network.
      threshold: A float scalar, will be used as a 'threshold' for all
        regularizer instances created by this class.
      l1_fraction: Relative weight of L1 in L1 + L2 regularization.
      regularizer_decorator: A class of OpRegularizer decorator to use.
      decorator_parameters: A dictionary of parameters to pass to the decorator
        factory. To be used only with decorators that requires parameters,
        otherwise use None.
      force_group: List of regex for ops that should be force-grouped.  Each
        regex corresponds to a separate group.  Use '|' operator to specify
        multiple patterns in a single regex. See op_regularizer_manager for more
        detail.
      regularizer_blacklist: List of regex for ops that should not be
        regularized. See op_regularizer_manager for more detail.
      convert_to_variable: If `True` convert to variable in the
        `GroupLassoBaseOpHandler`. If you're graph creates variables outside of
        `tf.get_variable`, set to `False`.
    """
        conv2d_handler = conv2d_source_op_handler.Conv2DSourceOpHandler(
            threshold, l1_fraction, convert_to_variable)
        conv2d_transpose_handler = (
            conv2d_transpose_source_op_handler.Conv2DTransposeSourceOpHandler(
                threshold, l1_fraction, convert_to_variable))
        matmul_handler = matmul_source_op_handler.MatMulSourceOpHandler(
            threshold, l1_fraction, convert_to_variable)
        if regularizer_decorator:
            conv2d_handler = op_handler_decorator.OpHandlerDecorator(
                conv2d_handler, regularizer_decorator, decorator_parameters)
            conv2d_transpose_handler = op_handler_decorator.OpHandlerDecorator(
                conv2d_transpose_handler, regularizer_decorator,
                decorator_parameters)
            matmul_handler = op_handler_decorator.OpHandlerDecorator(
                matmul_handler, regularizer_decorator, decorator_parameters)

        op_handler_dict = collections.defaultdict(
            grouping_op_handler.GroupingOpHandler)
        op_handler_dict.update({
            'Conv2D':
            conv2d_handler,
            'Conv2DBackpropInput':
            conv2d_transpose_handler,
            'ConcatV2':
            concat_op_handler.ConcatOpHandler(),
            'DepthToSpace':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
            'DepthwiseConv2dNative':
            depthwise_convolution_op_handler.DepthwiseConvolutionOpHandler(),
            'MatMul':
            matmul_handler,
            'RandomUniform':
            leaf_op_handler.LeafOpHandler(),
            'Reshape':
            leaf_op_handler.LeafOpHandler(),
            'Shape':
            leaf_op_handler.LeafOpHandler(),
            'TensorArrayGatherV3':
            leaf_op_handler.LeafOpHandler(),
            'Transpose':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
            'StridedSlice':
            leaf_op_handler.LeafOpHandler(),
        })

        self._manager = orm.OpRegularizerManager(
            ops,
            op_handler_dict,
            force_group=force_group,
            regularizer_blacklist=regularizer_blacklist)
        self._calculator = cost_calculator.CostCalculator(
            self._manager, resource_function.flop_function)
Esempio n. 8
0
    def __init__(self,
                 ops,
                 gamma_threshold,
                 hardware,
                 batch_size=1,
                 regularizer_decorator: Type[
                     generic_regularizers.OpRegularizer] = None,
                 decorator_parameters=None,
                 force_group=None,
                 regularizer_blacklist=None) -> None:
        """Creates a GammaLatencyRegularizer object.

    Latency cost and regularization loss is calculated for a specified hardware
    platform.

    Args:
      ops: A list of tf.Operation. An OpRegularizer will be created for all
        the ops in `ops`, and recursively for all ops they depend on via data
        dependency. Typically `ops` would contain a single tf.Operation, which
        is the output of the network.
      gamma_threshold: A float scalar, will be used as a 'gamma_threshold' for
        all instances GammaL1Regularizer created by this class.
      hardware: String name of hardware platform to target.  Must be a key from
        resource_function.PEAK_COMPUTE.
      batch_size: Integer batch size to calculate cost/loss for.
      regularizer_decorator: A string, the name of the regularizer decorators
        to use. Supported decorators are listed in
        op_regularizer_decorator.SUPPORTED_DECORATORS.
      decorator_parameters: A dictionary of parameters to pass to the decorator
        factory. To be used only with decorators that requires parameters,
        otherwise use None.
      force_group: List of regex for ops that should be force-grouped.  Each
        regex corresponds to a separate group.  Use '|' operator to specify
        multiple patterns in a single regex. See op_regularizer_manager for
        more detail.
      regularizer_blacklist: List of regex for ops that should not be
        regularized. See op_regularizer_manager for more detail.
    """
        source_op_handler = batch_norm_source_op_handler.BatchNormSourceOpHandler(
            gamma_threshold)
        if regularizer_decorator:
            source_op_handler = op_handler_decorator.OpHandlerDecorator(
                source_op_handler, regularizer_decorator, decorator_parameters)
        op_handler_dict = collections.defaultdict(
            grouping_op_handler.GroupingOpHandler)
        op_handler_dict.update({
            'FusedBatchNorm':
            source_op_handler,
            'FusedBatchNormV2':
            source_op_handler,
            'Conv2D':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
            'ConcatV2':
            concat_op_handler.ConcatOpHandler(),
            'DepthToSpace':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
            'DepthwiseConv2dNative':
            depthwise_convolution_op_handler.DepthwiseConvolutionOpHandler(),
            'MatMul':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
            'TensorArrayGatherV3':
            leaf_op_handler.LeafOpHandler(),
            'Transpose':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
        })

        self._manager = orm.OpRegularizerManager(
            ops,
            op_handler_dict,
            force_group=force_group,
            regularizer_blacklist=regularizer_blacklist)
        self._calculator = cost_calculator.CostCalculator(
            self._manager,
            resource_function.latency_function_factory(hardware, batch_size))
        self._hardware = hardware
    def testAssignGrouping_NoNeighborGroups(self):
        # In this test, both the inputs and outputs are missing groups.  The concat
        # and batch norm are sliced, but grouping does not happen until the inputs
        # and outputs are grouped.

        # Map ops to slices.
        self.op_slice_dict = {
            self.relu1_op: [self.relu1_op_slice],
            self.relu2_op: [self.relu2_op_slice],
            self.relu3_op: [self.relu3_op_slice],
            self.concat_op: [self.concat_op_slice],
            self.batch_norm_op: [self.batch_norm_op_slice],
        }

        # No neighbor slices are grouped.
        self.op_group_dict = {
            self.concat_op_slice_0_5: self.concat_op_group1,
            self.concat_op_slice_5_11: self.concat_op_group2,
            self.concat_op_slice_11_18: self.concat_op_group3,
        }

        # Call handler to assign grouping.
        handler = concat_op_handler.ConcatOpHandler()
        handler.assign_grouping(self.concat_op, self.mock_op_reg_manager)

        # Verify manager looks up OpSlice for ops of interest.
        self.mock_op_reg_manager.get_op_slices.assert_has_calls(
            # Checking for ops to process.
            [
                mock.call(self.relu1_op),
                mock.call(self.relu2_op),
                mock.call(self.relu3_op),
                mock.call(self.batch_norm_op),
                # Initial slice data.
                mock.call(self.concat_op),
                mock.call(self.relu1_op),
                mock.call(self.relu2_op),
                mock.call(self.relu3_op),
                mock.call(self.batch_norm_op),
                # Reslicing.
                mock.call(self.relu1_op),
                mock.call(self.relu2_op),
                mock.call(self.relu3_op),
                mock.call(self.batch_norm_op),
                mock.call(self.concat_op),
                # Refreshing slice data.
                mock.call(self.relu1_op),
                mock.call(self.relu2_op),
                mock.call(self.relu3_op),
                mock.call(self.batch_norm_op)
            ])

        # Verify manager slices ops that do not have aligned OpSlice sizes.
        self.mock_op_reg_manager.slice_op.assert_has_calls([
            mock.call(self.batch_norm_op, [5, 6, 7]),
            mock.call(self.concat_op, [5, 6, 7])
        ])

        # Verify manager doesn't group anything.
        self.mock_op_reg_manager.group_op_slices.assert_not_called()

        # Verify manager adds ops to processing queue.
        self.mock_op_reg_manager.process_ops.assert_called_once_with(
            [self.batch_norm_op, self.relu1_op, self.relu2_op, self.relu3_op])
        self.mock_op_reg_manager.process_ops_last.assert_called_once_with(
            [self.concat_op])
    def testAssignGrouping_OutputsGrouped(self):
        # In this test, only the output ops are grouped.  The concat and batch norm
        # ops will be sliced according to the input sizes.

        # Map ops to slices.
        self.op_slice_dict = {
            self.relu1_op: [self.relu1_op_slice],
            self.relu2_op: [self.relu2_op_slice],
            self.relu3_op: [self.relu3_op_slice],
            self.concat_op: [self.concat_op_slice],
            self.batch_norm_op: [self.batch_norm_op_slice],
        }

        # Map each slice to a group.  Input ops (ReLU) are not grouped.
        self.op_group_dict = {
            self.concat_op_slice_0_5: self.concat_op_group1,
            self.concat_op_slice_5_11: self.concat_op_group2,
            self.concat_op_slice_11_18: self.concat_op_group3,
            self.batch_norm_op_slice: self.batch_norm_op_group,
            self.batch_norm_op_slice_0_5: self.batch_norm_op_group1,
            self.batch_norm_op_slice_5_11: self.batch_norm_op_group2,
            self.batch_norm_op_slice_11_18: self.batch_norm_op_group3,
        }

        # Call handler to assign grouping.
        handler = concat_op_handler.ConcatOpHandler()
        handler.assign_grouping(self.concat_op, self.mock_op_reg_manager)

        # Verify manager looks up OpSlice for ops of interest.
        self.mock_op_reg_manager.get_op_slices.assert_has_calls(
            # Checking for ops to process.
            [
                mock.call(self.relu1_op),
                mock.call(self.relu2_op),
                mock.call(self.relu3_op),
                mock.call(self.batch_norm_op),
                # Initial slice data.
                mock.call(self.concat_op),
                mock.call(self.relu1_op),
                mock.call(self.relu2_op),
                mock.call(self.relu3_op),
                mock.call(self.batch_norm_op),
                # Reslicing.
                mock.call(self.relu1_op),
                mock.call(self.relu2_op),
                mock.call(self.relu3_op),
                mock.call(self.batch_norm_op),
                mock.call(self.concat_op),
                # Refreshing slice data.
                mock.call(self.relu1_op),
                mock.call(self.relu2_op),
                mock.call(self.relu3_op),
                mock.call(self.batch_norm_op)
            ])

        # Verify manager slices ops that do not have aligned OpSlice sizes.
        self.mock_op_reg_manager.slice_op.assert_has_calls([
            mock.call(self.batch_norm_op, [5, 6, 7]),
            mock.call(self.concat_op, [5, 6, 7])
        ])

        # Verify manager does not group ops.
        self.mock_op_reg_manager.group_op_slices.assert_not_called()

        # Verify manager adds ops to processing queue.
        self.mock_op_reg_manager.process_ops.assert_called_once_with(
            [self.relu1_op, self.relu2_op, self.relu3_op])
        self.mock_op_reg_manager.process_ops_last.assert_called_once_with(
            [self.concat_op])
    def testAssignGrouping_AllNeighborsGrouped_InputSlicesNotAligned(self):
        # In this test, the op c2 has size 6 but is split into 2 slices of size 3.
        # The concat op (and its output, the batch norm) both have size 18.  This
        # test verifies that the concat and batch norm are sliced according to the
        # sizes of c1, c2, and c3, and takes into account that c2 is also composed
        # of multiple slices.
        concat_op_slice_0_5 = orm.OpSlice(self.concat_op, orm.Slice(0, 5))
        concat_op_slice_5_8 = orm.OpSlice(self.concat_op, orm.Slice(5, 3))
        concat_op_slice_8_11 = orm.OpSlice(self.concat_op, orm.Slice(8, 3))
        concat_op_slice_11_18 = orm.OpSlice(self.concat_op, orm.Slice(11, 7))

        relu2_op_slice_0_3 = orm.OpSlice(self.relu2_op, orm.Slice(0, 3))
        relu2_op_slice_3_6 = orm.OpSlice(self.relu2_op, orm.Slice(3, 3))
        relu2_op_group1 = orm.OpGroup(
            relu2_op_slice_0_3, omit_source_op_slices=[relu2_op_slice_0_3])
        relu2_op_group2 = orm.OpGroup(
            relu2_op_slice_3_6, omit_source_op_slices=[relu2_op_slice_3_6])

        batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 18))
        batch_norm_op_group = orm.OpGroup(
            batch_norm_op_slice, omit_source_op_slices=[batch_norm_op_slice])
        batch_norm_op_slice_0_5 = orm.OpSlice(self.batch_norm_op,
                                              orm.Slice(0, 5))
        batch_norm_op_group1 = orm.OpGroup(
            batch_norm_op_slice_0_5,
            omit_source_op_slices=[batch_norm_op_slice_0_5])
        batch_norm_op_slice_5_8 = orm.OpSlice(self.batch_norm_op,
                                              orm.Slice(5, 3))
        batch_norm_op_group2 = orm.OpGroup(
            batch_norm_op_slice_5_8,
            omit_source_op_slices=[batch_norm_op_slice_5_8])
        batch_norm_op_slice_8_11 = orm.OpSlice(self.batch_norm_op,
                                               orm.Slice(8, 3))
        batch_norm_op_group3 = orm.OpGroup(
            batch_norm_op_slice_8_11,
            omit_source_op_slices=[batch_norm_op_slice_8_11])
        batch_norm_op_slice_11_18 = orm.OpSlice(self.batch_norm_op,
                                                orm.Slice(11, 7))
        batch_norm_op_group4 = orm.OpGroup(
            batch_norm_op_slice_11_18,
            omit_source_op_slices=[batch_norm_op_slice_11_18])

        # Map ops to slices.  The op c2 is composed of multiple slices.
        self.op_slice_dict = {
            self.relu1_op: [self.relu1_op_slice],
            self.relu2_op: [relu2_op_slice_0_3, relu2_op_slice_3_6],
            self.relu3_op: [self.relu3_op_slice],
            self.concat_op: [self.concat_op_slice],
            self.batch_norm_op: [batch_norm_op_slice],
        }

        # Map each slice to a group.
        self.op_group_dict = {
            self.relu1_op_slice: self.relu1_op_group,
            relu2_op_slice_0_3: relu2_op_group1,
            relu2_op_slice_3_6: relu2_op_group2,
            self.relu3_op_slice: self.relu3_op_group,
            batch_norm_op_slice: batch_norm_op_group,
            batch_norm_op_slice_0_5: batch_norm_op_group1,
            batch_norm_op_slice_5_8: batch_norm_op_group2,
            batch_norm_op_slice_8_11: batch_norm_op_group3,
            batch_norm_op_slice_11_18: batch_norm_op_group4,
        }

        # Update op_slice_dict when an op is sliced.
        def slice_op(op, _):
            if op == self.batch_norm_op:
                self.op_slice_dict[self.batch_norm_op] = [
                    batch_norm_op_slice_0_5, batch_norm_op_slice_5_8,
                    batch_norm_op_slice_8_11, batch_norm_op_slice_11_18
                ]
            if op == self.concat_op:
                self.op_slice_dict[self.concat_op] = [
                    concat_op_slice_0_5, concat_op_slice_5_8,
                    concat_op_slice_8_11, concat_op_slice_11_18
                ]

        self.mock_op_reg_manager.slice_op.side_effect = slice_op

        # Call handler to assign grouping.
        handler = concat_op_handler.ConcatOpHandler()
        handler.assign_grouping(self.concat_op, self.mock_op_reg_manager)

        # Verify manager looks up OpSlice for ops of interest.
        self.mock_op_reg_manager.get_op_slices.assert_has_calls(
            # Checking for ops to process.
            [
                mock.call(self.relu1_op),
                mock.call(self.relu2_op),
                mock.call(self.relu3_op),
                mock.call(self.batch_norm_op),
                # Initial slice data.
                mock.call(self.concat_op),
                mock.call(self.relu1_op),
                mock.call(self.relu2_op),
                mock.call(self.relu3_op),
                mock.call(self.batch_norm_op),
                # Reslicing.
                mock.call(self.relu1_op),
                mock.call(self.relu2_op),
                mock.call(self.relu3_op),
                mock.call(self.batch_norm_op),
                mock.call(self.concat_op),
                # Refreshing slice data.
                mock.call(self.relu1_op),
                mock.call(self.relu2_op),
                mock.call(self.relu3_op),
                mock.call(self.batch_norm_op),
                # Group concat op.
                mock.call(self.concat_op)
            ])

        # Verify manager slices ops that do not have aligned OpSlice sizes.
        self.mock_op_reg_manager.slice_op.assert_has_calls([
            mock.call(self.batch_norm_op, [5, 3, 3, 7]),
            mock.call(self.concat_op, [5, 3, 3, 7])
        ])

        # Verify manager groups the new slices.
        self.mock_op_reg_manager.group_op_slices.assert_has_calls([
            mock.call([
                concat_op_slice_0_5, self.relu1_op_slice,
                batch_norm_op_slice_0_5
            ]),
            mock.call([
                concat_op_slice_5_8, relu2_op_slice_0_3,
                batch_norm_op_slice_5_8
            ]),
            mock.call([
                concat_op_slice_8_11, relu2_op_slice_3_6,
                batch_norm_op_slice_8_11
            ]),
            mock.call([
                concat_op_slice_11_18, self.relu3_op_slice,
                batch_norm_op_slice_11_18
            ])
        ])
    def testAssignGrouping_AllNeighborsGrouped_OutputSlicesNotAligned(self):
        # The output (batch norm) has sizes [9, 4, 5] which are not aligned.  This
        # test verifies that the concat, batch norm, and Conv2D ops are sliced in
        # alignment.
        concat_op_slice_0_5 = orm.OpSlice(self.concat_op, orm.Slice(0, 5))
        concat_op_slice_5_9 = orm.OpSlice(self.concat_op, orm.Slice(5, 4))
        concat_op_slice_9_11 = orm.OpSlice(self.concat_op, orm.Slice(9, 2))
        concat_op_slice_11_13 = orm.OpSlice(self.concat_op, orm.Slice(11, 2))
        concat_op_slice_13_18 = orm.OpSlice(self.concat_op, orm.Slice(13, 5))

        relu2_op_slice_0_4 = orm.OpSlice(self.relu2_op, orm.Slice(0, 4))
        relu2_op_slice_4_6 = orm.OpSlice(self.relu2_op, orm.Slice(4, 2))

        relu3_op_slice_0_2 = orm.OpSlice(self.relu3_op, orm.Slice(0, 2))
        relu3_op_slice_2_7 = orm.OpSlice(self.relu3_op, orm.Slice(2, 5))

        batch_norm_op_slice_0_9 = orm.OpSlice(self.batch_norm_op,
                                              orm.Slice(0, 9))
        batch_norm_op_group1 = orm.OpGroup(
            batch_norm_op_slice_0_9,
            omit_source_op_slices=[batch_norm_op_slice_0_9])
        batch_norm_op_slice_9_13 = orm.OpSlice(self.batch_norm_op,
                                               orm.Slice(9, 4))
        batch_norm_op_group2 = orm.OpGroup(
            batch_norm_op_slice_9_13,
            omit_source_op_slices=[batch_norm_op_slice_9_13])
        batch_norm_op_slice_13_18 = orm.OpSlice(self.batch_norm_op,
                                                orm.Slice(13, 5))
        batch_norm_op_group3 = orm.OpGroup(
            batch_norm_op_slice_13_18,
            omit_source_op_slices=[batch_norm_op_slice_13_18])
        batch_norm_op_slice_0_5 = orm.OpSlice(self.batch_norm_op,
                                              orm.Slice(0, 5))
        batch_norm_op_group4 = orm.OpGroup(
            batch_norm_op_slice_0_5,
            omit_source_op_slices=[batch_norm_op_slice_0_5])
        batch_norm_op_slice_5_9 = orm.OpSlice(self.batch_norm_op,
                                              orm.Slice(5, 4))
        batch_norm_op_group5 = orm.OpGroup(
            batch_norm_op_slice_5_9,
            omit_source_op_slices=[batch_norm_op_slice_5_9])
        batch_norm_op_slice_9_11 = orm.OpSlice(self.batch_norm_op,
                                               orm.Slice(9, 2))
        batch_norm_op_group6 = orm.OpGroup(
            batch_norm_op_slice_9_11,
            omit_source_op_slices=[batch_norm_op_slice_9_11])
        batch_norm_op_slice_11_13 = orm.OpSlice(self.batch_norm_op,
                                                orm.Slice(11, 2))
        batch_norm_op_group7 = orm.OpGroup(
            batch_norm_op_slice_11_13,
            omit_source_op_slices=[batch_norm_op_slice_11_13])
        batch_norm_op_slice_13_18 = orm.OpSlice(self.batch_norm_op,
                                                orm.Slice(11, 5))
        batch_norm_op_group8 = orm.OpGroup(
            batch_norm_op_slice_13_18,
            omit_source_op_slices=[batch_norm_op_slice_13_18])

        # Map ops to slices.  Batch norm op is composed of multiple slices.
        self.op_slice_dict = {
            self.relu1_op: [self.relu1_op_slice],
            self.relu2_op: [self.relu2_op_slice],
            self.relu3_op: [self.relu3_op_slice],
            self.concat_op: [self.concat_op_slice],
            self.batch_norm_op: [
                batch_norm_op_slice_0_9, batch_norm_op_slice_9_13,
                batch_norm_op_slice_13_18
            ],
        }

        # Map each slice to a group.
        self.op_group_dict = {
            self.relu1_op_slice: self.relu1_op_group,
            self.relu2_op_slice: self.relu2_op_group,
            self.relu3_op_slice: self.relu3_op_group,
            batch_norm_op_slice_0_9: batch_norm_op_group1,
            batch_norm_op_slice_9_13: batch_norm_op_group2,
            batch_norm_op_slice_13_18: batch_norm_op_group3,
            batch_norm_op_slice_0_5: batch_norm_op_group4,
            batch_norm_op_slice_5_9: batch_norm_op_group5,
            batch_norm_op_slice_9_11: batch_norm_op_group6,
            batch_norm_op_slice_11_13: batch_norm_op_group7,
            batch_norm_op_slice_13_18: batch_norm_op_group8,
        }

        # Update op_slice_dict when an op is sliced.
        def slice_op(op, _):
            if op == self.batch_norm_op:
                self.op_slice_dict[self.batch_norm_op] = [
                    batch_norm_op_slice_0_5, batch_norm_op_slice_5_9,
                    batch_norm_op_slice_9_11, batch_norm_op_slice_11_13,
                    batch_norm_op_slice_13_18
                ]
            if op == self.concat_op:
                self.op_slice_dict[self.concat_op] = [
                    concat_op_slice_0_5, concat_op_slice_5_9,
                    concat_op_slice_9_11, concat_op_slice_11_13,
                    concat_op_slice_13_18
                ]
            if op == self.relu2_op:
                self.op_slice_dict[self.relu2_op] = [
                    relu2_op_slice_0_4, relu2_op_slice_4_6
                ]
            if op == self.relu3_op:
                self.op_slice_dict[self.relu3_op] = [
                    relu3_op_slice_0_2, relu3_op_slice_2_7
                ]

        self.mock_op_reg_manager.slice_op.side_effect = slice_op

        # Call handler to assign grouping.
        handler = concat_op_handler.ConcatOpHandler()
        handler.assign_grouping(self.concat_op, self.mock_op_reg_manager)

        # Verify manager looks up OpSlice for ops of interest.
        self.mock_op_reg_manager.get_op_slices.assert_has_calls(
            # Checking for ops to process.
            [
                mock.call(self.relu1_op),
                mock.call(self.relu2_op),
                mock.call(self.relu3_op),
                mock.call(self.batch_norm_op),
                # Initial slice data.
                mock.call(self.concat_op),
                mock.call(self.relu1_op),
                mock.call(self.relu2_op),
                mock.call(self.relu3_op),
                mock.call(self.batch_norm_op),
                # Reslicing.
                mock.call(self.relu1_op),
                mock.call(self.relu2_op),
                mock.call(self.relu3_op),
                mock.call(self.batch_norm_op),
                mock.call(self.concat_op),
                # Refreshing slice data.
                mock.call(self.relu1_op),
                mock.call(self.relu2_op),
                mock.call(self.relu3_op),
                mock.call(self.batch_norm_op),
                # Group concat op.
                mock.call(self.concat_op)
            ])

        # Verify manager slices ops that do not have aligned OpSlice sizes.
        self.mock_op_reg_manager.slice_op.assert_has_calls([
            mock.call(self.relu2_op, [4, 2]),
            mock.call(self.relu3_op, [2, 5]),
            mock.call(self.batch_norm_op, [5, 4, 2, 2, 5]),
            mock.call(self.concat_op, [5, 4, 2, 2, 5])
        ])

        # Verify manager groups the new slices.
        self.mock_op_reg_manager.group_op_slices.assert_has_calls([
            mock.call([
                concat_op_slice_0_5, self.relu1_op_slice,
                batch_norm_op_slice_0_5
            ]),
            mock.call([
                concat_op_slice_5_9, relu2_op_slice_0_4,
                batch_norm_op_slice_5_9
            ]),
            mock.call([
                concat_op_slice_9_11, relu2_op_slice_4_6,
                batch_norm_op_slice_9_11
            ]),
            mock.call([
                concat_op_slice_11_13, relu3_op_slice_0_2,
                batch_norm_op_slice_11_13
            ]),
            mock.call([
                concat_op_slice_13_18, relu3_op_slice_2_7,
                batch_norm_op_slice_13_18
            ])
        ])
    def testAssignGrouping_AllNeighborsGrouped_SlicesAligned_SameGroup(self):
        # This test verifies that no slicing or grouping occurs.

        # Map ops to slices.  Batch norm op is composed of multiple slices.
        self.op_slice_dict = {
            self.relu1_op: [self.relu1_op_slice],
            self.relu2_op: [self.relu2_op_slice],
            self.relu3_op: [self.relu3_op_slice],
            self.concat_op: [
                self.concat_op_slice_0_5, self.concat_op_slice_5_11,
                self.concat_op_slice_11_18
            ],
            self.batch_norm_op: [
                self.batch_norm_op_slice_0_5, self.batch_norm_op_slice_5_11,
                self.batch_norm_op_slice_11_18
            ],
        }

        # Map each slice to a group.  Corresponding op slices have the same group.
        self.op_group_dict = {
            self.relu1_op_slice: self.batch_norm_op_group1,
            self.relu2_op_slice: self.batch_norm_op_group2,
            self.relu3_op_slice: self.batch_norm_op_group3,
            self.concat_op_slice_0_5: self.batch_norm_op_group1,
            self.concat_op_slice_5_11: self.batch_norm_op_group2,
            self.concat_op_slice_11_18: self.batch_norm_op_group3,
            self.batch_norm_op_slice_0_5: self.batch_norm_op_group1,
            self.batch_norm_op_slice_5_11: self.batch_norm_op_group2,
            self.batch_norm_op_slice_11_18: self.batch_norm_op_group3,
        }

        # Call handler to assign grouping.
        handler = concat_op_handler.ConcatOpHandler()
        handler.assign_grouping(self.concat_op, self.mock_op_reg_manager)

        # Verify manager looks up OpSlice for ops of interest.
        self.mock_op_reg_manager.get_op_slices.assert_has_calls(
            # Checking for ops to process.
            [
                mock.call(self.relu1_op),
                mock.call(self.relu2_op),
                mock.call(self.relu3_op),
                mock.call(self.batch_norm_op),
                # Initial slice data.
                mock.call(self.concat_op),
                mock.call(self.relu1_op),
                mock.call(self.relu2_op),
                mock.call(self.relu3_op),
                mock.call(self.batch_norm_op),
                # Reslicing.
                mock.call(self.relu1_op),
                mock.call(self.relu2_op),
                mock.call(self.relu3_op),
                mock.call(self.batch_norm_op),
                mock.call(self.concat_op),
                # Refreshing slice data.
                mock.call(self.relu1_op),
                mock.call(self.relu2_op),
                mock.call(self.relu3_op),
                mock.call(self.batch_norm_op),
                # Group concat op.
                mock.call(self.concat_op)
            ])

        # Verify manager does not slice any ops.
        self.mock_op_reg_manager.slice_op.assert_not_called()

        # Verify manager does not group any ops.
        self.mock_op_reg_manager.group_op_slices.assert_not_called()
    def testAssignGrouping_AllNeighborsGrouped_SlicesAligned(self):
        # In this test, the output op (batch norm) has size 18 and is sliced into
        # sizes [5, 6, 7] which matches the Conv2D sizes which are [5, 6, 7].

        # Map ops to slices.  Batch norm op is composed of multiple slices.
        self.op_slice_dict = {
            self.relu1_op: [self.relu1_op_slice],
            self.relu2_op: [self.relu2_op_slice],
            self.relu3_op: [self.relu3_op_slice],
            self.concat_op: [self.concat_op_slice],
            self.batch_norm_op: [
                self.batch_norm_op_slice_0_5, self.batch_norm_op_slice_5_11,
                self.batch_norm_op_slice_11_18
            ],
        }

        # Map each slice to a group.
        self.op_group_dict = {
            self.relu1_op_slice: self.relu1_op_group,
            self.relu2_op_slice: self.relu2_op_group,
            self.relu3_op_slice: self.relu3_op_group,
            self.batch_norm_op_slice_0_5: self.batch_norm_op_group1,
            self.batch_norm_op_slice_5_11: self.batch_norm_op_group2,
            self.batch_norm_op_slice_11_18: self.batch_norm_op_group3,
        }

        # Call handler to assign grouping.
        handler = concat_op_handler.ConcatOpHandler()
        handler.assign_grouping(self.concat_op, self.mock_op_reg_manager)

        # Verify manager looks up OpSlice for ops of interest.
        self.mock_op_reg_manager.get_op_slices.assert_has_calls(
            # Checking for ops to process.
            [
                mock.call(self.relu1_op),
                mock.call(self.relu2_op),
                mock.call(self.relu3_op),
                mock.call(self.batch_norm_op),
                # Initial slice data.
                mock.call(self.concat_op),
                mock.call(self.relu1_op),
                mock.call(self.relu2_op),
                mock.call(self.relu3_op),
                mock.call(self.batch_norm_op),
                # Reslicing.
                mock.call(self.relu1_op),
                mock.call(self.relu2_op),
                mock.call(self.relu3_op),
                mock.call(self.batch_norm_op),
                mock.call(self.concat_op),
                # Refreshing slice data.
                mock.call(self.relu1_op),
                mock.call(self.relu2_op),
                mock.call(self.relu3_op),
                mock.call(self.batch_norm_op),
                # Group concat op.
                mock.call(self.concat_op)
            ])

        # Verify manager only slices the concat op.
        self.mock_op_reg_manager.slice_op.assert_called_once_with(
            self.concat_op, [5, 6, 7])

        # Verify manager groups the new slices.
        self.mock_op_reg_manager.group_op_slices.assert_has_calls([
            mock.call([
                self.concat_op_slice_0_5, self.relu1_op_slice,
                self.batch_norm_op_slice_0_5
            ]),
            mock.call([
                self.concat_op_slice_5_11, self.relu2_op_slice,
                self.batch_norm_op_slice_5_11
            ]),
            mock.call([
                self.concat_op_slice_11_18, self.relu3_op_slice,
                self.batch_norm_op_slice_11_18
            ])
        ])