Esempio n. 1
0
    def instance_normalization(unused_ctx, value, scale, bias, epsilon=1e-5):
        shape = [value.shape.dims[1]] + ([1] * (value.shape.ndims - 2))
        scale = op.reshape(scale, shape)
        bias = op.reshape(bias, shape)
        mean = op.mean(value,
                       axes=list(range(2, value.shape.ndims)),
                       keepdims=True)
        variance = op.variance(value,
                               axes=list(range(2, value.shape.ndims)),
                               keepdims=True)

        denom = op.sqrt(variance + epsilon)
        return (((value - mean) * scale / denom) + bias, )
Esempio n. 2
0
 def pow(data, exponent, axis=None, broadcast=None):
     if not broadcast and data.shape.dims != exponent.shape.dims:
         raise tile.LogicError('Incompatible shapes in power')
     if broadcast and (axis is not None):
         exponent = op.reshape(
             exponent,
             list(exponent.shape.dims) +
             ([1] * (data.shape.ndims - exponent.shape.ndims - axis)))
     return (op.pow(data, exponent), )
Esempio n. 3
0
 def mul(a, b, axis=None, broadcast=None):
     if not broadcast and a.shape.dims != b.shape.dims:
         raise tile.LogicError('Incompatible shapes in multiplication')
     if broadcast and (axis is not None):
         b = op.reshape(
             b,
             list(b.shape.dims) + ([1] *
                                   (a.shape.ndims - b.shape.ndims - axis)))
     return (a * b, )
Esempio n. 4
0
 def equal(a, b, axis=None, broadcast=None):
     if not broadcast and a.shape.dims != b.shape.dims:
         raise tile.LogicError('Incompatible shapes in equal')
     if broadcast and (axis is not None):
         b = op.reshape(
             b,
             list(b.shape.dims) + ([1] *
                                   (a.shape.ndims - b.shape.ndims - axis)))
     return (op.equal(a, b), )
Esempio n. 5
0
 def div(unused_ctx, a, b, axis=None, broadcast=None):
     if not broadcast and a.shape.dims != b.shape.dims:
         raise tile.LogicError('Incompatible shapes in division')
     if broadcast and (axis is not None):
         b = op.reshape(
             b,
             list(b.shape.dims) + ([1] *
                                   (a.shape.ndims - b.shape.ndims - axis)))
     return (a / b, )
Esempio n. 6
0
 def xor(a, b, axis=None, broadcast=None):
     if not broadcast and a.shape.dims != b.shape.dims:
         raise tile.LogicError('Incompatible shapes in subtraction')
     if broadcast and axis:
         b = op.reshape(
             b,
             list(b.shape.dims) + ([1] *
                                   (a.shape.ndims - b.shape.ndims - axis)))
     return (a ^ b, )
Esempio n. 7
0
 def greater(a, b, axis=None, broadcast=None):
     if not broadcast and a.shape.dims != b.shape.dims:
         raise tile.LogicError(
             'Incompatible shapes in logical > comparison')
     if broadcast and (axis is not None):
         b = op.reshape(
             b,
             list(b.shape.dims) + ([1] *
                                   (a.shape.ndims - b.shape.ndims - axis)))
     return (a > b, )
Esempio n. 8
0
 def gemm_reshape(value):
     if value.shape.ndims < 2:
         raise tile.LogicError(
             'Invalid Gemm input; two-dimensions required, got: {}'.
             format(value.shape))
     if value.shape.ndims == 2:
         return value
     newdims = (value.shape.dims[0],
                functools.reduce(lambda x, y: x * y,
                                 value.shape.dims[1:]))
     return op.reshape(value, newdims)
Esempio n. 9
0
    def batch_normalization(value,
                            scale,
                            bias,
                            mean,
                            variance,
                            epsilon=1e-5,
                            is_test=0,
                            momentum=.9,
                            spatial=1,
                            consumed_inputs=None):
        if not is_test:
            raise NotImplementedError()

        shape = [value.shape.dims[1]] + ([1] * (value.shape.ndims - 2))
        scale = op.reshape(scale, shape)
        bias = op.reshape(bias, shape)
        mean = op.reshape(mean, shape)
        variance = op.reshape(variance, shape)

        denom = op.sqrt(variance + epsilon)
        return (((value - mean) * scale / denom) + bias, )
Esempio n. 10
0
 def or_op(a, b, axis=None, broadcast=None):
     if not broadcast and a.shape.dims != b.shape.dims:
         raise tile.LogicError('Incompatible shapes in logical or')
     if broadcast and (axis is not None):
         b = op.reshape(
             b,
             list(b.shape.dims) + ([1] *
                                   (a.shape.ndims - b.shape.ndims - axis)))
     return (tile.binary_op(a,
                            b,
                            'L ? 1 : R',
                            dtype=plaidml.DType.BOOLEAN,
                            name='Or'), )
Esempio n. 11
0
 def and_op(unused_ctx, a, b, axis=None, broadcast=None):
     if not broadcast and a.shape.dims != b.shape.dims:
         raise tile.LogicError('Incompatible shapes in logical and')
     if broadcast and (axis is not None):
         b = op.reshape(
             b,
             list(b.shape.dims) + ([1] *
                                   (a.shape.ndims - b.shape.ndims - axis)))
     return (tile.binary_op(a,
                            b,
                            'cmp_eq(L ? R : 0, 1)',
                            dtype=plaidml.DType.BOOLEAN,
                            name='And'), )
Esempio n. 12
0
 def convolution(data,
                 kernel,
                 bias=None,
                 auto_pad=None,
                 dilations=None,
                 group=1,
                 kernel_shape=None,
                 pads=None,
                 strides=None):
     result = Convolution.function(data,
                                   kernel,
                                   auto_pad=auto_pad,
                                   dilations=dilations,
                                   group=group,
                                   kernel_shape=kernel_shape,
                                   pads=pads,
                                   strides=strides)
     if bias:
         bias = op.reshape(bias, [result.shape.dims[1]] +
                           ([1] * (result.shape.ndims - 2)))
         result += bias
     return (result, )
Esempio n. 13
0
 def reshape(data, shape=None):
     if not shape:
         raise tile.LogicError('Reshape requires a target shape')
     return (op.reshape(data, shape), )
Esempio n. 14
0
 def prelu(x, slope):
     if slope.shape.ndims == 1 and x.shape.ndims > 2:
         slope = op.reshape(slope, [slope.shape.dims[0]] + [1] *
                            (x.shape.ndims - 2))
     return (op.relu(x, alpha=slope), )