示例#1
0
def default_agent(obs_spec: specs.Array, action_spec: specs.DiscreteArray):
    """Initialize a DQN agent with default parameters."""
    network_init, network = stax.serial(
        stax.Flatten,
        stax.Dense(50),
        stax.Relu,
        stax.Dense(50),
        stax.Relu,
        stax.Dense(action_spec.num_values),
    )
    _, network_params = network_init(random.PRNGKey(seed=1),
                                     (-1, ) + obs_spec.shape)

    return DQNJAX(action_spec=action_spec,
                  network=network,
                  parameters=network_params,
                  batch_size=32,
                  discount=0.99,
                  replay_capacity=10000,
                  min_replay_size=100,
                  sgd_period=1,
                  target_update_period=4,
                  learning_rate=1e-3,
                  epsilon=0.05,
                  seed=42)
示例#2
0
    def __init__(self, input_dims, output_dims, scope_var: OrderedDict):
        super(NeuralODE, self).__init__(input_dims, output_dims, scope_var)

        intermediate_dims = 2 * self.output_dims * self.input_dims
        init_random_params, self.predict = stax.serial(
            stax.Flatten,
            stax.Dense(
                intermediate_dims,
                partial(nn.initializers.glorot_normal(), dtype=np.float64),
                partial(nn.initializers.normal(), dtype=np.float64),
            ),
            stax.Sigmoid,
            stax.Dense(
                intermediate_dims,
                partial(nn.initializers.glorot_normal(), dtype=np.float64),
                partial(nn.initializers.normal(), dtype=np.float64),
            ),
            stax.Sigmoid,
            stax.Dense(
                output_dims * input_dims,
                partial(nn.initializers.glorot_normal(), dtype=np.float64),
                partial(nn.initializers.normal(), dtype=np.float64),
            ),
        )

        self.key = random.PRNGKey(py_random.randrange(9999))
        _, init_params = init_random_params(self.key, (1, input_dims))
        scope_var["params"] = init_params
def main(unused_argv):
  # Build data and .
  print('Loading data.')
  x_train, y_train, x_test, y_test = datasets.mnist(permute_train=True)

  # Build the network
  init_fn, f = stax.serial(
      stax.Dense(2048),
      stax.Tanh,
      stax.Dense(10))

  key = random.PRNGKey(0)
  _, params = init_fn(key, (-1, 784))

  # Linearize the network about its initial parameters.
  f_lin = linearize(f, params)

  # Create and initialize an optimizer for both f and f_lin.
  opt_init, opt_apply, get_params = optimizers.momentum(FLAGS.learning_rate,
                                                        0.9)
  opt_apply = jit(opt_apply)

  state = opt_init(params)
  state_lin = opt_init(params)

  # Create a cross-entropy loss function.
  loss = lambda fx, y_hat: -np.mean(stax.logsoftmax(fx) * y_hat)

  # Specialize the loss function to compute gradients for both linearized and
  # full networks.
  grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))
  grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y)))

  # Train the network.
  print('Training.')
  print('Epoch\tLoss\tLinearized Loss')
  print('------------------------------------------')

  epoch = 0
  steps_per_epoch = 50000 // FLAGS.batch_size

  for i, (x, y) in enumerate(datasets.minibatch(
      x_train, y_train, FLAGS.batch_size, FLAGS.train_epochs)):

    params = get_params(state)
    state = opt_apply(i, grad_loss(params, x, y), state)

    params_lin = get_params(state_lin)
    state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin)

    if i % steps_per_epoch == 0:
      print('{}\t{:.4f}\t{:.4f}'.format(
          epoch, loss(f(params, x), y), loss(f_lin(params_lin, x), y)))
      epoch += 1

  # Print out summary data comparing the linear / nonlinear model.
  x, y = x_train[:10000], y_train[:10000]
  util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss)
  util.print_summary(
      'test', y_test, f(params, x_test), f_lin(params_lin, x_test), loss)
