Beispiel #1
0
def main(unused_argv):
    # Build data and .
    print('Loading data.')
    x_train, y_train, x_test, y_test = datasets.get_dataset('mnist',
                                                            permute_train=True)

    # Build the network
    init_fn, f, _ = stax.serial(stax.Dense(2048, 1., 0.05), stax.Erf(),
                                stax.Dense(10, 1., 0.05))

    key = random.PRNGKey(0)
    _, params = init_fn(key, (-1, 784))

    # Linearize the network about its initial parameters.
    f_lin = nt.linearize(f, params)

    # Create and initialize an optimizer for both f and f_lin.
    opt_init, opt_apply, get_params = optimizers.momentum(
        FLAGS.learning_rate, 0.9)
    opt_apply = jit(opt_apply)

    state = opt_init(params)
    state_lin = opt_init(params)

    # Create a cross-entropy loss function.
    loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat)

    # Specialize the loss function to compute gradients for both linearized and
    # full networks.
    grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))
    grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y)))

    # Train the network.
    print('Training.')
    print('Epoch\tLoss\tLinearized Loss')
    print('------------------------------------------')

    epoch = 0
    steps_per_epoch = 50000 // FLAGS.batch_size

    for i, (x, y) in enumerate(
            datasets.minibatch(x_train, y_train, FLAGS.batch_size,
                               FLAGS.train_epochs)):

        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x, y), state)

        params_lin = get_params(state_lin)
        state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin)

        if i % steps_per_epoch == 0:
            print('{}\t{:.4f}\t{:.4f}'.format(epoch, loss(f(params, x), y),
                                              loss(f_lin(params_lin, x), y)))
            epoch += 1

    # Print out summary data comparing the linear / nonlinear model.
    x, y = x_train[:10000], y_train[:10000]
    util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss)
    util.print_summary('test', y_test, f(params, x_test),
                       f_lin(params_lin, x_test), loss)
Beispiel #2
0
def softmax_cross_entropy_with_logits_l2_reg(params, f, x, targets, masks = None, L2_REG_COEFF = 0.0, key = None): 

    """ cross entropy loss + weighted l2 regularization. 
   
    Args: 
        params: parameters in a stax format. 
        apply_fn: a function that maps a set of network-parameters together with a set of network-inputs to network-outputs.
        x: network inputs
        targets: the target outputs
        masks: the sparsity-inducing mask
        L2_REG_COEFF: l2 regularization constant.
    Returns:
        The cross entropy loss + weighted l2 regularization
    """ 
    
    if masks is not None:
        masked_params = get_sparse_params_filtered_by_masks(params, masks)
    else:
        masked_params = params
    
    if key is not None:
        dense_outputs = f(masked_params, x, rng = key)     
    else:
        dense_outputs = f(masked_params, x)
        
    preds = logsoftmax(dense_outputs)
    
    params_norm_squared = stax_params_l2_square(masked_params)
        
    return -np.mean(np.sum(preds * targets, axis=1)) + L2_REG_COEFF * params_norm_squared
Beispiel #3
0
def sparse_softmax_cross_entropy_with_logits(*, labels, logits):
    """
    https://www.tensorflow.org/api_docs/python/tf/nn/sparse_softmax_cross_entropy_with_logits
    """
    assert labels.shape == logits.shape[:-1]
    assert_array(labels, dtypes=(jp.int32,))
    assert_array(logits, dtypes=(jp.float32,))
    log_probs = logsoftmax(logits, axis=-1)
    chosen_log_probs = jp.squeeze(
        jp.take_along_axis(log_probs, jp.expand_dims(labels, axis=-1), axis=-1), axis=-1
    )
    return -chosen_log_probs
Beispiel #4
0
 def __init__(self, probs=None, logits=None):
     if logits is None:
         assert probs is not None
         self.log_probs = jp.log(probs)
         self._logits = self.log_probs
     else:
         assert probs is None
         self.log_probs = logsoftmax(logits, axis=-1)
         self._logits = logits
     self._probs = jp.exp(self.log_probs)
     self.batch_shape = self.log_probs.shape[:-1]
     self.event_shape = ()
     self._cdf_probs = jp.cumsum(self._probs, axis=-1) - self._probs
