def run_mini_imagenet(): mini_imagenet_database = MiniImagenetDatabase() maml = ModelAgnosticMetaLearningModel( database=mini_imagenet_database, test_database=MiniImagenetDatabase(), network_cls=MiniImagenetModel, n=5, k_ml=1, k_val_ml=5, k_val=1, k_val_val=15, k_test=50, k_val_test=15, meta_batch_size=4, num_steps_ml=5, lr_inner_ml=0.05, num_steps_validation=5, save_after_iterations=15000, meta_learning_rate=0.001, report_validation_frequency=1000, log_train_images_after_iteration=1000, num_tasks_val=100, clip_gradients=True, experiment_name='mini_imagenet_test_res', val_seed=42, val_test_batch_norm_momentum=0.0, ) # maml.train(iterations=60000) maml.evaluate(50, num_tasks=1000, seed=42, use_val_batch_statistics=True) maml.evaluate(50, num_tasks=1000, seed=42, use_val_batch_statistics=False)
def run_omniglot(): omniglot_database = OmniglotDatabase( random_seed=47, num_train_classes=1200, num_val_classes=100, ) maml = ModelAgnosticMetaLearningModel( database=omniglot_database, network_cls=SimpleModel, n=20, k_ml=1, k_val_ml=5, k_val=1, k_val_val=15, k_test=5, k_val_test=15, meta_batch_size=4, num_steps_ml=5, lr_inner_ml=0.4, num_steps_validation=5, save_after_iterations=1000, meta_learning_rate=0.001, report_validation_frequency=50, log_train_images_after_iteration=200, num_tasks_val=100, clip_gradients=False, experiment_name='omniglot', val_seed=42, val_test_batch_norm_momentum=0.0 ) # maml.train(iterations=5000) maml.evaluate(iterations=50, num_tasks=1000, use_val_batch_statistics=True, seed=42)
def run_traffic_sign(): mscoco_database = MSCOCODatabase() maml = ModelAgnosticMetaLearningModel( database=mscoco_database, network_cls=MiniImagenetModel, n=5, k_ml=1, k_val_ml=5, k_val=1, k_val_val=15, k_test=1, k_val_test=15, meta_batch_size=4, num_steps_ml=5, lr_inner_ml=0.05, num_steps_validation=5, save_after_iterations=15000, meta_learning_rate=0.001, report_validation_frequency=1000, log_train_images_after_iteration=1000, num_tasks_val=100, clip_gradients=True, experiment_name='mscoco', val_seed=42, val_test_batch_norm_momentum=0.0, ) # This dataset is only for evaluation maml.evaluate(50, num_tasks=1000, seed=42, use_val_batch_statistics=True) maml.evaluate(50, num_tasks=1000, seed=42, use_val_batch_statistics=False)
def run_voxceleb(): voxceleb_database = VoxCelebDatabase() maml = ModelAgnosticMetaLearningModel( database=voxceleb_database, network_cls=VoxCelebModel, n=5, k_ml=1, k_val_ml=5, k_val=1, k_val_val=15, k_test=1, k_val_test=15, meta_batch_size=4, num_steps_ml=5, lr_inner_ml=0.05, num_steps_validation=5, save_after_iterations=15000, meta_learning_rate=0.001, report_validation_frequency=1000, log_train_images_after_iteration=-1, num_tasks_val=100, clip_gradients=True, experiment_name='voxceleb3', val_seed=42, val_test_batch_norm_momentum=0.0, ) # maml.train(iterations=60040) maml.evaluate(50, num_tasks=1000, seed=42, use_val_batch_statistics=True)
def run_celeba(): celeba_database = CelebADatabase() maml = ModelAgnosticMetaLearningModel( database=celeba_database, network_cls=MiniImagenetModel, n=5, k_ml=1, k_val_ml=5, k_val=1, k_val_val=15, k_test=1, k_val_test=15, meta_batch_size=4, num_steps_ml=1, lr_inner_ml=0.05, num_steps_validation=5, save_after_iterations=5000, meta_learning_rate=0.0001, report_validation_frequency=250, log_train_images_after_iteration=1000, num_tasks_val=100, clip_gradients=True, experiment_name='celeba' ) maml.train(iterations=60000) maml.evaluate(50, num_tasks=1000, seed=42)