Пример #1
0
def run_domain_attention():
    train_domain_databases = [
        MiniImagenetDatabase(),
        OmniglotDatabase(random_seed=47,
                         num_train_classes=1200,
                         num_val_classes=100),
        DTDDatabase(),
        VGGFlowerDatabase()
    ]
    meta_train_domain_databases = [
        # AirplaneDatabase(),
        FungiDatabase(),
        # CUBDatabase(),
    ]

    test_database = FungiDatabase()

    da = DomainAttentionUnsupervised(
        train_databases=train_domain_databases,
        meta_train_databases=meta_train_domain_databases,
        database=test_database,
        test_database=test_database,
        network_cls=None,
        image_shape=(84, 84, 3),
        n=5,
        k_ml=1,
        k_val_ml=5,
        k_val=1,
        k_val_val=15,
        k_test=5,
        k_val_test=15,
        meta_batch_size=4,
        num_steps_ml=1,
        lr_inner_ml=0.05,
        num_steps_validation=5,
        save_after_iterations=5000,
        meta_learning_rate=0.001,
        report_validation_frequency=1000,
        log_train_images_after_iteration=1000,
        num_tasks_val=100,
        clip_gradients=True,
        experiment_name=
        'domain_attention_all_frozen_layers_unsupervised_fungi2',
        val_seed=42,
        val_test_batch_norm_momentum=0.0,
    )

    print(da.experiment_name)

    # da.train(iterations=5000)
    da.evaluate(iterations=50, num_tasks=1000, seed=42)
def run_fungi():
    fungi_database = FungiDatabase()

    maml = ModelAgnosticMetaLearningModel(
        database=fungi_database,
        test_database=MiniImagenetDatabase(),
        network_cls=MiniImagenetModel,
        n=5,
        k_ml=1,
        k_val_ml=5,
        k_val=1,
        k_val_val=15,
        k_test=5,
        k_val_test=15,
        meta_batch_size=4,
        num_steps_ml=5,
        lr_inner_ml=0.05,
        num_steps_validation=5,
        save_after_iterations=15000,
        meta_learning_rate=0.001,
        report_validation_frequency=1000,
        log_train_images_after_iteration=1000,
        num_tasks_val=100,
        clip_gradients=True,
        experiment_name='fungi',
        val_seed=42,
        val_test_batch_norm_momentum=0.0,
    )

    maml.train(iterations=60000)
    maml.evaluate(50, seed=42, num_tasks=1000, use_val_batch_statistics=True)
    maml.evaluate(50, seed=42, num_tasks=1000, use_val_batch_statistics=False)
Пример #3
0
 def __init__(self, meta_train_databases=None, *args, **kwargs):
     super(CombinedCrossDomainMetaLearning, self).__init__(*args, **kwargs)
     if meta_train_databases is None:
         self.meta_train_databases = [
             MiniImagenetDatabase(),
             AirplaneDatabase(),
             CUBDatabase(),
             OmniglotDatabase(random_seed=47,
                              num_train_classes=1200,
                              num_val_classes=100),
             DTDDatabase(),
             FungiDatabase(),
             VGGFlowerDatabase()
         ]
     else:
         self.meta_train_databases = meta_train_databases
def run_domain_attention():
    train_domain_databases = [
        MiniImagenetDatabase(),
        OmniglotDatabase(random_seed=47,
                         num_train_classes=1200,
                         num_val_classes=100),
        DTDDatabase(),
        VGGFlowerDatabase()
    ]
    meta_train_domain_databases = [
        AirplaneDatabase(),
        FungiDatabase(),
        CUBDatabase(),
    ]

    test_database = EuroSatDatabase()

    ewda = ElementWiseDomainAttention(
        train_databases=train_domain_databases,
        meta_train_databases=meta_train_domain_databases,
        database=test_database,
        test_database=test_database,
        network_cls=None,
        image_shape=(84, 84, 3),
        n=5,
        k_ml=1,
        k_val_ml=5,
        k_val=1,
        k_val_val=15,
        k_test=5,
        k_val_test=15,
        meta_batch_size=4,
        num_steps_ml=5,
        lr_inner_ml=0.05,
        num_steps_validation=5,
        save_after_iterations=15000,
        meta_learning_rate=0.001,
        report_validation_frequency=1000,
        log_train_images_after_iteration=1000,
        num_tasks_val=100,
        clip_gradients=True,
        experiment_name='element_wise_domain_attention',
        val_seed=42,
        val_test_batch_norm_momentum=0.0,
    )
    ewda.train(iterations=60000)
    ewda.evaluate(iterations=50, num_tasks=1000, seed=14)
