Ejemplo n.º 1
0
def global_conv_block(num_channels, grids, minval, maxval, downsample_factor):
  """Global convolution block.

  First downsample the input, then apply global conv, finally upsample and
  concatenate with the input. The input itself is one channel in the output.

  Args:
    num_channels: Integer, the number of channels.
    grids: Float numpy array with shape (num_grids,).
    minval: Float, the min value in the uniform sampling for exponential width.
    maxval: Float, the max value in the uniform sampling for exponential width.
    downsample_factor: Integer, the factor of downsampling. The grids are
        downsampled with step size 2 ** downsample_factor.

  Returns:
    (init_fn, apply_fn) pair.
  """
  layers = []
  layers.extend([linear_interpolation_transpose()] * downsample_factor)
  layers.append(exponential_global_convolution(
      num_channels=num_channels - 1,  # one channel is reserved for input.
      grids=grids,
      minval=minval,
      maxval=maxval,
      downsample_factor=downsample_factor))
  layers.extend([linear_interpolation()] * downsample_factor)
  global_conv_path = stax.serial(*layers)
  return stax.serial(
      stax.FanOut(2),
      stax.parallel(stax.Identity, global_conv_path),
      stax.FanInConcat(axis=-1),
  )
Ejemplo n.º 2
0
def ConvBlock(kernel_size, filters, strides=(2, 2)):
  ks = kernel_size
  filters1, filters2, filters3 = filters
  Main = stax.serial(
      Conv(filters1, (1, 1), strides), BatchNorm(), Relu,
      Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), Relu,
      Conv(filters3, (1, 1)), BatchNorm())
  Shortcut = stax.serial(Conv(filters3, (1, 1), strides), BatchNorm())
  return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum, Relu)
Ejemplo n.º 3
0
def encoder(hidden_dim, z_dim):
    return stax.serial(
        stax.Dense(hidden_dim, W_init=stax.randn()),
        stax.Softplus,
        stax.FanOut(2),
        stax.parallel(
            stax.Dense(z_dim, W_init=stax.randn()),
            stax.serial(stax.Dense(z_dim, W_init=stax.randn()), stax.Exp),
        ),
    )
Ejemplo n.º 4
0
def build_unet(
    num_filters_list, core_num_filters, activation,
    num_channels=0, grids=None, minval=None, maxval=None,
    apply_negativity_transform=True):
  """Builds U-net.

  This neural network is used to parameterize a many-to-many mapping.

  Args:
    num_filters_list: List of integers, the number of filters for each
      downsampling_block.
      The number of filters for each upsampling_block is in reverse order.
      For example, if num_filters_list=[16, 32, 64], there are 3
      downsampling_block with number of filters from left to right: 16, 32, 64.
      There are 3 upsampling_block with number of filters from left to right:
      64, 32 ,16.
    core_num_filters: Integer, the number of filters for the convolution layer
      at the bottom of the U-shape structure.
    activation: String, the activation function to use in the network.
    num_channels: Integer, the number of channels.
    grids: Float numpy array with shape (num_grids,).
    minval: Float, the min value in the uniform sampling for exponential width.
    maxval: Float, the max value in the uniform sampling for exponential width.
    apply_negativity_transform: Boolean, whether to add negativity_transform at
        the end.

  Returns:
    (init_fn, apply_fn) pair.
  """
  layer = stax.serial(
      Conv1D(core_num_filters, filter_shape=(3,), padding='SAME'),
      _STAX_ACTIVATION[activation],
      Conv1D(core_num_filters, filter_shape=(3,), padding='SAME'),
      _STAX_ACTIVATION[activation])
  for num_filters in num_filters_list[::-1]:
    layer = _build_unet_shell(layer, num_filters, activation=activation)
  network = stax.serial(
      layer,
      # Use 1x1 convolution filter to aggregate channels.
      Conv1D(1, filter_shape=(1,), padding='SAME'))
  layers_before_network = []
  if num_channels > 0:
    layers_before_network.append(
        exponential_global_convolution(num_channels, grids, minval, maxval))
  if apply_negativity_transform:
    return stax.serial(*layers_before_network, network, negativity_transform())
  else:
    return stax.serial(*layers_before_network, network)
