Exemple #1
0
def decoder(hidden_dim, out_dim):
    return stax.serial(
        stax.Dense(hidden_dim, W_init=stax.randn()),
        stax.Softplus,
        stax.Dense(out_dim, W_init=stax.randn()),
        stax.Sigmoid,
    )
Exemple #2
0
def build_model_stax(output_size,
                     n_dense_units=300,
                     conv_depth=300,
                     n_conv_layers=2,
                     n_dense_layers=0,
                     kernel_size=5,
                     across_batch=False,
                     add_pos_encoding=False,
                     mean_over_pos=False,
                     mode="train"):
    """Build a model with convolutional layers followed by dense layers."""
    del mode
    layers = [
        cnn(conv_depth=conv_depth,
            n_conv_layers=n_conv_layers,
            kernel_size=kernel_size,
            across_batch=across_batch,
            add_pos_encoding=add_pos_encoding)
    ]
    for _ in range(n_dense_layers):
        layers.append(stax.Dense(n_dense_units))
        layers.append(stax.Relu)

    layers.append(stax.Dense(output_size))

    if mean_over_pos:
        layers.append(reduce_layer(jnp.mean, axis=1))
    init_random_params, predict = stax.serial(*layers)
    return init_random_params, predict
Exemple #3
0
    def test_kohn_sham_neural_xc_density_mse_converge_tolerance(
            self, density_mse_converge_tolerance, expected_converged):
        init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
            stax.serial(stax.Dense(8), stax.Elu, stax.Dense(1)))
        params_init = init_fn(rng=random.PRNGKey(0))

        states = jit_scf.kohn_sham(
            locations=self.locations,
            nuclear_charges=self.nuclear_charges,
            num_electrons=self.num_electrons,
            num_iterations=3,
            grids=self.grids,
            xc_energy_density_fn=tree_util.Partial(xc_energy_density_fn,
                                                   params=params_init),
            interaction_fn=utils.exponential_coulomb,
            initial_density=self.num_electrons *
            utils.gaussian(grids=self.grids, center=0., sigma=0.5),
            density_mse_converge_tolerance=density_mse_converge_tolerance)

        np.testing.assert_array_equal(states.converged, expected_converged)

        for single_state in scf.state_iterator(states):
            self._test_state(
                single_state,
                self._create_testing_external_potential(
                    utils.exponential_coulomb))
    def test_local_density_approximation_wrong_output_shape(self):
        init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
            stax.serial(stax.Dense(16), stax.Elu, stax.Dense(3)))
        init_params = init_fn(rng=random.PRNGKey(0))

        with self.assertRaisesRegex(
                ValueError, r'The output shape of the network '
                r'should be \(-1, 1\) but got \(11, 3\)'):
            xc_energy_density_fn(self.density, init_params)
    def test_local_density_approximation(self):
        init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
            stax.serial(stax.Dense(16), stax.Elu, stax.Dense(1)))
        init_params = init_fn(rng=random.PRNGKey(0))
        xc_energy_density = xc_energy_density_fn(self.density, init_params)

        # The first layer of the network takes 1 feature (density).
        self.assertEqual(init_params[0][0].shape, (1, 16))
        self.assertEqual(xc_energy_density.shape, (11, ))
Exemple #6
0
def encoder(hidden_dim, z_dim):
    return stax.serial(
        stax.Dense(hidden_dim, W_init=stax.randn()),
        stax.Softplus,
        stax.FanOut(2),
        stax.parallel(
            stax.Dense(z_dim, W_init=stax.randn()),
            stax.serial(stax.Dense(z_dim, W_init=stax.randn()), stax.Exp),
        ),
    )
Exemple #7
0
def create_stax_dense_model(only_digits: bool = False,
                            hidden_units: int = 200) -> models.Model:
    """Creates EMNIST dense net with stax."""
    num_classes = 10 if only_digits else 62
    stax_init, stax_apply = stax.serial(stax.Flatten,
                                        stax.Dense(hidden_units), stax.Relu,
                                        stax.Dense(hidden_units), stax.Relu,
                                        stax.Dense(num_classes))
    return models.create_model_from_stax(stax_init=stax_init,
                                         stax_apply=stax_apply,
                                         sample_shape=_STAX_SAMPLE_SHAPE,
                                         train_loss=_TRAIN_LOSS,
                                         eval_metrics=_EVAL_METRICS)
