def augment_ds(ds: Dataset, grayscale: bool) -> Dataset: if not grayscale: ds = ds.map( lambda x, y: (_random_hue_saturation_brightness_contrast(x), y), num_parallel_calls=AUTOTUNE, ) if grayscale: ds = ds.map(lambda x, y: (_random_crop_mnist(x), y), num_parallel_calls=AUTOTUNE) else: ds = ds.map(lambda x, y: (_random_crop_cifar(x), y), num_parallel_calls=AUTOTUNE) ds = ds.map(lambda x, y: (_random_horizontal_flip(x), y), num_parallel_calls=AUTOTUNE) return ds
def prepare(ds: Dataset, num_classes: int) -> Dataset: """Prepares dataset for training by - Casting color channel values to float, divide by 255 - One-hot encode labels Args: ds (Dataset): TensorFlow Dataset num_classes (int): Number of classes present in federated dataset partition Returns: Dataset """ ds = ds.map(lambda x, y: (x, _prep_cast_label(y))) ds = ds.map(lambda x, y: (_prep_cast_divide(x), y)) ds = ds.map(lambda x, y: (x, _prep_one_hot(y, num_classes))) return ds
def prepare(ds: Dataset, num_classes: int) -> Dataset: ds = ds.map(lambda x, y: (x, _prep_cast_label(y))) ds = ds.map(lambda x, y: (_prep_cast_divide(x), y)) ds = ds.map(lambda x, y: (x, _prep_one_hot(y, num_classes))) return ds