Beispiel #1
0
def transform_tanh_grad(g, op):
    # type: (TFGraph, TFOperation)->None

    # def tanh_grad(y, dy):
    #     return dy * (1 - y ** 2)

    y, dy = op.inputs

    const1 = TFTensor(graph=g, shape=[], dtype=op.output.dtype, data=1.0)
    square = TFOperation(graph=g,
                         name="tf.square",
                         inputs=y,
                         outputs=TFTensor(graph=g,
                                          shape=list(op.output.shape),
                                          dtype=op.output.dtype))
    sub = TFOperation(graph=g,
                      name="tf.subtract",
                      inputs=(const1, square.output),
                      outputs=TFTensor(graph=g,
                                       shape=list(op.output.shape),
                                       dtype=op.output.dtype))
    TFOperation(graph=g,
                name="tf.multiply",
                inputs=(dy, sub.output),
                outputs=op.outputs)
    g.remove_operation(op, unlink=True)
Beispiel #2
0
def transform_reciprocal_grad(g, op):
    # type: (TFGraph, TFOperation)->None

    # def reciprocal_grad(y, dy):
    #     return -dy * y ** 2

    y, dy = op.inputs

    neg = TFOperation(graph=g,
                      name="tf.negative",
                      inputs=dy,
                      outputs=TFTensor(graph=g,
                                       shape=list(op.output.shape),
                                       dtype=op.output.dtype))
    square = TFOperation(graph=g,
                         name="tf.square",
                         inputs=y,
                         outputs=TFTensor(graph=g,
                                          shape=list(op.output.shape),
                                          dtype=op.output.dtype))
    TFOperation(graph=g,
                name="tf.multiply",
                inputs=(neg.output, square.output),
                outputs=op.outputs)
    g.remove_operation(op, unlink=True)
Beispiel #3
0
def transform_sigmoid_grad(g, op):
    # type: (TFGraph, TFOperation)->None

    # def sigmoid_grad(y, dy):
    #     return dy * y * (1 - y)

    y, dy = op.inputs

    const1 = TFTensor(graph=g, shape=[], dtype=op.output.dtype, data=1.0)
    mul1 = TFOperation(graph=g,
                       name="tf.multiply",
                       inputs=(dy, y),
                       outputs=TFTensor(graph=g,
                                        shape=list(op.output.shape),
                                        dtype=op.output.dtype))
    sub = TFOperation(graph=g,
                      name="tf.subtract",
                      inputs=(const1, y),
                      outputs=TFTensor(graph=g,
                                       shape=list(op.output.shape),
                                       dtype=op.output.dtype))
    TFOperation(graph=g,
                name="tf.multiply",
                inputs=(mul1.output, sub.output),
                outputs=op.outputs)
    g.remove_operation(op, unlink=True)
Beispiel #4
0
def transform_softplus_grad(g, op):
    # type: (TFGraph, TFOperation)->None

    # def softplus_grad(gradients, features):
    #     return gradients * (tf.exp(features) / (tf.exp(features) + 1))

    gradients, features = op.inputs

    const1 = TFTensor(graph=g, shape=[], dtype=op.output.dtype, data=1.0)
    exp = TFOperation(graph=g,
                      name="tf.exp",
                      inputs=features,
                      outputs=TFTensor(graph=g,
                                       shape=list(op.output.shape),
                                       dtype=op.output.dtype))
    add = TFOperation(graph=g,
                      name="tf.add",
                      inputs=(exp.output, const1),
                      outputs=TFTensor(graph=g,
                                       shape=list(op.output.shape),
                                       dtype=op.output.dtype))
    div = TFOperation(graph=g,
                      name="tf.divide",
                      inputs=(exp.output, add.output),
                      outputs=TFTensor(graph=g,
                                       shape=list(op.output.shape),
                                       dtype=op.output.dtype))
    TFOperation(graph=g,
                name="tf.multiply",
                inputs=(gradients, div.output),
                outputs=op.outputs)
    g.remove_operation(op, unlink=True)