Ejemplo n.º 5
0
def build_model_stax(output_size,
                     n_dense_units=300,
                     conv_depth=300,
                     n_conv_layers=2,
                     n_dense_layers=0,
                     kernel_size=5,
                     across_batch=False,
                     add_pos_encoding=False,
                     mean_over_pos=False,
                     mode="train"):
    """Build a model with convolutional layers followed by dense layers."""
    del mode
    layers = [
        cnn(conv_depth=conv_depth,
            n_conv_layers=n_conv_layers,
            kernel_size=kernel_size,
            across_batch=across_batch,
            add_pos_encoding=add_pos_encoding)
    ]
    for _ in range(n_dense_layers):
        layers.append(stax.Dense(n_dense_units))
        layers.append(stax.Relu)

    layers.append(stax.Dense(output_size))

    if mean_over_pos:
        layers.append(reduce_layer(jnp.mean, axis=1))
    init_random_params, predict = stax.serial(*layers)
    return init_random_params, predict
Ejemplo n.º 6
0
def _build_unet_shell(layer, num_filters, activation):
  """Builds a shell in the U-net structure.

  *--------------*
  |              |                     *--------*     *----------------------*
  | downsampling |---------------------|        |     |                      |
  |   block      |   *------------*    | concat |-----|   upsampling block   |
  |              |---|    layer   |----|        |     |                      |
  *--------------*   *------------*    *--------*     *----------------------*

  Args:
    layer: (init_fn, apply_fn) pair in the bottom of the U-shape structure.
    num_filters: Integer, the number of filters used for downsampling and
        upsampling.
    activation: String, the activation function to use in the network.

  Returns:
    (init_fn, apply_fn) pair.
  """
  return stax.serial(
      downsampling_block(num_filters, activation=activation),
      stax.FanOut(2),
      stax.parallel(stax.Identity, layer),
      stax.FanInConcat(axis=-1),
      upsampling_block(num_filters, activation=activation)
  )
Ejemplo n.º 7
0
def serial(*layers: Layer) -> InternalLayer:
    """Combinator for composing layers in serial.

  Based on `jax.example_libraries.stax.serial`.

  Args:
    *layers:
      a sequence of layers, each an `(init_fn, apply_fn, kernel_fn)` triple.

  Returns:
    A new layer, meaning an `(init_fn, apply_fn, kernel_fn)` triple,
    representing the serial composition of the given sequence of layers.
  """
    init_fns, apply_fns, kernel_fns = zip(*layers)
    init_fn, apply_fn = ostax.serial(*zip(init_fns, apply_fns))

    @requires(**_get_input_req_attr(kernel_fns, fold=op.rshift))
    def kernel_fn(k: NTTree[Kernel], **kwargs) -> NTTree[Kernel]:
        # TODO(xlc): if we drop `x1_is_x2` and use `rng` instead, need split key
        # inside kernel functions here and parallel below.
        for f in kernel_fns:
            k = f(k, **kwargs)
        return k

    return init_fn, apply_fn, kernel_fn
Ejemplo n.º 8
0
def decoder(hidden_dim, out_dim):
    return stax.serial(
        stax.Dense(hidden_dim, W_init=stax.randn()),
        stax.Softplus,
        stax.Dense(out_dim, W_init=stax.randn()),
        stax.Sigmoid,
    )
Ejemplo n.º 9
0
    def test_kohn_sham_neural_xc_density_mse_converge_tolerance(
            self, density_mse_converge_tolerance, expected_converged):
        init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
            stax.serial(stax.Dense(8), stax.Elu, stax.Dense(1)))
        params_init = init_fn(rng=random.PRNGKey(0))

        states = jit_scf.kohn_sham(
            locations=self.locations,
            nuclear_charges=self.nuclear_charges,
            num_electrons=self.num_electrons,
            num_iterations=3,
            grids=self.grids,
            xc_energy_density_fn=tree_util.Partial(xc_energy_density_fn,
                                                   params=params_init),
            interaction_fn=utils.exponential_coulomb,
            initial_density=self.num_electrons *
            utils.gaussian(grids=self.grids, center=0., sigma=0.5),
            density_mse_converge_tolerance=density_mse_converge_tolerance)

        np.testing.assert_array_equal(states.converged, expected_converged)

        for single_state in scf.state_iterator(states):
            self._test_state(
                single_state,
                self._create_testing_external_potential(
                    utils.exponential_coulomb))
