def decoder(hidden_dim, out_dim): return stax.serial( stax.Dense(hidden_dim, W_init=stax.randn()), stax.Softplus, stax.Dense(out_dim, W_init=stax.randn()), stax.Sigmoid, )
def build_model_stax(output_size, n_dense_units=300, conv_depth=300, n_conv_layers=2, n_dense_layers=0, kernel_size=5, across_batch=False, add_pos_encoding=False, mean_over_pos=False, mode="train"): """Build a model with convolutional layers followed by dense layers.""" del mode layers = [ cnn(conv_depth=conv_depth, n_conv_layers=n_conv_layers, kernel_size=kernel_size, across_batch=across_batch, add_pos_encoding=add_pos_encoding) ] for _ in range(n_dense_layers): layers.append(stax.Dense(n_dense_units)) layers.append(stax.Relu) layers.append(stax.Dense(output_size)) if mean_over_pos: layers.append(reduce_layer(jnp.mean, axis=1)) init_random_params, predict = stax.serial(*layers) return init_random_params, predict
def test_kohn_sham_neural_xc_density_mse_converge_tolerance( self, density_mse_converge_tolerance, expected_converged): init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(8), stax.Elu, stax.Dense(1))) params_init = init_fn(rng=random.PRNGKey(0)) states = jit_scf.kohn_sham( locations=self.locations, nuclear_charges=self.nuclear_charges, num_electrons=self.num_electrons, num_iterations=3, grids=self.grids, xc_energy_density_fn=tree_util.Partial(xc_energy_density_fn, params=params_init), interaction_fn=utils.exponential_coulomb, initial_density=self.num_electrons * utils.gaussian(grids=self.grids, center=0., sigma=0.5), density_mse_converge_tolerance=density_mse_converge_tolerance) np.testing.assert_array_equal(states.converged, expected_converged) for single_state in scf.state_iterator(states): self._test_state( single_state, self._create_testing_external_potential( utils.exponential_coulomb))
def test_local_density_approximation_wrong_output_shape(self): init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(16), stax.Elu, stax.Dense(3))) init_params = init_fn(rng=random.PRNGKey(0)) with self.assertRaisesRegex( ValueError, r'The output shape of the network ' r'should be \(-1, 1\) but got \(11, 3\)'): xc_energy_density_fn(self.density, init_params)
def test_local_density_approximation(self): init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(16), stax.Elu, stax.Dense(1))) init_params = init_fn(rng=random.PRNGKey(0)) xc_energy_density = xc_energy_density_fn(self.density, init_params) # The first layer of the network takes 1 feature (density). self.assertEqual(init_params[0][0].shape, (1, 16)) self.assertEqual(xc_energy_density.shape, (11, ))
def encoder(hidden_dim, z_dim): return stax.serial( stax.Dense(hidden_dim, W_init=stax.randn()), stax.Softplus, stax.FanOut(2), stax.parallel( stax.Dense(z_dim, W_init=stax.randn()), stax.serial(stax.Dense(z_dim, W_init=stax.randn()), stax.Exp), ), )
def create_stax_dense_model(only_digits: bool = False, hidden_units: int = 200) -> models.Model: """Creates EMNIST dense net with stax.""" num_classes = 10 if only_digits else 62 stax_init, stax_apply = stax.serial(stax.Flatten, stax.Dense(hidden_units), stax.Relu, stax.Dense(hidden_units), stax.Relu, stax.Dense(num_classes)) return models.create_model_from_stax(stax_init=stax_init, stax_apply=stax_apply, sample_shape=_STAX_SAMPLE_SHAPE, train_loss=_TRAIN_LOSS, eval_metrics=_EVAL_METRICS)
def test_kohn_sham_iteration_neural_xc(self, enforce_reflection_symmetry): init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(8), stax.Elu, stax.Dense(1))) params_init = init_fn(rng=random.PRNGKey(0)) initial_state = self._create_testing_initial_state( utils.exponential_coulomb) next_state = jit_scf.kohn_sham_iteration( state=initial_state, num_electrons=self.num_electrons, xc_energy_density_fn=tree_util.Partial(xc_energy_density_fn, params=params_init), interaction_fn=utils.exponential_coulomb, enforce_reflection_symmetry=enforce_reflection_symmetry) self._test_state(next_state, initial_state)
def main(_): # Define the total number of training steps training_iters = 200 rng = random.PRNGKey(0) rng, key = random.split(rng) init_random_params, model_apply = stax.serial( stax.Dense(256), stax.Relu, stax.Dense(256), stax.Relu, stax.Dense(2)) # init the model _, params = init_random_params(rng, (-1, 2)) # Create the optimizer corresponding to the 0th hyperparameter configuration # with the specified amount of training steps. # opt = optix.adam(1e-4) opt = jax_optix_opt_list.optimizer_for_idx(0, training_iters) opt_state = opt.init(params) @jax.jit def loss_fn(params, batch): x, y = batch y_hat = model_apply(params, x) return jnp.mean(jnp.square(y_hat - y)) @jax.jit def train_step(params, opt_state, batch): """Train for a single step.""" value_and_grad_fn = jax.value_and_grad(loss_fn) loss, grad = value_and_grad_fn(params, batch) # Note this is not the usual optix api as we additionally need parameter # values. # updates, opt_state = opt.update(grad, opt_state) updates, opt_state = opt.update_with_params(grad, params, opt_state) new_params = optix.apply_updates(params, updates) return new_params, opt_state, loss for _ in range(training_iters): # make a random batch of fake data rng, key = random.split(rng) inp = random.normal(key, [512, 2]) / 4. target = jnp.tanh(1 / (1e-6 + inp)) # train the model a step params, opt_state, loss = train_step(params, opt_state, (inp, target)) print(loss)
def main(_): # Define the total number of training steps training_iters = 200 rng = random.PRNGKey(0) rng, key = random.split(rng) # Construct a model. We are using stax here. init_random_params, model_apply = stax.serial(stax.Dense(256), stax.Relu, stax.Dense(256), stax.Relu, stax.Dense(2)) # init the model _, init_params = init_random_params(rng, (-1, 2)) # Create the optimizer corresponding to the 0th hyperparameter configuration # with the specified amount of training steps. opt_init, opt_update, get_params = jax_optimizers_opt_list.optimizer_for_idx( 0, training_iters) # opt_init, opt_update, get_params = optimizers.adam(1e-4) # Initialize the optimizer state opt_state = opt_init(init_params) @jax.jit def loss_fn(params, batch): """The loss function.""" x, y = batch y_hat = model_apply(params, x) return jnp.mean(jnp.square(y_hat - y)) @jax.jit def train_step(i, opt_state, batch): """Train for a single step.""" params = get_params(opt_state) value_and_grad_fn = jax.value_and_grad(loss_fn) loss, grad = value_and_grad_fn(params, batch) return opt_update(i, grad, opt_state), loss for i in range(training_iters): # make a random batch of fake data rng, key = random.split(rng) inp = random.normal(key, [512, 2]) / 4. target = jnp.tanh(1 / (1e-6 + inp)) # train the model a step opt_state, loss = train_step(i, opt_state, (inp, target)) print(loss)
def test_create_model_from_stax(self): stax_init, stax_apply = stax.serial(stax.Dense(10)) stax_model = models.create_model_from_stax(stax_init=stax_init, stax_apply=stax_apply, sample_shape=(-1, 2), train_loss=train_loss, eval_metrics=eval_metrics) self.check_model(stax_model)
def test_kohn_sham_neural_xc(self, interaction_fn): init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(8), stax.Elu, stax.Dense(1))) params_init = init_fn(rng=random.PRNGKey(0)) state = scf.kohn_sham( locations=self.locations, nuclear_charges=self.nuclear_charges, num_electrons=self.num_electrons, num_iterations=3, grids=self.grids, xc_energy_density_fn=tree_util.Partial( xc_energy_density_fn, params=params_init), interaction_fn=interaction_fn) for single_state in scf.state_iterator(state): self._test_state( single_state, self._create_testing_external_potential(interaction_fn))
def test_kohn_sham_iteration_neural_xc_density_loss_gradient_symmetry(self): # The network only has one layer. # The initial params contains weights with shape (1, 1) and bias (1,). init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(1))) init_params = init_fn(rng=random.PRNGKey(0)) initial_state = self._create_testing_initial_state( utils.exponential_coulomb) target_density = ( utils.gaussian(grids=self.grids, center=-0.5, sigma=1.) + utils.gaussian(grids=self.grids, center=0.5, sigma=1.)) spec, flatten_init_params = np_utils.flatten(init_params) def loss(flatten_params, initial_state, target_density): state = scf.kohn_sham_iteration( state=initial_state, num_electrons=self.num_electrons, xc_energy_density_fn=tree_util.Partial( xc_energy_density_fn, params=np_utils.unflatten(spec, flatten_params)), interaction_fn=utils.exponential_coulomb, enforce_reflection_symmetry=True) return jnp.sum(jnp.abs(state.density - target_density)) * utils.get_dx( self.grids) grad_fn = jax.grad(loss) params_grad = grad_fn( flatten_init_params, initial_state=initial_state, target_density=target_density) # Check gradient values. np.testing.assert_allclose(params_grad, [-1.34137017, 0.], atol=5e-7) # Check whether the gradient values match the numerical gradient. np.testing.assert_allclose( optimize.approx_fprime( xk=flatten_init_params, f=functools.partial( loss, initial_state=initial_state, target_density=target_density), epsilon=1e-9), params_grad, atol=1e-3)
def test_kohn_sham_iteration_neural_xc_energy_loss_gradient(self): # The network only has one layer. # The initial params contains weights with shape (1, 1) and bias (1,). init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(1))) init_params = init_fn(rng=random.PRNGKey(0)) initial_state = self._create_testing_initial_state( utils.exponential_coulomb) target_energy = 2. spec, flatten_init_params = np_utils.flatten(init_params) def loss(flatten_params, initial_state, target_energy): state = scf.kohn_sham_iteration( state=initial_state, num_electrons=self.num_electrons, xc_energy_density_fn=tree_util.Partial( xc_energy_density_fn, params=np_utils.unflatten(spec, flatten_params)), interaction_fn=utils.exponential_coulomb, enforce_reflection_symmetry=True) return (state.total_energy - target_energy) ** 2 grad_fn = jax.grad(loss) params_grad = grad_fn( flatten_init_params, initial_state=initial_state, target_energy=target_energy) # Check gradient values. np.testing.assert_allclose(params_grad, [-8.549952, -14.754195]) # Check whether the gradient values match the numerical gradient. np.testing.assert_allclose( optimize.approx_fprime( xk=flatten_init_params, f=functools.partial( loss, initial_state=initial_state, target_energy=target_energy), epsilon=1e-9), params_grad, atol=2e-3)
flags.DEFINE_integer('seed', 0, 'Seed for jax PRNG') flags.DEFINE_integer( 'microbatches', None, 'Number of microbatches ' '(must evenly divide batch_size)') flags.DEFINE_string('model_dir', None, 'Model directory') init_random_params, predict = stax.serial( stax.Conv(16, (8, 8), padding='SAME', strides=(2, 2)), stax.Relu, stax.MaxPool((2, 2), (1, 1)), stax.Conv(32, (4, 4), padding='VALID', strides=(2, 2)), stax.Relu, stax.MaxPool((2, 2), (1, 1)), stax.Flatten, stax.Dense(32), stax.Relu, stax.Dense(10), ) def loss(params, batch): inputs, targets = batch logits = predict(params, inputs) logits = stax.logsoftmax(logits) # log normalize return -jnp.mean(jnp.sum(logits * targets, axis=1)) # cross entropy loss def accuracy(params, batch): inputs, targets = batch target_class = jnp.argmax(targets, axis=1)
class StaxTest(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_shape={shape}", "shape": shape } for shape in [(2, 3), (5, )])) def testRandnInitShape(self, shape): key = random.PRNGKey(0) out = stax.randn()(key, shape) self.assertEqual(out.shape, shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_shape={shape}", "shape": shape } for shape in [(2, 3), (2, 3, 4)])) def testGlorotInitShape(self, shape): key = random.PRNGKey(0) out = stax.glorot()(key, shape) self.assertEqual(out.shape, shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}" .format(channels, filter_shape, padding, strides, input_shape), "channels": channels, "filter_shape": filter_shape, "padding": padding, "strides": strides, "input_shape": input_shape } for channels in [2, 3] for filter_shape in [(1, 1), (2, 3)] for padding in ["SAME", "VALID"] for strides in [None, (2, 1)] for input_shape in [(2, 10, 11, 1)])) def testConvShape(self, channels, filter_shape, padding, strides, input_shape): init_fun, apply_fun = stax.Conv(channels, filter_shape, strides=strides, padding=padding) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}" .format(channels, filter_shape, padding, strides, input_shape), "channels": channels, "filter_shape": filter_shape, "padding": padding, "strides": strides, "input_shape": input_shape } for channels in [2, 3] for filter_shape in [(1, 1), (2, 3), (3, 3)] for padding in ["SAME", "VALID"] for strides in [None, (2, 1), (2, 2)] for input_shape in [(2, 10, 11, 1)])) def testConvTransposeShape(self, channels, filter_shape, padding, strides, input_shape): init_fun, apply_fun = stax.ConvTranspose( channels, filter_shape, # 2D strides=strides, padding=padding) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}" .format(channels, filter_shape, padding, strides, input_shape), "channels": channels, "filter_shape": filter_shape, "padding": padding, "strides": strides, "input_shape": input_shape } for channels in [2, 3] for filter_shape in [(1, ), (2, ), (3, )] for padding in ["SAME", "VALID"] for strides in [None, (1, ), (2, )] for input_shape in [(2, 10, 1)])) def testConv1DTransposeShape(self, channels, filter_shape, padding, strides, input_shape): init_fun, apply_fun = stax.Conv1DTranspose(channels, filter_shape, strides=strides, padding=padding) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_out_dim={}_input_shape={}".format(out_dim, input_shape), "out_dim": out_dim, "input_shape": input_shape } for out_dim in [3, 4] for input_shape in [(2, 3), (3, 4)])) def testDenseShape(self, out_dim, input_shape): init_fun, apply_fun = stax.Dense(out_dim) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_input_shape={}_nonlinear={}".format(input_shape, nonlinear), "input_shape": input_shape, "nonlinear": nonlinear } for input_shape in [(2, 3), (2, 3, 4)] for nonlinear in ["Relu", "Sigmoid", "Elu", "LeakyRelu"])) def testNonlinearShape(self, input_shape, nonlinear): init_fun, apply_fun = getattr(stax, nonlinear) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_window_shape={}_padding={}_strides={}_input_shape={}" "_maxpool={}_spec={}".format(window_shape, padding, strides, input_shape, max_pool, spec), "window_shape": window_shape, "padding": padding, "strides": strides, "input_shape": input_shape, "max_pool": max_pool, "spec": spec } for window_shape in [(1, 1), (2, 3)] for padding in ["VALID"] for strides in [None, (2, 1)] for input_shape in [(2, 5, 6, 4)] for max_pool in [False, True] for spec in ["NHWC", "NCHW", "WHNC", "WHCN"])) def testPoolingShape(self, window_shape, padding, strides, input_shape, max_pool, spec): layer = stax.MaxPool if max_pool else stax.AvgPool init_fun, apply_fun = layer(window_shape, padding=padding, strides=strides, spec=spec) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_shape={input_shape}", "input_shape": input_shape } for input_shape in [(2, 3), (2, 3, 4)])) def testFlattenShape(self, input_shape): init_fun, apply_fun = stax.Flatten _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": f"_input_shape={input_shape}_spec={i}", "input_shape": input_shape, "spec": spec } for input_shape in [(2, 5, 6, 1)] for i, spec in enumerate([[stax.Conv(3, ( 2, 2))], [stax.Conv(3, (2, 2)), stax.Flatten, stax.Dense(4)]]))) def testSerialComposeLayersShape(self, input_shape, spec): init_fun, apply_fun = stax.serial(*spec) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_input_shape={input_shape}", "input_shape": input_shape } for input_shape in [(3, 4), (2, 5, 6, 1)])) def testDropoutShape(self, input_shape): init_fun, apply_fun = stax.Dropout(0.9) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_input_shape={input_shape}", "input_shape": input_shape } for input_shape in [(3, 4), (2, 5, 6, 1)])) def testFanInSum(self, input_shape): init_fun, apply_fun = stax.FanInSum _CheckShapeAgreement(self, init_fun, apply_fun, [input_shape, input_shape]) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_inshapes={input_shapes}_axis={axis}", "input_shapes": input_shapes, "axis": axis } for input_shapes, axis in [ ([(2, 3), (2, 1)], 1), ([(2, 3), (2, 1)], -1), ([(1, 2, 4), (1, 1, 4)], 1), ])) def testFanInConcat(self, input_shapes, axis): init_fun, apply_fun = stax.FanInConcat(axis) _CheckShapeAgreement(self, init_fun, apply_fun, input_shapes) def testIssue182(self): key = random.PRNGKey(0) init_fun, apply_fun = stax.Softmax input_shape = (10, 3) inputs = np.arange(30.).astype("float32").reshape(input_shape) out_shape, params = init_fun(key, input_shape) out = apply_fun(params, inputs) assert out_shape == out.shape assert np.allclose(np.sum(np.asarray(out), -1), 1.) def testBatchNormNoScaleOrCenter(self): key = random.PRNGKey(0) axes = (0, 1, 2) init_fun, apply_fun = stax.BatchNorm(axis=axes, center=False, scale=False) input_shape = (4, 5, 6, 7) inputs = random_inputs(self.rng(), input_shape) out_shape, params = init_fun(key, input_shape) out = apply_fun(params, inputs) means = np.mean(out, axis=(0, 1, 2)) std_devs = np.std(out, axis=(0, 1, 2)) assert np.allclose(means, np.zeros_like(means), atol=1e-4) assert np.allclose(std_devs, np.ones_like(std_devs), atol=1e-4) def testBatchNormShapeNHWC(self): key = random.PRNGKey(0) init_fun, apply_fun = stax.BatchNorm(axis=(0, 1, 2)) input_shape = (4, 5, 6, 7) inputs = random_inputs(self.rng(), input_shape) out_shape, params = init_fun(key, input_shape) out = apply_fun(params, inputs) self.assertEqual(out_shape, input_shape) beta, gamma = params self.assertEqual(beta.shape, (7, )) self.assertEqual(gamma.shape, (7, )) self.assertEqual(out_shape, out.shape) def testBatchNormShapeNCHW(self): key = random.PRNGKey(0) # Regression test for https://github.com/google/jax/issues/461 init_fun, apply_fun = stax.BatchNorm(axis=(0, 2, 3)) input_shape = (4, 5, 6, 7) inputs = random_inputs(self.rng(), input_shape) out_shape, params = init_fun(key, input_shape) out = apply_fun(params, inputs) self.assertEqual(out_shape, input_shape) beta, gamma = params self.assertEqual(beta.shape, (5, )) self.assertEqual(gamma.shape, (5, )) self.assertEqual(out_shape, out.shape)
def testDenseShape(self, out_dim, input_shape): init_fun, apply_fun = stax.Dense(out_dim) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)