Beispiel #5
0
def transform_rsqrt_grad(g, op):
    # type: (TFGraph, TFOperation)->None

    # def rsqrt_grad(y, dy):
    #     return  (-0.5 * dy) *  y ** 3

    y, dy = op.inputs

    const_neg_half = TFTensor(graph=g,
                              shape=[],
                              dtype=op.output.dtype,
                              data=-0.5)
    const_3 = TFTensor(graph=g, shape=[], dtype=op.output.dtype, data=3.0)

    mul1 = TFOperation(graph=g,
                       name="tf.multiply",
                       inputs=(const_neg_half, dy),
                       outputs=TFTensor(graph=g,
                                        shape=list(op.output.shape),
                                        dtype=op.output.dtype))
    pow = TFOperation(graph=g,
                      name="tf.pow",
                      inputs=(y, const_3),
                      outputs=TFTensor(graph=g,
                                       shape=list(op.output.shape),
                                       dtype=op.output.dtype))
    TFOperation(graph=g,
                name="tf.multiply",
                inputs=(mul1.output, pow.output),
                outputs=op.outputs)
    g.remove_operation(op, unlink=True)
Beispiel #6
0
def transform_elu_grad(g, op):
    # type: (TFGraph, TFOperation)->None

    # def elu_grad(gradients, outputs):
    #     return tf.where(outputs > 0, gradients, gradients * (outputs + 1))

    gradients, outputs = op.inputs

    const0 = TFTensor(graph=g, shape=[], dtype=op.output.dtype, data=0.0)
    const1 = TFTensor(graph=g, shape=[], dtype=op.output.dtype, data=1.0)
    greater = TFOperation(graph=g,
                          name="tf.greater",
                          inputs=(outputs, const0),
                          outputs=TFTensor(graph=g,
                                           shape=list(op.output.shape),
                                           dtype="bool"))
    add = TFOperation(graph=g,
                      name="tf.add",
                      inputs=(outputs, const1),
                      outputs=TFTensor(graph=g,
                                       shape=list(op.output.shape),
                                       dtype=op.output.dtype))
    mul = TFOperation(graph=g,
                      name="tf.multiply",
                      inputs=(gradients, add.output),
                      outputs=TFTensor(graph=g,
                                       shape=list(op.output.shape),
                                       dtype=op.output.dtype))
    TFOperation(graph=g,
                name="tf.where",
                inputs=(greater.output, gradients, mul.output),
                outputs=op.outputs)
    g.remove_operation(op, unlink=True)
Beispiel #7
0
def transform_separate_duplicated_outputs(tf_graph):
    # type: (TFGraph)->None

    new_outputs = []
    seen = set()

    for tensor in tf_graph.outputs:
        if tensor in seen:
            new_outputs.append(
                TFOperation(graph=tf_graph,
                            name='tf.identity',
                            inputs=tensor,
                            outputs=TFTensor(graph=tf_graph,
                                             shape=list(tensor.shape),
                                             dtype=tensor.dtype,
                                             data=copy.copy(
                                                 tensor.data))).output)
        else:
            seen.add(tensor)
            new_outputs.append(tensor)

    if tf_graph.output_ids:
        tf_graph.outputs = OrderedDict([
            (name, tensor)
            for name, tensor in zip(tf_graph.output_ids, new_outputs)
        ])
    else:
        tf_graph.outputs = new_outputs
Beispiel #8
0
def transform_sqrt_grad(g, op):
    # type: (TFGraph, TFOperation)->None

    # def sqrt_grad(y, dy):
    #     return dy * 0.5 / y

    y, dy = op.inputs

    const_half = TFTensor(graph=g, shape=[], dtype=op.output.dtype, data=0.5)
    mul = TFOperation(graph=g,
                      name="tf.multiply",
                      inputs=(dy, const_half),
                      outputs=TFTensor(graph=g,
                                       shape=list(op.output.shape),
                                       dtype=op.output.dtype))
    TFOperation(graph=g,
                name="tf.divide",
                inputs=(mul.output, y),
                outputs=op.outputs)
    g.remove_operation(op, unlink=True)
