Example #1
0
def run_mini_imagenet():
    mini_imagenet_database = MiniImagenetDatabase()

    maml = ModelAgnosticMetaLearningModel(
        database=mini_imagenet_database,
        test_database=MiniImagenetDatabase(),
        network_cls=MiniImagenetModel,
        n=5,
        k_ml=1,
        k_val_ml=5,
        k_val=1,
        k_val_val=15,
        k_test=50,
        k_val_test=15,
        meta_batch_size=4,
        num_steps_ml=5,
        lr_inner_ml=0.05,
        num_steps_validation=5,
        save_after_iterations=15000,
        meta_learning_rate=0.001,
        report_validation_frequency=1000,
        log_train_images_after_iteration=1000,
        num_tasks_val=100,
        clip_gradients=True,
        experiment_name='mini_imagenet_test_res',
        val_seed=42,
        val_test_batch_norm_momentum=0.0,
    )

    # maml.train(iterations=60000)
    maml.evaluate(50, num_tasks=1000, seed=42, use_val_batch_statistics=True)
    maml.evaluate(50, num_tasks=1000, seed=42, use_val_batch_statistics=False)
Example #2
0
    def setUp(self):
        def parse_function(example_address):
            return example_address

        self.parse_function = parse_function

        self.omniglot_database = OmniglotDatabase(random_seed=-1,
                                                  num_train_classes=1200,
                                                  num_val_classes=100)
        self.mini_imagenet_database = MiniImagenetDatabase()
        self.celeba_database = CelebADatabase()
        self.lfw_database = LFWDatabase()
        self.euro_sat_database = EuroSatDatabase()
        self.isic_database = ISICDatabase()
        self.chest_x_ray_database = ChestXRay8Database()
Example #3
0
def run_transfer_meta_learning():
    mini_imagenet_database = MiniImagenetDatabase()
    euro_sat_database = EuroSatDatabase()
    tml = TransferMetaLearningVGG16(
        database=mini_imagenet_database,
        val_database=euro_sat_database,
        target_database=euro_sat_database,
        network_cls=None,
        n=5,
        k=1,
        k_val_ml=5,
        k_val_val=15,
        k_val_test=15,
        k_test=1,
        meta_batch_size=4,
        num_steps_ml=1,
        lr_inner_ml=0.001,
        num_steps_validation=5,
        save_after_iterations=1500,
        meta_learning_rate=0.0001,
        report_validation_frequency=250,
        log_train_images_after_iteration=1000,
        number_of_tasks_val=100,
        number_of_tasks_test=100,
        clip_gradients=True,
        experiment_name='transfer_meta_learning_mini_imagenet_euro_sat',
        val_seed=42,
        val_test_batch_norm_momentum=0.0,
        random_layer_initialization_seed=42,
        num_trainable_layers=3,
    )

    tml.train(iterations=6000)
    tml.evaluate(50, seed=42, use_val_batch_statistics=True)
    tml.evaluate(50, seed=42, use_val_batch_statistics=False)
def run_transfer_learning():
    miniimagenet_database = MiniImagenetDatabase()
    transfer_learning = TransferLearning(
        database=miniimagenet_database,
        network_cls=get_network,
        n=5,
        k_ml=1,
        k_val_ml=5,
        k_val_val=15,
        k_val=1,
        k_test=50,
        k_val_test=15,
        meta_batch_size=4,
        num_steps_ml=5,
        lr_inner_ml=0.05,
        num_steps_validation=5,
        save_after_iterations=15000,
        meta_learning_rate=0.001,
        report_validation_frequency=250,
        log_train_images_after_iteration=1000,
        num_tasks_val=100,
        clip_gradients=True,
        experiment_name='fixed_vgg_19',
        val_seed=42,
        val_test_batch_norm_momentum=0.0,
    )
    print(f'k: {transfer_learning.k_test}')
    transfer_learning.evaluate(50,
                               num_tasks=1000,
                               seed=42,
                               use_val_batch_statistics=True)
