Пример #1
0
    def testAssignGrouping_NeighborsHaveSameGroup(self):
        self.op_slice_dict = {
            self.batch_norm_op: [self.batch_norm_op_slice],
            self.conv_op: [self.batch_norm_op_slice],
            self.relu_op: [self.batch_norm_op_slice],
            self.gamma_op: [self.batch_norm_op_slice],
            self.beta_op: [self.batch_norm_op_slice],
        }

        # All ops have the same group.
        self.op_group_dict = {
            self.batch_norm_op_slice: self.batch_norm_op_group,
            self.conv_op_slice: self.batch_norm_op_group,
            self.relu_op_slice: self.batch_norm_op_group,
            self.gamma_op_slice: self.batch_norm_op_group,
            self.beta_op_slice: self.batch_norm_op_group,
        }

        # Call handler to assign grouping.
        handler = batch_norm_source_op_handler.BatchNormSourceOpHandler(
            _GAMMA_THRESHOLD)
        handler.assign_grouping(self.batch_norm_op, self.mock_op_reg_manager)

        # Verify manager looks up op slice for ops of interest.
        self.mock_op_reg_manager.get_op_slices.assert_any_call(
            self.batch_norm_op)
        self.mock_op_reg_manager.get_op_slices.assert_any_call(self.conv_op)
        self.mock_op_reg_manager.get_op_slices.assert_any_call(self.relu_op)

        # Verify manager does not group any ops.
        self.mock_op_reg_manager.group_op_slices.assert_not_called()

        # Verify manager does not process additional ops.
        self.mock_op_reg_manager.process_ops.assert_not_called()
        self.mock_op_reg_manager.process_ops_last.assert_not_called()
Пример #2
0
  def __init__(
      self,
      output_boundary: List[tf.Operation],
      gamma_threshold,
      hardware,
      batch_size=1,
      regularizer_decorator: Type[generic_regularizers.OpRegularizer] = None,
      decorator_parameters=None,
      input_boundary: List[tf.Operation] = None,
      force_group=None,
      regularizer_blacklist=None) -> None:
    """Creates a GammaLatencyRegularizer object.

    Latency cost and regularization loss is calculated for a specified hardware
    platform.

    Args:
      output_boundary: An OpRegularizer will be created for all these
        operations, and recursively for all ops they depend on via data
        dependency that does not involve ops from input_boundary.
      gamma_threshold: A float scalar, will be used as a 'gamma_threshold' for
        all instances GammaL1Regularizer created by this class.
      hardware: String name of hardware platform to target.  Must be a key from
        resource_function.PEAK_COMPUTE.
      batch_size: Integer batch size to calculate cost/loss for.
      regularizer_decorator: A string, the name of the regularizer decorators
        to use. Supported decorators are listed in
        op_regularizer_decorator.SUPPORTED_DECORATORS.
      decorator_parameters: A dictionary of parameters to pass to the decorator
        factory. To be used only with decorators that requires parameters,
        otherwise use None.
      input_boundary: A list of ops that represent the input boundary of the
        subgraph being regularized (input boundary is not regularized).
      force_group: List of regex for ops that should be force-grouped.  Each
        regex corresponds to a separate group.  Use '|' operator to specify
        multiple patterns in a single regex. See op_regularizer_manager for
        more detail.
      regularizer_blacklist: List of regex for ops that should not be
        regularized. See op_regularizer_manager for more detail.
    """
    source_op_handler = batch_norm_source_op_handler.BatchNormSourceOpHandler(
        gamma_threshold)
    if regularizer_decorator:
      source_op_handler = op_handler_decorator.OpHandlerDecorator(
          source_op_handler, regularizer_decorator,
          decorator_parameters)
    op_handler_dict = op_handlers.get_gamma_op_handler_dict()
    op_handler_dict.update({
        'FusedBatchNorm': source_op_handler,
        'FusedBatchNormV2': source_op_handler,
        'FusedBatchNormV3': source_op_handler,
    })

    self._manager = orm.OpRegularizerManager(
        output_boundary, op_handler_dict, input_boundary=input_boundary,
        force_group=force_group, regularizer_blacklist=regularizer_blacklist)
    self._calculator = cost_calculator.CostCalculator(
        self._manager,
        resource_function.latency_function_factory(hardware, batch_size))
    self._hardware = hardware