Ejemplo n.º 10
0
 def test_create_model_from_stax(self):
     stax_init, stax_apply = stax.serial(stax.Dense(10))
     stax_model = models.create_model_from_stax(stax_init=stax_init,
                                                stax_apply=stax_apply,
                                                sample_shape=(-1, 2),
                                                train_loss=train_loss,
                                                eval_metrics=eval_metrics)
     self.check_model(stax_model)
Ejemplo n.º 11
0
    def test_local_density_approximation_wrong_output_shape(self):
        init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
            stax.serial(stax.Dense(16), stax.Elu, stax.Dense(3)))
        init_params = init_fn(rng=random.PRNGKey(0))

        with self.assertRaisesRegex(
                ValueError, r'The output shape of the network '
                r'should be \(-1, 1\) but got \(11, 3\)'):
            xc_energy_density_fn(self.density, init_params)
Ejemplo n.º 12
0
    def test_local_density_approximation(self):
        init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
            stax.serial(stax.Dense(16), stax.Elu, stax.Dense(1)))
        init_params = init_fn(rng=random.PRNGKey(0))
        xc_energy_density = xc_energy_density_fn(self.density, init_params)

        # The first layer of the network takes 1 feature (density).
        self.assertEqual(init_params[0][0].shape, (1, 16))
        self.assertEqual(xc_energy_density.shape, (11, ))
Ejemplo n.º 13
0
def UNetBlock(filters, kernel_size, inner_block, **kwargs):
    def make_main(input_shape):
        return stax.serial(
            UnbiasedConv(filters, kernel_size, **kwargs),
            inner_block,
            UnbiasedConvTranspose(input_shape[3], kernel_size, **kwargs),
        )

    Main = stax.shape_dependent(make_main)
    return stax.serial(stax.FanOut(2), stax.parallel(Main, stax.Identity),
                       stax.FanInSum)
Ejemplo n.º 14
0
def IdentityBlock(kernel_size, filters):
  ks = kernel_size
  filters1, filters2 = filters
  def make_main(input_shape):
    # the number of output channels depends on the number of input channels
    return stax.serial(
        Conv(filters1, (1, 1)), BatchNorm(), Relu,
        Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), Relu,
        Conv(input_shape[3], (1, 1)), BatchNorm())
  Main = stax.shape_dependent(make_main)
  return stax.serial(FanOut(2), stax.parallel(Main, Identity), FanInSum, Relu)
Ejemplo n.º 15
0
def BlockNeuralAutoregressiveNN(input_dim,
                                hidden_factors=[8, 8],
                                residual=None):
    """
    An implementation of Block Neural Autoregressive neural network.

    **References**

    1. *Block Neural Autoregressive Flow*,
       Nicola De Cao, Ivan Titov, Wilker Aziz

    :param int input_dim: The dimensionality of the input.
    :param list hidden_factors: Hidden layer i has ``hidden_factors[i]`` hidden units per
        input dimension. This corresponds to both :math:`a` and :math:`b` in reference [1].
        The elements of hidden_factors must be integers.
    :param str residual: Type of residual connections to use. One of `None`, `"normal"`, `"gated"`.
    :return: an (`init_fn`, `update_fn`) pair.
    """
    layers = []
    in_factor = 1
    for hidden_factor in hidden_factors:
        layers.append(BlockMaskedDense(input_dim, in_factor, hidden_factor))
        layers.append(Tanh())
        in_factor = hidden_factor
    layers.append(BlockMaskedDense(input_dim, in_factor, 1))
    arn = stax.serial(*layers)
    if residual is not None:
        FanInResidual = (FanInResidualGated
                         if residual == "gated" else FanInResidualNormal)
        arn = stax.serial(stax.FanOut(2), stax.parallel(arn, stax.Identity),
                          FanInResidual())

    def init_fun(rng, input_shape):
        return arn[0](rng, input_shape)

    def apply_fun(params, inputs, **kwargs):
        out, logdet = arn[1](params, (inputs, None), **kwargs)
        return out, logdet.reshape(inputs.shape)

    return init_fun, apply_fun
