コード例 #1
0
  def remove_nn():
    """Remove nearest neighbor upsampling structure and replace with TF op."""
    input_pattern = graph_matcher.OpTypePattern(
        'FakeQuantWithMinMaxVars' if is_quantized else '*')
    stack_1_pattern = graph_matcher.OpTypePattern(
        'Pack', inputs=[input_pattern, input_pattern], ordered_inputs=False)
    stack_2_pattern = graph_matcher.OpTypePattern(
        'Pack', inputs=[stack_1_pattern, stack_1_pattern], ordered_inputs=False)
    reshape_pattern = graph_matcher.OpTypePattern(
        'Reshape', inputs=[stack_2_pattern, 'Const'], ordered_inputs=False)
    consumer_pattern = graph_matcher.OpTypePattern(
        'Add|AddV2|Max|Mul', inputs=[reshape_pattern, '*'],
        ordered_inputs=False)

    match_counter = 0
    matcher = graph_matcher.GraphMatcher(consumer_pattern)
    for match in matcher.match_graph(tf.get_default_graph()):
      match_counter += 1
      projection_op = match.get_op(input_pattern)
      reshape_op = match.get_op(reshape_pattern)
      consumer_op = match.get_op(consumer_pattern)
      nn_resize = tf.image.resize_nearest_neighbor(
          projection_op.outputs[0],
          reshape_op.outputs[0].shape.dims[1:3],
          align_corners=False,
          name=os.path.split(reshape_op.name)[0] + '/resize_nearest_neighbor')

      for index, op_input in enumerate(consumer_op.inputs):
        if op_input == reshape_op.outputs[0]:
          consumer_op._update_input(index, nn_resize)  # pylint: disable=protected-access
          break

    tf.logging.info('Found and fixed {} matches'.format(match_counter))
    return match_counter
コード例 #2
0
    def test_multiple_outputs(self):
        #   -         +
        #  / \y0   y1/ \
        # x    split    z
        #       |
        #       y         (nodes are ops; edges are going up)
        g = ops.Graph()
        with g.as_default():
            x = array_ops.placeholder(dtypes.float32, shape=[1], name='x')
            y = array_ops.placeholder(dtypes.float32, shape=[2], name='y')
            y0, y1 = array_ops.split(y, num_or_size_splits=2, axis=0)
            z = array_ops.placeholder(dtypes.float32, shape=[1], name='z')
            math_ops.add(x, y0)
            math_ops.subtract(y1, z)

        y1_pattern = graph_matcher.OpTypePattern('*')
        minus_pattern = graph_matcher.OpTypePattern('Sub',
                                                    inputs=[y1_pattern, '*'])
        matcher = graph_matcher.GraphMatcher(minus_pattern)

        match_results = list(matcher.match_graph(g))
        self.assertEqual(1, len(match_results))
        match_result = match_results[0]

        self.assertEqual(y0.op, y1.op)
        self.assertEqual(match_result.get_op(y1_pattern), y1.op)
        self.assertEqual(match_result.get_tensor(y1_pattern), y1)
コード例 #3
0
def rewrite_nn_resize_op(is_quantized=False):
    """Replaces a custom nearest-neighbor resize op with the Tensorflow version.

  Some graphs use this custom version for TPU-compatibility.

  Args:
    is_quantized: True if the default graph is quantized.
  """
    input_pattern = graph_matcher.OpTypePattern(
        'FakeQuantWithMinMaxVars' if is_quantized else '*')
    reshape_1_pattern = graph_matcher.OpTypePattern(
        'Reshape', inputs=[input_pattern, 'Const'], ordered_inputs=False)
    mul_pattern = graph_matcher.OpTypePattern(
        'Mul', inputs=[reshape_1_pattern, 'Const'], ordered_inputs=False)
    # The quantization script may or may not insert a fake quant op after the
    # Mul. In either case, these min/max vars are not needed once replaced with
    # the TF version of NN resize.
    fake_quant_pattern = graph_matcher.OpTypePattern(
        'FakeQuantWithMinMaxVars',
        inputs=[mul_pattern, 'Identity', 'Identity'],
        ordered_inputs=False)
    reshape_2_pattern = graph_matcher.OpTypePattern(
        'Reshape',
        inputs=[
            graph_matcher.OneofPattern([fake_quant_pattern, mul_pattern]),
            'Const'
        ],
        ordered_inputs=False)
    add_type_name = 'Add'
    if tf.compat.forward_compatible(2019, 6, 26):
        add_type_name = 'AddV2'
    add_pattern = graph_matcher.OpTypePattern(add_type_name,
                                              inputs=[reshape_2_pattern, '*'],
                                              ordered_inputs=False)

    matcher = graph_matcher.GraphMatcher(add_pattern)
    for match in matcher.match_graph(tf.get_default_graph()):
        projection_op = match.get_op(input_pattern)
        reshape_2_op = match.get_op(reshape_2_pattern)
        add_op = match.get_op(add_pattern)
        nn_resize = tf.image.resize_nearest_neighbor(
            projection_op.outputs[0],
            add_op.outputs[0].shape.dims[1:3],
            align_corners=False,
            name=os.path.split(reshape_2_op.name)[0] +
            '/resize_nearest_neighbor')

        for index, op_input in enumerate(add_op.inputs):
            if op_input == reshape_2_op.outputs[0]:
                add_op._update_input(index, nn_resize)  # pylint: disable=protected-access
                break
コード例 #4
0
    def test_ordered_pattern(self):
        #   +            +
        #  / \          / \
        # x   y  and   y   x  should both match when ordered inputs is False.
        # Even when x and y are different operations.
        g = ops.Graph()
        with g.as_default():
            x = array_ops.placeholder(dtypes.float32, shape=[], name='x')
            y = constant_op.constant(1.0, dtype=dtypes.float32)
            plus = x + y

        add_pattern_a = graph_matcher.OpTypePattern(
            'Add', inputs=['Const', 'Placeholder'], ordered_inputs=False)
        add_pattern_b = graph_matcher.OpTypePattern(
            'Add', inputs=['Placeholder', 'Const'], ordered_inputs=False)
        add_pattern_fail = graph_matcher.OpTypePattern(
            'Add', inputs=['Const', 'Placeholder'], ordered_inputs=True)
        # Both add_pattern_a and add_pattern_b should match the graph since
        # ordered_input was set False.
        matcher_a = graph_matcher.GraphMatcher(add_pattern_a)
        self.assertEqual([
            match_result.get_op(add_pattern_a)
            for match_result in matcher_a.match_graph(g)
        ], [plus.op])
        matcher_b = graph_matcher.GraphMatcher(add_pattern_b)
        self.assertEqual([
            match_result.get_op(add_pattern_b)
            for match_result in matcher_b.match_graph(g)
        ], [plus.op])
        # But if ordered_inputs is True, the inputs list match should fail if not
        # specified in the right order.
        matcher_fail = graph_matcher.GraphMatcher(add_pattern_fail)
        self.assertEqual(
            len([
                match_result.get_op(add_pattern_fail)
                for match_result in matcher_fail.match_graph(g)
            ]), 0)
コード例 #5
0
def _FindRestFilters(graph, find_unreplaced=True):
    """Finds all ops and tensors related to found FusedBatchNorms.

  Args:
    graph: Graph to inspect.

  Returns:
    _FusedBatchNormMatches.
  """
    input_pattern = graph_matcher.OpTypePattern('*')
    weight_pattern = graph_matcher.OpTypePattern('*')
    layer_pattern = graph_matcher.OpTypePattern(
        'Conv2D|DepthwiseConv2dNative|MatMul',
        inputs=[input_pattern, weight_pattern])

    layer_pattern_matcher = graph_matcher.GraphMatcher(layer_pattern)

    def _GetLayerMatch(match_result):
        layer_op = match_result.get_op(layer_pattern)
        # layer_tensor = match_result.get_tensor(layer_pattern)
        input_tensor = match_result.get_tensor(input_pattern)
        weight_tensor = match_result.get_tensor(weight_pattern)
        output_tensor = layer_op.outputs[0]

        assert len(layer_op.outputs) == 1

        # Ensure that the output tensor has consumers, otherwise this is a dangling
        # node and not a match.
        if find_unreplaced and (not output_tensor.consumers()
                                or layer_op.name.endswith("_psb")):
            return None, None

        return layer_op, {
            "layer_op": layer_op,
            "output_tensor": output_tensor,
            "input_tensor": input_tensor,
            "weight_tensor": weight_tensor
        }

    layer_matches = []
    for match_result in layer_pattern_matcher.match_graph(graph):
        layer_op, layer_match = _GetLayerMatch(match_result)
        if layer_op is not None:
            if layer_op not in matched_layer_set:
                matched_layer_set.add(layer_op)
                layer_matches.append(layer_match)

    return layer_matches