Пример #3
0
  def testImageIsNotZerothOutputOfOp(self):
    # Throughout the framework, we assume that the 0th output of each op is the
    # only one of interest. One exception that often happens is when the input
    # image comes from a queue or from a staging op. Then the image is one of
    # the outputs of the dequeue (or staging) op, not necessarily the 0th one.
    # Here we test that the BilinearNetworkRegularizer deals correctly with this
    # case.

    # Create an input op where the image is output number 1, not 0.
    # TODO(g1) Move this mechanism to add_concat_model_stub, possibly using
    # tf.split to produce an op where the image is not the 0th output image
    # (instead of FIFOQueue).
    image = add_concat_model_stub.image_stub()
    non_image_tensor = tf.zeros(shape=(41,))
    queue = tf.FIFOQueue(
        capacity=1,
        dtypes=(tf.float32,) * 2,
        shapes=(non_image_tensor.shape, image.shape))

    # Pass the image (output[1]) to the network.
    with arg_scope(self._batch_norm_scope()):
      output_op = add_concat_model_stub.build_model(queue.dequeue()[1])

    # Create OpHandler dict for test.
    op_handler_dict = collections.defaultdict(
        grouping_op_handler.GroupingOpHandler)
    op_handler_dict.update({
        'FusedBatchNorm':
        batch_norm_source_op_handler.BatchNormSourceOpHandler(0.1),
        'Conv2D':
        output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
        'ConcatV2':
        concat_op_handler.ConcatOpHandler(),
    })

    # Create OpRegularizerManager and NetworkRegularizer for test.
    manager = orm.OpRegularizerManager([output_op], op_handler_dict)
    calculator = cost_calculator.CostCalculator(
        manager, resource_function.flop_function)

    # Calculate expected FLOP cost.
    expected_alive_conv1 = sum(add_concat_model_stub.expected_alive()['conv1'])
    conv1_op = tf.get_default_graph().get_operation_by_name('conv1/Conv2D')
    conv1_coeff = resource_function.flop_coeff(conv1_op)
    num_channels = 3
    expected_cost = conv1_coeff * num_channels * expected_alive_conv1

    with self.session():
      tf.global_variables_initializer().run()
      # Set gamma values to replicate aliveness in add_concat_model_stub.
      name_to_var = {v.op.name: v for v in tf.global_variables()}
      gamma1 = name_to_var['conv1/BatchNorm/gamma']
      gamma1.assign([0, 1, 1, 0, 1, 0, 1]).eval()
      gamma4 = name_to_var['conv4/BatchNorm/gamma']
      gamma4.assign([0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0]).eval()

      queue.enqueue((non_image_tensor, image)).run()
      self.assertEqual(expected_cost,
                       calculator.get_cost([conv1_op]).eval())