Ejemplo n.º 16
0
def create_stax_dense_model(only_digits: bool = False,
                            hidden_units: int = 200) -> models.Model:
    """Creates EMNIST dense net with stax."""
    num_classes = 10 if only_digits else 62
    stax_init, stax_apply = stax.serial(stax.Flatten,
                                        stax.Dense(hidden_units), stax.Relu,
                                        stax.Dense(hidden_units), stax.Relu,
                                        stax.Dense(num_classes))
    return models.create_model_from_stax(stax_init=stax_init,
                                         stax_apply=stax_apply,
                                         sample_shape=_STAX_SAMPLE_SHAPE,
                                         train_loss=_TRAIN_LOSS,
                                         eval_metrics=_EVAL_METRICS)
Ejemplo n.º 17
0
def test_masked_dense(input_dim):
    hidden_dim = input_dim * 3
    output_dim_multiplier = input_dim - 4
    mask, _ = create_mask(input_dim, [hidden_dim],
                          np.random.permutation(input_dim),
                          output_dim_multiplier)
    init_random_params, masked_dense = serial(MaskedDense(mask[0]))

    rng_key = random.PRNGKey(0)
    batch_size = 4
    input_shape = (batch_size, input_dim)
    _, init_params = init_random_params(rng_key, input_shape)
    output = masked_dense(init_params, np.random.rand(*input_shape))
    assert output.shape == (batch_size, hidden_dim)
Ejemplo n.º 18
0
def main(_):
  # Define the total number of training steps
  training_iters = 200

  rng = random.PRNGKey(0)

  rng, key = random.split(rng)

  init_random_params, model_apply = stax.serial(
      stax.Dense(256), stax.Relu, stax.Dense(256), stax.Relu, stax.Dense(2))

  # init the model
  _, params = init_random_params(rng, (-1, 2))

  # Create the optimizer corresponding to the 0th hyperparameter configuration
  # with the specified amount of training steps.
  # opt = optix.adam(1e-4)
  opt = jax_optix_opt_list.optimizer_for_idx(0, training_iters)

  opt_state = opt.init(params)

  @jax.jit
  def loss_fn(params, batch):
    x, y = batch
    y_hat = model_apply(params, x)
    return jnp.mean(jnp.square(y_hat - y))

  @jax.jit
  def train_step(params, opt_state, batch):
    """Train for a single step."""
    value_and_grad_fn = jax.value_and_grad(loss_fn)
    loss, grad = value_and_grad_fn(params, batch)

    # Note this is not the usual optix api as we additionally need parameter
    # values.
    # updates, opt_state = opt.update(grad, opt_state)
    updates, opt_state = opt.update_with_params(grad, params, opt_state)

    new_params = optix.apply_updates(params, updates)
    return new_params, opt_state, loss

  for _ in range(training_iters):
    # make a random batch of fake data
    rng, key = random.split(rng)
    inp = random.normal(key, [512, 2]) / 4.
    target = jnp.tanh(1 / (1e-6 + inp))

    # train the model a step
    params, opt_state, loss = train_step(params, opt_state, (inp, target))
    print(loss)
Ejemplo n.º 19
0
 def test_kohn_sham_iteration_neural_xc(self, enforce_reflection_symmetry):
     init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
         stax.serial(stax.Dense(8), stax.Elu, stax.Dense(1)))
     params_init = init_fn(rng=random.PRNGKey(0))
     initial_state = self._create_testing_initial_state(
         utils.exponential_coulomb)
     next_state = jit_scf.kohn_sham_iteration(
         state=initial_state,
         num_electrons=self.num_electrons,
         xc_energy_density_fn=tree_util.Partial(xc_energy_density_fn,
                                                params=params_init),
         interaction_fn=utils.exponential_coulomb,
         enforce_reflection_symmetry=enforce_reflection_symmetry)
     self._test_state(next_state, initial_state)
