def testAssignGrouping_NoDepthMultiplier(self):
        # Repeat setUp, but with depth_multiplier=1.  Unfortunately, this involves
        # rebuilding the graph from scratch.
        tf.reset_default_graph()

        # This tests a Conv2D -> SeparableConv2D -> Conv2D chain of ops.
        with framework.arg_scope(self._batch_norm_scope()):
            inputs = tf.zeros([2, 4, 4, 3])
            c1 = layers.conv2d(inputs,
                               num_outputs=5,
                               kernel_size=3,
                               scope='conv1')
            c2 = layers.separable_conv2d(c1,
                                         num_outputs=8,
                                         kernel_size=3,
                                         depth_multiplier=1,
                                         scope='conv2')
            layers.conv2d(c2, num_outputs=6, kernel_size=3, scope='conv3')

        g = tf.get_default_graph()

        # Declare OpSlice and OpGroup for ops of interest.
        self.dwise_conv2_op = g.get_operation_by_name(
            'conv2/separable_conv2d/depthwise')
        self.dwise_conv2_op_slice = orm.OpSlice(self.dwise_conv2_op,
                                                orm.Slice(0, 5))

        self.conv2_op = g.get_operation_by_name('conv2/separable_conv2d')
        self.conv2_op_slice = orm.OpSlice(self.conv2_op, orm.Slice(0, 8))

        self.relu1_op = g.get_operation_by_name('conv1/Relu')
        self.relu1_op_slice = orm.OpSlice(self.relu1_op, orm.Slice(0, 5))
        self.relu1_op_group = orm.OpGroup(self.relu1_op_slice)

        self.conv3_op = g.get_operation_by_name('conv3/Conv2D')
        self.conv3_op_slice = orm.OpSlice(self.conv3_op, orm.Slice(0, 6))

        # Create mock OpRegularizerManager with custom mapping of OpSlice and
        # OpGroup.
        self.mock_op_reg_manager = mock.create_autospec(
            orm.OpRegularizerManager)

        self.op_slice_dict = {
            self.dwise_conv2_op: [self.dwise_conv2_op_slice],
            self.conv2_op: [self.conv2_op_slice],
            self.relu1_op: [self.relu1_op_slice],
            self.conv3_op: [self.conv3_op_slice],
        }

        def get_op_slices(op):
            return self.op_slice_dict.get(op)

        def get_op_group(op_slice):
            return self.op_group_dict.get(op_slice)

        self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices
        self.mock_op_reg_manager.get_op_group.side_effect = get_op_group
        self.mock_op_reg_manager.is_source_op.return_value = False
        self.mock_op_reg_manager.ops = [
            self.relu1_op, self.dwise_conv2_op, self.conv2_op, self.conv3_op
        ]

        # All neighbor ops have groups.
        self.op_group_dict = {
            self.relu1_op_slice: self.relu1_op_group,
        }

        # Call handler to assign grouping.
        handler = depthwise_convolution_op_handler.DepthwiseConvolutionOpHandler(
        )
        handler.assign_grouping(self.dwise_conv2_op, self.mock_op_reg_manager)

        # Verify manager looks up OpSlice for ops of interest.
        self.mock_op_reg_manager.get_op_slices.assert_has_calls(
            # Checking for ops to process.
            [
                mock.call(self.relu1_op),
                mock.call(self.conv2_op),
                # Initial slice data.
                mock.call(self.dwise_conv2_op),
                mock.call(self.relu1_op),
                # Reslicing.
                mock.call(self.relu1_op),
                mock.call(self.dwise_conv2_op),
                # Refreshing slice data.
                mock.call(self.relu1_op),
                # Group depthwise convolution.
                mock.call(self.dwise_conv2_op)
            ])

        # Verify manager groups batch norm with inputs and outputs.
        self.mock_op_reg_manager.group_op_slices.assert_called_once_with(
            [self.dwise_conv2_op_slice, self.relu1_op_slice])

        # Verify manager does not process any additional ops.
        self.mock_op_reg_manager.process_ops.assert_called_once_with(
            [self.conv2_op])
        self.mock_op_reg_manager.process_ops_last.assert_not_called()
Ejemplo n.º 2
0
    def setUp(self):
        tf.reset_default_graph()

        # This tests a Conv2D -> BatchNorm -> ReLU chain of ops.
        with tf.contrib.framework.arg_scope(self._batch_norm_scope()):
            inputs = tf.zeros([2, 4, 4, 3])
            layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv1')

        g = tf.get_default_graph()

        # Declare OpSlice and OpGroup for ops that are created in the test network.
        self.batch_norm_op = g.get_operation_by_name(
            'conv1/BatchNorm/FusedBatchNorm')
        self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op,
                                               orm.Slice(0, 5))
        self.batch_norm_op_group = orm.OpGroup(self.batch_norm_op_slice)

        self.conv_op = g.get_operation_by_name('conv1/Conv2D')
        self.conv_op_slice = orm.OpSlice(self.conv_op, orm.Slice(0, 5))
        self.conv_op_group = orm.OpGroup(
            self.conv_op_slice, omit_source_op_slices=[self.conv_op_slice])

        self.relu_op = g.get_operation_by_name('conv1/Relu')
        self.relu_op_slice = orm.OpSlice(self.relu_op, orm.Slice(0, 5))
        self.relu_op_group = orm.OpGroup(self.relu_op_slice)

        self.gamma_op = g.get_operation_by_name('conv1/BatchNorm/gamma/read')
        self.gamma_op_slice = orm.OpSlice(self.gamma_op, orm.Slice(0, 5))
        self.gamma_op_group = orm.OpGroup(
            self.gamma_op_slice, omit_source_op_slices=[self.gamma_op_slice])

        self.beta_op = g.get_operation_by_name('conv1/BatchNorm/beta/read')
        self.beta_op_slice = orm.OpSlice(self.beta_op, orm.Slice(0, 5))
        self.beta_op_group = orm.OpGroup(
            self.beta_op_slice, omit_source_op_slices=[self.beta_op_slice])

        self.mean_op = g.get_operation_by_name(
            'conv1/BatchNorm/AssignMovingAvg/sub_1')
        self.mean_op_slice = orm.OpSlice(self.mean_op, orm.Slice(0, 5))
        self.mean_op_group = orm.OpGroup(
            self.mean_op_slice, omit_source_op_slices=[self.mean_op_slice])

        self.std_op = g.get_operation_by_name(
            'conv1/BatchNorm/AssignMovingAvg_1/sub_1')
        self.std_op_slice = orm.OpSlice(self.std_op, orm.Slice(0, 5))
        self.std_op_group = orm.OpGroup(
            self.std_op_slice, omit_source_op_slices=[self.std_op_slice])

        # Create custom mapping of OpSlice and OpGroup in manager.
        self.mock_op_reg_manager = mock.create_autospec(
            orm.OpRegularizerManager)

        def get_op_slices(op):
            return self.op_slice_dict.get(op, [])

        def get_op_group(op_slice):
            return self.op_group_dict.get(op_slice)

        self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices
        self.mock_op_reg_manager.get_op_group.side_effect = get_op_group
        self.mock_op_reg_manager.is_source_op.return_value = False
        self.mock_op_reg_manager.ops = [
            self.batch_norm_op, self.conv_op, self.relu_op, self.gamma_op,
            self.beta_op, self.mean_op, self.std_op
        ]
