def run_ccdml(): test_database = EuroSatDatabase() ccdml = CombinedCrossDomainMetaLearning( database=test_database, network_cls=MiniImagenetModel, n=5, k_ml=1, k_val_ml=5, k_val=1, k_val_val=15, k_test=1, k_val_test=15, meta_batch_size=4, num_steps_ml=5, lr_inner_ml=0.05, num_steps_validation=5, save_after_iterations=15000, meta_learning_rate=0.001, report_validation_frequency=1000, log_train_images_after_iteration=1000, num_tasks_val=100, clip_gradients=True, experiment_name='cdml', val_seed=42, val_test_batch_norm_momentum=0.0, ) ccdml.train(iterations=60000) acdml.evaluate(iterations=100, num_tasks=1000, seed=14)
def run_transfer_meta_learning(): mini_imagenet_database = MiniImagenetDatabase() euro_sat_database = EuroSatDatabase() tml = TransferMetaLearningVGG16( database=mini_imagenet_database, val_database=euro_sat_database, target_database=euro_sat_database, network_cls=None, n=5, k=1, k_val_ml=5, k_val_val=15, k_val_test=15, k_test=1, meta_batch_size=4, num_steps_ml=1, lr_inner_ml=0.001, num_steps_validation=5, save_after_iterations=1500, meta_learning_rate=0.0001, report_validation_frequency=250, log_train_images_after_iteration=1000, number_of_tasks_val=100, number_of_tasks_test=100, clip_gradients=True, experiment_name='transfer_meta_learning_mini_imagenet_euro_sat', val_seed=42, val_test_batch_norm_momentum=0.0, random_layer_initialization_seed=42, num_trainable_layers=3, ) tml.train(iterations=6000) tml.evaluate(50, seed=42, use_val_batch_statistics=True) tml.evaluate(50, seed=42, use_val_batch_statistics=False)
def run_acdml(): acdml = AttentionCrossDomainMetaLearning( database=None, val_database=ISICDatabase(), test_database=EuroSatDatabase(), network_cls=get_assembled_model, n=5, k_ml=1, k_val_ml=5, k_val=1, k_val_val=15, k_test=1, k_val_test=15, meta_batch_size=4, num_steps_ml=5, lr_inner_ml=0.05, num_steps_validation=5, save_after_iterations=15000, meta_learning_rate=0.001, report_validation_frequency=1000, log_train_images_after_iteration=1000, num_tasks_val=100, clip_gradients=True, experiment_name='acdml_11', val_seed=42, val_test_batch_norm_momentum=0.0, ) acdml.train(iterations=60000) acdml.evaluate(iterations=100, num_tasks=1000, seed=14)
def run_transfer_learning(): euro_sat_database = EuroSatDatabase() transfer_learning = TransferLearningVGG16( database=euro_sat_database, n=5, k_val_test=15, k_test=5, lr_inner_ml=0.01, number_of_tasks_test=100, val_test_batch_norm_momentum=0.0, random_layer_initialization_seed=42, num_trainable_layers=0, ) transfer_learning.evaluate(10, seed=42, use_val_batch_statistics=True)
def setUp(self): def parse_function(example_address): return example_address self.parse_function = parse_function self.omniglot_database = OmniglotDatabase(random_seed=-1, num_train_classes=1200, num_val_classes=100) self.mini_imagenet_database = MiniImagenetDatabase() self.celeba_database = CelebADatabase() self.lfw_database = LFWDatabase() self.euro_sat_database = EuroSatDatabase() self.isic_database = ISICDatabase() self.chest_x_ray_database = ChestXRay8Database()
def run_domain_attention(): train_domain_databases = [ MiniImagenetDatabase(), OmniglotDatabase(random_seed=47, num_train_classes=1200, num_val_classes=100), DTDDatabase(), VGGFlowerDatabase() ] meta_train_domain_databases = [ AirplaneDatabase(), FungiDatabase(), CUBDatabase(), ] test_database = EuroSatDatabase() ewda = ElementWiseDomainAttention( train_databases=train_domain_databases, meta_train_databases=meta_train_domain_databases, database=test_database, test_database=test_database, network_cls=None, image_shape=(84, 84, 3), n=5, k_ml=1, k_val_ml=5, k_val=1, k_val_val=15, k_test=5, k_val_test=15, meta_batch_size=4, num_steps_ml=5, lr_inner_ml=0.05, num_steps_validation=5, save_after_iterations=15000, meta_learning_rate=0.001, report_validation_frequency=1000, log_train_images_after_iteration=1000, num_tasks_val=100, clip_gradients=True, experiment_name='element_wise_domain_attention', val_seed=42, val_test_batch_norm_momentum=0.0, ) ewda.train(iterations=60000) ewda.evaluate(iterations=50, num_tasks=1000, seed=14)
if __name__ == '__main__': tf.config.set_visible_devices([], 'GPU') root_folder_to_save = os.path.expanduser('~/datasets_visualization/') if not os.path.exists(root_folder_to_save): os.mkdir(root_folder_to_save) databases = ( MiniImagenetDatabase(), OmniglotDatabase(random_seed=42, num_train_classes=1200, num_val_classes=100), AirplaneDatabase(), CUBDatabase(), DTDDatabase(), FungiDatabase(), VGGFlowerDatabase(), # TrafficSignDatabase(), MSCOCODatabase(), PlantDiseaseDatabase(), EuroSatDatabase(), ISICDatabase(), ChestXRay8Database(), ) # for database in databases: # visualize_database(database, root_folder_to_save) visualize_all_domains_together(databases, root_folder_to_save)