def main(unused_argv): # Build data and . print('Loading data.') x_train, y_train, x_test, y_test = datasets.get_dataset('mnist', permute_train=True) # Build the network init_fn, f, _ = stax.serial(stax.Dense(2048, 1., 0.05), stax.Erf(), stax.Dense(10, 1., 0.05)) key = random.PRNGKey(0) _, params = init_fn(key, (-1, 784)) # Linearize the network about its initial parameters. f_lin = nt.linearize(f, params) # Create and initialize an optimizer for both f and f_lin. opt_init, opt_apply, get_params = optimizers.momentum( FLAGS.learning_rate, 0.9) opt_apply = jit(opt_apply) state = opt_init(params) state_lin = opt_init(params) # Create a cross-entropy loss function. loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat) # Specialize the loss function to compute gradients for both linearized and # full networks. grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y))) grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y))) # Train the network. print('Training.') print('Epoch\tLoss\tLinearized Loss') print('------------------------------------------') epoch = 0 steps_per_epoch = 50000 // FLAGS.batch_size for i, (x, y) in enumerate( datasets.minibatch(x_train, y_train, FLAGS.batch_size, FLAGS.train_epochs)): params = get_params(state) state = opt_apply(i, grad_loss(params, x, y), state) params_lin = get_params(state_lin) state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin) if i % steps_per_epoch == 0: print('{}\t{:.4f}\t{:.4f}'.format(epoch, loss(f(params, x), y), loss(f_lin(params_lin, x), y))) epoch += 1 # Print out summary data comparing the linear / nonlinear model. x, y = x_train[:10000], y_train[:10000] util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss) util.print_summary('test', y_test, f(params, x_test), f_lin(params_lin, x_test), loss)
def softmax_cross_entropy_with_logits_l2_reg(params, f, x, targets, masks = None, L2_REG_COEFF = 0.0, key = None): """ cross entropy loss + weighted l2 regularization. Args: params: parameters in a stax format. apply_fn: a function that maps a set of network-parameters together with a set of network-inputs to network-outputs. x: network inputs targets: the target outputs masks: the sparsity-inducing mask L2_REG_COEFF: l2 regularization constant. Returns: The cross entropy loss + weighted l2 regularization """ if masks is not None: masked_params = get_sparse_params_filtered_by_masks(params, masks) else: masked_params = params if key is not None: dense_outputs = f(masked_params, x, rng = key) else: dense_outputs = f(masked_params, x) preds = logsoftmax(dense_outputs) params_norm_squared = stax_params_l2_square(masked_params) return -np.mean(np.sum(preds * targets, axis=1)) + L2_REG_COEFF * params_norm_squared
def sparse_softmax_cross_entropy_with_logits(*, labels, logits): """ https://www.tensorflow.org/api_docs/python/tf/nn/sparse_softmax_cross_entropy_with_logits """ assert labels.shape == logits.shape[:-1] assert_array(labels, dtypes=(jp.int32,)) assert_array(logits, dtypes=(jp.float32,)) log_probs = logsoftmax(logits, axis=-1) chosen_log_probs = jp.squeeze( jp.take_along_axis(log_probs, jp.expand_dims(labels, axis=-1), axis=-1), axis=-1 ) return -chosen_log_probs
def __init__(self, probs=None, logits=None): if logits is None: assert probs is not None self.log_probs = jp.log(probs) self._logits = self.log_probs else: assert probs is None self.log_probs = logsoftmax(logits, axis=-1) self._logits = logits self._probs = jp.exp(self.log_probs) self.batch_shape = self.log_probs.shape[:-1] self.event_shape = () self._cdf_probs = jp.cumsum(self._probs, axis=-1) - self._probs
def split_mutation_predictor_output(output): return stax.logsoftmax(output[:, :, -1]), stax.logsoftmax(output[:, :, :-1])
net_init, f = mlp(n_output=args.n_way, n_hidden_layer=args.n_hidden_layer, n_hidden_unit=args.n_hidden_unit, bias_coef=args.bias_coef, activation=args.activation, norm=args.norm) _, params = net_init(rng=random.PRNGKey(42), input_shape=(-1, 2)) else: raise ValueError # loss functions if args.dataset == 'sinusoid': loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) elif args.dataset in ['omniglot', 'circle']: loss = lambda fx, targets: -np.sum(logsoftmax(fx) * targets ) / targets.shape[0] acc = lambda fx, targets: np.mean( np.argmax(logsoftmax(fx), axis=-1) == np.argmax(targets, axis=-1)) param_acc = jit(lambda p, x, y: acc(f(p, x), y)) else: raise ValueError grad_loss = jit(grad(lambda p, x, y: loss(f(p, x), y))) param_loss = jit(lambda p, x, y: loss(f(p, x), y)) # optimizers #TODO: separate optimizers for nonlinear and linear? outer_opt_init, outer_opt_update, outer_get_params = select_opt( args.outer_opt_alg, args.outer_step_size)() inner_opt_init, inner_opt_update, inner_get_params = select_opt( args.inner_opt_alg, args.inner_step_size)()
def xentr(params, images_and_labels): images, labels = images_and_labels return -np.mean(stax.logsoftmax(f(params, images)) * labels)
def computation(params, inputs, targets): logits = predict(params, inputs) preds = stax.logsoftmax(logits) return -np.mean(np.sum(preds * targets, axis=1))
def cross_entropy_loss(params, x_img, y_lbl): return -np.mean(stax.logsoftmax(predict(params, x_img)) * y_lbl)
def loss(params, predict, batch): inputs, targets = batch logits = predict(params, inputs) logits = stax.logsoftmax(logits) return -np.mean(np.sum(logits * targets, axis=1))
def loss(params, batch): inputs, targets = batch logits = f(params, inputs) outputs = logsoftmax(logits) return -np.sum(outputs * targets) / targets.shape[0]
net_init, f = mlp(n_output=args.n_way, n_hidden_layer=args.n_hidden_layer, n_hidden_unit=args.n_hidden_unit, bias_coef=args.bias_coef, activation=args.activation, norm=args.norm) _, params = net_init(rng=random.PRNGKey(42), input_shape=(-1, 2)) else: raise ValueError # loss functions if args.dataset == 'sinusoid': loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2) elif args.dataset in ['omniglot', 'circle']: loss = lambda fx, targets: -np.sum(logsoftmax(fx) * targets) / targets.shape[0] acc = lambda fx, targets: np.mean(np.argmax(logsoftmax(fx), axis=-1) == np.argmax(targets, axis=-1)) param_acc = jit(lambda p, x, y: acc(f(p, x), y)) else: raise ValueError grad_loss = jit(grad(lambda p, x, y: loss(f(p, x), y))) param_loss = jit(lambda p, x, y: loss(f(p, x), y)) # optimizers #TODO: separate optimizers for nonlinear and linear? outer_opt_init, outer_opt_update, outer_get_params = select_opt(args.outer_opt_alg, args.outer_step_size)() inner_opt_init, inner_opt_update, inner_get_params = select_opt(args.inner_opt_alg, args.inner_step_size)() # consistent task for plotting eval if args.dataset == 'sinusoid': task_eval = sinusoid_task(n_support=args.n_support, n_query=args.n_query)
def multiclass_loss(model, params, batch): inputs, targets = batch logits = model.apply(params, None, inputs) one_hot = jax.nn.one_hot(targets, logits.shape[-1]) logits = stax.logsoftmax(logits) # log normalize return -jnp.mean(jnp.sum(logits * one_hot, axis=-1)) # cross entropy loss
def loss_adv(image, label): pred = model_fn(image[None]) loss = - np.sum(logsoftmax(pred) * label) if targeted: loss = -loss return loss
def _loss(params, batch): global _fn_holder data, labels = batch logits = _fn_holder(params, data) logits = stax.logsoftmax(logits) # log normalize return -np.mean(np.sum(logits * labels, axis=1)) # cross entropy loss
def loss(params, batch): inputs, targets = batch logits = predict(params, inputs) logits = stax.logsoftmax(logits) # log normalize return -np.mean(np.sum(logits * targets, 1)) # cross entropy loss
def loss(params, batch): inputs, targets = batch logits = predict(params, inputs) preds = stax.logsoftmax(logits) return -np.sum(targets*preds)/len(targets)
def loss(params, batch): inputs, targets = batch logits = predict(params, inputs) preds = stax.logsoftmax(logits) return -np.mean(preds * targets)
def loss(params, batch): inputs, targets = batch preds = predict(params, inputs) return -np.mean(logsoftmax(preds) * targets)
lr = args.lr if args.optimizer == 'sgd': opt_init, opt_apply, get_params = myopt.sgd(lr) elif args.optimizer == 'momentum': opt_init, opt_apply, get_params = myopt.momentum( lr, args.momentum, weight_decay=args.weight_decay) elif args.optimizer == 'adagrad': opt_init, opt_apply, get_params = optimizers.adagrad(lr, args.momentum) elif args.optimizer == 'adam': opt_init, opt_apply, get_params = optimizers.adam(lr) state = opt_init(params) if args.loss == 'logistic': loss = lambda fx, y: np.mean(-np.sum(logsoftmax(fx) * y, axis=1)) elif args.loss == 'squared': loss = lambda fx, y: np.mean(np.sum((fx - y)**2, axis=1)) value_and_grad_loss = jit( value_and_grad(lambda params, x, y: loss(f(params, x), y))) loss_fn = jit(lambda params, x, y: loss(f(params, x), y)) accuracy_sum = jit( lambda fx, y: np.sum(np.argmax(fx, axis=1) == np.argmax(y, axis=1))) # Create tensorboard writer writer = SummaryWriter(logdir=args.logdir) # Train the network global_step, running_count = 0, 0 running_loss, running_loss_g = 0., 0. if args.save_path is not None:
def weight_space(train_embedding, test_embedding, data_set): init_fn, f, _ = stax.serial( stax.Dense(512, 1., 0.05), stax.Erf(), # 2 denotes 2 type of classes stax.Dense(2, 1., 0.05)) key = random.PRNGKey(0) # (-1, 135), 135 denotes the feature length, here is 9 * 15 = 135 _, params = init_fn(key, (-1, 135)) # Linearize the network about its initial parameters. f_lin = nt.linearize(f, params) # Create and initialize an optimizer for both f and f_lin. opt_init, opt_apply, get_params = optimizers.momentum(1.0, 0.9) opt_apply = jit(opt_apply) state = opt_init(params) state_lin = opt_init(params) # Create a cross-entropy loss function. loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat) # Specialize the loss function to compute gradients for both linearized and # full networks. grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y))) grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y))) # Train the network. print('Training.') print('Epoch\tLoss\tLinearized Loss') print('------------------------------------------') epoch = 0 # Use whole batch batch_size = 64 train_epochs = 10 steps_per_epoch = 100 for i, (x, y) in enumerate( datasets.mini_batch(train_embedding, data_set['Y_train'], batch_size, train_epochs)): params = get_params(state) state = opt_apply(i, grad_loss(params, x, y), state) params_lin = get_params(state_lin) state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin) if i % steps_per_epoch == 0: print('{}\t{:.4f}\t{:.4f}'.format(epoch, loss(f(params, x), y), loss(f_lin(params_lin, x), y))) epoch += 1 if i / steps_per_epoch == train_epochs: break # Print out summary data comparing the linear / nonlinear model. x, y = train_embedding[:10000], data_set['Y_train'][:10000] util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss) util.print_summary('test', data_set['Y_test'], f(params, test_embedding), f_lin(params_lin, test_embedding), loss)