def main(unused_argv):
  # Build data and .
  print('Loading data.')
  x_train, y_train, x_test, y_test = datasets.mnist(permute_train=True)

  # Build the network
  init_fn, f = stax.serial(
      stax.Dense(2048),
      stax.Tanh,
      stax.Dense(10))

  key = random.PRNGKey(0)
  _, params = init_fn(key, (-1, 784))

  # Linearize the network about its initial parameters.
  f_lin = linearize(f, params)

  # Create and initialize an optimizer for both f and f_lin.
  opt_init, opt_apply, get_params = optimizers.momentum(FLAGS.learning_rate,
                                                        0.9)
  opt_apply = jit(opt_apply)

  state = opt_init(params)
  state_lin = opt_init(params)

  # Create a cross-entropy loss function.
  loss = lambda fx, y_hat: -np.mean(stax.logsoftmax(fx) * y_hat)

  # Specialize the loss function to compute gradients for both linearized and
  # full networks.
  grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))
  grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y)))

  # Train the network.
  print('Training.')
  print('Epoch\tLoss\tLinearized Loss')
  print('------------------------------------------')

  epoch = 0
  steps_per_epoch = 50000 // FLAGS.batch_size

  for i, (x, y) in enumerate(datasets.minibatch(
      x_train, y_train, FLAGS.batch_size, FLAGS.train_epochs)):

    params = get_params(state)
    state = opt_apply(i, grad_loss(params, x, y), state)

    params_lin = get_params(state_lin)
    state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin)

    if i % steps_per_epoch == 0:
      print('{}\t{:.4f}\t{:.4f}'.format(
          epoch, loss(f(params, x), y), loss(f_lin(params_lin, x), y)))
      epoch += 1

  # Print out summary data comparing the linear / nonlinear model.
  x, y = x_train[:10000], y_train[:10000]
  util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss)
  util.print_summary(
      'test', y_test, f(params, x_test), f_lin(params_lin, x_test), loss)
Пример #2
0
def main():
    USE_MSE = False
    EPOCHS = 200

    optimizer = SGDOptimizer(lr=0.1, weight_decay=1e-4)

    if USE_MSE:

        @jax.jit
        def criterion(logits, targets):
            return jnp.mean(jnp.sum((logits - targets) ** 2, axis=1))

    else:

        @jax.jit
        def criterion(logits, targets):
            return -jnp.mean(jnp.sum(log_softmax(logits) * targets, axis=1), axis=0)

    init_fn, apply_fn, _ = nt.stax.serial(
        nt.stax.Dense(512, 1.0, 0.05),
        nt.stax.Erf(),
        nt.stax.Dense(512, 1.0, 0.05),
        nt.stax.Erf(),
        nt.stax.Dense(10, 1.0, 0.05),
    )

    key = random.PRNGKey(0)
    _, params = init_fn(key, (None, 784))

    # Generating dataset
    x_train, y_train, x_test, y_test = datasets.get_dataset("fashion_mnist", 1024, 128)

    for e in range(EPOCHS):
        key, subkey = random.split(key)
        train_loader = datasets.minibatch(
            x_train, y_train, batch_size=128, train_epochs=1, key=subkey
        )
        val_loader = datasets.minibatch(
            x_test, y_test, batch_size=128, train_epochs=1, key=subkey
        )

        params, optimizer, _, _ = train(
            train_loader,
            apply_fn,
            params,
            criterion,
            optimizer,
            epoch=e,
            num_images=len(x_train),
            batch_size=128,
        )

        validate(
            val_loader,
            apply_fn,
            params,
            criterion,
            epoch=e,
            batch_size=len(x_test),
            num_images=128,
        )

    with open("fashion-mnist-mlp.pkl", "wb+") as f:
        pickle.dump(params, f)
