def run_domain_attention():
    train_domain_databases = [
        MiniImagenetDatabase(),
        OmniglotDatabase(random_seed=47,
                         num_train_classes=1200,
                         num_val_classes=100),
        DTDDatabase(),
        VGGFlowerDatabase()
    ]
    meta_train_domain_databases = [
        AirplaneDatabase(),
        # FungiDatabase(),
        # CUBDatabase(),
    ]

    test_database = AirplaneDatabase()

    da = DomainAttention(
        train_databases=train_domain_databases,
        meta_train_databases=meta_train_domain_databases,
        database=test_database,
        test_database=test_database,
        network_cls=None,
        image_shape=(84, 84, 3),
        n=5,
        k_ml=1,
        k_val_ml=5,
        k_val=1,
        k_val_val=15,
        k_test=1,
        k_val_test=15,
        meta_batch_size=4,
        num_steps_ml=5,
        lr_inner_ml=0.05,
        num_steps_validation=5,
        save_after_iterations=1000,
        meta_learning_rate=0.001,
        report_validation_frequency=1000,
        log_train_images_after_iteration=1000,
        num_tasks_val=100,
        clip_gradients=True,
        # experiment_name='domain_attention_no_frozen_layers_airplane',
        experiment_name='domain_attention_all_frozen_layers_airplane',
        val_seed=42,
        val_test_batch_norm_momentum=0.0,
    )

    da.train(iterations=5000)
    da.evaluate(iterations=50, num_tasks=1000, seed=14)
def run_airplane():
    airplane_database = AirplaneDatabase()

    maml = ModelAgnosticMetaLearningModel(
        database=airplane_database,
        # test_database=Omniglot84x84Database(random_seed=47, num_train_classes=1200, num_val_classes=100),
        test_database=MiniImagenetDatabase(),
        network_cls=MiniImagenetModel,
        n=5,
        k_ml=1,
        k_val_ml=5,
        k_val=1,
        k_val_val=15,
        k_test=5,
        k_val_test=15,
        meta_batch_size=4,
        num_steps_ml=5,
        lr_inner_ml=0.05,
        num_steps_validation=5,
        save_after_iterations=15000,
        meta_learning_rate=0.001,
        report_validation_frequency=1000,
        log_train_images_after_iteration=1000,
        num_tasks_val=100,
        clip_gradients=True,
        experiment_name='airplane',
        val_seed=42,
        val_test_batch_norm_momentum=0.0,
    )

    # maml.train(iterations=60000)
    maml.evaluate(50, num_tasks=1000, seed=42, use_val_batch_statistics=True)
    maml.evaluate(50, num_tasks=1000, seed=42, use_val_batch_statistics=False)
Пример #3
0
def run_airplane():
    test_database = AirplaneDatabase()

    cdae = CrossDomainAE2(
        database=test_database,
        batch_size=512,
        # domains=('fungi', ),
        # domains=('airplane', 'fungi', 'cub', 'dtd', 'miniimagenet', 'omniglot', 'vggflowers'),
        domains=('fungi', 'cub', 'dtd', 'miniimagenet', 'omniglot', 'vggflowers'),
        # domains=('cub', 'miniimagenet', 'vggflowers'),
        # domains=('fungi', 'cub', 'dtd', 'miniimagenet', 'omniglot', 'vggflowers'),
    )

    experiment_name = 'all_domains_288'

    cdae.train(epochs=20, experiment_name=experiment_name)
    cdae.evaluate(
        10,
        num_tasks=1000,
        k_test=1,
        k_val_test=15,
        inner_learning_rate=0.001,
        experiment_name=experiment_name,
        seed=42
    )
Пример #4
0
 def __init__(self, meta_train_databases=None, *args, **kwargs):
     super(CombinedCrossDomainMetaLearning, self).__init__(*args, **kwargs)
     if meta_train_databases is None:
         self.meta_train_databases = [
             MiniImagenetDatabase(),
             AirplaneDatabase(),
             CUBDatabase(),
             OmniglotDatabase(random_seed=47,
                              num_train_classes=1200,
                              num_val_classes=100),
             DTDDatabase(),
             FungiDatabase(),
             VGGFlowerDatabase()
         ]
     else:
         self.meta_train_databases = meta_train_databases
    def get_train_dataset(self):
        databases = [
            MiniImagenetDatabase(),
            AirplaneDatabase(),
            CUBDatabase(),
            OmniglotDatabase(random_seed=47, num_train_classes=1200, num_val_classes=100),
            DTDDatabase(),
            FungiDatabase(),
            VGGFlowerDatabase()
        ]

        dataset = self.get_cross_domain_meta_learning_dataset(
            databases=databases,
            n=self.n,
            k_ml=self.k_ml,
            k_validation=self.k_val_ml,
            meta_batch_size=self.meta_batch_size
        )
        return dataset
    plt.savefig(fname=os.path.join(root_folder_to_save, 'all_domains2.pdf'))
    plt.show()


if __name__ == '__main__':
    tf.config.set_visible_devices([], 'GPU')
    root_folder_to_save = os.path.expanduser('~/datasets_visualization/')
    if not os.path.exists(root_folder_to_save):
        os.mkdir(root_folder_to_save)

    databases = (
        MiniImagenetDatabase(),
        OmniglotDatabase(random_seed=42,
                         num_train_classes=1200,
                         num_val_classes=100),
        AirplaneDatabase(),
        CUBDatabase(),
        DTDDatabase(),
        FungiDatabase(),
        VGGFlowerDatabase(),
        # TrafficSignDatabase(),
        # MSCOCODatabase(),
        # PlantDiseaseDatabase(),
        # EuroSatDatabase(),
        # ISICDatabase(),
        # ChestXRay8Database(),
    )

    # for database in databases:
    #     visualize_database(database, root_folder_to_save)
from databases import AirplaneDatabase
from models.maml_umtra.maml_umtra import MAMLUMTRA
from networks.maml_umtra_networks import MiniImagenetModel

if __name__ == '__main__':
    # import tensorflow as tf
    # tf.config.experimental_run_functions_eagerly(True)

    airplane_database = AirplaneDatabase()

    maml_umtra = MAMLUMTRA(database=airplane_database,
                           network_cls=MiniImagenetModel,
                           n=5,
                           k_ml=2,
                           k_val_ml=5,
                           k_val=1,
                           k_val_val=15,
                           k_test=5,
                           k_val_test=15,
                           meta_batch_size=4,
                           num_steps_ml=5,
                           lr_inner_ml=0.05,
                           num_steps_validation=5,
                           save_after_iterations=15000,
                           meta_learning_rate=0.001,
                           report_validation_frequency=250,
                           log_train_images_after_iteration=1000,
                           num_tasks_val=100,
                           clip_gradients=True,
                           experiment_name='airplane',
                           val_seed=42,