def test_autoregressive_sample_reformer2_lsh_attn_quality(self): gin.add_config_file_search_path(_CONFIG_DIR) max_len = 32 # 32 is the max length we trained the checkpoint for. test_lengths = [8, 16, 32] vocab_size = 13 # The checkpoint is correct on ~90% sequences, set random seed to deflake. np.random.seed(0) for test_len in test_lengths: gin.clear_config() gin.parse_config_file('reformer2_copy.gin') gin.bind_parameter('LSHSelfAttention.predict_mem_len', 2 * max_len) gin.bind_parameter('LSHSelfAttention.predict_drop_len', 2 * max_len) pred_model = models.Reformer2(mode='predict') shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) model_path = os.path.join(_TESTDATA, 'reformer2_copy_lsh_attn.pkl.gz') pred_model.init_from_file(model_path, weights_only=True, input_signature=(shape1l, shape11)) initial_state = pred_model.state for _ in range(2): # Set low to make the test run reasonably fast. # Pick a length in [1, test_len] at random. inp_len = np.random.randint(low=1, high=test_len + 1) inputs = np.random.randint(low=1, high=vocab_size-1, size=(1, inp_len)) inputs = np.pad(inputs, [(0, 0), (0, max_len - inp_len)], mode='constant', constant_values=0) s = decoding.autoregressive_sample( pred_model, inputs=inputs, eos_id=-1, max_length=inp_len, temperature=0.0) np.testing.assert_equal(s[0], inputs[0, :inp_len]) pred_model.state = initial_state gin.clear_config() # Make sure to not affect other tests.
def test_autoregressive_sample_transformerlm(self): model = models.TransformerLM(10, d_model=32, d_ff=64, n_layers=1, n_heads=2, mode='predict') model.init(shapes.ShapeDtype((1, 1), dtype=jnp.int32)) s1 = trainer_lib.autoregressive_sample(model, batch_size=1, eos_id=-1, max_length=10) self.assertEqual(s1.shape[0], 1) self.assertEqual(s1.shape[1], 10) batch_per_device = 2 // fastmath.device_count() model.init(shapes.ShapeDtype((batch_per_device, 1), dtype=jnp.int32)) s2 = trainer_lib.autoregressive_sample(model, batch_size=2, max_length=10) self.assertEqual(s2.shape[0], 2) self.assertLess(s2.shape[1], 11) model.init(shapes.ShapeDtype((1, 1), dtype=jnp.int32)) prefix = jnp.array([[1, 2, 3]]) s3 = trainer_lib.autoregressive_sample(model, eos_id=-1, max_length=10, batch_size=1, prefix=prefix) self.assertEqual(s3.shape[0], 1) self.assertEqual(int(s3[0][0]), 1) self.assertEqual(int(s3[0][1]), 2) self.assertEqual(int(s3[0][2]), 3)
def test_can_predict_with_trained_model(self): model = tl.Serial(tl.Dense(3), tl.Branch(tl.Dense(1), tl.Dense(2))) train_tasks, eval_tasks = [], [] for output_dim in [1, 2]: # The head we select from the model: 0 for output_dim 1 and 1 for 2. head_index = output_dim - 1 train_tasks.append( training.TrainTask( _very_simple_data(output_dim), tl.Serial(tl.Select([head_index], n_in=2), tl.L2Loss()), optimizers.SGD(.01))) eval_tasks.append( training.EvalTask( _very_simple_data( output_dim), # deliberately re-use training data [tl.Serial(tl.Select([head_index], n_in=2), tl.L2Loss())])) tmp_dir = self.create_tempdir().full_path training_session = training.Loop( model, tasks=train_tasks, eval_tasks=eval_tasks, checkpoint_at=lambda step_n: step_n == 1, output_dir=tmp_dir, which_task=lambda step_n: step_n % 2, ) training_session.run(n_steps=2) trained_model = training_session.eval_model inp = next(_very_simple_data())[0] out = trained_model(inp) self.assertEqual( shapes.signature(out), (shapes.ShapeDtype((8, 1)), shapes.ShapeDtype((8, 2))), )
def test_autoregressive_sample_transformerlm_tfnp(self): with fastmath.use_backend(fastmath.Backend.TFNP): model = models.TransformerLM(10, d_model=32, d_ff=64, n_layers=1, n_heads=2, mode='predict') model.init(shapes.ShapeDtype((1, 1), dtype=np.int32)) s1 = decoding.autoregressive_sample(model, batch_size=1, eos_id=-1, max_length=10) self.assertEqual(s1.shape[0], 1) self.assertEqual(s1.shape[1], 10) batch_per_device = 2 // fastmath.device_count() model.init(shapes.ShapeDtype((batch_per_device, 1), dtype=np.int32)) s2 = decoding.autoregressive_sample(model, batch_size=2, max_length=10) self.assertEqual(s2.shape[0], 2) self.assertLess(s2.shape[1], 11) model.init(shapes.ShapeDtype((1, 1), dtype=np.int32)) prefix = np.array([[1, 2, 3]]) s3 = decoding.autoregressive_sample(model, prefix, eos_id=-1, max_length=10, batch_size=1) self.assertEqual(s3.shape[0], 1) self.assertEqual(s3.shape[1], 10)
def test_can_predict_with_trained_model(self): model = tl.Serial(tl.Dense(3), tl.Branch(tl.Dense(1), tl.Dense(2))) tasks = tuple( training.TrainTask( # pylint: disable=g-complex-comprehension _very_simple_data(output_dim), tl.L2Loss(), optimizers.SGD(.01), ) for output_dim in (1, 2)) eval_tasks = tuple([ training.EvalTask( # pylint: disable=g-complex-comprehension # deliberately re-using training data _very_simple_data(output_dim), [tl.L2Loss()], ) ] for output_dim in (1, 2)) tmp_dir = self.create_tempdir().full_path training_session = training.Loop( model, tasks=tasks, eval_tasks=eval_tasks, checkpoint_at=lambda step_n: step_n == 1, output_dir=tmp_dir, which_task=lambda step_n: step_n % 2, ) training_session.run(n_steps=2) trained_model = training_session.eval_model inp = next(_very_simple_data())[0] out = trained_model(inp) self.assertEqual( shapes.signature(out), (shapes.ShapeDtype((8, 1)), shapes.ShapeDtype((8, 2))), )
def test_extract_inner_model(self): vocab_size = 3 inner_model = transformer.TransformerLM(vocab_size=vocab_size, d_model=2, d_ff=2, n_layers=0) obs_serializer = space_serializer.create(gym.spaces.Discrete(2), vocab_size=vocab_size) act_serializer = space_serializer.create(gym.spaces.Discrete(2), vocab_size=vocab_size) serialized_model = serialization_utils.SerializedModel( inner_model, observation_serializer=obs_serializer, action_serializer=act_serializer, significance_decay=0.9, ) obs_sig = shapes.ShapeDtype((1, 2)) act_sig = shapes.ShapeDtype((1, 1)) (weights, state) = serialized_model.init(input_signature=(obs_sig, act_sig, obs_sig, obs_sig), ) (inner_weights, inner_state) = map(serialization_utils.extract_inner_model, (weights, state)) inner_model(np.array([[0]]), weights=inner_weights, state=inner_state)
def test_serialized_model_extracts_seq_model_weights_and_state(self): vocab_size = 3 seq_model_fn = functools.partial( transformer.TransformerLM, vocab_size=vocab_size, d_model=2, d_ff=2, n_layers=0, ) seq_model = seq_model_fn(mode='eval') obs_serializer = space_serializer.create(gym.spaces.Discrete(2), vocab_size=vocab_size) act_serializer = space_serializer.create(gym.spaces.Discrete(2), vocab_size=vocab_size) serialized_model = serialization_utils.SerializedModel( seq_model_fn, observation_serializer=obs_serializer, action_serializer=act_serializer, significance_decay=0.9, ) obs_sig = shapes.ShapeDtype((1, 2)) act_sig = shapes.ShapeDtype((1, 1)) serialized_model.init(input_signature=(obs_sig, act_sig, obs_sig, obs_sig)) seq_model.weights = serialized_model.seq_model_weights seq_model.state = serialized_model.seq_model_state # Run the model to check if the weights and state have correct structure. seq_model(jnp.array([[0]]))
def _make_schedule( self, history, control_configs, observation_metrics=(('eval', 'metrics/accuracy'), ), action_multipliers=(1.0, ), vocab_size=None, ): policy_and_value_model = functools.partial( transformer.TransformerDecoder, d_model=2, d_ff=2, n_layers=0, vocab_size=vocab_size, ) observation_space = gym.spaces.Box( shape=(len(observation_metrics), ), low=0.0, high=1.0, ) action_space = gym.spaces.MultiDiscrete( nvec=(len(action_multipliers), ) * len(control_configs)) (net, _) = policy_based_utils.policy_and_value_net( bottom_layers_fn=policy_and_value_model, observation_space=observation_space, action_space=action_space, vocab_size=vocab_size, two_towers=False, ) input_signature = ( shapes.ShapeDtype((1, 2) + observation_space.shape, observation_space.dtype), shapes.ShapeDtype((1, 1) + action_space.shape, action_space.dtype), ) (params, state) = net.init(input_signature) policy_dir = self.get_temp_dir() # Optimizer slots and parameters should not be used for anything. slots = None opt_params = None opt_state = (params, slots, opt_params) policy_based_utils.save_opt_state(policy_dir, opt_state, state, epoch=0, total_opt_step=0, history=history) return lr_schedules.PolicySchedule( history, observation_metrics=observation_metrics, include_controls_in_observation=False, action_multipliers=action_multipliers, control_configs=control_configs, policy_and_value_model=policy_and_value_model, policy_and_value_two_towers=False, policy_and_value_vocab_size=vocab_size, policy_dir=policy_dir, )
def _test_sparse_fast_inference(self, length): with fastmath.use_backend(fastmath.Backend.JAX): vocab_size = 16 d_model = 4 encoder_decoder_attention_type = functools.partial( tl.MultiplicativeConvCausalAttention, sparsity=2, length_kernel_size=1, ) model_fn = functools.partial( ct.ConfigurableTransformer, input_vocab_size=vocab_size, d_model=d_model, d_ff=8, n_encoder_layers=2, n_decoder_layers=2, n_heads=2, loss_sparsity=2, ff_sparsity=2, encoder_decoder_attention_type=encoder_decoder_attention_type, # SRU currently doesn't work for second token and further. # ff_use_sru=(1, 4), ) model_slow = model_fn(mode='eval') model_fast = model_fn(mode='predict') rng = fastmath.random.get_prng(0) batch_size = 2 input_signature = (shapes.ShapeDtype( (batch_size, length), np.int32), shapes.ShapeDtype((batch_size, 1), np.int32)) model_slow.init(input_signature) model_fast.init(input_signature) model_slow.save_to_file('/tmp/unique_weights') model_fast.init_from_file('/tmp/unique_weights', weights_only=True, input_signature=input_signature) inp = np.random.randint(vocab_size, size=(batch_size, length)) buf = np.zeros((batch_size, length), dtype=np.int32) next_sym = np.zeros((batch_size, 1), dtype=np.int32) for index in range(length): logits_slow = model_slow((inp, buf), rng=rng)[0] logits_fast = model_fast((inp, next_sym), rng=rng)[0] np.testing.assert_array_almost_equal( logits_slow[:, index, :], logits_fast[:, 0, :], decimal=5, err_msg='Error on token {} out of {}.'.format( index + 1, length)) next_sym = np.random.randint(vocab_size, size=(batch_size, 1)) buf[:, index] = next_sym[:, 0]
def test_input_signatures(self): layer = tl.Serial(DivideBy(2.0), DivideBy(5.0)) self.assertIsNone(layer.input_signature) layer._set_input_signature_recursive(shapes.ShapeDtype((3, 2))) self.assertEqual(layer.input_signature, shapes.ShapeDtype((3, 2))) self.assertLen(layer.sublayers, 2) for sublayer in layer.sublayers: self.assertEqual(sublayer.input_signature, shapes.ShapeDtype( (3, 2)))
def test_autoregressive_sample_reformer2_timing(self): max_len = 16 def _self_attention_fn(): return functools.partial(layers.SelfAttention, predict_drop_len=2 * max_len, predict_mem_len=2 * max_len) def _causal_attention_fn(): return functools.partial(layers.ModularCausalAttention, n_modules=64, max_inference_length=2 * max_len) pred_model = models.Reformer2( mode='predict', d_model=8 * 1024, d_ff=64 * 1024, dropout=0.05, max_len=max_len, n_heads=64, n_encoder_layers=2, n_decoder_layers=2, encoder_attention_type=_self_attention_fn(), encoder_decoder_attention_type=_causal_attention_fn(), input_vocab_size=4, ff_sparsity=256, axial_pos_shape=None, ) shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) pred_model.init(input_signature=(shape1l, shape11)) inputs = np.arange(16, dtype=np.int32).reshape(1, 16) # This is decoding.autoregressive_sample but simplified and with timing. result, counter, start_time, total_time = [], 0, time.time(), 0.0 for sample in decoding.autoregressive_sample_stream( pred_model, inputs, temperature=0.0): # accelerate=False): elapsed_time = time.time() - start_time start_time = time.time() if counter > 3: total_time += elapsed_time result.append(sample[:, None]) counter += 1 if counter >= 14: break # We print 5* time for 10 tokens, @2 layers this is ~1 token @ 100 layers. print('\n\nTime for 5x10 tokens (~1tok @100): %.4fs\n\n\n' % (5 * total_time)) self.assertLess(total_time, 20.0) # If it's > 20s, it's some bug. # Check resulting shapes. s = np.concatenate(result, axis=1) self.assertEqual(s.shape[0], 1) self.assertEqual(s.shape[1], 14)
def test_new_rngs_deterministic(self): inputs1 = shapes.ShapeDtype((2, 3, 5)) inputs2 = (shapes.ShapeDtype((2, 3, 5)), shapes.ShapeDtype((2, 3, 5))) layer1 = base.Layer() layer2 = base.Layer(n_in=2, n_out=2) _, _ = layer1.init(inputs1) _, _ = layer2.init(inputs2) rng1, rng2 = layer1.new_rngs(2) rng3, rng4 = layer2.new_rngs(2) self.assertEqual(rng1.tolist(), rng3.tolist()) self.assertEqual(rng2.tolist(), rng4.tolist())
def test_output_signature_no_weights(self): shape_2_3_5 = shapes.ShapeDtype((2, 3, 5)) input_signature = (shape_2_3_5, shape_2_3_5) layer = tl.Fn('2in1out', lambda x, y: x + y) output_signature = layer.output_signature(input_signature) self.assertEqual(output_signature, shape_2_3_5) shape_5_7 = shapes.ShapeDtype((5, 7)) input_signature = shape_5_7 layer = tl.Fn('1in3out', lambda x: (x, 2 * x, 3 * x), n_out=3) output_signature = layer.output_signature(input_signature) self.assertEqual(output_signature, (shape_5_7, shape_5_7, shape_5_7))
def test_run_reversible_large_weights(self): """Runs the reversible trainer with a lot of weights to test memory use.""" # This test requires > 18GB RAM, only run on TPUs. It does pass on GPU # and CPU when you run it locally, but it's too big for unit-testing. ram_limited = True # Set to False to run this test locally. if fastmath.global_device_count() == 1 and ram_limited: return # Create inputs and rngs. inputs_batch = np.arange(8).reshape((2, 4)) targets_batch = inputs_batch labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) first_layer = tl.Serial(tl.Embedding(9, 16 * 1024), tl.Dup()) rng_init = fastmath.random.get_prng(12) rng_step = fastmath.random.get_prng(13) # Initialize layers. first_layer.init(labeled_batch, rng=rng_init) n_layers = 18 # 18 layers each 16K x 16K = 256M weights ~= 1GB, 18GB ram rev_layers = [] int_shape = shapes.ShapeDtype((2, 4), dtype=np.int32) shape = shapes.ShapeDtype((2, 4, 16 * 1024)) sig = (shape, shape) for _ in range(n_layers): layer = tl.ReversibleHalfResidual(tl.Dense(16 * 1024)) layer.init(sig, rng=rng_init) layer.weights = tl.on_cpu( layer.weights) # store weights in cpu memory rev_layers.append(layer) rev_layers.append(tl.ReversibleSwap()) loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(9), tl.LogSoftmax(), tl.CrossEntropyLoss()) loss_layer.init((shape, shape, int_shape, int_shape)) optimizer_fn = optimizers.Adafactor # Make a step with reversible trainer. trainer = optimizers.ReversibleSerialTrainer( [(first_layer, rev_layers)], loss_layer, optimizer_fn) loss, _ = trainer.one_step(labeled_batch, rng_step) self.assertLess(float(loss.sum()), 10000.0) # Just to get the loss. # Set to true to run again, e.g., for profiling. run_twice = False if run_twice: t = time.time() loss, _ = trainer.one_step(labeled_batch, rng_step) self.assertLess(float(loss.sum()), 10000.0) # Just to get the loss. print('Took %.3f seconds to run, loss %s' % (time.time() - t, loss))
def test_autoregressive_sample_reformer2_copy_self_attn_quality(self): max_len = 32 def _self_attention_fn(): return functools.partial( tl.SelfAttention, predict_drop_len=2 * max_len, predict_mem_len=2 * max_len, ) pred_model = models.Reformer2( mode='predict', d_model=256, d_ff=512, dropout=0.05, max_len=max_len, n_heads=4, n_encoder_layers=3, n_decoder_layers=3, ff_use_sru=1, d_attention_key=64, d_attention_value=64, encoder_attention_type=_self_attention_fn(), encoder_decoder_attention_type=_self_attention_fn(), n_decoder_attention_layers=1, input_vocab_size=13, axial_pos_shape=None, ) shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) model_path = os.path.join(_TESTDATA, 'reformer2_copy_self_attn.pkl.gz') pred_model.init_from_file(model_path, weights_only=True, input_signature=(shape1l, shape11)) inputs = np.array([[11, 5, 1, 2, 3, 4]], dtype=np.int32) inp_len = inputs.shape[1] inputs = np.pad(inputs, [(0, 0), (0, max_len - inp_len)], mode='constant', constant_values=0) s = decoding.autoregressive_sample(pred_model, inputs=inputs, eos_id=-1, max_length=inp_len, temperature=0.0) np.testing.assert_equal(s[0], inputs[0, :inp_len])
def get_slice_for_val(x): if isinstance(x, shapes.ShapeDtype): return shapes.ShapeDtype(shape=x.shape[:1] + (1, ) + x.shape[2:], dtype=x.dtype) else: return x[:, i:i + 1]
def _test_fast_inference(self, length): with fastmath.use_backend(fastmath.Backend.JAX): vocab_size = 16 model_fn = functools.partial( configurable_transformer.ConfigurableTransformerLM, vocab_size=vocab_size, d_model=4, d_ff=8, n_layers=2, n_heads=2, ) model_slow = model_fn(mode='eval') model_fast = model_fn(mode='predict') rng = fastmath.random.get_prng(0) batch_size = 2 input_signature = shapes.ShapeDtype((batch_size, 1), np.int32) # Given the same rng, both models initialize with the same parameters. model_slow.init(input_signature, rng) model_fast.init(input_signature, rng) buf = np.zeros((batch_size, length), dtype=np.int32) next_sym = np.zeros((batch_size, 1), dtype=np.int32) for index in range(length): logits_slow = model_slow(buf, rng=rng) logits_fast = model_fast(next_sym, rng=rng) np.testing.assert_array_almost_equal( logits_slow[:, index, :], logits_fast[:, 0, :], decimal=5, ) next_sym = np.random.randint(vocab_size, size=(batch_size, 1)) buf[:, index] = next_sym[:, 0]
def test_custom_id_grad(self): class IdWithIdGrad(base.Layer): def forward(self, x, weights): return x @property def has_backward(self): return True def backward(self, inputs, output, grad, weights, state, new_state, rng): return (inputs, ()) layer = IdWithIdGrad() rng = math.random.get_prng(0) input_signature = shapes.ShapeDtype((9, 17)) random_input = math.random.uniform(rng, input_signature.shape, minval=-1.0, maxval=1.0) layer.init(input_signature) f = lambda x: jnp.mean(layer(x)) grad = math.grad(f)(random_input) self.assertEqual(grad.shape, (9, 17)) # Gradient for each input. self.assertEqual(sum(sum(grad)), sum(sum(random_input))) # Same as input.
def test_autoregressive_sample_reformerlm(self): lsh_self_attention = self._lsh_self_attention_fn() timebin_self_attention = self._timebin_self_attention_fn() model = models.ReformerLM( vocab_size=256, d_model=256, d_ff=512, d_attention_key=128, d_attention_value=128, n_layers=2, n_heads=2, dropout=0.05, max_len=65536, attention_type=[timebin_self_attention, lsh_self_attention], pos_axial_shape=(256, 256), pos_d_axial_embs=(128, 128), ff_activation=tl.Relu, ff_use_sru=0, mode='predict', ) model.init(shapes.ShapeDtype((1, 1), dtype=np.int32)) s1 = decoding.autoregressive_sample(model, batch_size=1, eos_id=-1, max_length=10) self.assertEqual(s1.shape[0], 1) self.assertEqual(s1.shape[1], 10)
def test_custom_zero_grad(self, backend): class IdWithZeroGrad(tl.Layer): def forward(self, x): return x @property def has_backward(self): return True def backward(self, inputs, output, grad, weights, state, new_state, rng): return (jnp.zeros_like(grad), ()) with fastmath.use_backend(backend): layer = IdWithZeroGrad() rng = fastmath.random.get_prng(0) input_signature = shapes.ShapeDtype((9, 17)) random_input = fastmath.random.uniform(rng, input_signature.shape, minval=-1.0, maxval=1.0) layer.init(input_signature) f = lambda x: jnp.mean(layer(x)) grad = fastmath.grad(f)(random_input) self.assertEqual(grad.shape, (9, 17)) # Gradient for each input. self.assertEqual(sum(sum(grad * grad)), 0.0) # Each one is 0.
def _value_model_signature(self): obs_sig = shapes.signature(self._task.observation_space) target_sig = mask_sig = shapes.ShapeDtype( shape=(1, 1, self._task.action_space), ) inputs_sig = obs_sig.replace(shape=(1, 1) + obs_sig.shape) return (inputs_sig, target_sig, mask_sig)
def test_eval_equals_predict_discrete( model_fn, vocab_size=10, length=5, batch_size=3 ): """Tests the equivalence of eval and predict modes for discrete models.""" with fastmath.use_backend(fastmath.Backend.JAX): model_slow = model_fn(mode='eval', vocab_size=vocab_size) model_fast = model_fn(mode='predict', vocab_size=vocab_size) rng = fastmath.random.get_prng(0) input_signature = shapes.ShapeDtype((batch_size, 1), np.int32) # Given the same rng, both models initialize with the same parameters. model_slow.init(input_signature, rng) model_fast.init(input_signature, rng) buf = np.zeros((batch_size, length), dtype=np.int32) next_sym = np.zeros((batch_size, 1), dtype=np.int32) for index in range(length): logits_slow = model_slow(buf, rng=rng) logits_fast = model_fast(next_sym, rng=rng) np.testing.assert_array_almost_equal( logits_slow[:, index, :], logits_fast[:, 0, :], decimal=5, ) next_sym = np.random.randint(vocab_size, size=(batch_size, 1)) buf[:, index] = next_sym[:, 0]
def test_autoregressive_sample_terraformer_pure_lsh_attn_quality(self): gin.add_config_file_search_path(_CONFIG_DIR) max_len = 32 # 32 is the max length we trained the checkpoint for. test_lengths = [8, 16, 32] vocab_size = 13 # The checkpoint is correct on ~90% sequences, set random seed to deflake. np.random.seed(0) for test_len in test_lengths: gin.clear_config() gin.parse_config_file('terraformer_purelsh_copy.gin') gin.bind_parameter('PureLSHSelfAttention.predict_mem_len', 2 * max_len) gin.bind_parameter('PureLSHSelfAttention.predict_drop_len', 2 * max_len) gin.bind_parameter('PureLSHSelfAttentionWrapper.bias', False) gin.bind_parameter('PureLSHSelfAttentionWrapper.num_weights', 2) pred_model = models.ConfigurableTerraformer(mode='predict') shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) model_path = os.path.join(_TESTDATA, 'terraformer_purelsh_copy.pkl.gz') pred_model.init_from_file(model_path, weights_only=True, input_signature=(shape1l, shape11)) initial_state = pred_model.state for _ in range(2): # Set low to make the test run reasonably fast. # Pick a length in [1, test_len] at random. inp_len = np.random.randint(low=1, high=test_len + 1) inputs = np.random.randint(low=1, high=vocab_size - 1, size=(1, max_len)) # TODO(jaszczur): properly fix padding in terraformer predict mode, # and add a test here. s = decoding.autoregressive_sample(pred_model, inputs=inputs, eos_id=-1, max_length=inp_len, temperature=0.0) np.testing.assert_equal(s[0], inputs[0, :inp_len]) pred_model.state = initial_state gin.clear_config() # Make sure to not affect other tests.
def test_train_with_mixed_lsh_attention(self, backend=fastmath.Backend.JAX): with fastmath.use_backend(backend): # Prepare model and inputs def model(mode='train'): return models.ConfigurableTerraformer( mode=mode, d_model=16, d_ff=16, n_heads=2, dropout=0.05, n_decoder_layers=1, n_encoder_layers=1, input_vocab_size=256, encoder_attention_type=_mixed_lsh_self_attention_fn(), encoder_decoder_attention_type=_mixed_lsh_self_attention_fn( ), ) max_len = 128 inputs = _test_inputs_lm(vocab_size=256, seq_len=max_len) steps = 1 eval_steps = 1 # Train and evaluate output_dir = self.create_tempdir().full_path trainer_lib.train(output_dir, model=model, inputs=inputs, steps=steps, eval_steps=eval_steps, eval_frequency=1) # Read checkpoint model_file = os.path.join(output_dir, 'model.pkl.gz') shape11 = trax_shapes.ShapeDtype((1, 1), dtype=jnp.int32) shape1l = trax_shapes.ShapeDtype((1, max_len), dtype=jnp.int32) model_predict = model(mode='predict') model_predict.init_from_file(model_file, weights_only=True, input_signature=(shape1l, shape11))
def test_run_reversible_weights_trainsfer_xprof(self): """Runs the reversible trainer and profiles weight transfer stats.""" run_this_test = False # We only run this test manually. if not run_this_test or fastmath.global_device_count( ) == 1: # TPU only return # Create inputs and rngs. inputs_batch = np.ones((1024, 128), dtype=np.int32) targets_batch = inputs_batch labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) first_layer = tl.Serial(tl.Embedding(4, 1024), tl.Dup()) rng_init = fastmath.random.get_prng(12) rng_step = fastmath.random.get_prng(13) # Initialize layers. first_layer.init(labeled_batch, rng=rng_init) n_layers = 6 rev_layers = [] int_shape = shapes.ShapeDtype((1024, 128), dtype=np.int32) shape = shapes.ShapeDtype((1024, 128, 1024)) sig = (shape, shape) for _ in range(n_layers): layer = tl.ReversibleHalfResidual(tl.Dense(1024)) layer.init(sig, rng=rng_init) layer.weights = tl.on_cpu( layer.weights) # store weights in cpu memory rev_layers.append(layer) rev_layers.append(tl.ReversibleSwap()) loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(9), tl.LogSoftmax(), tl.CrossEntropyLoss()) loss_layer.init((shape, shape, int_shape, int_shape)) optimizer_fn = optimizers.SGD # Make a step with reversible trainer. trainer = optimizers.ReversibleSerialTrainer( [(first_layer, rev_layers)], loss_layer, optimizer_fn) loss, _ = trainer.one_step(labeled_batch, rng_step) self.assertLess(float(loss.sum()), 10000.0) # Just to get the loss. # We profile here. t = time.time() loss, _ = trainer.one_step(labeled_batch, rng_step) self.assertLess(float(loss.sum()), 10000.0) # Just to get the loss. print('Took %.3f seconds to run, loss %s' % (time.time() - t, loss))
def test_new_rng_deterministic(self): input_signature = shapes.ShapeDtype((2, 3, 5)) layer1 = base.Layer() layer2 = base.Layer(n_in=2, n_out=2) _, _ = layer1.init(input_signature) _, _ = layer2.init(input_signature) rng1 = layer1.new_rng() rng2 = layer2.new_rng() self.assertEqual(rng1.tolist(), rng2.tolist())
def test_loss_layer_timing(self): all_settings = [ # The first run is sometimes slower, less reliable. {'output': 32000, 'input': 2048, 'prob': None, 'type': None, 'sparsity': 0, 'lowrank': 0, 'use_bias': False}, {'output': 32000, 'input': 2048, 'prob': None, 'type': None, 'sparsity': 0, 'lowrank': 0, 'use_bias': False}, {'output': 32000, 'input': 2048, 'prob': None, 'type': 'einsum', 'sparsity': 0, 'lowrank': 0, 'use_bias': False}, {'output': 32000, 'input': 2048, 'prob': None, 'type': 'mult', 'sparsity': 2, 'lowrank': 0, 'use_bias': False}, {'output': 32000, 'input': 2048, 'prob': None, 'type': None, 'sparsity': 0, 'lowrank': 0, 'use_bias': True}, {'output': 32000, 'input': 2048, 'prob': None, 'type': 'einsum', 'sparsity': 0, 'lowrank': 0, 'use_bias': True}, {'output': 32000, 'input': 2048, 'prob': None, 'type': 'mult', 'sparsity': 2, 'lowrank': 0, 'use_bias': True}, ] messages = [] for settings in all_settings: pred_model = tl.SparseDenseWithOptions( n_units=settings['output'], d_input=settings['input'], sparsity_type=settings['type'], sparsity=settings['sparsity'], d_lowrank=settings['lowrank'], prob_sparse=settings['prob'], use_bias=settings['use_bias'], mode='predict', ) pred_model = tl.Accelerate(pred_model) shape1l = shapes.ShapeDtype((1, settings['input'])) pred_model.init(input_signature=shape1l) inputs = np.ones((1, settings['input'])) total_time = 0.0 for counter in range(-50, 100): start_time = time.time() y = pred_model(inputs) self.assertEqual(y.shape, (1, settings['output'])) elapsed_time = time.time() - start_time if counter >= 0: total_time += elapsed_time message = ( '\n\nParams: %d Settings: %s\nTime for 100 tokens: %.4f s\n\n\n' % (_size_of_model(pred_model), settings, total_time)) messages.append(message) print(message) print('Final results (recap):') for message in messages: print(message)
def test_new_rng_new_value_each_call(self): input_signature = shapes.ShapeDtype((2, 3, 5)) layer = base.Layer() _, _ = layer.init(input_signature) rng1 = layer.new_rng() rng2 = layer.new_rng() rng3 = layer.new_rng() self.assertNotEqual(rng1.tolist(), rng2.tolist()) self.assertNotEqual(rng2.tolist(), rng3.tolist())
def test_autoregressive_sample_transformer(self): model = models.Transformer(10, d_model=32, d_ff=64, n_encoder_layers=1, n_decoder_layers=1, n_heads=2, mode='predict') inputs = np.ones((1, 3), dtype=np.int32) model.init((shapes.signature(inputs), shapes.ShapeDtype((1, 1), dtype=np.int32))) s = decoding.autoregressive_sample(model, inputs=inputs, eos_id=-1, max_length=10) self.assertEqual(s.shape[0], 1) self.assertEqual(s.shape[1], 10)
def test_run_reversible_same_as_default_terraformer(self): """Runs the reversible trainer, check results are the same as default.""" inputs_batch = np.arange(8).reshape((2, 4)) + 1 targets_batch = 2 * inputs_batch labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) int_sig = shapes.ShapeDtype((2, 4), dtype=np.int32) input_sig = (int_sig, int_sig, int_sig) # We want to test rng propagation too, so adding some dropout layers. model = terraformer.ConfigurableTerraformer(20, d_model=8, d_ff=32, n_heads=1, dropout=0.0, n_encoder_layers=2, n_decoder_layers=2, ff_sparsity=(4, 8, 0.0, 1.0), pos_type=None, reversible_encoder=True) loss = tl.Serial(tl.LogSoftmax(), tl.CrossEntropyLoss()) optimizer_fn = optimizers.Adafactor blocks, loss_layer = optimizers.trainer.extract_reversible_blocks( [model, loss], loss_chunk_size=4) blocks_serial = [(tl.Serial(std), rev) for (std, rev) in blocks] model_with_loss = tl.Serial(model, loss) rng_init = fastmath.random.get_prng(12) model_with_loss.init(input_sig, rng=rng_init) # Make 3 steps with the original trainer. optimizer = optimizer_fn() optimizer.tree_init(model_with_loss.weights) trainer = optimizers.Trainer(model_with_loss, optimizer) rng_step1 = fastmath.random.get_prng(7) rng_step2 = fastmath.random.get_prng(8) rng_step3 = fastmath.random.get_prng(9) trainer.one_step(labeled_batch, rng_step1) trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03) first_weights = blocks_serial[0][0].weights first_rev_weights = blocks[0][1][0].weights loss_weights = loss_layer.weights # Now make 3 steps with reversible trainer. model_with_loss.init(input_sig, rng=rng_init) trainer = optimizers.ReversibleSerialTrainer(blocks, loss_layer, optimizer_fn) trainer.one_step(labeled_batch, rng_step1) trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03) # Check that weights end up the same. self._assert_all_equal(loss_weights, loss_layer.weights) self._assert_all_equal(first_rev_weights, blocks[0][1][0].weights) self._assert_all_equal(first_weights, blocks_serial[0][0].weights)