Exemple #1
0
 def learning_rate(self, step):
     """Return the learning rate for the given step."""
     if self._lr_schedule is not None:
         with fastmath.use_backend(fastmath.Backend.NUMPY):
             return self._lr_schedule(step)
     params = self._optimizer._init_opt_params  # pylint: disable=protected-access
     return params['learning_rate']
Exemple #2
0
 def test_lsh_ff(self):
     with fastmath.use_backend(fastmath.Backend.JAX):
         layer = efficient_attention.LSHFF(d_ff=1024 * 8, n_buckets=[16, 8])
         x = np.ones((3, 7, 1024)).astype(np.float32)
         _, _ = layer.init(shapes.signature(x))
         y = layer(x)
         self.assertEqual(y.shape, x.shape)
Exemple #3
0
    def test_no_int32_or_uint32_returned(self):
        """Tests that Trainer._jit_update_fn doesn't return int32 or uint32.

    TF pins int32/uint32 tensors to CPU, which will cause XLA-forced-compiled
    computation to copy int32/uint32 outputs to CPU. This test makes sure that
    won't happen.
    """
        if xla_bridge.device_count() > 1:
            self.skipTest(
                "tf-numpy backend doesn't support multi-devices yet.")
        with fastmath.use_backend(fastmath.Backend.TFNP), \
              self.tmp_dir() as output_dir:
            n_classes = 1001
            model_fn = functools.partial(models.Resnet50,
                                         n_output_classes=n_classes)
            inputs = _test_inputs(n_classes, input_shape=(224, 224, 3))
            trainer = trainer_lib.Trainer(
                model=model_fn,
                loss_fn=layers.CrossEntropyLoss(),
                optimizer=trax_opt.SM3,
                lr_schedule=lr.multifactor(),
                inputs=inputs,
            )
            trainer.reset(output_dir)
            trainer.train_epoch(1, 0)
            # Those are the things returned by Trainer._jit_update_fn
            arrays = (trainer._opt_state.weights, trainer._opt_state.slots,
                      trainer._model_state, trainer._rngs)
            arrays = tf.nest.flatten(arrays)
            for x in arrays:
                if isinstance(x, jnp.ndarray) and (x.dtype == jnp.int32
                                                   or x.dtype == jnp.uint32):
                    raise ValueError('Found an array of int32 or uint32: %s' %
                                     x)
Exemple #4
0
    def test_train_fills_in_missing_eval_metrics(self, backend):
        with fastmath.use_backend(backend):
            # Prepare model and inputs
            n_classes = 4
            steps = 2
            eval_steps = 2
            model_fn = functools.partial(models.MLP,
                                         layer_widths=(16, 16, n_classes))
            inputs = _test_inputs(n_classes)
            additional_eval_stream = trainer_lib.NamedStream(
                # deliberately duplicating eval data
                stream=inputs.eval_stream(1),
                name='additional_eval_task')

            # Train and evaluate
            output_dir = self.create_tempdir().full_path
            loop = trainer_lib.train(
                output_dir,
                model=model_fn,
                inputs=inputs,
                steps=steps,
                eval_steps=eval_steps,
                eval_frequency=1,
                additional_eval_streams=[additional_eval_stream])

            self.assertLen(loop.eval_tasks, 2)
            eval_task_1, eval_task_2 = loop.eval_tasks
            self.assertCountEqual(eval_task_1.metrics, eval_task_2.metrics)
            self.assertCountEqual(eval_task_1.metric_names,
                                  eval_task_2.metric_names)
