예제 #1
0
def main():
    step_size = 0.001
    num_epochs = 100
    batch_size = 32
    test_key = PRNGKey(
        1
    )  # get reconstructions for a *fixed* latent variable sample over time

    train_images, test_images = mnist_images()
    num_complete_batches, leftover = divmod(train_images.shape[0], batch_size)
    num_batches = num_complete_batches + bool(leftover)
    opt = optimizers.Momentum(step_size, mass=0.9)

    @jit
    def binarize_batch(key, i, images):
        i = i % num_batches
        batch = lax.dynamic_slice_in_dim(images, i * batch_size, batch_size)
        return random.bernoulli(key, batch)

    @jit
    def run_epoch(key, state):
        def body_fun(i, state):
            loss_key, data_key = random.split(random.fold_in(key, i))
            batch = binarize_batch(data_key, i, train_images)
            return opt.update(loss.apply, state, batch, key=loss_key)

        return lax.fori_loop(0, num_batches, body_fun, state)

    example_key = PRNGKey(0)
    example_batch = binarize_batch(example_key, 0, images=train_images)
    shaped_elbo = loss.shaped(example_batch)
    init_parameters = shaped_elbo.init_parameters(key=PRNGKey(2))
    state = opt.init(init_parameters)

    for epoch in range(num_epochs):
        tic = time.time()
        state = run_epoch(PRNGKey(epoch), state)
        params = opt.get_parameters(state)
        test_elbo, samples = evaluate.apply_from({shaped_elbo: params},
                                                 test_images,
                                                 key=test_key,
                                                 jit=True)
        print(
            f'Epoch {epoch: 3d} {test_elbo:.3f} ({time.time() - tic:.3f} sec)')
        from matplotlib import pyplot as plt
        plt.imshow(samples, cmap=plt.cm.gray)
        plt.show()
예제 #2
0
def test_mnist_classifier():
    from examples.mnist_classifier import predict, loss, accuracy

    next_batch = lambda: (np.zeros((3, 784)), np.zeros((3, 10)))
    opt = optimizers.Momentum(0.001, mass=0.9)
    state = opt.init(loss.init_parameters(*next_batch(), key=PRNGKey(0)))

    t = time.time()
    for _ in range(10):
        state = opt.update(loss.apply, state, *next_batch(), jit=True)

    elapsed = time.time() - t
    assert 5 > elapsed

    params = opt.get_parameters(state)
    train_acc = accuracy.apply_from({loss: params}, *next_batch(), jit=True)
    assert () == train_acc.shape

    predict_params = predict.parameters_from({loss.shaped(*next_batch()): params}, next_batch()[0])
    predictions = predict.apply(predict_params, next_batch()[0], jit=True)
    assert (3, 10) == predictions.shape
예제 #3
0
def main():
    num_epochs = 10
    batch_size = 128
    train_images, train_labels, test_images, test_labels = mnist()
    num_train = train_images.shape[0]
    num_complete_batches, leftover = divmod(num_train, batch_size)
    num_batches = num_complete_batches + bool(leftover)

    def data_stream():
        rng = npr.RandomState(0)
        while True:
            perm = rng.permutation(num_train)
            for i in range(num_batches):
                batch_idx = perm[i * batch_size:(i + 1) * batch_size]
                yield train_images[batch_idx], train_labels[batch_idx]

    batches = data_stream()

    opt = optimizers.Momentum(0.001, mass=0.9)
    state = opt.init(loss.init_parameters(PRNGKey(0), *next(batches)))

    for epoch in range(num_epochs):
        start_time = time.time()
        for _ in range(num_batches):
            state = opt.update(loss.apply, state, *next(batches), jit=True)
        epoch_time = time.time() - start_time

        params = opt.get_parameters(state)
        train_acc = accuracy.apply_from({loss: params},
                                        train_images,
                                        train_labels,
                                        jit=True)
        test_acc = accuracy.apply_from({loss: params},
                                       test_images,
                                       test_labels,
                                       jit=True)
        print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
        print("Training set accuracy {:.4f}".format(train_acc))
        print("Test set accuracy {:.4f}".format(test_acc))
예제 #4
0
def main():
    rng_key = random.PRNGKey(0)

    batch_size = 8
    num_classes = 1001
    input_shape = (224, 224, 3, batch_size)
    step_size = 0.1
    num_steps = 10

    resnet = ResNet50(num_classes)

    @parametrized
    def loss(inputs, targets):
        logits = resnet(inputs)
        return np.sum(logits * targets)

    @parametrized
    def accuracy(inputs, targets):
        target_class = np.argmax(targets, axis=-1)
        predicted_class = np.argmax(resnet(inputs), axis=-1)
        return np.mean(predicted_class == target_class)

    def synth_batches():
        rng = npr.RandomState(0)
        while True:
            images = rng.rand(*input_shape).astype('float32')
            labels = rng.randint(num_classes, size=(batch_size, 1))
            onehot_labels = labels == np.arange(num_classes)
            yield images, onehot_labels

    opt = optimizers.Momentum(step_size, mass=0.9)
    batches = synth_batches()

    print("\nInitializing parameters.")
    state = opt(loss.init_parameters(rng_key, *next(batches)))
    for i in range(num_steps):
        print(f'Training on batch {i}.')
        state = opt.update(loss.apply, state, *next(batches))
    trained_params = opt.get_parameters(state)