示例#4
0
def build_model_stax(output_size,
                     n_dense_units=300,
                     conv_depth=300,
                     n_conv_layers=2,
                     n_dense_layers=0,
                     kernel_size=5,
                     across_batch=False,
                     add_pos_encoding=False,
                     mean_over_pos=False,
                     mode="train"):
    """Build a model with convolutional layers followed by dense layers."""
    del mode
    layers = [
        cnn(conv_depth=conv_depth,
            n_conv_layers=n_conv_layers,
            kernel_size=kernel_size,
            across_batch=across_batch,
            add_pos_encoding=add_pos_encoding)
    ]
    for _ in range(n_dense_layers):
        layers.append(stax.Dense(n_dense_units))
        layers.append(stax.Relu)

    layers.append(stax.Dense(output_size))

    if mean_over_pos:
        layers.append(reduce_layer(jnp.mean, axis=1))
    init_random_params, predict = stax.serial(*layers)
    return init_random_params, predict
示例#5
0
    def __init__(self, input_dims, output_dims, scope_var: OrderedDict):
        self.input_dims = input_dims
        self.output_dims = output_dims
        self.key = random.PRNGKey(py_random.randrange(9999))

        init_random_params, self.predict = stax.serial(
            stax.Flatten,
            stax.Dense(
                input_dims * 4,
                partial(nn.initializers.glorot_normal(), dtype=np.float64),
                partial(nn.initializers.normal(), dtype=np.float64),
            ),
            stax.Softplus,
            stax.Dense(
                input_dims * 4,
                partial(nn.initializers.glorot_normal(), dtype=np.float64),
                partial(nn.initializers.normal(), dtype=np.float64),
            ),
            stax.Softplus,
            stax.Dense(
                output_dims * 2,
                partial(nn.initializers.glorot_normal(), dtype=np.float64),
                partial(nn.initializers.normal(), dtype=np.float64),
            ),
        )

        _, init_params = init_random_params(self.key, (-1, input_dims))

        scope_var["encoder_params"] = init_params
示例#6
0
def decoder(hidden_dim: int, out_dim: int) -> Tuple[Callable, Callable]:
    return stax.serial(
        stax.Dense(hidden_dim, W_init=stax.randn()),
        stax.Softplus,
        stax.Dense(out_dim, W_init=stax.randn()),
        stax.Sigmoid,
    )
示例#7
0
def MultiHeadedAttention(  # pylint: disable=invalid-name
        feature_depth,
        num_heads=8,
        dropout=1.0,
        mode='train'):
    """Transformer-style multi-headed attention.

  Args:
    feature_depth: int:  depth of embedding
    num_heads: int: number of attention heads
    dropout: float: dropout rate - keep probability
    mode: str: 'train' or 'eval'

  Returns:
    Multi-headed self-attention layer.
  """
    return stax.serial(
        stax.parallel(stax.Dense(feature_depth, W_init=xavier_uniform()),
                      stax.Dense(feature_depth, W_init=xavier_uniform()),
                      stax.Dense(feature_depth, W_init=xavier_uniform()),
                      stax.Identity),
        PureMultiHeadedAttention(feature_depth,
                                 num_heads=num_heads,
                                 dropout=dropout,
                                 mode=mode),
        stax.Dense(feature_depth, W_init=xavier_uniform()),
    )
    def test_kohn_sham_neural_xc_density_mse_converge_tolerance(
            self, density_mse_converge_tolerance, expected_converged):
        init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
            stax.serial(stax.Dense(8), stax.Elu, stax.Dense(1)))
        params_init = init_fn(rng=random.PRNGKey(0))

        states = jit_scf.kohn_sham(
            locations=self.locations,
            nuclear_charges=self.nuclear_charges,
            num_electrons=self.num_electrons,
            num_iterations=3,
            grids=self.grids,
            xc_energy_density_fn=tree_util.Partial(xc_energy_density_fn,
                                                   params=params_init),
            interaction_fn=utils.exponential_coulomb,
            initial_density=self.num_electrons *
            utils.gaussian(grids=self.grids, center=0., sigma=0.5),
            density_mse_converge_tolerance=density_mse_converge_tolerance)

        np.testing.assert_array_equal(states.converged, expected_converged)

        for single_state in scf.state_iterator(states):
            self._test_state(
                single_state,
                self._create_testing_external_potential(
                    utils.exponential_coulomb))