Exemple #5
0
 def test_blocksparse_ff_predict_equals_eval(self):
   d_model = 1024
   n_experts = 64
   d_ff = d_model * 8
   x_shape = (1, 1, d_model)
   temperature = 0.7
   with fastmath.use_backend(fastmath.Backend.JAX):
     x = np.ones(x_shape).astype(np.float32)
     input_signature = shapes.signature(x)
     common_kwargs = dict(
         d_ff=d_ff,
         n_experts=n_experts,
         temperature=temperature,
     )
     eval_model = sparsity.BlockSparseFF(
         mode='eval', **common_kwargs)
     weights, state = eval_model.init(input_signature)
     eval_out, _ = eval_model.pure_fn(
         x, weights, state, rng=jax.random.PRNGKey(0))
     pred_model = sparsity.BlockSparseFF(
         mode='predict', **common_kwargs)
     _, _ = pred_model.init(input_signature)
     pred_out, _ = pred_model.pure_fn(
         x, weights, state, rng=jax.random.PRNGKey(0))
     self.assertEqual(eval_out.shape, x.shape)
     # eval_out and pred_out should be identical.
     np.testing.assert_array_almost_equal(eval_out[0, 0, :], pred_out[0, 0, :])
Exemple #6
0
 def test_autoregressive_sample_transformerlm_tfnp(self):
     with fastmath.use_backend(fastmath.Backend.TFNP):
         model = models.TransformerLM(10,
                                      d_model=32,
                                      d_ff=64,
                                      n_layers=1,
                                      n_heads=2,
                                      mode='predict')
         model.init(shapes.ShapeDtype((1, 1), dtype=np.int32))
         s1 = decoding.autoregressive_sample(model,
                                             batch_size=1,
                                             eos_id=-1,
                                             max_length=10)
         self.assertEqual(s1.shape[0], 1)
         self.assertEqual(s1.shape[1], 10)
         batch_per_device = 2 // fastmath.device_count()
         model.init(shapes.ShapeDtype((batch_per_device, 1),
                                      dtype=np.int32))
         s2 = decoding.autoregressive_sample(model,
                                             batch_size=2,
                                             max_length=10)
         self.assertEqual(s2.shape[0], 2)
         self.assertLess(s2.shape[1], 11)
         model.init(shapes.ShapeDtype((1, 1), dtype=np.int32))
         prefix = np.array([[1, 2, 3]])
         s3 = decoding.autoregressive_sample(model,
                                             prefix,
                                             eos_id=-1,
                                             max_length=10,
                                             batch_size=1)
         self.assertEqual(s3.shape[0], 1)
         self.assertEqual(s3.shape[1], 10)
Exemple #7
0
    def _test_train_eval_predict(self, backend):
        with fastmath.use_backend(backend):
            # Prepare model and inputs
            n_classes = 4
            steps = 2
            eval_steps = 2

            # Adds Dropout and BatchNorm to test state handling.
            def model_fn(mode='train'):
                return tl.Serial(
                    tl.Dropout(mode=mode, rate=0.1), tl.BatchNorm(mode=mode),
                    models.MLP(d_hidden=16,
                               n_output_classes=n_classes,
                               mode=mode))

            inputs = _test_inputs(n_classes)

            # Train and evaluate
            output_dir = self.create_tempdir().full_path
            loop = trainer_lib.train(output_dir,
                                     model=model_fn,
                                     inputs=inputs,
                                     steps=steps,
                                     eval_steps=eval_steps,
                                     eval_frequency=1)  # eval at every step.

            # Assert total train steps
            self.assertEqual(steps, loop.step)

            # Predict with final weights
            inputs = inputs.train_stream(1)
            model = model_fn()
            weights = loop.model.weights
            state = loop.model.state
            model(next(inputs)[0], weights=weights, state=state)
Exemple #8
0
 def test_grulm_forward_shape(self, backend):
     with fastmath.use_backend(backend):
         model = rnn.GRULM(vocab_size=20, d_model=16)
         x = np.ones((3, 28)).astype(np.int32)
         _, _ = model.init(shapes.signature(x))
         y = model(x)
         self.assertEqual(y.shape, (3, 28, 20))
