Exemple #1
0
    def setUp(self):
        def parse_function(example_address):
            return example_address

        self.parse_function = parse_function

        self.omniglot_database = OmniglotDatabase(random_seed=-1,
                                                  num_train_classes=1200,
                                                  num_val_classes=100)
        self.mini_imagenet_database = MiniImagenetDatabase()
        self.celeba_database = CelebADatabase()
        self.lfw_database = LFWDatabase()
        self.euro_sat_database = EuroSatDatabase()
        self.isic_database = ISICDatabase()
        self.chest_x_ray_database = ChestXRay8Database()
def run_omniglot():
    omniglot_database = OmniglotDatabase(
        random_seed=47,
        num_train_classes=1200,
        num_val_classes=100,
    )

    proto_net = PrototypicalNetworks(
        database=omniglot_database,
        network_cls=SimpleModelProto,
        n=5,
        k_ml=1,
        k_val_ml=5,
        k_val=1,
        k_val_val=15,
        k_test=1,
        k_val_test=15,
        meta_batch_size=4,
        save_after_iterations=1000,
        meta_learning_rate=0.001,
        report_validation_frequency=200,
        log_train_images_after_iteration=200,  # Set to -1 if you do not want to log train images.
        num_tasks_val=100,
        val_seed=-1,
        experiment_name=None
    )

    # proto_net.train(iterations=5000)
    proto_net.evaluate(-1, num_tasks=1000)
def run_transfer_learning():
    omniglot_database = OmniglotDatabase(random_seed=47, num_train_classes=1200, num_val_classes=100)
    transfer_learning = TransferLearning(
        database=omniglot_database,
        network_cls=get_network,
        n=20,
        k_ml=1,
        k_val_ml=5,
        k_val_val=15,
        k_val=1,
        k_test=5,
        k_val_test=15,
        meta_batch_size=4,
        num_steps_ml=5,
        lr_inner_ml=0.05,
        num_steps_validation=5,
        save_after_iterations=15000,
        meta_learning_rate=0.001,
        report_validation_frequency=250,
        log_train_images_after_iteration=1000,
        num_tasks_val=100,
        clip_gradients=True,
        experiment_name='fixed_vgg_19',
        val_seed=42,
        val_test_batch_norm_momentum=0.0,
    )
    print(f'k: {transfer_learning.k_test}')
    transfer_learning.evaluate(50, num_tasks=1000, seed=42, use_val_batch_statistics=False)
def run_omniglot():
    omniglot_database = OmniglotDatabase(
        random_seed=47,
        num_train_classes=1200,
        num_val_classes=100,
    )

    maml = ModelAgnosticMetaLearningModel(
        database=omniglot_database,
        network_cls=SimpleModel,
        n=20,
        k_ml=1,
        k_val_ml=5,
        k_val=1,
        k_val_val=15,
        k_test=5,
        k_val_test=15,
        meta_batch_size=4,
        num_steps_ml=5,
        lr_inner_ml=0.4,
        num_steps_validation=5,
        save_after_iterations=1000,
        meta_learning_rate=0.001,
        report_validation_frequency=50,
        log_train_images_after_iteration=200,
        num_tasks_val=100,
        clip_gradients=False,
        experiment_name='omniglot',
        val_seed=42,
        val_test_batch_norm_momentum=0.0
    )

    # maml.train(iterations=5000)
    maml.evaluate(iterations=50, num_tasks=1000, use_val_batch_statistics=True, seed=42)
Exemple #5
0
def run_domain_attention():
    train_domain_databases = [
        MiniImagenetDatabase(),
        OmniglotDatabase(random_seed=47,
                         num_train_classes=1200,
                         num_val_classes=100),
        DTDDatabase(),
        VGGFlowerDatabase()
    ]
    meta_train_domain_databases = [
        # AirplaneDatabase(),
        FungiDatabase(),
        # CUBDatabase(),
    ]

    test_database = FungiDatabase()

    da = DomainAttentionUnsupervised(
        train_databases=train_domain_databases,
        meta_train_databases=meta_train_domain_databases,
        database=test_database,
        test_database=test_database,
        network_cls=None,
        image_shape=(84, 84, 3),
        n=5,
        k_ml=1,
        k_val_ml=5,
        k_val=1,
        k_val_val=15,
        k_test=5,
        k_val_test=15,
        meta_batch_size=4,
        num_steps_ml=1,
        lr_inner_ml=0.05,
        num_steps_validation=5,
        save_after_iterations=5000,
        meta_learning_rate=0.001,
        report_validation_frequency=1000,
        log_train_images_after_iteration=1000,
        num_tasks_val=100,
        clip_gradients=True,
        experiment_name=
        'domain_attention_all_frozen_layers_unsupervised_fungi2',
        val_seed=42,
        val_test_batch_norm_momentum=0.0,
    )

    print(da.experiment_name)

    # da.train(iterations=5000)
    da.evaluate(iterations=50, num_tasks=1000, seed=42)