Example #5
0
def run_mini_imagenet():
    mini_imagenet_database = MiniImagenetDatabase()

    anil = ANILUnsupervised(
        database=mini_imagenet_database,
        network_cls=MiniImagenetModel,
        n=5,
        k_ml=1,
        k_val_ml=1,
        k_val=1,
        k_val_val=15,
        k_test=1,
        k_val_test=15,
        meta_batch_size=4,
        num_steps_ml=5,
        lr_inner_ml=0.05,
        num_steps_validation=5,
        save_after_iterations=15000,
        meta_learning_rate=0.001,
        report_validation_frequency=1000,
        log_train_images_after_iteration=1000,
        num_tasks_val=100,
        clip_gradients=True,
        experiment_name='mini_imagenet_unsupervised_permutation',
        val_seed=42,
        val_test_batch_norm_momentum=0.0,
        # set_of_frozen_layers={'conv1', 'conv2', 'conv3', 'conv4', 'bn1', 'bn2', 'bn3', 'bn4', 'dense'}
        set_of_frozen_layers={
            'conv1', 'conv2', 'conv3', 'conv4', 'bn1', 'bn2', 'bn3', 'bn4'
        }
        # set_of_frozen_layers={}
    )

    anil.train(iterations=60000)
    anil.evaluate(50, num_tasks=1000, seed=42, use_val_batch_statistics=True)
def run_mini_imagenet():
    mini_imagenet_database = MiniImagenetDatabase()

    gan_sampling = GANSampling(
        database=mini_imagenet_database,
        network_cls=MiniImagenetModel,
        n=5,
        k=1,
        k_val_ml=5,
        k_val_train=1,
        k_val_val=15,
        k_val_test=15,
        k_test=1,
        meta_batch_size=4,
        num_steps_ml=5,
        lr_inner_ml=0.05,
        num_steps_validation=5,
        save_after_iterations=1000,
        meta_learning_rate=0.001,
        report_validation_frequency=50,
        log_train_images_after_iteration=250,
        number_of_tasks_val=100,
        number_of_tasks_test=1000,
        clip_gradients=True,
        experiment_name='mini_imagenet_interpolation_std_1.2_shift_5',
        val_seed=42,
        val_test_batch_norm_momentum=0.0)

    gan_sampling.train(iterations=60000)
    gan_sampling.evaluate(iterations=50,
                          use_val_batch_statistics=True,
                          seed=42,
                          iterations_to_load_from=16000)
 def get_train_dataset(self):
     database = MiniImagenetDatabase()
     dataset = self.get_supervised_meta_learning_dataset(
         database.train_folders,
         n=self.n,
         k=self.k,
         k_validation=self.k_val_ml,
         meta_batch_size=self.meta_batch_size
     )
     return dataset
 def get_train_dataset(self):
     trn_database = MiniImagenetDatabase()
     val_database = PlantDiseaseDatabase()
     dataset = self.get_separated_supervised_meta_learning_dataset(
         trn_database.train_folders,
         val_database.train_folders,
         n=self.n,
         k=self.k_ml,
         k_validation=self.k_val_ml,
         meta_batch_size=self.meta_batch_size)
     return dataset
Example #9
0
def run_domain_attention():
    train_domain_databases = [
        MiniImagenetDatabase(),
        OmniglotDatabase(random_seed=47,
                         num_train_classes=1200,
                         num_val_classes=100),
        DTDDatabase(),
        VGGFlowerDatabase()
    ]
    meta_train_domain_databases = [
        # AirplaneDatabase(),
        FungiDatabase(),
        # CUBDatabase(),
    ]

    test_database = FungiDatabase()

    da = DomainAttentionUnsupervised(
        train_databases=train_domain_databases,
        meta_train_databases=meta_train_domain_databases,
        database=test_database,
        test_database=test_database,
        network_cls=None,
        image_shape=(84, 84, 3),
        n=5,
        k_ml=1,
        k_val_ml=5,
        k_val=1,
        k_val_val=15,
        k_test=5,
        k_val_test=15,
        meta_batch_size=4,
        num_steps_ml=1,
        lr_inner_ml=0.05,
        num_steps_validation=5,
        save_after_iterations=5000,
        meta_learning_rate=0.001,
        report_validation_frequency=1000,
        log_train_images_after_iteration=1000,
        num_tasks_val=100,
        clip_gradients=True,
        experiment_name=
        'domain_attention_all_frozen_layers_unsupervised_fungi2',
        val_seed=42,
        val_test_batch_norm_momentum=0.0,
    )

    print(da.experiment_name)

    # da.train(iterations=5000)
    da.evaluate(iterations=50, num_tasks=1000, seed=42)
Example #10
0
 def __init__(self, meta_train_databases=None, *args, **kwargs):
     super(CombinedCrossDomainMetaLearning, self).__init__(*args, **kwargs)
     if meta_train_databases is None:
         self.meta_train_databases = [
             MiniImagenetDatabase(),
             AirplaneDatabase(),
             CUBDatabase(),
             OmniglotDatabase(random_seed=47,
                              num_train_classes=1200,
                              num_val_classes=100),
             DTDDatabase(),
             FungiDatabase(),
             VGGFlowerDatabase()
         ]
     else:
         self.meta_train_databases = meta_train_databases