コード例 #6
0
    def test_conv_layer(self):
        with compat.forward_compatibility_horizon(2019, 11, 11):
            g = ops.Graph()
            with g.as_default():
                inputs = array_ops.placeholder(dtypes.float32,
                                               shape=[8, 5, 5, 3])

            with contrib_ops.arg_scope([layers.batch_norm],
                                       fused=True,
                                       is_training=True,
                                       trainable=True):
                return layers.convolution(
                    inputs,
                    num_outputs=16,
                    kernel_size=3,
                    stride=1,
                    padding='VALID',
                    activation_fn=nn_ops.relu,
                    normalizer_fn=layers.batch_norm,
                    normalizer_params={},
                    weights_initializer=initializers.xavier_initializer(),
                    weights_regularizer=None,
                    biases_initializer=init_ops.zeros_initializer(),
                    biases_regularizer=None,
                    reuse=None,
                    trainable=True,
                    scope=None)

            inputs_pattern = graph_matcher.OpTypePattern('*', name='inputs')
            relu_pattern = graph_matcher.OpTypePattern(
                'Relu',
                name='relu',
                inputs=[
                    graph_matcher.OpTypePattern(
                        'FusedBatchNormV3',
                        inputs=[
                            graph_matcher.OpTypePattern(
                                'Conv2D', inputs=[inputs_pattern, '*']), '*',
                            '*', '*', '*'
                        ])
                ])
            matcher = graph_matcher.GraphMatcher(relu_pattern)
            match_results = list(matcher.match_graph(g))
            self.assertEqual(1, len(match_results))
            match_result = match_results[0]
            self.assertEqual(match_result.get_tensor(inputs_pattern), inputs)
            self.assertEqual(match_result.get_tensor('inputs'), inputs)
コード例 #7
0
    def test_oneof_pattern(self):
        #   -   +
        #  / \ / \
        # x   y   z
        g = ops.Graph()
        with g.as_default():
            x = array_ops.placeholder(dtypes.float32, shape=[], name='x')
            y = array_ops.placeholder(dtypes.float32, shape=[], name='y')
            z = array_ops.placeholder(dtypes.float32, shape=[], name='z')
            plus = x + y
            minus = y - z

        add_or_sub_pattern = graph_matcher.OpTypePattern('Add|Sub',
                                                         inputs=['*', '*'])
        matcher = graph_matcher.GraphMatcher(add_or_sub_pattern)
        self.assertEqual([
            match_result.get_op(add_or_sub_pattern)
            for match_result in matcher.match_graph(g)
        ], [plus.op, minus.op])
コード例 #8
0
ファイル: quantization_utils.py プロジェクト: WaugZ/my_dgcnn
def _FindLayersToQuantize(graph):
  """Matches layers in graph to quantize.
  """
  input_pattern = graph_matcher.OpTypePattern('*')
  # ('ConcatV2', 'Sub',
  #  'BatchMatMul', 'BatchMatMulV2', 'Mul', 'Add', 'Sum',)

  mul_patten = graph_matcher.OpTypePattern('Mul')
  add_patten = graph_matcher.OpTypePattern('Add|AddV2')
  sub_patten = graph_matcher.OpTypePattern('Sub')
  matmul_patten = graph_matcher.OpTypePattern('BatchMatMul|BatchMatMulV2')
  sum_patten = graph_matcher.OpTypePattern('Sum')
  concat_patten = graph_matcher.OpTypePattern('Concat|ConcatV2|Tile')
  topk_patten = graph_matcher.OpTypePattern('TopK|TopKV2')
  max_patten = graph_matcher.OpTypePattern('Max')
  all_pattens = [add_patten, mul_patten, sub_patten, matmul_patten, sub_patten, sum_patten, concat_patten,
                 # topk_patten,
                 # max_patten
                 ]


  layer_matches = []
  # We use matched_layer_set to ensure that layers aren't matched multiple
  # times.
  matched_layer_set = set()

  for op_patten in all_pattens:
      op_layer_matcher = graph_matcher.GraphMatcher(op_patten)
      for match_result in op_layer_matcher.match_graph(graph):
          layer_op = match_result.get_op(op_patten)
          if any(ng in layer_op.name for ng in {'quant', 'BatchNorm', 'weights',
                                                'dropout', 'fold', 'kernel', 'correction', 'post_conv'}):
              continue
          # print(layer_op)
          if any(int_type == layer_op.get_attr('T') for int_type in {tf.int32, tf.uint8, tf.int16}):
              continue

          # print(layer_op.name)
          if layer_op not in matched_layer_set:
              matched_layer_set.add(layer_op)
              layer_matches.append(_EasyLayerMatch(layer_op))

  return layer_matches
    def replace_matches(consumer_pattern):
      """Search for nearest neighbor pattern and replace with TF op."""
      match_counter = 0
      matcher = graph_matcher.GraphMatcher(consumer_pattern)
      for match in matcher.match_graph(tf.get_default_graph()):
        match_counter += 1
        projection_op = match.get_op(input_pattern)
        reshape_op = match.get_op(reshape_pattern)
        consumer_op = match.get_op(consumer_pattern)
        nn_resize = tf.image.resize_nearest_neighbor(
            projection_op.outputs[0],
            reshape_op.outputs[0].shape.dims[1:3],
            align_corners=False,
            name=os.path.split(reshape_op.name)[0] + '/resize_nearest_neighbor')

        for index, op_input in enumerate(consumer_op.inputs):
          if op_input == reshape_op.outputs[0]:
            consumer_op._update_input(index, nn_resize)  # pylint: disable=protected-access
            break

      return match_counter
コード例 #10
0
    def test_oneof_pattern(self):
        reshape_pattern = graph_matcher.OpTypePattern('Reshape')
        transpose_pattern = graph_matcher.OneofPattern([
            graph_matcher.OpTypePattern(
                'Transpose',
                name='transpose',
                inputs=[
                    graph_matcher.OpTypePattern(
                        'Slice',
                        name='slice',
                        inputs=[reshape_pattern, '*', '*']), '*'
                ]),
            graph_matcher.OpTypePattern('Transpose',
                                        name='transpose',
                                        inputs=[reshape_pattern, '*'])
        ])

        matcher = graph_matcher.GraphMatcher(transpose_pattern)

        g = ops.Graph()
        with g.as_default():
            inputs = array_ops.placeholder(dtypes.float32, shape=[6])
            reshape = array_ops.reshape(inputs, [2, 3])
            transpose = array_ops.transpose(reshape)
            [match_result] = list(matcher.match_graph(g))
            self.assertEqual(match_result.get_tensor(reshape_pattern), reshape)
            self.assertEqual(match_result.get_tensor('slice'), None)
            self.assertEqual(match_result.get_op('transpose'), transpose.op)

        g = ops.Graph()
        with g.as_default():
            inputs = array_ops.placeholder(dtypes.float32, shape=[6])
            reshape = array_ops.reshape(inputs, [2, 3])
            slicing = array_ops.slice(reshape, [0, 0], [-1, -1])
            transpose = array_ops.transpose(slicing)
            [match_result] = list(matcher.match_graph(g))
            self.assertEqual(match_result.get_tensor(reshape_pattern), reshape)
            self.assertEqual(match_result.get_tensor('slice'), slicing)
            self.assertEqual(match_result.get_op('transpose'), transpose.op)
