def test_image_decoder(self): # Make random image. image_size = 84 image = np.random.randint(low=0, high=255, size=[image_size, image_size, 3]).astype(np.ubyte) # Encode image_bytes = dataset_to_records.encode_image(image, image_format='PNG') label = np.zeros(1).astype(np.int64) image_example = dataset_to_records.make_example(image_bytes, label, input_key='image', label_key='label') # Decode image_decoder = decoder.ImageDecoder(image_size=image_size) image_decoded = image_decoder(image_example) # Assert perfect reconstruction. with self.session(use_gpu=False) as sess: image_rec_numpy = sess.run(image_decoded) self.assertAllClose(2 * (image.astype(np.float32) / 255.0 - 0.5), image_rec_numpy)
def test_trainer(self): # PrototypicalNetworkLearner is built automatically and this test does not # have the opportunity to pass values to its constructor except through gin. gin.bind_parameter('PrototypicalNetworkLearner.weight_decay', 1e-4) gin.bind_parameter( 'PrototypicalNetworkLearner.backprop_through_moments', True) gin.bind_parameter( 'PrototypicalNetworkLearner.transductive_batch_norm', False) gin.bind_parameter('PrototypicalNetworkLearner.embedding_fn', 'four_layer_convnet') # Values that can't be passed directly to EpisodeDescriptionConfig gin.bind_parameter('process_episode.support_decoder', decoder.ImageDecoder()) gin.bind_parameter('process_episode.query_decoder', decoder.ImageDecoder()) episode_config = config.EpisodeDescriptionConfig( num_ways=None, num_support=None, num_query=None, min_ways=5, max_ways_upper_bound=50, max_num_query=10, max_support_set_size=500, max_support_size_contrib_per_class=100, min_log_weight=np.log(0.5), max_log_weight=np.log(2), ignore_dag_ontology=False, ignore_bilevel_ontology=False) # Inspired from `learn/gin/default/debug_proto_mini_imagenet.gin`, but # building the objects explicitly. data_config = config.DataConfig( image_height=84, shuffle_buffer_size=20, read_buffer_size_bytes=(1024**2), num_prefetch=2, ) trainer_instance = trainer.Trainer( train_learner_class=learner_lib.PrototypicalNetworkLearner, eval_learner_class=learner_lib.PrototypicalNetworkLearner, is_training=True, train_dataset_list=['mini_imagenet'], eval_dataset_list=['mini_imagenet'], restrict_classes={}, restrict_num_per_class={}, checkpoint_dir='', summary_dir='', records_root_dir=FLAGS.records_root_dir, eval_split=trainer.VALID_SPLIT, eval_finegrainedness=False, eval_finegrainedness_split='', eval_imbalance_dataset='', omit_from_saving_and_reloading='', train_episode_config=episode_config, eval_episode_config=episode_config, data_config=data_config, num_updates=100, batch_size=8, # unused num_eval_episodes=10, checkpoint_every=10, validate_every=5, log_every=1, checkpoint_to_restore=None, learning_rate=1e-4, decay_learning_rate=True, decay_every=5000, decay_rate=0.5, experiment_name='test', pretrained_source='', ) # Get the next train / valid / test episodes. train_episode = trainer_instance.next_data[trainer.TRAIN_SPLIT] self.assertIsInstance(train_episode, providers.EpisodeDataset) # This isn't really a test. It just checks that things don't crash... print( trainer_instance.sess.run([ trainer_instance.train_op, trainer_instance.losses[trainer.TRAIN_SPLIT], trainer_instance.accuracies[trainer.TRAIN_SPLIT] ]))