コード例 #1
0
    def test_optimized_lstm_cell_matches_regular(self):

        # Create regular LSTMCell.
        rng = random.PRNGKey(0)
        key1, key2 = random.split(rng)
        x = random.normal(key1, (2, 3))
        c0, h0 = nn.LSTMCell.initialize_carry(rng, (2, ), 4)
        self.assertEqual(c0.shape, (2, 4))
        self.assertEqual(h0.shape, (2, 4))
        (carry, y), initial_params = nn.LSTMCell.init(key2, (c0, h0), x)
        lstm = nn.Model(nn.LSTMCell, initial_params)

        # Create OptimizedLSTMCell.
        rng = random.PRNGKey(0)
        key1, key2 = random.split(rng)
        x = random.normal(key1, (2, 3))
        c0, h0 = nn.OptimizedLSTMCell.initialize_carry(rng, (2, ), 4)
        self.assertEqual(c0.shape, (2, 4))
        self.assertEqual(h0.shape, (2, 4))
        (carry, y_opt), initial_params = nn.OptimizedLSTMCell.partial(
            name='LSTMCell').init(key2, (c0, h0), x)
        lstm_opt = nn.Model(nn.OptimizedLSTMCell.partial(name='LSTMCell'),
                            initial_params)

        onp.testing.assert_allclose(y, y_opt, rtol=1e-6)
        jtu.check_eq(lstm.params, lstm_opt.params)
コード例 #2
0
ファイル: optim_test.py プロジェクト: yanndupis/flax
 def test_param_selection(self):
   params = {
       'x': {
           'kernel': 1,
           'bias': 2,
           'y': {
               'kernel': 3,
               'bias': 4,
           },
       },
   }
   names = []
   def filter_fn(name, _):
     names.append(name)  # track names passed to filter_fn for testing
     return 'kernel' in name
   model = nn.Model(None, params)
   traversal = optim.ModelParamTraversal(filter_fn)
   values = list(traversal.iterate(model))
   self.assertEqual(values, [1, 3])
   self.assertEqual(set(names), set([
       '/x/kernel', '/x/bias', '/x/y/kernel', '/x/y/bias']))
   new_model = traversal.update(lambda x: x + x, model)
   expected_params = {
       'x': {
           'kernel': 2,
           'bias': 2,
           'y': {
               'kernel': 6,
               'bias': 4,
           },
       },
   }
   expected_model = nn.Model(None, expected_params)
   self.assertEqual(new_model, expected_model)
コード例 #3
0
ファイル: nn_test.py プロジェクト: zhang-yd15/flax
 def test_nested_model(self):
   x = jnp.array([1.])
   _, inner_initial_params = DummyModule.init(random.PRNGKey(0), x)
   inner_model = nn.Model(DummyModule, inner_initial_params)
   _, initial_params = NestedModel.init(random.PRNGKey(1), x, inner_model)
   model = nn.Model(NestedModel, initial_params)
   y = model(x, inner_model)
   self.assertEqual(y, jnp.array([3.]))
コード例 #4
0
ファイル: nn_test.py プロジェクト: zhang-yd15/flax
 def test_nested_model_capture_outputs(self):
   x = jnp.array([1.])
   _, inner_initial_params = DummyModule.init(random.PRNGKey(0), x)
   inner_model = nn.Model(DummyModule, inner_initial_params)
   _, initial_params = NestedModel.init(random.PRNGKey(1), x, inner_model)
   model = nn.Model(NestedModel, initial_params)
   with nn.capture_module_outputs() as activations:
     model(x, inner_model)
   expected_activations = {
       '/': [x + 2],
       '/dummy_0': [x + 1],
       '/inner_model': [x + 2],
   }
   self.assertEqual(activations.as_dict(), expected_activations)
