def run_mini_imagenet(): mini_imagenet_database = MiniImagenetDatabase() maml = ModelAgnosticMetaLearningModel( database=mini_imagenet_database, test_database=MiniImagenetDatabase(), network_cls=MiniImagenetModel, n=5, k_ml=1, k_val_ml=5, k_val=1, k_val_val=15, k_test=50, k_val_test=15, meta_batch_size=4, num_steps_ml=5, lr_inner_ml=0.05, num_steps_validation=5, save_after_iterations=15000, meta_learning_rate=0.001, report_validation_frequency=1000, log_train_images_after_iteration=1000, num_tasks_val=100, clip_gradients=True, experiment_name='mini_imagenet_test_res', val_seed=42, val_test_batch_norm_momentum=0.0, ) # maml.train(iterations=60000) maml.evaluate(50, num_tasks=1000, seed=42, use_val_batch_statistics=True) maml.evaluate(50, num_tasks=1000, seed=42, use_val_batch_statistics=False)
def setUp(self): def parse_function(example_address): return example_address self.parse_function = parse_function self.omniglot_database = OmniglotDatabase(random_seed=-1, num_train_classes=1200, num_val_classes=100) self.mini_imagenet_database = MiniImagenetDatabase() self.celeba_database = CelebADatabase() self.lfw_database = LFWDatabase() self.euro_sat_database = EuroSatDatabase() self.isic_database = ISICDatabase() self.chest_x_ray_database = ChestXRay8Database()
def run_transfer_meta_learning(): mini_imagenet_database = MiniImagenetDatabase() euro_sat_database = EuroSatDatabase() tml = TransferMetaLearningVGG16( database=mini_imagenet_database, val_database=euro_sat_database, target_database=euro_sat_database, network_cls=None, n=5, k=1, k_val_ml=5, k_val_val=15, k_val_test=15, k_test=1, meta_batch_size=4, num_steps_ml=1, lr_inner_ml=0.001, num_steps_validation=5, save_after_iterations=1500, meta_learning_rate=0.0001, report_validation_frequency=250, log_train_images_after_iteration=1000, number_of_tasks_val=100, number_of_tasks_test=100, clip_gradients=True, experiment_name='transfer_meta_learning_mini_imagenet_euro_sat', val_seed=42, val_test_batch_norm_momentum=0.0, random_layer_initialization_seed=42, num_trainable_layers=3, ) tml.train(iterations=6000) tml.evaluate(50, seed=42, use_val_batch_statistics=True) tml.evaluate(50, seed=42, use_val_batch_statistics=False)
def run_transfer_learning(): miniimagenet_database = MiniImagenetDatabase() transfer_learning = TransferLearning( database=miniimagenet_database, network_cls=get_network, n=5, k_ml=1, k_val_ml=5, k_val_val=15, k_val=1, k_test=50, k_val_test=15, meta_batch_size=4, num_steps_ml=5, lr_inner_ml=0.05, num_steps_validation=5, save_after_iterations=15000, meta_learning_rate=0.001, report_validation_frequency=250, log_train_images_after_iteration=1000, num_tasks_val=100, clip_gradients=True, experiment_name='fixed_vgg_19', val_seed=42, val_test_batch_norm_momentum=0.0, ) print(f'k: {transfer_learning.k_test}') transfer_learning.evaluate(50, num_tasks=1000, seed=42, use_val_batch_statistics=True)
def run_mini_imagenet(): mini_imagenet_database = MiniImagenetDatabase() anil = ANILUnsupervised( database=mini_imagenet_database, network_cls=MiniImagenetModel, n=5, k_ml=1, k_val_ml=1, k_val=1, k_val_val=15, k_test=1, k_val_test=15, meta_batch_size=4, num_steps_ml=5, lr_inner_ml=0.05, num_steps_validation=5, save_after_iterations=15000, meta_learning_rate=0.001, report_validation_frequency=1000, log_train_images_after_iteration=1000, num_tasks_val=100, clip_gradients=True, experiment_name='mini_imagenet_unsupervised_permutation', val_seed=42, val_test_batch_norm_momentum=0.0, # set_of_frozen_layers={'conv1', 'conv2', 'conv3', 'conv4', 'bn1', 'bn2', 'bn3', 'bn4', 'dense'} set_of_frozen_layers={ 'conv1', 'conv2', 'conv3', 'conv4', 'bn1', 'bn2', 'bn3', 'bn4' } # set_of_frozen_layers={} ) anil.train(iterations=60000) anil.evaluate(50, num_tasks=1000, seed=42, use_val_batch_statistics=True)
def run_mini_imagenet(): mini_imagenet_database = MiniImagenetDatabase() gan_sampling = GANSampling( database=mini_imagenet_database, network_cls=MiniImagenetModel, n=5, k=1, k_val_ml=5, k_val_train=1, k_val_val=15, k_val_test=15, k_test=1, meta_batch_size=4, num_steps_ml=5, lr_inner_ml=0.05, num_steps_validation=5, save_after_iterations=1000, meta_learning_rate=0.001, report_validation_frequency=50, log_train_images_after_iteration=250, number_of_tasks_val=100, number_of_tasks_test=1000, clip_gradients=True, experiment_name='mini_imagenet_interpolation_std_1.2_shift_5', val_seed=42, val_test_batch_norm_momentum=0.0) gan_sampling.train(iterations=60000) gan_sampling.evaluate(iterations=50, use_val_batch_statistics=True, seed=42, iterations_to_load_from=16000)
def get_train_dataset(self): database = MiniImagenetDatabase() dataset = self.get_supervised_meta_learning_dataset( database.train_folders, n=self.n, k=self.k, k_validation=self.k_val_ml, meta_batch_size=self.meta_batch_size ) return dataset
def get_train_dataset(self): trn_database = MiniImagenetDatabase() val_database = PlantDiseaseDatabase() dataset = self.get_separated_supervised_meta_learning_dataset( trn_database.train_folders, val_database.train_folders, n=self.n, k=self.k_ml, k_validation=self.k_val_ml, meta_batch_size=self.meta_batch_size) return dataset
def run_domain_attention(): train_domain_databases = [ MiniImagenetDatabase(), OmniglotDatabase(random_seed=47, num_train_classes=1200, num_val_classes=100), DTDDatabase(), VGGFlowerDatabase() ] meta_train_domain_databases = [ # AirplaneDatabase(), FungiDatabase(), # CUBDatabase(), ] test_database = FungiDatabase() da = DomainAttentionUnsupervised( train_databases=train_domain_databases, meta_train_databases=meta_train_domain_databases, database=test_database, test_database=test_database, network_cls=None, image_shape=(84, 84, 3), n=5, k_ml=1, k_val_ml=5, k_val=1, k_val_val=15, k_test=5, k_val_test=15, meta_batch_size=4, num_steps_ml=1, lr_inner_ml=0.05, num_steps_validation=5, save_after_iterations=5000, meta_learning_rate=0.001, report_validation_frequency=1000, log_train_images_after_iteration=1000, num_tasks_val=100, clip_gradients=True, experiment_name= 'domain_attention_all_frozen_layers_unsupervised_fungi2', val_seed=42, val_test_batch_norm_momentum=0.0, ) print(da.experiment_name) # da.train(iterations=5000) da.evaluate(iterations=50, num_tasks=1000, seed=42)
def __init__(self, meta_train_databases=None, *args, **kwargs): super(CombinedCrossDomainMetaLearning, self).__init__(*args, **kwargs) if meta_train_databases is None: self.meta_train_databases = [ MiniImagenetDatabase(), AirplaneDatabase(), CUBDatabase(), OmniglotDatabase(random_seed=47, num_train_classes=1200, num_val_classes=100), DTDDatabase(), FungiDatabase(), VGGFlowerDatabase() ] else: self.meta_train_databases = meta_train_databases
def run_domain_attention(): train_domain_databases = [ MiniImagenetDatabase(), OmniglotDatabase(random_seed=47, num_train_classes=1200, num_val_classes=100), DTDDatabase(), VGGFlowerDatabase() ] meta_train_domain_databases = [ AirplaneDatabase(), FungiDatabase(), CUBDatabase(), ] test_database = EuroSatDatabase() ewda = ElementWiseDomainAttention( train_databases=train_domain_databases, meta_train_databases=meta_train_domain_databases, database=test_database, test_database=test_database, network_cls=None, image_shape=(84, 84, 3), n=5, k_ml=1, k_val_ml=5, k_val=1, k_val_val=15, k_test=5, k_val_test=15, meta_batch_size=4, num_steps_ml=5, lr_inner_ml=0.05, num_steps_validation=5, save_after_iterations=15000, meta_learning_rate=0.001, report_validation_frequency=1000, log_train_images_after_iteration=1000, num_tasks_val=100, clip_gradients=True, experiment_name='element_wise_domain_attention', val_seed=42, val_test_batch_norm_momentum=0.0, ) ewda.train(iterations=60000) ewda.evaluate(iterations=50, num_tasks=1000, seed=14)
def get_train_dataset(self): databases = [ MiniImagenetDatabase(), AirplaneDatabase(), CUBDatabase(), OmniglotDatabase(random_seed=47, num_train_classes=1200, num_val_classes=100), DTDDatabase(), FungiDatabase(), VGGFlowerDatabase() ] dataset = self.get_cross_domain_meta_learning_dataset( databases=databases, n=self.n, k_ml=self.k_ml, k_validation=self.k_val_ml, meta_batch_size=self.meta_batch_size ) return dataset
def run_mini_imagenet(): mini_imagenet_database = MiniImagenetDatabase(random_seed=-1) maml = MAMLAbstractLearner(database=mini_imagenet_database, network_cls=MiniImagenetModel, n=5, k=4, meta_batch_size=1, num_steps_ml=5, lr_inner_ml=0.01, num_steps_validation=5, save_after_epochs=500, meta_learning_rate=0.001, report_validation_frequency=50, log_train_images_after_iteration=1000, least_number_of_tasks_val_test=50, clip_gradients=True) maml.train(epochs=30000) maml.evaluate(50)
def run_omniglot(): omniglot_database = Omniglot84x84Database( random_seed=47, num_train_classes=1200, num_val_classes=100, ) maml = ModelAgnosticMetaLearningModel( database=omniglot_database, test_database=MiniImagenetDatabase(), network_cls=MiniImagenetModel, n=5, k_ml=1, k_val_ml=5, k_val=1, k_val_val=15, k_test=5, k_val_test=15, meta_batch_size=4, num_steps_ml=5, lr_inner_ml=0.05, num_steps_validation=5, save_after_iterations=15000, meta_learning_rate=0.001, report_validation_frequency=1000, log_train_images_after_iteration=1000, num_tasks_val=100, clip_gradients=True, experiment_name='omniglot_84x84', val_seed=42, val_test_batch_norm_momentum=0.0) maml.train(iterations=60000) maml.evaluate(iterations=50, num_tasks=1000, use_val_batch_statistics=True, seed=42) maml.evaluate(iterations=50, num_tasks=1000, use_val_batch_statistics=False, seed=42)
def run_airplane(): test_database = MiniImagenetDatabase() cdae = CrossDomainAE2( database=test_database, batch_size=512, # domains=('fungi', ), # domains=('airplane', 'fungi', 'cub', 'dtd', 'miniimagenet', 'omniglot', 'vggflowers'), domains=('airplane', 'fungi', 'cub', 'dtd', 'omniglot', 'vggflowers'), # domains=('cub', 'miniimagenet', 'vggflowers'), # domains=('fungi', 'cub', 'dtd', 'miniimagenet', 'omniglot', 'vggflowers'), ) experiment_name = 'all_domains_288' cdae.train(epochs=20, experiment_name=experiment_name) cdae.evaluate(10, num_tasks=1000, k_test=1, k_val_test=15, inner_learning_rate=0.001, experiment_name=experiment_name, seed=42)
else: new_z = tf.stack([ z[0, ...] + (z[(i + 1) % self.n, ...] - z[0, ...]) * 0.3, z[1, ...] + (z[(i + 2) % self.n, ...] - z[1, ...]) * 0.3, z[2, ...] + (z[(i + 3) % self.n, ...] - z[2, ...]) * 0.3, z[3, ...] + (z[(i + 4) % self.n, ...] - z[3, ...]) * 0.3, z[4, ...] + (z[(i + 0) % self.n, ...] - z[4, ...]) * 0.3, ], axis=0) vectors.append(new_z) return vectors if __name__ == '__main__': mini_imagenet_database = MiniImagenetDatabase(input_shape=(224, 224, 3)) shape = (224, 224, 3) latent_dim = 120 import os os.environ['TFHUB_CACHE_DIR'] = os.path.expanduser('~/tf_hub') gan = hub.load("https://tfhub.dev/deepmind/bigbigan-resnet50/1", tags=[]).signatures['generate'] setattr(gan, 'parser', MiniImagenetParser(shape=shape)) maml_gan = MiniImageNetMAMLBigGan( gan=gan, latent_dim=latent_dim, generated_image_shape=shape, database=mini_imagenet_database, network_cls=FiveLayerResNet,
class TestDatabases(unittest.TestCase): def setUp(self): def parse_function(example_address): return example_address self.parse_function = parse_function self.omniglot_database = OmniglotDatabase(random_seed=-1, num_train_classes=1200, num_val_classes=100) self.mini_imagenet_database = MiniImagenetDatabase() self.celeba_database = CelebADatabase() self.lfw_database = LFWDatabase() self.euro_sat_database = EuroSatDatabase() self.isic_database = ISICDatabase() self.chest_x_ray_database = ChestXRay8Database() @property def databases(self): return ( self.omniglot_database, self.mini_imagenet_database, self.celeba_database, self.lfw_database, self.euro_sat_database, self.isic_database, self.chest_x_ray_database, ) def test_train_val_test_folders_are_separate(self): for database in (self.omniglot_database, self.mini_imagenet_database, self.celeba_database): train_folders = set(database.train_folders) val_folders = set(database.val_folders) test_folders = set(database.test_folders) for folder in train_folders: self.assertNotIn(folder, val_folders) self.assertNotIn(folder, test_folders) for folder in val_folders: self.assertNotIn(folder, train_folders) self.assertNotIn(folder, test_folders) for folder in test_folders: self.assertNotIn(folder, train_folders) self.assertNotIn(folder, val_folders) def test_train_val_test_folders_are_comprehensive(self): self.assertEqual( 1623, len( list(self.omniglot_database.train_folders.keys()) + list(self.omniglot_database.val_folders.keys()) + list(self.omniglot_database.test_folders.keys()))) self.assertEqual( 100, len( list(self.mini_imagenet_database.train_folders) + list(self.mini_imagenet_database.val_folders) + list(self.mini_imagenet_database.test_folders))) self.assertEqual( 10177, len( list(self.celeba_database.train_folders) + list(self.celeba_database.val_folders) + list(self.celeba_database.test_folders))) @patch('tf_datasets.OmniglotDatabase._get_parse_function') def test_covering_all_classes_in_one_epoch(self, mocked_parse_function): # Make a new database so that the number of classes are dividable by the number of meta batches * n. mocked_parse_function.return_value = self.parse_function ds = self.omniglot_database.get_supervised_meta_learning_dataset( self.omniglot_database.train_folders, n=5, k=5, k_validation=5, meta_batch_size=4, dtype=tf.string) # Check for covering all classes and no duplicates classes = set() for (x, x_val), (y, y_val) in ds: train_ds = combine_first_two_axes(x) for i in range(train_ds.shape[0]): class_instances = train_ds[i, ...] class_instance_address = class_instances[0].numpy().decode( 'utf-8') class_address = os.path.split(class_instance_address)[0] self.assertNotIn(class_address, classes) classes.add(class_address) self.assertSetEqual(classes, set(self.omniglot_database.train_folders.keys())) @patch('tf_datasets.OmniglotDatabase._get_parse_function') def test_covering_all_classes_in_subsequent_epochs(self, mocked_parse_function): # This test might not pass because of the random selection of classes at the beginning of the class, but the # chances are low. Specially if we increase the number of epochs, the chance of not covering classes will # decrease. Make a new database so that the number of classes are dividable by the number of meta batches * n. # This test should always fail with num_epochs = 1 num_epochs = 3 mocked_parse_function.return_value = self.parse_function ds = self.omniglot_database.get_supervised_meta_learning_dataset( self.omniglot_database.train_folders, n=7, k=5, k_validation=5, meta_batch_size=4, seed=42, dtype=tf.string) # Check for covering all classes classes = set() for epoch in range(num_epochs): for (x, x_val), (y, y_val) in ds: train_ds = combine_first_two_axes(x) for i in range(train_ds.shape[0]): class_instances = train_ds[i, ...] class_instance_address = class_instances[0].numpy().decode( 'utf-8') class_address = os.path.split(class_instance_address)[0] classes.add(class_address) self.assertSetEqual(classes, set(self.omniglot_database.train_folders.keys())) @patch('tf_datasets.OmniglotDatabase._get_parse_function') def test_labels_are_correct_in_train_and_val_for_every_task( self, mocked_parse_function): mocked_parse_function.return_value = self.parse_function n = 7 k = 5 k_validation = 6 ds = self.omniglot_database.get_supervised_meta_learning_dataset( self.omniglot_database.train_folders, n=n, k=k, k_validation=k_validation, meta_batch_size=4, one_hot_labels=False, dtype=tf.string) for epoch in range(2): for (x, x_val), (y, y_val) in ds: for meta_batch_index in range(x.shape[0]): train_ds = x[meta_batch_index, ...] val_ds = x_val[meta_batch_index, ...] train_labels = y[meta_batch_index, ...] val_labels = y_val[meta_batch_index, ...] class_label_dict = dict() for class_index in range(n): for instance_index in range(k): instance_name = train_ds[class_index, instance_index, ...] label = train_labels[class_index * k + instance_index, ...].numpy() class_name = os.path.split( instance_name.numpy().decode('utf-8'))[0] if class_name in class_label_dict: self.assertEqual(class_label_dict[class_name], label) else: class_label_dict[class_name] = label for class_index in range(n): for instance_index in range(k_validation): instance_name = val_ds[class_index, instance_index, ...] label = val_labels[class_index * k_validation + instance_index, ...].numpy() class_name = os.path.split( instance_name.numpy().decode('utf-8'))[0] self.assertIn(class_name, class_label_dict) self.assertEqual(class_label_dict[class_name], label) @patch('tf_datasets.OmniglotDatabase._get_parse_function') def test_train_and_val_have_different_samples_in_every_task( self, mocked_parse_function): mocked_parse_function.return_value = self.parse_function n = 6 k = 4 k_validation = 5 ds = self.omniglot_database.get_supervised_meta_learning_dataset( self.omniglot_database.train_folders, n=n, k=k, k_validation=k_validation, meta_batch_size=4, one_hot_labels=False, dtype=tf.string) for epoch in range(4): for (x, x_val), (y, y_val) in ds: for meta_batch_index in range(x.shape[0]): train_ds = x[meta_batch_index, ...] val_ds = x_val[meta_batch_index, ...] class_instances_dict = dict() for class_index in range(n): for instance_index in range(k): instance_name = train_ds[class_index, instance_index, ...] class_name, instance_name = os.path.split( instance_name.numpy().decode('utf-8')) if class_name not in class_instances_dict: class_instances_dict[class_name] = set() class_instances_dict[class_name].add(instance_name) for class_index in range(n): for instance_index in range(k_validation): instance_name = val_ds[class_index, instance_index, ...] class_name, instance_name = os.path.split( instance_name.numpy().decode('utf-8')) self.assertIn(class_name, class_instances_dict) self.assertNotIn(instance_name, class_instances_dict[class_name]) @patch('tf_datasets.OmniglotDatabase._get_parse_function') def test_no_two_class_in_the_same_task(self, mocked_parse_function): mocked_parse_function.return_value = self.parse_function n = 6 k = 4 k_validation = 5 ds = self.omniglot_database.get_supervised_meta_learning_dataset( self.omniglot_database.train_folders, n=n, k=k, k_validation=k_validation, meta_batch_size=4, one_hot_labels=False, dtype=tf.string) for epoch in range(4): for (x, x_val), (y, y_val) in ds: for base_data, k_value in zip((x, x_val), (k, k_validation)): for meta_batch_index in range(x.shape[0]): train_ds = base_data[meta_batch_index, ...] classes = dict() for class_index in range(n): for instance_index in range(k_value): instance_name = train_ds[class_index, instance_index, ...] class_name = os.path.split( instance_name.numpy().decode('utf-8'))[0] classes[class_name] = classes.get( class_name, 0) + 1 for class_name, num_class_instances in classes.items(): self.assertEqual(k_value, num_class_instances) @patch('tf_datasets.OmniglotDatabase._get_parse_function') def test_different_instances_are_selected_from_each_class_for_train_and_val_each_time( self, mocked_parse_function): # Random seed is selected such that instances are not selected the same for the whole epoch. # This test might fail due to change in random or behaviour of selecting the samples and it might not mean that the code # does not work properly. Maybe from one task in two different times the same train and val data will be selected mocked_parse_function.return_value = self.parse_function n = 6 k = 4 ds = self.omniglot_database.get_supervised_meta_learning_dataset( self.omniglot_database.train_folders, n=n, k=k, k_validation=8, meta_batch_size=4, one_hot_labels=False, dtype=tf.string) class_instances = dict() class_instances[0] = dict() class_instances[1] = dict() for epoch in range(2): for (x, x_val), (y, y_val) in ds: for meta_batch_index in range(x.shape[0]): train_ds = x[meta_batch_index, ...] for class_index in range(n): for instance_index in range(k): instance_address = train_ds[class_index, instance_index, ...] class_name, instance_name = os.path.split( instance_address.numpy().decode('utf-8')) if class_name not in class_instances[epoch]: class_instances[epoch][class_name] = set() class_instances[epoch][class_name].add( instance_name) first_epoch_class_instances = class_instances[0] second_epoch_class_instances = class_instances[1] for class_name in first_epoch_class_instances.keys(): self.assertIn(class_name, second_epoch_class_instances) self.assertNotEqual( 0, first_epoch_class_instances[class_name].difference( second_epoch_class_instances[class_name])) @patch('tf_datasets.MiniImagenetDatabase._get_parse_function') def test_1000_tasks_are_completely_different(self, mocked_parse_function): # This test should run on MAML get validation dataset. # For now, I just copied the code from there here to make sure that everything works fine. # However, if that code changes, this test does not have any value. # TODO Make this test work with maml.get_validation_dataset and get_test_dataset functions. mocked_parse_function.return_value = self.parse_function def get_val_dataset(): val_dataset = self.mini_imagenet_database.get_supervised_meta_learning_dataset( self.mini_imagenet_database.val_folders, n=6, k=4, k_validation=3, meta_batch_size=1, seed=42, dtype=tf.string) val_dataset = val_dataset.repeat(-1) val_dataset = val_dataset.take(1000) return val_dataset counter = 0 xs = set() x_vals = set() xs_queue = deque() x_vals_queue = deque() for (x, x_val), (y, y_val) in get_val_dataset(): x_string = ','.join(list(map(str, x.numpy().reshape(-1)))) # print(x_string) self.assertNotIn(x_string, xs) xs.add(x_string) x_val_string = ','.join(list(map(str, x_val.numpy().reshape(-1)))) self.assertNotIn(x_val_string, x_vals) x_vals.add(x_val_string) xs_queue.append(x_string) x_vals_queue.append(x_val_string) counter += 1 self.assertEqual(counter, 1000) for (x, x_val), (y, y_val) in get_val_dataset(): x_string = ','.join(list(map(str, x.numpy().reshape(-1)))) self.assertEqual(xs_queue.popleft(), x_string) x_val_string = ','.join(list(map(str, x_val.numpy().reshape(-1)))) self.assertEqual(x_vals_queue.popleft(), x_val_string)
print("Error: Too many arguments") sys.exit(0) # CONFIGS ITERATIONS = 15000 GAN_EPOCHS = 100 N_TASK_EVAL = 100 K = 5 N_WAY = 5 META_BATCH_SIZE = 1 LASIUM_TYPE = "p1" print("K=",K) print("N_WAY=",N_WAY) mini_imagenet_database = MiniImagenetDatabase() shape = (84, 84, 3) latent_dim = 512 mini_imagenet_generator = get_generator(latent_dim) mini_imagenet_discriminator = get_discriminator() mini_imagenet_parser = MiniImagenetParser(shape=shape) experiment_name = prefix+str(labeled_percentage) # for the SSGAN we need to feed the labels, L, when initializing gan = GAN( 'mini_imagenet', image_shape=shape, latent_dim=latent_dim, database=mini_imagenet_database, parser=mini_imagenet_parser,
x = layers.Conv2D(64, 4, activation=None, strides=2, padding="same", use_bias=False)(x) x = layers.BatchNormalization()(x) x = layers.LeakyReLU(alpha=0.2)(x) x = layers.GlobalMaxPooling2D()(x) discriminator_outputs = layers.Dense(1)(x) discriminator = keras.Model(discriminator_inputs, discriminator_outputs, name="discriminator") discriminator.summary() return discriminator if __name__ == '__main__': mini_imagenet_database = MiniImagenetDatabase() shape = (84, 84, 3) latent_dim = 512 mini_imagenet_generator = get_generator(latent_dim) mini_imagenet_discriminator = get_discriminator() mini_imagenet_parser = MiniImagenetParser(shape=shape) gan = GAN( 'mini_imagenet', image_shape=shape, latent_dim=latent_dim, database=mini_imagenet_database, parser=mini_imagenet_parser, generator=mini_imagenet_generator, discriminator=mini_imagenet_discriminator, visualization_freq=1,
from databases import MiniImagenetDatabase model = tf.keras.applications.VGG16(include_top=False, weights='imagenet', input_shape=(224, 224, 3)) flatten = tf.keras.layers.Flatten(name='flatten')(model.output) fc1 = tf.keras.layers.Dense(512)(flatten) fc2 = tf.keras.layers.Dense(512)(fc1) fc3 = tf.keras.layers.Dense(5)(fc2) new_model = tf.keras.models.Model(inputs=[model.input], outputs=[fc3]) new_model.compile(optimizer='adam', loss=tf.keras.losses.CategoricalCrossentropy()) data_base = MiniImagenetDatabase(input_shape=(224, 224, 3)) dataset = data_base.get_supervised_meta_learning_dataset( data_base.train_folders, 5, 1, 1, 2) for item in dataset: x, y = item x1, x2 = x y1, y2 = y x1 = tf.reshape(x1, (10, 224, 224, 3)) y1 = tf.reshape(y1, (10, 5)) new_model.fit(x1, y1) print(tf.argmax(new_model.predict(x1), axis=0)) print(tf.argmax(y1, axis=0)) break
axes[0, i].xaxis.set_label_position('top') axes[0, i].set_xlabel(title_name) # fig.suptitle('', fontsize=12, y=1) plt.savefig(fname=os.path.join(root_folder_to_save, 'all_domains2.pdf')) plt.show() if __name__ == '__main__': tf.config.set_visible_devices([], 'GPU') root_folder_to_save = os.path.expanduser('~/datasets_visualization/') if not os.path.exists(root_folder_to_save): os.mkdir(root_folder_to_save) databases = ( MiniImagenetDatabase(), OmniglotDatabase(random_seed=42, num_train_classes=1200, num_val_classes=100), AirplaneDatabase(), CUBDatabase(), DTDDatabase(), FungiDatabase(), VGGFlowerDatabase(), # TrafficSignDatabase(), # MSCOCODatabase(), # PlantDiseaseDatabase(), # EuroSatDatabase(), # ISICDatabase(), # ChestXRay8Database(), )
for style_layer_name in style_layers ] model = tf.keras.models.Model(inputs=vgg19.inputs, outputs=outputs) def convert_to_activaitons(imgs): imgs_activations = [] for image in imgs: image = convert_str_to_img(image)[np.newaxis, ...] activations = model.predict(image) imgs_activations.append(activations) return imgs_activations DTDDatabase() d1_imgs = get_all_instances(MiniImagenetDatabase()) d2_imgs = get_all_instances(ISICDatabase()) d3_imgs = get_all_instances(ChestXRay8Database()) d4_imgs = get_all_instances(PlantDiseaseDatabase()) d1_imgs = np.random.choice(d1_imgs, 10, replace=False) d2_imgs = np.random.choice(d2_imgs, 10, replace=False) d3_imgs = np.random.choice(d3_imgs, 10, replace=False) d4_imgs = np.random.choice(d4_imgs, 10, replace=False) print(d1_imgs) print(d2_imgs) print(d3_imgs) print(d4_imgs) d1_imgs_activations = convert_to_activaitons(d1_imgs)
def run_mini_imagenet(): mini_imagenet_database = MiniImagenetDatabase() n_data_points = 38400 data_points, classes, non_labeled_data_points = sample_data_points( mini_imagenet_database.train_folders, n_data_points) features_dataset, n_classes = make_features_dataset_mini_imagenet( data_points, classes, non_labeled_data_points, # shuffle_buffer_size=n_data_points, # batch_size=32, batch_size=16, shuffle_buffer_size=1000, ) feature_model = tf.keras.applications.VGG19(weights=None, classes=n_classes) feature_model.compile( loss=tf.keras.losses.CategoricalCrossentropy(), metrics=['accuracy'], optimizer=tf.keras.optimizers.Adam(learning_rate=1e-6), ) def save_call_back(epoch, logs): if epoch % 100 == 0: feature_model.save_weights( filepath=f'./feature_models/feature_model_{epoch}') # saver_callback = tf.keras.callbacks.LambdaCallback(on_epoch_end=save_call_back) # feature_model.fit(features_dataset, epochs=101, callbacks=[saver_callback]) feature_model.load_weights(filepath=f'./feature_models/feature_model_100') # feature_model.evaluate(features_dataset) # exit() feature_model = tf.keras.models.Model( inputs=feature_model.input, outputs=feature_model.layers[24].output) # print(n_classes) # feature_model = VariationalAutoEncoderFeature(input_shape=(84, 84, 3), latent_dim=32, n_classes=n_classes) # feature_model = train_the_feature_model_with_classification( # feature_model, # features_dataset, # n_classes, # mini_imagenet_database.input_shape # ) # feature_model = None # use imagenet # base_model = tf.keras.applications.VGG19(weights='imagenet') # feature_model = tf.keras.models.Model(inputs=base_model.input, outputs=base_model.layers[24].output) sml = SML( database=mini_imagenet_database, network_cls=MiniImagenetModel, n=5, k=1, k_val_ml=5, k_val_val=15, k_val_test=15, k_test=1, meta_batch_size=4, num_steps_ml=5, lr_inner_ml=0.05, num_steps_validation=5, save_after_iterations=15000, meta_learning_rate=0.001, n_clusters=500, feature_model=feature_model, # feature_size=288, feature_size=4096, input_shape=(224, 224, 3), preprocess_function=tf.keras.applications.vgg19.preprocess_input, log_train_images_after_iteration=1000, least_number_of_tasks_val_test=100, report_validation_frequency=250, experiment_name='mini_imagenet_learn_miniimagent_features') sml.train(iterations=60000)