コード例 #11
0
def _FindFusedBatchNorms(graph):
    """Finds all ops and tensors related to found FusedBatchNorms.

  Args:
    graph: Graph to inspect.

  Returns:
    _FusedBatchNormMatches.
  """
    input_pattern = graph_matcher.OpTypePattern('*')
    # In practice, the weight pattern can match a Variable or a SpaceToBatchND
    # operation that follows a variable for atrous convolutions.
    weight_pattern = graph_matcher.OpTypePattern('*')
    gamma_pattern = graph_matcher.OpTypePattern('*')
    beta_pattern = graph_matcher.OpTypePattern('*')
    mean_pattern = graph_matcher.OpTypePattern('*')
    variance_pattern = graph_matcher.OpTypePattern('*')

    moving_average_pattern = graph_matcher.OpTypePattern('*')
    bn_decay_pattern = graph_matcher.OpTypePattern('*')
    layer_pattern = graph_matcher.OpTypePattern(
        'Conv2D|DepthwiseConv2dNative|MatMul',
        inputs=[input_pattern, weight_pattern])
    batch_to_space_pattern = graph_matcher.OpTypePattern(
        'BatchToSpaceND',
        inputs=[
            layer_pattern,
            graph_matcher.OpTypePattern('*'),
            graph_matcher.OpTypePattern('*')
        ])
    # Identity between conv/matmul and bn
    layer_pattern_with_identity = graph_matcher.OpTypePattern(
        'Identity',
        inputs=[
            graph_matcher.OneofPattern([batch_to_space_pattern, layer_pattern])
        ])
    layer_output_pattern = graph_matcher.OneofPattern(
        [layer_pattern_with_identity, layer_pattern, batch_to_space_pattern])

    # MatMul has a Reshape between it and FusedBatchNorm.
    matmul_reshape_pattern = graph_matcher.OpTypePattern(
        'Reshape',
        inputs=[layer_output_pattern,
                graph_matcher.OpTypePattern('*')])

    batch_norm_pattern = graph_matcher.OpTypePattern(
        'FusedBatchNorm',
        inputs=[
            graph_matcher.OneofPattern(
                [matmul_reshape_pattern, layer_output_pattern]), gamma_pattern,
            beta_pattern, mean_pattern, variance_pattern
        ])
    matmul_bn_output_reshape_pattern = graph_matcher.OpTypePattern(
        'Reshape',
        inputs=[batch_norm_pattern,
                graph_matcher.OpTypePattern('*')])

    batch_norm_identity_pattern = graph_matcher.OpTypePattern(
        'Identity',
        inputs=[batch_norm_pattern, matmul_bn_output_reshape_pattern])

    bn_identity_matcher = graph_matcher.GraphMatcher(
        batch_norm_identity_pattern)

    bn_matcher = graph_matcher.GraphMatcher(
        graph_matcher.OneofPattern(
            [matmul_bn_output_reshape_pattern, batch_norm_pattern]))

    moving_average_sub_pattern = graph_matcher.OpTypePattern(
        'Sub', inputs=[moving_average_pattern, batch_norm_pattern])
    moving_average_mul_pattern = graph_matcher.OpTypePattern(
        'Mul', inputs=[moving_average_sub_pattern, bn_decay_pattern])

    moving_avg_mul_matcher = graph_matcher.GraphMatcher(
        moving_average_mul_pattern)

    def _GetLayerMatch(match_result):
        """Populates a layer match object containing ops/tensors for folding BNs.

    Args:
      match_result: Matched result from graph matcher

    Returns:
      layer_op: Matching conv/fc op prior to batch norm
      BatchNormMatch: _BatchNormMatch containing all required batch norm
      parameters.
    """
        moving_mean_tensor = None
        moving_variance_tensor = None
        bn_decay_mean_tensor = None
        bn_decay_var_tensor = None
        batch_to_space_op = None
        layer_op = match_result.get_op(layer_pattern)
        layer_tensor = match_result.get_tensor(layer_pattern)
        bn_id_op = match_result.get_op(batch_norm_identity_pattern)
        bn_op = match_result.get_op(batch_norm_pattern)
        if bn_id_op is None:
            bn_id_op = bn_op

        batch_epsilon = bn_op.get_attr('epsilon')

        # In the MatMul case, the output of batch norm is reshaped back into a
        # 2D tensor, so the output_tensor is the output of the Reshape op.
        output_tensor = bn_op.outputs[0]
        if layer_op.type == 'MatMul':
            output_reshape_op = match_result.get_op(
                matmul_bn_output_reshape_pattern)
            # If the matcher didn't match matmul_bn_output_reshape, there will be
            # another match for this 'MatMul' later, so we can skip this one.
            if output_reshape_op is None:
                return None, None
            output_tensor = output_reshape_op.outputs[0]

        # Ensure that the output tensor has consumers, otherwise this is a dangling
        # node and not a match.
        if not output_tensor.consumers():
            return None, None

        batch_to_space_op = match_result.get_op(batch_to_space_pattern)
        input_tensor = match_result.get_tensor(input_pattern)
        weight_tensor = match_result.get_tensor(weight_pattern)
        gamma_tensor = match_result.get_tensor(gamma_pattern)
        beta_tensor = match_result.get_tensor(beta_pattern)
        # FusedBatchNorm in training is different from that in inference. It takes
        # empty 'mean' and empty 'variance', and produces the mean and the variance
        # of the batch. Therefore, when is_training is true, mean_tensor and
        # variance_tensor point to 1st and 2nd (0-based) output of bn_op,
        # respectively; when is_training is false, they point to bn_op's inputs.
        is_training = bn_op.get_attr('is_training')
        if is_training:
            # FusedBatchNormGrad doesn't compute gradients of the batch_mean and
            # batch_variance outputs, so we need to substitute our own custom
            # gradient.
            # TODO(suharshs, raghuramank): Find a way to avoid needing this hack.
            # pylint: disable=protected-access
            bn_op._set_attr(
                '_gradient_op_type',
                attr_value_pb2.AttrValue(
                    s=compat.as_bytes('FoldFusedBatchNormGrad')))
            # pylint: enable=protected-access
            mean_tensor = bn_op.outputs[1]
            # The batch variance used during forward and backward prop is biased,
            # i.e it is calculated as: V=sum(x(k)-mu)^2/N. For the moving average
            # calculation, the variance is corrected by the term N/N-1 (Bessel's
            # correction). The variance tensor read from FuseBatchNorm has Bessel's
            # correction applied, so we undo it here.
            scope, sep, _ = bn_op.name.rpartition('/')
            g = ops.get_default_graph()
            with g.as_default(), g.name_scope(scope + sep):
                n = math_ops.cast(
                    array_ops.size(layer_tensor) / array_ops.size(mean_tensor),
                    dtypes.float32)
                variance_tensor = math_ops.multiply(
                    bn_op.outputs[2], (n - 1) / n,
                    name='Undo_Bessel_Correction')
            # TODO(suharshs): Find a way to get rid of this inner match.
            for mul_match_result in moving_avg_mul_matcher.match_graph(graph):
                sub_op = mul_match_result.get_op(moving_average_sub_pattern)
                if sub_op.inputs[1].name == bn_op.outputs[1].name:
                    # During training: Batch Mean is bn_op.outputs[1]
                    moving_mean_tensor = sub_op.inputs[0]
                    bn_decay_mean_tensor = mul_match_result.get_tensor(
                        bn_decay_pattern)
                if sub_op.inputs[1].name == bn_op.outputs[2].name:
                    # During training: Batch Var is bn_op.outputs[2]
                    moving_variance_tensor = sub_op.inputs[0]
                    bn_decay_var_tensor = mul_match_result.get_tensor(
                        bn_decay_pattern)
        else:
            mean_tensor = match_result.get_tensor(mean_pattern)
            variance_tensor = match_result.get_tensor(variance_pattern)

        return layer_op, _BatchNormMatch(
            layer_op=layer_op,
            bn_op=bn_op,
            output_tensor=output_tensor,
            input_tensor=input_tensor,
            weight_tensor=weight_tensor,
            gamma_tensor=gamma_tensor,
            beta_tensor=beta_tensor,
            mean_tensor=mean_tensor,
            variance_tensor=variance_tensor,
            moving_mean_tensor=moving_mean_tensor,
            moving_variance_tensor=moving_variance_tensor,
            bn_decay_mean_tensor=bn_decay_mean_tensor,
            bn_decay_var_tensor=bn_decay_var_tensor,
            batch_epsilon=batch_epsilon,
            batch_to_space_op=batch_to_space_op)

    layer_matches = []
    # We use matched_layer_set to ensure that layers aren't matched multiple
    # times.
    matched_layer_set = set()
    for match_result in bn_identity_matcher.match_graph(graph):
        layer_op, layer_match = _GetLayerMatch(match_result)
        if layer_op is not None:
            if layer_op not in matched_layer_set:
                matched_layer_set.add(layer_op)
                layer_matches.append(layer_match)

    for match_result in bn_matcher.match_graph(graph):
        layer_op, layer_match = _GetLayerMatch(match_result)
        if layer_op is not None:
            if layer_op not in matched_layer_set:
                matched_layer_set.add(layer_op)
                layer_matches.append(layer_match)

    return layer_matches
コード例 #12
0
def _FindFusedBatchNorms(graph):
  """Finds all ops and tensors related to found FusedBatchNorms.

  Args:
    graph: Graph to inspect.

  Yields:
    _FusedBatchNormMatches.
  """
  input_pattern = graph_matcher.OpTypePattern('*')
  weight_pattern = graph_matcher.OpTypePattern('*')
  gamma_pattern = graph_matcher.OpTypePattern('*')
  beta_pattern = graph_matcher.OpTypePattern('*')
  mean_pattern = graph_matcher.OpTypePattern('*')
  variance_pattern = graph_matcher.OpTypePattern('*')

  conv_pattern = graph_matcher.OpTypePattern(
      'Conv2D|DepthwiseConv2dNative', inputs=[input_pattern, weight_pattern])
  # MatMul has a Reshape between it and FusedBatchNorm.
  matmul_pattern = graph_matcher.OpTypePattern(
      'MatMul', inputs=[input_pattern, weight_pattern])
  matmul_reshape_pattern = graph_matcher.OpTypePattern(
      'Reshape', inputs=[matmul_pattern,
                         graph_matcher.OpTypePattern('*')])

  conv_batch_norm_pattern = graph_matcher.OpTypePattern(
      'FusedBatchNorm',
      inputs=[
          conv_pattern, gamma_pattern, beta_pattern, mean_pattern,
          variance_pattern
      ])
  matmul_batch_norm_pattern = graph_matcher.OpTypePattern(
      'FusedBatchNorm',
      inputs=[
          matmul_reshape_pattern, gamma_pattern, beta_pattern, mean_pattern,
          variance_pattern
      ])
  matmul_bn_output_reshape_pattern = graph_matcher.OpTypePattern(
      'Reshape',
      inputs=[matmul_batch_norm_pattern,
              graph_matcher.OpTypePattern('*')])

  conv_matcher = graph_matcher.GraphMatcher(conv_batch_norm_pattern)
  matmul_matcher = graph_matcher.GraphMatcher(matmul_bn_output_reshape_pattern)

  def _GetCommonTensors(match_result):
    """Gets tensors needed for FusedBatchNormMatch from match_result."""
    input_tensor = match_result.get_tensor(input_pattern)
    weight_tensor = match_result.get_tensor(weight_pattern)
    gamma_tensor = match_result.get_tensor(gamma_pattern)
    beta_tensor = match_result.get_tensor(beta_pattern)
    # FusedBatchNorm in training is different from that in inference. It takes
    # empty 'mean' and empty 'variance', and produces the mean and the variance
    # of the batch. Therefore, when is_training is true, mean_tensor and
    # variance_tensor point to 1st and 2nd (0-based) output of bn_op,
    # respectively; when is_training is false, they point to bn_op's inputs.
    is_training = bn_op.get_attr('is_training')
    if is_training:
      mean_tensor = bn_op.outputs[1]
      variance_tensor = bn_op.outputs[2]
    else:
      mean_tensor = match_result.get_tensor(mean_pattern)
      variance_tensor = match_result.get_tensor(variance_pattern)
    return (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
            variance_tensor)

  for match_result in conv_matcher.match_graph(graph):
    layer_op = match_result.get_op(conv_pattern)
    bn_op = match_result.get_op(conv_batch_norm_pattern)
    # In the case of convolution the output_tensor is the output of bn_op.
    output_tensor = bn_op.outputs[0]

    (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
     variance_tensor) = _GetCommonTensors(match_result)
    yield _FusedBatchNormMatch(
        layer_op=layer_op,
        bn_op=bn_op,
        output_tensor=output_tensor,
        input_tensor=input_tensor,
        weight_tensor=weight_tensor,
        gamma_tensor=gamma_tensor,
        beta_tensor=beta_tensor,
        mean_tensor=mean_tensor,
        variance_tensor=variance_tensor)

  for match_result in matmul_matcher.match_graph(graph):
    layer_op = match_result.get_op(matmul_pattern)
    bn_op = match_result.get_op(matmul_batch_norm_pattern)
    # In the MatMul case, the output of batch norm is reshaped back into a
    # 2D tensor, so the output_tensor is the output of the Reshape op.
    output_reshape_op = match_result.get_op(matmul_bn_output_reshape_pattern)
    output_tensor = output_reshape_op.outputs[0]

    (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
     variance_tensor) = _GetCommonTensors(match_result)
    yield _FusedBatchNormMatch(
        layer_op=layer_op,
        bn_op=bn_op,
        output_tensor=output_tensor,
        input_tensor=input_tensor,
        weight_tensor=weight_tensor,
        gamma_tensor=gamma_tensor,
        beta_tensor=beta_tensor,
        mean_tensor=mean_tensor,
        variance_tensor=variance_tensor)