Ejemplo n.º 3
0
    def setUp(self):
        tf.reset_default_graph()

        # This tests 2 Conv2D ops with batch norm at the top.
        with tf.contrib.framework.arg_scope(self._batch_norm_scope()):
            inputs = tf.zeros([2, 4, 4, 3])
            c1 = layers.conv2d(inputs,
                               num_outputs=5,
                               kernel_size=3,
                               scope='conv1',
                               normalizer_fn=None)
            layers.conv2d(c1, num_outputs=6, kernel_size=3, scope='conv2')

        g = tf.get_default_graph()

        # Declare OpSlice and OpGroup for ops of interest.
        self.conv1_op = g.get_operation_by_name('conv1/Conv2D')
        self.conv1_op_slice = orm.OpSlice(self.conv1_op, orm.Slice(0, 5))
        self.conv1_op_group = orm.OpGroup(
            self.conv1_op_slice, omit_source_op_slices=[self.conv1_op_slice])

        self.relu1_op = g.get_operation_by_name('conv1/Relu')
        self.relu1_op_slice = orm.OpSlice(self.relu1_op, orm.Slice(0, 5))
        self.relu1_op_group = orm.OpGroup(
            self.relu1_op_slice, omit_source_op_slices=[self.relu1_op_slice])

        self.conv2_op = g.get_operation_by_name('conv2/Conv2D')
        self.conv2_op_slice = orm.OpSlice(self.conv2_op, orm.Slice(0, 6))
        self.conv2_op_group = orm.OpGroup(
            self.conv2_op_slice, omit_source_op_slices=[self.conv2_op_slice])

        self.relu2_op = g.get_operation_by_name('conv2/Relu')
        self.relu2_op_slice = orm.OpSlice(self.relu2_op, orm.Slice(0, 6))
        self.relu2_op_group = orm.OpGroup(
            self.relu2_op_slice, omit_source_op_slices=[self.relu2_op_slice])

        self.batch_norm_op = g.get_operation_by_name(
            'conv2/BatchNorm/FusedBatchNormV3')
        self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op,
                                               orm.Slice(0, 6))
        self.batch_norm_op_group = orm.OpGroup(self.batch_norm_op_slice)

        # Create mock OpRegularizerManager with custom mapping of OpSlice and
        # OpGroup.
        self.mock_op_reg_manager = mock.create_autospec(
            orm.OpRegularizerManager)

        def get_op_slices(op):
            return self.op_slice_dict.get(op, [])

        def get_op_group(op_slice):
            return self.op_group_dict.get(op_slice)

        def is_passthrough(op):
            if op in [self.conv1_op, self.conv2_op]:
                h = output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(
                )
                return h.is_passthrough
            if op == self.batch_norm_op:
                return True
            else:
                return False

        self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices
        self.mock_op_reg_manager.get_op_group.side_effect = get_op_group
        self.mock_op_reg_manager.is_source_op.return_value = False
        self.mock_op_reg_manager.is_passthrough.side_effect = is_passthrough
        self.mock_op_reg_manager.ops = [
            self.conv1_op, self.relu1_op, self.conv2_op, self.relu2_op,
            self.batch_norm_op
        ]
  def _build(self, conv_type):
    assert conv_type in ['Conv2D', 'Conv3D']
    if conv_type == 'Conv2D':
      inputs = tf.zeros([2, 4, 4, 3])
      conv_fn = layers.conv2d
    else:
      inputs = tf.zeros([2, 4, 4, 4, 3])
      conv_fn = layers.conv3d

    c1 = conv_fn(
        inputs, num_outputs=5, kernel_size=3, scope='conv1', normalizer_fn=None)
    conv_fn(c1, num_outputs=6, kernel_size=3, scope='conv2', normalizer_fn=None)

    g = tf.get_default_graph()

    # Declare OpSlice and OpGroup for ops of interest.
    self.conv1_op = g.get_operation_by_name('conv1/' + conv_type)
    self.conv1_op_slice = orm.OpSlice(self.conv1_op, orm.Slice(0, 5))
    self.conv1_op_group = orm.OpGroup(
        self.conv1_op_slice, omit_source_op_slices=[self.conv1_op_slice])

    self.relu1_op = g.get_operation_by_name('conv1/Relu')
    self.relu1_op_slice = orm.OpSlice(self.relu1_op, orm.Slice(0, 5))
    self.relu1_op_group = orm.OpGroup(
        self.relu1_op_slice, omit_source_op_slices=[self.relu1_op_slice])

    self.conv2_op = g.get_operation_by_name('conv2/' + conv_type)
    self.conv2_op_slice = orm.OpSlice(self.conv2_op, orm.Slice(0, 6))
    self.conv2_op_group = orm.OpGroup(
        self.conv2_op_slice, omit_source_op_slices=[self.conv2_op_slice])

    self.conv2_weights_op = g.get_operation_by_name('conv2/weights/read')
    self.conv2_weights_op_slice = orm.OpSlice(
        self.conv2_weights_op, orm.Slice(0, 6))
    self.conv2_weights_op_group = orm.OpGroup(
        self.conv2_weights_op_slice,
        omit_source_op_slices=[self.conv2_weights_op_slice])

    self.relu2_op = g.get_operation_by_name('conv2/Relu')
    self.relu2_op_slice = orm.OpSlice(self.relu2_op, orm.Slice(0, 6))
    self.relu2_op_group = orm.OpGroup(
        self.relu2_op_slice, omit_source_op_slices=[self.relu2_op_slice])

    # Create mock OpRegularizerManager with custom mapping of OpSlice and
    # OpGroup.
    self.mock_op_reg_manager = mock.create_autospec(orm.OpRegularizerManager)

    def get_op_slices(op):
      return self.op_slice_dict.get(op, [])

    def get_op_group(op_slice):
      return self.op_group_dict.get(op_slice)

    def is_passthrough(op):
      if op in [self.conv1_op, self.conv2_op]:
        h = conv_source_op_handler.ConvSourceOpHandler(_DEFAULT_THRESHOLD)
        return h.is_passthrough
      else:
        return False

    self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices
    self.mock_op_reg_manager.get_op_group.side_effect = get_op_group
    self.mock_op_reg_manager.is_source_op.return_value = False
    self.mock_op_reg_manager.is_passthrough.side_effect = is_passthrough
    self.mock_op_reg_manager.ops = [
        self.conv1_op, self.relu1_op, self.conv2_op, self.relu2_op,
        self.conv2_weights_op]
