Exemplo n.º 1
0
  def testIsBroadcast(self):
    handler = grouping_op_handler.GroupingOpHandler()
    self.op_group_dict = {}

    # Size is not 1.
    self.assertFalse(handler._is_broadcast(self.batch_norm_op,
                                           self.mock_op_reg_manager))

    # Size is 1 but op is not grouped.
    ungrouped_broadcast_input = tf.zeros([2, 4, 4, 1])
    ungrouped_broadcast_input_slice = orm.OpSlice(ungrouped_broadcast_input,
                                                  orm.Slice(0, 1))
    self.op_slice_dict[ungrouped_broadcast_input.op] = [
        ungrouped_broadcast_input_slice]
    self.assertFalse(handler._is_broadcast(ungrouped_broadcast_input.op,
                                           self.mock_op_reg_manager))

    # Size is 1 and op is grouped.
    broadcast_input = tf.zeros([2, 4, 4, 1])
    broadcast_input_slice = orm.OpSlice(broadcast_input.op, orm.Slice(0, 1))
    self.op_slice_dict[broadcast_input.op] = [broadcast_input_slice]
    broadcast_input_group = orm.OpGroup(broadcast_input_slice)
    self.op_group_dict[broadcast_input_slice] = broadcast_input_group
    self.assertTrue(handler._is_broadcast(broadcast_input.op,
                                          self.mock_op_reg_manager))
Exemplo n.º 2
0
    def setUp(self):
        super(LogisticSigmoidSourceOpHandlerTest, self).setUp()
        tf.reset_default_graph()

        # This tests a Conv2D -> ReLU -> LogisticSigmoidGating chain of ops.
        inputs = tf.zeros([2, 4, 4, 3])
        c1 = layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv1')
        activation_gating.logistic_sigmoid_gating(c1, axis=3, is_training=True)

        g = tf.get_default_graph()

        # Declare OpSlice and OpGroup for ops that are created in the test network.
        self.conv_op = g.get_operation_by_name('conv1/Conv2D')
        self.conv_op_slice = orm.OpSlice(self.conv_op, orm.Slice(0, 5))
        self.conv_op_group = orm.OpGroup(
            self.conv_op_slice, omit_source_op_slices=[self.conv_op_slice])

        self.relu_op = g.get_operation_by_name('conv1/Relu')
        self.relu_op_slice = orm.OpSlice(self.relu_op, orm.Slice(0, 5))
        self.relu_op_group = orm.OpGroup(self.relu_op_slice)

        self.activation_gating_op = g.get_operation_by_name(
            'logistic_sigmoid_gating/LogisticSigmoidGating')
        self.activation_gating_op_slice = orm.OpSlice(
            self.activation_gating_op, orm.Slice(0, 5))
        self.activation_gating_op_group = orm.OpGroup(
            self.activation_gating_op_slice)

        self.mask_logits_op = g.get_operation_by_name(
            'logistic_sigmoid_gating/mask_logits/Read/ReadVariableOp')
        self.mask_logits_op_slice = orm.OpSlice(self.mask_logits_op,
                                                orm.Slice(0, 5))
        self.mask_logits_op_group = orm.OpGroup(
            self.mask_logits_op_slice,
            omit_source_op_slices=[self.mask_logits_op_slice])

        self.multiply_op = g.get_operation_by_name(
            'logistic_sigmoid_gating/Identity')
        self.multiply_op_slice = orm.OpSlice(self.multiply_op, orm.Slice(0, 5))
        self.multiply_op_group = orm.OpGroup(
            self.multiply_op_slice,
            omit_source_op_slices=[self.multiply_op_slice])

        # Create custom mapping of OpSlice and OpGroup in manager.
        self.mock_op_reg_manager = mock.create_autospec(
            orm.OpRegularizerManager)

        def get_op_slices(op):
            return self.op_slice_dict.get(op, [])

        def get_op_group(op_slice):
            return self.op_group_dict.get(op_slice)

        self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices
        self.mock_op_reg_manager.get_op_group.side_effect = get_op_group
        self.mock_op_reg_manager.is_source_op.return_value = False
        self.mock_op_reg_manager.ops = [
            self.conv_op, self.relu_op, self.activation_gating_op,
            self.multiply_op
        ]
Exemplo n.º 3
0
    def testGetOpsWithoutGroups(self):
        # For a list of ops, verify that ops without groups are returned.
        self.op_slice_dict = {
            self.batch_norm_op: [self.batch_norm_op_slice],
            self.conv_op: [self.conv_op_slice],
            self.gamma_op: [orm.OpSlice(self.gamma_op, None)],
            self.beta_op: [orm.OpSlice(self.beta_op, None)],
            self.decay_op: [orm.OpSlice(self.decay_op, None)],
            self.epsilon_op: [orm.OpSlice(self.epsilon_op, None)],
        }

        # Only batch norm and conv ops have groups.
        self.op_group_dict = {
            self.batch_norm_op_slice: self.batch_norm_op_group,
            self.conv_op_slice: self.conv_op_group
        }

        all_ops = [
            self.batch_norm_op, self.conv_op, self.gamma_op, self.beta_op,
            self.decay_op, self.epsilon_op
        ]
        # Batch norm and conv ops have groups.  The other ops do not have groups.
        expected_ops = [
            self.gamma_op, self.beta_op, self.decay_op, self.epsilon_op
        ]
        self.assertEqual(
            expected_ops,
            op_handler_util.get_ops_without_groups(all_ops,
                                                   self.mock_op_reg_manager))
Exemplo n.º 4
0
    def testGetOpSlices(self):
        # Generic ops are treated as a concatenation of their constituent OpSlice.
        batch_norm_op_slice_0_5 = orm.OpSlice(self.unfused_batch_norm_op,
                                              orm.Slice(0, 5))
        batch_norm_op_slice_5_11 = orm.OpSlice(self.unfused_batch_norm_op,
                                               orm.Slice(5, 6))
        batch_norm_op_slice_11_18 = orm.OpSlice(self.unfused_batch_norm_op,
                                                orm.Slice(11, 7))

        # Map ops to slices.
        self.op_slice_dict = {
            self.unfused_batch_norm_op: [
                batch_norm_op_slice_0_5, batch_norm_op_slice_5_11,
                batch_norm_op_slice_11_18
            ],
        }

        # A nested list composed of a list of OpSlice for each output op.  In this
        # case, there is just one output op (i.e. batch norm).
        expected_output_op_slices = [[
            batch_norm_op_slice_0_5, batch_norm_op_slice_5_11,
            batch_norm_op_slice_11_18
        ]]

        output_ops = op_handler_util.get_output_ops(self.concat_op,
                                                    self.mock_op_reg_manager)
        self.assertEqual(
            expected_output_op_slices,
            op_handler_util.get_op_slices(output_ops,
                                          self.mock_op_reg_manager))
    def setUp(self):
        super(GroupingOpHandlerTest, self).setUp()
        tf.reset_default_graph()

        # This tests a Conv2D -> BatchNorm -> ReLU chain of ops.
        with framework.arg_scope(self._batch_norm_scope()):
            inputs = tf.zeros([2, 4, 4, 3])
            layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv1')

        g = tf.get_default_graph()

        # Declare OpSlice and OpGroup for ops of interest.
        self.batch_norm_op = g.get_operation_by_name(
            'conv1/BatchNorm/FusedBatchNormV3')
        self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op,
                                               orm.Slice(0, 5))
        self.batch_norm_op_group = orm.OpGroup(self.batch_norm_op_slice)

        self.conv_op = g.get_operation_by_name('conv1/Conv2D')
        self.conv_op_slice = orm.OpSlice(self.conv_op, orm.Slice(0, 5))
        self.conv_op_group = orm.OpGroup(
            self.conv_op_slice, omit_source_op_slices=[self.conv_op_slice])

        self.relu_op = g.get_operation_by_name('conv1/Relu')
        self.relu_op_slice = orm.OpSlice(self.relu_op, orm.Slice(0, 5))
        self.relu_op_group = orm.OpGroup(self.relu_op_slice)

        self.gamma_op = g.get_operation_by_name('conv1/BatchNorm/gamma/read')
        self.gamma_op_slice = orm.OpSlice(self.gamma_op, orm.Slice(0, 5))

        self.beta_op = g.get_operation_by_name('conv1/BatchNorm/beta/read')
        self.beta_op_slice = orm.OpSlice(self.beta_op, orm.Slice(0, 5))

        # Create mock OpRegularizerManager with custom mapping of OpSlice and
        # OpGroup.
        self.mock_op_reg_manager = mock.create_autospec(
            orm.OpRegularizerManager)

        self.op_slice_dict = {
            self.batch_norm_op: [self.batch_norm_op_slice],
            self.conv_op: [self.conv_op_slice],
            self.relu_op: [self.relu_op_slice],
            self.gamma_op: [self.gamma_op_slice],
            self.beta_op: [self.beta_op_slice],
        }

        def get_op_slices(op):
            return self.op_slice_dict.get(op)

        def get_op_group(op_slice):
            return self.op_group_dict.get(op_slice)

        self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices
        self.mock_op_reg_manager.get_op_group.side_effect = get_op_group
        self.mock_op_reg_manager.is_source_op.return_value = False
        self.mock_op_reg_manager.ops = [
            self.batch_norm_op, self.conv_op, self.relu_op, self.gamma_op,
            self.beta_op
        ]