コード例 #13
0
ファイル: quantize.py プロジェクト: zmsxl625/tensorflow
def _FindLayersToQuantize(graph):
    """Matches layers in graph to quantize.

  Args:
    graph: Graph to perform match on.

  Yields:
    _LayerMatches.
  """
    input_pattern = graph_matcher.OpTypePattern('*')
    weight_var_pattern = graph_matcher.OpTypePattern('|'.join(_WEIGHT_TYPES))
    weight_pattern = graph_matcher.OpTypePattern('Identity|ReadVariableOp',
                                                 inputs=[weight_var_pattern])

    folded_weight_pattern = graph_matcher.OpTypePattern('Mul')

    # The weights inputs to the layer operation can either be from the Variable or
    # the folded weight (Mul).
    layer_pattern = graph_matcher.OpTypePattern(
        '|'.join(_QUANTIZABLE_TYPES),
        inputs=[
            input_pattern,
            graph_matcher.OneofPattern([weight_pattern, folded_weight_pattern])
        ])

    folded_bias_mul_pattern = graph_matcher.OpTypePattern(
        'Mul', inputs=[graph_matcher.OpTypePattern('*'), layer_pattern])
    post_layer_op_correction_pattern = graph_matcher.OpTypePattern(
        'Add',
        inputs=[folded_bias_mul_pattern,
                graph_matcher.OpTypePattern('*')])
    folded_bias_add_pattern = graph_matcher.OpTypePattern(
        'Add',
        inputs=[
            post_layer_op_correction_pattern,
            graph_matcher.OpTypePattern('*')
        ])

    bias_add_pattern = graph_matcher.OpTypePattern('Add|BiasAdd',
                                                   inputs=[layer_pattern, '*'])

    # The bias can come from the bias add or the folded bias add.
    bypass_pattern_a = graph_matcher.OpTypePattern(
        'Add',
        inputs=[
            graph_matcher.OneofPattern(
                [bias_add_pattern, folded_bias_add_pattern]), '*'
        ])
    bypass_pattern_b = graph_matcher.OpTypePattern(
        'Add',
        inputs=[
            '*',
            graph_matcher.OneofPattern(
                [bias_add_pattern, folded_bias_add_pattern])
        ])

    # The input to the activation can come from bias add, fold bias add or the
    # bypasses.
    activation_pattern = graph_matcher.OpTypePattern(
        '|'.join(_ACTIVATION_TYPES),
        inputs=[
            graph_matcher.OneofPattern([
                bias_add_pattern, folded_bias_add_pattern, bypass_pattern_a,
                bypass_pattern_b
            ])
        ])

    layer_matcher = graph_matcher.GraphMatcher(activation_pattern)
    for match_result in layer_matcher.match_graph(graph):
        layer_op = match_result.get_op(layer_pattern)
        weight_tensor = match_result.get_tensor(weight_pattern)
        if weight_tensor is None:
            weight_tensor = match_result.get_tensor(folded_weight_pattern)
        activation_op = match_result.get_op(activation_pattern)
        bias_add_op = match_result.get_op(bias_add_pattern)
        if bias_add_op is None:
            bias_add_op = match_result.get_op(folded_bias_add_pattern)
        bypass_op = match_result.get_op(bypass_pattern_a)
        if bypass_op is None:
            bypass_op = match_result.get_op(bypass_pattern_b)
        yield _LayerMatch(layer_op, weight_tensor, activation_op, bypass_op,
                          bias_add_op)

    # Match the final layer, where there will not be an activation and instead
    # the output of the final BiasAdd must be quantized, so we treat it as the
    # 'activation_op' in the _LayerMatch.
    # TODO(suharshs): Figure out how to quantize this final layer across many
    # models.
    final_layer_matcher = graph_matcher.GraphMatcher(bias_add_pattern)
    for match_result in final_layer_matcher.match_graph(graph):
        layer_op = match_result.get_op(layer_pattern)
        weight_tensor = match_result.get_tensor(weight_pattern)
        activation_op = match_result.get_op(bias_add_pattern)
        yield _LayerMatch(layer_op, weight_tensor, activation_op, None, None)