def run_domain_attention():
    train_domain_databases = [
        MiniImagenetDatabase(),
        OmniglotDatabase(random_seed=47,
                         num_train_classes=1200,
                         num_val_classes=100),
        DTDDatabase(),
        VGGFlowerDatabase()
    ]
    meta_train_domain_databases = [
        AirplaneDatabase(),
        FungiDatabase(),
        CUBDatabase(),
    ]

    test_database = EuroSatDatabase()

    ewda = ElementWiseDomainAttention(
        train_databases=train_domain_databases,
        meta_train_databases=meta_train_domain_databases,
        database=test_database,
        test_database=test_database,
        network_cls=None,
        image_shape=(84, 84, 3),
        n=5,
        k_ml=1,
        k_val_ml=5,
        k_val=1,
        k_val_val=15,
        k_test=5,
        k_val_test=15,
        meta_batch_size=4,
        num_steps_ml=5,
        lr_inner_ml=0.05,
        num_steps_validation=5,
        save_after_iterations=15000,
        meta_learning_rate=0.001,
        report_validation_frequency=1000,
        log_train_images_after_iteration=1000,
        num_tasks_val=100,
        clip_gradients=True,
        experiment_name='element_wise_domain_attention',
        val_seed=42,
        val_test_batch_norm_momentum=0.0,
    )
    ewda.train(iterations=60000)
    ewda.evaluate(iterations=50, num_tasks=1000, seed=14)
    def get_train_dataset(self):
        databases = [
            MiniImagenetDatabase(),
            AirplaneDatabase(),
            CUBDatabase(),
            OmniglotDatabase(random_seed=47, num_train_classes=1200, num_val_classes=100),
            DTDDatabase(),
            FungiDatabase(),
            VGGFlowerDatabase()
        ]

        dataset = self.get_cross_domain_meta_learning_dataset(
            databases=databases,
            n=self.n,
            k_ml=self.k_ml,
            k_validation=self.k_val_ml,
            meta_batch_size=self.meta_batch_size
        )
        return dataset
def run_mini_imagenet():
    mini_imagenet_database = MiniImagenetDatabase(random_seed=-1)

    maml = MAMLAbstractLearner(database=mini_imagenet_database,
                               network_cls=MiniImagenetModel,
                               n=5,
                               k=4,
                               meta_batch_size=1,
                               num_steps_ml=5,
                               lr_inner_ml=0.01,
                               num_steps_validation=5,
                               save_after_epochs=500,
                               meta_learning_rate=0.001,
                               report_validation_frequency=50,
                               log_train_images_after_iteration=1000,
                               least_number_of_tasks_val_test=50,
                               clip_gradients=True)

    maml.train(epochs=30000)
    maml.evaluate(50)
def run_omniglot():
    omniglot_database = Omniglot84x84Database(
        random_seed=47,
        num_train_classes=1200,
        num_val_classes=100,
    )

    maml = ModelAgnosticMetaLearningModel(
        database=omniglot_database,
        test_database=MiniImagenetDatabase(),
        network_cls=MiniImagenetModel,
        n=5,
        k_ml=1,
        k_val_ml=5,
        k_val=1,
        k_val_val=15,
        k_test=5,
        k_val_test=15,
        meta_batch_size=4,
        num_steps_ml=5,
        lr_inner_ml=0.05,
        num_steps_validation=5,
        save_after_iterations=15000,
        meta_learning_rate=0.001,
        report_validation_frequency=1000,
        log_train_images_after_iteration=1000,
        num_tasks_val=100,
        clip_gradients=True,
        experiment_name='omniglot_84x84',
        val_seed=42,
        val_test_batch_norm_momentum=0.0)

    maml.train(iterations=60000)
    maml.evaluate(iterations=50,
                  num_tasks=1000,
                  use_val_batch_statistics=True,
                  seed=42)
    maml.evaluate(iterations=50,
                  num_tasks=1000,
                  use_val_batch_statistics=False,
                  seed=42)
Example #15
0
def run_airplane():
    test_database = MiniImagenetDatabase()

    cdae = CrossDomainAE2(
        database=test_database,
        batch_size=512,
        # domains=('fungi', ),
        # domains=('airplane', 'fungi', 'cub', 'dtd', 'miniimagenet', 'omniglot', 'vggflowers'),
        domains=('airplane', 'fungi', 'cub', 'dtd', 'omniglot', 'vggflowers'),
        # domains=('cub', 'miniimagenet', 'vggflowers'),
        # domains=('fungi', 'cub', 'dtd', 'miniimagenet', 'omniglot', 'vggflowers'),
    )

    experiment_name = 'all_domains_288'

    cdae.train(epochs=20, experiment_name=experiment_name)
    cdae.evaluate(10,
                  num_tasks=1000,
                  k_test=1,
                  k_val_test=15,
                  inner_learning_rate=0.001,
                  experiment_name=experiment_name,
                  seed=42)
