Ejemplo n.º 1
0
def run_ccdml():
    test_database = EuroSatDatabase()
    ccdml = CombinedCrossDomainMetaLearning(
        database=test_database,
        network_cls=MiniImagenetModel,
        n=5,
        k_ml=1,
        k_val_ml=5,
        k_val=1,
        k_val_val=15,
        k_test=1,
        k_val_test=15,
        meta_batch_size=4,
        num_steps_ml=5,
        lr_inner_ml=0.05,
        num_steps_validation=5,
        save_after_iterations=15000,
        meta_learning_rate=0.001,
        report_validation_frequency=1000,
        log_train_images_after_iteration=1000,
        num_tasks_val=100,
        clip_gradients=True,
        experiment_name='cdml',
        val_seed=42,
        val_test_batch_norm_momentum=0.0,
    )

    ccdml.train(iterations=60000)
    acdml.evaluate(iterations=100, num_tasks=1000, seed=14)
Ejemplo n.º 2
0
def run_transfer_meta_learning():
    mini_imagenet_database = MiniImagenetDatabase()
    euro_sat_database = EuroSatDatabase()
    tml = TransferMetaLearningVGG16(
        database=mini_imagenet_database,
        val_database=euro_sat_database,
        target_database=euro_sat_database,
        network_cls=None,
        n=5,
        k=1,
        k_val_ml=5,
        k_val_val=15,
        k_val_test=15,
        k_test=1,
        meta_batch_size=4,
        num_steps_ml=1,
        lr_inner_ml=0.001,
        num_steps_validation=5,
        save_after_iterations=1500,
        meta_learning_rate=0.0001,
        report_validation_frequency=250,
        log_train_images_after_iteration=1000,
        number_of_tasks_val=100,
        number_of_tasks_test=100,
        clip_gradients=True,
        experiment_name='transfer_meta_learning_mini_imagenet_euro_sat',
        val_seed=42,
        val_test_batch_norm_momentum=0.0,
        random_layer_initialization_seed=42,
        num_trainable_layers=3,
    )

    tml.train(iterations=6000)
    tml.evaluate(50, seed=42, use_val_batch_statistics=True)
    tml.evaluate(50, seed=42, use_val_batch_statistics=False)
Ejemplo n.º 3
0
def run_acdml():
    acdml = AttentionCrossDomainMetaLearning(
        database=None,
        val_database=ISICDatabase(),
        test_database=EuroSatDatabase(),
        network_cls=get_assembled_model,
        n=5,
        k_ml=1,
        k_val_ml=5,
        k_val=1,
        k_val_val=15,
        k_test=1,
        k_val_test=15,
        meta_batch_size=4,
        num_steps_ml=5,
        lr_inner_ml=0.05,
        num_steps_validation=5,
        save_after_iterations=15000,
        meta_learning_rate=0.001,
        report_validation_frequency=1000,
        log_train_images_after_iteration=1000,
        num_tasks_val=100,
        clip_gradients=True,
        experiment_name='acdml_11',
        val_seed=42,
        val_test_batch_norm_momentum=0.0,
    )

    acdml.train(iterations=60000)
    acdml.evaluate(iterations=100, num_tasks=1000, seed=14)
def run_transfer_learning():
    euro_sat_database = EuroSatDatabase()
    transfer_learning = TransferLearningVGG16(
        database=euro_sat_database,
        n=5,
        k_val_test=15,
        k_test=5,
        lr_inner_ml=0.01,
        number_of_tasks_test=100,
        val_test_batch_norm_momentum=0.0,
        random_layer_initialization_seed=42,
        num_trainable_layers=0,
    )
    transfer_learning.evaluate(10, seed=42, use_val_batch_statistics=True)
Ejemplo n.º 5
0
    def setUp(self):
        def parse_function(example_address):
            return example_address

        self.parse_function = parse_function

        self.omniglot_database = OmniglotDatabase(random_seed=-1,
                                                  num_train_classes=1200,
                                                  num_val_classes=100)
        self.mini_imagenet_database = MiniImagenetDatabase()
        self.celeba_database = CelebADatabase()
        self.lfw_database = LFWDatabase()
        self.euro_sat_database = EuroSatDatabase()
        self.isic_database = ISICDatabase()
        self.chest_x_ray_database = ChestXRay8Database()
def run_domain_attention():
    train_domain_databases = [
        MiniImagenetDatabase(),
        OmniglotDatabase(random_seed=47,
                         num_train_classes=1200,
                         num_val_classes=100),
        DTDDatabase(),
        VGGFlowerDatabase()
    ]
    meta_train_domain_databases = [
        AirplaneDatabase(),
        FungiDatabase(),
        CUBDatabase(),
    ]

    test_database = EuroSatDatabase()

    ewda = ElementWiseDomainAttention(
        train_databases=train_domain_databases,
        meta_train_databases=meta_train_domain_databases,
        database=test_database,
        test_database=test_database,
        network_cls=None,
        image_shape=(84, 84, 3),
        n=5,
        k_ml=1,
        k_val_ml=5,
        k_val=1,
        k_val_val=15,
        k_test=5,
        k_val_test=15,
        meta_batch_size=4,
        num_steps_ml=5,
        lr_inner_ml=0.05,
        num_steps_validation=5,
        save_after_iterations=15000,
        meta_learning_rate=0.001,
        report_validation_frequency=1000,
        log_train_images_after_iteration=1000,
        num_tasks_val=100,
        clip_gradients=True,
        experiment_name='element_wise_domain_attention',
        val_seed=42,
        val_test_batch_norm_momentum=0.0,
    )
    ewda.train(iterations=60000)
    ewda.evaluate(iterations=50, num_tasks=1000, seed=14)
Ejemplo n.º 7
0

if __name__ == '__main__':
    tf.config.set_visible_devices([], 'GPU')
    root_folder_to_save = os.path.expanduser('~/datasets_visualization/')
    if not os.path.exists(root_folder_to_save):
        os.mkdir(root_folder_to_save)

    databases = (
        MiniImagenetDatabase(),
        OmniglotDatabase(random_seed=42,
                         num_train_classes=1200,
                         num_val_classes=100),
        AirplaneDatabase(),
        CUBDatabase(),
        DTDDatabase(),
        FungiDatabase(),
        VGGFlowerDatabase(),
        # TrafficSignDatabase(),
        MSCOCODatabase(),
        PlantDiseaseDatabase(),
        EuroSatDatabase(),
        ISICDatabase(),
        ChestXRay8Database(),
    )

    # for database in databases:
    #     visualize_database(database, root_folder_to_save)

    visualize_all_domains_together(databases, root_folder_to_save)