def learning_rate(self, step): """Return the learning rate for the given step.""" if self._lr_schedule is not None: with fastmath.use_backend(fastmath.Backend.NUMPY): return self._lr_schedule(step) params = self._optimizer._init_opt_params # pylint: disable=protected-access return params['learning_rate']
def test_lsh_ff(self): with fastmath.use_backend(fastmath.Backend.JAX): layer = efficient_attention.LSHFF(d_ff=1024 * 8, n_buckets=[16, 8]) x = np.ones((3, 7, 1024)).astype(np.float32) _, _ = layer.init(shapes.signature(x)) y = layer(x) self.assertEqual(y.shape, x.shape)
def test_no_int32_or_uint32_returned(self): """Tests that Trainer._jit_update_fn doesn't return int32 or uint32. TF pins int32/uint32 tensors to CPU, which will cause XLA-forced-compiled computation to copy int32/uint32 outputs to CPU. This test makes sure that won't happen. """ if xla_bridge.device_count() > 1: self.skipTest( "tf-numpy backend doesn't support multi-devices yet.") with fastmath.use_backend(fastmath.Backend.TFNP), \ self.tmp_dir() as output_dir: n_classes = 1001 model_fn = functools.partial(models.Resnet50, n_output_classes=n_classes) inputs = _test_inputs(n_classes, input_shape=(224, 224, 3)) trainer = trainer_lib.Trainer( model=model_fn, loss_fn=layers.CrossEntropyLoss(), optimizer=trax_opt.SM3, lr_schedule=lr.multifactor(), inputs=inputs, ) trainer.reset(output_dir) trainer.train_epoch(1, 0) # Those are the things returned by Trainer._jit_update_fn arrays = (trainer._opt_state.weights, trainer._opt_state.slots, trainer._model_state, trainer._rngs) arrays = tf.nest.flatten(arrays) for x in arrays: if isinstance(x, jnp.ndarray) and (x.dtype == jnp.int32 or x.dtype == jnp.uint32): raise ValueError('Found an array of int32 or uint32: %s' % x)
def test_train_fills_in_missing_eval_metrics(self, backend): with fastmath.use_backend(backend): # Prepare model and inputs n_classes = 4 steps = 2 eval_steps = 2 model_fn = functools.partial(models.MLP, layer_widths=(16, 16, n_classes)) inputs = _test_inputs(n_classes) additional_eval_stream = trainer_lib.NamedStream( # deliberately duplicating eval data stream=inputs.eval_stream(1), name='additional_eval_task') # Train and evaluate output_dir = self.create_tempdir().full_path loop = trainer_lib.train( output_dir, model=model_fn, inputs=inputs, steps=steps, eval_steps=eval_steps, eval_frequency=1, additional_eval_streams=[additional_eval_stream]) self.assertLen(loop.eval_tasks, 2) eval_task_1, eval_task_2 = loop.eval_tasks self.assertCountEqual(eval_task_1.metrics, eval_task_2.metrics) self.assertCountEqual(eval_task_1.metric_names, eval_task_2.metric_names)
def test_blocksparse_ff_predict_equals_eval(self): d_model = 1024 n_experts = 64 d_ff = d_model * 8 x_shape = (1, 1, d_model) temperature = 0.7 with fastmath.use_backend(fastmath.Backend.JAX): x = np.ones(x_shape).astype(np.float32) input_signature = shapes.signature(x) common_kwargs = dict( d_ff=d_ff, n_experts=n_experts, temperature=temperature, ) eval_model = sparsity.BlockSparseFF( mode='eval', **common_kwargs) weights, state = eval_model.init(input_signature) eval_out, _ = eval_model.pure_fn( x, weights, state, rng=jax.random.PRNGKey(0)) pred_model = sparsity.BlockSparseFF( mode='predict', **common_kwargs) _, _ = pred_model.init(input_signature) pred_out, _ = pred_model.pure_fn( x, weights, state, rng=jax.random.PRNGKey(0)) self.assertEqual(eval_out.shape, x.shape) # eval_out and pred_out should be identical. np.testing.assert_array_almost_equal(eval_out[0, 0, :], pred_out[0, 0, :])
def test_autoregressive_sample_transformerlm_tfnp(self): with fastmath.use_backend(fastmath.Backend.TFNP): model = models.TransformerLM(10, d_model=32, d_ff=64, n_layers=1, n_heads=2, mode='predict') model.init(shapes.ShapeDtype((1, 1), dtype=np.int32)) s1 = decoding.autoregressive_sample(model, batch_size=1, eos_id=-1, max_length=10) self.assertEqual(s1.shape[0], 1) self.assertEqual(s1.shape[1], 10) batch_per_device = 2 // fastmath.device_count() model.init(shapes.ShapeDtype((batch_per_device, 1), dtype=np.int32)) s2 = decoding.autoregressive_sample(model, batch_size=2, max_length=10) self.assertEqual(s2.shape[0], 2) self.assertLess(s2.shape[1], 11) model.init(shapes.ShapeDtype((1, 1), dtype=np.int32)) prefix = np.array([[1, 2, 3]]) s3 = decoding.autoregressive_sample(model, prefix, eos_id=-1, max_length=10, batch_size=1) self.assertEqual(s3.shape[0], 1) self.assertEqual(s3.shape[1], 10)
def _test_train_eval_predict(self, backend): with fastmath.use_backend(backend): # Prepare model and inputs n_classes = 4 steps = 2 eval_steps = 2 # Adds Dropout and BatchNorm to test state handling. def model_fn(mode='train'): return tl.Serial( tl.Dropout(mode=mode, rate=0.1), tl.BatchNorm(mode=mode), models.MLP(d_hidden=16, n_output_classes=n_classes, mode=mode)) inputs = _test_inputs(n_classes) # Train and evaluate output_dir = self.create_tempdir().full_path loop = trainer_lib.train(output_dir, model=model_fn, inputs=inputs, steps=steps, eval_steps=eval_steps, eval_frequency=1) # eval at every step. # Assert total train steps self.assertEqual(steps, loop.step) # Predict with final weights inputs = inputs.train_stream(1) model = model_fn() weights = loop.model.weights state = loop.model.state model(next(inputs)[0], weights=weights, state=state)
def test_grulm_forward_shape(self, backend): with fastmath.use_backend(backend): model = rnn.GRULM(vocab_size=20, d_model=16) x = np.ones((3, 28)).astype(np.int32) _, _ = model.init(shapes.signature(x)) y = model(x) self.assertEqual(y.shape, (3, 28, 20))
def test_train_eval_predict_sm3(self, backend): with fastmath.use_backend(backend): # Prepare model and inputs n_classes = 4 steps = 2 eval_steps = 2 model_fn = functools.partial(models.MLP, d_hidden=16, n_output_classes=n_classes) inputs = _test_inputs(n_classes) # Train and evaluate output_dir = self.create_tempdir().full_path loop = trainer_lib.train( output_dir, model=model_fn, inputs=inputs, steps=steps, eval_steps=eval_steps, eval_frequency=1, # eval every step. optimizer=trax_opt.SM3) # Assert total train steps self.assertEqual(steps, loop.step) # Predict with weights loaded from file. inputs = inputs.train_stream(1) model = model_fn() model.init_from_file(os.path.join(output_dir, 'model.pkl.gz')) model(next(inputs)[0])
def test_train_restart(self, backend): with fastmath.use_backend(backend): # Prepare model and inputs n_classes = 4 steps = 2 eval_steps = 2 model_fn = functools.partial(models.MLP, d_hidden=16, n_output_classes=n_classes) inputs = _test_inputs(n_classes) # Train and evaluate output_dir = self.create_tempdir().full_path trainer_lib.train(output_dir, model=model_fn, inputs=inputs, steps=steps, eval_steps=eval_steps, eval_frequency=1) # Restart training loop = trainer_lib.train(output_dir, model=model_fn, inputs=inputs, steps=(2 * steps), eval_steps=eval_steps, eval_frequency=1) # Assert total train steps - with loop we don't resume, but train for as # many steps as given, so: steps + 2*steps = 3*steps. self.assertEqual(loop.step, 3 * steps)
def test_sru(self, backend): with fastmath.use_backend(backend): layer = tl.SRU(7) x = np.ones((8, 9, 7), np.float32) _, _ = layer.init(shapes.signature(x)) y = layer(x) self.assertEqual(y.shape, x.shape)
def test_lstm_cell(self, backend): with fastmath.use_backend(backend): layer = tl.LSTMCell(9) xs = [np.ones((8, 9)), np.ones((8, 18))] _, _ = layer.init(shapes.signature(xs)) ys = layer(xs) self.assertEqual([y.shape for y in ys], [(8, 9), (8, 18)])
def test_conv_gru_cell(self, backend): with fastmath.use_backend(backend): layer = tl.ConvGRUCell(9, kernel_size=(3, 3)) x = np.ones((8, 1, 7, 9)) _, _ = layer.init(shapes.signature(x)) y = layer(x) self.assertEqual(y.shape, x.shape)
def test_reformer2_predict_equals_eval(self): with fastmath.use_backend(fastmath.Backend.JAX): vocab_size = 16 d_model = 8 batch_size = 2 length = 5 model_fn = functools.partial( reformer.Reformer2, vocab_size, d_model=d_model, d_ff=16, n_encoder_layers=1, n_decoder_layers=1, n_heads=2, dropout=0.0, max_len=length*2, pos_type=None, n_decoder_attention_layers=1, encoder_attention_type=tl.Attention, encoder_decoder_attention_type=tl.CausalAttention, ) inp = np.random.randint(vocab_size, size=(batch_size, length)) out = np.zeros((batch_size, length), dtype=np.int32) # TODO(jaszczur): check why init_tokens > 1 fails nondeterministically test_utils.test_eval_equals_predict((inp, out), model_fn, 1, -1, init_tokens=1)
def test_multi_input(self, backend): def _MultiInputFn(): # pylint: disable=invalid-name def f(a, b, carry): return a + b, b, carry + 1 return tl.Fn('MultiInputFn', f, n_out=2) with fastmath.use_backend(backend): layer = tl.Scan(_MultiInputFn(), axis=1) xs = [ np.array([[0, 1, 2], [0, 10, 20]]), np.array([[4, 5, 6], [40, 50, 60]]), np.array([9000, 8000]) ] ys = layer(xs) self.assertEqual(as_list(ys), [[[4, 6, 8], [40, 60, 80]], [[4, 5, 6], [40, 50, 60]], [9003, 8003] ])
def test_sparse_ff_predict_equals_eval(self): with fastmath.use_backend(fastmath.Backend.JAX): d_model = 64 seq_len = 6 x_shape = (1, seq_len, d_model) inp = np.ones(x_shape).astype(np.float32) model_fn = functools.partial( sparsity.SparseFF, d_ff=256, temperature=0.7, n_elements_in_block=8, ) configs = [ { 'multiply_by_controller_output': True }, { 'multiply_by_controller_output': False }, { 'ff_chunk_size': 2 }, ] test_utils.test_eval_equals_predict_configs(inp, model_fn, configs)
def _test_sparse_fast_inference(self, length): with fastmath.use_backend(fastmath.Backend.JAX): vocab_size = 16 d_model = 4 batch_size = 2 encoder_decoder_attention_type = functools.partial( tl.MultiplicativeConvCausalAttention, sparsity=2, length_kernel_size=1, ) model_fn = functools.partial( ct.ConfigurableTransformer, input_vocab_size=vocab_size, d_model=d_model, d_ff=8, n_encoder_layers=2, n_decoder_layers=2, n_heads=2, loss_sparsity=2, ff_sparsity=2, encoder_decoder_attention_type=encoder_decoder_attention_type, ff_use_sru=(1, 4), ) inp = np.random.randint(vocab_size, size=(batch_size, length)) out = np.zeros((batch_size, length), dtype=np.int32) test_utils.test_eval_equals_predict((inp, out), model_fn, seq_tensor=1)
def test_custom_id_grad(self, backend): class IdWithIdGrad(tl.Layer): def forward(self, x): return x @property def has_backward(self): return True def backward(self, inputs, output, grad, weights, state, new_state, rng): return (inputs, ()) with fastmath.use_backend(backend): layer = IdWithIdGrad() rng = fastmath.random.get_prng(0) input_signature = shapes.ShapeDtype((9, 17)) random_input = fastmath.random.uniform(rng, input_signature.shape, minval=-1.0, maxval=1.0) layer.init(input_signature) f = lambda x: jnp.mean(layer(x)) grad = fastmath.grad(f)(random_input) self.assertEqual(grad.shape, (9, 17)) # Gradient for each input. self.assertEqual(sum(sum(grad)), sum(sum(random_input))) # Same as input.
def _test_fast_inference(self, length): with fastmath.use_backend(fastmath.Backend.JAX): vocab_size = 16 model_fn = functools.partial( ct.ConfigurableTransformerLM, vocab_size=vocab_size, d_model=4, d_ff=8, n_layers=2, n_heads=2, ) model_slow = model_fn(mode='eval') model_fast = model_fn(mode='predict') rng = fastmath.random.get_prng(0) batch_size = 2 input_signature = shapes.ShapeDtype((batch_size, 1), np.int32) # Given the same rng, both models initialize with the same parameters. model_slow.init(input_signature, rng) model_fast.init(input_signature, rng) buf = np.zeros((batch_size, length), dtype=np.int32) next_sym = np.zeros((batch_size, 1), dtype=np.int32) for index in range(length): logits_slow = model_slow(buf, rng=rng) logits_fast = model_fast(next_sym, rng=rng) np.testing.assert_array_almost_equal( logits_slow[:, index, :], logits_fast[:, 0, :], decimal=5, ) next_sym = np.random.randint(vocab_size, size=(batch_size, 1)) buf[:, index] = next_sym[:, 0]
def test_forward_dtype(self, backend, dtype): with fastmath.use_backend(backend): layer = tl.BatchNorm() x = np.ones((3, 2, 7)).astype(dtype) _, _ = layer.init(shapes.signature(x)) y = layer(x) self.assertEqual(y.dtype, dtype)
def test_batching_lsh_self_attention(self): with fastmath.use_backend(fastmath.Backend.JAX): common_kwargs = dict( n_heads=6, d_qk=7, d_v=17, causal=True, chunk_len=5, n_chunks_before=1, n_chunks_after=0, n_hashes=2, n_buckets=4, attention_dropout=0.2, output_dropout=0.1, mode='train', ) test_kwargs = [] for n_parallel_heads in [1, 3, 6, 12]: for use_python_loop in [True, False]: test_kwargs.append( dict(n_parallel_heads=n_parallel_heads, use_python_loop=use_python_loop)) x = jax.random.uniform(jax.random.PRNGKey(0), (2, 10, 13), dtype=jnp.float32) input_signature = shapes.signature(x) self._test_equivalence_to_reference_code( efficient_attention.LSHSelfAttention, x, input_signature, common_kwargs, *test_kwargs)
def test_train_restart(self, backend): if xla_bridge.device_count() > 1 and backend == fastmath.Backend.TFNP: self.skipTest( "tf-numpy backend doesn't support multi-devices yet.") with fastmath.use_backend(backend), self.tmp_dir() as output_dir: # Prepare model and inputs n_classes = 4 steps = 2 eval_steps = 2 model_fn = functools.partial(models.MLP, d_hidden=16, n_output_classes=n_classes) inputs = _test_inputs(n_classes) # Train and evaluate trainer_lib.train(output_dir, model=model_fn, inputs=inputs, steps=steps, eval_steps=eval_steps) # Restart training state = trainer_lib.train(output_dir, model=model_fn, inputs=inputs, steps=(2 * steps), eval_steps=eval_steps) # Assert total train steps self.assertEqual(state.step, 2 * steps)
def test_fast_inference_self_attention(self): with fastmath.use_backend(fastmath.Backend.JAX): common_kwargs = dict( n_heads=6, d_qk=7, d_v=17, share_qk=False, causal=True, chunk_len=5, n_chunks_before=1, n_chunks_after=0, attention_dropout=0.0, output_dropout=0.0, ) test_kwargs = [] for n_parallel_heads in [1, 3, 6, 12]: for use_python_loop in [True, False]: test_kwargs.append( dict(n_parallel_heads=n_parallel_heads, use_python_loop=use_python_loop)) x = jax.random.uniform(jax.random.PRNGKey(0), (2, 10, 13), dtype=jnp.float32) input_signature = shapes.signature(x) self._test_fast_inference(efficient_attention.SelfAttention, x, input_signature, common_kwargs, *test_kwargs)
def test_train_restart_with_same_steps(self, backend): with fastmath.use_backend(backend): # Prepare model and inputs n_classes = 4 steps = 2 eval_steps = 2 model_fn = functools.partial(models.MLP, layer_widths=(16, 16, n_classes)) inputs = _test_inputs(n_classes) # Train and evaluate output_dir = self.create_tempdir().full_path trainer_lib.train(output_dir, model=model_fn, inputs=inputs, steps=steps, eval_steps=eval_steps, eval_frequency=1) # Restart training loop = trainer_lib.train(output_dir, model=model_fn, inputs=inputs, steps=steps, eval_steps=eval_steps, eval_frequency=1) # Assert total train steps self.assertEqual(loop.step, steps)
def _test_lsh_self_attention_deterministic_given_seed(self, causal=False): # Once the initialization and the call seeds are pinned down we have # deterministic output. with fastmath.use_backend(fastmath.Backend.JAX): layer = efficient_attention.LSHSelfAttention( n_heads=5, d_qk=7, d_v=17, causal=causal, chunk_len=8, n_chunks_before=1, n_chunks_after=0, n_hashes=2, n_buckets=4, use_reference_code=True, attention_dropout=0.0, mode='train') x = np.ones((3, 32, 8)).astype(np.float32) def get_output(): _, _ = layer.init(shapes.signature(x), jax.random.PRNGKey(0)) return layer(x, rng=jax.random.PRNGKey(1)) ys = [get_output() for _ in range(10)] self.assertEqual(ys[0].shape, x.shape) for y in ys[1:]: np.testing.assert_array_almost_equal(ys[0], y, decimal=6)
def test_no_int32_or_uint32_returned(self): """Tests that Trainer._jit_update_fn doesn't return int32 or uint32. TF pins int32/uint32 tensors to CPU, which will cause XLA-forced-compiled computation to copy int32/uint32 outputs to CPU. This test makes sure that won't happen. """ with fastmath.use_backend(fastmath.Backend.TFNP): n_classes = 1001 model_fn = functools.partial(models.Resnet50, n_output_classes=n_classes) inputs = _test_inputs(n_classes, input_shape=(224, 224, 3)) trainer = trainer_lib.Trainer( model=model_fn, loss_fn=tl.WeightedCategoryCrossEntropy(), optimizer=trax_opt.SM3, lr_schedule=lr.multifactor(), inputs=inputs, ) output_dir = self.create_tempdir().full_path trainer.reset(output_dir) trainer.train_epoch(1, 0) # Those are the things returned by Trainer._jit_update_fn arrays = (trainer._opt_state.weights, trainer._opt_state.slots, trainer._model_state, trainer._rngs) arrays = tf.nest.flatten(arrays) for x in arrays: if isinstance(x, jnp.ndarray) and (x.dtype == jnp.int32 or x.dtype == jnp.uint32): raise ValueError('Found an array of int32 or uint32: %s' % x)
def test_predict_equals_eval(self): with fastmath.use_backend(fastmath.Backend.JAX): d_model = 32 seq_len = 5 x_shape = (1, seq_len, d_model) inp = np.ones(x_shape).astype(np.float32) model_fn = functools.partial( sparsity.MultiplicativeConvCausalAttention, d_feature=d_model, n_heads=4, sparsity=4, ) list_kwargs = [] for share_qk in [True, False]: for output in ['none', 'mult', 'conv', 'multconv']: for concat in ['original', 'fixed', 'none']: kwargs = { 'share_qk': share_qk, 'output_layer_type': output, 'v_concat_type': concat } list_kwargs.append(kwargs) test_utils.test_eval_equals_predict_configs( inp, model_fn, list_kwargs)
def test_reformer2_deterministic_eval(self): with fastmath.use_backend(fastmath.Backend.JAX): vocab_size = 16 d_model = 4 batch_size = 2 length = 5 model_fn = functools.partial( reformer.Reformer2, vocab_size, d_model=d_model, d_ff=16, n_encoder_layers=0, n_decoder_layers=1, n_heads=2, dropout=0.0, max_len=length*2, pos_type=None, encoder_attention_type=tl.Attention, encoder_decoder_attention_type=tl.CausalAttention, ) inp = np.random.randint(vocab_size, size=(batch_size, length)) out = np.zeros((batch_size, length), dtype=np.int32) test_utils.test_eval_is_deterministic((inp, out), model_fn)
def test_lsh_self_attention_masked_non_causal(self): # Test that when the input that is in the masked area changes the attention # for the un-masked outputs doesn't change, but the masked region does # change. with fastmath.use_backend(fastmath.Backend.JAX): layer = efficient_attention.LSHSelfAttention( n_heads=5, d_qk=7, d_v=17, causal=False, masked=True, chunk_len=8, n_chunks_before=1, n_chunks_after=0, n_hashes=2, n_buckets=4, use_reference_code=True, attention_dropout=0.0, mode='train') batch = 5 max_len = 32 hidden = 8 x = np.random.uniform(size=(batch, max_len, hidden)) mask = np.ones((batch, max_len)).astype(np.bool) rngs = jax.random.randint(jax.random.PRNGKey(0), (batch, ), minval=1, maxval=max_len - 1) # Set some suffix of each mask[b] to 0. for i in range(batch): mask[i, rngs[i]:] = 0 # Fix rngs and get the output for the LSH layer. def get_output(x, mask): xs = [x, mask] _, _ = layer.init(shapes.signature(xs), jax.random.PRNGKey(0)) return layer(xs, rng=jax.random.PRNGKey(1)) # Get the attention output for masked x. y = get_output(x, mask) # Change x, but only in the masked regions. for i in range(batch): x[i, rngs[i]:] = np.random.uniform(size=(max_len - rngs[i], hidden)) y2 = get_output(x, mask) for i in range(batch): # y and y2 should be identical in the non-masked part. np.testing.assert_array_almost_equal(y[i, :rngs[i]], y2[i, :rngs[i]], decimal=6) # In the masked out part, they should be different. self.assertGreater( np.mean(np.abs(y[i, rngs[i]:] - y2[i, rngs[i]:])), 1e-5)
def test_names(self, backend): with fastmath.use_backend(backend): layer = tl.LSTM(3) self.assertEqual('LSTM_3', str(layer)) layer = tl.GRU(5) self.assertEqual('GRU_5', str(layer)) layer = tl.SRU(7) self.assertEqual('SRU_7', str(layer))