def run_domain_attention(): train_domain_databases = [ MiniImagenetDatabase(), OmniglotDatabase(random_seed=47, num_train_classes=1200, num_val_classes=100), DTDDatabase(), VGGFlowerDatabase() ] meta_train_domain_databases = [ # AirplaneDatabase(), FungiDatabase(), # CUBDatabase(), ] test_database = FungiDatabase() da = DomainAttentionUnsupervised( train_databases=train_domain_databases, meta_train_databases=meta_train_domain_databases, database=test_database, test_database=test_database, network_cls=None, image_shape=(84, 84, 3), n=5, k_ml=1, k_val_ml=5, k_val=1, k_val_val=15, k_test=5, k_val_test=15, meta_batch_size=4, num_steps_ml=1, lr_inner_ml=0.05, num_steps_validation=5, save_after_iterations=5000, meta_learning_rate=0.001, report_validation_frequency=1000, log_train_images_after_iteration=1000, num_tasks_val=100, clip_gradients=True, experiment_name= 'domain_attention_all_frozen_layers_unsupervised_fungi2', val_seed=42, val_test_batch_norm_momentum=0.0, ) print(da.experiment_name) # da.train(iterations=5000) da.evaluate(iterations=50, num_tasks=1000, seed=42)
def run_fungi(): fungi_database = FungiDatabase() maml = ModelAgnosticMetaLearningModel( database=fungi_database, test_database=MiniImagenetDatabase(), network_cls=MiniImagenetModel, n=5, k_ml=1, k_val_ml=5, k_val=1, k_val_val=15, k_test=5, k_val_test=15, meta_batch_size=4, num_steps_ml=5, lr_inner_ml=0.05, num_steps_validation=5, save_after_iterations=15000, meta_learning_rate=0.001, report_validation_frequency=1000, log_train_images_after_iteration=1000, num_tasks_val=100, clip_gradients=True, experiment_name='fungi', val_seed=42, val_test_batch_norm_momentum=0.0, ) maml.train(iterations=60000) maml.evaluate(50, seed=42, num_tasks=1000, use_val_batch_statistics=True) maml.evaluate(50, seed=42, num_tasks=1000, use_val_batch_statistics=False)
def __init__(self, meta_train_databases=None, *args, **kwargs): super(CombinedCrossDomainMetaLearning, self).__init__(*args, **kwargs) if meta_train_databases is None: self.meta_train_databases = [ MiniImagenetDatabase(), AirplaneDatabase(), CUBDatabase(), OmniglotDatabase(random_seed=47, num_train_classes=1200, num_val_classes=100), DTDDatabase(), FungiDatabase(), VGGFlowerDatabase() ] else: self.meta_train_databases = meta_train_databases
def run_domain_attention(): train_domain_databases = [ MiniImagenetDatabase(), OmniglotDatabase(random_seed=47, num_train_classes=1200, num_val_classes=100), DTDDatabase(), VGGFlowerDatabase() ] meta_train_domain_databases = [ AirplaneDatabase(), FungiDatabase(), CUBDatabase(), ] test_database = EuroSatDatabase() ewda = ElementWiseDomainAttention( train_databases=train_domain_databases, meta_train_databases=meta_train_domain_databases, database=test_database, test_database=test_database, network_cls=None, image_shape=(84, 84, 3), n=5, k_ml=1, k_val_ml=5, k_val=1, k_val_val=15, k_test=5, k_val_test=15, meta_batch_size=4, num_steps_ml=5, lr_inner_ml=0.05, num_steps_validation=5, save_after_iterations=15000, meta_learning_rate=0.001, report_validation_frequency=1000, log_train_images_after_iteration=1000, num_tasks_val=100, clip_gradients=True, experiment_name='element_wise_domain_attention', val_seed=42, val_test_batch_norm_momentum=0.0, ) ewda.train(iterations=60000) ewda.evaluate(iterations=50, num_tasks=1000, seed=14)
def run_airplane(): test_database = FungiDatabase() cdae = CrossDomainAE( database=test_database, batch_size=512, domains=('airplane', 'cub', 'dtd', 'miniimagenet', 'omniglot', 'vggflowers'), # domains=('airplane', 'cub', 'dtd', 'miniimagenet', 'omniglot', 'vggflowers'), ) # cdae.train(epochs=20) cdae.evaluate(10, num_tasks=1000, k_test=5, k_val_test=15, inner_learning_rate=0.001, seed=42)
def get_train_dataset(self): databases = [ MiniImagenetDatabase(), AirplaneDatabase(), CUBDatabase(), OmniglotDatabase(random_seed=47, num_train_classes=1200, num_val_classes=100), DTDDatabase(), FungiDatabase(), VGGFlowerDatabase() ] dataset = self.get_cross_domain_meta_learning_dataset( databases=databases, n=self.n, k_ml=self.k_ml, k_validation=self.k_val_ml, meta_batch_size=self.meta_batch_size ) return dataset
if __name__ == '__main__': tf.config.set_visible_devices([], 'GPU') root_folder_to_save = os.path.expanduser('~/datasets_visualization/') if not os.path.exists(root_folder_to_save): os.mkdir(root_folder_to_save) databases = ( MiniImagenetDatabase(), OmniglotDatabase(random_seed=42, num_train_classes=1200, num_val_classes=100), AirplaneDatabase(), CUBDatabase(), DTDDatabase(), FungiDatabase(), VGGFlowerDatabase(), # TrafficSignDatabase(), # MSCOCODatabase(), # PlantDiseaseDatabase(), # EuroSatDatabase(), # ISICDatabase(), # ChestXRay8Database(), ) # for database in databases: # visualize_database(database, root_folder_to_save) visualize_all_domains_together(databases, root_folder_to_save)
z[0, ...] + (z[(i + 1) % self.n, ...] - z[0, ...]) * 0.3, z[1, ...] + (z[(i + 2) % self.n, ...] - z[1, ...]) * 0.3, z[2, ...] + (z[(i + 3) % self.n, ...] - z[2, ...]) * 0.3, z[3, ...] + (z[(i + 4) % self.n, ...] - z[3, ...]) * 0.3, z[4, ...] + (z[(i + 0) % self.n, ...] - z[4, ...]) * 0.3, ], axis=0) vectors.append(new_z) return vectors if __name__ == '__main__': # tf.config.experimental_run_functions_eagerly(True) fungi_database = FungiDatabase() shape = (84, 84, 3) latent_dim = 512 gan = StyleGAN(lr=0.0001, silent=False, training=False) gan.load(23) gan.trainable = False setattr(gan, 'parser', MiniImagenetParser(shape=shape)) maml_gan = MiniImageNetMAMLStyleGan2( gan=gan, latent_dim=latent_dim, generated_image_shape=shape, database=fungi_database, network_cls=MiniImagenetModel, n=5, k_ml=1,