Exemple #6
0
 def __init__(self, meta_train_databases=None, *args, **kwargs):
     super(CombinedCrossDomainMetaLearning, self).__init__(*args, **kwargs)
     if meta_train_databases is None:
         self.meta_train_databases = [
             MiniImagenetDatabase(),
             AirplaneDatabase(),
             CUBDatabase(),
             OmniglotDatabase(random_seed=47,
                              num_train_classes=1200,
                              num_val_classes=100),
             DTDDatabase(),
             FungiDatabase(),
             VGGFlowerDatabase()
         ]
     else:
         self.meta_train_databases = meta_train_databases
def run_domain_attention():
    train_domain_databases = [
        MiniImagenetDatabase(),
        OmniglotDatabase(random_seed=47,
                         num_train_classes=1200,
                         num_val_classes=100),
        DTDDatabase(),
        VGGFlowerDatabase()
    ]
    meta_train_domain_databases = [
        AirplaneDatabase(),
        FungiDatabase(),
        CUBDatabase(),
    ]

    test_database = EuroSatDatabase()

    ewda = ElementWiseDomainAttention(
        train_databases=train_domain_databases,
        meta_train_databases=meta_train_domain_databases,
        database=test_database,
        test_database=test_database,
        network_cls=None,
        image_shape=(84, 84, 3),
        n=5,
        k_ml=1,
        k_val_ml=5,
        k_val=1,
        k_val_val=15,
        k_test=5,
        k_val_test=15,
        meta_batch_size=4,
        num_steps_ml=5,
        lr_inner_ml=0.05,
        num_steps_validation=5,
        save_after_iterations=15000,
        meta_learning_rate=0.001,
        report_validation_frequency=1000,
        log_train_images_after_iteration=1000,
        num_tasks_val=100,
        clip_gradients=True,
        experiment_name='element_wise_domain_attention',
        val_seed=42,
        val_test_batch_norm_momentum=0.0,
    )
    ewda.train(iterations=60000)
    ewda.evaluate(iterations=50, num_tasks=1000, seed=14)
    def get_train_dataset(self):
        databases = [
            MiniImagenetDatabase(),
            AirplaneDatabase(),
            CUBDatabase(),
            OmniglotDatabase(random_seed=47, num_train_classes=1200, num_val_classes=100),
            DTDDatabase(),
            FungiDatabase(),
            VGGFlowerDatabase()
        ]

        dataset = self.get_cross_domain_meta_learning_dataset(
            databases=databases,
            n=self.n,
            k_ml=self.k_ml,
            k_validation=self.k_val_ml,
            meta_batch_size=self.meta_batch_size
        )
        return dataset
Exemple #9
0
def run_omniglot():
    omniglot_database = OmniglotDatabase(
        random_seed=30,
        num_train_classes=1200,
        num_val_classes=100,
    )
    n_data_points = 10000

    data_points, classes, non_labeled_data_points = sample_data_points(
        omniglot_database.train_folders, n_data_points)
    features_dataset, n_classes = make_features_dataset_omniglot(
        data_points, classes, non_labeled_data_points)
    print(n_classes)
    feature_model = SimpleModelFeature(num_classes=5).get_sequential_model()
    feature_model = train_the_feature_model(feature_model, features_dataset,
                                            n_classes,
                                            omniglot_database.input_shape)

    sml = SML(database=omniglot_database,
              network_cls=SimpleModel,
              n=5,
              k=1,
              k_val_ml=5,
              k_val_val=15,
              k_val_test=15,
              meta_batch_size=32,
              num_steps_ml=5,
              lr_inner_ml=0.01,
              num_steps_validation=5,
              save_after_epochs=5,
              meta_learning_rate=0.001,
              n_clusters=1200,
              feature_model=feature_model,
              feature_size=256,
              input_shape=(28, 28, 1),
              log_train_images_after_iteration=10,
              report_validation_frequency=10,
              experiment_name='omniglot_vae_model_feature_10000')

    sml.train(epochs=101)
