def test_CreateModel(self, model_name): model, state = load_model.get_model(model_name, 1, 32, 10) self.assertIsInstance(model, flax.nn.Model) self.assertIsInstance(state, flax.nn.Collection) fake_input = np.zeros([1, 32, 32, 3]) with flax.nn.stateful(state, mutable=False): logits = model(fake_input, train=False) self.assertEqual(logits.shape, (1, 10))
def main(_): # As we gridsearch the weight decay and the learning rate, we add them to the # output directory path so that each model has its own directory to save the # results in. We also add the `run_seed` which is "gridsearched" on to # replicate an experiment several times. output_dir_suffix = os.path.join('lr_' + str(FLAGS.learning_rate), 'wd_' + str(FLAGS.weight_decay), 'seed_' + str(FLAGS.run_seed)) output_dir = os.path.join(FLAGS.output_dir, output_dir_suffix) if not gfile.Exists(output_dir): gfile.MakeDirs(output_dir) num_devices = jax.local_device_count() assert FLAGS.batch_size % num_devices == 0 local_batch_size = FLAGS.batch_size // num_devices info = 'Total batch size: {} ({} x {} replicas)'.format( FLAGS.batch_size, local_batch_size, num_devices) logging.info(info) if FLAGS.dataset.lower() == 'cifar10': dataset_source = dataset_source_lib.Cifar10( FLAGS.batch_size, FLAGS.image_level_augmentations, FLAGS.batch_level_augmentations) elif FLAGS.dataset.lower() == 'cifar100': dataset_source = dataset_source_lib.Cifar100( FLAGS.batch_size, FLAGS.image_level_augmentations, FLAGS.batch_level_augmentations) elif FLAGS.dataset.lower() == 'fashion_mnist': dataset_source = dataset_source_lib.FashionMnist( FLAGS.batch_size, FLAGS.image_level_augmentations, FLAGS.batch_level_augmentations) elif FLAGS.dataset.lower() == 'svhn': dataset_source = dataset_source_lib.SVHN( FLAGS.batch_size, FLAGS.image_level_augmentations, FLAGS.batch_level_augmentations) else: raise ValueError( 'Available datasets: cifar10(0), fashion_mnist, svhn.') if 'cifar' in FLAGS.dataset.lower() or 'svhn' in FLAGS.dataset.lower(): image_size = 32 num_channels = 3 else: image_size = 28 # For Fashion Mnist num_channels = 1 num_classes = 100 if FLAGS.dataset.lower() == 'cifar100' else 10 model, state = load_model.get_model(FLAGS.model_name, local_batch_size, image_size, num_classes, num_channels) # Learning rate will be overwritten by the lr schedule, we set it to zero. optimizer = flax_training.create_optimizer(model, 0.0) flax_training.train(optimizer, state, dataset_source, output_dir, FLAGS.num_epochs)
def test_ParameterCount(self, model_name): # Parameter count from the autoaugment paper models, 100 classes: reference_parameter_count = { 'WideResnet28x10': 36278324, 'WideResnet28x6_ShakeShake': 26227572, 'Pyramid_ShakeDrop': 26288692, } model, _ = load_model.get_model(model_name, 1, 32, 100) parameter_count = sum(np.prod(e.shape) for e in jax.tree_leaves(model)) self.assertEqual(parameter_count, reference_parameter_count[model_name])