Exemplo n.º 6
0
  def testMatMul2D(self, size):
    inputs = tf.zeros((13, 2))
    handler = matmul_source_op_handler.MatMulSourceOpHandler(0.1)

    kernel = tf.constant([[1, 2, 3], [4, 5, 6]], dtype=tf.float32)
    x = tf.matmul(inputs, kernel, transpose_b=False, name='MatMul')
    op_slice = orm.OpSlice(x.op, orm.Slice(0, size))

    transpose_kernel = tf.constant([[1, 4], [2, 5], [3, 6]], dtype=tf.float32)
    x_other = tf.matmul(
        inputs,
        transpose_kernel,
        transpose_b=True,
        name='MatMulTransposedKernel')
    op_slice_other = orm.OpSlice(x_other.op, orm.Slice(0, size))

    self.assertAllClose(
        handler.create_regularizer(op_slice).regularization_vector,
        handler.create_regularizer(op_slice_other).regularization_vector)
Exemplo n.º 7
0
    def testResliceConcatOps_NotAligned(self):
        relu3_op_slice_0_3 = orm.OpSlice(self.relu3_op, orm.Slice(0, 3))
        relu3_op_slice_3_6 = orm.OpSlice(self.relu3_op, orm.Slice(3, 3))

        # Map ops to slices.  The op c3 is composed of multiple slices.
        self.op_slice_dict = {
            self.relu2_op: [self.relu2_op_slice],
            self.relu3_op: [relu3_op_slice_0_3, relu3_op_slice_3_6],
            self.relu4_op: [self.relu4_op_slice],
        }

        op_handler_util.reslice_concat_ops(
            [self.relu2_op, self.relu3_op, self.relu4_op], [5, 4, 2, 2, 5],
            self.mock_op_reg_manager)

        # Verify manager slices input ops.
        self.mock_op_reg_manager.slice_op.assert_has_calls([
            mock.call(self.relu3_op, [4, 2]),
            mock.call(self.relu4_op, [2, 5])
        ])
    def testCreateRegularizer_Sliced(self):
        # Call handler to create regularizer.
        handler = conv2d_source_op_handler.Conv2DSourceOpHandler(
            _DEFAULT_THRESHOLD)
        conv2_op_slice = orm.OpSlice(self.conv2_op, orm.Slice(0, 3))
        regularizer = handler.create_regularizer(conv2_op_slice)

        # Verify regularizer produces correctly shaped tensors.
        # Most of the regularizer testing is in group_lasso_regularizer_test.py
        expected_norm_dim = 3
        self.assertEqual(expected_norm_dim,
                         regularizer.regularization_vector.shape.as_list()[0])
Exemplo n.º 9
0
    def setUp(self):
        tf.reset_default_graph()

        # Test a Identity -> DepthToSpace -> Identity chain of ops.
        inputs = tf.zeros([2, 4, 4, 4])
        id1 = tf.identity(inputs)
        dts = tf.depth_to_space(id1, 2)
        tf.identity(dts)

        g = tf.get_default_graph()

        # Declare OpSlice and OpGroup for ops of interest.
        self.id1_op = g.get_operation_by_name('Identity')
        self.id1_op_slice = orm.OpSlice(self.id1_op, orm.Slice(0, 4))
        self.id1_op_group = orm.OpGroup(
            self.id1_op_slice, omit_source_op_slices=[self.id1_op_slice])
        self.id1_op_slice0 = orm.OpSlice(self.id1_op, orm.Slice(0, 1))
        self.id1_op_slice1 = orm.OpSlice(self.id1_op, orm.Slice(1, 1))
        self.id1_op_slice2 = orm.OpSlice(self.id1_op, orm.Slice(2, 1))
        self.id1_op_slice3 = orm.OpSlice(self.id1_op, orm.Slice(3, 1))

        self.dts_op = g.get_operation_by_name('DepthToSpace')
        self.dts_op_slice = orm.OpSlice(self.dts_op, orm.Slice(0, 1))
        self.dts_op_group = orm.OpGroup(
            self.dts_op_slice, omit_source_op_slices=[self.dts_op_slice])

        self.id2_op = g.get_operation_by_name('Identity_1')
        self.id2_op_slice = orm.OpSlice(self.id2_op, orm.Slice(0, 1))
        self.id2_op_group = orm.OpGroup(
            self.id2_op_slice, omit_source_op_slices=[self.id2_op_slice])

        # Create mock OpRegularizerManager with custom mapping of OpSlice and
        # OpGroup.
        self.mock_op_reg_manager = mock.create_autospec(
            orm.OpRegularizerManager)

        self.op_slice_dict = {
            self.id1_op: [self.id1_op_slice],
            self.dts_op: [self.dts_op_slice],
            self.id2_op: [self.id2_op_slice],
        }

        def get_op_slices(op):
            return self.op_slice_dict.get(op)

        def get_op_group(op_slice):
            return self.op_group_dict.get(op_slice)

        self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices
        self.mock_op_reg_manager.get_op_group.side_effect = get_op_group
        self.mock_op_reg_manager.is_source_op.return_value = False
        self.mock_op_reg_manager.ops = [self.id1_op, self.dts_op, self.id2_op]
Exemplo n.º 10
0
    def testCreateRegularizer_Sliced(self):
        # Call handler to create regularizer.
        handler = batch_norm_source_op_handler.BatchNormSourceOpHandler(
            _GAMMA_THRESHOLD)
        batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 3))
        regularizer = handler.create_regularizer(batch_norm_op_slice)

        # Verify regularizer is the gamma tensor.
        with self.cached_session():
            # Initialize the gamma tensor to check value equality.
            with tf.variable_scope('', reuse=tf.AUTO_REUSE):
                gamma_tensor = tf.get_variable('conv1/BatchNorm/gamma')
            init = tf.variables_initializer([gamma_tensor])
            init.run()

            # Verify regularizer is the sliced gamma tensor.
            self.assertAllEqual(gamma_tensor.eval()[0:3],
                                regularizer._gamma.eval())
    def testOpHandlerDecorator(self):
        image = tf.constant(0.0, shape=[1, 17, 19, 3])
        kernel = tf.ones([5, 5, 3, 3])

        output = tf.nn.conv2d(image,
                              kernel,
                              strides=[1, 1, 1, 1],
                              padding='SAME')

        decorated_op_handler = op_handler_decorator.OpHandlerDecorator(
            conv_source_op_handler.ConvSourceOpHandler(1e-3, 0),
            DummyDecorator)
        op_slice = orm.OpSlice(output.op, orm.Slice(0, 3))
        regularizer = decorated_op_handler.create_regularizer(op_slice)

        self.assertAllClose(0.5 * np.ones(3),
                            regularizer.regularization_vector)
        self.assertAllClose(np.ones(3), regularizer.alive_vector)
Exemplo n.º 12
0
    def testCreateRegularizer_OnMask_Sliced(self):
        # Call handler to create regularizer.
        handler = ls_source_op_handler.LogisticSigmoidSourceOpHandler(
            regularize_on_mask=True)
        activation_gating_op_slice = orm.OpSlice(self.activation_gating_op,
                                                 orm.Slice(0, 3))
        regularizer = handler.create_regularizer(activation_gating_op_slice)

        g = tf.get_default_graph()
        mask_tensor = g.get_tensor_by_name(
            'logistic_sigmoid_gating/LogisticSigmoidGating:0')
        with self.cached_session() as sess:
            sess.run(tf.global_variables_initializer())
            mask, reg_vec = sess.run(
                [mask_tensor, regularizer._regularization_vector])

            # Verify regularizer is the sliced mask tensor.
            self.assertAllEqual(mask[0:3], reg_vec)
Exemplo n.º 13
0
    def testCreateRegularizer_OnLogits_Sliced(self):
        # Call handler to create regularizer.
        handler = ls_source_op_handler.LogisticSigmoidSourceOpHandler(
            regularize_on_mask=False)
        activation_gating_op_slice = orm.OpSlice(self.activation_gating_op,
                                                 orm.Slice(0, 3))
        regularizer = handler.create_regularizer(activation_gating_op_slice)

        g = tf.get_default_graph()
        logits_tensor = g.get_tensor_by_name(
            'logistic_sigmoid_gating/mask_logits/Read/ReadVariableOp:0')
        with self.cached_session() as sess:
            sess.run(tf.global_variables_initializer())
            logits, reg_vec = sess.run(
                [logits_tensor, regularizer._regularization_vector])

            # Verify regularizer is the sliced probability tensor.
            self.assertAllEqual(sp.special.expit(logits[0:3]), reg_vec)
Exemplo n.º 14
0
    def testGetOpSliceSizes(self):
        relu3_op_slice_0_3 = orm.OpSlice(self.relu2_op, orm.Slice(0, 3))
        relu3_op_slice_3_6 = orm.OpSlice(self.relu2_op, orm.Slice(3, 3))

        batch_norm_op_slice_0_5 = orm.OpSlice(self.unfused_batch_norm_op,
                                              orm.Slice(0, 5))
        batch_norm_op_slice_5_8 = orm.OpSlice(self.unfused_batch_norm_op,
                                              orm.Slice(5, 3))
        batch_norm_op_slice_8_11 = orm.OpSlice(self.unfused_batch_norm_op,
                                               orm.Slice(8, 3))
        batch_norm_op_slice_11_18 = orm.OpSlice(self.unfused_batch_norm_op,
                                                orm.Slice(11, 7))

        # Map ops to slices.
        self.op_slice_dict = {
            self.relu2_op: [self.relu2_op_slice],
            self.relu3_op: [relu3_op_slice_0_3, relu3_op_slice_3_6],
            self.relu4_op: [self.relu4_op_slice],
            self.unfused_batch_norm_op: [
                batch_norm_op_slice_0_5, batch_norm_op_slice_5_8,
                batch_norm_op_slice_8_11, batch_norm_op_slice_11_18
            ],
        }

        expected_op_slice_sizes = [
            [5],  # c2 has size 5.
            [3, 3],  # c3 has size 6, but in 2 slices of size 3.
            [7],  # c4 has size 7.
            [5, 3, 3, 7]
        ]  # batch norm has size 18, but slice sizes of c1, c2, c3.

        self.assertEqual(
            expected_op_slice_sizes,
            op_handler_util.get_op_slice_sizes(
                [[self.relu2_op_slice],
                 [relu3_op_slice_0_3,
                  relu3_op_slice_3_6], [self.relu4_op_slice],
                 [
                     batch_norm_op_slice_0_5, batch_norm_op_slice_5_8,
                     batch_norm_op_slice_8_11, batch_norm_op_slice_11_18
                 ]]))