Exemple #9
0
    def test_train_eval_predict_sm3(self, backend):
        with fastmath.use_backend(backend):
            # Prepare model and inputs
            n_classes = 4
            steps = 2
            eval_steps = 2
            model_fn = functools.partial(models.MLP,
                                         d_hidden=16,
                                         n_output_classes=n_classes)
            inputs = _test_inputs(n_classes)

            # Train and evaluate
            output_dir = self.create_tempdir().full_path
            loop = trainer_lib.train(
                output_dir,
                model=model_fn,
                inputs=inputs,
                steps=steps,
                eval_steps=eval_steps,
                eval_frequency=1,  # eval every step.
                optimizer=trax_opt.SM3)

            # Assert total train steps
            self.assertEqual(steps, loop.step)

            # Predict with weights loaded from file.
            inputs = inputs.train_stream(1)
            model = model_fn()
            model.init_from_file(os.path.join(output_dir, 'model.pkl.gz'))
            model(next(inputs)[0])
Exemple #10
0
    def test_train_restart(self, backend):
        with fastmath.use_backend(backend):
            # Prepare model and inputs
            n_classes = 4
            steps = 2
            eval_steps = 2
            model_fn = functools.partial(models.MLP,
                                         d_hidden=16,
                                         n_output_classes=n_classes)
            inputs = _test_inputs(n_classes)

            # Train and evaluate
            output_dir = self.create_tempdir().full_path
            trainer_lib.train(output_dir,
                              model=model_fn,
                              inputs=inputs,
                              steps=steps,
                              eval_steps=eval_steps,
                              eval_frequency=1)

            # Restart training
            loop = trainer_lib.train(output_dir,
                                     model=model_fn,
                                     inputs=inputs,
                                     steps=(2 * steps),
                                     eval_steps=eval_steps,
                                     eval_frequency=1)

            # Assert total train steps - with loop we don't resume, but train for as
            # many steps as given, so: steps + 2*steps = 3*steps.
            self.assertEqual(loop.step, 3 * steps)
Exemple #11
0
 def test_sru(self, backend):
     with fastmath.use_backend(backend):
         layer = tl.SRU(7)
         x = np.ones((8, 9, 7), np.float32)
         _, _ = layer.init(shapes.signature(x))
         y = layer(x)
         self.assertEqual(y.shape, x.shape)
Exemple #12
0
 def test_lstm_cell(self, backend):
     with fastmath.use_backend(backend):
         layer = tl.LSTMCell(9)
         xs = [np.ones((8, 9)), np.ones((8, 18))]
         _, _ = layer.init(shapes.signature(xs))
         ys = layer(xs)
         self.assertEqual([y.shape for y in ys], [(8, 9), (8, 18)])
Exemple #13
0
 def test_conv_gru_cell(self, backend):
     with fastmath.use_backend(backend):
         layer = tl.ConvGRUCell(9, kernel_size=(3, 3))
         x = np.ones((8, 1, 7, 9))
         _, _ = layer.init(shapes.signature(x))
         y = layer(x)
         self.assertEqual(y.shape, x.shape)
Exemple #14
0
  def test_reformer2_predict_equals_eval(self):
    with fastmath.use_backend(fastmath.Backend.JAX):
      vocab_size = 16
      d_model = 8
      batch_size = 2
      length = 5

      model_fn = functools.partial(
          reformer.Reformer2,
          vocab_size,
          d_model=d_model,
          d_ff=16,
          n_encoder_layers=1,
          n_decoder_layers=1,
          n_heads=2,
          dropout=0.0,
          max_len=length*2,
          pos_type=None,
          n_decoder_attention_layers=1,
          encoder_attention_type=tl.Attention,
          encoder_decoder_attention_type=tl.CausalAttention,
      )

      inp = np.random.randint(vocab_size, size=(batch_size, length))
      out = np.zeros((batch_size, length), dtype=np.int32)

      # TODO(jaszczur): check why init_tokens > 1 fails nondeterministically
      test_utils.test_eval_equals_predict((inp, out), model_fn, 1, -1,
                                          init_tokens=1)
