예제 #1
0
 def test_CreateResnetModel(self):
     model, state = load_model.get_model('Resnet50', 1, 224, 1000)
     self.assertIsInstance(model, flax.nn.Model)
     self.assertIsInstance(state, flax.nn.Collection)
     fake_input = np.zeros([1, 224, 224, 3])
     with flax.nn.stateful(state, mutable=False):
         logits = model(fake_input, train=False)
     self.assertEqual(logits.shape, (1, 1000))
예제 #2
0
파일: train.py 프로젝트: T-STAR-LTD/sam
def main(_):

    tf.enable_v2_behavior()
    # make sure tf does not allocate gpu memory
    tf.config.experimental.set_visible_devices([], 'GPU')

    # Performance gains on TPU by switching to hardware bernoulli.
    def hardware_bernoulli(rng_key, p=jax.numpy.float32(0.5), shape=None):
        lax_key = jax.lax.tie_in(rng_key, 0.0)
        return jax.lax.rng_uniform(lax_key, 1.0, shape) < p

    def set_hardware_bernoulli():
        jax.random.bernoulli = hardware_bernoulli

    set_hardware_bernoulli()

    # As we gridsearch the weight decay and the learning rate, we add them to the
    # output directory path so that each model has its own directory to save the
    # results in. We also add the `run_seed` which is "gridsearched" on to
    # replicate an experiment several times.
    output_dir_suffix = os.path.join('lr_' + str(FLAGS.learning_rate),
                                     'wd_' + str(FLAGS.weight_decay),
                                     'rho_' + str(FLAGS.sam_rho),
                                     'seed_' + str(FLAGS.run_seed))

    output_dir = os.path.join(FLAGS.output_dir, output_dir_suffix)

    if not gfile.exists(output_dir):
        gfile.makedirs(output_dir)

    num_devices = jax.local_device_count() * jax.host_count()
    assert FLAGS.batch_size % num_devices == 0
    local_batch_size = FLAGS.batch_size // num_devices
    info = 'Total batch size: {} ({} x {} replicas)'.format(
        FLAGS.batch_size, local_batch_size, num_devices)
    logging.info(info)

    if FLAGS.dataset == 'cifar10':
        if FLAGS.from_pretrained_checkpoint:
            image_size = efficientnet.name_to_image_size(FLAGS.model_name)
        else:
            image_size = None
        dataset_source = dataset_source_lib.Cifar10(
            FLAGS.batch_size // jax.host_count(),
            FLAGS.image_level_augmentations,
            FLAGS.batch_level_augmentations,
            image_size=image_size)
    elif FLAGS.dataset == 'cifar100':
        if FLAGS.from_pretrained_checkpoint:
            image_size = efficientnet.name_to_image_size(FLAGS.model_name)
        else:
            image_size = None
        dataset_source = dataset_source_lib.Cifar100(
            FLAGS.batch_size // jax.host_count(),
            FLAGS.image_level_augmentations,
            FLAGS.batch_level_augmentations,
            image_size=image_size)

    elif FLAGS.dataset == 'fashion_mnist':
        dataset_source = dataset_source_lib.FashionMnist(
            FLAGS.batch_size, FLAGS.image_level_augmentations,
            FLAGS.batch_level_augmentations)
    elif FLAGS.dataset == 'svhn':
        dataset_source = dataset_source_lib.SVHN(
            FLAGS.batch_size, FLAGS.image_level_augmentations,
            FLAGS.batch_level_augmentations)
    elif FLAGS.dataset == 'imagenet':
        imagenet_image_size = efficientnet.name_to_image_size(FLAGS.model_name)
        dataset_source = dataset_source_imagenet.Imagenet(
            FLAGS.batch_size // jax.host_count(), imagenet_image_size,
            FLAGS.image_level_augmentations)
    else:
        raise ValueError('Dataset not recognized.')

    if 'cifar' in FLAGS.dataset or 'svhn' in FLAGS.dataset:
        if image_size is None or 'svhn' in FLAGS.dataset:
            image_size = 32
        num_channels = 3
        num_classes = 100 if FLAGS.dataset == 'cifar100' else 10
    elif FLAGS.dataset == 'fashion_mnist':
        image_size = 28  # For Fashion Mnist
        num_channels = 1
        num_classes = 10
    elif FLAGS.dataset == 'imagenet':
        image_size = imagenet_image_size
        num_channels = 3
        num_classes = 1000
    else:
        raise ValueError('Dataset not recognized.')

    try:
        model, state = load_imagenet_model.get_model(FLAGS.model_name,
                                                     local_batch_size,
                                                     image_size, num_classes)
    except load_imagenet_model.ModelNameError:
        model, state = load_model.get_model(FLAGS.model_name, local_batch_size,
                                            image_size, num_classes,
                                            num_channels)

    # Learning rate will be overwritten by the lr schedule, we set it to zero.
    optimizer = flax_training.create_optimizer(model, 0.0)

    flax_training.train(optimizer, state, dataset_source, output_dir,
                        FLAGS.num_epochs)