def run_omniglot():
    omniglot_database = OmniglotDatabase(
        random_seed=47,
        num_train_classes=1200,
        num_val_classes=100,
    )

    anil = ANIL(database=omniglot_database,
                network_cls=SimpleModel,
                n=5,
                k_ml=1,
                k_val_ml=5,
                k_val=1,
                k_val_val=15,
                k_test=1,
                k_val_test=15,
                meta_batch_size=4,
                num_steps_ml=5,
                lr_inner_ml=0.4,
                num_steps_validation=5,
                save_after_iterations=1000,
                meta_learning_rate=0.001,
                report_validation_frequency=50,
                log_train_images_after_iteration=200,
                num_tasks_val=100,
                clip_gradients=False,
                experiment_name='omniglot3',
                val_seed=42,
                val_test_batch_norm_momentum=0.0,
                set_of_frozen_layers={
                    'conv1', 'conv2', 'conv3', 'conv4', 'bn1', 'bn2', 'bn3',
                    'bn4'
                })

    anil.train(iterations=5000)
    anil.evaluate(iterations=50,
                  num_tasks=1000,
                  use_val_batch_statistics=True,
                  seed=42)
def run_omniglot():
    omniglot_database = OmniglotDatabase(random_seed=47,
                                         num_train_classes=1200,
                                         num_val_classes=100)
    base_model = tf.keras.applications.VGG19(weights='imagenet')
    feature_model = tf.keras.models.Model(inputs=base_model.input,
                                          outputs=base_model.layers[24].output)

    sml = SML(
        database=omniglot_database,
        network_cls=SimpleModel,
        n=5,
        k=1,
        k_val_ml=5,
        k_val_val=15,
        k_val_test=15,
        k_test=1,
        meta_batch_size=4,
        num_steps_ml=5,
        lr_inner_ml=0.05,
        num_steps_validation=5,
        save_after_iterations=15000,
        meta_learning_rate=0.001,
        n_clusters=5000,
        feature_model=feature_model,
        # feature_size=288,
        feature_size=4096,
        input_shape=(224, 224, 3),
        preprocess_function=tf.keras.applications.vgg19.preprocess_input,
        log_train_images_after_iteration=1000,
        number_of_tasks_val=100,
        number_of_tasks_test=1000,
        clip_gradients=True,
        report_validation_frequency=250,
        experiment_name='omniglot_imagenet_features')
    sml.train(iterations=60000)
    sml.evaluate(iterations=50, seed=42)
def run_omniglot():
    omniglot_database = OmniglotDatabase(
        random_seed=47,
        num_train_classes=1200,
        num_val_classes=100,
    )

    maml = MAMLAbstractLearner(
        database=omniglot_database,
        network_cls=SimpleModel,
        n=5,
        k=1,
        meta_batch_size=32,
        num_steps_ml=1,
        lr_inner_ml=0.4,
        num_steps_validation=10,
        save_after_epochs=500,
        meta_learning_rate=0.001,
        report_validation_frequency=10,
        log_train_images_after_iteration=-1,
    )

    maml.train(epochs=4000)
    maml.evaluate(iterations=50)
def run_omniglot():
    omniglot_database = OmniglotDatabase(
        random_seed=47,
        num_train_classes=1200,
        num_val_classes=100,
    )

    gan_sampling = GANSampling(
        database=omniglot_database,
        network_cls=SimpleModel,
        n=5,
        k=1,
        k_val_ml=5,
        k_val_train=1,
        k_val_val=15,
        k_val_test=15,
        k_test=1,
        meta_batch_size=32,
        num_steps_ml=5,  # 1 for prev result
        lr_inner_ml=0.4,
        num_steps_validation=5,
        save_after_iterations=1000,
        meta_learning_rate=0.001,
        report_validation_frequency=50,
        log_train_images_after_iteration=200,
        number_of_tasks_val=100,
        number_of_tasks_test=1000,
        clip_gradients=False,
        experiment_name='omniglot',
        val_seed=42,
        val_test_batch_norm_momentum=0.0)

    gan_sampling.train(iterations=5000)
    gan_sampling.evaluate(iterations=50,
                          use_val_batch_statistics=True,
                          seed=42)