示例#9
0
 def make_stax_model(self):
     act = getattr(jstax, self._activation)
     layers = []
     for h in self._hiddens:
         layers.extend([jstax.Dense(h), act])
     layers.extend([jstax.Dense(1), _StaxSqueeze()])
     return jstax.serial(*layers)
示例#10
0
def main():

    net_init, net_apply = stax.serial(
        stax.Dense(128),
        stax.Softplus,
        stax.Dense(128),
        stax.Softplus,
        stax.Dense(2),
    )

    opt_init, opt_update, get_params = optimizers.adam(1e-3)

    out_shape, net_params = net_init(jax.random.PRNGKey(seed=42),
                                     input_shape=(-1, 2))
    opt_state = opt_init(net_params)

    loss_history = []

    print("Training...")

    train_step = get_train_step(opt_update, get_params, net_apply)

    for i in range(2000):
        x = sample_batch(size=128)
        loss, opt_state = train_step(i, opt_state, x)
        loss_history.append(loss.item())

    print("Training Finished...")

    plot_gradients(loss_history, opt_state, get_params, net_params, net_apply)
示例#11
0
文件: vae.py 项目: while519/numpyro
def decoder(hidden_dim, out_dim):
    return stax.serial(
        stax.Dense(hidden_dim, W_init=stax.randn()),
        stax.Softplus,
        stax.Dense(out_dim, W_init=stax.randn()),
        stax.Sigmoid,
    )
def gen_cnn_conv4(output_units=10,
                  W_initializers_str='glorot_normal()',
                  b_initializers_str='normal()'):
    # This is an up-scaled version of the CNN in keras tutorial: https://keras.io/examples/cifar10_cnn/
    return stax.serial(
        stax.Conv(out_chan=64,
                  filter_shape=(3, 3),
                  W_init=eval(W_initializers_str),
                  b_init=eval(b_initializers_str)), stax.Relu,
        stax.Conv(out_chan=64,
                  filter_shape=(3, 3),
                  W_init=eval(W_initializers_str),
                  b_init=eval(b_initializers_str)), stax.Relu,
        stax.MaxPool((2, 2), strides=(2, 2)),
        stax.Conv(out_chan=128,
                  filter_shape=(3, 3),
                  W_init=eval(W_initializers_str),
                  b_init=eval(b_initializers_str)), stax.Relu,
        stax.Conv(out_chan=128,
                  filter_shape=(3, 3),
                  W_init=eval(W_initializers_str),
                  b_init=eval(b_initializers_str)), stax.Relu,
        stax.MaxPool((2, 2), strides=(2, 2)), stax.Flatten,
        stax.Dense(512,
                   W_init=eval(W_initializers_str),
                   b_init=eval(b_initializers_str)), stax.Relu,
        stax.Dense(output_units,
                   W_init=eval(W_initializers_str),
                   b_init=eval(b_initializers_str)))
示例#13
0
def encoder(hidden_dim, z_dim):
    return stax.serial(
        stax.Dense(hidden_dim, W_init=stax.randn()), stax.Softplus,
        stax.FanOut(2),
        stax.parallel(stax.Dense(z_dim, W_init=stax.randn()),
                      stax.serial(stax.Dense(z_dim, W_init=stax.randn()), stax.Exp)),
    )
