def make_zip_dataset_npy(A_npy_paths, B_npy_paths, batch_size): A_npy = tf.cast(np.load(A_npy_paths), tf.float32) B_npy = tf.cast(np.load(B_npy_paths), tf.float32) A_dataset = tl.memory_data_batch_dataset(A_npy, batch_size, repeat=1) B_dataset = tl.memory_data_batch_dataset(B_npy, batch_size, repeat=1) A_B_dataset = tf.data.Dataset.zip((A_dataset, B_dataset)) len_dataset = max(A_npy.shape[0], B_npy.shape[0]) // batch_size return A_B_dataset, len_dataset
def make_32x32_dataset(dataset, batch_size, drop_remainder=True, shuffle=True, repeat=1): if dataset == 'mnist': (train_images, _), (_, _) = tf.keras.datasets.mnist.load_data() train_images.shape = train_images.shape + (1,) elif dataset == 'fashion_mnist': (train_images, _), (_, _) = tf.keras.datasets.fashion_mnist.load_data() train_images.shape = train_images.shape + (1,) elif dataset == 'cifar10': (train_images, _), (_, _) = tf.keras.datasets.cifar10.load_data() else: raise NotImplementedError @tf.function def _map_fn(img): img = tf.image.resize(img, [32, 32]) img = tf.clip_by_value(img, 0, 255) img = img / 127.5 - 1 return img dataset = tl.memory_data_batch_dataset(train_images, batch_size, drop_remainder=drop_remainder, map_fn=_map_fn, shuffle=shuffle, repeat=repeat) img_shape = (32, 32, train_images.shape[-1]) len_dataset = len(train_images) // batch_size return dataset, img_shape, len_dataset
def make_dataset_npy(npy_paths, batch_size): npy = tf.cast(np.load(npy_paths), tf.float32) return tl.memory_data_batch_dataset(npy, batch_size, repeat=1)