def setUp(self): def parse_function(example_address): return example_address self.parse_function = parse_function self.omniglot_database = OmniglotDatabase(random_seed=-1, num_train_classes=1200, num_val_classes=100) self.mini_imagenet_database = MiniImagenetDatabase() self.celeba_database = CelebADatabase() self.lfw_database = LFWDatabase() self.euro_sat_database = EuroSatDatabase() self.isic_database = ISICDatabase() self.chest_x_ray_database = ChestXRay8Database()
def run_omniglot(): omniglot_database = OmniglotDatabase( random_seed=47, num_train_classes=1200, num_val_classes=100, ) proto_net = PrototypicalNetworks( database=omniglot_database, network_cls=SimpleModelProto, n=5, k_ml=1, k_val_ml=5, k_val=1, k_val_val=15, k_test=1, k_val_test=15, meta_batch_size=4, save_after_iterations=1000, meta_learning_rate=0.001, report_validation_frequency=200, log_train_images_after_iteration=200, # Set to -1 if you do not want to log train images. num_tasks_val=100, val_seed=-1, experiment_name=None ) # proto_net.train(iterations=5000) proto_net.evaluate(-1, num_tasks=1000)
def run_transfer_learning(): omniglot_database = OmniglotDatabase(random_seed=47, num_train_classes=1200, num_val_classes=100) transfer_learning = TransferLearning( database=omniglot_database, network_cls=get_network, n=20, k_ml=1, k_val_ml=5, k_val_val=15, k_val=1, k_test=5, k_val_test=15, meta_batch_size=4, num_steps_ml=5, lr_inner_ml=0.05, num_steps_validation=5, save_after_iterations=15000, meta_learning_rate=0.001, report_validation_frequency=250, log_train_images_after_iteration=1000, num_tasks_val=100, clip_gradients=True, experiment_name='fixed_vgg_19', val_seed=42, val_test_batch_norm_momentum=0.0, ) print(f'k: {transfer_learning.k_test}') transfer_learning.evaluate(50, num_tasks=1000, seed=42, use_val_batch_statistics=False)
def run_omniglot(): omniglot_database = OmniglotDatabase( random_seed=47, num_train_classes=1200, num_val_classes=100, ) maml = ModelAgnosticMetaLearningModel( database=omniglot_database, network_cls=SimpleModel, n=20, k_ml=1, k_val_ml=5, k_val=1, k_val_val=15, k_test=5, k_val_test=15, meta_batch_size=4, num_steps_ml=5, lr_inner_ml=0.4, num_steps_validation=5, save_after_iterations=1000, meta_learning_rate=0.001, report_validation_frequency=50, log_train_images_after_iteration=200, num_tasks_val=100, clip_gradients=False, experiment_name='omniglot', val_seed=42, val_test_batch_norm_momentum=0.0 ) # maml.train(iterations=5000) maml.evaluate(iterations=50, num_tasks=1000, use_val_batch_statistics=True, seed=42)
def run_domain_attention(): train_domain_databases = [ MiniImagenetDatabase(), OmniglotDatabase(random_seed=47, num_train_classes=1200, num_val_classes=100), DTDDatabase(), VGGFlowerDatabase() ] meta_train_domain_databases = [ # AirplaneDatabase(), FungiDatabase(), # CUBDatabase(), ] test_database = FungiDatabase() da = DomainAttentionUnsupervised( train_databases=train_domain_databases, meta_train_databases=meta_train_domain_databases, database=test_database, test_database=test_database, network_cls=None, image_shape=(84, 84, 3), n=5, k_ml=1, k_val_ml=5, k_val=1, k_val_val=15, k_test=5, k_val_test=15, meta_batch_size=4, num_steps_ml=1, lr_inner_ml=0.05, num_steps_validation=5, save_after_iterations=5000, meta_learning_rate=0.001, report_validation_frequency=1000, log_train_images_after_iteration=1000, num_tasks_val=100, clip_gradients=True, experiment_name= 'domain_attention_all_frozen_layers_unsupervised_fungi2', val_seed=42, val_test_batch_norm_momentum=0.0, ) print(da.experiment_name) # da.train(iterations=5000) da.evaluate(iterations=50, num_tasks=1000, seed=42)
def __init__(self, meta_train_databases=None, *args, **kwargs): super(CombinedCrossDomainMetaLearning, self).__init__(*args, **kwargs) if meta_train_databases is None: self.meta_train_databases = [ MiniImagenetDatabase(), AirplaneDatabase(), CUBDatabase(), OmniglotDatabase(random_seed=47, num_train_classes=1200, num_val_classes=100), DTDDatabase(), FungiDatabase(), VGGFlowerDatabase() ] else: self.meta_train_databases = meta_train_databases
def run_domain_attention(): train_domain_databases = [ MiniImagenetDatabase(), OmniglotDatabase(random_seed=47, num_train_classes=1200, num_val_classes=100), DTDDatabase(), VGGFlowerDatabase() ] meta_train_domain_databases = [ AirplaneDatabase(), FungiDatabase(), CUBDatabase(), ] test_database = EuroSatDatabase() ewda = ElementWiseDomainAttention( train_databases=train_domain_databases, meta_train_databases=meta_train_domain_databases, database=test_database, test_database=test_database, network_cls=None, image_shape=(84, 84, 3), n=5, k_ml=1, k_val_ml=5, k_val=1, k_val_val=15, k_test=5, k_val_test=15, meta_batch_size=4, num_steps_ml=5, lr_inner_ml=0.05, num_steps_validation=5, save_after_iterations=15000, meta_learning_rate=0.001, report_validation_frequency=1000, log_train_images_after_iteration=1000, num_tasks_val=100, clip_gradients=True, experiment_name='element_wise_domain_attention', val_seed=42, val_test_batch_norm_momentum=0.0, ) ewda.train(iterations=60000) ewda.evaluate(iterations=50, num_tasks=1000, seed=14)
def get_train_dataset(self): databases = [ MiniImagenetDatabase(), AirplaneDatabase(), CUBDatabase(), OmniglotDatabase(random_seed=47, num_train_classes=1200, num_val_classes=100), DTDDatabase(), FungiDatabase(), VGGFlowerDatabase() ] dataset = self.get_cross_domain_meta_learning_dataset( databases=databases, n=self.n, k_ml=self.k_ml, k_validation=self.k_val_ml, meta_batch_size=self.meta_batch_size ) return dataset
def run_omniglot(): omniglot_database = OmniglotDatabase( random_seed=30, num_train_classes=1200, num_val_classes=100, ) n_data_points = 10000 data_points, classes, non_labeled_data_points = sample_data_points( omniglot_database.train_folders, n_data_points) features_dataset, n_classes = make_features_dataset_omniglot( data_points, classes, non_labeled_data_points) print(n_classes) feature_model = SimpleModelFeature(num_classes=5).get_sequential_model() feature_model = train_the_feature_model(feature_model, features_dataset, n_classes, omniglot_database.input_shape) sml = SML(database=omniglot_database, network_cls=SimpleModel, n=5, k=1, k_val_ml=5, k_val_val=15, k_val_test=15, meta_batch_size=32, num_steps_ml=5, lr_inner_ml=0.01, num_steps_validation=5, save_after_epochs=5, meta_learning_rate=0.001, n_clusters=1200, feature_model=feature_model, feature_size=256, input_shape=(28, 28, 1), log_train_images_after_iteration=10, report_validation_frequency=10, experiment_name='omniglot_vae_model_feature_10000') sml.train(epochs=101)
def run_omniglot(): omniglot_database = OmniglotDatabase( random_seed=47, num_train_classes=1200, num_val_classes=100, ) anil = ANIL(database=omniglot_database, network_cls=SimpleModel, n=5, k_ml=1, k_val_ml=5, k_val=1, k_val_val=15, k_test=1, k_val_test=15, meta_batch_size=4, num_steps_ml=5, lr_inner_ml=0.4, num_steps_validation=5, save_after_iterations=1000, meta_learning_rate=0.001, report_validation_frequency=50, log_train_images_after_iteration=200, num_tasks_val=100, clip_gradients=False, experiment_name='omniglot3', val_seed=42, val_test_batch_norm_momentum=0.0, set_of_frozen_layers={ 'conv1', 'conv2', 'conv3', 'conv4', 'bn1', 'bn2', 'bn3', 'bn4' }) anil.train(iterations=5000) anil.evaluate(iterations=50, num_tasks=1000, use_val_batch_statistics=True, seed=42)
def run_omniglot(): omniglot_database = OmniglotDatabase(random_seed=47, num_train_classes=1200, num_val_classes=100) base_model = tf.keras.applications.VGG19(weights='imagenet') feature_model = tf.keras.models.Model(inputs=base_model.input, outputs=base_model.layers[24].output) sml = SML( database=omniglot_database, network_cls=SimpleModel, n=5, k=1, k_val_ml=5, k_val_val=15, k_val_test=15, k_test=1, meta_batch_size=4, num_steps_ml=5, lr_inner_ml=0.05, num_steps_validation=5, save_after_iterations=15000, meta_learning_rate=0.001, n_clusters=5000, feature_model=feature_model, # feature_size=288, feature_size=4096, input_shape=(224, 224, 3), preprocess_function=tf.keras.applications.vgg19.preprocess_input, log_train_images_after_iteration=1000, number_of_tasks_val=100, number_of_tasks_test=1000, clip_gradients=True, report_validation_frequency=250, experiment_name='omniglot_imagenet_features') sml.train(iterations=60000) sml.evaluate(iterations=50, seed=42)
def run_omniglot(): omniglot_database = OmniglotDatabase( random_seed=47, num_train_classes=1200, num_val_classes=100, ) maml = MAMLAbstractLearner( database=omniglot_database, network_cls=SimpleModel, n=5, k=1, meta_batch_size=32, num_steps_ml=1, lr_inner_ml=0.4, num_steps_validation=10, save_after_epochs=500, meta_learning_rate=0.001, report_validation_frequency=10, log_train_images_after_iteration=-1, ) maml.train(epochs=4000) maml.evaluate(iterations=50)
def run_omniglot(): omniglot_database = OmniglotDatabase( random_seed=47, num_train_classes=1200, num_val_classes=100, ) gan_sampling = GANSampling( database=omniglot_database, network_cls=SimpleModel, n=5, k=1, k_val_ml=5, k_val_train=1, k_val_val=15, k_val_test=15, k_test=1, meta_batch_size=32, num_steps_ml=5, # 1 for prev result lr_inner_ml=0.4, num_steps_validation=5, save_after_iterations=1000, meta_learning_rate=0.001, report_validation_frequency=50, log_train_images_after_iteration=200, number_of_tasks_val=100, number_of_tasks_test=1000, clip_gradients=False, experiment_name='omniglot', val_seed=42, val_test_batch_norm_momentum=0.0) gan_sampling.train(iterations=5000) gan_sampling.evaluate(iterations=50, use_val_batch_statistics=True, seed=42)
from databases import OmniglotDatabase from models.maml_umtra.maml_umtra import MAMLUMTRA from networks.maml_umtra_networks import SimpleModel if __name__ == '__main__': # import tensorflow as tf # tf.config.experimental_run_functions_eagerly(True) omniglot_database = OmniglotDatabase(random_seed=47, num_train_classes=1200, num_val_classes=100) maml_umtra = MAMLUMTRA(database=omniglot_database, network_cls=SimpleModel, n=5, k=1, k_val_ml=5, k_val_val=15, k_val_test=15, k_test=5, meta_batch_size=4, num_steps_ml=5, lr_inner_ml=0.4, num_steps_validation=5, save_after_iterations=1000, meta_learning_rate=0.001, report_validation_frequency=200, log_train_images_after_iteration=200, number_of_tasks_val=100, number_of_tasks_test=1000, clip_gradients=False,
# fig.suptitle('', fontsize=12, y=1) plt.savefig(fname=os.path.join(root_folder_to_save, 'all_domains2.pdf')) plt.show() if __name__ == '__main__': tf.config.set_visible_devices([], 'GPU') root_folder_to_save = os.path.expanduser('~/datasets_visualization/') if not os.path.exists(root_folder_to_save): os.mkdir(root_folder_to_save) databases = ( MiniImagenetDatabase(), OmniglotDatabase(random_seed=42, num_train_classes=1200, num_val_classes=100), AirplaneDatabase(), CUBDatabase(), DTDDatabase(), FungiDatabase(), VGGFlowerDatabase(), # TrafficSignDatabase(), # MSCOCODatabase(), # PlantDiseaseDatabase(), # EuroSatDatabase(), # ISICDatabase(), # ChestXRay8Database(), ) # for database in databases:
prefix = sys.argv[2] else: print("Error: Too many arguments") sys.exit(0) # CONFIGS ITERATIONS = 5000 GAN_EPOCHS = 500 GAN_CHECKPOINTS = 50 N_TASK_EVAL = 1000 K = 1 TRAIN_GAN = True LASIUM_TYPE = "p1" GAN_N_ALT = 50 # How many times to alternate between unlabeled and labeled omniglot_database = OmniglotDatabase(random_seed=47, num_train_classes=1200, num_val_classes=100) shape = (28, 28, 1) latent_dim = 128 omniglot_generator = get_generator(latent_dim) omniglot_parser = OmniglotParser(shape=shape) experiment_name = prefix+str(labeled_percentage) if GAN_N_ALT > 1: experiment_name += "_alt"+str(GAN_N_ALT) # Split labeled and not labeled train_folders = omniglot_database.train_folders keys = list(train_folders.keys()) labeled_keys = np.random.choice(keys, int(len(train_folders.keys())*labeled_percentage), replace=False) train_folders_labeled = {k: v for (k, v) in train_folders.items() if k in labeled_keys} train_folders_unlabeled = {k: v for (k, v) in train_folders.items() if k not in labeled_keys}
class TestDatabases(unittest.TestCase): def setUp(self): def parse_function(example_address): return example_address self.parse_function = parse_function self.omniglot_database = OmniglotDatabase(random_seed=-1, num_train_classes=1200, num_val_classes=100) self.mini_imagenet_database = MiniImagenetDatabase() self.celeba_database = CelebADatabase() self.lfw_database = LFWDatabase() self.euro_sat_database = EuroSatDatabase() self.isic_database = ISICDatabase() self.chest_x_ray_database = ChestXRay8Database() @property def databases(self): return ( self.omniglot_database, self.mini_imagenet_database, self.celeba_database, self.lfw_database, self.euro_sat_database, self.isic_database, self.chest_x_ray_database, ) def test_train_val_test_folders_are_separate(self): for database in (self.omniglot_database, self.mini_imagenet_database, self.celeba_database): train_folders = set(database.train_folders) val_folders = set(database.val_folders) test_folders = set(database.test_folders) for folder in train_folders: self.assertNotIn(folder, val_folders) self.assertNotIn(folder, test_folders) for folder in val_folders: self.assertNotIn(folder, train_folders) self.assertNotIn(folder, test_folders) for folder in test_folders: self.assertNotIn(folder, train_folders) self.assertNotIn(folder, val_folders) def test_train_val_test_folders_are_comprehensive(self): self.assertEqual( 1623, len( list(self.omniglot_database.train_folders.keys()) + list(self.omniglot_database.val_folders.keys()) + list(self.omniglot_database.test_folders.keys()))) self.assertEqual( 100, len( list(self.mini_imagenet_database.train_folders) + list(self.mini_imagenet_database.val_folders) + list(self.mini_imagenet_database.test_folders))) self.assertEqual( 10177, len( list(self.celeba_database.train_folders) + list(self.celeba_database.val_folders) + list(self.celeba_database.test_folders))) @patch('tf_datasets.OmniglotDatabase._get_parse_function') def test_covering_all_classes_in_one_epoch(self, mocked_parse_function): # Make a new database so that the number of classes are dividable by the number of meta batches * n. mocked_parse_function.return_value = self.parse_function ds = self.omniglot_database.get_supervised_meta_learning_dataset( self.omniglot_database.train_folders, n=5, k=5, k_validation=5, meta_batch_size=4, dtype=tf.string) # Check for covering all classes and no duplicates classes = set() for (x, x_val), (y, y_val) in ds: train_ds = combine_first_two_axes(x) for i in range(train_ds.shape[0]): class_instances = train_ds[i, ...] class_instance_address = class_instances[0].numpy().decode( 'utf-8') class_address = os.path.split(class_instance_address)[0] self.assertNotIn(class_address, classes) classes.add(class_address) self.assertSetEqual(classes, set(self.omniglot_database.train_folders.keys())) @patch('tf_datasets.OmniglotDatabase._get_parse_function') def test_covering_all_classes_in_subsequent_epochs(self, mocked_parse_function): # This test might not pass because of the random selection of classes at the beginning of the class, but the # chances are low. Specially if we increase the number of epochs, the chance of not covering classes will # decrease. Make a new database so that the number of classes are dividable by the number of meta batches * n. # This test should always fail with num_epochs = 1 num_epochs = 3 mocked_parse_function.return_value = self.parse_function ds = self.omniglot_database.get_supervised_meta_learning_dataset( self.omniglot_database.train_folders, n=7, k=5, k_validation=5, meta_batch_size=4, seed=42, dtype=tf.string) # Check for covering all classes classes = set() for epoch in range(num_epochs): for (x, x_val), (y, y_val) in ds: train_ds = combine_first_two_axes(x) for i in range(train_ds.shape[0]): class_instances = train_ds[i, ...] class_instance_address = class_instances[0].numpy().decode( 'utf-8') class_address = os.path.split(class_instance_address)[0] classes.add(class_address) self.assertSetEqual(classes, set(self.omniglot_database.train_folders.keys())) @patch('tf_datasets.OmniglotDatabase._get_parse_function') def test_labels_are_correct_in_train_and_val_for_every_task( self, mocked_parse_function): mocked_parse_function.return_value = self.parse_function n = 7 k = 5 k_validation = 6 ds = self.omniglot_database.get_supervised_meta_learning_dataset( self.omniglot_database.train_folders, n=n, k=k, k_validation=k_validation, meta_batch_size=4, one_hot_labels=False, dtype=tf.string) for epoch in range(2): for (x, x_val), (y, y_val) in ds: for meta_batch_index in range(x.shape[0]): train_ds = x[meta_batch_index, ...] val_ds = x_val[meta_batch_index, ...] train_labels = y[meta_batch_index, ...] val_labels = y_val[meta_batch_index, ...] class_label_dict = dict() for class_index in range(n): for instance_index in range(k): instance_name = train_ds[class_index, instance_index, ...] label = train_labels[class_index * k + instance_index, ...].numpy() class_name = os.path.split( instance_name.numpy().decode('utf-8'))[0] if class_name in class_label_dict: self.assertEqual(class_label_dict[class_name], label) else: class_label_dict[class_name] = label for class_index in range(n): for instance_index in range(k_validation): instance_name = val_ds[class_index, instance_index, ...] label = val_labels[class_index * k_validation + instance_index, ...].numpy() class_name = os.path.split( instance_name.numpy().decode('utf-8'))[0] self.assertIn(class_name, class_label_dict) self.assertEqual(class_label_dict[class_name], label) @patch('tf_datasets.OmniglotDatabase._get_parse_function') def test_train_and_val_have_different_samples_in_every_task( self, mocked_parse_function): mocked_parse_function.return_value = self.parse_function n = 6 k = 4 k_validation = 5 ds = self.omniglot_database.get_supervised_meta_learning_dataset( self.omniglot_database.train_folders, n=n, k=k, k_validation=k_validation, meta_batch_size=4, one_hot_labels=False, dtype=tf.string) for epoch in range(4): for (x, x_val), (y, y_val) in ds: for meta_batch_index in range(x.shape[0]): train_ds = x[meta_batch_index, ...] val_ds = x_val[meta_batch_index, ...] class_instances_dict = dict() for class_index in range(n): for instance_index in range(k): instance_name = train_ds[class_index, instance_index, ...] class_name, instance_name = os.path.split( instance_name.numpy().decode('utf-8')) if class_name not in class_instances_dict: class_instances_dict[class_name] = set() class_instances_dict[class_name].add(instance_name) for class_index in range(n): for instance_index in range(k_validation): instance_name = val_ds[class_index, instance_index, ...] class_name, instance_name = os.path.split( instance_name.numpy().decode('utf-8')) self.assertIn(class_name, class_instances_dict) self.assertNotIn(instance_name, class_instances_dict[class_name]) @patch('tf_datasets.OmniglotDatabase._get_parse_function') def test_no_two_class_in_the_same_task(self, mocked_parse_function): mocked_parse_function.return_value = self.parse_function n = 6 k = 4 k_validation = 5 ds = self.omniglot_database.get_supervised_meta_learning_dataset( self.omniglot_database.train_folders, n=n, k=k, k_validation=k_validation, meta_batch_size=4, one_hot_labels=False, dtype=tf.string) for epoch in range(4): for (x, x_val), (y, y_val) in ds: for base_data, k_value in zip((x, x_val), (k, k_validation)): for meta_batch_index in range(x.shape[0]): train_ds = base_data[meta_batch_index, ...] classes = dict() for class_index in range(n): for instance_index in range(k_value): instance_name = train_ds[class_index, instance_index, ...] class_name = os.path.split( instance_name.numpy().decode('utf-8'))[0] classes[class_name] = classes.get( class_name, 0) + 1 for class_name, num_class_instances in classes.items(): self.assertEqual(k_value, num_class_instances) @patch('tf_datasets.OmniglotDatabase._get_parse_function') def test_different_instances_are_selected_from_each_class_for_train_and_val_each_time( self, mocked_parse_function): # Random seed is selected such that instances are not selected the same for the whole epoch. # This test might fail due to change in random or behaviour of selecting the samples and it might not mean that the code # does not work properly. Maybe from one task in two different times the same train and val data will be selected mocked_parse_function.return_value = self.parse_function n = 6 k = 4 ds = self.omniglot_database.get_supervised_meta_learning_dataset( self.omniglot_database.train_folders, n=n, k=k, k_validation=8, meta_batch_size=4, one_hot_labels=False, dtype=tf.string) class_instances = dict() class_instances[0] = dict() class_instances[1] = dict() for epoch in range(2): for (x, x_val), (y, y_val) in ds: for meta_batch_index in range(x.shape[0]): train_ds = x[meta_batch_index, ...] for class_index in range(n): for instance_index in range(k): instance_address = train_ds[class_index, instance_index, ...] class_name, instance_name = os.path.split( instance_address.numpy().decode('utf-8')) if class_name not in class_instances[epoch]: class_instances[epoch][class_name] = set() class_instances[epoch][class_name].add( instance_name) first_epoch_class_instances = class_instances[0] second_epoch_class_instances = class_instances[1] for class_name in first_epoch_class_instances.keys(): self.assertIn(class_name, second_epoch_class_instances) self.assertNotEqual( 0, first_epoch_class_instances[class_name].difference( second_epoch_class_instances[class_name])) @patch('tf_datasets.MiniImagenetDatabase._get_parse_function') def test_1000_tasks_are_completely_different(self, mocked_parse_function): # This test should run on MAML get validation dataset. # For now, I just copied the code from there here to make sure that everything works fine. # However, if that code changes, this test does not have any value. # TODO Make this test work with maml.get_validation_dataset and get_test_dataset functions. mocked_parse_function.return_value = self.parse_function def get_val_dataset(): val_dataset = self.mini_imagenet_database.get_supervised_meta_learning_dataset( self.mini_imagenet_database.val_folders, n=6, k=4, k_validation=3, meta_batch_size=1, seed=42, dtype=tf.string) val_dataset = val_dataset.repeat(-1) val_dataset = val_dataset.take(1000) return val_dataset counter = 0 xs = set() x_vals = set() xs_queue = deque() x_vals_queue = deque() for (x, x_val), (y, y_val) in get_val_dataset(): x_string = ','.join(list(map(str, x.numpy().reshape(-1)))) # print(x_string) self.assertNotIn(x_string, xs) xs.add(x_string) x_val_string = ','.join(list(map(str, x_val.numpy().reshape(-1)))) self.assertNotIn(x_val_string, x_vals) x_vals.add(x_val_string) xs_queue.append(x_string) x_vals_queue.append(x_val_string) counter += 1 self.assertEqual(counter, 1000) for (x, x_val), (y, y_val) in get_val_dataset(): x_string = ','.join(list(map(str, x.numpy().reshape(-1)))) self.assertEqual(xs_queue.popleft(), x_string) x_val_string = ','.join(list(map(str, x_val.numpy().reshape(-1)))) self.assertEqual(x_vals_queue.popleft(), x_val_string)