Exemplo n.º 15
0
    def setUp(self):
        tf.reset_default_graph()

        # This tests a Conv2D -> BatchNorm -> ReLU chain of ops.
        with tf.contrib.framework.arg_scope(self._batch_norm_scope()):
            inputs = tf.zeros([2, 4, 4, 3])
            layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv1')

        g = tf.get_default_graph()

        # Declare OpSlice and OpGroup for ops that are created in the test network.
        self.batch_norm_op = g.get_operation_by_name(
            'conv1/BatchNorm/FusedBatchNorm')
        self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op,
                                               orm.Slice(0, 5))
        self.batch_norm_op_group = orm.OpGroup(self.batch_norm_op_slice)

        self.conv_op = g.get_operation_by_name('conv1/Conv2D')
        self.conv_op_slice = orm.OpSlice(self.conv_op, orm.Slice(0, 5))
        self.conv_op_group = orm.OpGroup(
            self.conv_op_slice, omit_source_op_slices=[self.conv_op_slice])

        self.relu_op = g.get_operation_by_name('conv1/Relu')
        self.relu_op_slice = orm.OpSlice(self.relu_op, orm.Slice(0, 5))
        self.relu_op_group = orm.OpGroup(self.relu_op_slice)

        self.gamma_op = g.get_operation_by_name('conv1/BatchNorm/gamma/read')
        self.gamma_op_slice = orm.OpSlice(self.gamma_op, orm.Slice(0, 5))
        self.gamma_op_group = orm.OpGroup(
            self.gamma_op_slice, omit_source_op_slices=[self.gamma_op_slice])

        self.beta_op = g.get_operation_by_name('conv1/BatchNorm/beta/read')
        self.beta_op_slice = orm.OpSlice(self.beta_op, orm.Slice(0, 5))
        self.beta_op_group = orm.OpGroup(
            self.beta_op_slice, omit_source_op_slices=[self.beta_op_slice])

        self.mean_op = g.get_operation_by_name(
            'conv1/BatchNorm/AssignMovingAvg/sub_1')
        self.mean_op_slice = orm.OpSlice(self.mean_op, orm.Slice(0, 5))
        self.mean_op_group = orm.OpGroup(
            self.mean_op_slice, omit_source_op_slices=[self.mean_op_slice])

        self.std_op = g.get_operation_by_name(
            'conv1/BatchNorm/AssignMovingAvg_1/sub_1')
        self.std_op_slice = orm.OpSlice(self.std_op, orm.Slice(0, 5))
        self.std_op_group = orm.OpGroup(
            self.std_op_slice, omit_source_op_slices=[self.std_op_slice])

        # Create custom mapping of OpSlice and OpGroup in manager.
        self.mock_op_reg_manager = mock.create_autospec(
            orm.OpRegularizerManager)

        def get_op_slices(op):
            return self.op_slice_dict.get(op, [])

        def get_op_group(op_slice):
            return self.op_group_dict.get(op_slice)

        self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices
        self.mock_op_reg_manager.get_op_group.side_effect = get_op_group
        self.mock_op_reg_manager.is_source_op.return_value = False
        self.mock_op_reg_manager.ops = [
            self.batch_norm_op, self.conv_op, self.relu_op, self.gamma_op,
            self.beta_op, self.mean_op, self.std_op
        ]
Exemplo n.º 16
0
    def setUp(self):
        super(OpHandlerUtilTest, self).setUp()
        tf.reset_default_graph()

        # This tests a Conv2D -> BatchNorm -> ReLU chain of ops.
        with framework.arg_scope(self._batch_norm_scope()):
            inputs = tf.zeros([2, 4, 4, 3])
            layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv1')

        # This tests 3 Conv2D ops being concatenated before a batch normalization.
        c2 = layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv2')
        c3 = layers.conv2d(inputs, num_outputs=6, kernel_size=3, scope='conv3')
        c4 = layers.conv2d(inputs, num_outputs=7, kernel_size=3, scope='conv4')
        net = tf.concat([c2, c3, c4], axis=3)
        layers.batch_norm(net)

        g = tf.get_default_graph()

        # Declare OpSlice and OpGroup for ops in the first test network.
        self.batch_norm_op = g.get_operation_by_name(
            'conv1/BatchNorm/FusedBatchNormV3')
        self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, None)
        self.batch_norm_op_group = orm.OpGroup(self.batch_norm_op_slice)

        self.conv_op = g.get_operation_by_name('conv1/Conv2D')
        self.conv_op_slice = orm.OpSlice(self.conv_op, None)
        self.conv_op_group = orm.OpGroup(
            self.conv_op_slice, omit_source_op_slices=[self.conv_op_slice])

        self.gamma_op = g.get_operation_by_name('conv1/BatchNorm/gamma/read')
        self.beta_op = g.get_operation_by_name('conv1/BatchNorm/beta/read')
        self.decay_op = g.get_operation_by_name('conv1/BatchNorm/Const')
        self.epsilon_op = g.get_operation_by_name('conv1/BatchNorm/Const_1')
        self.mean_op = g.get_operation_by_name(
            'conv1/BatchNorm/AssignMovingAvg/sub_1')
        self.std_op = g.get_operation_by_name(
            'conv1/BatchNorm/AssignMovingAvg_1/sub_1')

        self.relu_op = g.get_operation_by_name('conv1/Relu')
        self.relu_op_slice = orm.OpSlice(self.relu_op, None)
        self.relu_op_group = orm.OpGroup(
            self.relu_op_slice, omit_source_op_slices=[self.relu_op_slice])

        # Declare OpSlice and OpGroup for ops in the second test network.
        self.relu2_op = g.get_operation_by_name('conv2/Relu')
        self.relu2_op_slice = orm.OpSlice(self.relu2_op, orm.Slice(0, 5))
        self.relu2_op_group = orm.OpGroup(
            self.relu2_op_slice, omit_source_op_slices=[self.relu2_op_slice])

        self.relu3_op = g.get_operation_by_name('conv3/Relu')
        self.relu3_op_slice = orm.OpSlice(self.relu3_op, orm.Slice(0, 6))
        self.relu3_op_group = orm.OpGroup(
            self.relu3_op_slice, omit_source_op_slices=[self.relu3_op_slice])

        self.relu4_op = g.get_operation_by_name('conv4/Relu')
        self.relu4_op_slice = orm.OpSlice(self.relu4_op, orm.Slice(0, 7))
        self.relu4_op_group = orm.OpGroup(
            self.relu4_op_slice, omit_source_op_slices=[self.relu4_op_slice])

        self.unfused_batch_norm_op = g.get_operation_by_name(
            'BatchNorm/FusedBatchNormV3')
        self.unfused_batch_norm_op_slice = orm.OpSlice(
            self.unfused_batch_norm_op, orm.Slice(0, 18))

        self.concat_op = g.get_operation_by_name('concat')
        self.concat_op_slice = orm.OpSlice(self.concat_op, orm.Slice(0, 18))
        self.concat_op_group = orm.OpGroup(
            self.concat_op_slice, omit_source_op_slices=[self.concat_op_slice])

        # Create mock OpRegularizerManager with custom mapping of OpSlice and
        # OpGroup.
        self.mock_op_reg_manager = mock.create_autospec(
            orm.OpRegularizerManager)

        def get_op_slices(op):
            return self.op_slice_dict.get(op, [])

        def get_op_group(op_slice):
            return self.op_group_dict.get(op_slice)

        def is_passthrough(op):
            return op in self._passthrough_ops

        self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices
        self.mock_op_reg_manager.get_op_group.side_effect = get_op_group
        self.mock_op_reg_manager.is_passthrough.side_effect = is_passthrough
        self.mock_op_reg_manager.ops = [
            self.batch_norm_op, self.gamma_op, self.beta_op, self.decay_op,
            self.epsilon_op, self.mean_op, self.std_op, self.conv_op,
            self.relu_op, self.relu2_op, self.relu3_op, self.relu4_op,
            self.unfused_batch_norm_op, self.concat_op
        ]
