def default_agent(obs_spec: specs.Array, action_spec: specs.DiscreteArray): """Initialize a DQN agent with default parameters.""" network_init, network = stax.serial( stax.Flatten, stax.Dense(50), stax.Relu, stax.Dense(50), stax.Relu, stax.Dense(action_spec.num_values), ) _, network_params = network_init(random.PRNGKey(seed=1), (-1, ) + obs_spec.shape) return DQNJAX(action_spec=action_spec, network=network, parameters=network_params, batch_size=32, discount=0.99, replay_capacity=10000, min_replay_size=100, sgd_period=1, target_update_period=4, learning_rate=1e-3, epsilon=0.05, seed=42)
def __init__(self, input_dims, output_dims, scope_var: OrderedDict): super(NeuralODE, self).__init__(input_dims, output_dims, scope_var) intermediate_dims = 2 * self.output_dims * self.input_dims init_random_params, self.predict = stax.serial( stax.Flatten, stax.Dense( intermediate_dims, partial(nn.initializers.glorot_normal(), dtype=np.float64), partial(nn.initializers.normal(), dtype=np.float64), ), stax.Sigmoid, stax.Dense( intermediate_dims, partial(nn.initializers.glorot_normal(), dtype=np.float64), partial(nn.initializers.normal(), dtype=np.float64), ), stax.Sigmoid, stax.Dense( output_dims * input_dims, partial(nn.initializers.glorot_normal(), dtype=np.float64), partial(nn.initializers.normal(), dtype=np.float64), ), ) self.key = random.PRNGKey(py_random.randrange(9999)) _, init_params = init_random_params(self.key, (1, input_dims)) scope_var["params"] = init_params
def main(unused_argv): # Build data and . print('Loading data.') x_train, y_train, x_test, y_test = datasets.mnist(permute_train=True) # Build the network init_fn, f = stax.serial( stax.Dense(2048), stax.Tanh, stax.Dense(10)) key = random.PRNGKey(0) _, params = init_fn(key, (-1, 784)) # Linearize the network about its initial parameters. f_lin = linearize(f, params) # Create and initialize an optimizer for both f and f_lin. opt_init, opt_apply, get_params = optimizers.momentum(FLAGS.learning_rate, 0.9) opt_apply = jit(opt_apply) state = opt_init(params) state_lin = opt_init(params) # Create a cross-entropy loss function. loss = lambda fx, y_hat: -np.mean(stax.logsoftmax(fx) * y_hat) # Specialize the loss function to compute gradients for both linearized and # full networks. grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y))) grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y))) # Train the network. print('Training.') print('Epoch\tLoss\tLinearized Loss') print('------------------------------------------') epoch = 0 steps_per_epoch = 50000 // FLAGS.batch_size for i, (x, y) in enumerate(datasets.minibatch( x_train, y_train, FLAGS.batch_size, FLAGS.train_epochs)): params = get_params(state) state = opt_apply(i, grad_loss(params, x, y), state) params_lin = get_params(state_lin) state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin) if i % steps_per_epoch == 0: print('{}\t{:.4f}\t{:.4f}'.format( epoch, loss(f(params, x), y), loss(f_lin(params_lin, x), y))) epoch += 1 # Print out summary data comparing the linear / nonlinear model. x, y = x_train[:10000], y_train[:10000] util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss) util.print_summary( 'test', y_test, f(params, x_test), f_lin(params_lin, x_test), loss)
def build_model_stax(output_size, n_dense_units=300, conv_depth=300, n_conv_layers=2, n_dense_layers=0, kernel_size=5, across_batch=False, add_pos_encoding=False, mean_over_pos=False, mode="train"): """Build a model with convolutional layers followed by dense layers.""" del mode layers = [ cnn(conv_depth=conv_depth, n_conv_layers=n_conv_layers, kernel_size=kernel_size, across_batch=across_batch, add_pos_encoding=add_pos_encoding) ] for _ in range(n_dense_layers): layers.append(stax.Dense(n_dense_units)) layers.append(stax.Relu) layers.append(stax.Dense(output_size)) if mean_over_pos: layers.append(reduce_layer(jnp.mean, axis=1)) init_random_params, predict = stax.serial(*layers) return init_random_params, predict
def __init__(self, input_dims, output_dims, scope_var: OrderedDict): self.input_dims = input_dims self.output_dims = output_dims self.key = random.PRNGKey(py_random.randrange(9999)) init_random_params, self.predict = stax.serial( stax.Flatten, stax.Dense( input_dims * 4, partial(nn.initializers.glorot_normal(), dtype=np.float64), partial(nn.initializers.normal(), dtype=np.float64), ), stax.Softplus, stax.Dense( input_dims * 4, partial(nn.initializers.glorot_normal(), dtype=np.float64), partial(nn.initializers.normal(), dtype=np.float64), ), stax.Softplus, stax.Dense( output_dims * 2, partial(nn.initializers.glorot_normal(), dtype=np.float64), partial(nn.initializers.normal(), dtype=np.float64), ), ) _, init_params = init_random_params(self.key, (-1, input_dims)) scope_var["encoder_params"] = init_params
def decoder(hidden_dim: int, out_dim: int) -> Tuple[Callable, Callable]: return stax.serial( stax.Dense(hidden_dim, W_init=stax.randn()), stax.Softplus, stax.Dense(out_dim, W_init=stax.randn()), stax.Sigmoid, )
def MultiHeadedAttention( # pylint: disable=invalid-name feature_depth, num_heads=8, dropout=1.0, mode='train'): """Transformer-style multi-headed attention. Args: feature_depth: int: depth of embedding num_heads: int: number of attention heads dropout: float: dropout rate - keep probability mode: str: 'train' or 'eval' Returns: Multi-headed self-attention layer. """ return stax.serial( stax.parallel(stax.Dense(feature_depth, W_init=xavier_uniform()), stax.Dense(feature_depth, W_init=xavier_uniform()), stax.Dense(feature_depth, W_init=xavier_uniform()), stax.Identity), PureMultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), stax.Dense(feature_depth, W_init=xavier_uniform()), )
def test_kohn_sham_neural_xc_density_mse_converge_tolerance( self, density_mse_converge_tolerance, expected_converged): init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(8), stax.Elu, stax.Dense(1))) params_init = init_fn(rng=random.PRNGKey(0)) states = jit_scf.kohn_sham( locations=self.locations, nuclear_charges=self.nuclear_charges, num_electrons=self.num_electrons, num_iterations=3, grids=self.grids, xc_energy_density_fn=tree_util.Partial(xc_energy_density_fn, params=params_init), interaction_fn=utils.exponential_coulomb, initial_density=self.num_electrons * utils.gaussian(grids=self.grids, center=0., sigma=0.5), density_mse_converge_tolerance=density_mse_converge_tolerance) np.testing.assert_array_equal(states.converged, expected_converged) for single_state in scf.state_iterator(states): self._test_state( single_state, self._create_testing_external_potential( utils.exponential_coulomb))
def make_stax_model(self): act = getattr(jstax, self._activation) layers = [] for h in self._hiddens: layers.extend([jstax.Dense(h), act]) layers.extend([jstax.Dense(1), _StaxSqueeze()]) return jstax.serial(*layers)
def main(): net_init, net_apply = stax.serial( stax.Dense(128), stax.Softplus, stax.Dense(128), stax.Softplus, stax.Dense(2), ) opt_init, opt_update, get_params = optimizers.adam(1e-3) out_shape, net_params = net_init(jax.random.PRNGKey(seed=42), input_shape=(-1, 2)) opt_state = opt_init(net_params) loss_history = [] print("Training...") train_step = get_train_step(opt_update, get_params, net_apply) for i in range(2000): x = sample_batch(size=128) loss, opt_state = train_step(i, opt_state, x) loss_history.append(loss.item()) print("Training Finished...") plot_gradients(loss_history, opt_state, get_params, net_params, net_apply)
def decoder(hidden_dim, out_dim): return stax.serial( stax.Dense(hidden_dim, W_init=stax.randn()), stax.Softplus, stax.Dense(out_dim, W_init=stax.randn()), stax.Sigmoid, )
def gen_cnn_conv4(output_units=10, W_initializers_str='glorot_normal()', b_initializers_str='normal()'): # This is an up-scaled version of the CNN in keras tutorial: https://keras.io/examples/cifar10_cnn/ return stax.serial( stax.Conv(out_chan=64, filter_shape=(3, 3), W_init=eval(W_initializers_str), b_init=eval(b_initializers_str)), stax.Relu, stax.Conv(out_chan=64, filter_shape=(3, 3), W_init=eval(W_initializers_str), b_init=eval(b_initializers_str)), stax.Relu, stax.MaxPool((2, 2), strides=(2, 2)), stax.Conv(out_chan=128, filter_shape=(3, 3), W_init=eval(W_initializers_str), b_init=eval(b_initializers_str)), stax.Relu, stax.Conv(out_chan=128, filter_shape=(3, 3), W_init=eval(W_initializers_str), b_init=eval(b_initializers_str)), stax.Relu, stax.MaxPool((2, 2), strides=(2, 2)), stax.Flatten, stax.Dense(512, W_init=eval(W_initializers_str), b_init=eval(b_initializers_str)), stax.Relu, stax.Dense(output_units, W_init=eval(W_initializers_str), b_init=eval(b_initializers_str)))
def encoder(hidden_dim, z_dim): return stax.serial( stax.Dense(hidden_dim, W_init=stax.randn()), stax.Softplus, stax.FanOut(2), stax.parallel(stax.Dense(z_dim, W_init=stax.randn()), stax.serial(stax.Dense(z_dim, W_init=stax.randn()), stax.Exp)), )
def mlp(args): return stax.serial( stax.Dense(args.hidden_dim), stax.Softplus, stax.Dense(args.hidden_dim), stax.Softplus, stax.Dense(args.output_dim), )
def MLP(num_hidden_layers=2, hidden_size=512, activation_fn=stax.Relu, num_output_classes=10): layers = [stax.Flatten] layers += [stax.Dense(hidden_size), activation_fn] * num_hidden_layers layers += [stax.Dense(num_output_classes), stax.LogSoftmax] return stax.serial(*layers)
def make_network(num_layers, num_channels): layers = [] for i in range(num_layers-1): layers.append(stax.Dense(num_channels)) layers.append(stax.Relu) layers.append(stax.Dense(3)) layers.append(stax.Sigmoid) return stax.serial(*layers)
def test_local_density_approximation(self): init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(16), stax.Elu, stax.Dense(1))) init_params = init_fn(rng=random.PRNGKey(0)) xc_energy_density = xc_energy_density_fn(self.density, init_params) # The first layer of the network takes 1 feature (density). self.assertEqual(init_params[0][0].shape, (1, 16)) self.assertEqual(xc_energy_density.shape, (11, ))
def test_local_density_approximation_wrong_output_shape(self): init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(16), stax.Elu, stax.Dense(3))) init_params = init_fn(rng=random.PRNGKey(0)) with self.assertRaisesRegex( ValueError, r'The output shape of the network ' r'should be \(-1, 1\) but got \(11, 3\)'): xc_energy_density_fn(self.density, init_params)
def encoder(hidden_dim, z_dim, activation='Tanh'): activation = getattr(stax, activation) encoder_init, encode = stax.serial( stax.Dense(hidden_dim), activation, # stax.Dense(hidden_dim), activation, stax.FanOut(2), stax.parallel(stax.Dense(z_dim), stax.Dense(z_dim)), ) return encoder_init, encode
def create_model(nbin, nhidden, nlayer): layers = [] for i in range(nlayer): layers.extend([ stax.Dense(nhidden), stax.LeakyRelu, stax.BatchNorm(axis=(0, 1)), ]) layers.extend([stax.Dense(nbin), stax.Softmax]) return stax.serial(*layers)
def decoder(hidden_dim, x_dim=2, activation='Tanh'): activation = getattr(stax, activation) decoder_init, decode = stax.serial( stax.Dense(hidden_dim), activation, # stax.Dense(hidden_dim), activation, stax.FanOut(2), stax.parallel(stax.Dense(x_dim), stax.Dense(x_dim)), ) return decoder_init, decode
def transform(rng, input_dim, output_dim): init_fun, apply_fun = stax.serial( stax.Dense(hidden_dim, weight_initializer, weight_initializer), act, stax.Dense(hidden_dim, weight_initializer, weight_initializer), act, stax.Dense(output_dim, weight_initializer, weight_initializer), ) _, params = init_fun(rng, (input_dim, )) return params, apply_fun
def jaxRbmSpinPhase(hilbert, alpha): return stax.serial( stax.FanOut(2), stax.parallel( stax.serial(stax.Dense(alpha * hilbert.size), LogCoshLayer, SumLayer), stax.serial(stax.Dense(alpha * hilbert.size), LogCoshLayer, SumLayer), ), FanInSum2ModPhase, )
def mlp(key, input_dim, output_dim, hidden_layers=(64, 32)): ''' make multilayer perceptron ''' layers = [] for hl in hidden_layers: layers += [stax.Dense(hl), stax.Tanh] layers.append(stax.Dense(output_dim)) init_fun, apply_fun = stax.serial(*layers) params = init_fun(key, (-1, input_dim))[1] return apply_fun, params
def get_f_and_L(): _, apply = stax.serial(TexVar(stax.Dense(256), 'z^1', ('W^1', 'b^1')), TexVar(stax.Relu, 'y^1'), TexVar(stax.Dense(3), 'z^2', ('W^2', 'b^2'))) def f_(params, x): return apply(params, tex_var(x, 'x', True)) def L_(params, x, y_hat): y_hat = tex_var(y_hat, '\\hat y', True) return tex_var(-np.sum(y_hat * jax.nn.log_softmax(f_(params, x))), 'L') return f_, L_
def gen_cnn_lenet_caffe(output_units = 10, W_initializers_str = 'glorot_normal()', b_initializers_str = 'normal()'): return stax.serial( stax.Conv(out_chan = 20, filter_shape = (5, 5), W_init= eval(W_initializers_str), b_init= eval(b_initializers_str) ), stax.Relu, stax.MaxPool((2, 2), strides = (2, 2)), stax.Conv(out_chan = 50, filter_shape = (5, 5), W_init= eval(W_initializers_str), b_init= eval(b_initializers_str) ), stax.Relu, stax.MaxPool((2, 2), strides = (2, 2)), stax.Flatten, stax.Dense(500, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu, stax.Dense(output_units, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)))
def get_f_and_nngp(): _, apply = stax.serial( TexVar(stax.Dense(256), 'z^1', ('W^1', 'b^1'), True), TexVar(stax.Relu, 'y^1'), TexVar(stax.Dense(3), 'z^2', ('W^2', 'b^2'))) def f_(params, x): return apply(params, tex_var(x, 'x', True)) def nngp_(params, x1, x2): x1 = tex_var(x1, 'x^1', True) x2 = tex_var(x2, 'x^2', True) return tex_var(apply(params, x1) @ apply(params, x2).T, '\\mathcal K') return f_, nngp_
def _create_stax_model(num_classes, sample_shape): """Creates toy stax model.""" stax_init_fn, stax_apply_fn = stax.serial(stax.Flatten, stax.Dense(2 * num_classes), stax.Dense(num_classes)) metrics_fn_map = collections.OrderedDict(accuracy=_accuracy) return model.create_model_from_stax( stax_init_fn=stax_init_fn, stax_apply_fn=stax_apply_fn, sample_shape=sample_shape, loss_fn=_loss, metrics_fn_map=metrics_fn_map)
def fully_connected(num_classes, layers=(64, 64)): """Build a fully connected neural network.""" stack = [stax.Flatten] # Concatenate fully connected layers. for num_units in layers: stack += [stax.Dense(num_units), stax.Relu] # Output layer. stack += [stax.Dense(num_classes), stax.LogSoftmax] return stax.serial(*stack)
def encoder(hidden_dim: int, z_dim: int) -> Tuple[Callable, Callable]: return stax.serial( stax.Dense(hidden_dim, W_init=stax.randn()), stax.Softplus, stax.FanOut(2), stax.parallel( stax.Dense(z_dim, W_init=stax.randn()), stax.serial( stax.Dense(z_dim, W_init=stax.randn()), stax.Exp, ), ), )