コード例 #14
0
def _FindFusedBatchNorms(graph):
  """Finds all ops and tensors related to found FusedBatchNorms.

  Args:
    graph: Graph to inspect.

  Yields:
    _FusedBatchNormMatches.
  """
  input_pattern = graph_matcher.OpTypePattern('*')
  weight_pattern = graph_matcher.OpTypePattern('*')
  gamma_pattern = graph_matcher.OpTypePattern('*')
  beta_pattern = graph_matcher.OpTypePattern('*')
  mean_pattern = graph_matcher.OpTypePattern('*')
  variance_pattern = graph_matcher.OpTypePattern('*')

  moving_average_pattern = graph_matcher.OpTypePattern('*')
  bn_decay_pattern = graph_matcher.OpTypePattern('*')
  conv_pattern = graph_matcher.OpTypePattern(
      'Conv2D|DepthwiseConv2dNative', inputs=[input_pattern, weight_pattern])
  # MatMul has a Reshape between it and FusedBatchNorm.
  matmul_pattern = graph_matcher.OpTypePattern(
      'MatMul', inputs=[input_pattern, weight_pattern])
  matmul_reshape_pattern = graph_matcher.OpTypePattern(
      'Reshape', inputs=[matmul_pattern,
                         graph_matcher.OpTypePattern('*')])

  conv_batch_norm_pattern = graph_matcher.OpTypePattern(
      'FusedBatchNorm',
      inputs=[
          conv_pattern, gamma_pattern, beta_pattern, mean_pattern,
          variance_pattern
      ])
  conv_moving_average_sub_pattern = graph_matcher.OpTypePattern(
      'Sub', inputs=[moving_average_pattern, conv_batch_norm_pattern])
  # TODO(suharshs): Use a OneofPattern here when available
  conv_moving_average_mul_pattern = graph_matcher.OpTypePattern(
      'Mul', inputs=[conv_moving_average_sub_pattern, bn_decay_pattern])
  matmul_batch_norm_pattern = graph_matcher.OpTypePattern(
      'FusedBatchNorm',
      inputs=[
          matmul_reshape_pattern, gamma_pattern, beta_pattern, mean_pattern,
          variance_pattern
      ])
  matmul_bn_output_reshape_pattern = graph_matcher.OpTypePattern(
      'Reshape',
      inputs=[matmul_batch_norm_pattern,
              graph_matcher.OpTypePattern('*')])

  matmul_moving_average_sub_pattern = graph_matcher.OpTypePattern(
      'Sub', inputs=[moving_average_pattern, matmul_batch_norm_pattern])
  matmul_moving_average_mul_pattern = graph_matcher.OpTypePattern(
      'Mul', inputs=[matmul_moving_average_sub_pattern, bn_decay_pattern])

  conv_matcher = graph_matcher.GraphMatcher(conv_batch_norm_pattern)
  matmul_matcher = graph_matcher.GraphMatcher(matmul_bn_output_reshape_pattern)
  conv_moving_average_mul_matcher = graph_matcher.GraphMatcher(
      conv_moving_average_mul_pattern)
  matmul_moving_average_mul_matcher = graph_matcher.GraphMatcher(
      matmul_moving_average_mul_pattern)

  def _GetMovingAverageTensors(graph, moving_avg_mul_matcher,
                               moving_avg_sub_pattern, bn_op):
    """Gets the moving mean and variance tensors and the batch norm momentum."""
    for mul_match_result in moving_avg_mul_matcher.match_graph(graph):
      sub_op = mul_match_result.get_op(moving_avg_sub_pattern)

      if sub_op.inputs[1].name == bn_op.outputs[1].name:
        # During training: Batch Mean is bn_op.outputs[1]
        moving_mean_tensor = sub_op.inputs[0]
        bn_decay_mean_tensor = mul_match_result.get_tensor(bn_decay_pattern)
      if sub_op.inputs[1].name == bn_op.outputs[2].name:
        # During training: Batch Var is bn_op.outputs[2]
        moving_variance_tensor = sub_op.inputs[0]
        bn_decay_var_tensor = mul_match_result.get_tensor(bn_decay_pattern)
    return (moving_mean_tensor, bn_decay_mean_tensor, moving_variance_tensor,
            bn_decay_var_tensor)

  def _GetCommonTensors(match_result, bn_op, bn_input_tensor):
    """Gets tensors needed for FusedBatchNormMatch from match_result."""
    input_tensor = match_result.get_tensor(input_pattern)
    weight_tensor = match_result.get_tensor(weight_pattern)
    gamma_tensor = match_result.get_tensor(gamma_pattern)
    beta_tensor = match_result.get_tensor(beta_pattern)
    # FusedBatchNorm in training is different from that in inference. It takes
    # empty 'mean' and empty 'variance', and produces the mean and the variance
    # of the batch. Therefore, when is_training is true, mean_tensor and
    # variance_tensor point to 1st and 2nd (0-based) output of bn_op,
    # respectively; when is_training is false, they point to bn_op's inputs.
    is_training = bn_op.get_attr('is_training')
    if is_training:
      # FusedBatchNormGrad doesn't compute gradients of the batch_mean and
      # batch_variance outputs, so we need to substitute our own custom
      # gradient.
      # TODO(suharshs, raghuramank): Find a way to avoid needing this hack.
      # pylint: disable=protected-access
      bn_op._set_attr(
          '_gradient_op_type',
          attr_value_pb2.AttrValue(s=compat.as_bytes('FoldFusedBatchNormGrad')))
      # pylint: enable=protected-access
      mean_tensor = bn_op.outputs[1]
      # The batch variance used during forward and backward prop is biased,
      # i.e it is calculated as: V=sum(x(k)-mu)^2/N. For the moving average
      # calculation, the variance is corrected by the term N/N-1 (Bessel's
      # correction). The variance tensor read from FuseBatchNorm has bessel's
      # correction applied, so we undo it here.
      scope, sep, _ = bn_op.name.rpartition('/')
      g = ops.get_default_graph()
      with g.as_default(), g.name_scope(scope + sep):
        n = math_ops.cast(
            array_ops.size(bn_input_tensor) / array_ops.size(mean_tensor),
            dtypes.float32)
        variance_tensor = math_ops.multiply(
            bn_op.outputs[2], (n - 1) / n, name='Undo_Bessel_Correction')
    else:
      mean_tensor = match_result.get_tensor(mean_pattern)
      variance_tensor = match_result.get_tensor(variance_pattern)
    return (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
            variance_tensor)

  for match_result in conv_matcher.match_graph(graph):
    moving_mean_tensor = None
    moving_variance_tensor = None
    bn_decay_mean_tensor = None
    bn_decay_var_tensor = None
    layer_op = match_result.get_op(conv_pattern)
    layer_tensor = match_result.get_tensor(conv_pattern)
    bn_op = match_result.get_op(conv_batch_norm_pattern)
    if bn_op.get_attr('is_training'):
      (moving_mean_tensor, bn_decay_mean_tensor, moving_variance_tensor,
       bn_decay_var_tensor) = _GetMovingAverageTensors(
           graph,
           moving_avg_mul_matcher=conv_moving_average_mul_matcher,
           moving_avg_sub_pattern=conv_moving_average_sub_pattern,
           bn_op=bn_op)

    output_tensor = bn_op.outputs[0]
    batch_epsilon_tensor = bn_op.get_attr('epsilon')
    (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
     variance_tensor) = _GetCommonTensors(
         match_result,
         bn_op,
         layer_tensor,
     )
    yield _BatchNormMatch(
        layer_op=layer_op,
        bn_op=bn_op,
        output_tensor=output_tensor,
        input_tensor=input_tensor,
        weight_tensor=weight_tensor,
        gamma_tensor=gamma_tensor,
        beta_tensor=beta_tensor,
        mean_tensor=mean_tensor,
        variance_tensor=variance_tensor,
        moving_mean_tensor=moving_mean_tensor,
        moving_variance_tensor=moving_variance_tensor,
        bn_decay_mean_tensor=bn_decay_mean_tensor,
        bn_decay_var_tensor=bn_decay_var_tensor,
        batch_epsilon_tensor=batch_epsilon_tensor)

  for match_result in matmul_matcher.match_graph(graph):
    moving_mean_tensor = None
    moving_variance_tensor = None
    bn_decay_mean_tensor = None
    bn_decay_var_tensor = None
    layer_op = match_result.get_op(matmul_pattern)
    layer_tensor = match_result.get_tensor(matmul_pattern)
    bn_op = match_result.get_op(matmul_batch_norm_pattern)
    if bn_op.get_attr('is_training'):
      (moving_mean_tensor, bn_decay_mean_tensor, moving_variance_tensor,
       bn_decay_var_tensor) = _GetMovingAverageTensors(
           graph,
           moving_avg_mul_matcher=matmul_moving_average_mul_matcher,
           moving_avg_sub_pattern=matmul_moving_average_sub_pattern,
           bn_op=bn_op)

    # In the MatMul case, the output of batch norm is reshaped back into a
    # 2D tensor, so the output_tensor is the output of the Reshape op.
    output_reshape_op = match_result.get_op(matmul_bn_output_reshape_pattern)
    output_tensor = output_reshape_op.outputs[0]
    batch_epsilon_tensor = bn_op.get_attr('epsilon')

    (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
     variance_tensor) = _GetCommonTensors(match_result, bn_op, layer_tensor)
    yield _BatchNormMatch(
        layer_op=layer_op,
        bn_op=bn_op,
        output_tensor=output_tensor,
        input_tensor=input_tensor,
        weight_tensor=weight_tensor,
        gamma_tensor=gamma_tensor,
        beta_tensor=beta_tensor,
        mean_tensor=mean_tensor,
        variance_tensor=variance_tensor,
        moving_mean_tensor=moving_mean_tensor,
        moving_variance_tensor=moving_variance_tensor,
        bn_decay_mean_tensor=bn_decay_mean_tensor,
        bn_decay_var_tensor=bn_decay_var_tensor,
        batch_epsilon_tensor=batch_epsilon_tensor)
