Ejemplo n.º 1
0
    def testBatchNormScope(self):
        batch_size, height, width, depth = 5, 128, 128, 3
        g = ops.Graph()
        with g.as_default():
            inputs = array_ops.zeros((batch_size, height, width, depth))
            stride = 1
            out_depth = 32
            scope = ''
            node = conv2d(inputs,
                          out_depth, [2, 2],
                          stride=stride,
                          padding='SAME',
                          weights_initializer=self._WeightInit(0.09),
                          activation_fn=None,
                          normalizer_fn=batch_norm,
                          normalizer_params=self._BatchNormParams(False),
                          scope=scope)

            node = nn_ops.relu(node, name='Relu6')
        bn_list = common.BatchNormGroups(g)
        with open('/tmp/common_test.pbtxt', 'w') as f:
            f.write(str(g.as_graph_def()))

    # Exactly one batch norm layer with empty scope should be found
        self.assertEqual(len(bn_list), 1)
        self.assertEqual(bn_list[0], '')
Ejemplo n.º 2
0
def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
    """Finds unfused batch norm layers and folds them into preceding layers.

  Folding only affects the following layers: Conv2D, fully connected, depthwise
  convolution.

  Args:
    graph: Graph to walk and modify.
    is_training: Bool, True if training.
    freeze_batch_norm_delay: How many steps to wait before freezing moving mean
      and variance and using them for batch normalization.

  Raises:
    ValueError: When batch norm folding fails.
  """
    input_to_ops_map = input_to_ops.InputToOps(graph)

    for bn in common.BatchNormGroups(graph):
        has_scaling = _HasScaling(graph, input_to_ops_map, bn)

        if not _IsValidUnfusedBatchNorm(graph, bn):
            continue

        # The mangling code intimately depends on BatchNorm node's internals.
        original_op, folded_op = _CreateFoldedOp(
            graph,
            bn,
            has_scaling=has_scaling,
            freeze_batch_norm_delay=freeze_batch_norm_delay,
            is_training=is_training)

        activation = common.GetEndpointActivationOp(graph, bn)
        if activation:
            nodes_modified_count = common.RerouteTensor(
                folded_op.outputs[0],
                original_op.outputs[0],
                can_modify=[activation])
            if nodes_modified_count != 1:
                raise ValueError('Unexpected inputs to op: %s' %
                                 activation.name)
            continue

        # Treat consumer ops in bypass modules differently since they have Add
        # operations instead of Relu* above.
        add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
        add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add')
        nodes_modified_count = common.RerouteTensor(folded_op.outputs[0],
                                                    original_op.outputs[0],
                                                    can_modify=[add_bypass])
        if nodes_modified_count != 1:
            raise ValueError('Unexpected inputs to op: %s' % add_bypass.name)
Ejemplo n.º 3
0
def FoldBatchNorms(graph):
    """Finds batch norm layers in the graph, folds them into preceding layers.

  Folding only affects the following layers: Conv2D, fully connected, depthwise
  convolution.

  Args:
    graph: Graph to walk and modify.

  Raises:
    ValueError: When batch norm folding fails.
  """
    # Fail immediately when the graph contains unsupported fused batch norm ops.
    if any(op for op in graph.get_operations() if op.type == 'FusedBatchNorm'):
        raise ValueError('Fused batch norm is not supported')

    input_to_ops_map = input_to_ops.InputToOps(graph)

    for bn in common.BatchNormGroups(graph):
        has_scaling = _HasScaling(graph, input_to_ops_map, bn)

        # The mangling code intimately depends on BatchNorm node's internals.
        original_op, folded_op = _CreateFoldedOp(graph,
                                                 bn,
                                                 has_scaling=has_scaling)

        activation = common.GetEndpointActivationOp(graph, bn)
        if activation:
            nodes_modified_count = graph_editor.reroute_ts(
                [folded_op.outputs[0]], [original_op.outputs[0]],
                can_modify=[activation])
            if nodes_modified_count != 1:
                raise ValueError('Unexpected inputs to op: %s' %
                                 activation.name)
            continue

        # Treat consumer ops in bypass modules differently since they have Add
        # operations instead of Relu* above.
        add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
        add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add')
        nodes_modified_count = graph_editor.reroute_ts(
            [folded_op.outputs[0]], [original_op.outputs[0]],
            can_modify=[add_bypass])
        if nodes_modified_count != 1:
            raise ValueError('Unexpected inputs to op: %s' % add_bypass.name)
