def dataset_to_stream(dataset, input_name, n_chunks=0): """Takes a tf.Dataset and creates a numpy stream of ready batches.""" for example in backend.dataset_as_numpy(dataset): inp, out = example[0][input_name], example[1] # All input-pipeline processing should be on CPU. with tf.device('cpu:0'): # Some accelerators don't handle uint8 well, cast to int. if isinstance(inp, np.uint8): inp = inp.astype(np.int32) if isinstance(out, np.uint8): out = out.astype(np.int32) if len(out.shape) > 1 and out.shape[-1] == 1: out = np.squeeze(out, axis=-1) if n_chunks > 0: inp = tuple(np.split(inp, n_chunks, axis=1)) out = tuple(np.split(out, n_chunks, axis=1)) yield inp, out
def dataset_to_stream(dataset, input_name, n_chunks=0, append_targets=False): """Takes a tf.Dataset and creates a numpy stream of ready batches.""" for example in backend.dataset_as_numpy(dataset): inp, out = example[0][input_name], example[1] # Some accelerators don't handle uint8 well, cast to int. if isinstance(inp, np.uint8): inp = inp.astype(np.int32) if isinstance(out, np.uint8): out = out.astype(np.int32) if len(out.shape) > 1 and out.shape[-1] == 1: out = np.squeeze(out, axis=-1) if n_chunks > 0: inp = tuple(np.split(inp, n_chunks, axis=1)) out = tuple(np.split(out, n_chunks, axis=1)) if append_targets: inp = (inp, out) yield inp, out