コード例 #15
0
def _FindLayersToQuantize(graph):
  """Matches layers in graph to quantize.

  The following patterns get matched. Nodes surrounded by [] will be
  optionally matched:

          weight|folded_weight
                /
         conv|fc
            |
    [post_conv_correction]
            |
     biasadd|folded_bias
            |
         [bypass]
            |
        activation
            |
   [post_activation_bypass]

  Match replacements:
    If weight|folded_weight is found, FakeQuant is added afterwards.
    If bypass is found, FakeQuant is added before and after.
    If activation is found, FakeQuant is added afterwards.
    If post_activation_bypass is found, FakeQuant is added afterwards.

  Args:
    graph: Graph to perform match on.

  Returns:
    list of _LayerMatches.
  """
  input_pattern = graph_matcher.OpTypePattern('*')
  weight_var_pattern = graph_matcher.OpTypePattern('Variable|VariableV2')
  weight_identity_pattern = graph_matcher.OpTypePattern(
      'Identity', inputs=[weight_var_pattern])
  weight_resource_var_pattern = graph_matcher.OpTypePattern('ReadVariableOp')
  folded_weight_pattern = graph_matcher.OpTypePattern('Mul')

  # The weights inputs to the layer operation can either be from the Variable or
  # the folded weight (Mul).
  layer_pattern = graph_matcher.OpTypePattern(
      '|'.join(_QUANTIZABLE_TYPES),
      inputs=[
          input_pattern,
          graph_matcher.OneofPattern([
              weight_identity_pattern, weight_resource_var_pattern,
              folded_weight_pattern
          ])
      ])

  folded_bias_mul_pattern = graph_matcher.OpTypePattern(
      'Mul', inputs=[graph_matcher.OpTypePattern('*'), layer_pattern])
  post_layer_op_correction_pattern = graph_matcher.OpTypePattern(
      'Add', inputs=[folded_bias_mul_pattern,
                     graph_matcher.OpTypePattern('*')])
  folded_bias_add_pattern = graph_matcher.OpTypePattern(
      'Add',
      inputs=[
          post_layer_op_correction_pattern,
          graph_matcher.OpTypePattern('*')
      ])

  bias_add_pattern = graph_matcher.OpTypePattern(
      'Add|BiasAdd', inputs=[layer_pattern, '*'])

  # The bias can come from the bias add or the folded bias add.
  bypass_pattern_a = graph_matcher.OpTypePattern(
      'Add',
      inputs=[
          graph_matcher.OneofPattern(
              [bias_add_pattern, folded_bias_add_pattern]), '*'
      ])
  bypass_pattern_b = graph_matcher.OpTypePattern(
      'Add',
      inputs=[
          '*',
          graph_matcher.OneofPattern(
              [bias_add_pattern, folded_bias_add_pattern])
      ])

  # The input to the activation can come from bias add, fold bias add, the
  # bypasses.
  activation_pattern = graph_matcher.OpTypePattern(
      '|'.join(_ACTIVATION_TYPES),
      inputs=[
          graph_matcher.OneofPattern([
              bias_add_pattern, folded_bias_add_pattern, bypass_pattern_a,
              bypass_pattern_b
          ])
      ])

  post_activation_bypass_pattern_a = graph_matcher.OpTypePattern(
      'Add', inputs=['*', activation_pattern])
  post_activation_bypass_pattern_b = graph_matcher.OpTypePattern(
      'Add', inputs=[activation_pattern, '*'])

  # The order of the following matching blocks is very important. Since matches
  # aren't guaranteed to be disjoint, we structure matches from largest to
  # smallest to guarantee that the largest match always wins. Additionally, we
  # ensure that we don't match layers multiple times.

  layer_matches = []
  # We use matched_layer_set to ensure that layers aren't matched multiple
  # times.
  matched_layer_set = set()

  # First, we match layers that have a post activation bypass. We do this first
  # to ensure we don't match only the first part of this layer, missing the
  # post activation bypass node.
  post_activation_bypass_layer_matcher = graph_matcher.GraphMatcher(
      graph_matcher.OneofPattern([
          post_activation_bypass_pattern_a,
          post_activation_bypass_pattern_b,
      ]))
  for match_result in post_activation_bypass_layer_matcher.match_graph(graph):
    layer_op = match_result.get_op(layer_pattern)
    weight_tensor = match_result.get_tensor(weight_identity_pattern)
    if weight_tensor is None:
      weight_tensor = match_result.get_tensor(weight_resource_var_pattern)
    if weight_tensor is None:
      weight_tensor = match_result.get_tensor(folded_weight_pattern)
    activation_op = match_result.get_op(activation_pattern)
    bias_add_op = match_result.get_op(bias_add_pattern)
    if bias_add_op is None:
      bias_add_op = match_result.get_op(folded_bias_add_pattern)
    bypass_op = match_result.get_op(bypass_pattern_a)
    if bypass_op is None:
      bypass_op = match_result.get_op(bypass_pattern_b)
    post_activation_bypass_op = match_result.get_op(
        post_activation_bypass_pattern_a)
    if post_activation_bypass_op is None:
      post_activation_bypass_op = match_result.get_op(
          post_activation_bypass_pattern_b)
    if layer_op not in matched_layer_set:
      matched_layer_set.add(layer_op)
      layer_matches.append(
          _LayerMatch(layer_op, weight_tensor, activation_op, bypass_op,
                      post_activation_bypass_op, bias_add_op))

  # Now, we match the basic layer ending at an activation. We may get duplicate
  # matches from above, but we don't add them to layer_matches.
  layer_matcher = graph_matcher.GraphMatcher(activation_pattern)
  for match_result in layer_matcher.match_graph(graph):
    layer_op = match_result.get_op(layer_pattern)
    weight_tensor = match_result.get_tensor(weight_identity_pattern)
    if weight_tensor is None:
      weight_tensor = match_result.get_tensor(weight_resource_var_pattern)
    if weight_tensor is None:
      weight_tensor = match_result.get_tensor(folded_weight_pattern)
    activation_op = match_result.get_op(activation_pattern)
    bias_add_op = match_result.get_op(bias_add_pattern)
    if bias_add_op is None:
      bias_add_op = match_result.get_op(folded_bias_add_pattern)
    bypass_op = match_result.get_op(bypass_pattern_a)
    if bypass_op is None:
      bypass_op = match_result.get_op(bypass_pattern_b)
    if layer_op not in matched_layer_set:
      matched_layer_set.add(layer_op)
      layer_matches.append(
          _LayerMatch(layer_op, weight_tensor, activation_op, bypass_op, None,
                      bias_add_op))

  # Match the final layer, where there may not be an activation and instead
  # the output of the final BiasAdd must be quantized. So we treat the BiasAdd
  # as the 'activation_op' in the _LayerMatch, to ensure that it's output is
  # quantized.
  final_layer_matcher = graph_matcher.GraphMatcher(
      graph_matcher.OneofPattern([bias_add_pattern, folded_bias_add_pattern]))
  for match_result in final_layer_matcher.match_graph(graph):
    layer_op = match_result.get_op(layer_pattern)
    weight_tensor = match_result.get_tensor(weight_identity_pattern)
    if weight_tensor is None:
      weight_tensor = match_result.get_tensor(weight_resource_var_pattern)
    if weight_tensor is None:
      weight_tensor = match_result.get_tensor(folded_weight_pattern)
    activation_op = match_result.get_op(bias_add_pattern)
    if activation_op is None:
      activation_op = match_result.get_op(folded_bias_add_pattern)
    if layer_op not in matched_layer_set:
      matched_layer_set.add(layer_op)
      layer_matches.append(
          _LayerMatch(layer_op, weight_tensor, activation_op, None, None, None))

  return layer_matches
コード例 #16
0
ファイル: quantize.py プロジェクト: wxfwzt/tensorflow-1
def _FindLayersToQuantize(graph):
    """Matches layers in graph to quantize.

  The following patterns get matched. Nodes surrounded by [] will be
  optionally matched:

          weight|folded_weight
                /
         conv|fc
            |
      [batch_to_space_nd]
            |
    [post_conv_correction]
            |
     biasadd|folded_bias
            |
         [bypass]
            |
        activation
            |
   [post_activation_bypass]

  Match replacements:
    If weight|folded_weight is found, FakeQuant is added afterwards.
    If bypass is found, FakeQuant is added before and after.
    If activation is found, FakeQuant is added afterwards.
    If post_activation_bypass is found, FakeQuant is added afterwards.

  Args:
    graph: Graph to perform match on.

  Returns:
    list of _LayerMatches.
  """
    input_pattern = graph_matcher.OpTypePattern('*')
    weight_var_pattern = graph_matcher.OpTypePattern('Variable|VariableV2')
    weight_partition_identity_pattern = graph_matcher.OpTypePattern(
        'Identity', inputs=[weight_var_pattern])
    weight_partition_concat_pattern = graph_matcher.OpTypePattern(
        'ConcatV2', inputs=[weight_partition_identity_pattern, '*', '*'])
    weight_identity_pattern = graph_matcher.OpTypePattern(
        'Identity',
        inputs=[
            graph_matcher.OneofPattern([
                weight_partition_identity_pattern,
                weight_partition_concat_pattern,
                weight_var_pattern,
            ])
        ])
    weight_resource_var_pattern = graph_matcher.OpTypePattern('ReadVariableOp')
    folded_weight_pattern = graph_matcher.OpTypePattern('Mul')

    # The weights inputs to the layer operation can either be from the Variable or
    # the folded weight (Mul).
    layer_pattern = graph_matcher.OpTypePattern(
        '|'.join(_QUANTIZABLE_TYPES),
        inputs=[
            input_pattern,
            graph_matcher.OneofPattern([
                weight_identity_pattern, weight_resource_var_pattern,
                folded_weight_pattern
            ])
        ],
        ordered_inputs=False)

    # For atrous convolutions a BatchToSpaceND will occur after the depthwise
    # convolution.
    batch_to_space_pattern = graph_matcher.OpTypePattern(
        'BatchToSpaceND',
        inputs=[
            layer_pattern,
            graph_matcher.OpTypePattern('*'),
            graph_matcher.OpTypePattern('*')
        ])

    layer_output_pattern = graph_matcher.OneofPattern(
        [batch_to_space_pattern, layer_pattern])

    # For separable convolutions, we are looking for a conv, followed by a conv
    # with no activations between the two.
    sep_conv_pattern = graph_matcher.OpTypePattern(
        '|'.join(_QUANTIZABLE_TYPES),
        inputs=[
            graph_matcher.OneofPattern([layer_output_pattern]),
            graph_matcher.OpTypePattern('*')
        ],
        ordered_inputs=False)
    folded_bias_mul_pattern = graph_matcher.OpTypePattern(
        'Mul',
        inputs=[graph_matcher.OpTypePattern('*'), layer_output_pattern],
        ordered_inputs=False)
    post_layer_op_correction_pattern = graph_matcher.OpTypePattern(
        'Add',
        inputs=[folded_bias_mul_pattern,
                graph_matcher.OpTypePattern('*')],
        ordered_inputs=False)
    folded_bias_add_pattern = graph_matcher.OpTypePattern(
        'Add',
        inputs=[
            post_layer_op_correction_pattern,
            graph_matcher.OpTypePattern('*')
        ],
        ordered_inputs=False)

    # batch_norms with forced updates have an Identity operation at the end.
    # TODO(suharshs): Find a way to easily skip extra Identity operations. The
    # current issue is that doing so can often match patterns across many layers
    # incorrectly.
    batch_norm_identity = graph_matcher.OpTypePattern(
        'Identity', inputs=[folded_bias_add_pattern])

    bias_add_pattern = graph_matcher.OpTypePattern(
        'Add|BiasAdd',
        inputs=[layer_output_pattern, '*'],
        ordered_inputs=False)

    # The bias can come from the bias add or the folded bias add.
    bypass_pattern = graph_matcher.OpTypePattern(
        'Add',
        inputs=[
            graph_matcher.OneofPattern([
                bias_add_pattern, folded_bias_add_pattern, batch_norm_identity
            ]), '*'
        ],
        ordered_inputs=False)

    # The input to the activation can come from bias add, fold bias add, the
    # bypasses.
    # TODO(suharshs): We should ideally skip Identity operations instead of
    # treating them as activations.
    activation_pattern = graph_matcher.OpTypePattern(
        '|'.join(_ACTIVATION_TYPES) + '|Identity',
        inputs=[
            graph_matcher.OneofPattern([
                bias_add_pattern,
                folded_bias_add_pattern,
                batch_norm_identity,
                bypass_pattern,
            ])
        ])

    post_activation_bypass_pattern = graph_matcher.OpTypePattern(
        'Add', inputs=['*', activation_pattern], ordered_inputs=False)

    # The order of the following matching blocks is very important. Since matches
    # aren't guaranteed to be disjoint, we structure matches from largest to
    # smallest to guarantee that the largest match always wins. Additionally, we
    # ensure that we don't match layers multiple times.

    layer_matches = []
    # We use matched_layer_set to ensure that layers aren't matched multiple
    # times.
    matched_layer_set = set()

    # First, we match layers that have a post activation bypass. We do this first
    # to ensure we don't match only the first part of this layer, missing the
    # post activation bypass node.
    post_activation_bypass_layer_matcher = graph_matcher.GraphMatcher(
        post_activation_bypass_pattern)
    for match_result in post_activation_bypass_layer_matcher.match_graph(
            graph):
        layer_op = match_result.get_op(layer_pattern)
        weight_tensor = match_result.get_tensor(weight_identity_pattern)
        if weight_tensor is None:
            weight_tensor = match_result.get_tensor(
                weight_resource_var_pattern)
        if weight_tensor is None:
            weight_tensor = match_result.get_tensor(folded_weight_pattern)
        activation_op = match_result.get_op(activation_pattern)
        bias_add_op = match_result.get_op(bias_add_pattern)
        if bias_add_op is None:
            bias_add_op = match_result.get_op(folded_bias_add_pattern)
        bypass_op = match_result.get_op(bypass_pattern)
        post_activation_bypass_op = match_result.get_op(
            post_activation_bypass_pattern)
        if layer_op not in matched_layer_set:
            matched_layer_set.add(layer_op)
            layer_matches.append(
                _LayerMatch(layer_op, weight_tensor, activation_op, bypass_op,
                            post_activation_bypass_op, bias_add_op))

    # Now, we match the basic layer ending at an activation. We may get duplicate
    # matches from above, but we don't add them to layer_matches.
    layer_matcher = graph_matcher.GraphMatcher(activation_pattern)
    for match_result in layer_matcher.match_graph(graph):
        layer_op = match_result.get_op(layer_pattern)
        weight_tensor = match_result.get_tensor(weight_identity_pattern)
        if weight_tensor is None:
            weight_tensor = match_result.get_tensor(
                weight_resource_var_pattern)
        if weight_tensor is None:
            weight_tensor = match_result.get_tensor(folded_weight_pattern)
        activation_op = match_result.get_op(activation_pattern)
        bias_add_op = match_result.get_op(bias_add_pattern)
        if bias_add_op is None:
            bias_add_op = match_result.get_op(folded_bias_add_pattern)
        bypass_op = match_result.get_op(bypass_pattern)
        if layer_op not in matched_layer_set:
            matched_layer_set.add(layer_op)
            layer_matches.append(
                _LayerMatch(layer_op, weight_tensor, activation_op, bypass_op,
                            None, bias_add_op))

    # Match the final layer, where there may not be an activation and instead
    # the output of the final BiasAdd must be quantized. So we treat the BiasAdd
    # as the 'activation_op' in the _LayerMatch, to ensure that it's output is
    # quantized.
    final_layer_matcher = graph_matcher.GraphMatcher(
        graph_matcher.OneofPattern([bias_add_pattern,
                                    folded_bias_add_pattern]))
    for match_result in final_layer_matcher.match_graph(graph):
        layer_op = match_result.get_op(layer_pattern)
        weight_tensor = match_result.get_tensor(weight_identity_pattern)
        if weight_tensor is None:
            weight_tensor = match_result.get_tensor(
                weight_resource_var_pattern)
        if weight_tensor is None:
            weight_tensor = match_result.get_tensor(folded_weight_pattern)
        activation_op = match_result.get_op(bias_add_pattern)
        if activation_op is None:
            activation_op = match_result.get_op(folded_bias_add_pattern)
        if layer_op not in matched_layer_set:
            matched_layer_set.add(layer_op)
            layer_matches.append(
                _LayerMatch(layer_op, weight_tensor, activation_op, None, None,
                            None))

    # Look for separable convolutions here
    sep_conv_matcher = graph_matcher.GraphMatcher(sep_conv_pattern)
    for match_result in sep_conv_matcher.match_graph(graph):
        layer_op = match_result.get_op(layer_pattern)
        weight_tensor = match_result.get_tensor(weight_identity_pattern)
        activation_op = match_result.get_op(layer_pattern)
        if layer_op not in matched_layer_set:
            matched_layer_set.add(layer_op)
            layer_matches.append(
                _LayerMatch(layer_op, weight_tensor, activation_op, None, None,
                            None))

    return layer_matches
