예제 #1
0
파일: sample.py 프로젝트: ykumards/flax
def generate_sample(pcnn_module, batch_size, rng_seed=0):
    rng = random.PRNGKey(rng_seed)
    rng, model_rng = random.split(rng)

    # Create a model with dummy parameters and a dummy optimizer
    example_images = jnp.zeros((1, 32, 32, 3))
    model = train.create_model(model_rng, example_images, pcnn_module)
    optimizer = train.create_optimizer(model, 0)

    # Load learned parameters
    _, ema = train.restore_checkpoint(optimizer, model.params)
    model = model.replace(params=ema)

    # Initialize batch of images
    device_count = jax.local_device_count()
    assert not batch_size % device_count, (
        'Sampling batch size must be a multiple of the device count, got '
        'sample_batch_size={}, device_count={}.'.format(
            batch_size, device_count))
    sample_prev = jnp.zeros(
        (device_count, batch_size // device_count, 32, 32, 3))

    # and batch of rng keys
    sample_rng = random.split(rng, device_count)

    # Generate sample using fixed-point iteration
    sample = sample_iteration(sample_rng, model, sample_prev)
    while jnp.any(sample != sample_prev):
        sample_prev, sample = sample, sample_iteration(sample_rng, model,
                                                       sample)
    return jnp.reshape(sample, (batch_size, 32, 32, 3))
예제 #2
0
    def test_train_one_step(self):
        batch = train.get_batch(128)
        rng = random.PRNGKey(0)

        model = train.create_model(rng)
        optimizer = train.create_optimizer(model, 0.003)
        optimizer, train_metrics = train.train_step(optimizer, batch)

        self.assertLessEqual(train_metrics['loss'], 5)
        self.assertGreaterEqual(train_metrics['accuracy'], 0)
예제 #3
0
파일: train_test.py 프로젝트: vballoli/flax
    def test_train_one_step(self):
        batch = train.get_batch(128)
        rng = random.PRNGKey(0)

        with nn.stochastic(rng):
            model = train.create_model(nn.make_rng())
            optimizer = train.create_optimizer(model, 0.003)
            optimizer, train_metrics = train.train_step(
                optimizer, batch, nn.make_rng())

        self.assertLessEqual(train_metrics['loss'], 5)
        self.assertGreaterEqual(train_metrics['accuracy'], 0)
예제 #4
0
 def test_train_one_epoch(self):
     train_ds, test_ds = train.get_datasets()
     input_rng = onp.random.RandomState(0)
     model = train.create_model(random.PRNGKey(0))
     optimizer = train.create_optimizer(model, 0.1, 0.9)
     optimizer, train_metrics = train.train_epoch(optimizer, train_ds, 128,
                                                  0, input_rng)
     self.assertLessEqual(train_metrics['loss'], 0.27)
     self.assertGreaterEqual(train_metrics['accuracy'], 0.92)
     loss, accuracy = train.eval_model(optimizer.target, test_ds)
     self.assertLessEqual(loss, 0.06)
     self.assertGreaterEqual(accuracy, 0.98)
예제 #5
0
    def test_single_train_step(self):
        train_ds, test_ds = train.get_datasets()
        batch_size = 32
        model = train.create_model(random.PRNGKey(0))
        optimizer = train.create_optimizer(model, 0.1, 0.9)

        # test single train step.
        optimizer, train_metrics = train.train_step(
            optimizer=optimizer,
            batch={k: v[:batch_size]
                   for k, v in train_ds.items()})
        self.assertLessEqual(train_metrics['loss'], 2.302)
        self.assertGreaterEqual(train_metrics['accuracy'], 0.0625)

        # Run eval model.
        loss, accuracy = train.eval_model(optimizer.target, test_ds)
        self.assertLess(loss, 2.252)
        self.assertGreater(accuracy, 0.2597)