示例#1
0
 def mean(*tensors):
     if len(tensors) == 0:
         raise tile.LogicError(
             'Must supply at least one tensor in a Mean operation')
     if len(tensors) == 1:
         return (tensors[0], )
     return (Mean.function(tensors), )
示例#2
0
 def pow(data, exponent, axis=None, broadcast=None):
     if not broadcast and data.shape.dims != exponent.shape.dims:
         raise tile.LogicError('Incompatible shapes in power')
     if broadcast and (axis is not None):
         exponent = op.reshape(
             exponent,
             list(exponent.shape.dims) +
             ([1] * (data.shape.ndims - exponent.shape.ndims - axis)))
     return (op.pow(data, exponent), )
示例#3
0
 def mul(a, b, axis=None, broadcast=None):
     if not broadcast and a.shape.dims != b.shape.dims:
         raise tile.LogicError('Incompatible shapes in multiplication')
     if broadcast and (axis is not None):
         b = op.reshape(
             b,
             list(b.shape.dims) + ([1] *
                                   (a.shape.ndims - b.shape.ndims - axis)))
     return (a * b, )
示例#4
0
 def div(unused_ctx, a, b, axis=None, broadcast=None):
     if not broadcast and a.shape.dims != b.shape.dims:
         raise tile.LogicError('Incompatible shapes in division')
     if broadcast and (axis is not None):
         b = op.reshape(
             b,
             list(b.shape.dims) + ([1] *
                                   (a.shape.ndims - b.shape.ndims - axis)))
     return (a / b, )
示例#5
0
 def equal(a, b, axis=None, broadcast=None):
     if not broadcast and a.shape.dims != b.shape.dims:
         raise tile.LogicError('Incompatible shapes in equal')
     if broadcast and (axis is not None):
         b = op.reshape(
             b,
             list(b.shape.dims) + ([1] *
                                   (a.shape.ndims - b.shape.ndims - axis)))
     return (op.equal(a, b), )
示例#6
0
 def __init__(self, x):
     if x.shape.dtype != plaidml.DType.BOOLEAN:
         raise tile.LogicError(
             'Logical Not requires a boolean tensor input')
     super(Not, self).__init__(
         """
         function (I) -> (O) {
             O = cmp_eq(I, 0);
         } """, [('I', x)], [('O', x.shape)])
示例#7
0
 def xor(a, b, axis=None, broadcast=None):
     if not broadcast and a.shape.dims != b.shape.dims:
         raise tile.LogicError('Incompatible shapes in subtraction')
     if broadcast and axis:
         b = op.reshape(
             b,
             list(b.shape.dims) + ([1] *
                                   (a.shape.ndims - b.shape.ndims - axis)))
     return (a ^ b, )
示例#8
0
 def greater(a, b, axis=None, broadcast=None):
     if not broadcast and a.shape.dims != b.shape.dims:
         raise tile.LogicError(
             'Incompatible shapes in logical > comparison')
     if broadcast and (axis is not None):
         b = op.reshape(
             b,
             list(b.shape.dims) + ([1] *
                                   (a.shape.ndims - b.shape.ndims - axis)))
     return (a > b, )
示例#9
0
 def gemm_reshape(value):
     if value.shape.ndims < 2:
         raise tile.LogicError(
             'Invalid Gemm input; two-dimensions required, got: {}'.
             format(value.shape))
     if value.shape.ndims == 2:
         return value
     newdims = (value.shape.dims[0],
                functools.reduce(lambda x, y: x * y,
                                 value.shape.dims[1:]))
     return op.reshape(value, newdims)
示例#10
0
 def or_op(a, b, axis=None, broadcast=None):
     if not broadcast and a.shape.dims != b.shape.dims:
         raise tile.LogicError('Incompatible shapes in logical or')
     if broadcast and (axis is not None):
         b = op.reshape(
             b,
             list(b.shape.dims) + ([1] *
                                   (a.shape.ndims - b.shape.ndims - axis)))
     return (tile.binary_op(a,
                            b,
                            'L ? 1 : R',
                            dtype=plaidml.DType.BOOLEAN,
                            name='Or'), )
示例#11
0
 def and_op(unused_ctx, a, b, axis=None, broadcast=None):
     if not broadcast and a.shape.dims != b.shape.dims:
         raise tile.LogicError('Incompatible shapes in logical and')
     if broadcast and (axis is not None):
         b = op.reshape(
             b,
             list(b.shape.dims) + ([1] *
                                   (a.shape.ndims - b.shape.ndims - axis)))
     return (tile.binary_op(a,
                            b,
                            'cmp_eq(L ? R : 0, 1)',
                            dtype=plaidml.DType.BOOLEAN,
                            name='And'), )
示例#12
0
    def pad(data, mode=None, pads=None, value=None):
        if not mode:
            mode = 'constant'
            padding_mode = PadConstant
        else:
            try:
                padding_mode = _CONV_PADDING_MODE[mode]
            except KeyError:
                six.raise_from(
                    ValueError('Unsupported padding mode: {}'.format(mode)),
                    None)
        if not pads or len(pads) != 2 * data.shape.ndims:
            raise tile.LogicError(
                'Inconsistant padding request; rank={}, #pads={}'.format(
                    data.shape.ndims,
                    len(pads) if pads else 0))

        return (padding_mode.function(data, pads=pads, mode=mode,
                                      value=value), )
示例#13
0
 def reshape(data, shape=None):
     if not shape:
         raise tile.LogicError('Reshape requires a target shape')
     return (op.reshape(data, shape), )