コード例 #17
0
ファイル: quantize.py プロジェクト: zqli-90s/tensorflow
def _FindLayersToQuantize(graph):
    """Matches layers in graph to quantize.

  Args:
    graph: Graph to perform match on.

  Yields:
    _LayerMatches.
  """
    input_pattern = graph_matcher.OpTypePattern('*')
    weight_var_pattern = graph_matcher.OpTypePattern('|'.join(_WEIGHT_TYPES))
    weight_pattern = graph_matcher.OpTypePattern('Identity',
                                                 inputs=[weight_var_pattern])

    folded_weight_pattern = graph_matcher.OpTypePattern('Mul')

    # The weights inputs to the layer operation can either be from the Variable or
    # the folded weight (Mul).
    layer_pattern = graph_matcher.OpTypePattern(
        '|'.join(_QUANTIZABLE_TYPES),
        inputs=[
            input_pattern,
            graph_matcher.OneofPattern([weight_pattern, folded_weight_pattern])
        ])

    folded_bias_mul_pattern = graph_matcher.OpTypePattern(
        'Mul', inputs=[graph_matcher.OpTypePattern('*'), layer_pattern])
    post_layer_op_correction_pattern = graph_matcher.OpTypePattern(
        'Add',
        inputs=[folded_bias_mul_pattern,
                graph_matcher.OpTypePattern('*')])
    folded_bias_add_pattern = graph_matcher.OpTypePattern(
        'Add',
        inputs=[
            post_layer_op_correction_pattern,
            graph_matcher.OpTypePattern('*')
        ])

    bias_add_pattern = graph_matcher.OpTypePattern('Add|BiasAdd',
                                                   inputs=[layer_pattern, '*'])

    # The bias can come from the bias add or the folded bias add.
    bypass_pattern_a = graph_matcher.OpTypePattern(
        'Add',
        inputs=[
            graph_matcher.OneofPattern(
                [bias_add_pattern, folded_bias_add_pattern]), '*'
        ])
    bypass_pattern_b = graph_matcher.OpTypePattern(
        'Add',
        inputs=[
            '*',
            graph_matcher.OneofPattern(
                [bias_add_pattern, folded_bias_add_pattern])
        ])

    # The input to the activation can come from bias add, fold bias add or the
    # bypasses.
    activation_pattern = graph_matcher.OpTypePattern(
        '|'.join(_ACTIVATION_TYPES),
        inputs=[
            graph_matcher.OneofPattern([
                bias_add_pattern, folded_bias_add_pattern, bypass_pattern_a,
                bypass_pattern_b
            ])
        ])

    layer_matcher = graph_matcher.GraphMatcher(activation_pattern)
    for match_result in layer_matcher.match_graph(graph):
        layer_op = match_result.get_op(layer_pattern)
        weight_tensor = match_result.get_tensor(weight_pattern)
        if weight_tensor is None:
            weight_tensor = match_result.get_tensor(folded_weight_pattern)
        activation_op = match_result.get_op(activation_pattern)
        bias_add_op = match_result.get_op(bias_add_pattern)
        if bias_add_op is None:
            bias_add_op = match_result.get_op(folded_bias_add_pattern)
        bypass_op = match_result.get_op(bypass_pattern_a)
        if bypass_op is None:
            bypass_op = match_result.get_op(bypass_pattern_b)
        yield _LayerMatch(layer_op, weight_tensor, activation_op, bypass_op,
                          bias_add_op)