Ejemplo n.º 5
0
    def testAssignGrouping_ProcessNeighborGroupsWithSlices(self):
        batch_norm_op_slice_0_2 = orm.OpSlice(self.batch_norm_op,
                                              orm.Slice(0, 2))
        batch_norm_op_slice_2_3 = orm.OpSlice(self.batch_norm_op,
                                              orm.Slice(2, 1))
        batch_norm_op_group1 = orm.OpGroup(batch_norm_op_slice_0_2)
        batch_norm_op_group2 = orm.OpGroup(batch_norm_op_slice_2_3)

        conv_op_slice_0_2 = orm.OpSlice(self.conv_op, orm.Slice(0, 2))
        conv_op_slice_2_3 = orm.OpSlice(self.conv_op, orm.Slice(2, 1))
        conv_op_group1 = orm.OpGroup(conv_op_slice_0_2,
                                     omit_source_op_slices=[conv_op_slice_0_2])
        conv_op_group2 = orm.OpGroup(conv_op_slice_2_3,
                                     omit_source_op_slices=[conv_op_slice_2_3])

        relu_op_slice_0_2 = orm.OpSlice(self.relu_op, orm.Slice(0, 2))
        relu_op_slice_2_3 = orm.OpSlice(self.relu_op, orm.Slice(2, 1))
        relu_op_group1 = orm.OpGroup(relu_op_slice_0_2)
        relu_op_group2 = orm.OpGroup(relu_op_slice_2_3)

        gamma_op_slice_0_2 = orm.OpSlice(self.gamma_op, orm.Slice(0, 2))
        gamma_op_slice_2_3 = orm.OpSlice(self.gamma_op, orm.Slice(2, 1))
        gamma_op_group1 = orm.OpGroup(
            gamma_op_slice_0_2, omit_source_op_slices=[gamma_op_slice_0_2])
        gamma_op_group2 = orm.OpGroup(
            gamma_op_slice_2_3, omit_source_op_slices=[gamma_op_slice_2_3])

        beta_op_slice_0_2 = orm.OpSlice(self.beta_op, orm.Slice(0, 2))
        beta_op_slice_2_3 = orm.OpSlice(self.beta_op, orm.Slice(2, 1))
        beta_op_group1 = orm.OpGroup(beta_op_slice_0_2,
                                     omit_source_op_slices=[beta_op_slice_0_2])
        beta_op_group2 = orm.OpGroup(beta_op_slice_2_3,
                                     omit_source_op_slices=[beta_op_slice_2_3])

        mean_op_slice_0_2 = orm.OpSlice(self.mean_op, orm.Slice(0, 2))
        mean_op_slice_2_3 = orm.OpSlice(self.mean_op, orm.Slice(2, 1))
        mean_op_group1 = orm.OpGroup(mean_op_slice_0_2,
                                     omit_source_op_slices=[mean_op_slice_0_2])
        mean_op_group2 = orm.OpGroup(mean_op_slice_2_3,
                                     omit_source_op_slices=[mean_op_slice_2_3])

        std_op_slice_0_2 = orm.OpSlice(self.std_op, orm.Slice(0, 2))
        std_op_slice_2_3 = orm.OpSlice(self.std_op, orm.Slice(2, 1))
        std_op_group1 = orm.OpGroup(std_op_slice_0_2,
                                    omit_source_op_slices=[std_op_slice_0_2])
        std_op_group2 = orm.OpGroup(std_op_slice_2_3,
                                    omit_source_op_slices=[std_op_slice_2_3])

        self.op_slice_dict = {
            self.batch_norm_op:
            [batch_norm_op_slice_0_2, batch_norm_op_slice_2_3],
            self.conv_op: [conv_op_slice_0_2, conv_op_slice_2_3],
            self.relu_op: [relu_op_slice_0_2, relu_op_slice_2_3],
            self.gamma_op: [gamma_op_slice_0_2, gamma_op_slice_2_3],
            self.beta_op: [beta_op_slice_0_2, beta_op_slice_2_3],
            self.mean_op: [mean_op_slice_0_2, mean_op_slice_2_3],
            self.std_op: [std_op_slice_0_2, std_op_slice_2_3],
        }

        # All OpSlice have groups.
        self.op_group_dict = {
            batch_norm_op_slice_0_2: batch_norm_op_group1,
            batch_norm_op_slice_2_3: batch_norm_op_group2,
            conv_op_slice_0_2: conv_op_group1,
            conv_op_slice_2_3: conv_op_group2,
            relu_op_slice_0_2: relu_op_group1,
            relu_op_slice_2_3: relu_op_group2,
            gamma_op_slice_0_2: gamma_op_group1,
            gamma_op_slice_2_3: gamma_op_group2,
            beta_op_slice_0_2: beta_op_group1,
            beta_op_slice_2_3: beta_op_group2,
            mean_op_slice_0_2: mean_op_group1,
            mean_op_slice_2_3: mean_op_group2,
            std_op_slice_0_2: std_op_group1,
            std_op_slice_2_3: std_op_group2,
        }

        # Call handler to assign grouping.
        handler = batch_norm_source_op_handler.BatchNormSourceOpHandler(
            _GAMMA_THRESHOLD)
        handler.assign_grouping(self.batch_norm_op, self.mock_op_reg_manager)

        # Verify manager looks up op slice for ops of interest.
        self.mock_op_reg_manager.get_op_slices.assert_any_call(
            self.batch_norm_op)
        self.mock_op_reg_manager.get_op_slices.assert_any_call(self.conv_op)
        self.mock_op_reg_manager.get_op_slices.assert_any_call(self.relu_op)

        # Verify manager groups batch norm with outputs and inputs by slice.
        self.mock_op_reg_manager.group_op_slices.assert_has_calls([
            mock.call([
                batch_norm_op_slice_0_2, relu_op_slice_0_2, mean_op_slice_0_2,
                std_op_slice_0_2
            ]),
            mock.call([
                batch_norm_op_slice_0_2, conv_op_slice_0_2, gamma_op_slice_0_2,
                beta_op_slice_0_2
            ],
                      omit_source_op_slices=[]),
            mock.call([
                batch_norm_op_slice_2_3, relu_op_slice_2_3, mean_op_slice_2_3,
                std_op_slice_2_3
            ]),
            mock.call([
                batch_norm_op_slice_2_3, conv_op_slice_2_3, gamma_op_slice_2_3,
                beta_op_slice_2_3
            ],
                      omit_source_op_slices=[])
        ])

        # Verify manager does not reprocess any ops.
        self.mock_op_reg_manager.process_ops.assert_not_called()
        self.mock_op_reg_manager.process_ops_last.assert_not_called()
