Example #1
0
 def gather(data, indicies, axis=0):
     if axis != 0:
         raise NotImplementedError(
             'Gather with a non-zero axis is not yet implemented by the PlaidML ONNX backend'
         )
     if indicies.shape.dtype == plaidml.DType.INT64:
         # TODO: Long-term, it'd be a fine thing to have PlaidML accept
         # an int64 indicies input to gather().
         indicies = op.cast(indicies, plaidml.DType.INT32)
     return (op.gather(data, indicies), )
Example #2
0
 def cast(x, to):
     dtype = opset_util.ONNX_DTYPE_TO_PLAIDML[
         onnx_pb2.TensorProto.DataType.Value(to.decode('utf-8'))]
     return (op.cast(x, dtype), )
Example #3
0
 def cast(unused_ctx, x, to):
     dtype = opset_util.ONNX_DTYPE_TO_PLAIDML[to]
     return (op.cast(x, dtype), )
Example #4
0
 def size(unused_ctx, value):
     return (op.cast(op.prod(op.shape_of(value)), plaidml.DType.INT64), )
Example #5
0
 def shape(unused_ctx, data):
     return (op.cast(op.shape_of(data), plaidml.DType.INT64), )
Example #6
0
 def gather(data, indicies):
     if indicies.shape.dtype == plaidml.DType.INT64:
         # TODO: Long-term, it'd be a fine thing to have PlaidML accept
         # an int64 indicies input to gather().
         indicies = op.cast(indicies, plaidml.DType.INT32)
     return (op.gather(data, indicies), )