Пример #4
0
    def testCreateRegularizer(self):
        # Call handler to create regularizer.
        handler = batch_norm_source_op_handler.BatchNormSourceOpHandler(
            _GAMMA_THRESHOLD)
        regularizer = handler.create_regularizer(self.batch_norm_op_slice)

        # Verify regularizer is the gamma tensor.
        g = tf.get_default_graph()
        gamma_tensor = g.get_tensor_by_name('conv1/BatchNorm/gamma/read:0')
        self.assertEqual(gamma_tensor, regularizer._gamma)
    def __init__(self,
                 ops,
                 gamma_threshold,
                 regularizer_decorator: Type[
                     generic_regularizers.OpRegularizer] = None,
                 decorator_parameters=None,
                 input_boundary=None,
                 force_group=None,
                 regularizer_blacklist=None):
        """Creates a GammaModelSizeRegularizer object.

    Args:
      ops: A list of tf.Operation. An OpRegularizer will be created for all
        the ops in `ops`, and recursively for all ops they depend on via data
        dependency. Typically `ops` would contain a single tf.Operation, which
        is the output of the network.
      gamma_threshold: A float scalar, will be used as a 'gamma_threshold' for
        all instances GammaL1Regularizer created by this class.
      regularizer_decorator: A string, the name of the regularizer decorators
        to use. Supported decorators are listed in
        op_regularizer_decorator.SUPPORTED_DECORATORS.
      decorator_parameters: A dictionary of parameters to pass to the decorator
        factory. To be used only with decorators that requires parameters,
        otherwise use None.
      input_boundary: A list of ops that represent the input boundary of the
        subgraph being regularized (input boundary is not regularized).
      force_group: List of regex for ops that should be force-grouped.  Each
        regex corresponds to a separate group.  Use '|' operator to specify
        multiple patterns in a single regex. See op_regularizer_manager for
        more detail.
      regularizer_blacklist: List of regex for ops that should not be
        regularized. See op_regularizer_manager for more detail.
    """
        source_op_handler = batch_norm_source_op_handler.BatchNormSourceOpHandler(
            gamma_threshold)
        if regularizer_decorator:
            source_op_handler = op_handler_decorator.OpHandlerDecorator(
                source_op_handler, regularizer_decorator, decorator_parameters)
        op_handler_dict = op_handlers.get_gamma_op_handler_dict()
        op_handler_dict.update({
            'FusedBatchNorm': source_op_handler,
            'FusedBatchNormV2': source_op_handler,
        })

        self._manager = orm.OpRegularizerManager(
            ops,
            op_handler_dict,
            input_boundary=input_boundary,
            force_group=force_group,
            regularizer_blacklist=regularizer_blacklist)
        self._calculator = cost_calculator.CostCalculator(
            self._manager, resource_function.model_size_function)
Пример #6
0
    def testAssignGrouping_NeighborsHaveSameGroup_ReprocessSources(self):
        source_conv_op_group = orm.OpGroup(self.conv_op_slice)

        self.op_slice_dict = {
            self.batch_norm_op: [self.batch_norm_op_slice],
            self.conv_op: [self.conv_op_slice],
            self.relu_op: [self.relu_op_slice],
            self.gamma_op: [self.gamma_op_slice],
            self.beta_op: [self.beta_op_slice],
            self.mean_op: [self.mean_op_slice],
            self.std_op: [self.std_op_slice],
        }

        self.op_group_dict = {
            self.batch_norm_op_slice: self.batch_norm_op_group,
            self.conv_op_slice: source_conv_op_group,
            self.relu_op_slice: self.batch_norm_op_group,
            self.gamma_op_slice: self.batch_norm_op_group,
            self.beta_op_slice: self.batch_norm_op_group,
        }

        source_ops = (self.conv_op, )

        def is_source_op(op):
            return op in source_ops

        self.mock_op_reg_manager.is_source_op.side_effect = is_source_op

        # Call handler to assign grouping.
        handler = batch_norm_source_op_handler.BatchNormSourceOpHandler(
            _GAMMA_THRESHOLD)
        handler.assign_grouping(self.batch_norm_op, self.mock_op_reg_manager)

        # Verify manager looks up op slice for ops of interest.
        self.mock_op_reg_manager.get_op_slices.assert_any_call(
            self.batch_norm_op)
        self.mock_op_reg_manager.get_op_slices.assert_any_call(self.conv_op)
        self.mock_op_reg_manager.get_op_slices.assert_any_call(self.relu_op)

        # Verify manager groups batch norm with inputs and overrides source.
        self.mock_op_reg_manager.group_op_slices.assert_called_once_with(
            [
                self.batch_norm_op_slice, self.conv_op_slice,
                self.gamma_op_slice, self.beta_op_slice
            ],
            omit_source_op_slices=[self.conv_op_slice])

        # Verify manager adds ungrouped output ops to queue.
        self.mock_op_reg_manager.process_ops.assert_called_once_with(
            [self.mean_op, self.std_op])
        self.mock_op_reg_manager.process_ops_last.assert_not_called()