Exemplo n.º 17
0
    def testGroupOpWithInputsAndOutputs_MultipleSlices(self):
        # For the multiple slice case, verify that batch norm slices are grouped
        # with output slices (ReLU) and input slices (Conv2D).
        batch_norm_op_slice_0_2 = orm.OpSlice(self.batch_norm_op,
                                              orm.OpSlice(0, 2))
        batch_norm_op_slice_2_5 = orm.OpSlice(self.batch_norm_op,
                                              orm.OpSlice(2, 3))
        batch_norm_op_group1 = orm.OpGroup(batch_norm_op_slice_0_2)
        batch_norm_op_group2 = orm.OpGroup(batch_norm_op_slice_2_5)

        conv_op_slice_0_2 = orm.OpSlice(self.conv_op, orm.OpSlice(0, 2))
        conv_op_slice_2_5 = orm.OpSlice(self.conv_op, orm.OpSlice(2, 3))
        conv_op_group1 = orm.OpGroup(conv_op_slice_0_2,
                                     omit_source_op_slices=[conv_op_slice_0_2])
        conv_op_group2 = orm.OpGroup(conv_op_slice_2_5,
                                     omit_source_op_slices=[conv_op_slice_2_5])

        relu_op_slice_0_2 = orm.OpSlice(self.relu_op, orm.OpSlice(0, 2))
        relu_op_slice_2_5 = orm.OpSlice(self.relu_op, orm.OpSlice(2, 3))
        relu_op_group1 = orm.OpGroup(relu_op_slice_0_2)
        relu_op_group2 = orm.OpGroup(relu_op_slice_2_5)

        aligned_op_slice_sizes = [2, 3]

        self.op_slice_dict = {
            self.batch_norm_op:
            [batch_norm_op_slice_0_2, batch_norm_op_slice_2_5],
            self.conv_op: [conv_op_slice_0_2, conv_op_slice_2_5],
            self.relu_op: [relu_op_slice_0_2, relu_op_slice_2_5],
        }

        # All ops have groups.
        self.op_group_dict = {
            batch_norm_op_slice_0_2: batch_norm_op_group1,
            batch_norm_op_slice_2_5: batch_norm_op_group2,
            conv_op_slice_0_2: conv_op_group1,
            conv_op_slice_2_5: conv_op_group2,
            relu_op_slice_0_2: relu_op_group1,
            relu_op_slice_2_5: relu_op_group2,
        }

        ops_grouped = op_handler_util.group_op_with_inputs_and_outputs(
            self.batch_norm_op, [[conv_op_slice_0_2, conv_op_slice_2_5]],
            [[relu_op_slice_0_2, relu_op_slice_2_5]], aligned_op_slice_sizes,
            self.mock_op_reg_manager)

        # Verify manager looks up op slice for ops of interest.
        self.mock_op_reg_manager.get_op_slices.assert_any_call(
            self.batch_norm_op)

        # Verify manager groups batch norm with Conv2D and ReLU ops.
        self.assertTrue(ops_grouped)
        self.mock_op_reg_manager.group_op_slices.assert_has_calls([
            mock.call([batch_norm_op_slice_0_2, relu_op_slice_0_2]),
            mock.call([batch_norm_op_slice_0_2, conv_op_slice_0_2]),
            mock.call([batch_norm_op_slice_2_5, relu_op_slice_2_5]),
            mock.call([batch_norm_op_slice_2_5, conv_op_slice_2_5])
        ])
Exemplo n.º 18
0
  def setUp(self):
    tf.reset_default_graph()

    # This tests 2 Conv2D ops with batch norm at the top.
    with tf.contrib.framework.arg_scope(self._batch_norm_scope()):
      inputs = tf.zeros([2, 4, 4, 3])
      c1 = layers.conv2d(inputs, num_outputs=5,
                         kernel_size=3, scope='conv1', normalizer_fn=None)
      layers.conv2d(c1, num_outputs=6, kernel_size=3, scope='conv2')

    g = tf.get_default_graph()

    # Declare OpSlice and OpGroup for ops of interest.
    self.conv1_op = g.get_operation_by_name('conv1/Conv2D')
    self.conv1_op_slice = orm.OpSlice(self.conv1_op, orm.Slice(0, 5))
    self.conv1_op_group = orm.OpGroup(
        self.conv1_op_slice, omit_source_op_slices=[self.conv1_op_slice])

    self.relu1_op = g.get_operation_by_name('conv1/Relu')
    self.relu1_op_slice = orm.OpSlice(self.relu1_op, orm.Slice(0, 5))
    self.relu1_op_group = orm.OpGroup(
        self.relu1_op_slice, omit_source_op_slices=[self.relu1_op_slice])

    self.conv2_op = g.get_operation_by_name('conv2/Conv2D')
    self.conv2_op_slice = orm.OpSlice(self.conv2_op, orm.Slice(0, 6))
    self.conv2_op_group = orm.OpGroup(
        self.conv2_op_slice, omit_source_op_slices=[self.conv2_op_slice])

    self.relu2_op = g.get_operation_by_name('conv2/Relu')
    self.relu2_op_slice = orm.OpSlice(self.relu2_op, orm.Slice(0, 6))
    self.relu2_op_group = orm.OpGroup(
        self.relu2_op_slice, omit_source_op_slices=[self.relu2_op_slice])

    self.batch_norm_op = g.get_operation_by_name(
        'conv2/BatchNorm/FusedBatchNorm')
    self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 6))
    self.batch_norm_op_group = orm.OpGroup(self.batch_norm_op_slice)

    # Create mock OpRegularizerManager with custom mapping of OpSlice and
    # OpGroup.
    self.mock_op_reg_manager = mock.create_autospec(orm.OpRegularizerManager)

    def get_op_slices(op):
      return self.op_slice_dict.get(op, [])

    def get_op_group(op_slice):
      return self.op_group_dict.get(op_slice)

    def is_passthrough(op):
      if op in [self.conv1_op, self.conv2_op]:
        h = output_non_passthrough_op_handler.OutputNonPassthroughOpHandler()
        return h.is_passthrough
      if op == self.batch_norm_op:
        return True
      else:
        return False

    self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices
    self.mock_op_reg_manager.get_op_group.side_effect = get_op_group
    self.mock_op_reg_manager.is_source_op.return_value = False
    self.mock_op_reg_manager.is_passthrough.side_effect = is_passthrough
    self.mock_op_reg_manager.ops = [
        self.conv1_op, self.relu1_op, self.conv2_op, self.relu2_op,
        self.batch_norm_op]