コード例 #18
0
def _FindLayersToQuantize(graph):
    """Matches layers in graph to quantize.

  The following patterns get matched. Nodes surrounded by [] will be
  optionally matched:

          weight|folded_weight
                /
         conv|fc
            |
    [post_conv_correction]
            |
     biasadd|folded_bias
            |
         [bypass]
            |
        activation
            |
   [post_activation_bypass]

  Match replacements:
    If weight_folded_weight is found, FakeQuant is added afterwards.
    If bypass is found, FakeQuant is added before and after.
    If activation is found, FakeQuant is added afterwards.
    If post_activation_bypass is found, FakeQuant is added afterwards.

  Args:
    graph: Graph to perform match on.

  Yields:
    _LayerMatches.
  """
    input_pattern = graph_matcher.OpTypePattern('*')
    weight_var_pattern = graph_matcher.OpTypePattern('|'.join(_WEIGHT_TYPES))
    weight_pattern = graph_matcher.OpTypePattern('Identity|ReadVariableOp',
                                                 inputs=[weight_var_pattern])

    folded_weight_pattern = graph_matcher.OpTypePattern('Mul')

    # The weights inputs to the layer operation can either be from the Variable or
    # the folded weight (Mul).
    layer_pattern = graph_matcher.OpTypePattern(
        '|'.join(_QUANTIZABLE_TYPES),
        inputs=[
            input_pattern,
            graph_matcher.OneofPattern([weight_pattern, folded_weight_pattern])
        ])

    folded_bias_mul_pattern = graph_matcher.OpTypePattern(
        'Mul', inputs=[graph_matcher.OpTypePattern('*'), layer_pattern])
    post_layer_op_correction_pattern = graph_matcher.OpTypePattern(
        'Add',
        inputs=[folded_bias_mul_pattern,
                graph_matcher.OpTypePattern('*')])
    folded_bias_add_pattern = graph_matcher.OpTypePattern(
        'Add',
        inputs=[
            post_layer_op_correction_pattern,
            graph_matcher.OpTypePattern('*')
        ])

    bias_add_pattern = graph_matcher.OpTypePattern('Add|BiasAdd',
                                                   inputs=[layer_pattern, '*'])

    # The bias can come from the bias add or the folded bias add.
    bypass_pattern_a = graph_matcher.OpTypePattern(
        'Add',
        inputs=[
            graph_matcher.OneofPattern(
                [bias_add_pattern, folded_bias_add_pattern]), '*'
        ])
    bypass_pattern_b = graph_matcher.OpTypePattern(
        'Add',
        inputs=[
            '*',
            graph_matcher.OneofPattern(
                [bias_add_pattern, folded_bias_add_pattern])
        ])

    # The input to the activation can come from bias add, fold bias add, the
    # bypasses.
    activation_pattern = graph_matcher.OpTypePattern(
        '|'.join(_ACTIVATION_TYPES),
        inputs=[
            graph_matcher.OneofPattern([
                bias_add_pattern, folded_bias_add_pattern, bypass_pattern_a,
                bypass_pattern_b
            ])
        ])

    post_activation_bypass_pattern_a = graph_matcher.OpTypePattern(
        'Add', inputs=['*', activation_pattern])
    post_activation_bypass_pattern_b = graph_matcher.OpTypePattern(
        'Add', inputs=[activation_pattern, '*'])

    layer_matcher = graph_matcher.GraphMatcher(
        graph_matcher.OneofPattern([
            post_activation_bypass_pattern_a, post_activation_bypass_pattern_b,
            activation_pattern
        ]))
    for match_result in layer_matcher.match_graph(graph):
        layer_op = match_result.get_op(layer_pattern)
        weight_tensor = match_result.get_tensor(weight_pattern)
        if weight_tensor is None:
            weight_tensor = match_result.get_tensor(folded_weight_pattern)
        activation_op = match_result.get_op(activation_pattern)
        bias_add_op = match_result.get_op(bias_add_pattern)
        if bias_add_op is None:
            bias_add_op = match_result.get_op(folded_bias_add_pattern)
        bypass_op = match_result.get_op(bypass_pattern_a)
        if bypass_op is None:
            bypass_op = match_result.get_op(bypass_pattern_b)
        post_activation_bypass_op = match_result.get_op(
            post_activation_bypass_pattern_a)
        if post_activation_bypass_op is None:
            post_activation_bypass_op = match_result.get_op(
                post_activation_bypass_pattern_b)
        # If we don't find a post_activation_bypass_op but activation_op has a
        # bypass following it, then we need to skip this match, since there will be
        # another match that includes post_activation_bypass_op.
        if post_activation_bypass_op is None and _HasPostActivationBypass(
                activation_op):
            continue
        yield _LayerMatch(layer_op, weight_tensor, activation_op, bypass_op,
                          post_activation_bypass_op, bias_add_op)

    # Match the final layer, where there will not be an activation and instead
    # the output of the final BiasAdd must be quantized, so we treat it as the
    # 'activation_op' in the _LayerMatch.
    # TODO(suharshs): Figure out how to quantize this final layer across many
    # models.
    final_layer_matcher = graph_matcher.GraphMatcher(bias_add_pattern)
    for match_result in final_layer_matcher.match_graph(graph):
        layer_op = match_result.get_op(layer_pattern)
        weight_tensor = match_result.get_tensor(weight_pattern)
        if weight_tensor is None:
            weight_tensor = match_result.get_tensor(folded_weight_pattern)
        activation_op = match_result.get_op(bias_add_pattern)
        if activation_op is None:
            activation_op = match_result.get_op(folded_bias_add_pattern)
        yield _LayerMatch(layer_op, weight_tensor, activation_op, None, None,
                          None)
コード例 #19
0
def _FindFusedBatchNorms(graph):
  """Finds all ops and tensors related to found FusedBatchNorms.

  Args:
    graph: Graph to inspect.

  Yields:
    _FusedBatchNormMatches.
  """
  input_pattern = graph_matcher.OpTypePattern('*')
  weight_pattern = graph_matcher.OpTypePattern('*')
  gamma_pattern = graph_matcher.OpTypePattern('*')
  beta_pattern = graph_matcher.OpTypePattern('*')
  mean_pattern = graph_matcher.OpTypePattern('*')
  variance_pattern = graph_matcher.OpTypePattern('*')

  conv_pattern = graph_matcher.OpTypePattern(
      'Conv2D|DepthwiseConv2dNative', inputs=[input_pattern, weight_pattern])
  # MatMul has a Reshape between it and FusedBatchNorm.
  matmul_pattern = graph_matcher.OpTypePattern(
      'MatMul', inputs=[input_pattern, weight_pattern])
  matmul_reshape_pattern = graph_matcher.OpTypePattern(
      'Reshape', inputs=[matmul_pattern,
                         graph_matcher.OpTypePattern('*')])

  conv_batch_norm_pattern = graph_matcher.OpTypePattern(
      'FusedBatchNorm',
      inputs=[
          conv_pattern, gamma_pattern, beta_pattern, mean_pattern,
          variance_pattern
      ])
  matmul_batch_norm_pattern = graph_matcher.OpTypePattern(
      'FusedBatchNorm',
      inputs=[
          matmul_reshape_pattern, gamma_pattern, beta_pattern, mean_pattern,
          variance_pattern
      ])
  matmul_bn_output_reshape_pattern = graph_matcher.OpTypePattern(
      'Reshape',
      inputs=[matmul_batch_norm_pattern,
              graph_matcher.OpTypePattern('*')])

  conv_matcher = graph_matcher.GraphMatcher(conv_batch_norm_pattern)
  matmul_matcher = graph_matcher.GraphMatcher(matmul_bn_output_reshape_pattern)

  def _GetCommonTensors(match_result, bn_op, bn_input_tensor):
    """Gets tensors needed for FusedBatchNormMatch from match_result."""
    input_tensor = match_result.get_tensor(input_pattern)
    weight_tensor = match_result.get_tensor(weight_pattern)
    gamma_tensor = match_result.get_tensor(gamma_pattern)
    beta_tensor = match_result.get_tensor(beta_pattern)
    # FusedBatchNorm in training is different from that in inference. It takes
    # empty 'mean' and empty 'variance', and produces the mean and the variance
    # of the batch. Therefore, when is_training is true, mean_tensor and
    # variance_tensor point to 1st and 2nd (0-based) output of bn_op,
    # respectively; when is_training is false, they point to bn_op's inputs.
    is_training = bn_op.get_attr('is_training')
    if is_training:
      # FusedBatchNormGrad doesn't compute gradients of the batch_mean and
      # batch_variance outputs, so we need to substitute our own custom
      # gradient.
      # TODO (suharshs, raghuramank): Find a way to avoid needing this hack. id:923 gh:924
      # pylint: disable=protected-access
      bn_op._set_attr(
          '_gradient_op_type',
          attr_value_pb2.AttrValue(s=compat.as_bytes('FoldFusedBatchNormGrad')))
      # pylint: enable=protected-access
      mean_tensor = bn_op.outputs[1]
      # The batch variance used during forward and backward prop is biased,
      # i.e it is calculated as: V=sum(x(k)-mu)^2/N. For the moving average
      # calculation, the variance is corrected by the term N/N-1 (Bessel's
      # correction). The variance tensor read from FuseBatchNorm has bessel's
      # correction applied, so we undo it here.
      n = math_ops.cast(
          array_ops.size(bn_input_tensor) / array_ops.size(mean_tensor),
          dtypes.float32)
      variance_tensor = bn_op.outputs[2] * (n - 1) / n
    else:
      mean_tensor = match_result.get_tensor(mean_pattern)
      variance_tensor = match_result.get_tensor(variance_pattern)
    return (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
            variance_tensor)

  for match_result in conv_matcher.match_graph(graph):
    layer_op = match_result.get_op(conv_pattern)
    layer_tensor = match_result.get_tensor(conv_pattern)
    bn_op = match_result.get_op(conv_batch_norm_pattern)
    # In the case of convolution the output_tensor is the output of bn_op.
    output_tensor = bn_op.outputs[0]

    (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
     variance_tensor) = _GetCommonTensors(match_result, bn_op, layer_tensor)
    yield _FusedBatchNormMatch(
        layer_op=layer_op,
        bn_op=bn_op,
        output_tensor=output_tensor,
        input_tensor=input_tensor,
        weight_tensor=weight_tensor,
        gamma_tensor=gamma_tensor,
        beta_tensor=beta_tensor,
        mean_tensor=mean_tensor,
        variance_tensor=variance_tensor)

  for match_result in matmul_matcher.match_graph(graph):
    layer_op = match_result.get_op(matmul_pattern)
    layer_tensor = match_result.get_tensor(matmul_pattern)
    bn_op = match_result.get_op(matmul_batch_norm_pattern)
    # In the MatMul case, the output of batch norm is reshaped back into a
    # 2D tensor, so the output_tensor is the output of the Reshape op.
    output_reshape_op = match_result.get_op(matmul_bn_output_reshape_pattern)
    output_tensor = output_reshape_op.outputs[0]

    (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
     variance_tensor) = _GetCommonTensors(match_result, bn_op, layer_tensor)
    yield _FusedBatchNormMatch(
        layer_op=layer_op,
        bn_op=bn_op,
        output_tensor=output_tensor,
        input_tensor=input_tensor,
        weight_tensor=weight_tensor,
        gamma_tensor=gamma_tensor,
        beta_tensor=beta_tensor,
        mean_tensor=mean_tensor,
        variance_tensor=variance_tensor)