def run_svgd(key, lr, full_data=False, progress_bar=False):
    key, subkey = random.split(key)
    init_particles = ravel(*sample_from_prior(subkey, n_particles))
    svgd_opt = optax.sgd(lr)

    svgd_grad = models.KernelGradient(
        get_target_logp=lambda batch: get_minibatch_logp(*batch), scaled=True)
    particles = models.Particles(key,
                                 svgd_grad.gradient,
                                 init_particles,
                                 custom_optimizer=svgd_opt)

    test_batches = get_batches(x_test, y_test, 2 *
                               NUM_VALS) if full_data else get_batches(
                                   x_val, y_val, 2 * NUM_VALS)
    train_batches = get_batches(xx, yy, NUM_STEPS +
                                1) if full_data else get_batches(
                                    x_train, y_train, NUM_STEPS + 1)
    for i, batch in tqdm(enumerate(train_batches),
                         total=NUM_STEPS,
                         disable=not progress_bar):
        particles.step(batch)
        if i % (NUM_STEPS // NUM_VALS) == 0:
            test_logp = get_minibatch_logp(*next(test_batches))
            stepdata = {
                "accuracy":
                compute_test_accuracy(unravel(particles.particles)[0]),
                "test_logp": test_logp(particles.particles),
            }
            metrics.append_to_log(particles.rundata, stepdata)

    particles.done()
    return particles
def run_sgld(key, lr, full_data=False, progress_bar=False):
    key, subkey = random.split(key)
    init_particles = ravel(*sample_from_prior(subkey, n_particles))
    key, subkey = random.split(key)
    #     sgld_opt = utils.scaled_sgld(subkey, lr, schedule)
    sgld_opt = utils.sgld(lr, 0)

    def energy_gradient(data, particles, aux=True):
        """data = [batch_x, batch_y]"""
        xbatch, ybatch = data
        logp = get_minibatch_logp(xbatch, ybatch)
        logprob, grads = value_and_grad(logp)(particles)
        if aux:
            return -grads, {"logp": logprob}
        else:
            return -grads

    particles = models.Particles(key,
                                 energy_gradient,
                                 init_particles,
                                 custom_optimizer=sgld_opt)
    test_batches = get_batches(x_test, y_test, 2 *
                               NUM_VALS) if full_data else get_batches(
                                   x_val, y_val, 2 * NUM_VALS)
    train_batches = get_batches(xx, yy, NUM_STEPS +
                                1) if full_data else get_batches(
                                    x_train, y_train, NUM_STEPS + 1)
    for i, batch_xy in tqdm(enumerate(train_batches),
                            total=NUM_STEPS,
                            disable=not progress_bar):
        particles.step(batch_xy)
        if i % (NUM_STEPS // NUM_VALS) == 0:
            test_logp = get_minibatch_logp(*next(test_batches))
            stepdata = {
                "accuracy":
                compute_test_accuracy(unravel(particles.particles)[0]),
                "train_accuracy":
                compute_train_accuracy(unravel(particles.particles)[0]),
                "test_logp":
                np.mean(test_logp(particles.particles))
            }
            metrics.append_to_log(particles.rundata, stepdata)
    particles.done()
    return particles
def run_neural_svgd(key, plr, full_data=False, progress_bar=False):
    key, subkey = random.split(key)
    init_particles = ravel(*sample_from_prior(subkey, n_particles))
    nsvgd_opt = optax.sgd(plr)

    key1, key2 = random.split(key)
    neural_grad = models.SteinNetwork(
        target_dim=init_particles.shape[1],
        #get_target_logp=lambda batch: get_minibatch_logp(*batch),
        learning_rate=neural_lr,
        key=key1,
        aux=False,
        lambda_reg=lambda_reg)
    particles = models.Particles(key2,
                                 neural_grad.gradient,
                                 init_particles,
                                 custom_optimizer=nsvgd_opt)

    test_batches = get_batches(x_test, y_test, 2 *
                               NUM_VALS) if full_data else get_batches(
                                   x_val, y_val, 2 * NUM_VALS)
    train_batches = get_batches(xx, yy, NUM_STEPS +
                                1) if full_data else get_batches(
                                    x_train, y_train, NUM_STEPS + 1)

    @jit
    def v_dlogp(particles, batch):
        logp = get_minibatch_logp(*batch)
        return vmap(grad(logp))(particles)

    # Warmup on first batch


#    key, subkey = random.split(key)
#    neural_grad.warmup(key=subkey,
#                       sample_split_particles=sample_tv,
#                       next_data=lambda: next(get_batches(x_train, y_train, n_steps=100+1)),  # note: lambda always returns first batch
#                       n_iter=3)
    first_batch = next(get_batches(x_train, y_train, n_steps=100 + 1))
    for _ in range(3):
        key, subkey = random.split(key)
        split_particles = sample_tv(subkey)
        split_dlogp = [v_dlogp(x, first_batch) for x in split_particles]
        neural_grad.train(split_particles,
                          split_dlogp,
                          n_steps=30,
                          early_stopping=True)

    for i, data_batch in tqdm(enumerate(train_batches),
                              total=NUM_STEPS,
                              disable=not progress_bar):
        key, subkey = random.split(key)
        split_particles = particles.next_batch(subkey)
        split_dlogp = [v_dlogp(x, data_batch) for x in split_particles]

        neural_grad.train(split_particles, split_dlogp, n_steps=10)
        particles.step(neural_grad.get_params())
        if i % (NUM_STEPS // NUM_VALS) == 0:
            test_logp = get_minibatch_logp(*next(test_batches))
            train_logp = get_minibatch_logp(*data_batch)
            stepdata = {
                "accuracy":
                compute_test_accuracy(unravel(particles.particles)[0]),
                "test_logp": test_logp(particles.particles),
                "training_logp": train_logp(particles.particles),
            }
            metrics.append_to_log(particles.rundata, stepdata)
    neural_grad.done()
    particles.done()
    return particles, neural_grad
Exemple #4
0
def train(key,
          particle_stepsize: float = 1e-3,
          evaluate_every: int = 10,
          n_iter: int = 400,
          n_samples: int = 100,
          results_file: str = cfg.results_path + 'svgd-bnn.csv',
          overwrite_file: bool = False,
          optimizer="sgd"):
    """
    Initialize model; warmup; training; evaluation.
    Returns a dictionary of metrics.
    Args:
        particle_stepsize: learning rate of BNN
        evaluate_every: compute metrics every `evaluate_every` steps
        n_iter: number of train-update iterations
        write_results_to_file: whether to save accuracy in csv file
    """
    csv_string = f"{particle_stepsize}"

    # initialize particles and the dynamics model
    key, subkey = random.split(key)
    init_particles = vmap(bnn.init_flat_params)(random.split(
        subkey, n_samples))

    if optimizer == "sgd":
        opt = optax.sgd(particle_stepsize)
    elif optimizer == "adam":
        opt = optax.adam(particle_stepsize)
    else:
        raise ValueError("must be adam or sgd")

    key, subkey1, subkey2 = random.split(key, 3)
    svgd_grad = models.KernelGradient(get_target_logp=bnn.get_minibatch_logp,
                                      scaled=False,
                                      lambda_reg=LAMBDA_REG)

    particles = models.Particles(key=subkey2,
                                 gradient=svgd_grad.gradient,
                                 init_samples=init_particles,
                                 custom_optimizer=opt)

    def evaluate(step_counter, ps):
        stepdata = {
            "accuracy": bnn.compute_acc_from_flat(ps),
            "step_counter": step_counter,
        }
        with open(results_file, "a") as file:
            file.write(csv_string + f"{step_counter},{stepdata['accuracy']}\n")
        return stepdata

    if not os.path.isfile(results_file) or overwrite_file:
        with open(results_file, "w") as file:
            file.write("meta_lr,particle_stepsize,patience,"
                       "max_train_steps,step,accuracy\n")

    print("Training...")
    for step_counter in tqdm(range(n_iter), disable=on_cluster):
        train_batch = next(data.train_batches)
        particles.step(train_batch)

        if (step_counter + 1) % evaluate_every == 0:
            metrics.append_to_log(particles.rundata,
                                  evaluate(step_counter, particles.particles))

        if step_counter % data.steps_per_epoch == 0:
            print(f"Starting epoch {step_counter // data.steps_per_epoch + 1}")

    # final eval
    final_eval = evaluate(-1, particles.particles)
    particles.done()

    return final_eval['accuracy']
Exemple #5
0
def train(
        key,
        meta_lr: float = DEFAULT_META_LR,
        particle_stepsize: float = 1e-3,
        evaluate_every: int = 10,
        n_iter: int = 200,
        n_samples: int = cfg.n_samples +
    1,  # add 1 to account for dummy val set
        particle_steps_per_iter: int = 1,
        max_train_steps_per_iter: int = DEFAULT_MAX_TRAIN_STEPS,
        patience: int = DEFAULT_PATIENCE,
        dropout: bool = True,
        results_file: str = cfg.results_path + 'nvgd-bnn.csv',
        overwrite_file: bool = False,
        early_stopping: bool = True,
        optimizer: str = "sgd",
        hidden_sizes=[DEFAULT_LAYER_SIZE] * 3,
        use_hypernetwork: bool = False):
    """
    Initialize model; warmup; training; evaluation.
    Returns a dictionary of metrics.
    Args:
        meta_lr: learning rate of Stein network
        particle_stepsize: learning rate of BNN
        evaluate_every: compute metrics every `evaluate_every` steps
        n_iter: number of train-update iterations
        particle_steps_per_iter: num particle updates after training Stein network
        max_train_steps_per_iter: cutoff for Stein network training iteration
        patience: early stopping criterion
        dropout: use dropout during training of the Stein network
        write_results_to_file: whether to save accuracy in csv file
        use_hypernetwork: whether to use net-to-net hypernetwork to model
            the witness function
    """
    #    csv_string = f"{meta_lr},{particle_stepsize}," \
    #                 f"{patience},{max_train_steps_per_iter},"

    # initialize particles and the dynamics model
    key, subkey = random.split(key)
    init_particles = vmap(bnn.init_flat_params)(random.split(
        subkey, n_samples))

    if optimizer == "sgd":
        opt = optax.sgd(particle_stepsize)
    elif optimizer == "adam":
        opt = optax.adam(particle_stepsize)
    else:
        raise ValueError("optimizer must be sgd or adam")

    key, subkey1, subkey2 = random.split(key, 3)
    neural_grad = models.SteinNetwork(target_dim=init_particles.shape[1],
                                      learning_rate=meta_lr,
                                      key=subkey1,
                                      sizes=hidden_sizes +
                                      [init_particles.shape[1]],
                                      aux=False,
                                      use_hutchinson=True,
                                      lambda_reg=LAMBDA_REG,
                                      patience=patience,
                                      dropout=dropout,
                                      particle_unravel=nets.cnn_unravel,
                                      hypernet=use_hypernetwork)

    particles = models.Particles(key=subkey2,
                                 gradient=neural_grad.gradient,
                                 init_samples=init_particles,
                                 custom_optimizer=opt)

    minibatch_vdlogp = vmap(value_and_grad(bnn.minibatch_logp), (0, None))

    @jit
    def split_vdlogp(split_particles, train_batch):
        """returns tuple (split_logp, split_dlogp)"""
        train_out, val_out = [
            minibatch_vdlogp(x, train_batch) for x in split_particles
        ]
        return tuple(zip(train_out, val_out))

    def step(split_particles, split_dlogp):
        """one iteration of the particle trajectory simulation"""
        neural_grad.train(split_particles=split_particles,
                          split_dlogp=split_dlogp,
                          n_steps=max_train_steps_per_iter,
                          early_stopping=early_stopping)
        for _ in range(particle_steps_per_iter):
            particles.step(neural_grad.get_params())
        return

    def evaluate(step_counter, ps, logp):
        ll = logp.mean()
        stepdata = {
            "accuracy": bnn.compute_acc_from_flat(ps),
            "step_counter": step_counter,
            "loglikelihood": ll,
        }
        with open(results_file, "a") as file:
            file.write(f"{step_counter},{stepdata['accuracy']},{ll}\n")
        return stepdata

    if not os.path.isfile(results_file) or overwrite_file:
        with open(results_file, "w") as file:
            file.write("step_counter,accuracy,loglikelihood\n")

    print("Training...")
    for step_counter in tqdm(range(n_iter), disable=on_cluster):
        key, subkey = random.split(key)
        train_batch = next(data.train_batches)
        n_train_particles = 3 * n_samples // 4 if early_stopping else n_samples - 1
        split_particles = particles.next_batch(
            key, n_train_particles=n_train_particles)
        split_logp, split_dlogp = split_vdlogp(split_particles, train_batch)
        step(split_particles, split_dlogp)

        if (step_counter + 1) % evaluate_every == 0:
            eval_ps = particles.particles if early_stopping else split_particles[
                0]
            metrics.append_to_log(
                particles.rundata,
                evaluate(step_counter, eval_ps, split_logp[0]))

        if step_counter % data.steps_per_epoch == 0:
            print(f"Starting epoch {step_counter // data.steps_per_epoch + 1}")

    neural_grad.done()
    particles.done()

    final_eval = evaluate(-1, particles.particles, split_particles[0])
    return round(float(final_eval['accuracy']), 4), particles.rundata