Beispiel #9
0
def transform_relu6_grad(g, op):
    # type: (TFGraph, TFOperation)->None

    # def relu6_grad(gradients, features):
    #     return tf.where(features > 0 and features < 6, gradients, 0.0)

    gradients, features = op.inputs

    const0 = TFTensor(graph=g,
                      shape=list(op.output.shape),
                      dtype=op.output.dtype,
                      data=0.0)
    const6 = TFTensor(graph=g,
                      shape=list(op.output.shape),
                      dtype=op.output.dtype,
                      data=6.0)
    greater = TFOperation(graph=g,
                          name="tf.greater",
                          inputs=(features, const0),
                          outputs=TFTensor(graph=g,
                                           shape=list(op.output.shape),
                                           dtype="bool"))
    less = TFOperation(graph=g,
                       name="tf.less",
                       inputs=(features, const6),
                       outputs=TFTensor(graph=g,
                                        shape=list(op.output.shape),
                                        dtype="bool"))
    and_ = TFOperation(graph=g,
                       name="tf.logical_and",
                       inputs=(greater.output, less.output),
                       outputs=TFTensor(graph=g,
                                        shape=list(op.output.shape),
                                        dtype="bool"))
    TFOperation(graph=g,
                name="tf.where",
                inputs=(and_.output, gradients, const0),
                outputs=op.outputs)
    g.remove_operation(op, unlink=True)
Beispiel #10
0
def transform_cast(g, op):
    # type: (TFGraph, TFOperation)->None
    from_ = op.input.dtype  # type: str
    to_ = op.attribs["dtype"]  # type: str

    if (from_ == to_ or (from_.startswith('float') and to_.startswith('float'))
            or (from_.startswith('int') and to_.startswith('int'))):
        TFOperation(graph=g,
                    name="tf.identity",
                    inputs=op.input,
                    outputs=op.outputs)
    elif from_ == "bool" and to_.startswith("float"):
        zeros = TFTensor(graph=g,
                         shape=list(op.input.shape),
                         dtype=to_,
                         data=0.0)
        ones = TFTensor(graph=g,
                        shape=list(op.input.shape),
                        dtype=to_,
                        data=1.0)
        TFOperation(graph=g,
                    name="tf.where",
                    inputs=(op.input, ones, zeros),
                    outputs=op.outputs)
    elif from_.startswith("float") and to_ == "bool":
        zeros = TFTensor(graph=g,
                         shape=list(op.input.shape),
                         dtype=from_,
                         data=0.0)
        TFOperation(graph=g,
                    name="tf.not_equal",
                    inputs=(op.input, zeros),
                    outputs=op.outputs)
    else:
        print("Possibly unsupported tf.cast: {} -> {}".format(from_, to_))
        return
    g.remove_operation(op, unlink=True)
Beispiel #11
0
def transform_separate_inputs_and_outputs(tf_graph):
    # type: (TFGraph)->None

    for tensor in list(tf_graph.tensors):
        if tensor in tf_graph.inputs and tensor in tf_graph.outputs:
            output_tensor = TFTensor(graph=tf_graph,
                                     name=None,
                                     shape=list(tensor.shape),
                                     dtype=tensor.dtype,
                                     data=copy.copy(tensor.data))
            TFOperation(graph=tf_graph,
                        name="tf.identity",
                        inputs=tensor,
                        outputs=output_tensor)
            _replace_tensor_in_outputs(tf_graph, tensor, output_tensor)
Beispiel #12
0
def transform_relu_grad(g, op):
    # type: (TFGraph, TFOperation)->None

    # def relu_grad(gradients, features):
    #     return tf.where(features > 0, gradients, 0.0)

    gradients, features = op.inputs

    const0 = TFTensor(graph=g,
                      shape=list(op.output.shape),
                      dtype=op.output.dtype,
                      data=0.0)
    greater = TFOperation(graph=g,
                          name="tf.greater",
                          inputs=(features, const0),
                          outputs=TFTensor(graph=g,
                                           name=None,
                                           shape=list(op.output.shape),
                                           dtype="bool"))
    TFOperation(graph=g,
                name="tf.where",
                inputs=(greater.output, gradients, const0),
                outputs=op.outputs)
    g.remove_operation(op, unlink=True)