Пример #7
0
    def __init__(self,
                 output_boundary: List[tf.Operation],
                 gamma_threshold,
                 regularizer_decorator: Type[
                     generic_regularizers.OpRegularizer] = None,
                 decorator_parameters=None,
                 input_boundary: List[tf.Operation] = None,
                 force_group=None,
                 regularizer_blacklist=None):
        """Creates a GammaActivationRegularizer object.

    Args:
      output_boundary: An OpRegularizer will be created for all these
        operations, and recursively for all ops they depend on via data
        dependency that does not involve ops from input_boundary.
      gamma_threshold: A float scalar, will be used as a 'gamma_threshold' for
        all instances GammaL1Regularizer created by this class.
      regularizer_decorator: A class of OpRegularizer decorator to use.
      decorator_parameters: A dictionary of parameters to pass to the decorator
        factory. To be used only with decorators that requires parameters,
        otherwise use None.
      input_boundary: A list of ops that represent the input boundary of the
        subgraph being regularized (input boundary is not regularized).
      force_group: List of regex for ops that should be force-grouped.  Each
        regex corresponds to a separate group.  Use '|' operator to specify
        multiple patterns in a single regex. See op_regularizer_manager for more
        detail.
      regularizer_blacklist: List of regex for ops that should not be
        regularized. See op_regularizer_manager for more detail.
    """
        source_op_handler = batch_norm_source_op_handler.BatchNormSourceOpHandler(
            gamma_threshold)
        if regularizer_decorator:
            source_op_handler = op_handler_decorator.OpHandlerDecorator(
                source_op_handler, regularizer_decorator, decorator_parameters)
        op_handler_dict = op_handlers.get_gamma_op_handler_dict()
        op_handler_dict.update({
            'FusedBatchNorm': source_op_handler,
            'FusedBatchNormV2': source_op_handler,
            'FusedBatchNormV3': source_op_handler,
        })

        self._manager = orm.OpRegularizerManager(
            output_boundary,
            op_handler_dict,
            input_boundary=input_boundary,
            force_group=force_group,
            regularizer_blacklist=regularizer_blacklist)
        self._calculator = cost_calculator.CostCalculator(
            self._manager, resource_function.activation_count_function)
Пример #8
0
    def testAssignGrouping_ProcessNeighborGroups(self):
        self.op_slice_dict = {
            self.batch_norm_op: [self.batch_norm_op_slice],
            self.conv_op: [self.conv_op_slice],
            self.relu_op: [self.relu_op_slice],
            self.gamma_op: [self.gamma_op_slice],
            self.beta_op: [self.beta_op_slice],
            self.mean_op: [self.mean_op_slice],
            self.std_op: [self.std_op_slice],
        }

        # All ops have groups.
        self.op_group_dict = {
            self.batch_norm_op_slice: self.batch_norm_op_group,
            self.conv_op_slice: self.conv_op_group,
            self.relu_op_slice: self.relu_op_group,
            self.gamma_op_slice: self.gamma_op_group,
            self.beta_op_slice: self.beta_op_group,
            self.mean_op_slice: self.mean_op_group,
            self.std_op_slice: self.std_op_group,
        }

        # Call handler to assign grouping.
        handler = batch_norm_source_op_handler.BatchNormSourceOpHandler(
            _GAMMA_THRESHOLD)
        handler.assign_grouping(self.batch_norm_op, self.mock_op_reg_manager)

        # Verify manager looks up op slice for ops of interest.
        self.mock_op_reg_manager.get_op_slices.assert_any_call(
            self.batch_norm_op)
        self.mock_op_reg_manager.get_op_slices.assert_any_call(self.conv_op)
        self.mock_op_reg_manager.get_op_slices.assert_any_call(self.relu_op)

        # Verify manager groups batch norm with outputs and inputs.
        self.mock_op_reg_manager.group_op_slices.assert_has_calls([
            mock.call([
                self.batch_norm_op_slice, self.relu_op_slice,
                self.mean_op_slice, self.std_op_slice
            ]),
            mock.call([
                self.batch_norm_op_slice, self.conv_op_slice,
                self.gamma_op_slice, self.beta_op_slice
            ],
                      omit_source_op_slices=[])
        ])

        # Verify manager does not reprocess any ops.
        self.mock_op_reg_manager.process_ops.assert_not_called()
        self.mock_op_reg_manager.process_ops_last.assert_not_called()