Exemple #8
0
 def test_kohn_sham_iteration_neural_xc(self, enforce_reflection_symmetry):
     init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
         stax.serial(stax.Dense(8), stax.Elu, stax.Dense(1)))
     params_init = init_fn(rng=random.PRNGKey(0))
     initial_state = self._create_testing_initial_state(
         utils.exponential_coulomb)
     next_state = jit_scf.kohn_sham_iteration(
         state=initial_state,
         num_electrons=self.num_electrons,
         xc_energy_density_fn=tree_util.Partial(xc_energy_density_fn,
                                                params=params_init),
         interaction_fn=utils.exponential_coulomb,
         enforce_reflection_symmetry=enforce_reflection_symmetry)
     self._test_state(next_state, initial_state)
Exemple #9
0
def main(_):
  # Define the total number of training steps
  training_iters = 200

  rng = random.PRNGKey(0)

  rng, key = random.split(rng)

  init_random_params, model_apply = stax.serial(
      stax.Dense(256), stax.Relu, stax.Dense(256), stax.Relu, stax.Dense(2))

  # init the model
  _, params = init_random_params(rng, (-1, 2))

  # Create the optimizer corresponding to the 0th hyperparameter configuration
  # with the specified amount of training steps.
  # opt = optix.adam(1e-4)
  opt = jax_optix_opt_list.optimizer_for_idx(0, training_iters)

  opt_state = opt.init(params)

  @jax.jit
  def loss_fn(params, batch):
    x, y = batch
    y_hat = model_apply(params, x)
    return jnp.mean(jnp.square(y_hat - y))

  @jax.jit
  def train_step(params, opt_state, batch):
    """Train for a single step."""
    value_and_grad_fn = jax.value_and_grad(loss_fn)
    loss, grad = value_and_grad_fn(params, batch)

    # Note this is not the usual optix api as we additionally need parameter
    # values.
    # updates, opt_state = opt.update(grad, opt_state)
    updates, opt_state = opt.update_with_params(grad, params, opt_state)

    new_params = optix.apply_updates(params, updates)
    return new_params, opt_state, loss

  for _ in range(training_iters):
    # make a random batch of fake data
    rng, key = random.split(rng)
    inp = random.normal(key, [512, 2]) / 4.
    target = jnp.tanh(1 / (1e-6 + inp))

    # train the model a step
    params, opt_state, loss = train_step(params, opt_state, (inp, target))
    print(loss)
Exemple #10
0
def main(_):
    # Define the total number of training steps
    training_iters = 200

    rng = random.PRNGKey(0)

    rng, key = random.split(rng)

    # Construct a model. We are using stax here.
    init_random_params, model_apply = stax.serial(stax.Dense(256), stax.Relu,
                                                  stax.Dense(256), stax.Relu,
                                                  stax.Dense(2))

    # init the model
    _, init_params = init_random_params(rng, (-1, 2))

    # Create the optimizer corresponding to the 0th hyperparameter configuration
    # with the specified amount of training steps.
    opt_init, opt_update, get_params = jax_optimizers_opt_list.optimizer_for_idx(
        0, training_iters)
    # opt_init, opt_update, get_params = optimizers.adam(1e-4)

    # Initialize the optimizer state
    opt_state = opt_init(init_params)

    @jax.jit
    def loss_fn(params, batch):
        """The loss function."""
        x, y = batch
        y_hat = model_apply(params, x)
        return jnp.mean(jnp.square(y_hat - y))

    @jax.jit
    def train_step(i, opt_state, batch):
        """Train for a single step."""
        params = get_params(opt_state)
        value_and_grad_fn = jax.value_and_grad(loss_fn)
        loss, grad = value_and_grad_fn(params, batch)
        return opt_update(i, grad, opt_state), loss

    for i in range(training_iters):
        # make a random batch of fake data
        rng, key = random.split(rng)
        inp = random.normal(key, [512, 2]) / 4.
        target = jnp.tanh(1 / (1e-6 + inp))

        # train the model a step
        opt_state, loss = train_step(i, opt_state, (inp, target))
        print(loss)
Exemple #11
0
 def test_create_model_from_stax(self):
     stax_init, stax_apply = stax.serial(stax.Dense(10))
     stax_model = models.create_model_from_stax(stax_init=stax_init,
                                                stax_apply=stax_apply,
                                                sample_shape=(-1, 2),
                                                train_loss=train_loss,
                                                eval_metrics=eval_metrics)
     self.check_model(stax_model)
