Пример #1
0
def expand_softmax(tf_graph, tf_op):
    assert tf_op.input.rank != 0

    axis = tf_op.attribs.get('axis')
    if axis is None:
        axis = -1
    if axis < 0:
        axis += tf_op.input.rank

    tf_op.attribs['axis'] = -1

    if tf_op.input.rank == 2 and axis == 1:
        return

    if axis != tf_op.input.rank - 1:
        perm = utils.without(range(tf_op.input.rank), axis) + [axis]
        perm_inv = utils.inverse_permutation(perm)
        transpose = TFOperation(graph=tf_graph,
                                name="tf.transpose",
                                inputs=tf_op.input,
                                attribs=dict(perm=perm),
                                outputs=TFTensor(graph=tf_graph,
                                                 name=None,
                                                 shape=infer.transpose(
                                                     input=tf_op.input.shape,
                                                     axes=perm),
                                                 dtype=tf_op.input.dtype))
        tf_op.inputs = transpose.output
        old_output = tf_op.output
        tf_op.outputs = TFTensor(graph=tf_graph,
                                 name=None,
                                 shape=tf_op.input.shape,
                                 dtype=tf_op.input.dtype)
        TFOperation(graph=tf_graph,
                    name="tf.transpose",
                    inputs=tf_op.output,
                    attribs=dict(perm=perm_inv),
                    outputs=old_output)

    if tf_op.input.rank != 2:
        shape = [-1, tf_op.input.shape[-1]]
        reshape = TFOperation(graph=tf_graph,
                              name="tf.reshape",
                              inputs=tf_op.input,
                              attribs=dict(shape=shape),
                              outputs=TFTensor(graph=tf_graph,
                                               name=None,
                                               shape=infer.reshape(
                                                   input=tf_op.input.shape,
                                                   shape=shape),
                                               dtype=tf_op.input.dtype))
        tf_op.inputs = reshape.output
        old_output = tf_op.output
        tf_op.outputs = TFTensor(graph=tf_graph,
                                 name=None,
                                 shape=list(tf_op.input.shape),
                                 dtype=tf_op.input.dtype)
        TFOperation(graph=tf_graph,
                    name="tf.reshape",
                    inputs=tf_op.output,
                    attribs=dict(shape=old_output.shape),
                    outputs=old_output)
Пример #2
0
def transform_fused_batch_norm(g, op):
    # type: (TFGraph, TFOperation)->None
    VARIANCE_CORRECTION_ENABLED = True

    in_input = op.inputs[0]
    in_scale = op.inputs[1]
    in_offset = op.inputs[2]

    epsilon = op.attribs["epsilon"]

    out_y = op.outputs[0]
    out_batch_mean = op.outputs[1]
    out_batch_var = op.outputs[2]

    data_format = op.attribs["data_format"].upper(
    ) if op.attribs["data_format"] else "NHWC"
    channel_dim = 1 if data_format == "NCHW" else in_input.rank - 1
    rest_count = int(op.inputs[0].count / channel_dim)
    tensors_to_remove = []

    if op.attribs["is_training"]:
        if VARIANCE_CORRECTION_ENABLED:
            biased_batch_var = TFTensor(graph=g,
                                        shape=list(out_batch_var.shape),
                                        dtype=out_batch_var.dtype)
            const = TFTensor(graph=g,
                             shape=[],
                             dtype=in_input.dtype,
                             data=float(rest_count) / max(rest_count - 1, 1))
            TFOperation(graph=g,
                        name="tf.nn.moments",
                        inputs=in_input,
                        attribs=dict(axes=utils.without(
                            range(in_input.rank), channel_dim),
                                     keep_dims=False),
                        outputs=(out_batch_mean, biased_batch_var))
            TFOperation(graph=g,
                        name="tf.multiply",
                        inputs=(biased_batch_var, const),
                        outputs=out_batch_var)
            TFOperation(graph=g,
                        name="tf.nn.batch_normalization",
                        inputs=(in_input, out_batch_mean, out_batch_var,
                                in_offset, in_scale),
                        attribs=dict(variance_epsilon=epsilon,
                                     _data_format=data_format),
                        outputs=out_y)
            if len(op.outputs) > 3:  # This can happen in gradients
                out_saved_mean = op.outputs[3]
                out_saved_var = op.outputs[4]
                graph_utils.replace_tensor_in_consumers(
                    g, out_saved_mean, out_batch_mean)
                graph_utils.replace_tensor_in_consumers(
                    g, out_saved_var, out_batch_var)
                tensors_to_remove += [out_saved_mean, out_saved_var]
        else:  # not VARIANCE_CORRECTION_ENABLED
            TFOperation(graph=g,
                        name="tf.nn.moments",
                        inputs=in_input,
                        attribs=dict(axes=utils.without(
                            range(in_input.rank), channel_dim),
                                     keep_dims=False),
                        outputs=(out_batch_mean, out_batch_var))
            TFOperation(graph=g,
                        name="tf.nn.batch_normalization",
                        inputs=(in_input, out_batch_mean, out_batch_var,
                                in_offset, in_scale),
                        attribs=dict(variance_epsilon=epsilon,
                                     _data_format=data_format),
                        outputs=out_y)
            if len(op.outputs) > 3:  # This can happen in gradients
                out_saved_mean = op.outputs[3]
                out_saved_var = op.outputs[4]
                graph_utils.replace_tensor_in_consumers(
                    g, out_saved_mean, out_batch_mean)
                graph_utils.replace_tensor_in_consumers(
                    g, out_saved_var, out_batch_var)
                tensors_to_remove += [out_saved_mean, out_saved_var]
    else:  # not training
        in_mean = op.inputs[3]
        in_variance = op.inputs[4]
        graph_utils.replace_tensor_in_consumers(g, out_batch_mean, in_mean)
        graph_utils.replace_tensor_in_consumers(g, out_batch_var, in_variance)
        tensors_to_remove += [out_batch_mean, out_batch_var]
        if len(op.outputs) > 3:  # This can happen in gradients
            out_saved_mean = op.outputs[3]
            out_saved_var = op.outputs[4]
            graph_utils.replace_tensor_in_consumers(g, out_saved_mean, in_mean)
            graph_utils.replace_tensor_in_consumers(g, out_saved_var,
                                                    in_variance)
            tensors_to_remove += [out_saved_mean, out_saved_var]
        TFOperation(graph=g,
                    name="tf.nn.batch_normalization",
                    inputs=(in_input, in_mean, in_variance, in_offset,
                            in_scale),
                    attribs=dict(variance_epsilon=epsilon,
                                 _data_format=data_format),
                    outputs=out_y)
    g.remove_operation(op, unlink=True)
    for t in tensors_to_remove:
        g.remove_tensor(t)