Пример #9
0
    def testCreateRegularizer_Sliced(self):
        # Call handler to create regularizer.
        handler = batch_norm_source_op_handler.BatchNormSourceOpHandler(
            _GAMMA_THRESHOLD)
        batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 3))
        regularizer = handler.create_regularizer(batch_norm_op_slice)

        # Verify regularizer is the gamma tensor.
        with self.cached_session():
            # Initialize the gamma tensor to check value equality.
            with tf.variable_scope('', reuse=tf.AUTO_REUSE):
                gamma_tensor = tf.get_variable('conv1/BatchNorm/gamma')
            init = tf.variables_initializer([gamma_tensor])
            init.run()

            # Verify regularizer is the sliced gamma tensor.
            self.assertAllEqual(gamma_tensor.eval()[0:3],
                                regularizer._gamma.eval())
Пример #10
0
    def testAssignGrouping_NoNeighborGroups(self):
        self.op_slice_dict = {
            self.batch_norm_op: [self.batch_norm_op_slice],
            self.conv_op: [self.conv_op_slice],
            self.relu_op: [self.relu_op_slice],
            self.gamma_op: [self.gamma_op_slice],
            self.beta_op: [self.beta_op_slice],
            self.mean_op: [self.mean_op_slice],
            self.std_op: [self.std_op_slice],
        }

        # No neighbor ops have groups.
        self.op_group_dict = {
            self.batch_norm_op_slice: self.batch_norm_op_group,
        }

        # Call handler to assign grouping.
        handler = batch_norm_source_op_handler.BatchNormSourceOpHandler(
            _GAMMA_THRESHOLD)
        handler.assign_grouping(self.batch_norm_op, self.mock_op_reg_manager)

        # Verify manager looks up op slice for ops of interest.
        self.mock_op_reg_manager.get_op_slices.assert_any_call(
            self.batch_norm_op)
        self.mock_op_reg_manager.get_op_slices.assert_any_call(self.conv_op)
        self.mock_op_reg_manager.get_op_slices.assert_any_call(self.relu_op)

        # Verify manager creates OpGroup for batch norm op.
        self.mock_op_reg_manager.create_op_group_for_op_slice(
            self.batch_norm_op_slice)

        # Verify manager groups batch norm with outputs and inputs.
        self.mock_op_reg_manager.group_op_slices.assert_not_called()

        # Verify manager processes grouping for input and output ops.
        self.mock_op_reg_manager.process_ops.assert_called_once_with([
            self.relu_op, self.mean_op, self.std_op, self.conv_op,
            self.gamma_op, self.beta_op
        ])
        self.mock_op_reg_manager.process_ops_last.assert_called_once_with(
            [self.batch_norm_op])