示例#14
0
def mlp(args):
    return stax.serial(
        stax.Dense(args.hidden_dim),
        stax.Softplus,
        stax.Dense(args.hidden_dim),
        stax.Softplus,
        stax.Dense(args.output_dim),
    )
示例#15
0
def MLP(num_hidden_layers=2,
        hidden_size=512,
        activation_fn=stax.Relu,
        num_output_classes=10):
    layers = [stax.Flatten]
    layers += [stax.Dense(hidden_size), activation_fn] * num_hidden_layers
    layers += [stax.Dense(num_output_classes), stax.LogSoftmax]
    return stax.serial(*layers)
示例#16
0
def make_network(num_layers, num_channels):
  layers = []
  for i in range(num_layers-1):
      layers.append(stax.Dense(num_channels))
      layers.append(stax.Relu)
  layers.append(stax.Dense(3))
  layers.append(stax.Sigmoid)
  return stax.serial(*layers)
    def test_local_density_approximation(self):
        init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
            stax.serial(stax.Dense(16), stax.Elu, stax.Dense(1)))
        init_params = init_fn(rng=random.PRNGKey(0))
        xc_energy_density = xc_energy_density_fn(self.density, init_params)

        # The first layer of the network takes 1 feature (density).
        self.assertEqual(init_params[0][0].shape, (1, 16))
        self.assertEqual(xc_energy_density.shape, (11, ))
    def test_local_density_approximation_wrong_output_shape(self):
        init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
            stax.serial(stax.Dense(16), stax.Elu, stax.Dense(3)))
        init_params = init_fn(rng=random.PRNGKey(0))

        with self.assertRaisesRegex(
                ValueError, r'The output shape of the network '
                r'should be \(-1, 1\) but got \(11, 3\)'):
            xc_energy_density_fn(self.density, init_params)
示例#19
0
def encoder(hidden_dim, z_dim, activation='Tanh'):
    activation = getattr(stax, activation)
    encoder_init, encode = stax.serial(
        stax.Dense(hidden_dim),
        activation,
        # stax.Dense(hidden_dim), activation,
        stax.FanOut(2),
        stax.parallel(stax.Dense(z_dim), stax.Dense(z_dim)),
    )
    return encoder_init, encode
示例#20
0
文件: nnet.py 项目: dkirkby/zotbin
def create_model(nbin, nhidden, nlayer):
    layers = []
    for i in range(nlayer):
        layers.extend([
            stax.Dense(nhidden),
            stax.LeakyRelu,
            stax.BatchNorm(axis=(0, 1)),
        ])
    layers.extend([stax.Dense(nbin), stax.Softmax])
    return stax.serial(*layers)
示例#21
0
def decoder(hidden_dim, x_dim=2, activation='Tanh'):
    activation = getattr(stax, activation)
    decoder_init, decode = stax.serial(
        stax.Dense(hidden_dim),
        activation,
        # stax.Dense(hidden_dim), activation,
        stax.FanOut(2),
        stax.parallel(stax.Dense(x_dim), stax.Dense(x_dim)),
    )
    return decoder_init, decode
示例#22
0
 def transform(rng, input_dim, output_dim):
     init_fun, apply_fun = stax.serial(
         stax.Dense(hidden_dim, weight_initializer, weight_initializer),
         act,
         stax.Dense(hidden_dim, weight_initializer, weight_initializer),
         act,
         stax.Dense(output_dim, weight_initializer, weight_initializer),
     )
     _, params = init_fun(rng, (input_dim, ))
     return params, apply_fun
示例#23
0
def jaxRbmSpinPhase(hilbert, alpha):
    return stax.serial(
        stax.FanOut(2),
        stax.parallel(
            stax.serial(stax.Dense(alpha * hilbert.size), LogCoshLayer,
                        SumLayer),
            stax.serial(stax.Dense(alpha * hilbert.size), LogCoshLayer,
                        SumLayer),
        ),
        FanInSum2ModPhase,
    )
