Example #1
0
 def __init__(self, filename, internal=False):
     with tf.name_scope("EXRIOTensor"):
         data = tf.io.read_file(filename)
         shapes, dtypes, channels = core_ops.io_decode_exr_info(data)
         parts = []
         index = 0
         for (shape, dtypes, channels) in zip(shapes.numpy(),
                                              dtypes.numpy(),
                                              channels.numpy()):
             # Remove trailing 0 from dtypes
             while dtypes[-1] == 0:
                 dtypes.pop()
                 channels.pop()
             spec = tuple([
                 tf.TensorSpec(tf.TensorShape(shape), dtype)
                 for dtype in dtypes
             ])
             columns = [channel.decode() for channel in channels]
             elements = [
                 io_tensor_ops.TensorIOTensor(core_ops.io_decode_exr(
                     data, index, channel, dtype=dtype),
                                              internal=internal)
                 for (channel, dtype) in zip(columns, dtypes)
             ]
             parts.append(
                 EXRPartIOTensor(spec, columns, elements,
                                 internal=internal))
             index += 1
         spec = tuple([part.spec for part in parts])
         columns = [i for i, _ in enumerate(parts)]
         super(EXRIOTensor, self).__init__(spec,
                                           columns,
                                           parts,
                                           internal=internal)
Example #2
0
def decode_exr(contents, index, channel, dtype, name=None):
  """
  Decode a EXR-encoded image to a uint8 tensor.

  Args:
    contents: A `Tensor` of type `string`. 0-D.  The EXR-encoded image.
    index: A `Tensor` of type int64. 0-D. The 0-based index of the frame
      inside EXR-encoded image.
    channel: A `Tensor` of type string. 0-D. The channel inside the image.
    name: A name for the operation (optional).

  Returns:
    A `Tensor` of type `uint8` and shape of `[height, width, 4]` (RGBA).
  """
  return core_ops.io_decode_exr(
      contents, index=index, channel=channel, dtype=dtype, name=name)