Exemple #15
0
  def test_multi_input(self, backend):
    def _MultiInputFn():  # pylint: disable=invalid-name
      def f(a, b, carry):
        return a + b, b, carry + 1
      return tl.Fn('MultiInputFn', f, n_out=2)

    with fastmath.use_backend(backend):
      layer = tl.Scan(_MultiInputFn(), axis=1)
      xs = [
          np.array([[0, 1, 2],
                    [0, 10, 20]]),
          np.array([[4, 5, 6],
                    [40, 50, 60]]),
          np.array([9000,
                    8000])
      ]
      ys = layer(xs)
      self.assertEqual(as_list(ys),
                       [[[4, 6, 8],
                         [40, 60, 80]],
                        [[4, 5, 6],
                         [40, 50, 60]],
                        [9003,
                         8003]
                       ])
Exemple #16
0
    def test_sparse_ff_predict_equals_eval(self):
        with fastmath.use_backend(fastmath.Backend.JAX):
            d_model = 64
            seq_len = 6
            x_shape = (1, seq_len, d_model)
            inp = np.ones(x_shape).astype(np.float32)

            model_fn = functools.partial(
                sparsity.SparseFF,
                d_ff=256,
                temperature=0.7,
                n_elements_in_block=8,
            )

            configs = [
                {
                    'multiply_by_controller_output': True
                },
                {
                    'multiply_by_controller_output': False
                },
                {
                    'ff_chunk_size': 2
                },
            ]

            test_utils.test_eval_equals_predict_configs(inp, model_fn, configs)
    def _test_sparse_fast_inference(self, length):
        with fastmath.use_backend(fastmath.Backend.JAX):
            vocab_size = 16
            d_model = 4
            batch_size = 2

            encoder_decoder_attention_type = functools.partial(
                tl.MultiplicativeConvCausalAttention,
                sparsity=2,
                length_kernel_size=1,
            )

            model_fn = functools.partial(
                ct.ConfigurableTransformer,
                input_vocab_size=vocab_size,
                d_model=d_model,
                d_ff=8,
                n_encoder_layers=2,
                n_decoder_layers=2,
                n_heads=2,
                loss_sparsity=2,
                ff_sparsity=2,
                encoder_decoder_attention_type=encoder_decoder_attention_type,
                ff_use_sru=(1, 4),
            )

            inp = np.random.randint(vocab_size, size=(batch_size, length))
            out = np.zeros((batch_size, length), dtype=np.int32)

            test_utils.test_eval_equals_predict((inp, out),
                                                model_fn,
                                                seq_tensor=1)
Exemple #18
0
    def test_custom_id_grad(self, backend):
        class IdWithIdGrad(tl.Layer):
            def forward(self, x):
                return x

            @property
            def has_backward(self):
                return True

            def backward(self, inputs, output, grad, weights, state, new_state,
                         rng):
                return (inputs, ())

        with fastmath.use_backend(backend):
            layer = IdWithIdGrad()
            rng = fastmath.random.get_prng(0)
            input_signature = shapes.ShapeDtype((9, 17))
            random_input = fastmath.random.uniform(rng,
                                                   input_signature.shape,
                                                   minval=-1.0,
                                                   maxval=1.0)
            layer.init(input_signature)
            f = lambda x: jnp.mean(layer(x))
            grad = fastmath.grad(f)(random_input)
            self.assertEqual(grad.shape, (9, 17))  # Gradient for each input.
            self.assertEqual(sum(sum(grad)),
                             sum(sum(random_input)))  # Same as input.