Ejemplo n.º 6
0
  def setUp(self):
    tf.reset_default_graph()

    # This tests 3 Conv2D ops being concatenated.
    inputs = tf.zeros([2, 4, 4, 3])
    with tf.contrib.framework.arg_scope(self._get_scope()):
      c1 = layers.conv2d(inputs, num_outputs=6, kernel_size=3, scope='conv1')
      c2 = layers.conv2d(inputs, num_outputs=6, kernel_size=3, scope='conv2')
      c3 = layers.conv2d(inputs, num_outputs=6, kernel_size=3, scope='conv3')
      net = tf.concat([c1, c2, c3], axis=2)
      layers.batch_norm(net)

    g = tf.get_default_graph()

    # Declare OpSlice and OpGroup for ops of interest.
    self.concat_op = g.get_operation_by_name('concat')
    self.concat_op_slice = orm.OpSlice(self.concat_op, orm.Slice(0, 6))
    self.concat_op_group = orm.OpGroup(
        self.concat_op_slice,
        omit_source_op_slices=[self.concat_op_slice])

    self.relu1_op = g.get_operation_by_name('conv1/Relu')
    self.relu1_op_slice = orm.OpSlice(self.relu1_op, orm.Slice(0, 6))
    self.relu1_op_group = orm.OpGroup(
        self.relu1_op_slice, omit_source_op_slices=[self.relu1_op_slice])

    self.relu2_op = g.get_operation_by_name('conv2/Relu')
    self.relu2_op_slice = orm.OpSlice(self.relu2_op, orm.Slice(0, 6))
    self.relu2_op_group = orm.OpGroup(
        self.relu2_op_slice, omit_source_op_slices=[self.relu2_op_slice])

    self.relu3_op = g.get_operation_by_name('conv3/Relu')
    self.relu3_op_slice = orm.OpSlice(self.relu3_op, orm.Slice(0, 6))
    self.relu3_op_group = orm.OpGroup(
        self.relu3_op_slice, omit_source_op_slices=[self.relu3_op_slice])

    self.batch_norm_op = g.get_operation_by_name('BatchNorm/FusedBatchNormV3')
    self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 6))
    self.batch_norm_op_group = orm.OpGroup(
        self.batch_norm_op_slice,
        omit_source_op_slices=[self.batch_norm_op_slice])

    self.concat_group = orm.OpGroup(
        op_slice=None,
        op_groups=[
            self.batch_norm_op_group, self.concat_op_group, self.relu1_op_group,
            self.relu2_op_group, self.relu3_op_group
        ])

    # Create mock OpRegularizerManager with custom mapping of OpSlice and
    # OpGroup.
    self.mock_op_reg_manager = mock.create_autospec(orm.OpRegularizerManager)

    def get_op_slices(op):
      return self.op_slice_dict.get(op, [])

    def get_op_group(op_slice):
      return self.op_group_dict.get(op_slice)

    self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices
    self.mock_op_reg_manager.get_op_group.side_effect = get_op_group
    self.mock_op_reg_manager.is_source_op.return_value = False
    self.mock_op_reg_manager.is_passthrough.return_value = True
    self.mock_op_reg_manager.ops = [
        self.concat_op, self.relu1_op, self.relu2_op, self.relu3_op,
        self.batch_norm_op]