Ejemplo n.º 4
0
def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
    """Finds unfused batch norm layers and folds them into preceding layers.

  Folding only affects the following layers: Conv2D, fully connected, depthwise
  convolution.

  Args:
    graph: Graph to walk and modify.
    is_training: Bool, True if training.
    freeze_batch_norm_delay: How many steps to wait before freezing moving mean
      and variance and using them for batch normalization.

  Raises:
    ValueError: When batch norm folding fails.
  """
    input_to_ops_map = input_to_ops.InputToOps(graph)

    for bn in common.BatchNormGroups(graph):
        has_scaling = _HasScaling(graph, input_to_ops_map, bn)

        if not _IsValidUnfusedBatchNorm(graph, bn):
            continue

        # The mangling code intimately depends on BatchNorm node's internals.
        original_op, folded_op = _CreateFoldedOp(
            graph,
            bn,
            has_scaling=has_scaling,
            freeze_batch_norm_delay=freeze_batch_norm_delay,
            is_training=is_training)

        # TODO: generalise
        activation = input_to_ops_map.ConsumerOperations(original_op).pop()
        # assert any(activation.type == o or
        #            o.lower() in activation.name.split("/")[-1].lower()
        #            for o in (common._ACTIVATION_OP_SUFFIXES + ["Add"]))

        nodes_modified_count = common.RerouteTensor(folded_op.outputs[0],
                                                    original_op.outputs[0],
                                                    can_modify=[activation])
        if nodes_modified_count != 1:
            raise ValueError('Unexpected inputs to op: %s' % activation.name)
Ejemplo n.º 5
0
def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
    """Finds unfused batch norm layers and folds them into preceding layers.

  Folding only affects the following layers: Conv2D, fully connected, depthwise
  convolution.

  Args:
    graph: Graph to walk and modify.
    is_training: Bool, True if training.
    freeze_batch_norm_delay: How many steps to wait before freezing moving mean
      and variance and using them for batch normalization.

  Raises:
    ValueError: When batch norm folding fails.
  """
    input_to_ops_map = input_to_ops.InputToOps(graph)

    for bn in common.BatchNormGroups(graph):
        has_scaling = _HasScaling(graph, input_to_ops_map, bn)

        if not _IsValidUnfusedBatchNorm(graph, bn):
            continue

        print("found unfused batchnarm")
        raise Exception("Not Implemented")

        # The mangling code intimately depends on BatchNorm node's internals.
        original_op, folded_op = _CreateFoldedOp(
            graph,
            bn,
            has_scaling=has_scaling,
            freeze_batch_norm_delay=freeze_batch_norm_delay,
            is_training=is_training)

        activation = common.GetEndpointActivationOp(graph, bn)
        if activation:
            nodes_modified_count = common.RerouteTensor(
                folded_op.outputs[0],
                original_op.outputs[0],
                can_modify=[activation])
            if nodes_modified_count != 1:
                raise ValueError('Unexpected inputs to op: %s' %
                                 activation.name)
            continue

        # Treat consumer ops in bypass modules differently since they have Add
        # operations instead of Relu* above.
        # Changes to make sure that the correct scope is selected for the bypass add
        # The rule here is that if the scope is of the form: str1/str2 for the
        # batch norm,
        # the bypass add is at scope str1. If bn is of scope just str1, then the
        # bypass add is at scope ''.
        # If there is no batch norm, then there is no bypass add.
        add_bypass_ctx = ''
        if bn:
            try:
                add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
            except AttributeError:
                add_bypass_ctx = ''

        if add_bypass_ctx:
            add_bypass_ctx = add_bypass_ctx + '/'

        add_bypass = graph.get_operation_by_name(add_bypass_ctx + 'Add')
        nodes_modified_count = common.RerouteTensor(folded_op.outputs[0],
                                                    original_op.outputs[0],
                                                    can_modify=[add_bypass])
        if nodes_modified_count != 1:
            raise ValueError('Unexpected inputs to op: %s' % add_bypass.name)