Exemplo n.º 1
0
def dataset_to_stream(dataset, input_name, n_chunks=0):
    """Takes a tf.Dataset and creates a numpy stream of ready batches."""
    for example in backend.dataset_as_numpy(dataset):
        inp, out = example[0][input_name], example[1]
        # All input-pipeline processing should be on CPU.
        with tf.device('cpu:0'):
            # Some accelerators don't handle uint8 well, cast to int.
            if isinstance(inp, np.uint8):
                inp = inp.astype(np.int32)
            if isinstance(out, np.uint8):
                out = out.astype(np.int32)
            if len(out.shape) > 1 and out.shape[-1] == 1:
                out = np.squeeze(out, axis=-1)
            if n_chunks > 0:
                inp = tuple(np.split(inp, n_chunks, axis=1))
                out = tuple(np.split(out, n_chunks, axis=1))
        yield inp, out
Exemplo n.º 2
0
def dataset_to_stream(dataset, input_name, n_chunks=0, append_targets=False):
    """Takes a tf.Dataset and creates a numpy stream of ready batches."""
    for example in backend.dataset_as_numpy(dataset):
        inp, out = example[0][input_name], example[1]
        # Some accelerators don't handle uint8 well, cast to int.
        if isinstance(inp, np.uint8):
            inp = inp.astype(np.int32)
        if isinstance(out, np.uint8):
            out = out.astype(np.int32)
        if len(out.shape) > 1 and out.shape[-1] == 1:
            out = np.squeeze(out, axis=-1)
        if n_chunks > 0:
            inp = tuple(np.split(inp, n_chunks, axis=1))
            out = tuple(np.split(out, n_chunks, axis=1))
        if append_targets:
            inp = (inp, out)
        yield inp, out