コード例 #1
0
 def test_CreateModel(self, model_name):
     model, state = load_model.get_model(model_name, 1, 32, 10)
     self.assertIsInstance(model, flax.nn.Model)
     self.assertIsInstance(state, flax.nn.Collection)
     fake_input = np.zeros([1, 32, 32, 3])
     with flax.nn.stateful(state, mutable=False):
         logits = model(fake_input, train=False)
     self.assertEqual(logits.shape, (1, 10))
コード例 #2
0
ファイル: train.py プロジェクト: zzhaozeng/google-research
def main(_):

    # As we gridsearch the weight decay and the learning rate, we add them to the
    # output directory path so that each model has its own directory to save the
    # results in. We also add the `run_seed` which is "gridsearched" on to
    # replicate an experiment several times.
    output_dir_suffix = os.path.join('lr_' + str(FLAGS.learning_rate),
                                     'wd_' + str(FLAGS.weight_decay),
                                     'seed_' + str(FLAGS.run_seed))

    output_dir = os.path.join(FLAGS.output_dir, output_dir_suffix)

    if not gfile.Exists(output_dir):
        gfile.MakeDirs(output_dir)

    num_devices = jax.local_device_count()
    assert FLAGS.batch_size % num_devices == 0
    local_batch_size = FLAGS.batch_size // num_devices
    info = 'Total batch size: {} ({} x {} replicas)'.format(
        FLAGS.batch_size, local_batch_size, num_devices)
    logging.info(info)

    if FLAGS.dataset.lower() == 'cifar10':
        dataset_source = dataset_source_lib.Cifar10(
            FLAGS.batch_size, FLAGS.image_level_augmentations,
            FLAGS.batch_level_augmentations)
    elif FLAGS.dataset.lower() == 'cifar100':
        dataset_source = dataset_source_lib.Cifar100(
            FLAGS.batch_size, FLAGS.image_level_augmentations,
            FLAGS.batch_level_augmentations)
    elif FLAGS.dataset.lower() == 'fashion_mnist':
        dataset_source = dataset_source_lib.FashionMnist(
            FLAGS.batch_size, FLAGS.image_level_augmentations,
            FLAGS.batch_level_augmentations)
    elif FLAGS.dataset.lower() == 'svhn':
        dataset_source = dataset_source_lib.SVHN(
            FLAGS.batch_size, FLAGS.image_level_augmentations,
            FLAGS.batch_level_augmentations)
    else:
        raise ValueError(
            'Available datasets: cifar10(0), fashion_mnist, svhn.')

    if 'cifar' in FLAGS.dataset.lower() or 'svhn' in FLAGS.dataset.lower():
        image_size = 32
        num_channels = 3
    else:
        image_size = 28  # For Fashion Mnist
        num_channels = 1

    num_classes = 100 if FLAGS.dataset.lower() == 'cifar100' else 10
    model, state = load_model.get_model(FLAGS.model_name, local_batch_size,
                                        image_size, num_classes, num_channels)
    # Learning rate will be overwritten by the lr schedule, we set it to zero.
    optimizer = flax_training.create_optimizer(model, 0.0)

    flax_training.train(optimizer, state, dataset_source, output_dir,
                        FLAGS.num_epochs)
コード例 #3
0
 def test_ParameterCount(self, model_name):
     # Parameter count from the autoaugment paper models, 100 classes:
     reference_parameter_count = {
         'WideResnet28x10': 36278324,
         'WideResnet28x6_ShakeShake': 26227572,
         'Pyramid_ShakeDrop': 26288692,
     }
     model, _ = load_model.get_model(model_name, 1, 32, 100)
     parameter_count = sum(np.prod(e.shape) for e in jax.tree_leaves(model))
     self.assertEqual(parameter_count,
                      reference_parameter_count[model_name])