Пример #3
0
def main(unused_argv):
    # print(f'Available GPU memory: {util.get_gpu_memory()}')
    # Load and normalize data
    print('Loading data...')
    x_train, y_train, x_test, y_test = datasets.get_dataset('mnist',
                                                            n_train=60000,
                                                            n_test=10000,
                                                            permute_train=True)
    # print(f'Available GPU memory: {util.get_gpu_memory()}')

    # Reformat MNIST data to 28x28x1 pictures
    x_train = np.asarray(x_train.reshape(-1, 28, 28, 1))
    x_test = np.asarray(x_test.reshape(-1, 28, 28, 1))
    print('Data loaded and reshaped')
    # print(f'Available GPU memory: {util.get_gpu_memory()}')

    # Set random seed
    key = random.PRNGKey(0)

    # # Add random translation to images
    # x_train = util.add_translation(x_train, FLAGS.max_pixel)
    # x_test = util.add_translation(x_test, FLAGS.max_pixel)
    # print(f'Random translation by up to {FLAGS.max_pixel} pixels added')

    # # Add random translations with padding
    # x_train = util.add_padded_translation(x_train, 10)
    # x_test = util.add_padded_translation(x_test, 10)
    # print(f'Random translations with additional padding up to 10 pixels added')

    # Build the LeNet network with NTK parameterization
    init_fn, f, kernel_fn = util.build_le_net(FLAGS.network_width)
    print(f'Network of width x{FLAGS.network_width} built.')

    # # Construct the kernel function
    # kernel_fn = nt.batch(kernel_fn, device_count=-1, batch_size=FLAGS.batch_size_kernel)
    # print('Kernel constructed')
    # print(f'Available GPU memory: {util.get_gpu_memory()}')

    # Compute random initial parameters
    _, params = init_fn(key, (-1, 28, 28, 1))
    params_lin = params

    print('Initial parameters constructed')
    # print(f'Available GPU memory: {util.get_gpu_memory()}')

    # # Save initial parameters
    # with open('init_params.npy', 'wb') as file:
    #     np.save(file, params)

    # Linearize the network about its initial parameters.
    # Use jit for faster GPU computation (only feasible for width < 25)
    f_lin = nt.linearize(f, params)
    if FLAGS.network_width <= 10:
        f_jit = jit(f)
        f_lin_jit = jit(f_lin)
    else:
        f_jit = f
        f_lin_jit = f_lin

    # Create a callable function for dynamic learning rates
    # Starts with learning_rate, divided by 10 after learning_decline epochs.
    dynamic_learning_rate = lambda iteration_step: FLAGS.learning_rate / 10**(
        (iteration_step //
         (x_train.shape[0] // FLAGS.batch_size)) // FLAGS.learning_decline)

    # Create and initialize an optimizer for both f and f_lin.
    # Use momentum with coefficient 0.9 and jit
    opt_init, opt_apply, get_params = optimizers.momentum(
        dynamic_learning_rate, 0.9)
    opt_apply = jit(opt_apply)

    # Compute the initial states
    state = opt_init(params)
    state_lin = opt_init(params)

    # Define the accuracy function
    accuracy = lambda fx, y_hat: np.mean(
        np.argmax(fx, axis=1) == np.argmax(y_hat, axis=1))

    # Define mean square error loss function
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)

    # # Create a cross-entropy loss function.
    # loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat)

    # Specialize the loss function to compute gradients for both linearized and
    # full networks.
    grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))
    grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y)))

    # Train the network.
    print(
        f'Training with dynamic learning decline after {FLAGS.learning_decline} epochs...'
    )
    print(
        'Epoch\tTime\tAccuracy\tLin. Accuracy\tLoss\tLin. Loss\tAccuracy Train\tLin.Accuracy Train'
    )
    print(
        '----------------------------------------------------------------------------------------------------------'
    )

    # Initialize training
    epoch = 0
    steps_per_epoch = x_train.shape[0] // FLAGS.batch_size

    # Set start time (total and 100 epochs)
    start = time.time()
    start_epoch = time.time()

    for i, (x, y) in enumerate(
            datasets.minibatch(x_train, y_train, FLAGS.batch_size,
                               FLAGS.train_epochs)):

        # Update the parameters
        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x, y), state)

        params_lin = get_params(state_lin)
        state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin)

        # Print information after each 100 epochs
        if (i + 1) % (steps_per_epoch * 100) == 0:
            time_point = time.time() - start_epoch

            # Update epoch
            epoch += 100

            # Accuracy in batches
            f_x = util.output_in_batches(x_train, params, f_jit,
                                         FLAGS.batch_count_accuracy)
            f_x_test = util.output_in_batches(x_test, params, f_jit,
                                              FLAGS.batch_count_accuracy)
            f_x_lin = util.output_in_batches(x_train, params_lin, f_lin_jit,
                                             FLAGS.batch_count_accuracy)
            f_x_lin_test = util.output_in_batches(x_test, params_lin,
                                                  f_lin_jit,
                                                  FLAGS.batch_count_accuracy)
            # time_point = time.time() - start_epoch

            # Print information about past 100 epochs
            print(
                '{}\t{:.3f}\t{:.4f}\t\t{:.4f}\t\t{:.4f}\t{:.4f}\t\t{:.4f}\t\t{:.4f}'
                .format(epoch, time_point,
                        accuracy(f_x, y_train) * 100,
                        accuracy(f_x_lin, y_train) * 100, loss(f_x, y_train),
                        loss(f_x_lin, y_train),
                        accuracy(f_x_test, y_test) * 100,
                        accuracy(f_x_lin_test, y_test) * 100))

            # # Save params if epoch is multiple of learning decline or multiple of fixed value
            # if epoch % FLAGS.learning_decline == 0:
            #     filename = FLAGS.default_path + f'LinLeNetx{FLAGS.network_width}_pmod_{epoch}_{FLAGS.learning_decline}.npy'
            #     with open(filename, 'wb') as file:
            #         np.save(file, params)
            #     filename_lin = FLAGS.default_path + f'LinLeNetx{FLAGS.network_width}_pmod_{epoch}_{FLAGS.learning_decline}_lin.npy'
            #     with open(filename_lin, 'wb') as file_lin:
            #         np.save(file_lin, params_lin)

            # Reset timer
            start_epoch = time.time()

    duration = time.time() - start
    print(
        '----------------------------------------------------------------------------------------------------------'
    )
    print(f'Training complete in {duration} seconds.')

    # # Save final params in file
    # filename_final = FLAGS.default_path + f'LinLeNetx{FLAGS.network_width}_final_pmod_{FLAGS.train_epochs}_{FLAGS.learning_decline}.npy '
    # with open(filename_final, 'wb') as final:
    #     np.save(final, params)
    # filename_final_lin = FLAGS.default_path + f'LinLeNetx{FLAGS.network_width}_final_pmod_{FLAGS.train_epochs}_{FLAGS.learning_decline}_lin.npy'
    # with open(filename_final_lin, 'wb') as final_lin:
    #     np.save(final_lin, params_lin)

    # Compute output in batches
    f_x = util.output_in_batches(x_train, params, f_jit,
                                 FLAGS.batch_count_accuracy)
    f_x_lin = util.output_in_batches(x_train, params_lin, f_lin_jit,
                                     FLAGS.batch_count_accuracy)

    f_x_test = util.output_in_batches(x_test, params, f_jit,
                                      FLAGS.batch_count_accuracy)
    f_x_lin_test = util.output_in_batches(x_test, params_lin, f_lin_jit,
                                          FLAGS.batch_count_accuracy)

    # Print out summary data comparing the linear / nonlinear model.
    util.print_summary('train', y_train, f_x, f_x_lin, loss)
    util.print_summary('test', y_test, f_x_test, f_x_lin_test, loss)