Beispiel #5
0
def split_mutation_predictor_output(output):
    return stax.logsoftmax(output[:, :,
                                  -1]), stax.logsoftmax(output[:, :, :-1])
    net_init, f = mlp(n_output=args.n_way,
                      n_hidden_layer=args.n_hidden_layer,
                      n_hidden_unit=args.n_hidden_unit,
                      bias_coef=args.bias_coef,
                      activation=args.activation,
                      norm=args.norm)
    _, params = net_init(rng=random.PRNGKey(42), input_shape=(-1, 2))

else:
    raise ValueError

# loss functions
if args.dataset == 'sinusoid':
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
elif args.dataset in ['omniglot', 'circle']:
    loss = lambda fx, targets: -np.sum(logsoftmax(fx) * targets
                                       ) / targets.shape[0]
    acc = lambda fx, targets: np.mean(
        np.argmax(logsoftmax(fx), axis=-1) == np.argmax(targets, axis=-1))
    param_acc = jit(lambda p, x, y: acc(f(p, x), y))
else:
    raise ValueError

grad_loss = jit(grad(lambda p, x, y: loss(f(p, x), y)))
param_loss = jit(lambda p, x, y: loss(f(p, x), y))

# optimizers    #TODO: separate optimizers for nonlinear and linear?
outer_opt_init, outer_opt_update, outer_get_params = select_opt(
    args.outer_opt_alg, args.outer_step_size)()
inner_opt_init, inner_opt_update, inner_get_params = select_opt(
    args.inner_opt_alg, args.inner_step_size)()
Beispiel #7
0
 def xentr(params, images_and_labels):
     images, labels = images_and_labels
     return -np.mean(stax.logsoftmax(f(params, images)) * labels)
def computation(params, inputs, targets):
    logits = predict(params, inputs)
    preds = stax.logsoftmax(logits)
    return -np.mean(np.sum(preds * targets, axis=1))
Beispiel #9
0
 def cross_entropy_loss(params, x_img, y_lbl):
     return -np.mean(stax.logsoftmax(predict(params, x_img)) * y_lbl)
Beispiel #10
0
def loss(params, predict, batch):
  inputs, targets = batch
  logits = predict(params, inputs)
  logits = stax.logsoftmax(logits)
  return -np.mean(np.sum(logits * targets, axis=1))
Beispiel #11
0
 def loss(params, batch):
     inputs, targets = batch
     logits = f(params, inputs)
     outputs = logsoftmax(logits)
     return -np.sum(outputs * targets) / targets.shape[0]
    net_init, f = mlp(n_output=args.n_way,
                      n_hidden_layer=args.n_hidden_layer,
                      n_hidden_unit=args.n_hidden_unit,
                      bias_coef=args.bias_coef,
                      activation=args.activation,
                      norm=args.norm)
    _, params = net_init(rng=random.PRNGKey(42), input_shape=(-1, 2))

else:
    raise ValueError

# loss functions
if args.dataset == 'sinusoid':
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2)
elif args.dataset in ['omniglot', 'circle']:
    loss = lambda fx, targets: -np.sum(logsoftmax(fx) * targets) / targets.shape[0]
    acc = lambda fx, targets: np.mean(np.argmax(logsoftmax(fx), axis=-1) == np.argmax(targets, axis=-1))
    param_acc = jit(lambda p, x, y: acc(f(p, x), y))
else:
    raise ValueError

grad_loss = jit(grad(lambda p, x, y: loss(f(p, x), y)))
param_loss = jit(lambda p, x, y: loss(f(p, x), y))

# optimizers    #TODO: separate optimizers for nonlinear and linear?
outer_opt_init, outer_opt_update, outer_get_params = select_opt(args.outer_opt_alg, args.outer_step_size)()
inner_opt_init, inner_opt_update, inner_get_params = select_opt(args.inner_opt_alg, args.inner_step_size)()

# consistent task for plotting eval
if args.dataset == 'sinusoid':
    task_eval = sinusoid_task(n_support=args.n_support, n_query=args.n_query)
def multiclass_loss(model, params, batch):
    inputs, targets = batch
    logits = model.apply(params, None, inputs)
    one_hot = jax.nn.one_hot(targets, logits.shape[-1])
    logits = stax.logsoftmax(logits)  # log normalize
    return -jnp.mean(jnp.sum(logits * one_hot, axis=-1))  # cross entropy loss
 def loss_adv(image, label):
   pred = model_fn(image[None])
   loss = - np.sum(logsoftmax(pred) * label)
   if targeted:
   	loss = -loss
   return loss
