def generate_sample(pcnn_module, batch_size, rng_seed=0): rng = random.PRNGKey(rng_seed) rng, model_rng = random.split(rng) # Create a model with dummy parameters and a dummy optimizer example_images = jnp.zeros((1, 32, 32, 3)) model = train.create_model(model_rng, example_images, pcnn_module) optimizer = train.create_optimizer(model, 0) # Load learned parameters _, ema = train.restore_checkpoint(optimizer, model.params) model = model.replace(params=ema) # Initialize batch of images device_count = jax.local_device_count() assert not batch_size % device_count, ( 'Sampling batch size must be a multiple of the device count, got ' 'sample_batch_size={}, device_count={}.'.format( batch_size, device_count)) sample_prev = jnp.zeros( (device_count, batch_size // device_count, 32, 32, 3)) # and batch of rng keys sample_rng = random.split(rng, device_count) # Generate sample using fixed-point iteration sample = sample_iteration(sample_rng, model, sample_prev) while jnp.any(sample != sample_prev): sample_prev, sample = sample, sample_iteration(sample_rng, model, sample) return jnp.reshape(sample, (batch_size, 32, 32, 3))
def test_train_one_step(self): batch = train.get_batch(128) rng = random.PRNGKey(0) model = train.create_model(rng) optimizer = train.create_optimizer(model, 0.003) optimizer, train_metrics = train.train_step(optimizer, batch) self.assertLessEqual(train_metrics['loss'], 5) self.assertGreaterEqual(train_metrics['accuracy'], 0)
def test_train_one_step(self): batch = train.get_batch(128) rng = random.PRNGKey(0) with nn.stochastic(rng): model = train.create_model(nn.make_rng()) optimizer = train.create_optimizer(model, 0.003) optimizer, train_metrics = train.train_step( optimizer, batch, nn.make_rng()) self.assertLessEqual(train_metrics['loss'], 5) self.assertGreaterEqual(train_metrics['accuracy'], 0)
def test_train_one_epoch(self): train_ds, test_ds = train.get_datasets() input_rng = onp.random.RandomState(0) model = train.create_model(random.PRNGKey(0)) optimizer = train.create_optimizer(model, 0.1, 0.9) optimizer, train_metrics = train.train_epoch(optimizer, train_ds, 128, 0, input_rng) self.assertLessEqual(train_metrics['loss'], 0.27) self.assertGreaterEqual(train_metrics['accuracy'], 0.92) loss, accuracy = train.eval_model(optimizer.target, test_ds) self.assertLessEqual(loss, 0.06) self.assertGreaterEqual(accuracy, 0.98)
def test_single_train_step(self): train_ds, test_ds = train.get_datasets() batch_size = 32 model = train.create_model(random.PRNGKey(0)) optimizer = train.create_optimizer(model, 0.1, 0.9) # test single train step. optimizer, train_metrics = train.train_step( optimizer=optimizer, batch={k: v[:batch_size] for k, v in train_ds.items()}) self.assertLessEqual(train_metrics['loss'], 2.302) self.assertGreaterEqual(train_metrics['accuracy'], 0.0625) # Run eval model. loss, accuracy = train.eval_model(optimizer.target, test_ds) self.assertLess(loss, 2.252) self.assertGreater(accuracy, 0.2597)