Beispiel #13
0
def transform_strided_slice(g, op):
    # type: (TFGraph, TFOperation)->None

    assert op.attribs["strides"] is None or all(s == 1 for s in op.attribs["strides"]), \
        "Only strides=1 is supported for tf.strided_slice, got: {}".format(op.attribs["strides"])

    ssl_begin, ssl_end, ssl_stride, ssl_shape, reshape_shape = shape_inference.decompose_strided_slice(
        input=op.input.shape,
        begin=op.attribs['begin'],
        end=op.attribs['end'],
        stride=op.attribs["strides"] if op.attribs["strides"] is not None else
        [1] * len(op.attribs["begin"]),
        ellipsis_mask=op.attribs['ellipsis_mask'],
        new_axis_mask=op.attribs['new_axis_mask'],
        shrink_axis_mask=op.attribs['shrink_axis_mask'],
        begin_mask=op.attribs['begin_mask'],
        end_mask=op.attribs['end_mask'])

    assert all(stride == 1 for stride in ssl_stride)

    slice_size = [e - b for b, e in zip(ssl_begin, ssl_end)]

    if reshape_shape != ssl_shape:
        slice_output = TFTensor(graph=g,
                                shape=ssl_shape,
                                dtype=op.output.dtype)
        TFOperation(graph=g,
                    name="tf.slice",
                    inputs=op.input,
                    attribs=dict(begin=ssl_begin, size=slice_size),
                    outputs=slice_output)
        TFOperation(graph=g,
                    name="tf.reshape",
                    inputs=slice_output,
                    attribs=dict(shape=reshape_shape),
                    outputs=op.outputs)
    else:
        TFOperation(graph=g,
                    name="tf.slice",
                    inputs=op.input,
                    attribs=dict(begin=ssl_begin, size=slice_size),
                    outputs=op.outputs)

    g.remove_operation(op, unlink=True)
Beispiel #14
0
def transform_min_or_max_grad(g, op):
    # type: (TFGraph, TFOperation)->None

    # def _MinOrMaxGrad(input, axes, y, grad):
    #     output_shape_kept_dims = reduced_shape(input.shape, axes)
    #     y = reshape(y, output_shape_kept_dims) # needed?
    #     grad = reshape(grad, output_shape_kept_dims) # needed?
    #     equal = math_ops.equal(y, input)
    #     indicators = cast(equal, tf.float32)
    #     num_selected = reshape(reduce_sum(indicators, axes), output_shape_kept_dims) # needed?
    #     return indicators / num_selected * grad

    input, y, grad = op.inputs

    axes = _nonneg_axes(op.attribs["orig_axis"],
                        input.rank,
                        none_means_all=True)
    output_shape_kept_dims = _reduced_shape(input.shape, axes)

    reshape0 = TFOperation(graph=g,
                           name="tf.reshape",
                           inputs=y,
                           attribs=dict(shape=list(output_shape_kept_dims)),
                           outputs=TFTensor(graph=g,
                                            shape=list(output_shape_kept_dims),
                                            dtype=op.output.dtype))
    reshape1 = TFOperation(graph=g,
                           name="tf.reshape",
                           inputs=grad,
                           attribs=dict(shape=list(output_shape_kept_dims)),
                           outputs=TFTensor(graph=g,
                                            shape=list(output_shape_kept_dims),
                                            dtype=op.output.dtype))
    equal = TFOperation(graph=g,
                        name="tf.equal",
                        inputs=(reshape0.output, input),
                        outputs=TFTensor(graph=g,
                                         shape=input.shape,
                                         dtype="bool"))
    const0 = TFTensor(graph=g,
                      shape=list(equal.output.shape),
                      dtype=op.output.dtype,
                      data=0.0)
    const1 = TFTensor(graph=g,
                      shape=list(equal.output.shape),
                      dtype=op.output.dtype,
                      data=1.0)
    where = TFOperation(graph=g,
                        name="tf.where",
                        inputs=(equal.output, const1, const0),
                        outputs=TFTensor(graph=g,
                                         shape=list(equal.output.shape),
                                         dtype=op.output.dtype))
    reduce = TFOperation(graph=g,
                         name="tf.reduce_sum",
                         inputs=where.output,
                         attribs=dict(axis=axes, keepdims=False),
                         outputs=TFTensor(graph=g,
                                          shape=list(output_shape_kept_dims),
                                          dtype=op.output.dtype))
    reshape2 = TFOperation(graph=g,
                           name="tf.reshape",
                           inputs=reduce.output,
                           attribs=dict(shape=list(output_shape_kept_dims)),
                           outputs=TFTensor(graph=g,
                                            shape=list(output_shape_kept_dims),
                                            dtype=op.output.dtype))
    div = TFOperation(graph=g,
                      name="tf.divide",
                      inputs=(where.output, reshape2.output),
                      outputs=TFTensor(graph=g,
                                       shape=list(output_shape_kept_dims),
                                       dtype=op.output.dtype))
    TFOperation(graph=g,
                name="tf.multiply",
                inputs=(div.output, reshape1.output),
                outputs=op.outputs)
    g.remove_operation(op, unlink=True)