from databases import OmniglotDatabase
from models.maml_umtra.maml_umtra import MAMLUMTRA
from networks.maml_umtra_networks import SimpleModel

if __name__ == '__main__':
    # import tensorflow as tf
    # tf.config.experimental_run_functions_eagerly(True)

    omniglot_database = OmniglotDatabase(random_seed=47,
                                         num_train_classes=1200,
                                         num_val_classes=100)

    maml_umtra = MAMLUMTRA(database=omniglot_database,
                           network_cls=SimpleModel,
                           n=5,
                           k=1,
                           k_val_ml=5,
                           k_val_val=15,
                           k_val_test=15,
                           k_test=5,
                           meta_batch_size=4,
                           num_steps_ml=5,
                           lr_inner_ml=0.4,
                           num_steps_validation=5,
                           save_after_iterations=1000,
                           meta_learning_rate=0.001,
                           report_validation_frequency=200,
                           log_train_images_after_iteration=200,
                           number_of_tasks_val=100,
                           number_of_tasks_test=1000,
                           clip_gradients=False,
    # fig.suptitle('', fontsize=12, y=1)
    plt.savefig(fname=os.path.join(root_folder_to_save, 'all_domains2.pdf'))
    plt.show()


if __name__ == '__main__':
    tf.config.set_visible_devices([], 'GPU')
    root_folder_to_save = os.path.expanduser('~/datasets_visualization/')
    if not os.path.exists(root_folder_to_save):
        os.mkdir(root_folder_to_save)

    databases = (
        MiniImagenetDatabase(),
        OmniglotDatabase(random_seed=42,
                         num_train_classes=1200,
                         num_val_classes=100),
        AirplaneDatabase(),
        CUBDatabase(),
        DTDDatabase(),
        FungiDatabase(),
        VGGFlowerDatabase(),
        # TrafficSignDatabase(),
        # MSCOCODatabase(),
        # PlantDiseaseDatabase(),
        # EuroSatDatabase(),
        # ISICDatabase(),
        # ChestXRay8Database(),
    )

    # for database in databases:
Exemple #16
0
        prefix = sys.argv[2]
    else:
        print("Error: Too many arguments")
        sys.exit(0)

    # CONFIGS
    ITERATIONS = 5000
    GAN_EPOCHS = 500
    GAN_CHECKPOINTS = 50
    N_TASK_EVAL = 1000 
    K = 1
    TRAIN_GAN = True
    LASIUM_TYPE = "p1"
    GAN_N_ALT = 50 # How many times to alternate between unlabeled and labeled

    omniglot_database = OmniglotDatabase(random_seed=47, num_train_classes=1200, num_val_classes=100)
    shape = (28, 28, 1)
    latent_dim = 128
    omniglot_generator = get_generator(latent_dim)
    omniglot_parser = OmniglotParser(shape=shape)

    experiment_name = prefix+str(labeled_percentage)
    if GAN_N_ALT > 1:
        experiment_name += "_alt"+str(GAN_N_ALT)

    # Split labeled and not labeled
    train_folders = omniglot_database.train_folders
    keys = list(train_folders.keys())
    labeled_keys = np.random.choice(keys, int(len(train_folders.keys())*labeled_percentage), replace=False)
    train_folders_labeled = {k: v for (k, v) in train_folders.items() if k in labeled_keys}
    train_folders_unlabeled = {k: v for (k, v) in train_folders.items() if k not in labeled_keys}
Exemple #17
0
class TestDatabases(unittest.TestCase):
    def setUp(self):
        def parse_function(example_address):
            return example_address

        self.parse_function = parse_function

        self.omniglot_database = OmniglotDatabase(random_seed=-1,
                                                  num_train_classes=1200,
                                                  num_val_classes=100)
        self.mini_imagenet_database = MiniImagenetDatabase()
        self.celeba_database = CelebADatabase()
        self.lfw_database = LFWDatabase()
        self.euro_sat_database = EuroSatDatabase()
        self.isic_database = ISICDatabase()
        self.chest_x_ray_database = ChestXRay8Database()

    @property
    def databases(self):
        return (
            self.omniglot_database,
            self.mini_imagenet_database,
            self.celeba_database,
            self.lfw_database,
            self.euro_sat_database,
            self.isic_database,
            self.chest_x_ray_database,
        )

    def test_train_val_test_folders_are_separate(self):
        for database in (self.omniglot_database, self.mini_imagenet_database,
                         self.celeba_database):
            train_folders = set(database.train_folders)
            val_folders = set(database.val_folders)
            test_folders = set(database.test_folders)

            for folder in train_folders:
                self.assertNotIn(folder, val_folders)
                self.assertNotIn(folder, test_folders)

            for folder in val_folders:
                self.assertNotIn(folder, train_folders)
                self.assertNotIn(folder, test_folders)

            for folder in test_folders:
                self.assertNotIn(folder, train_folders)
                self.assertNotIn(folder, val_folders)

    def test_train_val_test_folders_are_comprehensive(self):
        self.assertEqual(
            1623,
            len(
                list(self.omniglot_database.train_folders.keys()) +
                list(self.omniglot_database.val_folders.keys()) +
                list(self.omniglot_database.test_folders.keys())))
        self.assertEqual(
            100,
            len(
                list(self.mini_imagenet_database.train_folders) +
                list(self.mini_imagenet_database.val_folders) +
                list(self.mini_imagenet_database.test_folders)))
        self.assertEqual(
            10177,
            len(
                list(self.celeba_database.train_folders) +
                list(self.celeba_database.val_folders) +
                list(self.celeba_database.test_folders)))

    @patch('tf_datasets.OmniglotDatabase._get_parse_function')
    def test_covering_all_classes_in_one_epoch(self, mocked_parse_function):
        # Make a new database so that the number of classes are dividable by the number of meta batches * n.
        mocked_parse_function.return_value = self.parse_function
        ds = self.omniglot_database.get_supervised_meta_learning_dataset(
            self.omniglot_database.train_folders,
            n=5,
            k=5,
            k_validation=5,
            meta_batch_size=4,
            dtype=tf.string)

        # Check for covering all classes and no duplicates
        classes = set()
        for (x, x_val), (y, y_val) in ds:
            train_ds = combine_first_two_axes(x)
            for i in range(train_ds.shape[0]):
                class_instances = train_ds[i, ...]
                class_instance_address = class_instances[0].numpy().decode(
                    'utf-8')
                class_address = os.path.split(class_instance_address)[0]
                self.assertNotIn(class_address, classes)
                classes.add(class_address)

        self.assertSetEqual(classes,
                            set(self.omniglot_database.train_folders.keys()))

    @patch('tf_datasets.OmniglotDatabase._get_parse_function')
    def test_covering_all_classes_in_subsequent_epochs(self,
                                                       mocked_parse_function):
        # This test might not pass because of the random selection of classes at the beginning of the class, but the
        # chances are low.  Specially if we increase the number of epochs, the chance of not covering classes will
        # decrease. Make a new database so that the number of classes are dividable by the number of meta batches * n.
        # This test should always fail with num_epochs = 1
        num_epochs = 3
        mocked_parse_function.return_value = self.parse_function
        ds = self.omniglot_database.get_supervised_meta_learning_dataset(
            self.omniglot_database.train_folders,
            n=7,
            k=5,
            k_validation=5,
            meta_batch_size=4,
            seed=42,
            dtype=tf.string)
        # Check for covering all classes
        classes = set()
        for epoch in range(num_epochs):
            for (x, x_val), (y, y_val) in ds:
                train_ds = combine_first_two_axes(x)
                for i in range(train_ds.shape[0]):
                    class_instances = train_ds[i, ...]
                    class_instance_address = class_instances[0].numpy().decode(
                        'utf-8')
                    class_address = os.path.split(class_instance_address)[0]
                    classes.add(class_address)

        self.assertSetEqual(classes,
                            set(self.omniglot_database.train_folders.keys()))

    @patch('tf_datasets.OmniglotDatabase._get_parse_function')
    def test_labels_are_correct_in_train_and_val_for_every_task(
            self, mocked_parse_function):
        mocked_parse_function.return_value = self.parse_function
        n = 7
        k = 5
        k_validation = 6
        ds = self.omniglot_database.get_supervised_meta_learning_dataset(
            self.omniglot_database.train_folders,
            n=n,
            k=k,
            k_validation=k_validation,
            meta_batch_size=4,
            one_hot_labels=False,
            dtype=tf.string)

        for epoch in range(2):
            for (x, x_val), (y, y_val) in ds:
                for meta_batch_index in range(x.shape[0]):
                    train_ds = x[meta_batch_index, ...]
                    val_ds = x_val[meta_batch_index, ...]
                    train_labels = y[meta_batch_index, ...]
                    val_labels = y_val[meta_batch_index, ...]

                    class_label_dict = dict()

                    for class_index in range(n):
                        for instance_index in range(k):
                            instance_name = train_ds[class_index,
                                                     instance_index, ...]
                            label = train_labels[class_index * k +
                                                 instance_index, ...].numpy()

                            class_name = os.path.split(
                                instance_name.numpy().decode('utf-8'))[0]
                            if class_name in class_label_dict:
                                self.assertEqual(class_label_dict[class_name],
                                                 label)
                            else:
                                class_label_dict[class_name] = label

                    for class_index in range(n):
                        for instance_index in range(k_validation):
                            instance_name = val_ds[class_index, instance_index,
                                                   ...]
                            label = val_labels[class_index * k_validation +
                                               instance_index, ...].numpy()
                            class_name = os.path.split(
                                instance_name.numpy().decode('utf-8'))[0]
                            self.assertIn(class_name, class_label_dict)
                            self.assertEqual(class_label_dict[class_name],
                                             label)

    @patch('tf_datasets.OmniglotDatabase._get_parse_function')
    def test_train_and_val_have_different_samples_in_every_task(
            self, mocked_parse_function):
        mocked_parse_function.return_value = self.parse_function
        n = 6
        k = 4
        k_validation = 5
        ds = self.omniglot_database.get_supervised_meta_learning_dataset(
            self.omniglot_database.train_folders,
            n=n,
            k=k,
            k_validation=k_validation,
            meta_batch_size=4,
            one_hot_labels=False,
            dtype=tf.string)

        for epoch in range(4):
            for (x, x_val), (y, y_val) in ds:
                for meta_batch_index in range(x.shape[0]):
                    train_ds = x[meta_batch_index, ...]
                    val_ds = x_val[meta_batch_index, ...]

                    class_instances_dict = dict()

                    for class_index in range(n):
                        for instance_index in range(k):
                            instance_name = train_ds[class_index,
                                                     instance_index, ...]

                            class_name, instance_name = os.path.split(
                                instance_name.numpy().decode('utf-8'))
                            if class_name not in class_instances_dict:
                                class_instances_dict[class_name] = set()

                            class_instances_dict[class_name].add(instance_name)

                    for class_index in range(n):
                        for instance_index in range(k_validation):
                            instance_name = val_ds[class_index, instance_index,
                                                   ...]
                            class_name, instance_name = os.path.split(
                                instance_name.numpy().decode('utf-8'))
                            self.assertIn(class_name, class_instances_dict)
                            self.assertNotIn(instance_name,
                                             class_instances_dict[class_name])

    @patch('tf_datasets.OmniglotDatabase._get_parse_function')
    def test_no_two_class_in_the_same_task(self, mocked_parse_function):
        mocked_parse_function.return_value = self.parse_function
        n = 6
        k = 4
        k_validation = 5
        ds = self.omniglot_database.get_supervised_meta_learning_dataset(
            self.omniglot_database.train_folders,
            n=n,
            k=k,
            k_validation=k_validation,
            meta_batch_size=4,
            one_hot_labels=False,
            dtype=tf.string)

        for epoch in range(4):
            for (x, x_val), (y, y_val) in ds:
                for base_data, k_value in zip((x, x_val), (k, k_validation)):
                    for meta_batch_index in range(x.shape[0]):
                        train_ds = base_data[meta_batch_index, ...]
                        classes = dict()

                        for class_index in range(n):
                            for instance_index in range(k_value):
                                instance_name = train_ds[class_index,
                                                         instance_index, ...]
                                class_name = os.path.split(
                                    instance_name.numpy().decode('utf-8'))[0]
                                classes[class_name] = classes.get(
                                    class_name, 0) + 1

                        for class_name, num_class_instances in classes.items():
                            self.assertEqual(k_value, num_class_instances)

    @patch('tf_datasets.OmniglotDatabase._get_parse_function')
    def test_different_instances_are_selected_from_each_class_for_train_and_val_each_time(
            self, mocked_parse_function):
        #  Random seed is selected such that instances are not selected the same for the whole epoch.
        #  This test might fail due to change in random or behaviour of selecting the samples and it might not mean that the code
        #  does not work properly. Maybe from one task in two different times the same train and val data will be selected
        mocked_parse_function.return_value = self.parse_function
        n = 6
        k = 4
        ds = self.omniglot_database.get_supervised_meta_learning_dataset(
            self.omniglot_database.train_folders,
            n=n,
            k=k,
            k_validation=8,
            meta_batch_size=4,
            one_hot_labels=False,
            dtype=tf.string)

        class_instances = dict()
        class_instances[0] = dict()
        class_instances[1] = dict()

        for epoch in range(2):
            for (x, x_val), (y, y_val) in ds:
                for meta_batch_index in range(x.shape[0]):
                    train_ds = x[meta_batch_index, ...]

                    for class_index in range(n):
                        for instance_index in range(k):
                            instance_address = train_ds[class_index,
                                                        instance_index, ...]
                            class_name, instance_name = os.path.split(
                                instance_address.numpy().decode('utf-8'))
                            if class_name not in class_instances[epoch]:
                                class_instances[epoch][class_name] = set()
                            class_instances[epoch][class_name].add(
                                instance_name)

        first_epoch_class_instances = class_instances[0]
        second_epoch_class_instances = class_instances[1]

        for class_name in first_epoch_class_instances.keys():
            self.assertIn(class_name, second_epoch_class_instances)
            self.assertNotEqual(
                0, first_epoch_class_instances[class_name].difference(
                    second_epoch_class_instances[class_name]))

    @patch('tf_datasets.MiniImagenetDatabase._get_parse_function')
    def test_1000_tasks_are_completely_different(self, mocked_parse_function):
        # This test should run on MAML get validation dataset.
        # For now, I just copied the code from there here to make sure that everything works fine.
        # However, if that code changes, this test does not have any value.
        # TODO Make this test work with maml.get_validation_dataset and get_test_dataset functions.
        mocked_parse_function.return_value = self.parse_function

        def get_val_dataset():
            val_dataset = self.mini_imagenet_database.get_supervised_meta_learning_dataset(
                self.mini_imagenet_database.val_folders,
                n=6,
                k=4,
                k_validation=3,
                meta_batch_size=1,
                seed=42,
                dtype=tf.string)
            val_dataset = val_dataset.repeat(-1)
            val_dataset = val_dataset.take(1000)
            return val_dataset

        counter = 0
        xs = set()
        x_vals = set()

        xs_queue = deque()
        x_vals_queue = deque()

        for (x, x_val), (y, y_val) in get_val_dataset():
            x_string = ','.join(list(map(str, x.numpy().reshape(-1))))
            # print(x_string)
            self.assertNotIn(x_string, xs)
            xs.add(x_string)

            x_val_string = ','.join(list(map(str, x_val.numpy().reshape(-1))))
            self.assertNotIn(x_val_string, x_vals)
            x_vals.add(x_val_string)

            xs_queue.append(x_string)
            x_vals_queue.append(x_val_string)
            counter += 1

        self.assertEqual(counter, 1000)
        for (x, x_val), (y, y_val) in get_val_dataset():
            x_string = ','.join(list(map(str, x.numpy().reshape(-1))))
            self.assertEqual(xs_queue.popleft(), x_string)

            x_val_string = ','.join(list(map(str, x_val.numpy().reshape(-1))))
            self.assertEqual(x_vals_queue.popleft(), x_val_string)