def test_imagenet_example_channels_first(): image, label = imagenet_example(data_format='channels_first') image2, _ = imagenet_example(data_format='channels_last') assert 0 <= label < 1000 assert isinstance(label, int) assert image.shape == (3, 224, 224) assert image.dtype == np.float32 for i in range(3): assert np.all(image[i] == image2[:, :, i])
def test_imagenet_example(): image, label = imagenet_example() assert 0 <= label < 1000 assert isinstance(label, int) assert image.shape == (224, 224, 3) assert image.dtype == np.float32