コード例 #5
0
    def test_autoregressive_sampling_with_lstm(self):

        L = 4

        # Set up symmetry orbit
        orbit = jnp.array([
            jnp.roll(jnp.identity(L, dtype=np.int32), l, axis=1)
            for l in range(L)
        ])

        # Set up variational wave function
        rnn = nets.LSTM.partial(L=L, hiddenSize=5)
        _, params = rnn.init_by_shape(random.PRNGKey(0), [(L, )])
        rnnModel = nn.Model(rnn, params)
        rbm = nets.RBM.partial(numHidden=2, bias=False)
        _, params = rbm.init_by_shape(random.PRNGKey(0), [(L, )])
        rbmModel = nn.Model(rbm, params)

        psi = NQS((rnnModel, rbmModel))

        # Set up exact sampler
        exactSampler = sampler.ExactSampler(L)

        # Set up MCMC sampler
        mcSampler = sampler.MCMCSampler(random.PRNGKey(0),
                                        jVMC.sampler.propose_spin_flip, (L, ),
                                        numChains=777)

        # Compute exact probabilities
        _, logPsi, pex = exactSampler.sample(psi)

        numSamples = 1000000
        smc, p, _ = mcSampler.sample(psi, numSamples=numSamples)

        self.assertTrue(jnp.max(jnp.abs(jnp.real(psi(smc) - p))) < 1e-12)

        if global_defs.usePmap:
            smc = smc.reshape((smc.shape[0] * smc.shape[1], -1))

        self.assertTrue(smc.shape[0] >= numSamples)

        # Compute histogram of sampled configurations
        smcInt = jax.vmap(state_to_int)(smc)
        pmc, _ = np.histogram(smcInt, bins=np.arange(0, 17))

        self.assertTrue(
            jnp.max(
                jnp.abs(pmc / mcSampler.get_last_number_of_samples() -
                        pex.reshape((-1, ))[:16])) < 1e-3)
コード例 #6
0
def create_model(key, flax_module, input_shape, model_kwargs):
    module = flax_module.partial(**model_kwargs)
    with nn.stochastic(key):
        _, initial_params = module.init_by_shape(key,
                                                 [(input_shape, jnp.float32)])
        model = nn.Model(module, initial_params)
    return model
コード例 #7
0
ファイル: serialization_test.py プロジェクト: shafiahmed/flax
 def test_optimizer_serialization(self):
     rng = random.PRNGKey(0)
     module = nn.Dense.partial(features=1, kernel_init=nn.initializers.ones)
     _, initial_params = module.init_by_shape(rng, [((1, 1), jnp.float32)])
     model = nn.Model(module, initial_params)
     optim_def = optim.Momentum(learning_rate=1.)
     optimizer = optim_def.create(model)
     state = serialization.to_state_dict(optimizer)
     expected_state = {
         'target': {
             'params': {
                 'kernel': onp.ones((1, 1)),
                 'bias': onp.zeros((1, )),
             }
         },
         'state': {
             'step': 0,
             'param_states': {
                 'params': {
                     'kernel': {
                         'momentum': onp.zeros((1, 1))
                     },
                     'bias': {
                         'momentum': onp.zeros((1, ))
                     },
                 }
             }
         },
     }
     self.assertEqual(state, expected_state)
     state = jax.tree_map(lambda x: x + 1, expected_state)
     restored_optimizer = serialization.from_state_dict(optimizer, state)
     optimizer_plus1 = jax.tree_map(lambda x: x + 1, optimizer)
     self.assertEqual(restored_optimizer, optimizer_plus1)
コード例 #8
0
  def test_grad_var(self):
    model_size = 10
    example_grads = [{
        'layer1': np.ones(model_size),
        'layer2': 3 * np.ones(model_size)
    }, {
        'layer1': 2 * np.ones(model_size),
        'layer2': np.ones(model_size)
    }]
    eval_config = {'ema_beta': 0.5}
    training_metrics_grabber = utils.TrainingMetricsGrabber.create(
        example_grads[0], eval_config)

    # For the purposes of this test, we create fake optimizers to satisfy
    # metrics grabber API.
    fake_model = nn.Model(None, example_grads[0])
    new_optimizer = optimizers.GradientDescent(
        learning_rate=None).create(fake_model)
    old_optimizer = optimizers.GradientDescent(
        learning_rate=None).create(fake_model)

    for grad in example_grads:
      training_metrics_grabber = training_metrics_grabber.update(
          grad, old_optimizer, new_optimizer)

    for layer in ['layer1', 'layer2']:
      expected_grad_ema = 1 / 4 * np.zeros(model_size) + 1 / 4 * example_grads[
          0][layer] + 1 / 2 * example_grads[1][layer]

      self.assertArraysAllClose(expected_grad_ema,
                                training_metrics_grabber.state[layer].grad_ema)
コード例 #9
0
ファイル: train.py プロジェクト: shafiahmed/flax
 def init(key):
     with nn.attention.Cache().mutate() as cache_def:
         _, initial_params = model_def.init_by_shape(
             key, [(input_shape, jnp.float32), (target_shape, jnp.float32)],
             cache=cache_def)
         model = nn.Model(model_def, initial_params)
     return model, cache_def
