def testLowLevelIndexedDatasetOps(self):
        identity = ged_ops.experimental_identity_indexed_dataset(
            ops.convert_to_tensor(16, dtype=dtypes.uint64))
        handle = ged_ops.experimental_materialized_index_dataset_handle(
            container="",
            shared_name="",
            output_types=[dtypes.uint64],
            output_shapes=[[]])
        materialize = ged_ops.experimental_indexed_dataset_materialize(
            identity, handle)
        get_op = ged_ops.experimental_indexed_dataset_get(
            handle, 3, output_types=[dtypes.uint64], output_shapes=[[]])

        self.evaluate(materialize)
        self.assertEqual([3], self.evaluate(get_op))
  def testLowLevelIndexedDatasetOps(self):
    identity = ged_ops.experimental_identity_indexed_dataset(
        ops.convert_to_tensor(16, dtype=dtypes.uint64))
    handle = ged_ops.experimental_materialized_index_dataset_handle(
        container="",
        shared_name="",
        output_types=[dtypes.uint64],
        output_shapes=[[]])
    materialize = ged_ops.experimental_indexed_dataset_materialize(
        identity, handle)
    get_op = ged_ops.experimental_indexed_dataset_get(
        handle, 3, output_types=[dtypes.uint64], output_shapes=[[]])

    self.evaluate(materialize)
    self.assertEqual([3], self.evaluate(get_op))
Esempio n. 3
0
  def testLowLevelIndexedDatasetOps(self):
    identity = ged_ops.experimental_identity_indexed_dataset(
        ops.convert_to_tensor(16, dtype=dtypes.uint64))
    handle = ged_ops.experimental_materialized_index_dataset_handle(
        container="",
        shared_name="",
        output_types=[dtypes.uint64],
        output_shapes=[[]])
    materialize = ged_ops.experimental_indexed_dataset_materialize(
        identity, handle)
    index = array_ops.placeholder(dtypes.uint64)
    get_op = ged_ops.experimental_indexed_dataset_get(
        handle, index, output_types=[dtypes.uint64], output_shapes=[[]])

    with self.cached_session() as sess:
      sess.run(materialize)
      self.assertEqual([3], sess.run(get_op, feed_dict={index: 3}))
Esempio n. 4
0
  def get(self, index):
    """Get retrieves a value (or set of values) from the IndexedDataset.

    Args:
      index: A uint64 scalar or vector tensor with the indices to retrieve.

    Returns:
      A tensor containing the values corresponding to `index`.
    """
    # TODO(saeta): nest.pack_sequence_as(...)
    return ged_ops.experimental_indexed_dataset_get(
        self._materialized_resource,
        index,
        output_types=nest.flatten(
            sparse.as_dense_types(self._output_types, self._output_classes)),
        output_shapes=nest.flatten(
            sparse.as_dense_types(self._output_shapes, self._output_classes)))
  def get(self, index):
    """Get retrieves a value (or set of values) from the IndexedDataset.

    Args:
      index: A uint64 scalar or vector tensor with the indices to retrieve.

    Returns:
      A tensor containing the values corresponding to `index`.
    """
    # TODO(saeta): nest.pack_sequence_as(...)
    return ged_ops.experimental_indexed_dataset_get(
        self._materialized_resource,
        index,
        output_types=nest.flatten(
            sparse.as_dense_types(self._output_types, self._output_classes)),
        output_shapes=nest.flatten(
            sparse.as_dense_types(self._output_shapes, self._output_classes)))