Exemple #19
0
    def _test_fast_inference(self, length):
        with fastmath.use_backend(fastmath.Backend.JAX):
            vocab_size = 16
            model_fn = functools.partial(
                ct.ConfigurableTransformerLM,
                vocab_size=vocab_size,
                d_model=4,
                d_ff=8,
                n_layers=2,
                n_heads=2,
            )
            model_slow = model_fn(mode='eval')
            model_fast = model_fn(mode='predict')
            rng = fastmath.random.get_prng(0)
            batch_size = 2
            input_signature = shapes.ShapeDtype((batch_size, 1), np.int32)
            # Given the same rng, both models initialize with the same parameters.
            model_slow.init(input_signature, rng)
            model_fast.init(input_signature, rng)

            buf = np.zeros((batch_size, length), dtype=np.int32)
            next_sym = np.zeros((batch_size, 1), dtype=np.int32)

            for index in range(length):
                logits_slow = model_slow(buf, rng=rng)
                logits_fast = model_fast(next_sym, rng=rng)
                np.testing.assert_array_almost_equal(
                    logits_slow[:, index, :],
                    logits_fast[:, 0, :],
                    decimal=5,
                )
                next_sym = np.random.randint(vocab_size, size=(batch_size, 1))
                buf[:, index] = next_sym[:, 0]
Exemple #20
0
 def test_forward_dtype(self, backend, dtype):
     with fastmath.use_backend(backend):
         layer = tl.BatchNorm()
         x = np.ones((3, 2, 7)).astype(dtype)
         _, _ = layer.init(shapes.signature(x))
         y = layer(x)
         self.assertEqual(y.dtype, dtype)
Exemple #21
0
    def test_batching_lsh_self_attention(self):
        with fastmath.use_backend(fastmath.Backend.JAX):
            common_kwargs = dict(
                n_heads=6,
                d_qk=7,
                d_v=17,
                causal=True,
                chunk_len=5,
                n_chunks_before=1,
                n_chunks_after=0,
                n_hashes=2,
                n_buckets=4,
                attention_dropout=0.2,
                output_dropout=0.1,
                mode='train',
            )
            test_kwargs = []
            for n_parallel_heads in [1, 3, 6, 12]:
                for use_python_loop in [True, False]:
                    test_kwargs.append(
                        dict(n_parallel_heads=n_parallel_heads,
                             use_python_loop=use_python_loop))

            x = jax.random.uniform(jax.random.PRNGKey(0), (2, 10, 13),
                                   dtype=jnp.float32)
            input_signature = shapes.signature(x)
            self._test_equivalence_to_reference_code(
                efficient_attention.LSHSelfAttention, x, input_signature,
                common_kwargs, *test_kwargs)
Exemple #22
0
    def test_train_restart(self, backend):
        if xla_bridge.device_count() > 1 and backend == fastmath.Backend.TFNP:
            self.skipTest(
                "tf-numpy backend doesn't support multi-devices yet.")
        with fastmath.use_backend(backend), self.tmp_dir() as output_dir:
            # Prepare model and inputs
            n_classes = 4
            steps = 2
            eval_steps = 2
            model_fn = functools.partial(models.MLP,
                                         d_hidden=16,
                                         n_output_classes=n_classes)
            inputs = _test_inputs(n_classes)

            # Train and evaluate
            trainer_lib.train(output_dir,
                              model=model_fn,
                              inputs=inputs,
                              steps=steps,
                              eval_steps=eval_steps)

            # Restart training
            state = trainer_lib.train(output_dir,
                                      model=model_fn,
                                      inputs=inputs,
                                      steps=(2 * steps),
                                      eval_steps=eval_steps)

            # Assert total train steps
            self.assertEqual(state.step, 2 * steps)
Exemple #23
0
    def test_fast_inference_self_attention(self):
        with fastmath.use_backend(fastmath.Backend.JAX):
            common_kwargs = dict(
                n_heads=6,
                d_qk=7,
                d_v=17,
                share_qk=False,
                causal=True,
                chunk_len=5,
                n_chunks_before=1,
                n_chunks_after=0,
                attention_dropout=0.0,
                output_dropout=0.0,
            )
            test_kwargs = []
            for n_parallel_heads in [1, 3, 6, 12]:
                for use_python_loop in [True, False]:
                    test_kwargs.append(
                        dict(n_parallel_heads=n_parallel_heads,
                             use_python_loop=use_python_loop))

            x = jax.random.uniform(jax.random.PRNGKey(0), (2, 10, 13),
                                   dtype=jnp.float32)
            input_signature = shapes.signature(x)
            self._test_fast_inference(efficient_attention.SelfAttention, x,
                                      input_signature, common_kwargs,
                                      *test_kwargs)
