Exemplo n.º 1
0
def run_vgg_flower():
    vgg_flower_database = VGGFlowerDatabase()
    maml = ModelAgnosticMetaLearningModel(
        database=vgg_flower_database,
        network_cls=MiniImagenetModel,
        n=5,
        k_ml=1,
        k_val_ml=5,
        k_val=1,
        k_val_val=15,
        k_test=1,
        k_val_test=15,
        meta_batch_size=4,
        num_steps_ml=5,
        lr_inner_ml=0.05,
        num_steps_validation=5,
        save_after_iterations=15000,
        meta_learning_rate=0.001,
        report_validation_frequency=1000,
        log_train_images_after_iteration=1000,
        num_tasks_val=100,
        clip_gradients=True,
        experiment_name='dtd',
        val_seed=42,
        val_test_batch_norm_momentum=0.0,
    )

    maml.train(iterations=60040)
    maml.evaluate(50, num_tasks=1000, seed=42, use_val_batch_statistics=True)
    maml.evaluate(50, num_tasks=1000, seed=42, use_val_batch_statistics=False)
Exemplo n.º 2
0
def run_domain_attention():
    train_domain_databases = [
        MiniImagenetDatabase(),
        OmniglotDatabase(random_seed=47,
                         num_train_classes=1200,
                         num_val_classes=100),
        DTDDatabase(),
        VGGFlowerDatabase()
    ]
    meta_train_domain_databases = [
        # AirplaneDatabase(),
        FungiDatabase(),
        # CUBDatabase(),
    ]

    test_database = FungiDatabase()

    da = DomainAttentionUnsupervised(
        train_databases=train_domain_databases,
        meta_train_databases=meta_train_domain_databases,
        database=test_database,
        test_database=test_database,
        network_cls=None,
        image_shape=(84, 84, 3),
        n=5,
        k_ml=1,
        k_val_ml=5,
        k_val=1,
        k_val_val=15,
        k_test=5,
        k_val_test=15,
        meta_batch_size=4,
        num_steps_ml=1,
        lr_inner_ml=0.05,
        num_steps_validation=5,
        save_after_iterations=5000,
        meta_learning_rate=0.001,
        report_validation_frequency=1000,
        log_train_images_after_iteration=1000,
        num_tasks_val=100,
        clip_gradients=True,
        experiment_name=
        'domain_attention_all_frozen_layers_unsupervised_fungi2',
        val_seed=42,
        val_test_batch_norm_momentum=0.0,
    )

    print(da.experiment_name)

    # da.train(iterations=5000)
    da.evaluate(iterations=50, num_tasks=1000, seed=42)
Exemplo n.º 3
0
 def __init__(self, meta_train_databases=None, *args, **kwargs):
     super(CombinedCrossDomainMetaLearning, self).__init__(*args, **kwargs)
     if meta_train_databases is None:
         self.meta_train_databases = [
             MiniImagenetDatabase(),
             AirplaneDatabase(),
             CUBDatabase(),
             OmniglotDatabase(random_seed=47,
                              num_train_classes=1200,
                              num_val_classes=100),
             DTDDatabase(),
             FungiDatabase(),
             VGGFlowerDatabase()
         ]
     else:
         self.meta_train_databases = meta_train_databases