Ejemplo n.º 7
0
  def testAssignGrouping_AllNeighborsGrouped_InputSlicesNotAligned(self):
    # In this test, the op c2 has size 6 but is split into 2 slices of size 3.
    # The concat op (and its output, the batch norm) both have size 18.  This
    # test verifies that the concat and batch norm are sliced according to the
    # sizes of c1, c2, and c3, and takes into account that c2 is also composed
    # of multiple slices.
    concat_op_slice_0_5 = orm.OpSlice(self.concat_op, orm.Slice(0, 5))
    concat_op_slice_5_8 = orm.OpSlice(self.concat_op, orm.Slice(5, 3))
    concat_op_slice_8_11 = orm.OpSlice(self.concat_op, orm.Slice(8, 3))
    concat_op_slice_11_18 = orm.OpSlice(self.concat_op, orm.Slice(11, 7))

    relu2_op_slice_0_3 = orm.OpSlice(self.relu2_op, orm.Slice(0, 3))
    relu2_op_slice_3_6 = orm.OpSlice(self.relu2_op, orm.Slice(3, 3))
    relu2_op_group1 = orm.OpGroup(
        relu2_op_slice_0_3, omit_source_op_slices=[relu2_op_slice_0_3])
    relu2_op_group2 = orm.OpGroup(
        relu2_op_slice_3_6, omit_source_op_slices=[relu2_op_slice_3_6])

    batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 18))
    batch_norm_op_group = orm.OpGroup(
        batch_norm_op_slice, omit_source_op_slices=[batch_norm_op_slice])
    batch_norm_op_slice_0_5 = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 5))
    batch_norm_op_group1 = orm.OpGroup(
        batch_norm_op_slice_0_5,
        omit_source_op_slices=[batch_norm_op_slice_0_5])
    batch_norm_op_slice_5_8 = orm.OpSlice(self.batch_norm_op, orm.Slice(5, 3))
    batch_norm_op_group2 = orm.OpGroup(
        batch_norm_op_slice_5_8,
        omit_source_op_slices=[batch_norm_op_slice_5_8])
    batch_norm_op_slice_8_11 = orm.OpSlice(self.batch_norm_op, orm.Slice(8, 3))
    batch_norm_op_group3 = orm.OpGroup(
        batch_norm_op_slice_8_11,
        omit_source_op_slices=[batch_norm_op_slice_8_11])
    batch_norm_op_slice_11_18 = orm.OpSlice(
        self.batch_norm_op, orm.Slice(11, 7))
    batch_norm_op_group4 = orm.OpGroup(
        batch_norm_op_slice_11_18,
        omit_source_op_slices=[batch_norm_op_slice_11_18])

    # Map ops to slices.  The op c2 is composed of multiple slices.
    self.op_slice_dict = {
        self.relu1_op: [self.relu1_op_slice],
        self.relu2_op: [relu2_op_slice_0_3, relu2_op_slice_3_6],
        self.relu3_op: [self.relu3_op_slice],
        self.concat_op: [self.concat_op_slice],
        self.batch_norm_op: [batch_norm_op_slice],
    }

    # Map each slice to a group.
    self.op_group_dict = {
        self.relu1_op_slice: self.relu1_op_group,
        relu2_op_slice_0_3: relu2_op_group1,
        relu2_op_slice_3_6: relu2_op_group2,
        self.relu3_op_slice: self.relu3_op_group,
        batch_norm_op_slice: batch_norm_op_group,
        batch_norm_op_slice_0_5: batch_norm_op_group1,
        batch_norm_op_slice_5_8: batch_norm_op_group2,
        batch_norm_op_slice_8_11: batch_norm_op_group3,
        batch_norm_op_slice_11_18: batch_norm_op_group4,
    }

    # Update op_slice_dict when an op is sliced.
    def slice_op(op, _):
      if op == self.batch_norm_op:
        self.op_slice_dict[self.batch_norm_op] = [
            batch_norm_op_slice_0_5,
            batch_norm_op_slice_5_8,
            batch_norm_op_slice_8_11,
            batch_norm_op_slice_11_18]
      if op == self.concat_op:
        self.op_slice_dict[self.concat_op] = [
            concat_op_slice_0_5,
            concat_op_slice_5_8,
            concat_op_slice_8_11,
            concat_op_slice_11_18]

    self.mock_op_reg_manager.slice_op.side_effect = slice_op

    # Call handler to assign grouping.
    handler = concat_op_handler.ConcatOpHandler()
    handler.assign_grouping(self.concat_op, self.mock_op_reg_manager)

    # Verify manager looks up OpSlice for ops of interest.
    self.mock_op_reg_manager.get_op_slices.assert_has_calls(
        # Checking for ops to process.
        [mock.call(self.relu1_op),
         mock.call(self.relu2_op),
         mock.call(self.relu3_op),
         mock.call(self.batch_norm_op),
         # Initial slice data.
         mock.call(self.concat_op),
         mock.call(self.relu1_op),
         mock.call(self.relu2_op),
         mock.call(self.relu3_op),
         mock.call(self.batch_norm_op),
         # Reslicing.
         mock.call(self.relu1_op),
         mock.call(self.relu2_op),
         mock.call(self.relu3_op),
         mock.call(self.batch_norm_op),
         mock.call(self.concat_op),
         # Refreshing slice data.
         mock.call(self.relu1_op),
         mock.call(self.relu2_op),
         mock.call(self.relu3_op),
         mock.call(self.batch_norm_op),
         # Group concat op.
         mock.call(self.concat_op)])

    # Verify manager slices ops that do not have aligned OpSlice sizes.
    self.mock_op_reg_manager.slice_op.assert_has_calls(
        [mock.call(self.batch_norm_op, [5, 3, 3, 7]),
         mock.call(self.concat_op, [5, 3, 3, 7])])

    # Verify manager groups the new slices.
    self.mock_op_reg_manager.group_op_slices.assert_has_calls(
        [mock.call([concat_op_slice_0_5, self.relu1_op_slice,
                    batch_norm_op_slice_0_5]),
         mock.call([concat_op_slice_5_8, relu2_op_slice_0_3,
                    batch_norm_op_slice_5_8]),
         mock.call([concat_op_slice_8_11, relu2_op_slice_3_6,
                    batch_norm_op_slice_8_11]),
         mock.call([concat_op_slice_11_18, self.relu3_op_slice,
                    batch_norm_op_slice_11_18])])
