def instance_normalization(unused_ctx, value, scale, bias, epsilon=1e-5): shape = [value.shape.dims[1]] + ([1] * (value.shape.ndims - 2)) scale = op.reshape(scale, shape) bias = op.reshape(bias, shape) mean = op.mean(value, axes=list(range(2, value.shape.ndims)), keepdims=True) variance = op.variance(value, axes=list(range(2, value.shape.ndims)), keepdims=True) denom = op.sqrt(variance + epsilon) return (((value - mean) * scale / denom) + bias, )
def pow(data, exponent, axis=None, broadcast=None): if not broadcast and data.shape.dims != exponent.shape.dims: raise tile.LogicError('Incompatible shapes in power') if broadcast and (axis is not None): exponent = op.reshape( exponent, list(exponent.shape.dims) + ([1] * (data.shape.ndims - exponent.shape.ndims - axis))) return (op.pow(data, exponent), )
def mul(a, b, axis=None, broadcast=None): if not broadcast and a.shape.dims != b.shape.dims: raise tile.LogicError('Incompatible shapes in multiplication') if broadcast and (axis is not None): b = op.reshape( b, list(b.shape.dims) + ([1] * (a.shape.ndims - b.shape.ndims - axis))) return (a * b, )
def equal(a, b, axis=None, broadcast=None): if not broadcast and a.shape.dims != b.shape.dims: raise tile.LogicError('Incompatible shapes in equal') if broadcast and (axis is not None): b = op.reshape( b, list(b.shape.dims) + ([1] * (a.shape.ndims - b.shape.ndims - axis))) return (op.equal(a, b), )
def div(unused_ctx, a, b, axis=None, broadcast=None): if not broadcast and a.shape.dims != b.shape.dims: raise tile.LogicError('Incompatible shapes in division') if broadcast and (axis is not None): b = op.reshape( b, list(b.shape.dims) + ([1] * (a.shape.ndims - b.shape.ndims - axis))) return (a / b, )
def xor(a, b, axis=None, broadcast=None): if not broadcast and a.shape.dims != b.shape.dims: raise tile.LogicError('Incompatible shapes in subtraction') if broadcast and axis: b = op.reshape( b, list(b.shape.dims) + ([1] * (a.shape.ndims - b.shape.ndims - axis))) return (a ^ b, )
def greater(a, b, axis=None, broadcast=None): if not broadcast and a.shape.dims != b.shape.dims: raise tile.LogicError( 'Incompatible shapes in logical > comparison') if broadcast and (axis is not None): b = op.reshape( b, list(b.shape.dims) + ([1] * (a.shape.ndims - b.shape.ndims - axis))) return (a > b, )
def gemm_reshape(value): if value.shape.ndims < 2: raise tile.LogicError( 'Invalid Gemm input; two-dimensions required, got: {}'. format(value.shape)) if value.shape.ndims == 2: return value newdims = (value.shape.dims[0], functools.reduce(lambda x, y: x * y, value.shape.dims[1:])) return op.reshape(value, newdims)
def batch_normalization(value, scale, bias, mean, variance, epsilon=1e-5, is_test=0, momentum=.9, spatial=1, consumed_inputs=None): if not is_test: raise NotImplementedError() shape = [value.shape.dims[1]] + ([1] * (value.shape.ndims - 2)) scale = op.reshape(scale, shape) bias = op.reshape(bias, shape) mean = op.reshape(mean, shape) variance = op.reshape(variance, shape) denom = op.sqrt(variance + epsilon) return (((value - mean) * scale / denom) + bias, )
def or_op(a, b, axis=None, broadcast=None): if not broadcast and a.shape.dims != b.shape.dims: raise tile.LogicError('Incompatible shapes in logical or') if broadcast and (axis is not None): b = op.reshape( b, list(b.shape.dims) + ([1] * (a.shape.ndims - b.shape.ndims - axis))) return (tile.binary_op(a, b, 'L ? 1 : R', dtype=plaidml.DType.BOOLEAN, name='Or'), )
def and_op(unused_ctx, a, b, axis=None, broadcast=None): if not broadcast and a.shape.dims != b.shape.dims: raise tile.LogicError('Incompatible shapes in logical and') if broadcast and (axis is not None): b = op.reshape( b, list(b.shape.dims) + ([1] * (a.shape.ndims - b.shape.ndims - axis))) return (tile.binary_op(a, b, 'cmp_eq(L ? R : 0, 1)', dtype=plaidml.DType.BOOLEAN, name='And'), )
def convolution(data, kernel, bias=None, auto_pad=None, dilations=None, group=1, kernel_shape=None, pads=None, strides=None): result = Convolution.function(data, kernel, auto_pad=auto_pad, dilations=dilations, group=group, kernel_shape=kernel_shape, pads=pads, strides=strides) if bias: bias = op.reshape(bias, [result.shape.dims[1]] + ([1] * (result.shape.ndims - 2))) result += bias return (result, )
def reshape(data, shape=None): if not shape: raise tile.LogicError('Reshape requires a target shape') return (op.reshape(data, shape), )
def prelu(x, slope): if slope.shape.ndims == 1 and x.shape.ndims > 2: slope = op.reshape(slope, [slope.shape.dims[0]] + [1] * (x.shape.ndims - 2)) return (op.relu(x, alpha=slope), )