Beispiel #15
0
def transform_strided_slice_grad(g, op):
    # type: (TFGraph, TFOperation)->None

    def is_compatible(s1, s2):
        s1 = list(s1)
        s2 = list(s2)
        if (s1 == [] and s2 == [1]) or (s2 == [] and s1 == [1]):
            return True
        for a, b in zip(s1, s2):
            if a != b:
                return False
        return True

    assert op.attribs["strides"] is None or all(s == 1 for s in op.attribs["strides"]), \
        "Only strides=1 is supported for tf.strided_slice, got: {}".format(op.attribs["strides"])

    input_shape = op.attribs["shape"]

    ssl_begin, ssl_end, ssl_stride, ssl_shape, reshape_shape = shape_inference.decompose_strided_slice(
        input=input_shape,
        begin=op.attribs['begin'],
        end=op.attribs['end'],
        stride=op.attribs["strides"] if op.attribs["strides"] is not None else
        [1] * len(op.attribs["begin"]),
        ellipsis_mask=op.attribs['ellipsis_mask'],
        new_axis_mask=op.attribs['new_axis_mask'],
        shrink_axis_mask=op.attribs['shrink_axis_mask'],
        begin_mask=op.attribs['begin_mask'],
        end_mask=op.attribs['end_mask'])

    assert all(stride == 1 for stride in ssl_stride)

    if reshape_shape != ssl_shape:
        assert is_compatible(reshape_shape, op.input.shape), \
            "Shape mismatch in strided_slice_grad {} {}".format(reshape_shape, op.input.shape)

        reshape = TFOperation(graph=g,
                              name="tf.reshape",
                              inputs=op.input,
                              attribs=dict(shape=list(ssl_shape)),
                              outputs=TFTensor(graph=g,
                                               shape=list(ssl_shape),
                                               dtype=op.output.dtype))
        TFOperation(graph=g,
                    name="tf.pad",
                    inputs=reshape.output,
                    attribs=dict(paddings=[[
                        b, s - e
                    ] for b, e, s in zip(ssl_begin, ssl_end, input_shape)],
                                 mode="CONSTANT",
                                 constant_values=0),
                    outputs=op.outputs)
    else:
        TFOperation(graph=g,
                    name="tf.pad",
                    inputs=op.input,
                    attribs=dict(paddings=[[
                        b, s - e
                    ] for b, e, s in zip(ssl_begin, ssl_end, input_shape)],
                                 mode="CONSTANT",
                                 constant_values=0),
                    outputs=op.outputs)

    g.remove_operation(op, unlink=True)