Ejemplo n.º 8
0
  def testAssignGrouping_AllNeighborsGrouped_OutputSlicesNotAligned(self):
    # The output (batch norm) has sizes [9, 4, 5] which are not aligned.  This
    # test verifies that the concat, batch norm, and Conv2D ops are sliced in
    # alignment.
    concat_op_slice_0_5 = orm.OpSlice(self.concat_op, orm.Slice(0, 5))
    concat_op_slice_5_9 = orm.OpSlice(self.concat_op, orm.Slice(5, 4))
    concat_op_slice_9_11 = orm.OpSlice(self.concat_op, orm.Slice(9, 2))
    concat_op_slice_11_13 = orm.OpSlice(self.concat_op, orm.Slice(11, 2))
    concat_op_slice_13_18 = orm.OpSlice(self.concat_op, orm.Slice(13, 5))

    relu2_op_slice_0_4 = orm.OpSlice(self.relu2_op, orm.Slice(0, 4))
    relu2_op_slice_4_6 = orm.OpSlice(self.relu2_op, orm.Slice(4, 2))

    relu3_op_slice_0_2 = orm.OpSlice(self.relu3_op, orm.Slice(0, 2))
    relu3_op_slice_2_7 = orm.OpSlice(self.relu3_op, orm.Slice(2, 5))

    batch_norm_op_slice_0_9 = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 9))
    batch_norm_op_group1 = orm.OpGroup(
        batch_norm_op_slice_0_9,
        omit_source_op_slices=[batch_norm_op_slice_0_9])
    batch_norm_op_slice_9_13 = orm.OpSlice(self.batch_norm_op, orm.Slice(9, 4))
    batch_norm_op_group2 = orm.OpGroup(
        batch_norm_op_slice_9_13,
        omit_source_op_slices=[batch_norm_op_slice_9_13])
    batch_norm_op_slice_13_18 = orm.OpSlice(
        self.batch_norm_op, orm.Slice(13, 5))
    batch_norm_op_group3 = orm.OpGroup(
        batch_norm_op_slice_13_18,
        omit_source_op_slices=[batch_norm_op_slice_13_18])
    batch_norm_op_slice_0_5 = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 5))
    batch_norm_op_group4 = orm.OpGroup(
        batch_norm_op_slice_0_5,
        omit_source_op_slices=[batch_norm_op_slice_0_5])
    batch_norm_op_slice_5_9 = orm.OpSlice(self.batch_norm_op, orm.Slice(5, 4))
    batch_norm_op_group5 = orm.OpGroup(
        batch_norm_op_slice_5_9,
        omit_source_op_slices=[batch_norm_op_slice_5_9])
    batch_norm_op_slice_9_11 = orm.OpSlice(self.batch_norm_op, orm.Slice(9, 2))
    batch_norm_op_group6 = orm.OpGroup(
        batch_norm_op_slice_9_11,
        omit_source_op_slices=[batch_norm_op_slice_9_11])
    batch_norm_op_slice_11_13 = orm.OpSlice(
        self.batch_norm_op, orm.Slice(11, 2))
    batch_norm_op_group7 = orm.OpGroup(
        batch_norm_op_slice_11_13,
        omit_source_op_slices=[batch_norm_op_slice_11_13])
    batch_norm_op_slice_13_18 = orm.OpSlice(
        self.batch_norm_op, orm.Slice(11, 5))
    batch_norm_op_group8 = orm.OpGroup(
        batch_norm_op_slice_13_18,
        omit_source_op_slices=[batch_norm_op_slice_13_18])

    # Map ops to slices.  Batch norm op is composed of multiple slices.
    self.op_slice_dict = {
        self.relu1_op: [self.relu1_op_slice],
        self.relu2_op: [self.relu2_op_slice],
        self.relu3_op: [self.relu3_op_slice],
        self.concat_op: [self.concat_op_slice],
        self.batch_norm_op: [batch_norm_op_slice_0_9, batch_norm_op_slice_9_13,
                             batch_norm_op_slice_13_18],
    }

    # Map each slice to a group.
    self.op_group_dict = {
        self.relu1_op_slice: self.relu1_op_group,
        self.relu2_op_slice: self.relu2_op_group,
        self.relu3_op_slice: self.relu3_op_group,
        batch_norm_op_slice_0_9: batch_norm_op_group1,
        batch_norm_op_slice_9_13: batch_norm_op_group2,
        batch_norm_op_slice_13_18: batch_norm_op_group3,
        batch_norm_op_slice_0_5: batch_norm_op_group4,
        batch_norm_op_slice_5_9: batch_norm_op_group5,
        batch_norm_op_slice_9_11: batch_norm_op_group6,
        batch_norm_op_slice_11_13: batch_norm_op_group7,
        batch_norm_op_slice_13_18: batch_norm_op_group8,
    }

    # Update op_slice_dict when an op is sliced.
    def slice_op(op, _):
      if op == self.batch_norm_op:
        self.op_slice_dict[self.batch_norm_op] = [
            batch_norm_op_slice_0_5,
            batch_norm_op_slice_5_9,
            batch_norm_op_slice_9_11,
            batch_norm_op_slice_11_13,
            batch_norm_op_slice_13_18]
      if op == self.concat_op:
        self.op_slice_dict[self.concat_op] = [
            concat_op_slice_0_5,
            concat_op_slice_5_9,
            concat_op_slice_9_11,
            concat_op_slice_11_13,
            concat_op_slice_13_18]
      if op == self.relu2_op:
        self.op_slice_dict[self.relu2_op] = [
            relu2_op_slice_0_4,
            relu2_op_slice_4_6]
      if op == self.relu3_op:
        self.op_slice_dict[self.relu3_op] = [
            relu3_op_slice_0_2,
            relu3_op_slice_2_7]

    self.mock_op_reg_manager.slice_op.side_effect = slice_op

    # Call handler to assign grouping.
    handler = concat_op_handler.ConcatOpHandler()
    handler.assign_grouping(self.concat_op, self.mock_op_reg_manager)

    # Verify manager looks up OpSlice for ops of interest.
    self.mock_op_reg_manager.get_op_slices.assert_has_calls(
        # Checking for ops to process.
        [mock.call(self.relu1_op),
         mock.call(self.relu2_op),
         mock.call(self.relu3_op),
         mock.call(self.batch_norm_op),
         # Initial slice data.
         mock.call(self.concat_op),
         mock.call(self.relu1_op),
         mock.call(self.relu2_op),
         mock.call(self.relu3_op),
         mock.call(self.batch_norm_op),
         # Reslicing.
         mock.call(self.relu1_op),
         mock.call(self.relu2_op),
         mock.call(self.relu3_op),
         mock.call(self.batch_norm_op),
         mock.call(self.concat_op),
         # Refreshing slice data.
         mock.call(self.relu1_op),
         mock.call(self.relu2_op),
         mock.call(self.relu3_op),
         mock.call(self.batch_norm_op),
         # Group concat op.
         mock.call(self.concat_op)])

    # Verify manager slices ops that do not have aligned OpSlice sizes.
    self.mock_op_reg_manager.slice_op.assert_has_calls(
        [mock.call(self.relu2_op, [4, 2]),
         mock.call(self.relu3_op, [2, 5]),
         mock.call(self.batch_norm_op, [5, 4, 2, 2, 5]),
         mock.call(self.concat_op, [5, 4, 2, 2, 5])])

    # Verify manager groups the new slices.
    self.mock_op_reg_manager.group_op_slices.assert_has_calls(
        [mock.call([concat_op_slice_0_5, self.relu1_op_slice,
                    batch_norm_op_slice_0_5]),
         mock.call([concat_op_slice_5_9, relu2_op_slice_0_4,
                    batch_norm_op_slice_5_9]),
         mock.call([concat_op_slice_9_11, relu2_op_slice_4_6,
                    batch_norm_op_slice_9_11]),
         mock.call([concat_op_slice_11_13, relu3_op_slice_0_2,
                    batch_norm_op_slice_11_13]),
         mock.call([concat_op_slice_13_18, relu3_op_slice_2_7,
                    batch_norm_op_slice_13_18])])