Exemple #24
0
    def test_train_restart_with_same_steps(self, backend):
        with fastmath.use_backend(backend):
            # Prepare model and inputs
            n_classes = 4
            steps = 2
            eval_steps = 2
            model_fn = functools.partial(models.MLP,
                                         layer_widths=(16, 16, n_classes))
            inputs = _test_inputs(n_classes)

            # Train and evaluate
            output_dir = self.create_tempdir().full_path
            trainer_lib.train(output_dir,
                              model=model_fn,
                              inputs=inputs,
                              steps=steps,
                              eval_steps=eval_steps,
                              eval_frequency=1)

            # Restart training
            loop = trainer_lib.train(output_dir,
                                     model=model_fn,
                                     inputs=inputs,
                                     steps=steps,
                                     eval_steps=eval_steps,
                                     eval_frequency=1)

            # Assert total train steps
            self.assertEqual(loop.step, steps)
Exemple #25
0
    def _test_lsh_self_attention_deterministic_given_seed(self, causal=False):
        # Once the initialization and the call seeds are pinned down we have
        # deterministic output.
        with fastmath.use_backend(fastmath.Backend.JAX):
            layer = efficient_attention.LSHSelfAttention(
                n_heads=5,
                d_qk=7,
                d_v=17,
                causal=causal,
                chunk_len=8,
                n_chunks_before=1,
                n_chunks_after=0,
                n_hashes=2,
                n_buckets=4,
                use_reference_code=True,
                attention_dropout=0.0,
                mode='train')
            x = np.ones((3, 32, 8)).astype(np.float32)

            def get_output():
                _, _ = layer.init(shapes.signature(x), jax.random.PRNGKey(0))
                return layer(x, rng=jax.random.PRNGKey(1))

            ys = [get_output() for _ in range(10)]

            self.assertEqual(ys[0].shape, x.shape)

            for y in ys[1:]:
                np.testing.assert_array_almost_equal(ys[0], y, decimal=6)
Exemple #26
0
    def test_no_int32_or_uint32_returned(self):
        """Tests that Trainer._jit_update_fn doesn't return int32 or uint32.

    TF pins int32/uint32 tensors to CPU, which will cause XLA-forced-compiled
    computation to copy int32/uint32 outputs to CPU. This test makes sure that
    won't happen.
    """
        with fastmath.use_backend(fastmath.Backend.TFNP):
            n_classes = 1001
            model_fn = functools.partial(models.Resnet50,
                                         n_output_classes=n_classes)
            inputs = _test_inputs(n_classes, input_shape=(224, 224, 3))
            trainer = trainer_lib.Trainer(
                model=model_fn,
                loss_fn=tl.WeightedCategoryCrossEntropy(),
                optimizer=trax_opt.SM3,
                lr_schedule=lr.multifactor(),
                inputs=inputs,
            )
            output_dir = self.create_tempdir().full_path
            trainer.reset(output_dir)
            trainer.train_epoch(1, 0)
            # Those are the things returned by Trainer._jit_update_fn
            arrays = (trainer._opt_state.weights, trainer._opt_state.slots,
                      trainer._model_state, trainer._rngs)
            arrays = tf.nest.flatten(arrays)
            for x in arrays:
                if isinstance(x, jnp.ndarray) and (x.dtype == jnp.int32
                                                   or x.dtype == jnp.uint32):
                    raise ValueError('Found an array of int32 or uint32: %s' %
                                     x)