Exemple #12
0
 def test_kohn_sham_neural_xc(self, interaction_fn):
   init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
       stax.serial(stax.Dense(8), stax.Elu, stax.Dense(1)))
   params_init = init_fn(rng=random.PRNGKey(0))
   state = scf.kohn_sham(
       locations=self.locations,
       nuclear_charges=self.nuclear_charges,
       num_electrons=self.num_electrons,
       num_iterations=3,
       grids=self.grids,
       xc_energy_density_fn=tree_util.Partial(
           xc_energy_density_fn, params=params_init),
       interaction_fn=interaction_fn)
   for single_state in scf.state_iterator(state):
     self._test_state(
         single_state,
         self._create_testing_external_potential(interaction_fn))
Exemple #13
0
  def test_kohn_sham_iteration_neural_xc_density_loss_gradient_symmetry(self):
    # The network only has one layer.
    # The initial params contains weights with shape (1, 1) and bias (1,).
    init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
        stax.serial(stax.Dense(1)))
    init_params = init_fn(rng=random.PRNGKey(0))
    initial_state = self._create_testing_initial_state(
        utils.exponential_coulomb)
    target_density = (
        utils.gaussian(grids=self.grids, center=-0.5, sigma=1.)
        + utils.gaussian(grids=self.grids, center=0.5, sigma=1.))
    spec, flatten_init_params = np_utils.flatten(init_params)

    def loss(flatten_params, initial_state, target_density):
      state = scf.kohn_sham_iteration(
          state=initial_state,
          num_electrons=self.num_electrons,
          xc_energy_density_fn=tree_util.Partial(
              xc_energy_density_fn,
              params=np_utils.unflatten(spec, flatten_params)),
          interaction_fn=utils.exponential_coulomb,
          enforce_reflection_symmetry=True)
      return jnp.sum(jnp.abs(state.density - target_density)) * utils.get_dx(
          self.grids)

    grad_fn = jax.grad(loss)

    params_grad = grad_fn(
        flatten_init_params,
        initial_state=initial_state,
        target_density=target_density)

    # Check gradient values.
    np.testing.assert_allclose(params_grad, [-1.34137017, 0.], atol=5e-7)

    # Check whether the gradient values match the numerical gradient.
    np.testing.assert_allclose(
        optimize.approx_fprime(
            xk=flatten_init_params,
            f=functools.partial(
                loss,
                initial_state=initial_state,
                target_density=target_density),
            epsilon=1e-9),
        params_grad, atol=1e-3)
Exemple #14
0
  def test_kohn_sham_iteration_neural_xc_energy_loss_gradient(self):
    # The network only has one layer.
    # The initial params contains weights with shape (1, 1) and bias (1,).
    init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
        stax.serial(stax.Dense(1)))
    init_params = init_fn(rng=random.PRNGKey(0))
    initial_state = self._create_testing_initial_state(
        utils.exponential_coulomb)
    target_energy = 2.
    spec, flatten_init_params = np_utils.flatten(init_params)

    def loss(flatten_params, initial_state, target_energy):
      state = scf.kohn_sham_iteration(
          state=initial_state,
          num_electrons=self.num_electrons,
          xc_energy_density_fn=tree_util.Partial(
              xc_energy_density_fn,
              params=np_utils.unflatten(spec, flatten_params)),
          interaction_fn=utils.exponential_coulomb,
          enforce_reflection_symmetry=True)
      return (state.total_energy - target_energy) ** 2

    grad_fn = jax.grad(loss)

    params_grad = grad_fn(
        flatten_init_params,
        initial_state=initial_state,
        target_energy=target_energy)

    # Check gradient values.
    np.testing.assert_allclose(params_grad, [-8.549952, -14.754195])

    # Check whether the gradient values match the numerical gradient.
    np.testing.assert_allclose(
        optimize.approx_fprime(
            xk=flatten_init_params,
            f=functools.partial(
                loss, initial_state=initial_state, target_energy=target_energy),
            epsilon=1e-9),
        params_grad, atol=2e-3)
flags.DEFINE_integer('seed', 0, 'Seed for jax PRNG')
flags.DEFINE_integer(
    'microbatches', None, 'Number of microbatches '
    '(must evenly divide batch_size)')
flags.DEFINE_string('model_dir', None, 'Model directory')


init_random_params, predict = stax.serial(
    stax.Conv(16, (8, 8), padding='SAME', strides=(2, 2)),
    stax.Relu,
    stax.MaxPool((2, 2), (1, 1)),
    stax.Conv(32, (4, 4), padding='VALID', strides=(2, 2)),
    stax.Relu,
    stax.MaxPool((2, 2), (1, 1)),
    stax.Flatten,
    stax.Dense(32),
    stax.Relu,
    stax.Dense(10),
)