Ejemplo n.º 20
0
def main(_):
    # Define the total number of training steps
    training_iters = 200

    rng = random.PRNGKey(0)

    rng, key = random.split(rng)

    # Construct a model. We are using stax here.
    init_random_params, model_apply = stax.serial(stax.Dense(256), stax.Relu,
                                                  stax.Dense(256), stax.Relu,
                                                  stax.Dense(2))

    # init the model
    _, init_params = init_random_params(rng, (-1, 2))

    # Create the optimizer corresponding to the 0th hyperparameter configuration
    # with the specified amount of training steps.
    opt_init, opt_update, get_params = jax_optimizers_opt_list.optimizer_for_idx(
        0, training_iters)
    # opt_init, opt_update, get_params = optimizers.adam(1e-4)

    # Initialize the optimizer state
    opt_state = opt_init(init_params)

    @jax.jit
    def loss_fn(params, batch):
        """The loss function."""
        x, y = batch
        y_hat = model_apply(params, x)
        return jnp.mean(jnp.square(y_hat - y))

    @jax.jit
    def train_step(i, opt_state, batch):
        """Train for a single step."""
        params = get_params(opt_state)
        value_and_grad_fn = jax.value_and_grad(loss_fn)
        loss, grad = value_and_grad_fn(params, batch)
        return opt_update(i, grad, opt_state), loss

    for i in range(training_iters):
        # make a random batch of fake data
        rng, key = random.split(rng)
        inp = random.normal(key, [512, 2]) / 4.
        target = jnp.tanh(1 / (1e-6 + inp))

        # train the model a step
        opt_state, loss = train_step(i, opt_state, (inp, target))
        print(loss)
Ejemplo n.º 21
0
 def test_kohn_sham_neural_xc(self, interaction_fn):
   init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
       stax.serial(stax.Dense(8), stax.Elu, stax.Dense(1)))
   params_init = init_fn(rng=random.PRNGKey(0))
   state = scf.kohn_sham(
       locations=self.locations,
       nuclear_charges=self.nuclear_charges,
       num_electrons=self.num_electrons,
       num_iterations=3,
       grids=self.grids,
       xc_energy_density_fn=tree_util.Partial(
           xc_energy_density_fn, params=params_init),
       interaction_fn=interaction_fn)
   for single_state in scf.state_iterator(state):
     self._test_state(
         single_state,
         self._create_testing_external_potential(interaction_fn))
Ejemplo n.º 22
0
  def test_kohn_sham_iteration_neural_xc_density_loss_gradient_symmetry(self):
    # The network only has one layer.
    # The initial params contains weights with shape (1, 1) and bias (1,).
    init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
        stax.serial(stax.Dense(1)))
    init_params = init_fn(rng=random.PRNGKey(0))
    initial_state = self._create_testing_initial_state(
        utils.exponential_coulomb)
    target_density = (
        utils.gaussian(grids=self.grids, center=-0.5, sigma=1.)
        + utils.gaussian(grids=self.grids, center=0.5, sigma=1.))
    spec, flatten_init_params = np_utils.flatten(init_params)

    def loss(flatten_params, initial_state, target_density):
      state = scf.kohn_sham_iteration(
          state=initial_state,
          num_electrons=self.num_electrons,
          xc_energy_density_fn=tree_util.Partial(
              xc_energy_density_fn,
              params=np_utils.unflatten(spec, flatten_params)),
          interaction_fn=utils.exponential_coulomb,
          enforce_reflection_symmetry=True)
      return jnp.sum(jnp.abs(state.density - target_density)) * utils.get_dx(
          self.grids)

    grad_fn = jax.grad(loss)

    params_grad = grad_fn(
        flatten_init_params,
        initial_state=initial_state,
        target_density=target_density)

    # Check gradient values.
    np.testing.assert_allclose(params_grad, [-1.34137017, 0.], atol=5e-7)

    # Check whether the gradient values match the numerical gradient.
    np.testing.assert_allclose(
        optimize.approx_fprime(
            xk=flatten_init_params,
            f=functools.partial(
                loss,
                initial_state=initial_state,
                target_density=target_density),
            epsilon=1e-9),
        params_grad, atol=1e-3)