コード例 #10
0
ファイル: nn_test.py プロジェクト: zhang-yd15/flax
  def test_call_module_method(self):
    class MultiMethod(nn.Module):

      def apply(self, x):
        return x + self.param('bias', x.shape, initializers.ones)

      @nn.module_method
      def l2(self):
        return jnp.sum(self.get_param('bias') ** 2)

    class MultiMethodModel(nn.Module):

      def apply(self, x):
        layer = MultiMethod.shared()
        layer(x)  # init
        return layer.l2()

    self.assertEqual(
        MultiMethod.l2.__qualname__,
        MultiMethod.__qualname__ + '.l2')

    x = jnp.array([1., 2.])

    _, params = MultiMethod.init(random.PRNGKey(0), x)
    model = nn.Model(MultiMethod, params)
    self.assertEqual(model.l2(), 2.)
    
    y, _ = MultiMethodModel.init(random.PRNGKey(0), x)
    self.assertEqual(y, 2.)
コード例 #11
0
def create_representation_model(encoder_fn,
                                encoder_fn_kwargs,
                                reduce_fn,
                                reduce_fn_kwargs,
                                num_categories,
                                output_features,
                                embed=False,
                                key=random.PRNGKey(0)):
    """Instantiates a RepresentationModel object."""

    module = RepresentationModel.partial(encoder_fn=encoder_fn,
                                         encoder_fn_kwargs=encoder_fn_kwargs,
                                         reduce_fn=reduce_fn,
                                         reduce_fn_kwargs=reduce_fn_kwargs,
                                         num_categories=num_categories,
                                         output_features=output_features,
                                         embed=embed)

    _, initial_params = RepresentationModel.init_by_shape(
        key,
        input_specs=[((1, 1), jnp.float32)],
        encoder_fn=encoder_fn,
        encoder_fn_kwargs=encoder_fn_kwargs,
        reduce_fn=reduce_fn,
        reduce_fn_kwargs=reduce_fn_kwargs,
        num_categories=num_categories,
        output_features=output_features,
        embed=embed)

    model = nn.Model(module, initial_params)

    return model
コード例 #12
0
ファイル: nn_test.py プロジェクト: wrzadkow/flax
 def test_conv2dlstm(self):
     rng = random.PRNGKey(0)
     key1, key2 = random.split(rng)
     x = random.normal(key1, (2, 4, 4, 3))
     c0, h0 = nn.ConvLSTM.initialize_carry(rng, (2, ), (4, 4, 6))
     self.assertEqual(c0.shape, (2, 4, 4, 6))
     self.assertEqual(h0.shape, (2, 4, 4, 6))
     (carry, y), initial_params = nn.ConvLSTM.init(key2, (c0, h0),
                                                   x,
                                                   features=6,
                                                   kernel_size=(3, 3))
     lstm = nn.Model(nn.ConvLSTM, initial_params)
     self.assertEqual(carry[0].shape, (2, 4, 4, 6))
     self.assertEqual(carry[1].shape, (2, 4, 4, 6))
     onp.testing.assert_allclose(y, carry[1])
     param_shapes = jax.tree_map(onp.shape, lstm.params)
     self.assertEqual(
         param_shapes, {
             'hh': {
                 'bias': (6 * 4, ),
                 'kernel': (3, 3, 6, 6 * 4)
             },
             'ih': {
                 'bias': (6 * 4, ),
                 'kernel': (3, 3, 3, 6 * 4)
             },
         })
コード例 #13
0
def main(argv):
    key = random.PRNGKey(0)
    train_ds = tfds.load('mnist', split=tfds.Split.TRAIN)
    train_ds = train_ds.cache().shuffle(1000).batch(FLAGS.batch_size)
    test_ds = tfds.as_numpy(
        tfds.load('mnist', split=tfds.Split.TEST, batch_size=-1))

    _, params = VAE.init_by_shape(key, [((1, 784), jnp.float32)])
    vae = nn.Model(VAE, params)

    optimizer = optim.Adam(learning_rate=FLAGS.learning_rate).create(vae)

    for epoch in range(FLAGS.num_epochs):
        for batch in tfds.as_numpy(train_ds):
            batch['image'] = batch['image'].reshape(-1, 784) / 255.0
            optimizer = train_step(optimizer, batch)

        z = np.random.normal(size=(64, 20))
        metrics, comparison, sample = eval(optimizer.target, test_ds, z)
        save_image(comparison,
                   'results/reconstruction_' + str(epoch) + '.png',
                   nrow=8)
        save_image(sample, 'results/sample_' + str(epoch) + '.png', nrow=8)

        print("eval epoch: {}, loss: {:.4f}, BCE: {:.4f}, KLD: {:.4f}".format(
            epoch + 1, metrics['loss'], metrics['bce'], metrics['kld']))