Beispiel #15
0
def _loss(params, batch):
    global _fn_holder
    data, labels = batch
    logits = _fn_holder(params, data)
    logits = stax.logsoftmax(logits)  # log normalize
    return -np.mean(np.sum(logits * labels, axis=1))  # cross entropy loss
Beispiel #16
0
def loss(params, batch):
    inputs, targets = batch
    logits = predict(params, inputs)
    logits = stax.logsoftmax(logits)  # log normalize
    return -np.mean(np.sum(logits * targets, 1))  # cross entropy loss
Beispiel #17
0
def loss(params, batch):
  inputs, targets = batch
  logits = predict(params, inputs)
  preds  = stax.logsoftmax(logits)
  return -np.sum(targets*preds)/len(targets)
def loss(params, batch):
    inputs, targets = batch
    logits = predict(params, inputs)
    preds = stax.logsoftmax(logits)
    return -np.mean(preds * targets)
Beispiel #19
0
 def loss(params, batch):
     inputs, targets = batch
     preds = predict(params, inputs)
     return -np.mean(logsoftmax(preds) * targets)
Beispiel #20
0
    lr = args.lr

if args.optimizer == 'sgd':
    opt_init, opt_apply, get_params = myopt.sgd(lr)
elif args.optimizer == 'momentum':
    opt_init, opt_apply, get_params = myopt.momentum(
        lr, args.momentum, weight_decay=args.weight_decay)
elif args.optimizer == 'adagrad':
    opt_init, opt_apply, get_params = optimizers.adagrad(lr, args.momentum)
elif args.optimizer == 'adam':
    opt_init, opt_apply, get_params = optimizers.adam(lr)

state = opt_init(params)

if args.loss == 'logistic':
    loss = lambda fx, y: np.mean(-np.sum(logsoftmax(fx) * y, axis=1))
elif args.loss == 'squared':
    loss = lambda fx, y: np.mean(np.sum((fx - y)**2, axis=1))
value_and_grad_loss = jit(
    value_and_grad(lambda params, x, y: loss(f(params, x), y)))
loss_fn = jit(lambda params, x, y: loss(f(params, x), y))
accuracy_sum = jit(
    lambda fx, y: np.sum(np.argmax(fx, axis=1) == np.argmax(y, axis=1)))

# Create tensorboard writer
writer = SummaryWriter(logdir=args.logdir)

# Train the network
global_step, running_count = 0, 0
running_loss, running_loss_g = 0., 0.
if args.save_path is not None:
Beispiel #21
0
def weight_space(train_embedding, test_embedding, data_set):
    init_fn, f, _ = stax.serial(
        stax.Dense(512, 1., 0.05),
        stax.Erf(),
        # 2 denotes 2 type of classes
        stax.Dense(2, 1., 0.05))

    key = random.PRNGKey(0)
    # (-1, 135),  135 denotes the feature length, here is 9 * 15 = 135
    _, params = init_fn(key, (-1, 135))

    # Linearize the network about its initial parameters.
    f_lin = nt.linearize(f, params)

    # Create and initialize an optimizer for both f and f_lin.
    opt_init, opt_apply, get_params = optimizers.momentum(1.0, 0.9)
    opt_apply = jit(opt_apply)

    state = opt_init(params)
    state_lin = opt_init(params)

    # Create a cross-entropy loss function.
    loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat)

    # Specialize the loss function to compute gradients for both linearized and
    # full networks.
    grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))
    grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y)))

    # Train the network.
    print('Training.')
    print('Epoch\tLoss\tLinearized Loss')
    print('------------------------------------------')

    epoch = 0
    # Use whole batch
    batch_size = 64
    train_epochs = 10
    steps_per_epoch = 100

    for i, (x, y) in enumerate(
            datasets.mini_batch(train_embedding, data_set['Y_train'],
                                batch_size, train_epochs)):
        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x, y), state)

        params_lin = get_params(state_lin)
        state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin)

        if i % steps_per_epoch == 0:
            print('{}\t{:.4f}\t{:.4f}'.format(epoch, loss(f(params, x), y),
                                              loss(f_lin(params_lin, x), y)))
            epoch += 1
        if i / steps_per_epoch == train_epochs:
            break

    # Print out summary data comparing the linear / nonlinear model.
    x, y = train_embedding[:10000], data_set['Y_train'][:10000]
    util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss)
    util.print_summary('test', data_set['Y_test'], f(params, test_embedding),
                       f_lin(params_lin, test_embedding), loss)