Exemple #27
0
    def test_predict_equals_eval(self):
        with fastmath.use_backend(fastmath.Backend.JAX):
            d_model = 32
            seq_len = 5
            x_shape = (1, seq_len, d_model)
            inp = np.ones(x_shape).astype(np.float32)

            model_fn = functools.partial(
                sparsity.MultiplicativeConvCausalAttention,
                d_feature=d_model,
                n_heads=4,
                sparsity=4,
            )

            list_kwargs = []
            for share_qk in [True, False]:
                for output in ['none', 'mult', 'conv', 'multconv']:
                    for concat in ['original', 'fixed', 'none']:
                        kwargs = {
                            'share_qk': share_qk,
                            'output_layer_type': output,
                            'v_concat_type': concat
                        }
                        list_kwargs.append(kwargs)

            test_utils.test_eval_equals_predict_configs(
                inp, model_fn, list_kwargs)
Exemple #28
0
  def test_reformer2_deterministic_eval(self):
    with fastmath.use_backend(fastmath.Backend.JAX):
      vocab_size = 16
      d_model = 4
      batch_size = 2
      length = 5

      model_fn = functools.partial(
          reformer.Reformer2,
          vocab_size,
          d_model=d_model,
          d_ff=16,
          n_encoder_layers=0,
          n_decoder_layers=1,
          n_heads=2,
          dropout=0.0,
          max_len=length*2,
          pos_type=None,
          encoder_attention_type=tl.Attention,
          encoder_decoder_attention_type=tl.CausalAttention,
      )

      inp = np.random.randint(vocab_size, size=(batch_size, length))
      out = np.zeros((batch_size, length), dtype=np.int32)

      test_utils.test_eval_is_deterministic((inp, out), model_fn)
    def test_lsh_self_attention_masked_non_causal(self):
        # Test that when the input that is in the masked area changes the attention
        # for the un-masked outputs doesn't change, but the masked region does
        # change.
        with fastmath.use_backend(fastmath.Backend.JAX):
            layer = efficient_attention.LSHSelfAttention(
                n_heads=5,
                d_qk=7,
                d_v=17,
                causal=False,
                masked=True,
                chunk_len=8,
                n_chunks_before=1,
                n_chunks_after=0,
                n_hashes=2,
                n_buckets=4,
                use_reference_code=True,
                attention_dropout=0.0,
                mode='train')

            batch = 5
            max_len = 32
            hidden = 8

            x = np.random.uniform(size=(batch, max_len, hidden))
            mask = np.ones((batch, max_len)).astype(np.bool)
            rngs = jax.random.randint(jax.random.PRNGKey(0), (batch, ),
                                      minval=1,
                                      maxval=max_len - 1)

            # Set some suffix of each mask[b] to 0.
            for i in range(batch):
                mask[i, rngs[i]:] = 0

            # Fix rngs and get the output for the LSH layer.
            def get_output(x, mask):
                xs = [x, mask]
                _, _ = layer.init(shapes.signature(xs), jax.random.PRNGKey(0))
                return layer(xs, rng=jax.random.PRNGKey(1))

            # Get the attention output for masked x.
            y = get_output(x, mask)

            # Change x, but only in the masked regions.
            for i in range(batch):
                x[i, rngs[i]:] = np.random.uniform(size=(max_len - rngs[i],
                                                         hidden))

            y2 = get_output(x, mask)

            for i in range(batch):
                # y and y2 should be identical in the non-masked part.
                np.testing.assert_array_almost_equal(y[i, :rngs[i]],
                                                     y2[i, :rngs[i]],
                                                     decimal=6)

                # In the masked out part, they should be different.
                self.assertGreater(
                    np.mean(np.abs(y[i, rngs[i]:] - y2[i, rngs[i]:])), 1e-5)
Exemple #30
0
 def test_names(self, backend):
   with fastmath.use_backend(backend):
     layer = tl.LSTM(3)
     self.assertEqual('LSTM_3', str(layer))
     layer = tl.GRU(5)
     self.assertEqual('GRU_5', str(layer))
     layer = tl.SRU(7)
     self.assertEqual('SRU_7', str(layer))