Exemplo n.º 19
0
    def testAssignGrouping_AllNeighborsGrouped_InputSlicesNotAligned(self):
        # In this test, the op c2 has size 6 but is split into 2 slices of size 3.
        # The concat op (and its output, the batch norm) both have size 18.  This
        # test verifies that the concat and batch norm are sliced according to the
        # sizes of c1, c2, and c3, and takes into account that c2 is also composed
        # of multiple slices.
        concat_op_slice_0_5 = orm.OpSlice(self.concat_op, orm.Slice(0, 5))
        concat_op_slice_5_8 = orm.OpSlice(self.concat_op, orm.Slice(5, 3))
        concat_op_slice_8_11 = orm.OpSlice(self.concat_op, orm.Slice(8, 3))
        concat_op_slice_11_18 = orm.OpSlice(self.concat_op, orm.Slice(11, 7))

        relu2_op_slice_0_3 = orm.OpSlice(self.relu2_op, orm.Slice(0, 3))
        relu2_op_slice_3_6 = orm.OpSlice(self.relu2_op, orm.Slice(3, 3))
        relu2_op_group1 = orm.OpGroup(
            relu2_op_slice_0_3, omit_source_op_slices=[relu2_op_slice_0_3])
        relu2_op_group2 = orm.OpGroup(
            relu2_op_slice_3_6, omit_source_op_slices=[relu2_op_slice_3_6])

        batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 18))
        batch_norm_op_group = orm.OpGroup(
            batch_norm_op_slice, omit_source_op_slices=[batch_norm_op_slice])
        batch_norm_op_slice_0_5 = orm.OpSlice(self.batch_norm_op,
                                              orm.Slice(0, 5))
        batch_norm_op_group1 = orm.OpGroup(
            batch_norm_op_slice_0_5,
            omit_source_op_slices=[batch_norm_op_slice_0_5])
        batch_norm_op_slice_5_8 = orm.OpSlice(self.batch_norm_op,
                                              orm.Slice(5, 3))
        batch_norm_op_group2 = orm.OpGroup(
            batch_norm_op_slice_5_8,
            omit_source_op_slices=[batch_norm_op_slice_5_8])
        batch_norm_op_slice_8_11 = orm.OpSlice(self.batch_norm_op,
                                               orm.Slice(8, 3))
        batch_norm_op_group3 = orm.OpGroup(
            batch_norm_op_slice_8_11,
            omit_source_op_slices=[batch_norm_op_slice_8_11])
        batch_norm_op_slice_11_18 = orm.OpSlice(self.batch_norm_op,
                                                orm.Slice(11, 7))
        batch_norm_op_group4 = orm.OpGroup(
            batch_norm_op_slice_11_18,
            omit_source_op_slices=[batch_norm_op_slice_11_18])

        # Map ops to slices.  The op c2 is composed of multiple slices.
        self.op_slice_dict = {
            self.relu1_op: [self.relu1_op_slice],
            self.relu2_op: [relu2_op_slice_0_3, relu2_op_slice_3_6],
            self.relu3_op: [self.relu3_op_slice],
            self.concat_op: [self.concat_op_slice],
            self.batch_norm_op: [batch_norm_op_slice],
        }

        # Map each slice to a group.
        self.op_group_dict = {
            self.relu1_op_slice: self.relu1_op_group,
            relu2_op_slice_0_3: relu2_op_group1,
            relu2_op_slice_3_6: relu2_op_group2,
            self.relu3_op_slice: self.relu3_op_group,
            batch_norm_op_slice: batch_norm_op_group,
            batch_norm_op_slice_0_5: batch_norm_op_group1,
            batch_norm_op_slice_5_8: batch_norm_op_group2,
            batch_norm_op_slice_8_11: batch_norm_op_group3,
            batch_norm_op_slice_11_18: batch_norm_op_group4,
        }

        # Update op_slice_dict when an op is sliced.
        def slice_op(op, _):
            if op == self.batch_norm_op:
                self.op_slice_dict[self.batch_norm_op] = [
                    batch_norm_op_slice_0_5, batch_norm_op_slice_5_8,
                    batch_norm_op_slice_8_11, batch_norm_op_slice_11_18
                ]
            if op == self.concat_op:
                self.op_slice_dict[self.concat_op] = [
                    concat_op_slice_0_5, concat_op_slice_5_8,
                    concat_op_slice_8_11, concat_op_slice_11_18
                ]

        self.mock_op_reg_manager.slice_op.side_effect = slice_op

        # Call handler to assign grouping.
        handler = concat_op_handler.ConcatOpHandler()
        handler.assign_grouping(self.concat_op, self.mock_op_reg_manager)

        # Verify manager looks up OpSlice for ops of interest.
        self.mock_op_reg_manager.get_op_slices.assert_has_calls(
            # Checking for ops to process.
            [
                mock.call(self.relu1_op),
                mock.call(self.relu2_op),
                mock.call(self.relu3_op),
                mock.call(self.batch_norm_op),
                # Initial slice data.
                mock.call(self.concat_op),
                mock.call(self.relu1_op),
                mock.call(self.relu2_op),
                mock.call(self.relu3_op),
                mock.call(self.batch_norm_op),
                # Reslicing.
                mock.call(self.relu1_op),
                mock.call(self.relu2_op),
                mock.call(self.relu3_op),
                mock.call(self.batch_norm_op),
                mock.call(self.concat_op),
                # Refreshing slice data.
                mock.call(self.relu1_op),
                mock.call(self.relu2_op),
                mock.call(self.relu3_op),
                mock.call(self.batch_norm_op),
                # Group concat op.
                mock.call(self.concat_op)
            ])

        # Verify manager slices ops that do not have aligned OpSlice sizes.
        self.mock_op_reg_manager.slice_op.assert_has_calls([
            mock.call(self.batch_norm_op, [5, 3, 3, 7]),
            mock.call(self.concat_op, [5, 3, 3, 7])
        ])

        # Verify manager groups the new slices.
        self.mock_op_reg_manager.group_op_slices.assert_has_calls([
            mock.call([
                concat_op_slice_0_5, self.relu1_op_slice,
                batch_norm_op_slice_0_5
            ]),
            mock.call([
                concat_op_slice_5_8, relu2_op_slice_0_3,
                batch_norm_op_slice_5_8
            ]),
            mock.call([
                concat_op_slice_8_11, relu2_op_slice_3_6,
                batch_norm_op_slice_8_11
            ]),
            mock.call([
                concat_op_slice_11_18, self.relu3_op_slice,
                batch_norm_op_slice_11_18
            ])
        ])
    def setUp(self):
        tf.reset_default_graph()

        # This tests 2 Conv2D ops.
        inputs = tf.zeros([2, 4, 4, 3])
        c1 = layers.conv2d(inputs,
                           num_outputs=5,
                           kernel_size=3,
                           scope='conv1',
                           normalizer_fn=None)
        layers.conv2d(c1,
                      num_outputs=6,
                      kernel_size=3,
                      scope='conv2',
                      normalizer_fn=None)

        g = tf.get_default_graph()

        # Declare OpSlice and OpGroup for ops of interest.
        self.conv1_op = g.get_operation_by_name('conv1/Conv2D')
        self.conv1_op_slice = orm.OpSlice(self.conv1_op, orm.Slice(0, 5))
        self.conv1_op_group = orm.OpGroup(
            self.conv1_op_slice, omit_source_op_slices=[self.conv1_op_slice])

        self.relu1_op = g.get_operation_by_name('conv1/Relu')
        self.relu1_op_slice = orm.OpSlice(self.relu1_op, orm.Slice(0, 5))
        self.relu1_op_group = orm.OpGroup(
            self.relu1_op_slice, omit_source_op_slices=[self.relu1_op_slice])

        self.conv2_op = g.get_operation_by_name('conv2/Conv2D')
        self.conv2_op_slice = orm.OpSlice(self.conv2_op, orm.Slice(0, 6))
        self.conv2_op_group = orm.OpGroup(
            self.conv2_op_slice, omit_source_op_slices=[self.conv2_op_slice])

        self.conv2_weights_op = g.get_operation_by_name('conv2/weights/read')
        self.conv2_weights_op_slice = orm.OpSlice(self.conv2_weights_op,
                                                  orm.Slice(0, 6))
        self.conv2_weights_op_group = orm.OpGroup(
            self.conv2_weights_op_slice,
            omit_source_op_slices=[self.conv2_weights_op_slice])

        self.relu2_op = g.get_operation_by_name('conv2/Relu')
        self.relu2_op_slice = orm.OpSlice(self.relu2_op, orm.Slice(0, 6))
        self.relu2_op_group = orm.OpGroup(
            self.relu2_op_slice, omit_source_op_slices=[self.relu2_op_slice])

        # Create mock OpRegularizerManager with custom mapping of OpSlice and
        # OpGroup.
        self.mock_op_reg_manager = mock.create_autospec(
            orm.OpRegularizerManager)

        def get_op_slices(op):
            return self.op_slice_dict.get(op, [])

        def get_op_group(op_slice):
            return self.op_group_dict.get(op_slice)

        def is_passthrough(op):
            if op in [self.conv1_op, self.conv2_op]:
                h = conv2d_source_op_handler.Conv2DSourceOpHandler(
                    _DEFAULT_THRESHOLD)
                return h.is_passthrough
            else:
                return False

        self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices
        self.mock_op_reg_manager.get_op_group.side_effect = get_op_group
        self.mock_op_reg_manager.is_source_op.return_value = False
        self.mock_op_reg_manager.is_passthrough.side_effect = is_passthrough
        self.mock_op_reg_manager.ops = [
            self.conv1_op, self.relu1_op, self.conv2_op, self.relu2_op,
            self.conv2_weights_op
        ]
Exemplo n.º 21
0
    def testAssignGrouping_NoDepthMultiplier(self):
        # Repeat setUp, but with depth_multiplier=1.  Unfortunately, this involves
        # rebuilding the graph from scratch.
        tf.reset_default_graph()

        # This tests a Conv2D -> SeparableConv2D -> Conv2D chain of ops.
        with tf.contrib.framework.arg_scope(self._batch_norm_scope()):
            inputs = tf.zeros([2, 4, 4, 3])
            c1 = layers.conv2d(inputs,
                               num_outputs=5,
                               kernel_size=3,
                               scope='conv1')
            c2 = layers.separable_conv2d(c1,
                                         num_outputs=8,
                                         kernel_size=3,
                                         depth_multiplier=1,
                                         scope='conv2')
            layers.conv2d(c2, num_outputs=6, kernel_size=3, scope='conv3')

        g = tf.get_default_graph()

        # Declare OpSlice and OpGroup for ops of interest.
        self.dwise_conv2_op = g.get_operation_by_name(
            'conv2/separable_conv2d/depthwise')
        self.dwise_conv2_op_slice = orm.OpSlice(self.dwise_conv2_op,
                                                orm.Slice(0, 5))

        self.conv2_op = g.get_operation_by_name('conv2/separable_conv2d')
        self.conv2_op_slice = orm.OpSlice(self.conv2_op, orm.Slice(0, 8))

        self.relu1_op = g.get_operation_by_name('conv1/Relu')
        self.relu1_op_slice = orm.OpSlice(self.relu1_op, orm.Slice(0, 5))
        self.relu1_op_group = orm.OpGroup(self.relu1_op_slice)

        self.conv3_op = g.get_operation_by_name('conv3/Conv2D')
        self.conv3_op_slice = orm.OpSlice(self.conv3_op, orm.Slice(0, 6))

        # Create mock OpRegularizerManager with custom mapping of OpSlice and
        # OpGroup.
        self.mock_op_reg_manager = mock.create_autospec(
            orm.OpRegularizerManager)

        self.op_slice_dict = {
            self.dwise_conv2_op: [self.dwise_conv2_op_slice],
            self.conv2_op: [self.conv2_op_slice],
            self.relu1_op: [self.relu1_op_slice],
            self.conv3_op: [self.conv3_op_slice],
        }

        def get_op_slices(op):
            return self.op_slice_dict.get(op)

        def get_op_group(op_slice):
            return self.op_group_dict.get(op_slice)

        self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices
        self.mock_op_reg_manager.get_op_group.side_effect = get_op_group
        self.mock_op_reg_manager.is_source_op.return_value = False
        self.mock_op_reg_manager.ops = [
            self.relu1_op, self.dwise_conv2_op, self.conv2_op, self.conv3_op
        ]

        # All neighbor ops have groups.
        self.op_group_dict = {
            self.relu1_op_slice: self.relu1_op_group,
        }

        # Call handler to assign grouping.
        handler = depthwise_convolution_op_handler.DepthwiseConvolutionOpHandler(
        )
        handler.assign_grouping(self.dwise_conv2_op, self.mock_op_reg_manager)

        # Verify manager looks up OpSlice for ops of interest.
        self.mock_op_reg_manager.get_op_slices.assert_has_calls(
            # Checking for ops to process.
            [
                mock.call(self.relu1_op),
                mock.call(self.conv2_op),
                # Initial slice data.
                mock.call(self.dwise_conv2_op),
                mock.call(self.relu1_op),
                # Reslicing.
                mock.call(self.relu1_op),
                mock.call(self.dwise_conv2_op),
                # Refreshing slice data.
                mock.call(self.relu1_op),
                # Group depthwise convolution.
                mock.call(self.dwise_conv2_op)
            ])

        # Verify manager groups batch norm with inputs and outputs.
        self.mock_op_reg_manager.group_op_slices.assert_called_once_with(
            [self.dwise_conv2_op_slice, self.relu1_op_slice])

        # Verify manager does not process any additional ops.
        self.mock_op_reg_manager.process_ops.assert_called_once_with(
            [self.conv2_op])
        self.mock_op_reg_manager.process_ops_last.assert_not_called()
  def setUp(self):
    tf.reset_default_graph()

    # This tests 3 Conv2D ops being concatenated.
    inputs = tf.zeros([2, 4, 4, 3])
    with tf.contrib.framework.arg_scope(self._get_scope()):
      c1 = layers.conv2d(inputs, num_outputs=6, kernel_size=3, scope='conv1')
      c2 = layers.conv2d(inputs, num_outputs=6, kernel_size=3, scope='conv2')
      c3 = layers.conv2d(inputs, num_outputs=6, kernel_size=3, scope='conv3')
      net = tf.concat([c1, c2, c3], axis=2)
      layers.batch_norm(net)

    g = tf.get_default_graph()

    # Declare OpSlice and OpGroup for ops of interest.
    self.concat_op = g.get_operation_by_name('concat')
    self.concat_op_slice = orm.OpSlice(self.concat_op, orm.Slice(0, 6))
    self.concat_op_group = orm.OpGroup(
        self.concat_op_slice,
        omit_source_op_slices=[self.concat_op_slice])

    self.relu1_op = g.get_operation_by_name('conv1/Relu')
    self.relu1_op_slice = orm.OpSlice(self.relu1_op, orm.Slice(0, 6))
    self.relu1_op_group = orm.OpGroup(
        self.relu1_op_slice, omit_source_op_slices=[self.relu1_op_slice])

    self.relu2_op = g.get_operation_by_name('conv2/Relu')
    self.relu2_op_slice = orm.OpSlice(self.relu2_op, orm.Slice(0, 6))
    self.relu2_op_group = orm.OpGroup(
        self.relu2_op_slice, omit_source_op_slices=[self.relu2_op_slice])

    self.relu3_op = g.get_operation_by_name('conv3/Relu')
    self.relu3_op_slice = orm.OpSlice(self.relu3_op, orm.Slice(0, 6))
    self.relu3_op_group = orm.OpGroup(
        self.relu3_op_slice, omit_source_op_slices=[self.relu3_op_slice])

    self.batch_norm_op = g.get_operation_by_name('BatchNorm/FusedBatchNorm')
    self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 6))
    self.batch_norm_op_group = orm.OpGroup(
        self.batch_norm_op_slice,
        omit_source_op_slices=[self.batch_norm_op_slice])

    self.concat_group = orm.OpGroup(
        op_slice=None,
        op_groups=[
            self.batch_norm_op_group, self.concat_op_group, self.relu1_op_group,
            self.relu2_op_group, self.relu3_op_group
        ])

    # Create mock OpRegularizerManager with custom mapping of OpSlice and
    # OpGroup.
    self.mock_op_reg_manager = mock.create_autospec(orm.OpRegularizerManager)

    def get_op_slices(op):
      return self.op_slice_dict.get(op, [])

    def get_op_group(op_slice):
      return self.op_group_dict.get(op_slice)

    self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices
    self.mock_op_reg_manager.get_op_group.side_effect = get_op_group
    self.mock_op_reg_manager.is_source_op.return_value = False
    self.mock_op_reg_manager.is_passthrough.return_value = True
    self.mock_op_reg_manager.ops = [
        self.concat_op, self.relu1_op, self.relu2_op, self.relu3_op,
        self.batch_norm_op]