示例#24
0
def mlp(key, input_dim, output_dim, hidden_layers=(64, 32)):
    ''' make multilayer perceptron '''
    layers = []
    for hl in hidden_layers:
        layers += [stax.Dense(hl), stax.Tanh]
    layers.append(stax.Dense(output_dim))

    init_fun, apply_fun = stax.serial(*layers)
    params = init_fun(key, (-1, input_dim))[1]

    return apply_fun, params
def get_f_and_L():
  _, apply = stax.serial(TexVar(stax.Dense(256), 'z^1', ('W^1', 'b^1')),
                         TexVar(stax.Relu, 'y^1'),
                         TexVar(stax.Dense(3), 'z^2', ('W^2', 'b^2')))
  def f_(params, x):
    return apply(params, tex_var(x, 'x', True))

  def L_(params, x, y_hat):
    y_hat = tex_var(y_hat, '\\hat y', True)
    return tex_var(-np.sum(y_hat * jax.nn.log_softmax(f_(params, x))), 'L')
  return f_, L_
示例#26
0
def gen_cnn_lenet_caffe(output_units = 10, W_initializers_str = 'glorot_normal()', b_initializers_str = 'normal()'):
    return stax.serial(
      stax.Conv(out_chan = 20, filter_shape = (5, 5), W_init= eval(W_initializers_str), b_init= eval(b_initializers_str) ),
      stax.Relu, 
      stax.MaxPool((2, 2), strides = (2, 2)),
      stax.Conv(out_chan = 50, filter_shape = (5, 5), W_init= eval(W_initializers_str), b_init= eval(b_initializers_str) ),
      stax.Relu, 
      stax.MaxPool((2, 2), strides = (2, 2)),
      stax.Flatten, 
      stax.Dense(500, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)),
      stax.Relu,
      stax.Dense(output_units, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str))) 
def get_f_and_nngp():
  _, apply = stax.serial(
      TexVar(stax.Dense(256), 'z^1', ('W^1', 'b^1'), True),
      TexVar(stax.Relu, 'y^1'),
      TexVar(stax.Dense(3), 'z^2', ('W^2', 'b^2')))
  def f_(params, x):
    return apply(params, tex_var(x, 'x', True))
  def nngp_(params, x1, x2):
    x1 = tex_var(x1, 'x^1', True)
    x2 = tex_var(x2, 'x^2', True)
    return tex_var(apply(params, x1) @ apply(params, x2).T, '\\mathcal K')
  return f_, nngp_
示例#28
0
def _create_stax_model(num_classes, sample_shape):
  """Creates toy stax model."""
  stax_init_fn, stax_apply_fn = stax.serial(stax.Flatten,
                                            stax.Dense(2 * num_classes),
                                            stax.Dense(num_classes))
  metrics_fn_map = collections.OrderedDict(accuracy=_accuracy)
  return model.create_model_from_stax(
      stax_init_fn=stax_init_fn,
      stax_apply_fn=stax_apply_fn,
      sample_shape=sample_shape,
      loss_fn=_loss,
      metrics_fn_map=metrics_fn_map)
def fully_connected(num_classes, layers=(64, 64)):
    """Build a fully connected neural network."""
    stack = [stax.Flatten]

    # Concatenate fully connected layers.
    for num_units in layers:
        stack += [stax.Dense(num_units), stax.Relu]

    # Output layer.
    stack += [stax.Dense(num_classes), stax.LogSoftmax]

    return stax.serial(*stack)
示例#30
0
def encoder(hidden_dim: int, z_dim: int) -> Tuple[Callable, Callable]:
    return stax.serial(
        stax.Dense(hidden_dim, W_init=stax.randn()),
        stax.Softplus,
        stax.FanOut(2),
        stax.parallel(
            stax.Dense(z_dim, W_init=stax.randn()),
            stax.serial(
                stax.Dense(z_dim, W_init=stax.randn()),
                stax.Exp,
            ),
        ),
    )