Example #16
0
            else:
                new_z = tf.stack([
                    z[0, ...] + (z[(i + 1) % self.n, ...] - z[0, ...]) * 0.3,
                    z[1, ...] + (z[(i + 2) % self.n, ...] - z[1, ...]) * 0.3,
                    z[2, ...] + (z[(i + 3) % self.n, ...] - z[2, ...]) * 0.3,
                    z[3, ...] + (z[(i + 4) % self.n, ...] - z[3, ...]) * 0.3,
                    z[4, ...] + (z[(i + 0) % self.n, ...] - z[4, ...]) * 0.3,
                ],
                                 axis=0)
                vectors.append(new_z)

        return vectors


if __name__ == '__main__':
    mini_imagenet_database = MiniImagenetDatabase(input_shape=(224, 224, 3))
    shape = (224, 224, 3)
    latent_dim = 120
    import os
    os.environ['TFHUB_CACHE_DIR'] = os.path.expanduser('~/tf_hub')

    gan = hub.load("https://tfhub.dev/deepmind/bigbigan-resnet50/1",
                   tags=[]).signatures['generate']
    setattr(gan, 'parser', MiniImagenetParser(shape=shape))

    maml_gan = MiniImageNetMAMLBigGan(
        gan=gan,
        latent_dim=latent_dim,
        generated_image_shape=shape,
        database=mini_imagenet_database,
        network_cls=FiveLayerResNet,
Example #17
0
class TestDatabases(unittest.TestCase):
    def setUp(self):
        def parse_function(example_address):
            return example_address

        self.parse_function = parse_function

        self.omniglot_database = OmniglotDatabase(random_seed=-1,
                                                  num_train_classes=1200,
                                                  num_val_classes=100)
        self.mini_imagenet_database = MiniImagenetDatabase()
        self.celeba_database = CelebADatabase()
        self.lfw_database = LFWDatabase()
        self.euro_sat_database = EuroSatDatabase()
        self.isic_database = ISICDatabase()
        self.chest_x_ray_database = ChestXRay8Database()

    @property
    def databases(self):
        return (
            self.omniglot_database,
            self.mini_imagenet_database,
            self.celeba_database,
            self.lfw_database,
            self.euro_sat_database,
            self.isic_database,
            self.chest_x_ray_database,
        )

    def test_train_val_test_folders_are_separate(self):
        for database in (self.omniglot_database, self.mini_imagenet_database,
                         self.celeba_database):
            train_folders = set(database.train_folders)
            val_folders = set(database.val_folders)
            test_folders = set(database.test_folders)

            for folder in train_folders:
                self.assertNotIn(folder, val_folders)
                self.assertNotIn(folder, test_folders)

            for folder in val_folders:
                self.assertNotIn(folder, train_folders)
                self.assertNotIn(folder, test_folders)

            for folder in test_folders:
                self.assertNotIn(folder, train_folders)
                self.assertNotIn(folder, val_folders)

    def test_train_val_test_folders_are_comprehensive(self):
        self.assertEqual(
            1623,
            len(
                list(self.omniglot_database.train_folders.keys()) +
                list(self.omniglot_database.val_folders.keys()) +
                list(self.omniglot_database.test_folders.keys())))
        self.assertEqual(
            100,
            len(
                list(self.mini_imagenet_database.train_folders) +
                list(self.mini_imagenet_database.val_folders) +
                list(self.mini_imagenet_database.test_folders)))
        self.assertEqual(
            10177,
            len(
                list(self.celeba_database.train_folders) +
                list(self.celeba_database.val_folders) +
                list(self.celeba_database.test_folders)))

    @patch('tf_datasets.OmniglotDatabase._get_parse_function')
    def test_covering_all_classes_in_one_epoch(self, mocked_parse_function):
        # Make a new database so that the number of classes are dividable by the number of meta batches * n.
        mocked_parse_function.return_value = self.parse_function
        ds = self.omniglot_database.get_supervised_meta_learning_dataset(
            self.omniglot_database.train_folders,
            n=5,
            k=5,
            k_validation=5,
            meta_batch_size=4,
            dtype=tf.string)

        # Check for covering all classes and no duplicates
        classes = set()
        for (x, x_val), (y, y_val) in ds:
            train_ds = combine_first_two_axes(x)
            for i in range(train_ds.shape[0]):
                class_instances = train_ds[i, ...]
                class_instance_address = class_instances[0].numpy().decode(
                    'utf-8')
                class_address = os.path.split(class_instance_address)[0]
                self.assertNotIn(class_address, classes)
                classes.add(class_address)

        self.assertSetEqual(classes,
                            set(self.omniglot_database.train_folders.keys()))

    @patch('tf_datasets.OmniglotDatabase._get_parse_function')
    def test_covering_all_classes_in_subsequent_epochs(self,
                                                       mocked_parse_function):
        # This test might not pass because of the random selection of classes at the beginning of the class, but the
        # chances are low.  Specially if we increase the number of epochs, the chance of not covering classes will
        # decrease. Make a new database so that the number of classes are dividable by the number of meta batches * n.
        # This test should always fail with num_epochs = 1
        num_epochs = 3
        mocked_parse_function.return_value = self.parse_function
        ds = self.omniglot_database.get_supervised_meta_learning_dataset(
            self.omniglot_database.train_folders,
            n=7,
            k=5,
            k_validation=5,
            meta_batch_size=4,
            seed=42,
            dtype=tf.string)
        # Check for covering all classes
        classes = set()
        for epoch in range(num_epochs):
            for (x, x_val), (y, y_val) in ds:
                train_ds = combine_first_two_axes(x)
                for i in range(train_ds.shape[0]):
                    class_instances = train_ds[i, ...]
                    class_instance_address = class_instances[0].numpy().decode(
                        'utf-8')
                    class_address = os.path.split(class_instance_address)[0]
                    classes.add(class_address)

        self.assertSetEqual(classes,
                            set(self.omniglot_database.train_folders.keys()))

    @patch('tf_datasets.OmniglotDatabase._get_parse_function')
    def test_labels_are_correct_in_train_and_val_for_every_task(
            self, mocked_parse_function):
        mocked_parse_function.return_value = self.parse_function
        n = 7
        k = 5
        k_validation = 6
        ds = self.omniglot_database.get_supervised_meta_learning_dataset(
            self.omniglot_database.train_folders,
            n=n,
            k=k,
            k_validation=k_validation,
            meta_batch_size=4,
            one_hot_labels=False,
            dtype=tf.string)

        for epoch in range(2):
            for (x, x_val), (y, y_val) in ds:
                for meta_batch_index in range(x.shape[0]):
                    train_ds = x[meta_batch_index, ...]
                    val_ds = x_val[meta_batch_index, ...]
                    train_labels = y[meta_batch_index, ...]
                    val_labels = y_val[meta_batch_index, ...]

                    class_label_dict = dict()

                    for class_index in range(n):
                        for instance_index in range(k):
                            instance_name = train_ds[class_index,
                                                     instance_index, ...]
                            label = train_labels[class_index * k +
                                                 instance_index, ...].numpy()

                            class_name = os.path.split(
                                instance_name.numpy().decode('utf-8'))[0]
                            if class_name in class_label_dict:
                                self.assertEqual(class_label_dict[class_name],
                                                 label)
                            else:
                                class_label_dict[class_name] = label

                    for class_index in range(n):
                        for instance_index in range(k_validation):
                            instance_name = val_ds[class_index, instance_index,
                                                   ...]
                            label = val_labels[class_index * k_validation +
                                               instance_index, ...].numpy()
                            class_name = os.path.split(
                                instance_name.numpy().decode('utf-8'))[0]
                            self.assertIn(class_name, class_label_dict)
                            self.assertEqual(class_label_dict[class_name],
                                             label)

    @patch('tf_datasets.OmniglotDatabase._get_parse_function')
    def test_train_and_val_have_different_samples_in_every_task(
            self, mocked_parse_function):
        mocked_parse_function.return_value = self.parse_function
        n = 6
        k = 4
        k_validation = 5
        ds = self.omniglot_database.get_supervised_meta_learning_dataset(
            self.omniglot_database.train_folders,
            n=n,
            k=k,
            k_validation=k_validation,
            meta_batch_size=4,
            one_hot_labels=False,
            dtype=tf.string)

        for epoch in range(4):
            for (x, x_val), (y, y_val) in ds:
                for meta_batch_index in range(x.shape[0]):
                    train_ds = x[meta_batch_index, ...]
                    val_ds = x_val[meta_batch_index, ...]

                    class_instances_dict = dict()

                    for class_index in range(n):
                        for instance_index in range(k):
                            instance_name = train_ds[class_index,
                                                     instance_index, ...]

                            class_name, instance_name = os.path.split(
                                instance_name.numpy().decode('utf-8'))
                            if class_name not in class_instances_dict:
                                class_instances_dict[class_name] = set()

                            class_instances_dict[class_name].add(instance_name)

                    for class_index in range(n):
                        for instance_index in range(k_validation):
                            instance_name = val_ds[class_index, instance_index,
                                                   ...]
                            class_name, instance_name = os.path.split(
                                instance_name.numpy().decode('utf-8'))
                            self.assertIn(class_name, class_instances_dict)
                            self.assertNotIn(instance_name,
                                             class_instances_dict[class_name])

    @patch('tf_datasets.OmniglotDatabase._get_parse_function')
    def test_no_two_class_in_the_same_task(self, mocked_parse_function):
        mocked_parse_function.return_value = self.parse_function
        n = 6
        k = 4
        k_validation = 5
        ds = self.omniglot_database.get_supervised_meta_learning_dataset(
            self.omniglot_database.train_folders,
            n=n,
            k=k,
            k_validation=k_validation,
            meta_batch_size=4,
            one_hot_labels=False,
            dtype=tf.string)

        for epoch in range(4):
            for (x, x_val), (y, y_val) in ds:
                for base_data, k_value in zip((x, x_val), (k, k_validation)):
                    for meta_batch_index in range(x.shape[0]):
                        train_ds = base_data[meta_batch_index, ...]
                        classes = dict()

                        for class_index in range(n):
                            for instance_index in range(k_value):
                                instance_name = train_ds[class_index,
                                                         instance_index, ...]
                                class_name = os.path.split(
                                    instance_name.numpy().decode('utf-8'))[0]
                                classes[class_name] = classes.get(
                                    class_name, 0) + 1

                        for class_name, num_class_instances in classes.items():
                            self.assertEqual(k_value, num_class_instances)

    @patch('tf_datasets.OmniglotDatabase._get_parse_function')
    def test_different_instances_are_selected_from_each_class_for_train_and_val_each_time(
            self, mocked_parse_function):
        #  Random seed is selected such that instances are not selected the same for the whole epoch.
        #  This test might fail due to change in random or behaviour of selecting the samples and it might not mean that the code
        #  does not work properly. Maybe from one task in two different times the same train and val data will be selected
        mocked_parse_function.return_value = self.parse_function
        n = 6
        k = 4
        ds = self.omniglot_database.get_supervised_meta_learning_dataset(
            self.omniglot_database.train_folders,
            n=n,
            k=k,
            k_validation=8,
            meta_batch_size=4,
            one_hot_labels=False,
            dtype=tf.string)

        class_instances = dict()
        class_instances[0] = dict()
        class_instances[1] = dict()

        for epoch in range(2):
            for (x, x_val), (y, y_val) in ds:
                for meta_batch_index in range(x.shape[0]):
                    train_ds = x[meta_batch_index, ...]

                    for class_index in range(n):
                        for instance_index in range(k):
                            instance_address = train_ds[class_index,
                                                        instance_index, ...]
                            class_name, instance_name = os.path.split(
                                instance_address.numpy().decode('utf-8'))
                            if class_name not in class_instances[epoch]:
                                class_instances[epoch][class_name] = set()
                            class_instances[epoch][class_name].add(
                                instance_name)

        first_epoch_class_instances = class_instances[0]
        second_epoch_class_instances = class_instances[1]

        for class_name in first_epoch_class_instances.keys():
            self.assertIn(class_name, second_epoch_class_instances)
            self.assertNotEqual(
                0, first_epoch_class_instances[class_name].difference(
                    second_epoch_class_instances[class_name]))

    @patch('tf_datasets.MiniImagenetDatabase._get_parse_function')
    def test_1000_tasks_are_completely_different(self, mocked_parse_function):
        # This test should run on MAML get validation dataset.
        # For now, I just copied the code from there here to make sure that everything works fine.
        # However, if that code changes, this test does not have any value.
        # TODO Make this test work with maml.get_validation_dataset and get_test_dataset functions.
        mocked_parse_function.return_value = self.parse_function

        def get_val_dataset():
            val_dataset = self.mini_imagenet_database.get_supervised_meta_learning_dataset(
                self.mini_imagenet_database.val_folders,
                n=6,
                k=4,
                k_validation=3,
                meta_batch_size=1,
                seed=42,
                dtype=tf.string)
            val_dataset = val_dataset.repeat(-1)
            val_dataset = val_dataset.take(1000)
            return val_dataset

        counter = 0
        xs = set()
        x_vals = set()

        xs_queue = deque()
        x_vals_queue = deque()

        for (x, x_val), (y, y_val) in get_val_dataset():
            x_string = ','.join(list(map(str, x.numpy().reshape(-1))))
            # print(x_string)
            self.assertNotIn(x_string, xs)
            xs.add(x_string)

            x_val_string = ','.join(list(map(str, x_val.numpy().reshape(-1))))
            self.assertNotIn(x_val_string, x_vals)
            x_vals.add(x_val_string)

            xs_queue.append(x_string)
            x_vals_queue.append(x_val_string)
            counter += 1

        self.assertEqual(counter, 1000)
        for (x, x_val), (y, y_val) in get_val_dataset():
            x_string = ','.join(list(map(str, x.numpy().reshape(-1))))
            self.assertEqual(xs_queue.popleft(), x_string)

            x_val_string = ','.join(list(map(str, x_val.numpy().reshape(-1))))
            self.assertEqual(x_vals_queue.popleft(), x_val_string)
Example #18
0
        print("Error: Too many arguments")
        sys.exit(0)

    # CONFIGS
    ITERATIONS = 15000
    GAN_EPOCHS = 100
    N_TASK_EVAL = 100
    K = 5
    N_WAY = 5
    META_BATCH_SIZE = 1
    LASIUM_TYPE = "p1"

    print("K=",K)
    print("N_WAY=",N_WAY)

    mini_imagenet_database = MiniImagenetDatabase()
    shape = (84, 84, 3)
    latent_dim = 512
    mini_imagenet_generator = get_generator(latent_dim)
    mini_imagenet_discriminator = get_discriminator()
    mini_imagenet_parser = MiniImagenetParser(shape=shape)

    experiment_name = prefix+str(labeled_percentage)

    # for the SSGAN we need to feed the labels, L, when initializing
    gan = GAN(
        'mini_imagenet',
        image_shape=shape,
        latent_dim=latent_dim,
        database=mini_imagenet_database,
        parser=mini_imagenet_parser,
Example #19
0
    x = layers.Conv2D(64, 4, activation=None, strides=2, padding="same", use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(alpha=0.2)(x)

    x = layers.GlobalMaxPooling2D()(x)

    discriminator_outputs = layers.Dense(1)(x)

    discriminator = keras.Model(discriminator_inputs, discriminator_outputs, name="discriminator")

    discriminator.summary()
    return discriminator


if __name__ == '__main__':
    mini_imagenet_database = MiniImagenetDatabase()
    shape = (84, 84, 3)
    latent_dim = 512
    mini_imagenet_generator = get_generator(latent_dim)
    mini_imagenet_discriminator = get_discriminator()
    mini_imagenet_parser = MiniImagenetParser(shape=shape)

    gan = GAN(
        'mini_imagenet',
        image_shape=shape,
        latent_dim=latent_dim,
        database=mini_imagenet_database,
        parser=mini_imagenet_parser,
        generator=mini_imagenet_generator,
        discriminator=mini_imagenet_discriminator,
        visualization_freq=1,
from databases import MiniImagenetDatabase

model = tf.keras.applications.VGG16(include_top=False,
                                    weights='imagenet',
                                    input_shape=(224, 224, 3))

flatten = tf.keras.layers.Flatten(name='flatten')(model.output)
fc1 = tf.keras.layers.Dense(512)(flatten)
fc2 = tf.keras.layers.Dense(512)(fc1)
fc3 = tf.keras.layers.Dense(5)(fc2)

new_model = tf.keras.models.Model(inputs=[model.input], outputs=[fc3])
new_model.compile(optimizer='adam',
                  loss=tf.keras.losses.CategoricalCrossentropy())

data_base = MiniImagenetDatabase(input_shape=(224, 224, 3))

dataset = data_base.get_supervised_meta_learning_dataset(
    data_base.train_folders, 5, 1, 1, 2)

for item in dataset:
    x, y = item
    x1, x2 = x
    y1, y2 = y
    x1 = tf.reshape(x1, (10, 224, 224, 3))
    y1 = tf.reshape(y1, (10, 5))
    new_model.fit(x1, y1)
    print(tf.argmax(new_model.predict(x1), axis=0))
    print(tf.argmax(y1, axis=0))
    break
        axes[0, i].xaxis.set_label_position('top')
        axes[0, i].set_xlabel(title_name)

    # fig.suptitle('', fontsize=12, y=1)
    plt.savefig(fname=os.path.join(root_folder_to_save, 'all_domains2.pdf'))
    plt.show()


if __name__ == '__main__':
    tf.config.set_visible_devices([], 'GPU')
    root_folder_to_save = os.path.expanduser('~/datasets_visualization/')
    if not os.path.exists(root_folder_to_save):
        os.mkdir(root_folder_to_save)

    databases = (
        MiniImagenetDatabase(),
        OmniglotDatabase(random_seed=42,
                         num_train_classes=1200,
                         num_val_classes=100),
        AirplaneDatabase(),
        CUBDatabase(),
        DTDDatabase(),
        FungiDatabase(),
        VGGFlowerDatabase(),
        # TrafficSignDatabase(),
        # MSCOCODatabase(),
        # PlantDiseaseDatabase(),
        # EuroSatDatabase(),
        # ISICDatabase(),
        # ChestXRay8Database(),
    )
    for style_layer_name in style_layers
]
model = tf.keras.models.Model(inputs=vgg19.inputs, outputs=outputs)


def convert_to_activaitons(imgs):
    imgs_activations = []
    for image in imgs:
        image = convert_str_to_img(image)[np.newaxis, ...]
        activations = model.predict(image)
        imgs_activations.append(activations)
    return imgs_activations


DTDDatabase()
d1_imgs = get_all_instances(MiniImagenetDatabase())
d2_imgs = get_all_instances(ISICDatabase())
d3_imgs = get_all_instances(ChestXRay8Database())
d4_imgs = get_all_instances(PlantDiseaseDatabase())

d1_imgs = np.random.choice(d1_imgs, 10, replace=False)
d2_imgs = np.random.choice(d2_imgs, 10, replace=False)
d3_imgs = np.random.choice(d3_imgs, 10, replace=False)
d4_imgs = np.random.choice(d4_imgs, 10, replace=False)

print(d1_imgs)
print(d2_imgs)
print(d3_imgs)
print(d4_imgs)

d1_imgs_activations = convert_to_activaitons(d1_imgs)
Example #23
0
def run_mini_imagenet():
    mini_imagenet_database = MiniImagenetDatabase()
    n_data_points = 38400
    data_points, classes, non_labeled_data_points = sample_data_points(
        mini_imagenet_database.train_folders, n_data_points)
    features_dataset, n_classes = make_features_dataset_mini_imagenet(
        data_points,
        classes,
        non_labeled_data_points,
        # shuffle_buffer_size=n_data_points,
        # batch_size=32,
        batch_size=16,
        shuffle_buffer_size=1000,
    )
    feature_model = tf.keras.applications.VGG19(weights=None,
                                                classes=n_classes)
    feature_model.compile(
        loss=tf.keras.losses.CategoricalCrossentropy(),
        metrics=['accuracy'],
        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-6),
    )

    def save_call_back(epoch, logs):
        if epoch % 100 == 0:
            feature_model.save_weights(
                filepath=f'./feature_models/feature_model_{epoch}')

    # saver_callback = tf.keras.callbacks.LambdaCallback(on_epoch_end=save_call_back)
    # feature_model.fit(features_dataset, epochs=101, callbacks=[saver_callback])
    feature_model.load_weights(filepath=f'./feature_models/feature_model_100')
    # feature_model.evaluate(features_dataset)
    # exit()
    feature_model = tf.keras.models.Model(
        inputs=feature_model.input, outputs=feature_model.layers[24].output)

    # print(n_classes)
    # feature_model = VariationalAutoEncoderFeature(input_shape=(84, 84, 3), latent_dim=32, n_classes=n_classes)
    # feature_model = train_the_feature_model_with_classification(
    #     feature_model,
    #     features_dataset,
    #     n_classes,
    #     mini_imagenet_database.input_shape
    # )

    # feature_model = None

    # use imagenet
    # base_model = tf.keras.applications.VGG19(weights='imagenet')
    # feature_model = tf.keras.models.Model(inputs=base_model.input, outputs=base_model.layers[24].output)

    sml = SML(
        database=mini_imagenet_database,
        network_cls=MiniImagenetModel,
        n=5,
        k=1,
        k_val_ml=5,
        k_val_val=15,
        k_val_test=15,
        k_test=1,
        meta_batch_size=4,
        num_steps_ml=5,
        lr_inner_ml=0.05,
        num_steps_validation=5,
        save_after_iterations=15000,
        meta_learning_rate=0.001,
        n_clusters=500,
        feature_model=feature_model,
        # feature_size=288,
        feature_size=4096,
        input_shape=(224, 224, 3),
        preprocess_function=tf.keras.applications.vgg19.preprocess_input,
        log_train_images_after_iteration=1000,
        least_number_of_tasks_val_test=100,
        report_validation_frequency=250,
        experiment_name='mini_imagenet_learn_miniimagent_features')
    sml.train(iterations=60000)