def loss(params, batch):
  inputs, targets = batch
  logits = predict(params, inputs)
  logits = stax.logsoftmax(logits)  # log normalize
  return -jnp.mean(jnp.sum(logits * targets, axis=1))  # cross entropy loss


def accuracy(params, batch):
  inputs, targets = batch
  target_class = jnp.argmax(targets, axis=1)
Exemple #16
0
class StaxTest(jtu.JaxTestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": f"_shape={shape}",
            "shape": shape
        } for shape in [(2, 3), (5, )]))
    def testRandnInitShape(self, shape):
        key = random.PRNGKey(0)
        out = stax.randn()(key, shape)
        self.assertEqual(out.shape, shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": f"_shape={shape}",
            "shape": shape
        } for shape in [(2, 3), (2, 3, 4)]))
    def testGlorotInitShape(self, shape):
        key = random.PRNGKey(0)
        out = stax.glorot()(key, shape)
        self.assertEqual(out.shape, shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}"
            .format(channels, filter_shape, padding, strides, input_shape),
            "channels":
            channels,
            "filter_shape":
            filter_shape,
            "padding":
            padding,
            "strides":
            strides,
            "input_shape":
            input_shape
        } for channels in [2, 3] for filter_shape in [(1, 1), (2, 3)]
                            for padding in ["SAME", "VALID"]
                            for strides in [None, (2, 1)]
                            for input_shape in [(2, 10, 11, 1)]))
    def testConvShape(self, channels, filter_shape, padding, strides,
                      input_shape):
        init_fun, apply_fun = stax.Conv(channels,
                                        filter_shape,
                                        strides=strides,
                                        padding=padding)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}"
            .format(channels, filter_shape, padding, strides, input_shape),
            "channels":
            channels,
            "filter_shape":
            filter_shape,
            "padding":
            padding,
            "strides":
            strides,
            "input_shape":
            input_shape
        } for channels in [2, 3] for filter_shape in [(1, 1), (2, 3), (3, 3)]
                            for padding in ["SAME", "VALID"]
                            for strides in [None, (2, 1), (2, 2)]
                            for input_shape in [(2, 10, 11, 1)]))
    def testConvTransposeShape(self, channels, filter_shape, padding, strides,
                               input_shape):
        init_fun, apply_fun = stax.ConvTranspose(
            channels,
            filter_shape,  # 2D
            strides=strides,
            padding=padding)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}"
            .format(channels, filter_shape, padding, strides, input_shape),
            "channels":
            channels,
            "filter_shape":
            filter_shape,
            "padding":
            padding,
            "strides":
            strides,
            "input_shape":
            input_shape
        } for channels in [2, 3] for filter_shape in [(1, ), (2, ), (3, )]
                            for padding in ["SAME", "VALID"]
                            for strides in [None, (1, ), (2, )]
                            for input_shape in [(2, 10, 1)]))
    def testConv1DTransposeShape(self, channels, filter_shape, padding,
                                 strides, input_shape):
        init_fun, apply_fun = stax.Conv1DTranspose(channels,
                                                   filter_shape,
                                                   strides=strides,
                                                   padding=padding)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_out_dim={}_input_shape={}".format(out_dim, input_shape),
            "out_dim":
            out_dim,
            "input_shape":
            input_shape
        } for out_dim in [3, 4] for input_shape in [(2, 3), (3, 4)]))
    def testDenseShape(self, out_dim, input_shape):
        init_fun, apply_fun = stax.Dense(out_dim)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_input_shape={}_nonlinear={}".format(input_shape, nonlinear),
                "input_shape":
                input_shape,
                "nonlinear":
                nonlinear
            } for input_shape in [(2, 3), (2, 3, 4)]
            for nonlinear in ["Relu", "Sigmoid", "Elu", "LeakyRelu"]))
    def testNonlinearShape(self, input_shape, nonlinear):
        init_fun, apply_fun = getattr(stax, nonlinear)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_window_shape={}_padding={}_strides={}_input_shape={}"
            "_maxpool={}_spec={}".format(window_shape, padding, strides,
                                         input_shape, max_pool, spec),
            "window_shape":
            window_shape,
            "padding":
            padding,
            "strides":
            strides,
            "input_shape":
            input_shape,
            "max_pool":
            max_pool,
            "spec":
            spec
        } for window_shape in [(1, 1), (2, 3)] for padding in ["VALID"]
                            for strides in [None, (2, 1)]
                            for input_shape in [(2, 5, 6, 4)]
                            for max_pool in [False, True]
                            for spec in ["NHWC", "NCHW", "WHNC", "WHCN"]))
    def testPoolingShape(self, window_shape, padding, strides, input_shape,
                         max_pool, spec):
        layer = stax.MaxPool if max_pool else stax.AvgPool
        init_fun, apply_fun = layer(window_shape,
                                    padding=padding,
                                    strides=strides,
                                    spec=spec)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": f"_shape={input_shape}",
            "input_shape": input_shape
        } for input_shape in [(2, 3), (2, 3, 4)]))
    def testFlattenShape(self, input_shape):
        init_fun, apply_fun = stax.Flatten
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": f"_input_shape={input_shape}_spec={i}",
                "input_shape": input_shape,
                "spec": spec
            } for input_shape in [(2, 5, 6, 1)]
            for i, spec in enumerate([[stax.Conv(3, (
                2, 2))], [stax.Conv(3, (2, 2)), stax.Flatten,
                          stax.Dense(4)]])))
    def testSerialComposeLayersShape(self, input_shape, spec):
        init_fun, apply_fun = stax.serial(*spec)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": f"_input_shape={input_shape}",
            "input_shape": input_shape
        } for input_shape in [(3, 4), (2, 5, 6, 1)]))
    def testDropoutShape(self, input_shape):
        init_fun, apply_fun = stax.Dropout(0.9)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": f"_input_shape={input_shape}",
            "input_shape": input_shape
        } for input_shape in [(3, 4), (2, 5, 6, 1)]))
    def testFanInSum(self, input_shape):
        init_fun, apply_fun = stax.FanInSum
        _CheckShapeAgreement(self, init_fun, apply_fun,
                             [input_shape, input_shape])

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": f"_inshapes={input_shapes}_axis={axis}",
            "input_shapes": input_shapes,
            "axis": axis
        } for input_shapes, axis in [
            ([(2, 3), (2, 1)], 1),
            ([(2, 3), (2, 1)], -1),
            ([(1, 2, 4), (1, 1, 4)], 1),
        ]))
    def testFanInConcat(self, input_shapes, axis):
        init_fun, apply_fun = stax.FanInConcat(axis)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shapes)

    def testIssue182(self):
        key = random.PRNGKey(0)
        init_fun, apply_fun = stax.Softmax
        input_shape = (10, 3)
        inputs = np.arange(30.).astype("float32").reshape(input_shape)

        out_shape, params = init_fun(key, input_shape)
        out = apply_fun(params, inputs)

        assert out_shape == out.shape
        assert np.allclose(np.sum(np.asarray(out), -1), 1.)

    def testBatchNormNoScaleOrCenter(self):
        key = random.PRNGKey(0)
        axes = (0, 1, 2)
        init_fun, apply_fun = stax.BatchNorm(axis=axes,
                                             center=False,
                                             scale=False)
        input_shape = (4, 5, 6, 7)
        inputs = random_inputs(self.rng(), input_shape)

        out_shape, params = init_fun(key, input_shape)
        out = apply_fun(params, inputs)
        means = np.mean(out, axis=(0, 1, 2))
        std_devs = np.std(out, axis=(0, 1, 2))
        assert np.allclose(means, np.zeros_like(means), atol=1e-4)
        assert np.allclose(std_devs, np.ones_like(std_devs), atol=1e-4)

    def testBatchNormShapeNHWC(self):
        key = random.PRNGKey(0)
        init_fun, apply_fun = stax.BatchNorm(axis=(0, 1, 2))
        input_shape = (4, 5, 6, 7)
        inputs = random_inputs(self.rng(), input_shape)

        out_shape, params = init_fun(key, input_shape)
        out = apply_fun(params, inputs)

        self.assertEqual(out_shape, input_shape)
        beta, gamma = params
        self.assertEqual(beta.shape, (7, ))
        self.assertEqual(gamma.shape, (7, ))
        self.assertEqual(out_shape, out.shape)

    def testBatchNormShapeNCHW(self):
        key = random.PRNGKey(0)
        # Regression test for https://github.com/google/jax/issues/461
        init_fun, apply_fun = stax.BatchNorm(axis=(0, 2, 3))
        input_shape = (4, 5, 6, 7)
        inputs = random_inputs(self.rng(), input_shape)

        out_shape, params = init_fun(key, input_shape)
        out = apply_fun(params, inputs)

        self.assertEqual(out_shape, input_shape)
        beta, gamma = params
        self.assertEqual(beta.shape, (5, ))
        self.assertEqual(gamma.shape, (5, ))
        self.assertEqual(out_shape, out.shape)
Exemple #17
0
 def testDenseShape(self, out_dim, input_shape):
     init_fun, apply_fun = stax.Dense(out_dim)
     _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)