Пример #4
0
def train_unbalanced_descent(D, dataQ0, dataP, wP, opt):
    n_samples, n_features = dataQ0.shape
    device = dataQ0.device

    # Lagrange multiplier for Augmented Lagrangian
    lambda_aug = torch.tensor([opt.lambda_aug_init],
                              requires_grad=True,
                              device=device)

    # MMD distance
    mmd = MMD_RFF(num_features=n_features, num_outputs=300).to(device)

    # Train
    print('Start training')

    if opt.plot_online:
        fig, ax = plt.subplots()
        ax.set_xlim((-1.1, 1.1))
        ax.set_ylim((-1.1, 1.1))
        scat = ax.scatter([], [], facecolor='r')

    # Save stuff
    collQ, coll_mmd = [], []
    birth_total, death_total = 0, 0

    dataQ = dataQ0.clone()
    for t in range(opt.T + 1):
        tic = time.time()

        # Snapshot of current state
        with torch.no_grad():
            mmd_PQ = mmd(dataP,
                         dataQ,
                         weights_X=wP if wP is not None else None)
        coll_mmd.append(mmd_PQ)
        collQ.append(dataQ.detach().cpu().numpy())  # snapshot of current state

        # (1) Update D network
        optimizerD = torch.optim.Adam(D.parameters(),
                                      lr=opt.lrD,
                                      weight_decay=opt.wdecay,
                                      amsgrad=True)
        D.train()
        for i in range(opt.n_c_startup if t == 0 else opt.n_c):
            optimizerD.zero_grad()

            x_p, w_p = minibatch((dataP, wP), opt.batchSizeD)
            x_q = minibatch(dataQ, opt.batchSizeD).requires_grad_(True)

            loss, Ep_f, Eq_f, normgrad_f2_q = D_forward_weights(
                D, x_p, w_p, x_q, 1.0, lambda_aug, opt.alpha, opt.rho)
            loss.backward()
            optimizerD.step()

            manual_sgd_(lambda_aug, opt.rho)

        tocD = time.time() - tic

        # (2) Update Q distribution (with birth/death)
        D.eval()

        # compute initial m_f
        with torch.no_grad():
            x_q = minibatch(dataQ)
            m_f = D(x_q).mean()

        # Update particles positions, and compute birth-death scores
        new_x_q, b_j = [], []
        for x_q, in get_loader(dataQ, batch_size=opt.batchSizeQ):
            x_q = x_q.detach().requires_grad_(True)
            sum_f_q = D(x_q).sum()
            grad_x_q = grad(outputs=sum_f_q, inputs=x_q, create_graph=True)[0]

            with torch.no_grad():
                new_x_q.append(x_q + opt.lrQ * grad_x_q)
                f_q_new = D(new_x_q[-1])

                # birth-death score
                m_f = m_f + (1 / n_samples) * (f_q_new.sum() - sum_f_q)

                b_j.append(f_q_new.view(-1) - m_f)

        new_x_q = torch.cat(new_x_q)
        b_j = torch.cat(b_j)

        # Birth
        idx_alive = (b_j > 0).nonzero().view(-1)
        p_j = 1 - torch.exp(-opt.alpha * opt.tau * b_j[idx_alive])
        idx_birth = idx_alive[p_j > torch.rand_like(p_j)]

        # Death
        idx_neg = (b_j <= 0).nonzero().view(-1)
        p_j = 1 - torch.exp(-opt.alpha * opt.tau * torch.abs(b_j[idx_neg]))
        ix_die = p_j > torch.rand_like(p_j)  # Particles that die
        idx_dead = idx_neg[ix_die]
        idx_notdead = idx_neg[~ix_die]  # Particles that don't die

        birth_total += len(idx_birth)
        death_total += len(idx_dead)

        if not opt.keep_order:
            new_x_q.data = new_x_q.data[torch.cat(
                (idx_alive, idx_notdead, idx_birth))]

            # Resize population
            if opt.balance:
                n_l = new_x_q.shape[0]

                if n_l < n_samples:  # Randomly double particles
                    r_idx = torch.randint(n_l, (n_samples - n_l, ))
                    new_x_q = torch.cat((new_x_q, new_x_q[r_idx]))

                if n_l > n_samples:  # Randomly kill particles
                    r_idx = torch.randperm(
                        n_l)[:n_samples]  # Particles that should be kept
                    new_x_q = new_x_q[r_idx]

        else:
            # Sample dead samples from cloned ones (if there are any), otherwise sample them from alive
            if len(idx_birth) > 0:
                r_idx = idx_birth[torch.randint(len(idx_birth),
                                                (len(idx_dead), ))]
            else:
                r_idx = idx_alive[torch.randint(len(idx_alive),
                                                (len(idx_dead), ))]
            new_x_q.data[idx_dead] = new_x_q.data[r_idx]

        dataQ = new_x_q.data

        # (3) print some stuff
        if t % opt.log_every == 0:
            x_p, w_p = minibatch((dataP, wP))
            x_q = minibatch(dataQ)
            loss, Ep_f, Eq_f, normgrad_f2_q = D_forward_weights(
                D, x_p, w_p, x_q, 1.0, lambda_aug, opt.alpha, opt.rho)
            with torch.no_grad():
                SobDist_lasti = Ep_f.item() - Eq_f.item()
                mmd_dist = mmd(dataP,
                               dataQ,
                               weights_X=wP if wP is not None else None)

            print('[{:5d}/{}] SobolevDist={:.4f}\t mmd={:.5f} births={} deaths={} Eq_normgrad_f2[stepQ]={:.3f} Ep_f={:.2f} Eq_f={:.2f} lambda_aug={:.4f}'.\
                format(t, opt.T, SobDist_lasti, mmd_dist, birth_total, death_total, normgrad_f2_q.mean().item(), Ep_f.item(), Eq_f.item(), lambda_aug.item()))

            if opt.plot_online:
                line.set_data(dataQ[:, 0].detach().cpu().numpy(),
                              dataQ[:, 1].detach().cpu().numpy())
                plt.pause(0.01)

    return dataQ, collQ, coll_mmd
