from jax.experimental import optimizers import jax.numpy as jnp from dataset.make_data import equation_of_motion, rk4_step, solve_lagrangian from dataset.plot import normalize_dp from visualization import plot_loss TRAIN_DATASET_PATH = "../data/train_data.pickle" TEST_DATASET_PATH = "../data/test_data.pickle" LOG_DIR = "./logs" # build a neural network model init_random_params, nn_forward_fn = stax.serial( stax.Dense(128), stax.Softplus, stax.Dense(128), stax.Softplus, stax.Dense(1), ) # replace the lagrangian with a parameteric model def learned_lagrangian(params): def lagrangian(q, q_t): assert q.shape == (2, ) state = normalize_dp(jnp.concatenate([q, q_t])) return jnp.squeeze(nn_forward_fn(params, state), axis=-1) return lagrangian
def repeat(layer, num_repeats): """Repeats layers serially num_repeats times.""" if num_repeats < 1: raise ValueError('Repeat combinator num_repeats must be >= 1.') layers = num_repeats * (layer, ) return stax.serial(*layers)
def initialize_parametric_nonlinearity(self, init_to='exponential', method=None, params_dict=None): if method is None: # if no methods specified, use defaults. # this piece of code is quite redundant. # need to refactor. if hasattr(self, 'nonlinearity'): method = self.nonlinearity else: method = self.filter_nonlinearity else: # overwrite the default nonlinearity if hasattr(self, 'nonlinearity'): self.nonlinearity = method else: self.filter_nonlinearity = method self.output_nonlinearity = method # prepare data if params_dict is None: params_dict = {} xrange = params_dict['xrange'] if 'xrange' in params_dict else 5 nx = params_dict['nx'] if 'nx' in params_dict else 1000 x0 = np.linspace(-xrange, xrange, nx) if init_to == 'exponential': y0 = np.exp(x0) elif init_to == 'softplus': y0 = softplus(x0) elif init_to == 'relu': y0 = relu(x0) elif init_to == 'nonparametric': y0 = self.fnl_nonparametric(x0) elif init_to == 'gaussian': import scipy.signal y0 = scipy.signal.gaussian(nx, nx / 10) # fit nonlin if method == 'spline': smooth = params_dict['smooth'] if 'smooth' in params_dict else 'cr' df = params_dict['df'] if 'df' in params_dict else 7 if smooth == 'cr': X = cr(x0, df) elif smooth == 'cc': X = cc(x0, df) elif smooth == 'bs': deg = params_dict['degree'] if 'degree' in params_dict else 3 X = bs(x0, df, deg) opt_params = np.linalg.pinv(X.T @ X) @ X.T @ y0 self.nl_basis = X def _nl(opt_params, x_new): return np.maximum(interp1d(x0, X @ opt_params)(x_new), 0) elif method == 'nn': def loss(params, data): x = data['x'] y = data['y'] yhat = _predict(params, x) return np.mean((y - yhat)**2) @jit def step(i, opt_state, data): p = get_params(opt_state) g = grad(loss)(p, data) return opt_update(i, g, opt_state) random_seed = params_dict[ 'random_seed'] if 'random_seed' in params_dict else 2046 key = random.PRNGKey(random_seed) step_size = params_dict[ 'step_size'] if 'step_size' in params_dict else 0.01 layer_sizes = params_dict[ 'layer_sizes'] if 'layer_sizes' in params_dict else [ 10, 10, 1 ] layers = [] for layer_size in layer_sizes: layers.append(Dense(layer_size)) layers.append(BatchNorm(axis=(0, 1))) layers.append(Relu) else: layers.pop(-1) init_random_params, _predict = stax.serial(*layers) num_subunits = params_dict[ 'num_subunits'] if 'num_subunits' in params_dict else 1 _, init_params = init_random_params(key, (-1, num_subunits)) opt_init, opt_update, get_params = optimizers.adam(step_size) opt_state = opt_init(init_params) num_iters = params_dict[ 'num_iters'] if 'num_iters' in params_dict else 1000 if num_subunits == 1: data = {'x': x0.reshape(-1, 1), 'y': y0.reshape(-1, 1)} else: data = { 'x': np.vstack([x0 for i in range(num_subunits)]).T, 'y': y0.reshape(-1, 1) } for i in range(num_iters): opt_state = step(i, opt_state, data) opt_params = get_params(opt_state) def _nl(opt_params, x_new): if len(x_new.shape) == 1: x_new = x_new.reshape(-1, 1) return np.maximum(_predict(opt_params, x_new), 0) self.nl_xrange = x0 self.nl_params = opt_params self.fnl_fitted = _nl
def main(): total_time = 20.0 gamma = 1.0 x_dim = 2 z_dim = 32 rng = random.PRNGKey(0) x0 = jp.array([2.0, 1.0]) ### Set up the problem/environment # xdot = Ax + Bu # u = - Kx # cost = xQx + uRu + 2xNu A, B, Q, R, N = fixed_env(x_dim) dynamics_fn = lambda x, u: A @ x + B @ u cost_fn = lambda x, u: x.T @ Q @ x + u.T @ R @ u + 2 * x.T @ N @ u policy_loss = policy_integrate_cost(x_dim, z_dim, dynamics_fn, cost_fn, gamma) ### Solve the Riccatti equation to get the infinite-horizon optimal solution. K, _, _ = control.lqr(A, B, Q, R, N) K = jp.array(K) t0 = time.time() opt_y_fwd, opt_y_bwd = policy_loss(lambda _, x: -K @ x)(None, x0, total_time) opt_cost = opt_y_fwd[1, 0] print(f"opt_cost = {opt_cost} in {time.time() - t0}s") print(opt_y_fwd) print(opt_y_bwd) print(f"l2 error: {jp.sqrt(jp.sum((opt_y_fwd - opt_y_bwd)**2))}") ### Set up the learned policy model. policy_init, policy = stax.serial( Dense(64), Tanh, Dense(x_dim), ) rng_init_params, rng = random.split(rng) _, init_policy_params = policy_init(rng_init_params, (x_dim, )) opt = make_optimizer(optimizers.adam(1e-3))(init_policy_params) runny_run = jit(policy_loss(policy)) ### Main optimization loop. costs = [] bwd_errors = [] for i in range(5000): t0 = time.time() (y_fwd, y_bwd), vjp = jax.vjp(runny_run, opt.value, x0, total_time) cost = y_fwd[1, 0] y_fwd_bar = jax.ops.index_update(jp.zeros_like(y_fwd), (1, 0), 1) g, _, _ = vjp((y_fwd_bar, jp.zeros_like(y_bwd))) opt = opt.update(g) bwd_err = jp.sqrt(jp.sum((y_fwd - y_bwd)**2)) bwd_errors.append(bwd_err) print(f"Episode {i}:") print(f" excess cost = {cost - opt_cost}") print(f" bwd error = {bwd_err}") print(f" elapsed = {time.time() - t0}") costs.append(float(cost)) print(f"Opt solution cost from starting point: {opt_cost}") ### Plot performance per iteration, incl. average optimal policy performance. _, ax1 = plt.subplots() ax1.set_xlabel("Iteration") ax1.set_ylabel("Cost", color="tab:blue") ax1.set_yscale("log") ax1.tick_params(axis="y", labelcolor="tab:blue") ax1.plot(costs, color="tab:blue") plt.axhline(opt_cost, linestyle="--", color="gray") ax2 = ax1.twinx() ax2.set_ylabel("Backward solve L2 error", color="tab:red") ax2.set_yscale("log") ax2.tick_params(axis="y", labelcolor="tab:red") ax2.plot(bwd_errors, color="tab:red") plt.title("ODE control of LQR problem") blt.show()
def testSerialComposeLayersShape(self, input_shape, spec): init_fun, apply_fun = stax.serial(*spec) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
def loss(params, batch): inputs, targets = batch preds = predict(params, inputs) return -np.mean(np.sum(preds * targets, axis=1)) def accuracy(params, batch): inputs, targets = batch target_class = np.argmax(targets, axis=1) predicted_class = np.argmax(predict(params, inputs), axis=1) return np.mean(predicted_class == target_class) init_random_params, predict = stax.serial( Conv(10, (5, 5), (1, 1)), Activator, MaxPool((4, 4)), Flatten, Dense(10), LogSoftmax) if __name__ == "__main__": rng = random.PRNGKey(0) step_size = 0.001 num_epochs = 10 batch_size = 128 momentum_mass = 0.9 # input shape for CNN input_shape = (-1, 28, 28, 1) # training/test split (train_images, train_labels), (test_images, test_labels) = mnist_data.tiny_mnist(flatten=False)
import jax import jax.experimental.stax as stax import jax.experimental.optimizers as opt from functools import partial import numpy as np import matplotlib.pyplot as plt ############################################### ########## SETTING UP NEURAL NETWORK ########## ############################################### net_init, net_apply = stax.serial( stax.Dense(40), stax.Relu, stax.Dense(40), stax.Relu, stax.Dense(1) ) input_shape = (-1, 1,) output_shape, net_params = net_init(input_shape) def loss(params, inputs, targets): predictions = net_apply(params, inputs) return jax.numpy.mean((targets - predictions)**2) xrange_inputs = jax.numpy.linspace((-5, 5, 100).reshape((100, 1))) targets = jax.numpy.sin(xrange_inputs) predictions = jax.vmap(partial(net_apply, net_params))(xrange_inputs) losses = jax.vmap(partial(loss, net_params))(xrange_inputs, targets) #################################################### ########## PLOTTING UNINITIALIZED NETWORK ##########
def main(): rng = random.PRNGKey(0) x_dim = 2 T = 20.0 policy_init, policy = stax.serial( Dense(64), Tanh, Dense(x_dim), ) x0 = jnp.ones(x_dim) A, B, Q, R, N = fixed_env(x_dim) print("System dynamics:") print(f" A = {A}") print(f" B = {B}") print(f" Q = {Q}") print(f" R = {R}") print(f" N = {N}") print() dynamics_fn = lambda x, u: A @ x + B @ u cost_fn = lambda x, u: x.T @ Q @ x + u.T @ R @ u + 2 * x.T @ N @ u ### Evaluate LQR solution to get a sense of optimal cost. K, _, _ = control.lqr(A, B, Q, R, N) K = jnp.array(K) opt_loss, _opt_K_grad = policy_loss_and_grad(dynamics_fn, cost_fn, T, lambda KK, x: -KK @ x)(x0, K) # This is true for longer time horizons, but not true for shorter time # horizons due to the LQR solution being an infinite-time solution. # assert jnp.allclose(opt_K_grad, 0) ### Training loop. rng_init_params, rng = random.split(rng) _, init_policy_params = policy_init(rng_init_params, (x_dim, )) opt = make_optimizer(optimizers.adam(1e-3))(init_policy_params) loss_and_grad = policy_loss_and_grad(dynamics_fn, cost_fn, T, policy) loss_per_iter = [] elapsed_per_iter = [] for iteration in range(10000): t0 = time.time() loss, g = loss_and_grad(x0, opt.value) opt = opt.update(g) elapsed = time.time() - t0 loss_per_iter.append(loss) elapsed_per_iter.append(elapsed) print(f"Iteration {iteration}") print(f" excess loss = {loss - opt_loss}") print(f" elapsed = {elapsed}") blt.remember({ "loss_per_iter": loss_per_iter, "elapsed_per_iter": elapsed_per_iter, "opt_loss": opt_loss }) _, ax1 = plt.subplots() ax1.set_xlabel("Iteration") ax1.set_ylabel("Cost", color="tab:blue") ax1.set_yscale("log") ax1.tick_params(axis="y", labelcolor="tab:blue") ax1.plot(loss_per_iter, color="tab:blue", label="Total rollout cost") plt.axhline(opt_loss, linestyle="--", color="gray") ax1.legend(loc="upper left") plt.title("Combined fwd-bwd BVP problem") blt.show()
def fcnn( output_dim: int, layers=2, units=256, skips=False, output_activation=None, first_layer_factor=None, parallel_with=None, ): """ Creates init and apply functions for a fully-connected NN with a Gaussian likelihood. :param output_dim: Number of output dimensions :param first_layer_factor: Scaling factor for first dense layer initialization. """ first_layer_factor = 10.0 if first_layer_factor is None else first_layer_factor activation = stax.elementwise(np.sin) block = (skip_connect( stax.serial(Dense(units, W_init=siren_w_init()()), activation)) if skips else stax.serial( Dense(units, W_init=siren_w_init()()), activation)) layer_list = [ stax.serial( Dense( units, W_init=siren_w_init(factor=first_layer_factor)(), b_init=jax.nn.initializers.normal(stddev=2.0 * math.pi), ), activation, ), *([block] * (layers - 1)), Dense(output_dim), # No skips on the last one! ] if output_activation is not None: layer_list.append(output_activation) if parallel_with is None: _init_fun, _apply_fun = stax.serial(*layer_list) else: _init_fun, _apply_fun = parallel(stax.serial(*layer_list), parallel_with) def init_fun(rng, input_shape): output_shape, net_params = _init_fun(rng, input_shape) params = {"net": net_params, "raw_noise": np.array(-2.0)} return output_shape, params # Conform to API: @t_wrapper def apply_fun(params, rng, inputs): return _apply_fun(params["net"], inputs) @t_wrapper def gaussian_fun(params, rng, inputs): pred_mean = apply_fun(params, None, inputs) pred_std = params["noise"] * np.ones(pred_mean.shape) return pred_mean, pred_std @t_wrapper def loss_fun(params, rng, data, batch_size=None, n=None, loss_type="nlp", reduce="sum"): """ :param batch_size: How large a batch to subselect from the provided data :param n: The total size of the dataset (to multiply batch estimate by) """ assert loss_type in ("nlp", "mse") inputs, targets = data n = inputs.shape[0] if n is None else n if batch_size is not None: rng, rng_batch = random.split(rng) i = random.permutation(rng_batch, n)[:batch_size] inputs, targets = inputs[i], targets[i] preds = apply_fun(params, rng, inputs).squeeze() mean_loss = ( -norm.logpdf(targets.squeeze(), preds, params["noise"]).mean() if loss_type == "nlp" else np.power(targets.squeeze() - preds, 2).mean()) if reduce == "sum": loss = n * mean_loss elif reduce == "mean": loss = mean_loss return loss def sample_fun_fun(rng, params): def f(x): return apply_fun(params, rng, x) return f return { "init": init_fun, "apply": apply_fun, "gaussian": gaussian_fun, "loss": loss_fun, "sample_function": sample_fun_fun, }
def create_surrogate(self): surrogate_init, surrogate = stax.serial(Dense(200), Relu, Dense(200), Relu, Dense(200), Relu, Dense(1)) return surrogate, surrogate_init
def AutoregressiveNN(input_dim, hidden_dims, param_dims=[1, 1], permutation=None, skip_connections=False, nonlinearity=stax.Relu): """ An implementation of a MADE-like auto-regressive neural network. Similar to the purely functional layer implemented in jax.experimental.stax, the `AutoregressiveNN` class has `init_fun` and `apply_fun` methods, where `init_fun` takes an rng_key key and an input shape and returns an (output_shape, params) pair, and `apply_fun` takes params and inputs and applies the layer. :param input_dim: the dimensionality of the input :type input_dim: int :param hidden_dims: the dimensionality of the hidden units per layer :type hidden_dims: list[int] :param param_dims: shape the output into parameters of dimension (p_n, input_dim) for p_n in param_dims when p_n > 1 and dimension (input_dim) when p_n == 1. The default is [1, 1], i.e. output two parameters of dimension (input_dim), which is useful for inverse autoregressive flow. :type param_dims: list[int] :param permutation: an optional permutation that is applied to the inputs and controls the order of the autoregressive factorization. in particular for the identity permutation the autoregressive structure is such that the Jacobian is triangular. Defaults to identity permutation. :type permutation: array of ints :param bool skip_connection: whether to add skip connections from the input to the output. :type skip_connections: bool :param nonlinearity: The nonlinearity to use in the feedforward network such as ReLU. Note that no nonlinearity is applied to the final network output, so the output is an unbounded real number. defaults to ReLU. :type nonlinearity: callable. :return: a tuple (init_fun, apply_fun) Reference: MADE: Masked Autoencoder for Distribution Estimation [arXiv:1502.03509] Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle """ output_multiplier = sum(param_dims) all_ones = (np.array(param_dims) == 1).all() # Calculate the indices on the output corresponding to each parameter ends = np.cumsum(np.array(param_dims), axis=0) starts = np.concatenate((np.zeros(1), ends[:-1])) param_slices = [slice(int(s), int(e)) for s, e in zip(starts, ends)] # Hidden dimension must be not less than the input otherwise it isn't # possible to connect to the outputs correctly for h in hidden_dims: if h < input_dim: raise ValueError( 'Hidden dimension must not be less than input dimension.') if permutation is None: permutation = jnp.arange(input_dim) # Create masks masks, mask_skip = create_mask(input_dim=input_dim, hidden_dims=hidden_dims, permutation=permutation, output_dim_multiplier=output_multiplier) main_layers = [] # Create masked layers for i, mask in enumerate(masks): main_layers.append(MaskedDense(mask)) if i < len(masks) - 1: main_layers.append(nonlinearity) if skip_connections: net_init, net = stax.serial( stax.FanOut(2), stax.parallel(stax.serial(*main_layers), MaskedDense(mask_skip, bias=False)), stax.FanInSum) else: net_init, net = stax.serial(*main_layers) def init_fun(rng_key, input_shape): """ :param rng_key: rng_key used to initialize parameters :param input_shape: input shape """ assert input_dim == input_shape[-1] return net_init(rng_key, input_shape) def apply_fun(params, inputs, **kwargs): """ :param params: layer parameters :param inputs: layer inputs """ out = net(params, inputs, **kwargs) # reshape output as necessary out = jnp.reshape(out, inputs.shape[:-1] + (output_multiplier, input_dim)) # move param dims to the first dimension out = jnp.moveaxis(out, -2, 0) if all_ones: # Squeeze dimension if all parameters are one dimensional out = tuple([out[i] for i in range(output_multiplier)]) else: # If not all ones, then probably don't want to squeeze a single dimension parameter out = tuple([out[s] for s in param_slices]) # if len(param_dims) == 1, we return the array instead of a tuple of arrays return out[0] if len(param_dims) == 1 else out return init_fun, apply_fun
print("Starting STAX demo") # Stax version from jax.experimental import stax def const_init(params): def init(rng_key, shape): return params return init #net_init, net_apply = stax.serial(stax.Dense(1), stax.elementwise(sigmoid)) dense_layer = stax.Dense(1, W_init=const_init(np.reshape(w, (D, 1))), b_init=const_init(np.array([0.0]))) net_init, net_apply = stax.serial(dense_layer) rng = jax.random.PRNGKey(0) in_shape = (-1, D) out_shape, net_params = net_init(rng, in_shape) def NLL_model(net_params, net_apply, batch): X, y = batch logits = net_apply(net_params, X) return BCE_with_logits(logits, y) y_pred2 = net_apply(net_params, X_test) loss2 = NLL_model(net_params, net_apply, (X_test, y_test)) grad_jax2 = grad(NLL_model)(net_params, net_apply, (X_test, y_test)) grad_jax3 = grad_jax2[0][0] # layer 0, block 0 (weights not bias) grad_jax4 = grad_jax3[:, 0] # column vector assert np.allclose(grad_np, grad_jax4)
def conv_net(mode="train"): out_dim = 1 dim_nums = ("NHWC", "HWIO", "NHWC") unit_stride = (1, 1) zero_pad = ((0, 0), (0, 0)) # Primary convolutional layer. conv_channels = 32 conv_init, conv_apply = Conv(out_chan=conv_channels, filter_shape=(3, 3), strides=(1, 3), padding=zero_pad) # Group all possible pairs. pair_channels, filter_shape = 256, (1, 2) # Convolutional block with the same number of channels. block_channels = pair_channels conv_block_init, conv_block_apply = serial( Conv(block_channels, (1, 3), unit_stride, "SAME"), Relu, # One block of convolutions. Conv(block_channels, (1, 3), unit_stride, "SAME"), Relu, Conv(block_channels, (1, 3), unit_stride, "SAME")) # Forward pass. hidden_size = 2048 dropout_rate = 0.25 serial_init, serial_apply = serial( Conv(block_channels, (1, 3), (1, 3), zero_pad), Relu, # Using convolution with strides Flatten, Dense(hidden_size), # instead of pooling for downsampling. # Dropout(dropout_rate, mode), Relu, Dense(out_dim)) def init_fun(rng, input_shape): rng, conv_rng, block_rng, serial_rng = jax.random.split(rng, num=4) # Primary convolutional layer. conv_shape, conv_params = conv_init(conv_rng, (-1, ) + input_shape) # Grouping all possible pairs. kernel_shape = [ filter_shape[0], filter_shape[1], conv_channels, pair_channels ] bias_shape = [1, 1, 1, pair_channels] W_init = glorot_normal(in_axis=2, out_axis=3) b_init = normal(1e-6) k1, k2 = jax.random.split(rng) W, b = W_init(k1, kernel_shape), b_init(k2, bias_shape) pair_shape = conv_shape[:2] + (15, ) + (pair_channels, ) pair_params = (W, b) # Convolutional block. conv_block_shape, conv_block_params = conv_block_init( block_rng, pair_shape) # Forward pass. serial_shape, serial_params = serial_init(serial_rng, conv_block_shape) params = [conv_params, pair_params, conv_block_params, serial_params] return serial_shape, params def apply_fun(params, inputs): conv_params, pair_params, conv_block_params, serial_params = params # Apply the primary convolutional layer. conv_out = conv_apply(conv_params, inputs) conv_out = relu(conv_out) # Group all possible pairs. W, b = pair_params pair_1 = conv_general_dilated(conv_out, W, unit_stride, zero_pad, (1, 1), (1, 1), dim_nums) + b pair_2 = conv_general_dilated(conv_out, W, unit_stride, zero_pad, (1, 1), (1, 2), dim_nums) + b pair_3 = conv_general_dilated(conv_out, W, unit_stride, zero_pad, (1, 1), (1, 3), dim_nums) + b pair_4 = conv_general_dilated(conv_out, W, unit_stride, zero_pad, (1, 1), (1, 4), dim_nums) + b pair_5 = conv_general_dilated(conv_out, W, unit_stride, zero_pad, (1, 1), (1, 5), dim_nums) + b pair_out = jnp.dstack([pair_1, pair_2, pair_3, pair_4, pair_5]) pair_out = relu(pair_out) # Convolutional block. conv_block_out = conv_block_apply(conv_block_params, pair_out) # Residual connection. res_out = conv_block_out + pair_out res_out = relu(res_out) # Forward pass. out = serial_apply(serial_params, res_out) return out return init_fun, apply_fun
def main(_): rng = random.PRNGKey(0) # Load MNIST dataset train_images, train_labels, test_images, test_labels = datasets.mnist() batch_size = 128 batch_shape = (-1, 28, 28, 1) num_train = train_images.shape[0] num_complete_batches, leftover = divmod(num_train, batch_size) num_batches = num_complete_batches + bool(leftover) train_images = np.reshape(train_images, batch_shape) test_images = np.reshape(test_images, batch_shape) def data_stream(): rng = npr.RandomState(0) while True: perm = rng.permutation(num_train) for i in range(num_batches): batch_idx = perm[i * batch_size:(i + 1) * batch_size] yield train_images[batch_idx], train_labels[batch_idx] batches = data_stream() # Model, loss, and accuracy functions init_random_params, predict = stax.serial( stax.Conv(32, (8, 8), strides=(2, 2), padding="SAME"), stax.Relu, stax.Conv(128, (6, 6), strides=(2, 2), padding="VALID"), stax.Relu, stax.Conv(128, (5, 5), strides=(1, 1), padding="VALID"), stax.Flatten, stax.Dense(128), stax.Relu, stax.Dense(10), ) def loss(params, batch): inputs, targets = batch preds = predict(params, inputs) return -np.mean(logsoftmax(preds) * targets) def accuracy(params, batch): inputs, targets = batch target_class = np.argmax(targets, axis=1) predicted_class = np.argmax(predict(params, inputs), axis=1) return np.mean(predicted_class == target_class) # Instantiate an optimizer opt_init, opt_update, get_params = optimizers.adam(0.001) @jit def update(i, opt_state, batch): params = get_params(opt_state) return opt_update(i, grad(loss)(params, batch), opt_state) # Initialize model _, init_params = init_random_params(rng, batch_shape) opt_state = opt_init(init_params) itercount = itertools.count() # Training loop print("\nStarting training...") for epoch in range(FLAGS.nb_epochs): start_time = time.time() for _ in range(num_batches): opt_state = update(next(itercount), opt_state, next(batches)) epoch_time = time.time() - start_time # Evaluate model on clean data params = get_params(opt_state) train_acc = accuracy(params, (train_images, train_labels)) test_acc = accuracy(params, (test_images, test_labels)) # Evaluate model on adversarial data model_fn = lambda images: predict(params, images) test_images_fgm = fast_gradient_method(model_fn, test_images, FLAGS.eps, np.inf) test_images_pgd = projected_gradient_descent(model_fn, test_images, FLAGS.eps, 0.01, 40, np.inf) test_acc_fgm = accuracy(params, (test_images_fgm, test_labels)) test_acc_pgd = accuracy(params, (test_images_pgd, test_labels)) print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time)) print("Training set accuracy: {}".format(train_acc)) print("Test set accuracy on clean examples: {}".format(test_acc)) print("Test set accuracy on FGM adversarial examples: {}".format( test_acc_fgm)) print("Test set accuracy on PGD adversarial examples: {}".format( test_acc_pgd))
d_hidden_size = hidden_size g_out_dim = 2 d_out_dim = 1 z_size = 64 smooth = 0.0 learning_rate = 1e-4 batch_size = 256 epochs = 10001 nsave = 2000 #Define Generator gen_init, gen_apply = stax.serial( Dense(g_hidden_size), Relu, Dense(g_hidden_size), Relu, Dense(g_hidden_size), Relu, Dense(g_hidden_size), Relu, Dense(g_hidden_size), Relu, Dense(g_hidden_size), Relu, Dense(g_out_dim) ) g_in_shape = (-1, z_size) _, gen_params = gen_init(rng, g_in_shape) #Define Discriminator disc_init, disc_apply = stax.serial( Dense(d_hidden_size), Relu, Dense(d_hidden_size), Relu, Dense(d_hidden_size), Relu, Dense(d_hidden_size), Relu, Dense(d_hidden_size), Relu, Dense(d_hidden_size), Relu,
def action_encoder(output_num): return serial( Dense(128), Tanh, # BatchNormつけるとなぜか出力が固定値になる, Dense(output_num))
def logistic_regression(rng, dim): """Logistic regression.""" init_params, forward = stax.serial(Dense(1), Sigmoid) temp, rng = random.split(rng) params = init_params(temp, (-1, dim))[1] return params, forward
def value_decoder(output_num): return serial( Dense(128), Tanh, # BatchNormつけるとなぜか出力が固定値になる Dense(output_num), )
set_step_size(lr_i) return opt_update(i, g, opt_state) key = random.PRNGKey(3) num_epochs = 60000 num_instances, num_vars = 200, 2 batch_size = num_instances minim, maxim = -5, 5 x, y = generate_data(num_instances, 1, key) X = np.c_[np.ones_like(x), x] batches = data_stream(num_instances, batch_size) init_random_params, predict = stax.serial(Dense(5), Softplus, Dense(5), Softplus, Dense(5), Softplus, Dense(5), Softplus, Dense(1)) lambd, step_size = 0.6, 1e-4 opt_init, opt_update, get_params, soft_thresholding, set_step_size = pgd( step_size, lambd) _, init_params = init_random_params(key, (-1, num_vars)) opt_state = opt_init(init_params) itercount = itertools.count() for epoch in range(num_epochs): opt_state = update(next(itercount), opt_state, next(batches)) labels = {"training": "Data", "test": "Deep Neural Net"} x_test = np.arange(minim, maxim, 1e-5) x_test = np.c_[np.ones((x_test.shape[0], 1)), x_test]
def loss(params, batch): inputs, targets = batch preds = predict(params, inputs) return -np.mean(np.sum(preds * targets, axis=1)) def accuracy(params, batch): inputs, targets = batch target_class = np.argmax(targets, axis=1) predicted_class = np.argmax(predict(params, inputs), axis=1) return np.mean(predicted_class == target_class) init_random_params, predict = stax.serial( Dense(300), Relu, Dense(10), LogSoftmax) if __name__ == "__main__": rng = random.PRNGKey(0) step_size = 0.001 num_epochs = 10 batch_size = 128 momentum_mass = 0.9 # training/test split train_images, train_labels, test_images, test_labels = datasets.mnist() num_train = train_images.shape[0] # converting to batches
def loss(params, batch): inputs, targets = batch preds = predict(params, inputs) return -jnp.mean(jnp.sum(preds * targets, axis=1)) def accuracy(params, batch): inputs, targets = batch target_class = jnp.argmax(targets, axis=1) predicted_class = jnp.argmax(predict(params, inputs), axis=1) return jnp.mean(predicted_class == target_class) init_random_params, predict = stax.serial(Dense(config.num_hidden), Relu, Dense(config.num_hidden), Relu, Dense(10), LogSoftmax) if __name__ == "__main__": rng = random.PRNGKey(0) num_epochs = 10 train_images, train_labels, test_images, test_labels = datasets.mnist() num_train = train_images.shape[0] num_complete_batches, leftover = divmod(num_train, config.batch_size) num_batches = num_complete_batches + bool(leftover) def data_stream(): rng = npr.RandomState(0) while True:
def make_main(input_shape): # the number of output channels depends on the number of input channels return stax.serial(Conv(filters1, (1, 1)), BatchNorm(), Relu, Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), Relu, Conv(input_shape[3], (1, 1)), BatchNorm())
plt.figure() plt.scatter(population_samples[:, 0], population_samples[:, 1]) plt.title("Population") plt.figure() plt.scatter(biased_samples[:, 0], biased_samples[:, 1]) plt.title("Biased sample") plt.show() encoder_init, encoder = stax.serial( Dense(32), Relu, FanOut(2), stax.parallel( Dense(latent_dim), stax.serial( Dense(latent_dim), Softplus, Dampen(0.1, 1e-6), ), ), DistributionLayer(DiagMVN), ) decoder_init, decoder = stax.serial( Dense(128), Relu, Dense(128), Relu, Dense(128), Relu, Dense(128),
def wide_resnet_group(n, num_channels, strides=(1, 1)): blocks = [wide_resnet_block(num_channels, strides, channel_mismatch=True)] for _ in range(1, n): blocks += [wide_resnet_block(num_channels, strides=(1, 1))] return stax.serial(*blocks)
def GCNPredicator(hidden_feats, activation=None, batchnorm=None, dropout=None, pooling_method='mean', predicator_hidden_feats=64, predicator_dropout=None, n_out=1, bias=True, normalize=True): r"""GCN Predicator is a wrapper function using GCN and MLP. Parameters ---------- hidden_feats : list[int] List of output node features. activation : list[Function] or None ``activation[i]`` is the activation function of the i-th GCN layer. batchnorm : list[bool] or None ``batchnorm[i]`` decides if batch normalization is to be applied on the output of the i-th GCN layer. ``len(batchnorm)`` equals the number of GCN layers. By default, batch normalization is applied for all GCN layers. dropout : list[float] or None ``dropout[i]`` decides the dropout probability on the output of the i-th GCN layer. ``len(dropout)`` equals the number of GCN layers. By default, no dropout is performed for all layers. pooling_method : str ('max', 'min', 'mean', 'sum') pooling method name predicator_hidden_feats : int Size of hidden graph representations in the predicator, default to 128. predicator_dropout : float or None The probability for dropout in the predicator, default to None. n_out : int Number of the output size, default to 1. bias : bool Whether to add bias after affine transformation, default to be True. normalize : bool Whether to normalize the adjacency matrix or not, default to be True. Returns ------- init_fun : Function Initializes the parameters of the layer. apply_fun : Function Defines the forward computation function. """ gcn_init, gcn_fun = GCN(hidden_feats, activation, batchnorm, dropout, bias, normalize) pooling_fun = pooling(method=pooling_method) predicator_dropout = 0.0 if predicator_dropout is None else predicator_dropout _, drop_fun = Dropout(predicator_dropout) dnn_layers = [Dense(predicator_hidden_feats), Relu, Dense(n_out)] dnn_init, dnn_fun = serial(*dnn_layers) def init_fun(rng, input_shape): """Initialize parameters. Parameters ---------- rng : PRNGKey rng is a value for generating random values. input_shape : (batch_size, N, M1) The shape of input (input node features). N is the total number of nodes in the batch of graphs. M1 is the input node feature size. Returns ------- output_shape : (batch_size, M4) The shape of output. M4 is the output size of GCNPredicator and equal to n_out. params: Tuple (gcn_param, dnn_param) gcn_param is all parameters of GCN. dnn_param is all parameters of full connected layer. """ output_shape = input_shape rng, gcn_rng, dnn_rng = random.split(rng, 3) output_shape, gcn_param = gcn_init(gcn_rng, output_shape) # convert out_shape by pooling output_shape = (output_shape[0], output_shape[-1]) output_shape, dnn_param = dnn_init(dnn_rng, output_shape) return output_shape, (gcn_param, dnn_param) def apply_fun(params, node_feats, adj, rng, is_train): """Define forward computation function. Parameters ---------- node_feats : ndarray of shape (batch_size, N, M1) Batched input node features. N is the total number of nodes in the batch of graphs. M1 is the input node feature size. adj : ndarray of shape (batch_size, N, N) Batched adjacency matrix. rng : PRNGKey rng is a value for generating random values is_train : bool Whether the model is training or not. Returns ------- out : ndarray of shape (batch_size, M4) The shape of output. M4 is the output size of GCNPredicator and equal to n_out. """ gcn_param, dnn_param = params rng, gcn_rng, dropout_rng = random.split(rng, 3) node_feats = gcn_fun(gcn_param, node_feats, adj, gcn_rng, is_train) # pooling graph_feat = pooling_fun(node_feats) if predicator_dropout != 0.0: graph_feat = drop_fun(None, graph_feat, is_train, rng=dropout_rng) out = dnn_fun(dnn_param, graph_feat) return out return init_fun, apply_fun
f = partial( apply_fun_scan, params ) # We use partial to “clone” all the params to use at all timesteps. _, out = lax.scan(f, h, inputs) return out return init_fun, apply_fun num_dims = 10 #0 # Number of OU timesteps batch_size = 64 # Batchsize num_hidden_units = 12 # GRU cells in the RNN layer # Initialize the network and perform a forward pass init_fun, gru_rnn = stax.serial( Dense(num_hidden_units), Relu, GRU(num_hidden_units), Dense(1) ) #<-this Dense(1) is applied to every lax.scan output from the GRU loop??? hence shape of pred _, params = init_fun(key, (batch_size, num_dims, 1)) def mse_loss(params, inputs, targets): """ Calculate the Mean Squared Error Prediction Loss. """ preds = gru_rnn(params, inputs) return np.mean((preds - targets)**2) @jit def update(params, x, y, opt_state): """ Perform a forward pass, calculate the MSE & perform a SGD step. """ loss, grads = value_and_grad(mse_loss)(params, x, y) opt_state = opt_update(0, grads, opt_state)
preds = predict(params, inputs) return -jnp.mean(jnp.sum(preds * targets, axis=1)) def accuracy(params, batch): inputs, targets = batch target_class = jnp.argmax(targets, axis=1) predicted_class = jnp.argmax(predict(params, inputs), axis=1) return jnp.mean(predicted_class == target_class) init_random_params, predict = stax.serial( Dense(1024), Tanh, Dense(1024), Tanh, # Dense(1024), Tanh, # Dense(1024), Tanh, # Dense(1024), Tanh, # Dense(1024), Tanh, # Dense(1024), Tanh, # Dense(1024), Tanh, Dense(10), LogSoftmax) if __name__ == "__main__": wandb.init(project="playing-the-lottery", entity="skainswo") rp = RngPooper(random.PRNGKey(0)) config = wandb.config config.learning_rate = 0.001 config.num_epochs = 100
return image_grid(nrow, ncol, sampled_images, (28, 28)) def image_grid(nrow, ncol, imagevecs, imshape): """Reshape a stack of image vectors into an image grid for plotting.""" images = iter(imagevecs.reshape((-1, ) + imshape)) return jnp.vstack([ jnp.hstack([next(images).T for _ in range(ncol)][::-1]) for _ in range(nrow) ]).T encoder_init, encode = stax.serial( Dense(512), Relu, Dense(512), Relu, FanOut(2), stax.parallel(Dense(10), stax.serial(Dense(10), Softplus)), ) decoder_init, decode = stax.serial( Dense(512), Relu, Dense(512), Relu, Dense(28 * 28), ) if __name__ == "__main__": step_size = 0.001 num_epochs = 100
'Ratio of the standard deviation to the clipping norm') flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm') flags.DEFINE_integer('batch_size', 256, 'Batch size') flags.DEFINE_integer('epochs', 60, 'Number of epochs') flags.DEFINE_integer('seed', 0, 'Seed for jax PRNG') flags.DEFINE_integer( 'microbatches', None, 'Number of microbatches ' '(must evenly divide batch_size)') flags.DEFINE_string('model_dir', None, 'Model directory') init_random_params, predict = stax.serial( stax.Conv(16, (8, 8), padding='SAME', strides=(2, 2)), stax.Relu, stax.MaxPool((2, 2), (1, 1)), stax.Conv(32, (4, 4), padding='VALID', strides=(2, 2)), stax.Relu, stax.MaxPool((2, 2), (1, 1)), stax.Flatten, stax.Dense(32), stax.Relu, stax.Dense(10), ) def loss(params, batch): inputs, targets = batch logits = predict(params, inputs) logits = stax.logsoftmax(logits) # log normalize return -jnp.mean(jnp.sum(logits * targets, axis=1)) # cross entropy loss def accuracy(params, batch):
flags.DEFINE_string('clustering', 'KMeans', 'clustering method for projected embeddings') flags.DEFINE_integer('ppc', 10, 'number of examples picked per cluster') flags.DEFINE_integer('n_extra', 3000, 'number of extra points') flags.DEFINE_integer('uncertain_extra', 1000, 'n_uncertain - n_extra') flags.DEFINE_bool( 'show_label', True, 'visualize predicted label at top/left, true at bottom/right') # BEGIN: define the classifier model init_fn_0, apply_fn_0 = stax.serial( stax.Conv(16, (8, 8), padding='SAME', strides=(2, 2)), stax.Tanh, stax.MaxPool((2, 2), (1, 1)), stax.Conv(32, (4, 4), padding='VALID', strides=(2, 2)), stax.Tanh, stax.MaxPool((2, 2), (1, 1)), stax.Flatten, # (-1, 800) stax.Dense(32), stax.Tanh, # embeddings ) init_fn_1, apply_fn_1 = stax.serial( stax.Dense(10), # logits ) def predict(params, inputs): params_0 = params[:-1] params_1 = params[-1:] embeddings = apply_fn_0(params_0, inputs)