Exemplo n.º 23
0
    def setUp(self):
        tf.reset_default_graph()

        # This tests 3 Conv2D ops being concatenated.
        inputs = tf.zeros([2, 4, 4, 3])
        c1 = layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv1')
        c2 = layers.conv2d(inputs, num_outputs=6, kernel_size=3, scope='conv2')
        c3 = layers.conv2d(inputs, num_outputs=7, kernel_size=3, scope='conv3')
        net = tf.concat([c1, c2, c3], axis=3)
        layers.batch_norm(net)

        g = tf.get_default_graph()

        # Declare OpSlice and OpGroup for ops of interest.
        self.concat_op = g.get_operation_by_name('concat')
        self.concat_op_slice = orm.OpSlice(self.concat_op, orm.Slice(0, 18))
        self.concat_op_slice_0_5 = orm.OpSlice(self.concat_op, orm.Slice(0, 5))
        self.concat_op_slice_5_11 = orm.OpSlice(self.concat_op,
                                                orm.Slice(5, 6))
        self.concat_op_slice_11_18 = orm.OpSlice(self.concat_op,
                                                 orm.Slice(11, 7))
        self.concat_op_group1 = orm.OpGroup(
            self.concat_op_slice_0_5,
            omit_source_op_slices=[self.concat_op_slice_0_5])
        self.concat_op_group2 = orm.OpGroup(
            self.concat_op_slice_5_11,
            omit_source_op_slices=[self.concat_op_slice_5_11])
        self.concat_op_group3 = orm.OpGroup(
            self.concat_op_slice_11_18,
            omit_source_op_slices=[self.concat_op_slice_11_18])

        self.relu1_op = g.get_operation_by_name('conv1/Relu')
        self.relu1_op_slice = orm.OpSlice(self.relu1_op, orm.Slice(0, 5))
        self.relu1_op_group = orm.OpGroup(
            self.relu1_op_slice, omit_source_op_slices=[self.relu1_op_slice])

        self.relu2_op = g.get_operation_by_name('conv2/Relu')
        self.relu2_op_slice = orm.OpSlice(self.relu2_op, orm.Slice(0, 6))
        self.relu2_op_group = orm.OpGroup(
            self.relu2_op_slice, omit_source_op_slices=[self.relu2_op_slice])

        self.relu3_op = g.get_operation_by_name('conv3/Relu')
        self.relu3_op_slice = orm.OpSlice(self.relu3_op, orm.Slice(0, 7))
        self.relu3_op_group = orm.OpGroup(
            self.relu3_op_slice, omit_source_op_slices=[self.relu3_op_slice])

        self.axis_op = g.get_operation_by_name('concat/axis')

        self.batch_norm_op = g.get_operation_by_name(
            'BatchNorm/FusedBatchNorm')
        self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op,
                                               orm.Slice(0, 18))
        self.batch_norm_op_group = orm.OpGroup(
            self.batch_norm_op_slice,
            omit_source_op_slices=[self.batch_norm_op_slice])
        self.batch_norm_op_slice_0_5 = orm.OpSlice(self.batch_norm_op,
                                                   orm.Slice(0, 5))
        self.batch_norm_op_slice_5_11 = orm.OpSlice(self.batch_norm_op,
                                                    orm.Slice(5, 6))
        self.batch_norm_op_slice_11_18 = orm.OpSlice(self.batch_norm_op,
                                                     orm.Slice(11, 7))
        self.batch_norm_op_group1 = orm.OpGroup(
            self.batch_norm_op_slice_0_5,
            omit_source_op_slices=[self.batch_norm_op_slice_0_5])
        self.batch_norm_op_group2 = orm.OpGroup(
            self.batch_norm_op_slice_5_11,
            omit_source_op_slices=[self.batch_norm_op_slice_5_11])
        self.batch_norm_op_group3 = orm.OpGroup(
            self.batch_norm_op_slice_11_18,
            omit_source_op_slices=[self.batch_norm_op_slice_11_18])

        # Create mock OpRegularizerManager with custom mapping of OpSlice and
        # OpGroup.
        self.mock_op_reg_manager = mock.create_autospec(
            orm.OpRegularizerManager)

        def get_op_slices(op):
            return self.op_slice_dict.get(op, [])

        def get_op_group(op_slice):
            return self.op_group_dict.get(op_slice)

        # Update op_slice_dict when an op is sliced.
        def slice_op(op, _):
            if op == self.batch_norm_op:
                self.op_slice_dict[self.batch_norm_op] = [
                    self.batch_norm_op_slice_0_5,
                    self.batch_norm_op_slice_5_11,
                    self.batch_norm_op_slice_11_18
                ]
            if op == self.concat_op:
                self.op_slice_dict[self.concat_op] = [
                    self.concat_op_slice_0_5, self.concat_op_slice_5_11,
                    self.concat_op_slice_11_18
                ]

        self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices
        self.mock_op_reg_manager.get_op_group.side_effect = get_op_group
        self.mock_op_reg_manager.is_source_op.return_value = False
        self.mock_op_reg_manager.slice_op.side_effect = slice_op
        self.mock_op_reg_manager.is_passthrough.return_value = True
        self.mock_op_reg_manager.ops = [
            self.concat_op, self.relu1_op, self.relu2_op, self.relu3_op,
            self.batch_norm_op
        ]