Ejemplo n.º 9
0
  def setUp(self):
    tf.reset_default_graph()

    # This tests 3 Conv2D ops being concatenated.
    inputs = tf.zeros([2, 4, 4, 3])
    c1 = layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv1')
    c2 = layers.conv2d(inputs, num_outputs=6, kernel_size=3, scope='conv2')
    c3 = layers.conv2d(inputs, num_outputs=7, kernel_size=3, scope='conv3')
    net = tf.concat([c1, c2, c3], axis=3)
    layers.batch_norm(net)

    g = tf.get_default_graph()

    # Declare OpSlice and OpGroup for ops of interest.
    self.concat_op = g.get_operation_by_name('concat')
    self.concat_op_slice = orm.OpSlice(self.concat_op, orm.Slice(0, 18))
    self.concat_op_slice_0_5 = orm.OpSlice(self.concat_op, orm.Slice(0, 5))
    self.concat_op_slice_5_11 = orm.OpSlice(self.concat_op, orm.Slice(5, 6))
    self.concat_op_slice_11_18 = orm.OpSlice(self.concat_op, orm.Slice(11, 7))
    self.concat_op_group1 = orm.OpGroup(
        self.concat_op_slice_0_5,
        omit_source_op_slices=[self.concat_op_slice_0_5])
    self.concat_op_group2 = orm.OpGroup(
        self.concat_op_slice_5_11,
        omit_source_op_slices=[self.concat_op_slice_5_11])
    self.concat_op_group3 = orm.OpGroup(
        self.concat_op_slice_11_18,
        omit_source_op_slices=[self.concat_op_slice_11_18])

    self.relu1_op = g.get_operation_by_name('conv1/Relu')
    self.relu1_op_slice = orm.OpSlice(self.relu1_op, orm.Slice(0, 5))
    self.relu1_op_group = orm.OpGroup(
        self.relu1_op_slice, omit_source_op_slices=[self.relu1_op_slice])

    self.relu2_op = g.get_operation_by_name('conv2/Relu')
    self.relu2_op_slice = orm.OpSlice(self.relu2_op, orm.Slice(0, 6))
    self.relu2_op_group = orm.OpGroup(
        self.relu2_op_slice, omit_source_op_slices=[self.relu2_op_slice])

    self.relu3_op = g.get_operation_by_name('conv3/Relu')
    self.relu3_op_slice = orm.OpSlice(self.relu3_op, orm.Slice(0, 7))
    self.relu3_op_group = orm.OpGroup(
        self.relu3_op_slice, omit_source_op_slices=[self.relu3_op_slice])

    self.axis_op = g.get_operation_by_name('concat/axis')

    self.batch_norm_op = g.get_operation_by_name('BatchNorm/FusedBatchNormV3')
    self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 18))
    self.batch_norm_op_group = orm.OpGroup(
        self.batch_norm_op_slice,
        omit_source_op_slices=[self.batch_norm_op_slice])
    self.batch_norm_op_slice_0_5 = orm.OpSlice(
        self.batch_norm_op, orm.Slice(0, 5))
    self.batch_norm_op_slice_5_11 = orm.OpSlice(
        self.batch_norm_op, orm.Slice(5, 6))
    self.batch_norm_op_slice_11_18 = orm.OpSlice(
        self.batch_norm_op, orm.Slice(11, 7))
    self.batch_norm_op_group1 = orm.OpGroup(
        self.batch_norm_op_slice_0_5,
        omit_source_op_slices=[self.batch_norm_op_slice_0_5])
    self.batch_norm_op_group2 = orm.OpGroup(
        self.batch_norm_op_slice_5_11,
        omit_source_op_slices=[self.batch_norm_op_slice_5_11])
    self.batch_norm_op_group3 = orm.OpGroup(
        self.batch_norm_op_slice_11_18,
        omit_source_op_slices=[self.batch_norm_op_slice_11_18])

    # Create mock OpRegularizerManager with custom mapping of OpSlice and
    # OpGroup.
    self.mock_op_reg_manager = mock.create_autospec(orm.OpRegularizerManager)

    def get_op_slices(op):
      return self.op_slice_dict.get(op, [])

    def get_op_group(op_slice):
      return self.op_group_dict.get(op_slice)

    # Update op_slice_dict when an op is sliced.
    def slice_op(op, _):
      if op == self.batch_norm_op:
        self.op_slice_dict[self.batch_norm_op] = [
            self.batch_norm_op_slice_0_5,
            self.batch_norm_op_slice_5_11,
            self.batch_norm_op_slice_11_18]
      if op == self.concat_op:
        self.op_slice_dict[self.concat_op] = [
            self.concat_op_slice_0_5,
            self.concat_op_slice_5_11,
            self.concat_op_slice_11_18]

    self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices
    self.mock_op_reg_manager.get_op_group.side_effect = get_op_group
    self.mock_op_reg_manager.is_source_op.return_value = False
    self.mock_op_reg_manager.slice_op.side_effect = slice_op
    self.mock_op_reg_manager.is_passthrough.return_value = True
    self.mock_op_reg_manager.ops = [
        self.concat_op, self.relu1_op, self.relu2_op, self.relu3_op,
        self.batch_norm_op]