Beispiel #16
0
def transform_lrn_grad(g, op):
    # type: (TFGraph, TFOperation)->None

    # def lrn_grad(input_grads, input_image, output_image, depth_radius=5, bias=1, alpha=1, beta=0.5, name=None)

    input_grads, input_image, output_image = op.inputs

    depth_radius = int(op.attribs["depth_radius"])
    bias = op.attribs["bias"]
    alpha = op.attribs["alpha"]
    beta = op.attribs["beta"]

    input_shape = input_image.shape
    input_shape_transposed = input_shape[:-2] + [
        input_shape[-1], input_shape[-2]
    ]
    input_shape_transposed_padded = input_shape[:-2] + [
        input_shape[-1] + 2 * depth_radius, input_shape[-2]
    ]
    input_dtype = input_image.dtype

    t_depth_size = TFTensor(graph=g,
                            shape=[],
                            dtype=input_dtype,
                            data=2.0 * depth_radius + 1.0)
    t_alpha = TFTensor(graph=g, shape=[], dtype=input_dtype, data=alpha)
    t_bias = TFTensor(graph=g, shape=[], dtype=input_dtype, data=bias)
    t_beta = TFTensor(graph=g, shape=[], dtype=input_dtype, data=beta)
    t_beta_minus_1 = TFTensor(graph=g,
                              shape=[],
                              dtype=input_dtype,
                              data=beta - 1.0)
    const2 = TFTensor(graph=g, shape=[], dtype=input_dtype, data=2.0)

    tensor0 = input_image
    op1 = TFOperation(graph=g,
                      name="tf.square",
                      inputs=tensor0,
                      outputs=TFTensor(graph=g,
                                       shape=list(input_shape),
                                       dtype=input_dtype))
    op2 = TFOperation(graph=g,
                      name="tf.transpose",
                      inputs=op1.output,
                      attribs=dict(perm=[0, 1, 3, 2], conjugate=False),
                      outputs=TFTensor(graph=g,
                                       shape=list(input_shape_transposed),
                                       dtype=input_dtype))
    op3 = TFOperation(graph=g,
                      name="tf.pad",
                      inputs=op2.output,
                      attribs=dict(mode="CONSTANT",
                                   paddings=[(0, 0), (0, 0),
                                             (depth_radius, depth_radius),
                                             (0, 0)],
                                   constant_values=0),
                      outputs=TFTensor(
                          graph=g,
                          name=None,
                          shape=list(input_shape_transposed_padded),
                          dtype=input_dtype))
    op4 = TFOperation(graph=g,
                      name="_avg_pool",
                      inputs=op3.output,
                      attribs=dict(padding="VALID",
                                   size=[1, 1, 2 * depth_radius + 1, 1],
                                   stride=[1, 1, 1, 1],
                                   data_format="NHWC"),
                      outputs=TFTensor(graph=g,
                                       shape=list(input_shape_transposed),
                                       dtype=input_dtype))
    op5 = TFOperation(graph=g,
                      name="tf.multiply",
                      inputs=(t_depth_size, op4.output),
                      outputs=TFTensor(graph=g,
                                       shape=list(input_shape_transposed),
                                       dtype=input_dtype))
    op6 = TFOperation(graph=g,
                      name="tf.transpose",
                      inputs=op5.output,
                      attribs=dict(perm=[0, 1, 3, 2], conjugate=False),
                      outputs=TFTensor(graph=g,
                                       shape=list(input_shape),
                                       dtype=input_dtype))
    op7 = TFOperation(graph=g,
                      name="tf.multiply",
                      inputs=(t_alpha, op6.output),
                      outputs=TFTensor(graph=g,
                                       shape=list(input_shape),
                                       dtype=input_dtype))
    op8 = TFOperation(graph=g,
                      name="tf.add",
                      inputs=(t_bias, op7.output),
                      outputs=TFTensor(graph=g,
                                       shape=list(input_shape),
                                       dtype=input_dtype))
    op9 = TFOperation(graph=g,
                      name="tf.pow",
                      inputs=(op8.output, t_beta),
                      outputs=TFTensor(graph=g,
                                       shape=list(input_shape),
                                       dtype=input_dtype))
    tensor10 = input_grads
    op11 = TFOperation(graph=g,
                       name="tf.divide",
                       inputs=(tensor10, op9.output),
                       outputs=TFTensor(graph=g,
                                        shape=list(input_shape),
                                        dtype=input_dtype))
    op12 = TFOperation(graph=g,
                       name="tf.negative",
                       inputs=tensor0,
                       outputs=TFTensor(graph=g,
                                        shape=list(input_shape),
                                        dtype=input_dtype))
    op13 = TFOperation(graph=g,
                       name="tf.divide",
                       inputs=(op12.output, op9.output),
                       outputs=TFTensor(graph=g,
                                        shape=list(input_shape),
                                        dtype=input_dtype))
    op14 = TFOperation(graph=g,
                       name="tf.divide",
                       inputs=(op13.output, op9.output),
                       outputs=TFTensor(graph=g,
                                        shape=list(input_shape),
                                        dtype=input_dtype))
    op15 = TFOperation(graph=g,
                       name="tf.multiply",
                       inputs=(tensor10, op14.output),
                       outputs=TFTensor(graph=g,
                                        shape=list(input_shape),
                                        dtype=input_dtype))
    op16 = TFOperation(graph=g,
                       name="tf.multiply",
                       inputs=(op15.output, t_beta),
                       outputs=TFTensor(graph=g,
                                        shape=list(input_shape),
                                        dtype=input_dtype))
    tensor17 = t_beta_minus_1
    op18 = TFOperation(graph=g,
                       name="tf.pow",
                       inputs=(op8.output, tensor17),
                       outputs=TFTensor(graph=g,
                                        shape=list(input_shape),
                                        dtype=input_dtype))
    op19 = TFOperation(graph=g,
                       name="tf.multiply",
                       inputs=(op16.output, op18.output),
                       outputs=TFTensor(graph=g,
                                        shape=list(input_shape),
                                        dtype=input_dtype))
    op20 = TFOperation(graph=g,
                       name="tf.multiply",
                       inputs=(t_alpha, op19.output),
                       outputs=TFTensor(graph=g,
                                        shape=list(input_shape),
                                        dtype=input_dtype))
    op21 = TFOperation(graph=g,
                       name="tf.transpose",
                       attribs=dict(conjugate=False, perm=[0, 1, 3, 2]),
                       inputs=op20.output,
                       outputs=TFTensor(graph=g,
                                        shape=list(input_shape_transposed),
                                        dtype=input_dtype))
    op22 = TFOperation(graph=g,
                       name="tf.multiply",
                       inputs=(t_depth_size, op21.output),
                       outputs=TFTensor(graph=g,
                                        shape=list(input_shape_transposed),
                                        dtype=input_dtype))
    op23 = TFOperation(
        graph=g,
        name="_avg_pool_grad",
        inputs=op22.output,
        attribs=dict(padding="VALID",
                     size=[1, 1, int(2 * depth_radius + 1), 1],
                     stride=[1, 1, 1, 1],
                     orig_input_shape=list(input_shape_transposed_padded),
                     data_format="NHWC"),
        outputs=TFTensor(graph=g,
                         shape=list(input_shape_transposed_padded),
                         dtype=input_dtype))
    op24 = TFOperation(graph=g,
                       name="tf.slice",
                       inputs=op23.output,
                       attribs=dict(begin=[0, 0, depth_radius, 0],
                                    size=[-1, -1, input_shape[-1], -1]),
                       outputs=TFTensor(graph=g,
                                        shape=list(input_shape_transposed),
                                        dtype=input_dtype))
    op25 = TFOperation(graph=g,
                       name="tf.transpose",
                       inputs=op24.output,
                       attribs=dict(conjugate=False, perm=[0, 1, 3, 2]),
                       outputs=TFTensor(graph=g,
                                        shape=list(input_shape),
                                        dtype=input_dtype))
    op26 = TFOperation(graph=g,
                       name="tf.multiply",
                       inputs=(tensor0, const2),
                       outputs=TFTensor(graph=g,
                                        shape=list(input_shape),
                                        dtype=input_dtype))
    op27 = TFOperation(graph=g,
                       name="tf.multiply",
                       inputs=(op25.output, op26.output),
                       outputs=TFTensor(graph=g,
                                        shape=list(input_shape),
                                        dtype=input_dtype))
    TFOperation(graph=g,
                name="tf.add",
                inputs=(op11.output, op27.output),
                outputs=op.outputs)
    g.remove_operation(op, unlink=True)