Exemplo n.º 24
0
    def testAssignGrouping_AllNeighborsGrouped_OutputSlicesNotAligned(self):
        # The output (batch norm) has sizes [9, 4, 5] which are not aligned.  This
        # test verifies that the concat, batch norm, and Conv2D ops are sliced in
        # alignment.
        concat_op_slice_0_5 = orm.OpSlice(self.concat_op, orm.Slice(0, 5))
        concat_op_slice_5_9 = orm.OpSlice(self.concat_op, orm.Slice(5, 4))
        concat_op_slice_9_11 = orm.OpSlice(self.concat_op, orm.Slice(9, 2))
        concat_op_slice_11_13 = orm.OpSlice(self.concat_op, orm.Slice(11, 2))
        concat_op_slice_13_18 = orm.OpSlice(self.concat_op, orm.Slice(13, 5))

        relu2_op_slice_0_4 = orm.OpSlice(self.relu2_op, orm.Slice(0, 4))
        relu2_op_slice_4_6 = orm.OpSlice(self.relu2_op, orm.Slice(4, 2))

        relu3_op_slice_0_2 = orm.OpSlice(self.relu3_op, orm.Slice(0, 2))
        relu3_op_slice_2_7 = orm.OpSlice(self.relu3_op, orm.Slice(2, 5))

        batch_norm_op_slice_0_9 = orm.OpSlice(self.batch_norm_op,
                                              orm.Slice(0, 9))
        batch_norm_op_group1 = orm.OpGroup(
            batch_norm_op_slice_0_9,
            omit_source_op_slices=[batch_norm_op_slice_0_9])
        batch_norm_op_slice_9_13 = orm.OpSlice(self.batch_norm_op,
                                               orm.Slice(9, 4))
        batch_norm_op_group2 = orm.OpGroup(
            batch_norm_op_slice_9_13,
            omit_source_op_slices=[batch_norm_op_slice_9_13])
        batch_norm_op_slice_13_18 = orm.OpSlice(self.batch_norm_op,
                                                orm.Slice(13, 5))
        batch_norm_op_group3 = orm.OpGroup(
            batch_norm_op_slice_13_18,
            omit_source_op_slices=[batch_norm_op_slice_13_18])
        batch_norm_op_slice_0_5 = orm.OpSlice(self.batch_norm_op,
                                              orm.Slice(0, 5))
        batch_norm_op_group4 = orm.OpGroup(
            batch_norm_op_slice_0_5,
            omit_source_op_slices=[batch_norm_op_slice_0_5])
        batch_norm_op_slice_5_9 = orm.OpSlice(self.batch_norm_op,
                                              orm.Slice(5, 4))
        batch_norm_op_group5 = orm.OpGroup(
            batch_norm_op_slice_5_9,
            omit_source_op_slices=[batch_norm_op_slice_5_9])
        batch_norm_op_slice_9_11 = orm.OpSlice(self.batch_norm_op,
                                               orm.Slice(9, 2))
        batch_norm_op_group6 = orm.OpGroup(
            batch_norm_op_slice_9_11,
            omit_source_op_slices=[batch_norm_op_slice_9_11])
        batch_norm_op_slice_11_13 = orm.OpSlice(self.batch_norm_op,
                                                orm.Slice(11, 2))
        batch_norm_op_group7 = orm.OpGroup(
            batch_norm_op_slice_11_13,
            omit_source_op_slices=[batch_norm_op_slice_11_13])
        batch_norm_op_slice_13_18 = orm.OpSlice(self.batch_norm_op,
                                                orm.Slice(11, 5))
        batch_norm_op_group8 = orm.OpGroup(
            batch_norm_op_slice_13_18,
            omit_source_op_slices=[batch_norm_op_slice_13_18])

        # Map ops to slices.  Batch norm op is composed of multiple slices.
        self.op_slice_dict = {
            self.relu1_op: [self.relu1_op_slice],
            self.relu2_op: [self.relu2_op_slice],
            self.relu3_op: [self.relu3_op_slice],
            self.concat_op: [self.concat_op_slice],
            self.batch_norm_op: [
                batch_norm_op_slice_0_9, batch_norm_op_slice_9_13,
                batch_norm_op_slice_13_18
            ],
        }

        # Map each slice to a group.
        self.op_group_dict = {
            self.relu1_op_slice: self.relu1_op_group,
            self.relu2_op_slice: self.relu2_op_group,
            self.relu3_op_slice: self.relu3_op_group,
            batch_norm_op_slice_0_9: batch_norm_op_group1,
            batch_norm_op_slice_9_13: batch_norm_op_group2,
            batch_norm_op_slice_13_18: batch_norm_op_group3,
            batch_norm_op_slice_0_5: batch_norm_op_group4,
            batch_norm_op_slice_5_9: batch_norm_op_group5,
            batch_norm_op_slice_9_11: batch_norm_op_group6,
            batch_norm_op_slice_11_13: batch_norm_op_group7,
            batch_norm_op_slice_13_18: batch_norm_op_group8,
        }

        # Update op_slice_dict when an op is sliced.
        def slice_op(op, _):
            if op == self.batch_norm_op:
                self.op_slice_dict[self.batch_norm_op] = [
                    batch_norm_op_slice_0_5, batch_norm_op_slice_5_9,
                    batch_norm_op_slice_9_11, batch_norm_op_slice_11_13,
                    batch_norm_op_slice_13_18
                ]
            if op == self.concat_op:
                self.op_slice_dict[self.concat_op] = [
                    concat_op_slice_0_5, concat_op_slice_5_9,
                    concat_op_slice_9_11, concat_op_slice_11_13,
                    concat_op_slice_13_18
                ]
            if op == self.relu2_op:
                self.op_slice_dict[self.relu2_op] = [
                    relu2_op_slice_0_4, relu2_op_slice_4_6
                ]
            if op == self.relu3_op:
                self.op_slice_dict[self.relu3_op] = [
                    relu3_op_slice_0_2, relu3_op_slice_2_7
                ]

        self.mock_op_reg_manager.slice_op.side_effect = slice_op

        # Call handler to assign grouping.
        handler = concat_op_handler.ConcatOpHandler()
        handler.assign_grouping(self.concat_op, self.mock_op_reg_manager)

        # Verify manager looks up OpSlice for ops of interest.
        self.mock_op_reg_manager.get_op_slices.assert_has_calls(
            # Checking for ops to process.
            [
                mock.call(self.relu1_op),
                mock.call(self.relu2_op),
                mock.call(self.relu3_op),
                mock.call(self.batch_norm_op),
                # Initial slice data.
                mock.call(self.concat_op),
                mock.call(self.relu1_op),
                mock.call(self.relu2_op),
                mock.call(self.relu3_op),
                mock.call(self.batch_norm_op),
                # Reslicing.
                mock.call(self.relu1_op),
                mock.call(self.relu2_op),
                mock.call(self.relu3_op),
                mock.call(self.batch_norm_op),
                mock.call(self.concat_op),
                # Refreshing slice data.
                mock.call(self.relu1_op),
                mock.call(self.relu2_op),
                mock.call(self.relu3_op),
                mock.call(self.batch_norm_op),
                # Group concat op.
                mock.call(self.concat_op)
            ])

        # Verify manager slices ops that do not have aligned OpSlice sizes.
        self.mock_op_reg_manager.slice_op.assert_has_calls([
            mock.call(self.relu2_op, [4, 2]),
            mock.call(self.relu3_op, [2, 5]),
            mock.call(self.batch_norm_op, [5, 4, 2, 2, 5]),
            mock.call(self.concat_op, [5, 4, 2, 2, 5])
        ])

        # Verify manager groups the new slices.
        self.mock_op_reg_manager.group_op_slices.assert_has_calls([
            mock.call([
                concat_op_slice_0_5, self.relu1_op_slice,
                batch_norm_op_slice_0_5
            ]),
            mock.call([
                concat_op_slice_5_9, relu2_op_slice_0_4,
                batch_norm_op_slice_5_9
            ]),
            mock.call([
                concat_op_slice_9_11, relu2_op_slice_4_6,
                batch_norm_op_slice_9_11
            ]),
            mock.call([
                concat_op_slice_11_13, relu3_op_slice_0_2,
                batch_norm_op_slice_11_13
            ]),
            mock.call([
                concat_op_slice_13_18, relu3_op_slice_2_7,
                batch_norm_op_slice_13_18
            ])
        ])