Пример #11
0
    def __init__(self,
                 ops,
                 gamma_threshold,
                 regularizer_decorator: Type[
                     generic_regularizers.OpRegularizer] = None,
                 decorator_parameters=None,
                 force_group=None,
                 regularizer_blacklist=None):
        """Creates a GammaFlopsRegularizer object.

    Args:
      ops: A list of tf.Operation. An OpRegularizer will be created for all the
        ops in `ops`, and recursively for all ops they depend on via data
        dependency. Typically `ops` would contain a single tf.Operation, which
        is the output of the network.
      gamma_threshold: A float scalar, will be used as a 'gamma_threshold' for
        all instances GammaL1Regularizer created by this class.
      regularizer_decorator: A class of OpRegularizer decorator to use.
      decorator_parameters: A dictionary of parameters to pass to the decorator
        factory. To be used only with decorators that requires parameters,
        otherwise use None.
      force_group: List of regex for ops that should be force-grouped.  Each
        regex corresponds to a separate group.  Use '|' operator to specify
        multiple patterns in a single regex. See op_regularizer_manager for more
        detail.
      regularizer_blacklist: List of regex for ops that should not be
        regularized. See op_regularizer_manager for more detail.
    """
        source_op_handler = batch_norm_source_op_handler.BatchNormSourceOpHandler(
            gamma_threshold)
        if regularizer_decorator:
            source_op_handler = op_handler_decorator.OpHandlerDecorator(
                source_op_handler, regularizer_decorator, decorator_parameters)
        op_handler_dict = collections.defaultdict(
            grouping_op_handler.GroupingOpHandler)
        op_handler_dict.update({
            'FusedBatchNorm':
            source_op_handler,
            'FusedBatchNormV2':
            source_op_handler,
            'Conv2D':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
            'ConcatV2':
            concat_op_handler.ConcatOpHandler(),
            'DepthToSpace':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
            'DepthwiseConv2dNative':
            depthwise_convolution_op_handler.DepthwiseConvolutionOpHandler(),
            'MatMul':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
            'TensorArrayGatherV3':
            leaf_op_handler.LeafOpHandler(),
            'RandomUniform':
            leaf_op_handler.LeafOpHandler(),
            'Reshape':
            leaf_op_handler.LeafOpHandler(),
            'Transpose':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
            'ExpandDims':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
        })

        self._manager = orm.OpRegularizerManager(
            ops,
            op_handler_dict,
            force_group=force_group,
            regularizer_blacklist=regularizer_blacklist)
        self._calculator = cost_calculator.CostCalculator(
            self._manager, resource_function.flop_function)
Пример #12
0
    def __init__(self,
                 ops,
                 gamma_threshold,
                 hardware,
                 batch_size=1,
                 regularizer_decorator: Type[
                     generic_regularizers.OpRegularizer] = None,
                 decorator_parameters=None,
                 force_group=None,
                 regularizer_blacklist=None) -> None:
        """Creates a GammaLatencyRegularizer object.

    Latency cost and regularization loss is calculated for a specified hardware
    platform.

    Args:
      ops: A list of tf.Operation. An OpRegularizer will be created for all
        the ops in `ops`, and recursively for all ops they depend on via data
        dependency. Typically `ops` would contain a single tf.Operation, which
        is the output of the network.
      gamma_threshold: A float scalar, will be used as a 'gamma_threshold' for
        all instances GammaL1Regularizer created by this class.
      hardware: String name of hardware platform to target.  Must be a key from
        resource_function.PEAK_COMPUTE.
      batch_size: Integer batch size to calculate cost/loss for.
      regularizer_decorator: A string, the name of the regularizer decorators
        to use. Supported decorators are listed in
        op_regularizer_decorator.SUPPORTED_DECORATORS.
      decorator_parameters: A dictionary of parameters to pass to the decorator
        factory. To be used only with decorators that requires parameters,
        otherwise use None.
      force_group: List of regex for ops that should be force-grouped.  Each
        regex corresponds to a separate group.  Use '|' operator to specify
        multiple patterns in a single regex. See op_regularizer_manager for
        more detail.
      regularizer_blacklist: List of regex for ops that should not be
        regularized. See op_regularizer_manager for more detail.
    """
        source_op_handler = batch_norm_source_op_handler.BatchNormSourceOpHandler(
            gamma_threshold)
        if regularizer_decorator:
            source_op_handler = op_handler_decorator.OpHandlerDecorator(
                source_op_handler, regularizer_decorator, decorator_parameters)
        op_handler_dict = collections.defaultdict(
            grouping_op_handler.GroupingOpHandler)
        op_handler_dict.update({
            'FusedBatchNorm':
            source_op_handler,
            'FusedBatchNormV2':
            source_op_handler,
            'Conv2D':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
            'ConcatV2':
            concat_op_handler.ConcatOpHandler(),
            'DepthToSpace':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
            'DepthwiseConv2dNative':
            depthwise_convolution_op_handler.DepthwiseConvolutionOpHandler(),
            'MatMul':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
            'TensorArrayGatherV3':
            leaf_op_handler.LeafOpHandler(),
            'Transpose':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
        })

        self._manager = orm.OpRegularizerManager(
            ops,
            op_handler_dict,
            force_group=force_group,
            regularizer_blacklist=regularizer_blacklist)
        self._calculator = cost_calculator.CostCalculator(
            self._manager,
            resource_function.latency_function_factory(hardware, batch_size))
        self._hardware = hardware