def run_airplane():
    test_database = FungiDatabase()

    cdae = CrossDomainAE(
        database=test_database,
        batch_size=512,
        domains=('airplane', 'cub', 'dtd', 'miniimagenet', 'omniglot',
                 'vggflowers'),
        # domains=('airplane', 'cub', 'dtd', 'miniimagenet', 'omniglot', 'vggflowers'),
    )

    # cdae.train(epochs=20)
    cdae.evaluate(10,
                  num_tasks=1000,
                  k_test=5,
                  k_val_test=15,
                  inner_learning_rate=0.001,
                  seed=42)
    def get_train_dataset(self):
        databases = [
            MiniImagenetDatabase(),
            AirplaneDatabase(),
            CUBDatabase(),
            OmniglotDatabase(random_seed=47, num_train_classes=1200, num_val_classes=100),
            DTDDatabase(),
            FungiDatabase(),
            VGGFlowerDatabase()
        ]

        dataset = self.get_cross_domain_meta_learning_dataset(
            databases=databases,
            n=self.n,
            k_ml=self.k_ml,
            k_validation=self.k_val_ml,
            meta_batch_size=self.meta_batch_size
        )
        return dataset

if __name__ == '__main__':
    tf.config.set_visible_devices([], 'GPU')
    root_folder_to_save = os.path.expanduser('~/datasets_visualization/')
    if not os.path.exists(root_folder_to_save):
        os.mkdir(root_folder_to_save)

    databases = (
        MiniImagenetDatabase(),
        OmniglotDatabase(random_seed=42,
                         num_train_classes=1200,
                         num_val_classes=100),
        AirplaneDatabase(),
        CUBDatabase(),
        DTDDatabase(),
        FungiDatabase(),
        VGGFlowerDatabase(),
        # TrafficSignDatabase(),
        # MSCOCODatabase(),
        # PlantDiseaseDatabase(),
        # EuroSatDatabase(),
        # ISICDatabase(),
        # ChestXRay8Database(),
    )

    # for database in databases:
    #     visualize_database(database, root_folder_to_save)

    visualize_all_domains_together(databases, root_folder_to_save)
                    z[0, ...] + (z[(i + 1) % self.n, ...] - z[0, ...]) * 0.3,
                    z[1, ...] + (z[(i + 2) % self.n, ...] - z[1, ...]) * 0.3,
                    z[2, ...] + (z[(i + 3) % self.n, ...] - z[2, ...]) * 0.3,
                    z[3, ...] + (z[(i + 4) % self.n, ...] - z[3, ...]) * 0.3,
                    z[4, ...] + (z[(i + 0) % self.n, ...] - z[4, ...]) * 0.3,
                ],
                                 axis=0)
                vectors.append(new_z)

        return vectors


if __name__ == '__main__':
    # tf.config.experimental_run_functions_eagerly(True)

    fungi_database = FungiDatabase()
    shape = (84, 84, 3)
    latent_dim = 512
    gan = StyleGAN(lr=0.0001, silent=False, training=False)
    gan.load(23)
    gan.trainable = False
    setattr(gan, 'parser', MiniImagenetParser(shape=shape))

    maml_gan = MiniImageNetMAMLStyleGan2(
        gan=gan,
        latent_dim=latent_dim,
        generated_image_shape=shape,
        database=fungi_database,
        network_cls=MiniImagenetModel,
        n=5,
        k_ml=1,