def main(): step_size = 0.001 num_epochs = 100 batch_size = 32 test_key = PRNGKey( 1 ) # get reconstructions for a *fixed* latent variable sample over time train_images, test_images = mnist_images() num_complete_batches, leftover = divmod(train_images.shape[0], batch_size) num_batches = num_complete_batches + bool(leftover) opt = optimizers.Momentum(step_size, mass=0.9) @jit def binarize_batch(key, i, images): i = i % num_batches batch = lax.dynamic_slice_in_dim(images, i * batch_size, batch_size) return random.bernoulli(key, batch) @jit def run_epoch(key, state): def body_fun(i, state): loss_key, data_key = random.split(random.fold_in(key, i)) batch = binarize_batch(data_key, i, train_images) return opt.update(loss.apply, state, batch, key=loss_key) return lax.fori_loop(0, num_batches, body_fun, state) example_key = PRNGKey(0) example_batch = binarize_batch(example_key, 0, images=train_images) shaped_elbo = loss.shaped(example_batch) init_parameters = shaped_elbo.init_parameters(key=PRNGKey(2)) state = opt.init(init_parameters) for epoch in range(num_epochs): tic = time.time() state = run_epoch(PRNGKey(epoch), state) params = opt.get_parameters(state) test_elbo, samples = evaluate.apply_from({shaped_elbo: params}, test_images, key=test_key, jit=True) print( f'Epoch {epoch: 3d} {test_elbo:.3f} ({time.time() - tic:.3f} sec)') from matplotlib import pyplot as plt plt.imshow(samples, cmap=plt.cm.gray) plt.show()
def test_mnist_classifier(): from examples.mnist_classifier import predict, loss, accuracy next_batch = lambda: (np.zeros((3, 784)), np.zeros((3, 10))) opt = optimizers.Momentum(0.001, mass=0.9) state = opt.init(loss.init_parameters(*next_batch(), key=PRNGKey(0))) t = time.time() for _ in range(10): state = opt.update(loss.apply, state, *next_batch(), jit=True) elapsed = time.time() - t assert 5 > elapsed params = opt.get_parameters(state) train_acc = accuracy.apply_from({loss: params}, *next_batch(), jit=True) assert () == train_acc.shape predict_params = predict.parameters_from({loss.shaped(*next_batch()): params}, next_batch()[0]) predictions = predict.apply(predict_params, next_batch()[0], jit=True) assert (3, 10) == predictions.shape
def main(): num_epochs = 10 batch_size = 128 train_images, train_labels, test_images, test_labels = mnist() num_train = train_images.shape[0] num_complete_batches, leftover = divmod(num_train, batch_size) num_batches = num_complete_batches + bool(leftover) def data_stream(): rng = npr.RandomState(0) while True: perm = rng.permutation(num_train) for i in range(num_batches): batch_idx = perm[i * batch_size:(i + 1) * batch_size] yield train_images[batch_idx], train_labels[batch_idx] batches = data_stream() opt = optimizers.Momentum(0.001, mass=0.9) state = opt.init(loss.init_parameters(PRNGKey(0), *next(batches))) for epoch in range(num_epochs): start_time = time.time() for _ in range(num_batches): state = opt.update(loss.apply, state, *next(batches), jit=True) epoch_time = time.time() - start_time params = opt.get_parameters(state) train_acc = accuracy.apply_from({loss: params}, train_images, train_labels, jit=True) test_acc = accuracy.apply_from({loss: params}, test_images, test_labels, jit=True) print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time)) print("Training set accuracy {:.4f}".format(train_acc)) print("Test set accuracy {:.4f}".format(test_acc))
def main(): rng_key = random.PRNGKey(0) batch_size = 8 num_classes = 1001 input_shape = (224, 224, 3, batch_size) step_size = 0.1 num_steps = 10 resnet = ResNet50(num_classes) @parametrized def loss(inputs, targets): logits = resnet(inputs) return np.sum(logits * targets) @parametrized def accuracy(inputs, targets): target_class = np.argmax(targets, axis=-1) predicted_class = np.argmax(resnet(inputs), axis=-1) return np.mean(predicted_class == target_class) def synth_batches(): rng = npr.RandomState(0) while True: images = rng.rand(*input_shape).astype('float32') labels = rng.randint(num_classes, size=(batch_size, 1)) onehot_labels = labels == np.arange(num_classes) yield images, onehot_labels opt = optimizers.Momentum(step_size, mass=0.9) batches = synth_batches() print("\nInitializing parameters.") state = opt(loss.init_parameters(rng_key, *next(batches))) for i in range(num_steps): print(f'Training on batch {i}.') state = opt.update(loss.apply, state, *next(batches)) trained_params = opt.get_parameters(state)