Пример #13
0
    def testAssignGrouping_ProcessNeighborGroupsWithSlices(self):
        batch_norm_op_slice_0_2 = orm.OpSlice(self.batch_norm_op,
                                              orm.Slice(0, 2))
        batch_norm_op_slice_2_3 = orm.OpSlice(self.batch_norm_op,
                                              orm.Slice(2, 1))
        batch_norm_op_group1 = orm.OpGroup(batch_norm_op_slice_0_2)
        batch_norm_op_group2 = orm.OpGroup(batch_norm_op_slice_2_3)

        conv_op_slice_0_2 = orm.OpSlice(self.conv_op, orm.Slice(0, 2))
        conv_op_slice_2_3 = orm.OpSlice(self.conv_op, orm.Slice(2, 1))
        conv_op_group1 = orm.OpGroup(conv_op_slice_0_2,
                                     omit_source_op_slices=[conv_op_slice_0_2])
        conv_op_group2 = orm.OpGroup(conv_op_slice_2_3,
                                     omit_source_op_slices=[conv_op_slice_2_3])

        relu_op_slice_0_2 = orm.OpSlice(self.relu_op, orm.Slice(0, 2))
        relu_op_slice_2_3 = orm.OpSlice(self.relu_op, orm.Slice(2, 1))
        relu_op_group1 = orm.OpGroup(relu_op_slice_0_2)
        relu_op_group2 = orm.OpGroup(relu_op_slice_2_3)

        gamma_op_slice_0_2 = orm.OpSlice(self.gamma_op, orm.Slice(0, 2))
        gamma_op_slice_2_3 = orm.OpSlice(self.gamma_op, orm.Slice(2, 1))
        gamma_op_group1 = orm.OpGroup(
            gamma_op_slice_0_2, omit_source_op_slices=[gamma_op_slice_0_2])
        gamma_op_group2 = orm.OpGroup(
            gamma_op_slice_2_3, omit_source_op_slices=[gamma_op_slice_2_3])

        beta_op_slice_0_2 = orm.OpSlice(self.beta_op, orm.Slice(0, 2))
        beta_op_slice_2_3 = orm.OpSlice(self.beta_op, orm.Slice(2, 1))
        beta_op_group1 = orm.OpGroup(beta_op_slice_0_2,
                                     omit_source_op_slices=[beta_op_slice_0_2])
        beta_op_group2 = orm.OpGroup(beta_op_slice_2_3,
                                     omit_source_op_slices=[beta_op_slice_2_3])

        mean_op_slice_0_2 = orm.OpSlice(self.mean_op, orm.Slice(0, 2))
        mean_op_slice_2_3 = orm.OpSlice(self.mean_op, orm.Slice(2, 1))
        mean_op_group1 = orm.OpGroup(mean_op_slice_0_2,
                                     omit_source_op_slices=[mean_op_slice_0_2])
        mean_op_group2 = orm.OpGroup(mean_op_slice_2_3,
                                     omit_source_op_slices=[mean_op_slice_2_3])

        std_op_slice_0_2 = orm.OpSlice(self.std_op, orm.Slice(0, 2))
        std_op_slice_2_3 = orm.OpSlice(self.std_op, orm.Slice(2, 1))
        std_op_group1 = orm.OpGroup(std_op_slice_0_2,
                                    omit_source_op_slices=[std_op_slice_0_2])
        std_op_group2 = orm.OpGroup(std_op_slice_2_3,
                                    omit_source_op_slices=[std_op_slice_2_3])

        self.op_slice_dict = {
            self.batch_norm_op:
            [batch_norm_op_slice_0_2, batch_norm_op_slice_2_3],
            self.conv_op: [conv_op_slice_0_2, conv_op_slice_2_3],
            self.relu_op: [relu_op_slice_0_2, relu_op_slice_2_3],
            self.gamma_op: [gamma_op_slice_0_2, gamma_op_slice_2_3],
            self.beta_op: [beta_op_slice_0_2, beta_op_slice_2_3],
            self.mean_op: [mean_op_slice_0_2, mean_op_slice_2_3],
            self.std_op: [std_op_slice_0_2, std_op_slice_2_3],
        }

        # All OpSlice have groups.
        self.op_group_dict = {
            batch_norm_op_slice_0_2: batch_norm_op_group1,
            batch_norm_op_slice_2_3: batch_norm_op_group2,
            conv_op_slice_0_2: conv_op_group1,
            conv_op_slice_2_3: conv_op_group2,
            relu_op_slice_0_2: relu_op_group1,
            relu_op_slice_2_3: relu_op_group2,
            gamma_op_slice_0_2: gamma_op_group1,
            gamma_op_slice_2_3: gamma_op_group2,
            beta_op_slice_0_2: beta_op_group1,
            beta_op_slice_2_3: beta_op_group2,
            mean_op_slice_0_2: mean_op_group1,
            mean_op_slice_2_3: mean_op_group2,
            std_op_slice_0_2: std_op_group1,
            std_op_slice_2_3: std_op_group2,
        }

        # Call handler to assign grouping.
        handler = batch_norm_source_op_handler.BatchNormSourceOpHandler(
            _GAMMA_THRESHOLD)
        handler.assign_grouping(self.batch_norm_op, self.mock_op_reg_manager)

        # Verify manager looks up op slice for ops of interest.
        self.mock_op_reg_manager.get_op_slices.assert_any_call(
            self.batch_norm_op)
        self.mock_op_reg_manager.get_op_slices.assert_any_call(self.conv_op)
        self.mock_op_reg_manager.get_op_slices.assert_any_call(self.relu_op)

        # Verify manager groups batch norm with outputs and inputs by slice.
        self.mock_op_reg_manager.group_op_slices.assert_has_calls([
            mock.call([
                batch_norm_op_slice_0_2, relu_op_slice_0_2, mean_op_slice_0_2,
                std_op_slice_0_2
            ]),
            mock.call([
                batch_norm_op_slice_0_2, conv_op_slice_0_2, gamma_op_slice_0_2,
                beta_op_slice_0_2
            ],
                      omit_source_op_slices=[]),
            mock.call([
                batch_norm_op_slice_2_3, relu_op_slice_2_3, mean_op_slice_2_3,
                std_op_slice_2_3
            ]),
            mock.call([
                batch_norm_op_slice_2_3, conv_op_slice_2_3, gamma_op_slice_2_3,
                beta_op_slice_2_3
            ],
                      omit_source_op_slices=[])
        ])

        # Verify manager does not reprocess any ops.
        self.mock_op_reg_manager.process_ops.assert_not_called()
        self.mock_op_reg_manager.process_ops_last.assert_not_called()