def global_conv_block(num_channels, grids, minval, maxval, downsample_factor): """Global convolution block. First downsample the input, then apply global conv, finally upsample and concatenate with the input. The input itself is one channel in the output. Args: num_channels: Integer, the number of channels. grids: Float numpy array with shape (num_grids,). minval: Float, the min value in the uniform sampling for exponential width. maxval: Float, the max value in the uniform sampling for exponential width. downsample_factor: Integer, the factor of downsampling. The grids are downsampled with step size 2 ** downsample_factor. Returns: (init_fn, apply_fn) pair. """ layers = [] layers.extend([linear_interpolation_transpose()] * downsample_factor) layers.append(exponential_global_convolution( num_channels=num_channels - 1, # one channel is reserved for input. grids=grids, minval=minval, maxval=maxval, downsample_factor=downsample_factor)) layers.extend([linear_interpolation()] * downsample_factor) global_conv_path = stax.serial(*layers) return stax.serial( stax.FanOut(2), stax.parallel(stax.Identity, global_conv_path), stax.FanInConcat(axis=-1), )
def ConvBlock(kernel_size, filters, strides=(2, 2)): ks = kernel_size filters1, filters2, filters3 = filters Main = stax.serial( Conv(filters1, (1, 1), strides), BatchNorm(), Relu, Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), Relu, Conv(filters3, (1, 1)), BatchNorm()) Shortcut = stax.serial(Conv(filters3, (1, 1), strides), BatchNorm()) return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum, Relu)
def encoder(hidden_dim, z_dim): return stax.serial( stax.Dense(hidden_dim, W_init=stax.randn()), stax.Softplus, stax.FanOut(2), stax.parallel( stax.Dense(z_dim, W_init=stax.randn()), stax.serial(stax.Dense(z_dim, W_init=stax.randn()), stax.Exp), ), )
def build_unet( num_filters_list, core_num_filters, activation, num_channels=0, grids=None, minval=None, maxval=None, apply_negativity_transform=True): """Builds U-net. This neural network is used to parameterize a many-to-many mapping. Args: num_filters_list: List of integers, the number of filters for each downsampling_block. The number of filters for each upsampling_block is in reverse order. For example, if num_filters_list=[16, 32, 64], there are 3 downsampling_block with number of filters from left to right: 16, 32, 64. There are 3 upsampling_block with number of filters from left to right: 64, 32 ,16. core_num_filters: Integer, the number of filters for the convolution layer at the bottom of the U-shape structure. activation: String, the activation function to use in the network. num_channels: Integer, the number of channels. grids: Float numpy array with shape (num_grids,). minval: Float, the min value in the uniform sampling for exponential width. maxval: Float, the max value in the uniform sampling for exponential width. apply_negativity_transform: Boolean, whether to add negativity_transform at the end. Returns: (init_fn, apply_fn) pair. """ layer = stax.serial( Conv1D(core_num_filters, filter_shape=(3,), padding='SAME'), _STAX_ACTIVATION[activation], Conv1D(core_num_filters, filter_shape=(3,), padding='SAME'), _STAX_ACTIVATION[activation]) for num_filters in num_filters_list[::-1]: layer = _build_unet_shell(layer, num_filters, activation=activation) network = stax.serial( layer, # Use 1x1 convolution filter to aggregate channels. Conv1D(1, filter_shape=(1,), padding='SAME')) layers_before_network = [] if num_channels > 0: layers_before_network.append( exponential_global_convolution(num_channels, grids, minval, maxval)) if apply_negativity_transform: return stax.serial(*layers_before_network, network, negativity_transform()) else: return stax.serial(*layers_before_network, network)
def build_model_stax(output_size, n_dense_units=300, conv_depth=300, n_conv_layers=2, n_dense_layers=0, kernel_size=5, across_batch=False, add_pos_encoding=False, mean_over_pos=False, mode="train"): """Build a model with convolutional layers followed by dense layers.""" del mode layers = [ cnn(conv_depth=conv_depth, n_conv_layers=n_conv_layers, kernel_size=kernel_size, across_batch=across_batch, add_pos_encoding=add_pos_encoding) ] for _ in range(n_dense_layers): layers.append(stax.Dense(n_dense_units)) layers.append(stax.Relu) layers.append(stax.Dense(output_size)) if mean_over_pos: layers.append(reduce_layer(jnp.mean, axis=1)) init_random_params, predict = stax.serial(*layers) return init_random_params, predict
def _build_unet_shell(layer, num_filters, activation): """Builds a shell in the U-net structure. *--------------* | | *--------* *----------------------* | downsampling |---------------------| | | | | block | *------------* | concat |-----| upsampling block | | |---| layer |----| | | | *--------------* *------------* *--------* *----------------------* Args: layer: (init_fn, apply_fn) pair in the bottom of the U-shape structure. num_filters: Integer, the number of filters used for downsampling and upsampling. activation: String, the activation function to use in the network. Returns: (init_fn, apply_fn) pair. """ return stax.serial( downsampling_block(num_filters, activation=activation), stax.FanOut(2), stax.parallel(stax.Identity, layer), stax.FanInConcat(axis=-1), upsampling_block(num_filters, activation=activation) )
def serial(*layers: Layer) -> InternalLayer: """Combinator for composing layers in serial. Based on `jax.example_libraries.stax.serial`. Args: *layers: a sequence of layers, each an `(init_fn, apply_fn, kernel_fn)` triple. Returns: A new layer, meaning an `(init_fn, apply_fn, kernel_fn)` triple, representing the serial composition of the given sequence of layers. """ init_fns, apply_fns, kernel_fns = zip(*layers) init_fn, apply_fn = ostax.serial(*zip(init_fns, apply_fns)) @requires(**_get_input_req_attr(kernel_fns, fold=op.rshift)) def kernel_fn(k: NTTree[Kernel], **kwargs) -> NTTree[Kernel]: # TODO(xlc): if we drop `x1_is_x2` and use `rng` instead, need split key # inside kernel functions here and parallel below. for f in kernel_fns: k = f(k, **kwargs) return k return init_fn, apply_fn, kernel_fn
def decoder(hidden_dim, out_dim): return stax.serial( stax.Dense(hidden_dim, W_init=stax.randn()), stax.Softplus, stax.Dense(out_dim, W_init=stax.randn()), stax.Sigmoid, )
def test_kohn_sham_neural_xc_density_mse_converge_tolerance( self, density_mse_converge_tolerance, expected_converged): init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(8), stax.Elu, stax.Dense(1))) params_init = init_fn(rng=random.PRNGKey(0)) states = jit_scf.kohn_sham( locations=self.locations, nuclear_charges=self.nuclear_charges, num_electrons=self.num_electrons, num_iterations=3, grids=self.grids, xc_energy_density_fn=tree_util.Partial(xc_energy_density_fn, params=params_init), interaction_fn=utils.exponential_coulomb, initial_density=self.num_electrons * utils.gaussian(grids=self.grids, center=0., sigma=0.5), density_mse_converge_tolerance=density_mse_converge_tolerance) np.testing.assert_array_equal(states.converged, expected_converged) for single_state in scf.state_iterator(states): self._test_state( single_state, self._create_testing_external_potential( utils.exponential_coulomb))
def test_create_model_from_stax(self): stax_init, stax_apply = stax.serial(stax.Dense(10)) stax_model = models.create_model_from_stax(stax_init=stax_init, stax_apply=stax_apply, sample_shape=(-1, 2), train_loss=train_loss, eval_metrics=eval_metrics) self.check_model(stax_model)
def test_local_density_approximation_wrong_output_shape(self): init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(16), stax.Elu, stax.Dense(3))) init_params = init_fn(rng=random.PRNGKey(0)) with self.assertRaisesRegex( ValueError, r'The output shape of the network ' r'should be \(-1, 1\) but got \(11, 3\)'): xc_energy_density_fn(self.density, init_params)
def test_local_density_approximation(self): init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(16), stax.Elu, stax.Dense(1))) init_params = init_fn(rng=random.PRNGKey(0)) xc_energy_density = xc_energy_density_fn(self.density, init_params) # The first layer of the network takes 1 feature (density). self.assertEqual(init_params[0][0].shape, (1, 16)) self.assertEqual(xc_energy_density.shape, (11, ))
def UNetBlock(filters, kernel_size, inner_block, **kwargs): def make_main(input_shape): return stax.serial( UnbiasedConv(filters, kernel_size, **kwargs), inner_block, UnbiasedConvTranspose(input_shape[3], kernel_size, **kwargs), ) Main = stax.shape_dependent(make_main) return stax.serial(stax.FanOut(2), stax.parallel(Main, stax.Identity), stax.FanInSum)
def IdentityBlock(kernel_size, filters): ks = kernel_size filters1, filters2 = filters def make_main(input_shape): # the number of output channels depends on the number of input channels return stax.serial( Conv(filters1, (1, 1)), BatchNorm(), Relu, Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), Relu, Conv(input_shape[3], (1, 1)), BatchNorm()) Main = stax.shape_dependent(make_main) return stax.serial(FanOut(2), stax.parallel(Main, Identity), FanInSum, Relu)
def BlockNeuralAutoregressiveNN(input_dim, hidden_factors=[8, 8], residual=None): """ An implementation of Block Neural Autoregressive neural network. **References** 1. *Block Neural Autoregressive Flow*, Nicola De Cao, Ivan Titov, Wilker Aziz :param int input_dim: The dimensionality of the input. :param list hidden_factors: Hidden layer i has ``hidden_factors[i]`` hidden units per input dimension. This corresponds to both :math:`a` and :math:`b` in reference [1]. The elements of hidden_factors must be integers. :param str residual: Type of residual connections to use. One of `None`, `"normal"`, `"gated"`. :return: an (`init_fn`, `update_fn`) pair. """ layers = [] in_factor = 1 for hidden_factor in hidden_factors: layers.append(BlockMaskedDense(input_dim, in_factor, hidden_factor)) layers.append(Tanh()) in_factor = hidden_factor layers.append(BlockMaskedDense(input_dim, in_factor, 1)) arn = stax.serial(*layers) if residual is not None: FanInResidual = (FanInResidualGated if residual == "gated" else FanInResidualNormal) arn = stax.serial(stax.FanOut(2), stax.parallel(arn, stax.Identity), FanInResidual()) def init_fun(rng, input_shape): return arn[0](rng, input_shape) def apply_fun(params, inputs, **kwargs): out, logdet = arn[1](params, (inputs, None), **kwargs) return out, logdet.reshape(inputs.shape) return init_fun, apply_fun
def create_stax_dense_model(only_digits: bool = False, hidden_units: int = 200) -> models.Model: """Creates EMNIST dense net with stax.""" num_classes = 10 if only_digits else 62 stax_init, stax_apply = stax.serial(stax.Flatten, stax.Dense(hidden_units), stax.Relu, stax.Dense(hidden_units), stax.Relu, stax.Dense(num_classes)) return models.create_model_from_stax(stax_init=stax_init, stax_apply=stax_apply, sample_shape=_STAX_SAMPLE_SHAPE, train_loss=_TRAIN_LOSS, eval_metrics=_EVAL_METRICS)
def test_masked_dense(input_dim): hidden_dim = input_dim * 3 output_dim_multiplier = input_dim - 4 mask, _ = create_mask(input_dim, [hidden_dim], np.random.permutation(input_dim), output_dim_multiplier) init_random_params, masked_dense = serial(MaskedDense(mask[0])) rng_key = random.PRNGKey(0) batch_size = 4 input_shape = (batch_size, input_dim) _, init_params = init_random_params(rng_key, input_shape) output = masked_dense(init_params, np.random.rand(*input_shape)) assert output.shape == (batch_size, hidden_dim)
def main(_): # Define the total number of training steps training_iters = 200 rng = random.PRNGKey(0) rng, key = random.split(rng) init_random_params, model_apply = stax.serial( stax.Dense(256), stax.Relu, stax.Dense(256), stax.Relu, stax.Dense(2)) # init the model _, params = init_random_params(rng, (-1, 2)) # Create the optimizer corresponding to the 0th hyperparameter configuration # with the specified amount of training steps. # opt = optix.adam(1e-4) opt = jax_optix_opt_list.optimizer_for_idx(0, training_iters) opt_state = opt.init(params) @jax.jit def loss_fn(params, batch): x, y = batch y_hat = model_apply(params, x) return jnp.mean(jnp.square(y_hat - y)) @jax.jit def train_step(params, opt_state, batch): """Train for a single step.""" value_and_grad_fn = jax.value_and_grad(loss_fn) loss, grad = value_and_grad_fn(params, batch) # Note this is not the usual optix api as we additionally need parameter # values. # updates, opt_state = opt.update(grad, opt_state) updates, opt_state = opt.update_with_params(grad, params, opt_state) new_params = optix.apply_updates(params, updates) return new_params, opt_state, loss for _ in range(training_iters): # make a random batch of fake data rng, key = random.split(rng) inp = random.normal(key, [512, 2]) / 4. target = jnp.tanh(1 / (1e-6 + inp)) # train the model a step params, opt_state, loss = train_step(params, opt_state, (inp, target)) print(loss)
def test_kohn_sham_iteration_neural_xc(self, enforce_reflection_symmetry): init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(8), stax.Elu, stax.Dense(1))) params_init = init_fn(rng=random.PRNGKey(0)) initial_state = self._create_testing_initial_state( utils.exponential_coulomb) next_state = jit_scf.kohn_sham_iteration( state=initial_state, num_electrons=self.num_electrons, xc_energy_density_fn=tree_util.Partial(xc_energy_density_fn, params=params_init), interaction_fn=utils.exponential_coulomb, enforce_reflection_symmetry=enforce_reflection_symmetry) self._test_state(next_state, initial_state)
def main(_): # Define the total number of training steps training_iters = 200 rng = random.PRNGKey(0) rng, key = random.split(rng) # Construct a model. We are using stax here. init_random_params, model_apply = stax.serial(stax.Dense(256), stax.Relu, stax.Dense(256), stax.Relu, stax.Dense(2)) # init the model _, init_params = init_random_params(rng, (-1, 2)) # Create the optimizer corresponding to the 0th hyperparameter configuration # with the specified amount of training steps. opt_init, opt_update, get_params = jax_optimizers_opt_list.optimizer_for_idx( 0, training_iters) # opt_init, opt_update, get_params = optimizers.adam(1e-4) # Initialize the optimizer state opt_state = opt_init(init_params) @jax.jit def loss_fn(params, batch): """The loss function.""" x, y = batch y_hat = model_apply(params, x) return jnp.mean(jnp.square(y_hat - y)) @jax.jit def train_step(i, opt_state, batch): """Train for a single step.""" params = get_params(opt_state) value_and_grad_fn = jax.value_and_grad(loss_fn) loss, grad = value_and_grad_fn(params, batch) return opt_update(i, grad, opt_state), loss for i in range(training_iters): # make a random batch of fake data rng, key = random.split(rng) inp = random.normal(key, [512, 2]) / 4. target = jnp.tanh(1 / (1e-6 + inp)) # train the model a step opt_state, loss = train_step(i, opt_state, (inp, target)) print(loss)
def test_kohn_sham_neural_xc(self, interaction_fn): init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(8), stax.Elu, stax.Dense(1))) params_init = init_fn(rng=random.PRNGKey(0)) state = scf.kohn_sham( locations=self.locations, nuclear_charges=self.nuclear_charges, num_electrons=self.num_electrons, num_iterations=3, grids=self.grids, xc_energy_density_fn=tree_util.Partial( xc_energy_density_fn, params=params_init), interaction_fn=interaction_fn) for single_state in scf.state_iterator(state): self._test_state( single_state, self._create_testing_external_potential(interaction_fn))
def test_kohn_sham_iteration_neural_xc_density_loss_gradient_symmetry(self): # The network only has one layer. # The initial params contains weights with shape (1, 1) and bias (1,). init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(1))) init_params = init_fn(rng=random.PRNGKey(0)) initial_state = self._create_testing_initial_state( utils.exponential_coulomb) target_density = ( utils.gaussian(grids=self.grids, center=-0.5, sigma=1.) + utils.gaussian(grids=self.grids, center=0.5, sigma=1.)) spec, flatten_init_params = np_utils.flatten(init_params) def loss(flatten_params, initial_state, target_density): state = scf.kohn_sham_iteration( state=initial_state, num_electrons=self.num_electrons, xc_energy_density_fn=tree_util.Partial( xc_energy_density_fn, params=np_utils.unflatten(spec, flatten_params)), interaction_fn=utils.exponential_coulomb, enforce_reflection_symmetry=True) return jnp.sum(jnp.abs(state.density - target_density)) * utils.get_dx( self.grids) grad_fn = jax.grad(loss) params_grad = grad_fn( flatten_init_params, initial_state=initial_state, target_density=target_density) # Check gradient values. np.testing.assert_allclose(params_grad, [-1.34137017, 0.], atol=5e-7) # Check whether the gradient values match the numerical gradient. np.testing.assert_allclose( optimize.approx_fprime( xk=flatten_init_params, f=functools.partial( loss, initial_state=initial_state, target_density=target_density), epsilon=1e-9), params_grad, atol=1e-3)
def downsampling_block(num_filters, activation): """A downsampling block. Given the input feature with spatial dimension size `m`, applies a convolution and then reduce the size to `(m + 1) / 2`. Args: num_filters: Integer, the number of filters of the three convolution layers. activation: String, the activation function to use in the network. Returns: (init_fn, apply_fn) pair. """ return stax.serial( Conv1D(num_filters, filter_shape=(3,), padding='SAME'), linear_interpolation_transpose(), _STAX_ACTIVATION[activation], )
def upsampling_block(num_filters, activation): """An upsampling block. Given the input feature with spatial dimension size `m`, upsamples the size to `2 * m - 1` and then applies a convolution. Args: num_filters: Integer, the number of filters of the convolution layers. activation: String, the activation function to use in the network. Returns: (init_fn, apply_fn) pair. """ return stax.serial( linear_interpolation(), Conv1D(num_filters, filter_shape=(3,), padding='SAME'), _STAX_ACTIVATION[activation], )
def test_global_functional_with_sliding_net_wrong_output_shape(self): init_fn, xc_energy_density_fn = ( neural_xc.global_functional( stax.serial( neural_xc.build_sliding_net(window_size=3, num_filters_list=[4, 2, 2], activation='softplus'), # Additional conv layer to make the output shape wrong. neural_xc.Conv1D(1, filter_shape=(1, ), strides=(2, ), padding='VALID')), grids=self.grids)) init_params = init_fn(rng=random.PRNGKey(0)) with self.assertRaisesRegex( ValueError, r'The output shape of the network ' r'should be \(-1, 17\) but got \(1, 9\)'): xc_energy_density_fn(self.density, init_params)
def ResNet50(num_classes): return stax.serial( GeneralConv(('HWCN', 'OIHW', 'NHWC'), 64, (7, 7), (2, 2), 'SAME'), BatchNorm(), Relu, MaxPool((3, 3), strides=(2, 2)), ConvBlock(3, [64, 64, 256], strides=(1, 1)), IdentityBlock(3, [64, 64]), IdentityBlock(3, [64, 64]), ConvBlock(3, [128, 128, 512]), IdentityBlock(3, [128, 128]), IdentityBlock(3, [128, 128]), IdentityBlock(3, [128, 128]), ConvBlock(3, [256, 256, 1024]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), ConvBlock(3, [512, 512, 2048]), IdentityBlock(3, [512, 512]), IdentityBlock(3, [512, 512]), AvgPool((7, 7)), Flatten, Dense(num_classes), LogSoftmax)
def wrap_network_with_self_interaction_layer(network, grids, interaction_fn): """Wraps a network with self-interaction layer. Args: network: an (init_fn, apply_fn) pair. * init_fn: The init_fn of the neural network. It takes an rng key and an input shape and returns an (output_shape, params) pair. * apply_fn: The apply_fn of the neural network. It takes params, inputs, and an rng key and applies the layer. grids: Float numpy array with shape (num_grids,). interaction_fn: function takes displacements and returns float numpy array with the same shape of displacements. Returns: (init_fn, apply_fn) pair. """ return stax.serial( stax.FanOut(2), stax.parallel(stax.Identity, network), self_interaction_layer(grids, interaction_fn), )
def test_kohn_sham_iteration_neural_xc_energy_loss_gradient(self): # The network only has one layer. # The initial params contains weights with shape (1, 1) and bias (1,). init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(1))) init_params = init_fn(rng=random.PRNGKey(0)) initial_state = self._create_testing_initial_state( utils.exponential_coulomb) target_energy = 2. spec, flatten_init_params = np_utils.flatten(init_params) def loss(flatten_params, initial_state, target_energy): state = scf.kohn_sham_iteration( state=initial_state, num_electrons=self.num_electrons, xc_energy_density_fn=tree_util.Partial( xc_energy_density_fn, params=np_utils.unflatten(spec, flatten_params)), interaction_fn=utils.exponential_coulomb, enforce_reflection_symmetry=True) return (state.total_energy - target_energy) ** 2 grad_fn = jax.grad(loss) params_grad = grad_fn( flatten_init_params, initial_state=initial_state, target_energy=target_energy) # Check gradient values. np.testing.assert_allclose(params_grad, [-8.549952, -14.754195]) # Check whether the gradient values match the numerical gradient. np.testing.assert_allclose( optimize.approx_fprime( xk=flatten_init_params, f=functools.partial( loss, initial_state=initial_state, target_energy=target_energy), epsilon=1e-9), params_grad, atol=2e-3)
def create_mlp(input_dim, num_channels, output_dim, omega=30): modules = [] modules.append( layer.Dense(num_channels[0], W_init=siren_init_first(), b_init=bias_uniform()) ) modules.append(layer.Sine(omega)) for nc in num_channels: modules.append( layer.Dense(nc, W_init=siren_init(omega=omega), b_init=bias_uniform()) ) modules.append(layer.Sine(omega)) modules.append( layer.Dense(output_dim, W_init=siren_init(omega=omega), b_init=bias_uniform()) ) net_init_random, net_apply = stax.serial(*modules) in_shape = (-1, input_dim) rng = create_random_generator() out_shape, net_params = net_init_random(rng, in_shape) return net_params, net_apply
def build_global_local_conv_net( num_global_filters, num_local_filters, num_local_conv_layers, activation, grids, minval, maxval, downsample_factor, apply_negativity_transform=True): """Builds global-local convolutional network. Args: num_global_filters: Integer, the number of global filters in one cell. num_local_filters: Integer, the number of local filters in one cell. num_local_conv_layers: Integer, the number of local convolution layer in one cell. activation: String, the activation function to use in the network. grids: Float numpy array with shape (num_grids,). minval: Float, the min value in the uniform sampling for exponential width. maxval: Float, the max value in the uniform sampling for exponential width. downsample_factor: Integer, the factor of downsampling. The grids are downsampled with step size 2 ** downsample_factor. apply_negativity_transform: Boolean, whether to add negativity_transform at the end. Returns: (init_fn, apply_fn) pair. """ layers = [] layers.append( global_conv_block( num_channels=num_global_filters, grids=grids, minval=minval, maxval=maxval, downsample_factor=downsample_factor)) layers.extend([ Conv1D(num_local_filters, filter_shape=(3,), padding='SAME'), _STAX_ACTIVATION[activation]] * num_local_conv_layers) layers.append( # Use unit convolution filter to aggregate channels. Conv1D(1, filter_shape=(1,), padding='SAME')) if apply_negativity_transform: layers.append(negativity_transform()) return stax.serial(*layers)