コード例 #14
0
def create_model(key, batch_size, image_size, model_dtype, space_to_depth):
    """Initialize a ResNet-50 model."""
    if space_to_depth:
        input_shape = (batch_size, image_size // 2, image_size // 2, 3 * 2 * 2)
    else:
        input_shape = (batch_size, image_size, image_size, 3)
    model_type = models.FakeResNet if FLAGS.fake_model else models.ResNet
    batchnorm_span = FLAGS.batchnorm_span
    if batchnorm_span is None:
        batchnorm_span = max(batch_size, 64)
    if FLAGS.distributed_batchnorm and (batch_size < batchnorm_span <=
                                        batch_size * jax.device_count()):
        mllogger.event('model_bn_span', batchnorm_span)
        model_def = model_type.partial(num_classes=1000,
                                       axis_name='batch',
                                       axis_index_groups=local_replica_groups(
                                           batchnorm_span // batch_size),
                                       dtype=model_dtype,
                                       conv0_space_to_depth=space_to_depth)
    else:
        mllogger.event('model_bn_span', batch_size)
        model_def = model_type.partial(num_classes=1000,
                                       dtype=model_dtype,
                                       conv0_space_to_depth=space_to_depth)
    with nn.stateful() as init_state:
        _, params = model_def.init_by_shape(key, [(input_shape, model_dtype)])
    model = nn.Model(model_def, params)
    return model, init_state
コード例 #15
0
ファイル: basic_dgp.py プロジェクト: danieljtait/ladax
def create_model(key, input_shape):
    def inducing_loc_init(key, shape):
        return jnp.linspace(-1.5, 1.5, FLAGS.num_inducing_points)[:,
                                                                  jnp.newaxis]

    kwargs = {}
    for i in range(1, FLAGS.num_layers + 1):
        kwargs['kernel_fn_{}_kwargs'.format(i)] = {
            'amplitude_init': lambda key, shape: jnp.ones(shape),
            'length_scale_init': lambda key, shape: jnp.ones(shape)
        }
        kwargs['inducing_var_{}_kwargs'.format(i)] = {
            'fixed_locations': False,
            'whiten': FLAGS.whiten,
            'inducing_locations_init': inducing_loc_init
        }

    model_def = DeepGPModel.partial(**kwargs)

    with nn.stochastic(key):
        _, params = model_def.init_by_shape(key, [
            (input_shape, jnp.float64),
        ], nn.make_rng(), **kwargs)

        return nn.Model(model_def, params)
コード例 #16
0
def create_model(rng):
    """Creates a model."""
    vocab_size = params['vocab_length']
    _, initial_params = charRNN.init_by_shape(
        rng, [((1, params['seq_length'], vocab_size), jnp.float32)])
    model = nn.Model(charRNN, initial_params)
    return model
コード例 #17
0
    def test_gradients_nonhermitian(self):

        dlist = [jax.devices()[0], jax.devices()]

        for ds in dlist:

            global_defs.set_pmap_devices(ds)

            net = nets.CpxRNN.partial(L=3)
            _, params1 = net.init_by_shape(random.PRNGKey(0), [(3, )])
            model = nn.Model(net, params1)

            s = jnp.zeros(get_shape((4, 3)), dtype=np.int32)
            s = jax.ops.index_update(s, jax.ops.index[..., 0, 1], 1)
            s = jax.ops.index_update(s, jax.ops.index[..., 2, 2], 1)

            psi = NQS(model)
            psi0 = psi(s)
            G = psi.gradients(s)
            delta = 1e-5
            params = psi.get_parameters()
            for j in range(G.shape[-1]):
                u = jax.ops.index_update(
                    jnp.zeros(G.shape[-1], dtype=jVMC.global_defs.tReal),
                    jax.ops.index[j], 1)
                psi.update_parameters(delta * u)
                psi1 = psi(s)
                psi.set_parameters(params)

                # Finite difference gradients
                Gfd = (psi1 - psi0) / delta

                with self.subTest(i=j):
                    self.assertTrue(jnp.max(jnp.abs(Gfd - G[..., j])) < 1e-2)
コード例 #18
0
    def test_gradients_cpx(self):

        dlist = [jax.devices()[0], jax.devices()]

        for ds in dlist:

            global_defs.set_pmap_devices(ds)

            rbm = nets.CpxRBM.partial(numHidden=2, bias=True)
            _, params = rbm.init_by_shape(random.PRNGKey(0), [(1, 3)])
            rbmModel = nn.Model(rbm, params)
            s = jnp.zeros(get_shape((4, 3)), dtype=np.int32)
            s = jax.ops.index_update(s, jax.ops.index[..., 0, 1], 1)
            s = jax.ops.index_update(s, jax.ops.index[..., 2, 2], 1)

            psiC = NQS(rbmModel)
            psi0 = psiC(s)
            G = psiC.gradients(s)
            delta = 1e-5
            params = psiC.get_parameters()
            for j in range(G.shape[-1]):
                u = jax.ops.index_update(
                    jnp.zeros(G.shape[-1], dtype=global_defs.tReal),
                    jax.ops.index[j], 1)
                psiC.update_parameters(delta * u)
                psi1 = psiC(s)
                psiC.set_parameters(params)

                # Finite difference gradients
                Gfd = (psi1 - psi0) / delta

                with self.subTest(i=j):
                    self.assertTrue(jnp.max(jnp.abs(Gfd - G[..., j])) < 1e-2)
コード例 #19
0
ファイル: nn_test.py プロジェクト: zhang-yd15/flax
 def test_truncated_module(self):
   x = jnp.array([1.])
   _, initial_params = NestedModule.init(random.PRNGKey(0), x)
   model = nn.Model(NestedModule, initial_params)
   model = model.truncate_at('/dummy_0')
   y = model(x)
   self.assertEqual(y, [x + 1])
コード例 #20
0
ファイル: inducing_var.py プロジェクト: danieljtait/ladax
def create_loss(rng, model, train_ds):

    loss_clz = losses.VariationalGaussianLikelihoodLoss

    dist = model(train_ds['index_points'])
    _, params = loss_clz.init(rng, train_ds['y'], dist)
    return nn.Model(loss_clz, params)
コード例 #21
0
ファイル: train.py プロジェクト: vballoli/flax
def create_model(key, input_shape, model_kwargs):
    model_def = models.TransformerLM.partial(**model_kwargs)
    with nn.attention.Cache().mutate() as cache_def:
        _, initial_params = model_def.init_by_shape(
            key, [(input_shape, jnp.float32)], cache=cache_def)
    model = nn.Model(model_def, initial_params)
    return model, cache_def
コード例 #22
0
ファイル: train.py プロジェクト: shafiahmed/flax
def train():
    """Run main training loop."""
    rng = random.PRNGKey(0)

    # Get Zachary's karate club graph dataset.
    node_feats, node_labels, sources, targets = get_karate_club_data()

    # Create model and optimizer.
    _, initial_params = GNN.init(rng,
                                 node_x=node_feats,
                                 edge_x=None,
                                 sources=sources,
                                 targets=targets)
    model = nn.Model(GNN, initial_params)
    optimizer = optim.Adam(learning_rate=0.01).create(model)

    # Train for 20 iterations.
    for iteration in range(20):
        optimizer, loss = train_step(optimizer, node_feats, sources, targets)

        accuracy = eval_step(  # Model is stored in `optimizer.target`.
            optimizer.target, node_feats, sources, targets, node_labels)

        print('iteration: %d, loss: %.4f, accuracy: %.2f' %
              (iteration + 1, loss, accuracy * 100))
コード例 #23
0
    def train(
        self,
        epochs=None,
        batch_size=None,
        model_save_path=None,
        display_every=1000,
    ):
        """ Trains the model for a fixed number of epochs"""
        dim_x = self.data.geom.dim
        train_data = self.data.train_data()
        train_points = device_put(train_data[:, dim_x])
        train_tag = device_put(train_data[:, dim_x:])
        print('+-+-+-+-+-+-+-')

        _, initial_params = FNN.init_by_shape(jax.random.PRNGKey(0),
                                              [((1, 1, 3), jnp.float32)])
        model = nn.Model(FNN, initial_params)

        optimizer_def = flax.optim.Adam(learning_rate=self.learning_rate)
        optimizer = optimizer_def.create(model)
        print('+++++++++++++')

        first_grad = grad(optimizer.target)(train_points)
        second_grad = jax.hessian(optimizer.target)(train_points).diagonal()

        print('------------')
        print(first_grad, second_grad)
        return first_grad, second_grad
コード例 #24
0
def create_model(config):
    """Create a model, starting with a pre-trained checkpoint."""
    model_kwargs = dict(config=config.model, )
    model_def = modeling.BertForPreTraining.partial(**model_kwargs)
    if config.init_checkpoint:
        initial_params = import_weights.load_params(
            init_checkpoint=config.init_checkpoint,
            hidden_size=config.model.hidden_size,
            num_attention_heads=config.model.num_attention_heads,
            keep_masked_lm_head=True)
    else:
        with nn.stochastic(jax.random.PRNGKey(0)):
            _, initial_params = model_def.init_by_shape(
                jax.random.PRNGKey(0),
                [((1, config.max_seq_length), jnp.int32),
                 ((1, config.max_seq_length), jnp.int32),
                 ((1, config.max_seq_length), jnp.int32),
                 ((1, config.max_predictions_per_seq), jnp.int32)],
                deterministic=True)

            def fixup_for_tpu(x, i=[0]):
                """HACK to fix incorrect param initialization on TPU."""
                if isinstance(x, jax.ShapeDtypeStruct):
                    i[0] += 1
                    if len(x.shape) == 2:
                        return jnp.zeros(x.shape, x.dtype)
                    else:
                        return nn.linear.default_kernel_init(
                            jax.random.PRNGKey(i[0]), x.shape, x.dtype)
                else:
                    return x

            initial_params = jax.tree_map(fixup_for_tpu, initial_params)
    model = nn.Model(model_def, initial_params)
    return model
コード例 #25
0
ファイル: train_test.py プロジェクト: yanndupis/flax
  def test_permutation_invariance(self):

    num_nodes = 4
    num_features = 2
    rng = random.PRNGKey(0)

    # Generate random graph.
    adjacency = random.randint(rng, (num_nodes, num_nodes), 0, 2)
    node_feats = random.normal(rng, (num_nodes, num_features))
    sources, targets = jnp.where(adjacency)

    # Get permuted graph.
    perm = random.permutation(rng, jnp.arange(num_nodes))
    node_feats_perm = node_feats[perm]
    adjacency_perm = adjacency[perm]
    for j in range(len(adjacency)):
      adjacency_perm = jax.ops.index_update(
          adjacency_perm, j, adjacency_perm[j][perm])
    sources_perm, targets_perm = jnp.where(adjacency_perm)

    # Create GNN.
    _, initial_params = GNN.init(
      rng, node_x=node_feats, edge_x=None, sources=sources, targets=targets)
    model = nn.Model(GNN, initial_params)

    # Feedforward both original and permuted graph.
    logits = model(node_feats, None, sources, targets)
    logits_perm = model(node_feats_perm, None, sources_perm, targets_perm)

    self.assertAllClose(logits[perm], logits_perm, check_dtypes=False)
コード例 #26
0
  def test_decoding(self, spatial_shape, attn_dims):
    bs = 2
    num_heads = 3
    num_features = 4
    rng = random.PRNGKey(0)
    key1, key2 = random.split(rng)
    inputs = random.normal(
        key1, (bs,) + spatial_shape + (num_heads * num_features,))
    module = nn.SelfAttention.partial(
        num_heads=num_heads,
        qkv_features=num_heads * num_features,
        attention_axis=attn_dims,
        causal_mask=True,
        precision=lax.Precision.HIGHEST)

    with nn.attention.Cache().mutate() as cache_def:
      _, initial_params = module.init_by_shape(
          key2, [(inputs.shape, inputs.dtype)], cache=cache_def)
    model = nn.Model(module, initial_params)
    y_ref = jax.jit(lambda f, x: f(x))(model, inputs)

    # feed the inputs sequentially to simulate decoding
    cache0 = cache_def.initialize_cache((bs,) + spatial_shape)
    def body_fn(cache, x):
      with cache.mutate() as new_cache:
        y = model(x, cache=new_cache)
      return new_cache, y
    # scan_in_dim supports scanning multiple dims
    _, y = jax_utils.scan_in_dim(body_fn, cache0, inputs,
                                    axis=attn_dims, keepdims=True)

    onp.testing.assert_allclose(y_ref, y, atol=1e-5)
コード例 #27
0
def initialize(flax_module_def, initializer, loss_fn, input_shape,
               output_shape, hps, rng, metrics_logger):
    """Run the given initializer.

  We initialize in 3 phases. First we run the default initializer that is
  specified by the model constructor. Next we apply any rescaling as specified
  by hps.layer_rescale_factors. Finally we run the black box initializer
  provided by the initializer arg (the default is noop).

  Args:
    flax_module_def: An uninitialized flax module definition.
    initializer: An initializer defined in init_lib.
    loss_fn: A loss function.
    input_shape: The input shape of a single data example.
    output_shape: The output shape of a single data example.
    hps: A dictionary specifying the model and initializer hparams.
    rng: An rng key to seed the initialization.
    metrics_logger: Used for black box initializers that have learning curves.

  Returns:
    A tuple (model, batch_stats), where model is the initialized
    flax.nn.Model and batch_stats is the collection used for batch norm.
  """
    model_dtype = utils.dtype_from_str(hps.model_dtype)
    # init_by_shape should either pass in a tuple or a list of tuples.
    # For example, for vision tasks typically input_shape is (image_shape)
    # For seq2seq tasks, shape can be a list of two tuples corresponding to
    # input_sequence_shape for encoder and output_sequence_shape for decoder.
    # TODO(gilmer,ankugarg): Support initializers for list of tuples.
    if isinstance(input_shape, list):  # Typical case for seq2seq models
        input_specs = [((hps.batch_size, *x), model_dtype)
                       for x in input_shape]
    else:  # Typical case for classification models
        input_specs = [((hps.batch_size, *input_shape), model_dtype)]
    params_rng, init_rng, dropout_rng = jax.random.split(rng, num=3)

    with nn.stateful() as batch_stats:
        with nn.stochastic(dropout_rng):
            # Using flax_module_def.create can OOM for larger models, so we must use
            # create by shape here.
            # TODO(gilmer) Link to flax issue when bug reporting process finalizes.
            _, params = flax_module_def.init_by_shape(params_rng,
                                                      input_specs,
                                                      train=False)
    model = nn.Model(flax_module_def, params)

    if hps.get('layer_rescale_factors'):
        model = model_utils.rescale_layers(model, hps.layer_rescale_factors)
    # We don't pass batch_stats to the initializer, the initializer will just
    # run batch_norm in train mode and does not need to maintain the batch_stats.
    # TODO(gilmer): We hardcode here weighted_cross_entropy, but this will need
    # to change for other models. Maybe have meta_loss_inner as an initializer
    # hyper_param?
    # TODO(gilmer): instead of passing in weighted_xent, pass in the model and get
    # the loss from that.
    new_model = initializer(loss_fn, model, hps, input_shape, output_shape,
                            init_rng, metrics_logger)

    return new_model, batch_stats
コード例 #28
0
ファイル: nn_test.py プロジェクト: zhang-yd15/flax
 def test_shared_module(self):
   rng = random.PRNGKey(0)
   x = jnp.array([1.])
   _, initial_params = LoopModule.init(rng, x)
   model = nn.Model(LoopModule, initial_params)
   y = model(x)
   self.assertEqual(y, jnp.array([3.]))
   self.assertEqual(model.params, {'dummy': {'bias': jnp.array([1.])}})
コード例 #29
0
def create_model(key, batch_size, image_size, model_dtype):
    input_shape = (batch_size, image_size, image_size, 3)
    module = models.ResNet.partial(num_classes=1000, dtype=model_dtype)
    with nn.stateful() as init_state:
        _, initial_params = module.init_by_shape(key,
                                                 [(input_shape, model_dtype)])
        model = nn.Model(module, initial_params)
    return model, init_state
コード例 #30
0
ファイル: serialization_test.py プロジェクト: shafiahmed/flax
 def test_model_serialization_to_bytes(self):
     rng = random.PRNGKey(0)
     module = nn.Dense.partial(features=1, kernel_init=nn.initializers.ones)
     _, initial_params = module.init_by_shape(rng, [((1, 1), jnp.float32)])
     model = nn.Model(module, initial_params)
     serialized_bytes = serialization.to_bytes(model)
     restored_model = serialization.from_bytes(model, serialized_bytes)
     self.assertEqual(restored_model.params, model.params)