Пример #5
0
def train_weighted_descent(D, dataQ0, dataP, wP, opt):
    n_samples, n_features = dataQ0.shape
    device = dataQ0.device

    # Lagrange multiplier for Augmented Lagrangian
    lambda_aug = torch.tensor([opt.lambda_aug_init],
                              requires_grad=True,
                              device=device)

    # MMD distance
    mmd = MMD_RFF(num_features=n_features, num_outputs=300).to(device)

    # Train
    print('Start training')

    if opt.plot_online:
        fig, ax = plt.subplots()
        ax.set_xlim((-1.1, 1.1))
        ax.set_ylim((-1.1, 1.1))
        scat = ax.scatter([], [], facecolor='r')

    # Save stuff
    wQ = torch.ones((len(dataQ0), 1), device=device)
    collQ, collW, coll_mmd = [], [], []

    dataQ = dataQ0.clone()
    for t in range(opt.T + 1):
        tic = time.time()

        # Snapshot of current state
        with torch.no_grad():
            mmd_PQ = mmd(dataP,
                         dataQ,
                         weights_X=wP if wP is not None else None,
                         weights_Y=wQ)

        coll_mmd.append(mmd_PQ)
        collQ.append(dataQ.detach().cpu().numpy())  # snapshot of current state
        collW.append(
            wQ.view(-1).detach().cpu().numpy())  # snapshot of current weights

        # (1) Update D network
        optimizerD = torch.optim.Adam(D.parameters(),
                                      lr=opt.lrD,
                                      weight_decay=opt.wdecay,
                                      amsgrad=True)
        D.train()
        for i in range(opt.n_c_startup if t == 0 else opt.n_c):
            optimizerD.zero_grad()

            x_p, w_p = minibatch((dataP, wP), opt.batchSizeD)
            x_q, w_q = minibatch((dataQ, wQ), opt.batchSizeD)

            loss, Ep_f, Eq_f, normgrad_f2_q = D_forward_weights(
                D, x_p, w_p, x_q, w_q, lambda_aug, opt.alpha, opt.rho)
            loss.backward()
            optimizerD.step()

            manual_sgd_(lambda_aug, opt.rho)

        tocD = time.time() - tic

        # (2) Update Q distribution (with birth/death)
        D.eval()
        with torch.no_grad():
            x_q, w_q = minibatch((dataQ, wQ))
            f_q = D(x_q)
            m_f = (w_q * f_q).mean()

        new_x_q, log_wQ = [], []
        for x_q, w_q in get_loader((dataQ, wQ), batch_size=opt.batchSizeQ):
            x_q = x_q.detach().requires_grad_(True)
            sum_f_q = D(x_q).sum()
            grad_x_q = grad(outputs=sum_f_q, inputs=x_q, create_graph=True)[0]

            # Update particles
            with torch.no_grad():
                # Move particles
                x_q.data += opt.lrQ * grad_x_q
                f_q = D(x_q)
                dw_q = f_q - m_f

                log_wQ.append((w_q / n_samples).log() + opt.tau * dw_q)
                new_x_q.append(x_q)

        # Update weights and dataQ
        wQ = F.softmax(torch.cat(log_wQ), dim=0) * n_samples
        dataQ = torch.cat(new_x_q)

        # (3) print some stuff
        if t % opt.log_every == 0:
            x_p, w_p = minibatch((dataP, wP))
            x_q, w_q = minibatch((dataQ, wQ))

            loss, Ep_f, Eq_f, normgrad_f2_q = D_forward_weights(
                D, x_p, w_p, x_q, w_q, lambda_aug, opt.alpha, opt.rho)
            with torch.no_grad():
                SobDist_lasti = Ep_f.item() - Eq_f.item()
                mmd_dist = mmd(dataP,
                               dataQ,
                               weights_X=wP if wP is not None else None,
                               weights_Y=wQ)

            print('[{:5d}/{}] SobolevDist={:.4f}\t mmd={:.5f} Eq_normgrad_f2[stepQ]={:.3f} Ep_f={:.2f} Eq_f={:.2f} lambda_aug={:.4f}'.\
                format(t, opt.T, SobDist_lasti, mmd_dist, normgrad_f2_q.mean().item(), Ep_f.item(), Eq_f.item(), lambda_aug.item()))

            if opt.plot_online:
                scat.set_offsets(dataQ.detach().cpu().numpy())
                rgba_colors = np.zeros((wQ.shape[0], 4))
                rgba_colors[:, 0] = 1.0
                rgba_colors[:, 3] = wQ.view(
                    -1).detach().cpu().numpy() / wQ.max().item()
                scat.set_color(rgba_colors)
                plt.pause(0.01)

    return dataQ, wQ, collQ, collW, coll_mmd
@jax.jit
def criterion(logits, targets):
    return jnp.mean(jnp.sum((logits - targets) ** 2, axis=1))


def model(x):
    return apply_fn(params, x)

print("Uncorrupted Test Accuracy")
print("="*80)
x_train, y_train, x_test, y_test = datasets.get_dataset("fashion_mnist", 1024, 128)

validate(
    val_loader=datasets.minibatch(
        x_test, y_test, batch_size=128, train_epochs=1, key=None
    ),
    model=apply_fn,
    params=params,
    criterion=criterion,
    epoch=20,
    batch_size=128,
    num_images=len(x_test),
)

print("Corrupted Test Accuracy")
print("="*80)
x_train, y_train, x_test, y_test = datasets.get_dataset("fashion_mnist", 1024, 128, perturb=True)

validate(
    val_loader=datasets.minibatch(