Exemplo n.º 25
0
    def setUp(self):
        tf.reset_default_graph()

        # This tests a Conv2D -> SeparableConv2D -> Conv2D chain of ops.
        with tf.contrib.framework.arg_scope(self._batch_norm_scope()):
            inputs = tf.zeros([2, 4, 4, 3])
            c1 = layers.conv2d(inputs,
                               num_outputs=5,
                               kernel_size=3,
                               scope='conv1')
            c2 = layers.separable_conv2d(c1,
                                         num_outputs=8,
                                         kernel_size=3,
                                         depth_multiplier=2,
                                         scope='conv2')
            layers.conv2d(c2, num_outputs=6, kernel_size=3, scope='conv3')

        g = tf.get_default_graph()

        # Declare OpSlice and OpGroup for ops of interest.
        self.dwise_conv2_op = g.get_operation_by_name(
            'conv2/separable_conv2d/depthwise')
        self.dwise_conv2_op_slice = orm.OpSlice(self.dwise_conv2_op,
                                                orm.Slice(0, 10))
        self.dwise_conv2_op_slice_0_1 = orm.OpSlice(self.dwise_conv2_op,
                                                    orm.Slice(0, 1))
        self.dwise_conv2_op_slice_1_2 = orm.OpSlice(self.dwise_conv2_op,
                                                    orm.Slice(1, 1))
        self.dwise_conv2_op_slice_2_3 = orm.OpSlice(self.dwise_conv2_op,
                                                    orm.Slice(2, 1))
        self.dwise_conv2_op_slice_3_4 = orm.OpSlice(self.dwise_conv2_op,
                                                    orm.Slice(3, 1))
        self.dwise_conv2_op_slice_4_5 = orm.OpSlice(self.dwise_conv2_op,
                                                    orm.Slice(4, 1))
        self.dwise_conv2_op_slice_5_6 = orm.OpSlice(self.dwise_conv2_op,
                                                    orm.Slice(5, 1))
        self.dwise_conv2_op_slice_6_7 = orm.OpSlice(self.dwise_conv2_op,
                                                    orm.Slice(6, 1))
        self.dwise_conv2_op_slice_7_8 = orm.OpSlice(self.dwise_conv2_op,
                                                    orm.Slice(7, 1))
        self.dwise_conv2_op_slice_8_9 = orm.OpSlice(self.dwise_conv2_op,
                                                    orm.Slice(8, 1))
        self.dwise_conv2_op_slice_9_10 = orm.OpSlice(self.dwise_conv2_op,
                                                     orm.Slice(9, 1))

        self.conv2_op = g.get_operation_by_name('conv2/separable_conv2d')
        self.conv2_op_slice = orm.OpSlice(self.conv2_op, orm.Slice(0, 8))

        self.relu1_op = g.get_operation_by_name('conv1/Relu')
        self.relu1_op_slice = orm.OpSlice(self.relu1_op, orm.Slice(0, 5))
        self.relu1_op_slice_0_1 = orm.OpSlice(self.relu1_op, orm.Slice(0, 1))
        self.relu1_op_slice_1_2 = orm.OpSlice(self.relu1_op, orm.Slice(1, 1))
        self.relu1_op_slice_2_3 = orm.OpSlice(self.relu1_op, orm.Slice(2, 1))
        self.relu1_op_slice_3_4 = orm.OpSlice(self.relu1_op, orm.Slice(3, 1))
        self.relu1_op_slice_4_5 = orm.OpSlice(self.relu1_op, orm.Slice(4, 1))
        self.relu1_op_group = orm.OpGroup(self.relu1_op_slice)

        self.conv3_op = g.get_operation_by_name('conv3/Conv2D')
        self.conv3_op_slice = orm.OpSlice(self.conv3_op, orm.Slice(0, 6))

        # Create mock OpRegularizerManager with custom mapping of OpSlice and
        # OpGroup.
        self.mock_op_reg_manager = mock.create_autospec(
            orm.OpRegularizerManager)

        self.op_slice_dict = {
            self.dwise_conv2_op: [self.dwise_conv2_op_slice],
            self.conv2_op: [self.conv2_op_slice],
            self.relu1_op: [self.relu1_op_slice],
            self.conv3_op: [self.conv3_op_slice],
        }

        def get_op_slices(op):
            return self.op_slice_dict.get(op)

        def get_op_group(op_slice):
            return self.op_group_dict.get(op_slice)

        # Update op_slice_dict when an op is sliced.
        def slice_op(op, _):
            if op == self.dwise_conv2_op:
                self.op_slice_dict[self.dwise_conv2_op] = [
                    self.dwise_conv2_op_slice_0_1,
                    self.dwise_conv2_op_slice_1_2,
                    self.dwise_conv2_op_slice_2_3,
                    self.dwise_conv2_op_slice_3_4,
                    self.dwise_conv2_op_slice_4_5,
                    self.dwise_conv2_op_slice_5_6,
                    self.dwise_conv2_op_slice_6_7,
                    self.dwise_conv2_op_slice_7_8,
                    self.dwise_conv2_op_slice_8_9,
                    self.dwise_conv2_op_slice_9_10
                ]
            if op == self.relu1_op:
                self.op_slice_dict[self.relu1_op] = [
                    self.relu1_op_slice_0_1, self.relu1_op_slice_1_2,
                    self.relu1_op_slice_2_3, self.relu1_op_slice_3_4,
                    self.relu1_op_slice_4_5
                ]

        self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices
        self.mock_op_reg_manager.get_op_group.side_effect = get_op_group
        self.mock_op_reg_manager.is_source_op.return_value = False
        self.mock_op_reg_manager.slice_op.side_effect = slice_op
        self.mock_op_reg_manager.ops = [
            self.relu1_op, self.dwise_conv2_op, self.conv2_op, self.conv3_op
        ]
Exemplo n.º 26
0
    def testAssignGrouping_ProcessNeighborGroupsWithSlices(self):
        batch_norm_op_slice_0_2 = orm.OpSlice(self.batch_norm_op,
                                              orm.Slice(0, 2))
        batch_norm_op_slice_2_3 = orm.OpSlice(self.batch_norm_op,
                                              orm.Slice(2, 1))
        batch_norm_op_group1 = orm.OpGroup(batch_norm_op_slice_0_2)
        batch_norm_op_group2 = orm.OpGroup(batch_norm_op_slice_2_3)

        conv_op_slice_0_2 = orm.OpSlice(self.conv_op, orm.Slice(0, 2))
        conv_op_slice_2_3 = orm.OpSlice(self.conv_op, orm.Slice(2, 1))
        conv_op_group1 = orm.OpGroup(conv_op_slice_0_2,
                                     omit_source_op_slices=[conv_op_slice_0_2])
        conv_op_group2 = orm.OpGroup(conv_op_slice_2_3,
                                     omit_source_op_slices=[conv_op_slice_2_3])

        relu_op_slice_0_2 = orm.OpSlice(self.relu_op, orm.Slice(0, 2))
        relu_op_slice_2_3 = orm.OpSlice(self.relu_op, orm.Slice(2, 1))
        relu_op_group1 = orm.OpGroup(relu_op_slice_0_2)
        relu_op_group2 = orm.OpGroup(relu_op_slice_2_3)

        gamma_op_slice_0_2 = orm.OpSlice(self.gamma_op, orm.Slice(0, 2))
        gamma_op_slice_2_3 = orm.OpSlice(self.gamma_op, orm.Slice(2, 1))
        gamma_op_group1 = orm.OpGroup(
            gamma_op_slice_0_2, omit_source_op_slices=[gamma_op_slice_0_2])
        gamma_op_group2 = orm.OpGroup(
            gamma_op_slice_2_3, omit_source_op_slices=[gamma_op_slice_2_3])

        beta_op_slice_0_2 = orm.OpSlice(self.beta_op, orm.Slice(0, 2))
        beta_op_slice_2_3 = orm.OpSlice(self.beta_op, orm.Slice(2, 1))
        beta_op_group1 = orm.OpGroup(beta_op_slice_0_2,
                                     omit_source_op_slices=[beta_op_slice_0_2])
        beta_op_group2 = orm.OpGroup(beta_op_slice_2_3,
                                     omit_source_op_slices=[beta_op_slice_2_3])

        mean_op_slice_0_2 = orm.OpSlice(self.mean_op, orm.Slice(0, 2))
        mean_op_slice_2_3 = orm.OpSlice(self.mean_op, orm.Slice(2, 1))
        mean_op_group1 = orm.OpGroup(mean_op_slice_0_2,
                                     omit_source_op_slices=[mean_op_slice_0_2])
        mean_op_group2 = orm.OpGroup(mean_op_slice_2_3,
                                     omit_source_op_slices=[mean_op_slice_2_3])

        std_op_slice_0_2 = orm.OpSlice(self.std_op, orm.Slice(0, 2))
        std_op_slice_2_3 = orm.OpSlice(self.std_op, orm.Slice(2, 1))
        std_op_group1 = orm.OpGroup(std_op_slice_0_2,
                                    omit_source_op_slices=[std_op_slice_0_2])
        std_op_group2 = orm.OpGroup(std_op_slice_2_3,
                                    omit_source_op_slices=[std_op_slice_2_3])

        self.op_slice_dict = {
            self.batch_norm_op:
            [batch_norm_op_slice_0_2, batch_norm_op_slice_2_3],
            self.conv_op: [conv_op_slice_0_2, conv_op_slice_2_3],
            self.relu_op: [relu_op_slice_0_2, relu_op_slice_2_3],
            self.gamma_op: [gamma_op_slice_0_2, gamma_op_slice_2_3],
            self.beta_op: [beta_op_slice_0_2, beta_op_slice_2_3],
            self.mean_op: [mean_op_slice_0_2, mean_op_slice_2_3],
            self.std_op: [std_op_slice_0_2, std_op_slice_2_3],
        }

        # All OpSlice have groups.
        self.op_group_dict = {
            batch_norm_op_slice_0_2: batch_norm_op_group1,
            batch_norm_op_slice_2_3: batch_norm_op_group2,
            conv_op_slice_0_2: conv_op_group1,
            conv_op_slice_2_3: conv_op_group2,
            relu_op_slice_0_2: relu_op_group1,
            relu_op_slice_2_3: relu_op_group2,
            gamma_op_slice_0_2: gamma_op_group1,
            gamma_op_slice_2_3: gamma_op_group2,
            beta_op_slice_0_2: beta_op_group1,
            beta_op_slice_2_3: beta_op_group2,
            mean_op_slice_0_2: mean_op_group1,
            mean_op_slice_2_3: mean_op_group2,
            std_op_slice_0_2: std_op_group1,
            std_op_slice_2_3: std_op_group2,
        }

        # Call handler to assign grouping.
        handler = batch_norm_source_op_handler.BatchNormSourceOpHandler(
            _GAMMA_THRESHOLD)
        handler.assign_grouping(self.batch_norm_op, self.mock_op_reg_manager)

        # Verify manager looks up op slice for ops of interest.
        self.mock_op_reg_manager.get_op_slices.assert_any_call(
            self.batch_norm_op)
        self.mock_op_reg_manager.get_op_slices.assert_any_call(self.conv_op)
        self.mock_op_reg_manager.get_op_slices.assert_any_call(self.relu_op)

        # Verify manager groups batch norm with outputs and inputs by slice.
        self.mock_op_reg_manager.group_op_slices.assert_has_calls([
            mock.call([
                batch_norm_op_slice_0_2, relu_op_slice_0_2, mean_op_slice_0_2,
                std_op_slice_0_2
            ]),
            mock.call([
                batch_norm_op_slice_0_2, conv_op_slice_0_2, gamma_op_slice_0_2,
                beta_op_slice_0_2
            ],
                      omit_source_op_slices=[]),
            mock.call([
                batch_norm_op_slice_2_3, relu_op_slice_2_3, mean_op_slice_2_3,
                std_op_slice_2_3
            ]),
            mock.call([
                batch_norm_op_slice_2_3, conv_op_slice_2_3, gamma_op_slice_2_3,
                beta_op_slice_2_3
            ],
                      omit_source_op_slices=[])
        ])

        # Verify manager does not reprocess any ops.
        self.mock_op_reg_manager.process_ops.assert_not_called()
        self.mock_op_reg_manager.process_ops_last.assert_not_called()