def run_domain_attention(): train_domain_databases = [ MiniImagenetDatabase(), OmniglotDatabase(random_seed=47, num_train_classes=1200, num_val_classes=100), DTDDatabase(), VGGFlowerDatabase() ] meta_train_domain_databases = [ AirplaneDatabase(), # FungiDatabase(), # CUBDatabase(), ] test_database = AirplaneDatabase() da = DomainAttention( train_databases=train_domain_databases, meta_train_databases=meta_train_domain_databases, database=test_database, test_database=test_database, network_cls=None, image_shape=(84, 84, 3), n=5, k_ml=1, k_val_ml=5, k_val=1, k_val_val=15, k_test=1, k_val_test=15, meta_batch_size=4, num_steps_ml=5, lr_inner_ml=0.05, num_steps_validation=5, save_after_iterations=1000, meta_learning_rate=0.001, report_validation_frequency=1000, log_train_images_after_iteration=1000, num_tasks_val=100, clip_gradients=True, # experiment_name='domain_attention_no_frozen_layers_airplane', experiment_name='domain_attention_all_frozen_layers_airplane', val_seed=42, val_test_batch_norm_momentum=0.0, ) da.train(iterations=5000) da.evaluate(iterations=50, num_tasks=1000, seed=14)
def run_airplane(): airplane_database = AirplaneDatabase() maml = ModelAgnosticMetaLearningModel( database=airplane_database, # test_database=Omniglot84x84Database(random_seed=47, num_train_classes=1200, num_val_classes=100), test_database=MiniImagenetDatabase(), network_cls=MiniImagenetModel, n=5, k_ml=1, k_val_ml=5, k_val=1, k_val_val=15, k_test=5, k_val_test=15, meta_batch_size=4, num_steps_ml=5, lr_inner_ml=0.05, num_steps_validation=5, save_after_iterations=15000, meta_learning_rate=0.001, report_validation_frequency=1000, log_train_images_after_iteration=1000, num_tasks_val=100, clip_gradients=True, experiment_name='airplane', val_seed=42, val_test_batch_norm_momentum=0.0, ) # maml.train(iterations=60000) maml.evaluate(50, num_tasks=1000, seed=42, use_val_batch_statistics=True) maml.evaluate(50, num_tasks=1000, seed=42, use_val_batch_statistics=False)
def run_airplane(): test_database = AirplaneDatabase() cdae = CrossDomainAE2( database=test_database, batch_size=512, # domains=('fungi', ), # domains=('airplane', 'fungi', 'cub', 'dtd', 'miniimagenet', 'omniglot', 'vggflowers'), domains=('fungi', 'cub', 'dtd', 'miniimagenet', 'omniglot', 'vggflowers'), # domains=('cub', 'miniimagenet', 'vggflowers'), # domains=('fungi', 'cub', 'dtd', 'miniimagenet', 'omniglot', 'vggflowers'), ) experiment_name = 'all_domains_288' cdae.train(epochs=20, experiment_name=experiment_name) cdae.evaluate( 10, num_tasks=1000, k_test=1, k_val_test=15, inner_learning_rate=0.001, experiment_name=experiment_name, seed=42 )
def __init__(self, meta_train_databases=None, *args, **kwargs): super(CombinedCrossDomainMetaLearning, self).__init__(*args, **kwargs) if meta_train_databases is None: self.meta_train_databases = [ MiniImagenetDatabase(), AirplaneDatabase(), CUBDatabase(), OmniglotDatabase(random_seed=47, num_train_classes=1200, num_val_classes=100), DTDDatabase(), FungiDatabase(), VGGFlowerDatabase() ] else: self.meta_train_databases = meta_train_databases
def get_train_dataset(self): databases = [ MiniImagenetDatabase(), AirplaneDatabase(), CUBDatabase(), OmniglotDatabase(random_seed=47, num_train_classes=1200, num_val_classes=100), DTDDatabase(), FungiDatabase(), VGGFlowerDatabase() ] dataset = self.get_cross_domain_meta_learning_dataset( databases=databases, n=self.n, k_ml=self.k_ml, k_validation=self.k_val_ml, meta_batch_size=self.meta_batch_size ) return dataset
plt.savefig(fname=os.path.join(root_folder_to_save, 'all_domains2.pdf')) plt.show() if __name__ == '__main__': tf.config.set_visible_devices([], 'GPU') root_folder_to_save = os.path.expanduser('~/datasets_visualization/') if not os.path.exists(root_folder_to_save): os.mkdir(root_folder_to_save) databases = ( MiniImagenetDatabase(), OmniglotDatabase(random_seed=42, num_train_classes=1200, num_val_classes=100), AirplaneDatabase(), CUBDatabase(), DTDDatabase(), FungiDatabase(), VGGFlowerDatabase(), # TrafficSignDatabase(), # MSCOCODatabase(), # PlantDiseaseDatabase(), # EuroSatDatabase(), # ISICDatabase(), # ChestXRay8Database(), ) # for database in databases: # visualize_database(database, root_folder_to_save)
from databases import AirplaneDatabase from models.maml_umtra.maml_umtra import MAMLUMTRA from networks.maml_umtra_networks import MiniImagenetModel if __name__ == '__main__': # import tensorflow as tf # tf.config.experimental_run_functions_eagerly(True) airplane_database = AirplaneDatabase() maml_umtra = MAMLUMTRA(database=airplane_database, network_cls=MiniImagenetModel, n=5, k_ml=2, k_val_ml=5, k_val=1, k_val_val=15, k_test=5, k_val_test=15, meta_batch_size=4, num_steps_ml=5, lr_inner_ml=0.05, num_steps_validation=5, save_after_iterations=15000, meta_learning_rate=0.001, report_validation_frequency=250, log_train_images_after_iteration=1000, num_tasks_val=100, clip_gradients=True, experiment_name='airplane', val_seed=42,