Beispiel #17
0
def transform_fused_batch_norm(g, op):
    # type: (TFGraph, TFOperation)->None
    VARIANCE_CORRECTION_ENABLED = True

    in_input = op.inputs[0]
    in_scale = op.inputs[1]
    in_offset = op.inputs[2]

    epsilon = op.attribs["epsilon"]

    out_y = op.outputs[0]
    out_batch_mean = op.outputs[1]
    out_batch_var = op.outputs[2]

    data_format = op.attribs["data_format"].upper(
    ) if op.attribs["data_format"] else "NHWC"
    channel_dim = 1 if data_format == "NCHW" else in_input.rank - 1
    rest_count = int(op.inputs[0].count / channel_dim)
    tensors_to_remove = []

    if op.attribs["is_training"]:
        if VARIANCE_CORRECTION_ENABLED:
            biased_batch_var = TFTensor(graph=g,
                                        shape=list(out_batch_var.shape),
                                        dtype=out_batch_var.dtype)
            const = TFTensor(graph=g,
                             shape=[],
                             dtype=in_input.dtype,
                             data=float(rest_count) / max(rest_count - 1, 1))
            TFOperation(graph=g,
                        name="tf.nn.moments",
                        inputs=in_input,
                        attribs=dict(axes=utils.without(
                            range(in_input.rank), channel_dim),
                                     keep_dims=False),
                        outputs=(out_batch_mean, biased_batch_var))
            TFOperation(graph=g,
                        name="tf.multiply",
                        inputs=(biased_batch_var, const),
                        outputs=out_batch_var)
            TFOperation(graph=g,
                        name="tf.nn.batch_normalization",
                        inputs=(in_input, out_batch_mean, out_batch_var,
                                in_offset, in_scale),
                        attribs=dict(variance_epsilon=epsilon,
                                     _data_format=data_format),
                        outputs=out_y)
            if len(op.outputs) > 3:  # This can happen in gradients
                out_saved_mean = op.outputs[3]
                out_saved_var = op.outputs[4]
                graph_utils.replace_tensor_in_consumers(
                    g, out_saved_mean, out_batch_mean)
                graph_utils.replace_tensor_in_consumers(
                    g, out_saved_var, out_batch_var)
                tensors_to_remove += [out_saved_mean, out_saved_var]
        else:  # not VARIANCE_CORRECTION_ENABLED
            TFOperation(graph=g,
                        name="tf.nn.moments",
                        inputs=in_input,
                        attribs=dict(axes=utils.without(
                            range(in_input.rank), channel_dim),
                                     keep_dims=False),
                        outputs=(out_batch_mean, out_batch_var))
            TFOperation(graph=g,
                        name="tf.nn.batch_normalization",
                        inputs=(in_input, out_batch_mean, out_batch_var,
                                in_offset, in_scale),
                        attribs=dict(variance_epsilon=epsilon,
                                     _data_format=data_format),
                        outputs=out_y)
            if len(op.outputs) > 3:  # This can happen in gradients
                out_saved_mean = op.outputs[3]
                out_saved_var = op.outputs[4]
                graph_utils.replace_tensor_in_consumers(
                    g, out_saved_mean, out_batch_mean)
                graph_utils.replace_tensor_in_consumers(
                    g, out_saved_var, out_batch_var)
                tensors_to_remove += [out_saved_mean, out_saved_var]
    else:  # not training
        in_mean = op.inputs[3]
        in_variance = op.inputs[4]
        graph_utils.replace_tensor_in_consumers(g, out_batch_mean, in_mean)
        graph_utils.replace_tensor_in_consumers(g, out_batch_var, in_variance)
        tensors_to_remove += [out_batch_mean, out_batch_var]
        if len(op.outputs) > 3:  # This can happen in gradients
            out_saved_mean = op.outputs[3]
            out_saved_var = op.outputs[4]
            graph_utils.replace_tensor_in_consumers(g, out_saved_mean, in_mean)
            graph_utils.replace_tensor_in_consumers(g, out_saved_var,
                                                    in_variance)
            tensors_to_remove += [out_saved_mean, out_saved_var]
        TFOperation(graph=g,
                    name="tf.nn.batch_normalization",
                    inputs=(in_input, in_mean, in_variance, in_offset,
                            in_scale),
                    attribs=dict(variance_epsilon=epsilon,
                                 _data_format=data_format),
                    outputs=out_y)
    g.remove_operation(op, unlink=True)
    for t in tensors_to_remove:
        g.remove_tensor(t)
Beispiel #18
0
 def create_tensor(self, graph, name, shape, dtype):
     return TFTensor(graph=graph, name=name, shape=shape, dtype=dtype)