def gather(data, indicies, axis=0): if axis != 0: raise NotImplementedError( 'Gather with a non-zero axis is not yet implemented by the PlaidML ONNX backend' ) if indicies.shape.dtype == plaidml.DType.INT64: # TODO: Long-term, it'd be a fine thing to have PlaidML accept # an int64 indicies input to gather(). indicies = op.cast(indicies, plaidml.DType.INT32) return (op.gather(data, indicies), )
def cast(x, to): dtype = opset_util.ONNX_DTYPE_TO_PLAIDML[ onnx_pb2.TensorProto.DataType.Value(to.decode('utf-8'))] return (op.cast(x, dtype), )
def cast(unused_ctx, x, to): dtype = opset_util.ONNX_DTYPE_TO_PLAIDML[to] return (op.cast(x, dtype), )
def size(unused_ctx, value): return (op.cast(op.prod(op.shape_of(value)), plaidml.DType.INT64), )
def shape(unused_ctx, data): return (op.cast(op.shape_of(data), plaidml.DType.INT64), )
def gather(data, indicies): if indicies.shape.dtype == plaidml.DType.INT64: # TODO: Long-term, it'd be a fine thing to have PlaidML accept # an int64 indicies input to gather(). indicies = op.cast(indicies, plaidml.DType.INT32) return (op.gather(data, indicies), )