Ejemplo n.º 10
0
    def setUp(self):
        tf.reset_default_graph()

        # This tests a Conv2D -> BatchNorm -> ReLU chain of ops.
        with tf.contrib.framework.arg_scope(self._batch_norm_scope()):
            inputs = tf.zeros([2, 4, 4, 3])
            layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv1')

        # This tests 3 Conv2D ops being concatenated before a batch normalization.
        c2 = layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv2')
        c3 = layers.conv2d(inputs, num_outputs=6, kernel_size=3, scope='conv3')
        c4 = layers.conv2d(inputs, num_outputs=7, kernel_size=3, scope='conv4')
        net = tf.concat([c2, c3, c4], axis=3)
        layers.batch_norm(net)

        g = tf.get_default_graph()

        # Declare OpSlice and OpGroup for ops in the first test network.
        self.batch_norm_op = g.get_operation_by_name(
            'conv1/BatchNorm/FusedBatchNormV3')
        self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, None)
        self.batch_norm_op_group = orm.OpGroup(self.batch_norm_op_slice)

        self.conv_op = g.get_operation_by_name('conv1/Conv2D')
        self.conv_op_slice = orm.OpSlice(self.conv_op, None)
        self.conv_op_group = orm.OpGroup(
            self.conv_op_slice, omit_source_op_slices=[self.conv_op_slice])

        self.gamma_op = g.get_operation_by_name('conv1/BatchNorm/gamma/read')
        self.beta_op = g.get_operation_by_name('conv1/BatchNorm/beta/read')
        self.decay_op = g.get_operation_by_name('conv1/BatchNorm/Const')
        self.epsilon_op = g.get_operation_by_name('conv1/BatchNorm/Const_1')
        self.mean_op = g.get_operation_by_name(
            'conv1/BatchNorm/AssignMovingAvg/sub_1')
        self.std_op = g.get_operation_by_name(
            'conv1/BatchNorm/AssignMovingAvg_1/sub_1')

        self.relu_op = g.get_operation_by_name('conv1/Relu')
        self.relu_op_slice = orm.OpSlice(self.relu_op, None)
        self.relu_op_group = orm.OpGroup(
            self.relu_op_slice, omit_source_op_slices=[self.relu_op_slice])

        # Declare OpSlice and OpGroup for ops in the second test network.
        self.relu2_op = g.get_operation_by_name('conv2/Relu')
        self.relu2_op_slice = orm.OpSlice(self.relu2_op, orm.Slice(0, 5))
        self.relu2_op_group = orm.OpGroup(
            self.relu2_op_slice, omit_source_op_slices=[self.relu2_op_slice])

        self.relu3_op = g.get_operation_by_name('conv3/Relu')
        self.relu3_op_slice = orm.OpSlice(self.relu3_op, orm.Slice(0, 6))
        self.relu3_op_group = orm.OpGroup(
            self.relu3_op_slice, omit_source_op_slices=[self.relu3_op_slice])

        self.relu4_op = g.get_operation_by_name('conv4/Relu')
        self.relu4_op_slice = orm.OpSlice(self.relu4_op, orm.Slice(0, 7))
        self.relu4_op_group = orm.OpGroup(
            self.relu4_op_slice, omit_source_op_slices=[self.relu4_op_slice])

        self.unfused_batch_norm_op = g.get_operation_by_name(
            'BatchNorm/FusedBatchNormV3')
        self.unfused_batch_norm_op_slice = orm.OpSlice(
            self.unfused_batch_norm_op, orm.Slice(0, 18))

        self.concat_op = g.get_operation_by_name('concat')
        self.concat_op_slice = orm.OpSlice(self.concat_op, orm.Slice(0, 18))
        self.concat_op_group = orm.OpGroup(
            self.concat_op_slice, omit_source_op_slices=[self.concat_op_slice])

        # Create mock OpRegularizerManager with custom mapping of OpSlice and
        # OpGroup.
        self.mock_op_reg_manager = mock.create_autospec(
            orm.OpRegularizerManager)

        def get_op_slices(op):
            return self.op_slice_dict.get(op, [])

        def get_op_group(op_slice):
            return self.op_group_dict.get(op_slice)

        def is_passthrough(op):
            return op in self._passthrough_ops

        self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices
        self.mock_op_reg_manager.get_op_group.side_effect = get_op_group
        self.mock_op_reg_manager.is_passthrough.side_effect = is_passthrough
        self.mock_op_reg_manager.ops = [
            self.batch_norm_op, self.gamma_op, self.beta_op, self.decay_op,
            self.epsilon_op, self.mean_op, self.std_op, self.conv_op,
            self.relu_op, self.relu2_op, self.relu3_op, self.relu4_op,
            self.unfused_batch_norm_op, self.concat_op
        ]