def run_domain_attention():
    train_domain_databases = [
        MiniImagenetDatabase(),
        OmniglotDatabase(random_seed=47,
                         num_train_classes=1200,
                         num_val_classes=100),
        DTDDatabase(),
        VGGFlowerDatabase()
    ]
    meta_train_domain_databases = [
        AirplaneDatabase(),
        FungiDatabase(),
        CUBDatabase(),
    ]

    test_database = EuroSatDatabase()

    ewda = ElementWiseDomainAttention(
        train_databases=train_domain_databases,
        meta_train_databases=meta_train_domain_databases,
        database=test_database,
        test_database=test_database,
        network_cls=None,
        image_shape=(84, 84, 3),
        n=5,
        k_ml=1,
        k_val_ml=5,
        k_val=1,
        k_val_val=15,
        k_test=5,
        k_val_test=15,
        meta_batch_size=4,
        num_steps_ml=5,
        lr_inner_ml=0.05,
        num_steps_validation=5,
        save_after_iterations=15000,
        meta_learning_rate=0.001,
        report_validation_frequency=1000,
        log_train_images_after_iteration=1000,
        num_tasks_val=100,
        clip_gradients=True,
        experiment_name='element_wise_domain_attention',
        val_seed=42,
        val_test_batch_norm_momentum=0.0,
    )
    ewda.train(iterations=60000)
    ewda.evaluate(iterations=50, num_tasks=1000, seed=14)
    def get_train_dataset(self):
        databases = [
            MiniImagenetDatabase(),
            AirplaneDatabase(),
            CUBDatabase(),
            OmniglotDatabase(random_seed=47, num_train_classes=1200, num_val_classes=100),
            DTDDatabase(),
            FungiDatabase(),
            VGGFlowerDatabase()
        ]

        dataset = self.get_cross_domain_meta_learning_dataset(
            databases=databases,
            n=self.n,
            k_ml=self.k_ml,
            k_validation=self.k_val_ml,
            meta_batch_size=self.meta_batch_size
        )
        return dataset
def run_airplane():
    test_database = VGGFlowerDatabase()

    cdae = CrossDomainAE2(
        database=test_database,
        batch_size=512,
        # domains=('fungi', ),
        # domains=('airplane', 'fungi', 'cub', 'dtd', 'miniimagenet', 'omniglot', 'vggflowers'),
        domains=('airplane', 'fungi', 'cub', 'dtd', 'miniimagenet',
                 'omniglot'),
        # domains=('cub', 'miniimagenet', 'vggflowers'),
        # domains=('fungi', 'cub', 'dtd', 'miniimagenet', 'omniglot', 'vggflowers'),
    )

    experiment_name = 'all_domains_288'

    cdae.train(epochs=20, experiment_name=experiment_name)
    cdae.evaluate(10,
                  num_tasks=1000,
                  k_test=1,
                  k_val_test=15,
                  inner_learning_rate=0.001,
                  experiment_name=experiment_name,
                  seed=42)

if __name__ == '__main__':
    tf.config.set_visible_devices([], 'GPU')
    root_folder_to_save = os.path.expanduser('~/datasets_visualization/')
    if not os.path.exists(root_folder_to_save):
        os.mkdir(root_folder_to_save)

    databases = (
        MiniImagenetDatabase(),
        OmniglotDatabase(random_seed=42,
                         num_train_classes=1200,
                         num_val_classes=100),
        AirplaneDatabase(),
        CUBDatabase(),
        DTDDatabase(),
        FungiDatabase(),
        VGGFlowerDatabase(),
        # TrafficSignDatabase(),
        # MSCOCODatabase(),
        # PlantDiseaseDatabase(),
        # EuroSatDatabase(),
        # ISICDatabase(),
        # ChestXRay8Database(),
    )

    # for database in databases:
    #     visualize_database(database, root_folder_to_save)

    visualize_all_domains_together(databases, root_folder_to_save)
from databases import VGGFlowerDatabase
from models.maml_umtra.maml_umtra import MAMLUMTRA
from networks.maml_umtra_networks import MiniImagenetModel

if __name__ == '__main__':
    # import tensorflow as tf
    # tf.config.experimental_run_functions_eagerly(True)

    vgg_flower_database = VGGFlowerDatabase()

    maml_umtra = MAMLUMTRA(
        database=vgg_flower_database,
        network_cls=MiniImagenetModel,
        n=5,
        k_ml=1,
        k_val_ml=1,
        k_val=1,
        k_val_val=15,
        k_test=5,
        k_val_test=15,
        meta_batch_size=4,
        num_steps_ml=5,
        lr_inner_ml=0.05,
        num_steps_validation=5,
        save_after_iterations=15000,
        meta_learning_rate=0.001,
        report_validation_frequency=250,
        log_train_images_after_iteration=1000,
        num_tasks_val=100,
        clip_gradients=True,
        experiment_name='vgg_flower',