Ejemplo n.º 23
0
def downsampling_block(num_filters, activation):
  """A downsampling block.

  Given the input feature with spatial dimension size `m`, applies a convolution
  and then reduce the size to `(m + 1) / 2`.

  Args:
    num_filters: Integer, the number of filters of the three convolution layers.
    activation: String, the activation function to use in the network.

  Returns:
    (init_fn, apply_fn) pair.
  """
  return stax.serial(
      Conv1D(num_filters, filter_shape=(3,), padding='SAME'),
      linear_interpolation_transpose(),
      _STAX_ACTIVATION[activation],
  )
Ejemplo n.º 24
0
def upsampling_block(num_filters, activation):
  """An upsampling block.

  Given the input feature with spatial dimension size `m`, upsamples the size to
  `2 * m - 1` and then applies a convolution.

  Args:
    num_filters: Integer, the number of filters of the convolution layers.
    activation: String, the activation function to use in the network.

  Returns:
    (init_fn, apply_fn) pair.
  """
  return stax.serial(
      linear_interpolation(),
      Conv1D(num_filters, filter_shape=(3,), padding='SAME'),
      _STAX_ACTIVATION[activation],
  )
Ejemplo n.º 25
0
    def test_global_functional_with_sliding_net_wrong_output_shape(self):
        init_fn, xc_energy_density_fn = (
            neural_xc.global_functional(
                stax.serial(
                    neural_xc.build_sliding_net(window_size=3,
                                                num_filters_list=[4, 2, 2],
                                                activation='softplus'),
                    # Additional conv layer to make the output shape wrong.
                    neural_xc.Conv1D(1,
                                     filter_shape=(1, ),
                                     strides=(2, ),
                                     padding='VALID')),
                grids=self.grids))
        init_params = init_fn(rng=random.PRNGKey(0))

        with self.assertRaisesRegex(
                ValueError, r'The output shape of the network '
                r'should be \(-1, 17\) but got \(1, 9\)'):
            xc_energy_density_fn(self.density, init_params)
Ejemplo n.º 26
0
def ResNet50(num_classes):
  return stax.serial(
      GeneralConv(('HWCN', 'OIHW', 'NHWC'), 64, (7, 7), (2, 2), 'SAME'),
      BatchNorm(), Relu, MaxPool((3, 3), strides=(2, 2)),
      ConvBlock(3, [64, 64, 256], strides=(1, 1)),
      IdentityBlock(3, [64, 64]),
      IdentityBlock(3, [64, 64]),
      ConvBlock(3, [128, 128, 512]),
      IdentityBlock(3, [128, 128]),
      IdentityBlock(3, [128, 128]),
      IdentityBlock(3, [128, 128]),
      ConvBlock(3, [256, 256, 1024]),
      IdentityBlock(3, [256, 256]),
      IdentityBlock(3, [256, 256]),
      IdentityBlock(3, [256, 256]),
      IdentityBlock(3, [256, 256]),
      IdentityBlock(3, [256, 256]),
      ConvBlock(3, [512, 512, 2048]),
      IdentityBlock(3, [512, 512]),
      IdentityBlock(3, [512, 512]),
      AvgPool((7, 7)), Flatten, Dense(num_classes), LogSoftmax)
Ejemplo n.º 27
0
def wrap_network_with_self_interaction_layer(network, grids, interaction_fn):
  """Wraps a network with self-interaction layer.

  Args:
    network: an (init_fn, apply_fn) pair.
     * init_fn: The init_fn of the neural network. It takes an rng key and
         an input shape and returns an (output_shape, params) pair.
     * apply_fn: The apply_fn of the neural network. It takes params,
         inputs, and an rng key and applies the layer.
    grids: Float numpy array with shape (num_grids,).
    interaction_fn: function takes displacements and returns
        float numpy array with the same shape of displacements.

  Returns:
    (init_fn, apply_fn) pair.
  """
  return stax.serial(
      stax.FanOut(2),
      stax.parallel(stax.Identity, network),
      self_interaction_layer(grids, interaction_fn),
  )
Ejemplo n.º 28
0
  def test_kohn_sham_iteration_neural_xc_energy_loss_gradient(self):
    # The network only has one layer.
    # The initial params contains weights with shape (1, 1) and bias (1,).
    init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
        stax.serial(stax.Dense(1)))
    init_params = init_fn(rng=random.PRNGKey(0))
    initial_state = self._create_testing_initial_state(
        utils.exponential_coulomb)
    target_energy = 2.
    spec, flatten_init_params = np_utils.flatten(init_params)

    def loss(flatten_params, initial_state, target_energy):
      state = scf.kohn_sham_iteration(
          state=initial_state,
          num_electrons=self.num_electrons,
          xc_energy_density_fn=tree_util.Partial(
              xc_energy_density_fn,
              params=np_utils.unflatten(spec, flatten_params)),
          interaction_fn=utils.exponential_coulomb,
          enforce_reflection_symmetry=True)
      return (state.total_energy - target_energy) ** 2

    grad_fn = jax.grad(loss)

    params_grad = grad_fn(
        flatten_init_params,
        initial_state=initial_state,
        target_energy=target_energy)

    # Check gradient values.
    np.testing.assert_allclose(params_grad, [-8.549952, -14.754195])

    # Check whether the gradient values match the numerical gradient.
    np.testing.assert_allclose(
        optimize.approx_fprime(
            xk=flatten_init_params,
            f=functools.partial(
                loss, initial_state=initial_state, target_energy=target_energy),
            epsilon=1e-9),
        params_grad, atol=2e-3)
Ejemplo n.º 29
0
def create_mlp(input_dim, num_channels, output_dim, omega=30):
    modules = []
    modules.append(
        layer.Dense(num_channels[0], W_init=siren_init_first(), b_init=bias_uniform())
    )
    modules.append(layer.Sine(omega))
    for nc in num_channels:
        modules.append(
            layer.Dense(nc, W_init=siren_init(omega=omega), b_init=bias_uniform())
        )
        modules.append(layer.Sine(omega))

    modules.append(
        layer.Dense(output_dim, W_init=siren_init(omega=omega), b_init=bias_uniform())
    )
    net_init_random, net_apply = stax.serial(*modules)

    in_shape = (-1, input_dim)
    rng = create_random_generator()
    out_shape, net_params = net_init_random(rng, in_shape)

    return net_params, net_apply
Ejemplo n.º 30
0
def build_global_local_conv_net(
    num_global_filters, num_local_filters, num_local_conv_layers, activation,
    grids, minval, maxval, downsample_factor, apply_negativity_transform=True):
  """Builds global-local convolutional network.

  Args:
    num_global_filters: Integer, the number of global filters in one cell.
    num_local_filters: Integer, the number of local filters in one cell.
    num_local_conv_layers: Integer, the number of local convolution layer in
        one cell.
    activation: String, the activation function to use in the network.
    grids: Float numpy array with shape (num_grids,).
    minval: Float, the min value in the uniform sampling for exponential width.
    maxval: Float, the max value in the uniform sampling for exponential width.
    downsample_factor: Integer, the factor of downsampling. The grids are
        downsampled with step size 2 ** downsample_factor.
    apply_negativity_transform: Boolean, whether to add negativity_transform at
        the end.

  Returns:
    (init_fn, apply_fn) pair.
  """
  layers = []
  layers.append(
      global_conv_block(
          num_channels=num_global_filters,
          grids=grids,
          minval=minval,
          maxval=maxval,
          downsample_factor=downsample_factor))
  layers.extend([
      Conv1D(num_local_filters, filter_shape=(3,), padding='SAME'),
      _STAX_ACTIVATION[activation]] * num_local_conv_layers)
  layers.append(
      # Use unit convolution filter to aggregate channels.
      Conv1D(1, filter_shape=(1,), padding='SAME'))
  if apply_negativity_transform:
    layers.append(negativity_transform())
  return stax.serial(*layers)