def test_autoregressive_sample_reformer2_lsh_attn_quality(self): gin.add_config_file_search_path(_CONFIG_DIR) max_len = 32 # 32 is the max length we trained the checkpoint for. test_lengths = [8, 16, 32] vocab_size = 13 # The checkpoint is correct on ~90% sequences, set random seed to deflake. np.random.seed(0) for test_len in test_lengths: gin.clear_config() gin.parse_config_file('reformer2_copy.gin') gin.bind_parameter('LSHSelfAttention.predict_mem_len', 2 * max_len) gin.bind_parameter('LSHSelfAttention.predict_drop_len', 2 * max_len) pred_model = models.Reformer2(mode='predict') shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) model_path = os.path.join(_TESTDATA, 'reformer2_copy_lsh_attn.pkl.gz') pred_model.init_from_file(model_path, weights_only=True, input_signature=(shape1l, shape11)) initial_state = pred_model.state for _ in range(2): # Set low to make the test run reasonably fast. # Pick a length in [1, test_len] at random. inp_len = np.random.randint(low=1, high=test_len + 1) inputs = np.random.randint(low=1, high=vocab_size-1, size=(1, inp_len)) inputs = np.pad(inputs, [(0, 0), (0, max_len - inp_len)], mode='constant', constant_values=0) s = decoding.autoregressive_sample( pred_model, inputs=inputs, eos_id=-1, max_length=inp_len, temperature=0.0) np.testing.assert_equal(s[0], inputs[0, :inp_len]) pred_model.state = initial_state gin.clear_config() # Make sure to not affect other tests.
def test_autoregressive_sample_reformer2_timing(self): max_len = 16 def _self_attention_fn(): return functools.partial(layers.SelfAttention, predict_drop_len=2 * max_len, predict_mem_len=2 * max_len) def _causal_attention_fn(): return functools.partial(layers.ModularCausalAttention, n_modules=64, max_inference_length=2 * max_len) pred_model = models.Reformer2( mode='predict', d_model=8 * 1024, d_ff=64 * 1024, dropout=0.05, max_len=max_len, n_heads=64, n_encoder_layers=2, n_decoder_layers=2, encoder_attention_type=_self_attention_fn(), encoder_decoder_attention_type=_causal_attention_fn(), input_vocab_size=4, ff_sparsity=256, axial_pos_shape=None, ) shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) pred_model.init(input_signature=(shape1l, shape11)) inputs = np.arange(16, dtype=np.int32).reshape(1, 16) # This is decoding.autoregressive_sample but simplified and with timing. result, counter, start_time, total_time = [], 0, time.time(), 0.0 for sample in decoding.autoregressive_sample_stream( pred_model, inputs, temperature=0.0): # accelerate=False): elapsed_time = time.time() - start_time start_time = time.time() if counter > 3: total_time += elapsed_time result.append(sample[:, None]) counter += 1 if counter >= 14: break # We print 5* time for 10 tokens, @2 layers this is ~1 token @ 100 layers. print('\n\nTime for 5x10 tokens (~1tok @100): %.4fs\n\n\n' % (5 * total_time)) self.assertLess(total_time, 20.0) # If it's > 20s, it's some bug. # Check resulting shapes. s = np.concatenate(result, axis=1) self.assertEqual(s.shape[0], 1) self.assertEqual(s.shape[1], 14)
def model(mode='train'): return models.Reformer2( mode=mode, d_model=16, d_ff=16, n_heads=2, dropout=0.05, n_decoder_layers=1, n_encoder_layers=1, input_vocab_size=256, encoder_attention_type=_mixed_lsh_self_attention_fn(), encoder_decoder_attention_type=_mixed_lsh_self_attention_fn( ), )
def test_autoregressive_sample_reformer2_copy_self_attn_quality(self): max_len = 32 def _self_attention_fn(): return functools.partial( tl.SelfAttention, predict_drop_len=2 * max_len, predict_mem_len=2 * max_len, ) pred_model = models.Reformer2( mode='predict', d_model=256, d_ff=512, dropout=0.05, max_len=max_len, n_heads=4, n_encoder_layers=3, n_decoder_layers=3, ff_use_sru=1, d_attention_key=64, d_attention_value=64, encoder_attention_type=_self_attention_fn(), encoder_decoder_attention_type=_self_attention_fn(), n_decoder_attention_layers=1, input_vocab_size=13, axial_pos_shape=None, ) shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) model_path = os.path.join(_TESTDATA, 'reformer2_copy_self_attn.pkl.gz') pred_model.init_from_file(model_path, weights_only=True, input_signature=(shape1l, shape11)) inputs = np.array([[11, 5, 1, 2, 3, 4]], dtype=np.int32) inp_len = inputs.shape[1] inputs = np.pad(inputs, [(0, 0), (0, max_len - inp_len)], mode='constant', constant_values=0) s = decoding.autoregressive_sample(pred_model, inputs=inputs, eos_id=-1, max_length=inp_len, temperature=0.0) np.testing.assert_equal(s[0], inputs[0, :inp_len])
def test_autoregressive_sample_reformer2_pure_lsh(self): max_len = 128 pred_model = models.Reformer2( mode='predict', d_model=256, d_ff=512, dropout=0.05, max_len=max_len, n_heads=4, n_encoder_layers=1, n_decoder_layers=1, ff_use_sru=1, d_attention_key=64, d_attention_value=64, encoder_attention_type=self._pure_lsh_self_attention_fn( n_chunks_after=1), encoder_decoder_attention_type=self._pure_lsh_self_attention_fn(), input_vocab_size=256, pos_axial_shape=None, ) shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) pred_model.init(input_signature=(shape1l, shape11)) # 0w0w inputs = np.array( [[0, 3, 7, 5, 3, 2, 4, 1, 8, 0, 3, 7, 5, 3, 2, 4, 1, 8]], dtype=np.int32) inputs = np.pad(inputs, [(0, 0), (0, max_len - inputs.shape[1])], mode='constant', constant_values=0) s = decoding.autoregressive_sample(pred_model, inputs=inputs, eos_id=-1, max_length=10, temperature=0.0) self.assertEqual(s.shape[0], 1) self.assertEqual(s.shape[1], 10)
def test_autoregressive_sample_reformer2_timing(self): max_len = 16 all_settings = [ { 'attn_sparsity': 64, 'ff_sparsity': (256, 32), 'attn': (tl.MultiplicativeCausalAttention, {}) }, { 'attn_sparsity': 64, 'ff_sparsity': (256, 32), 'attn': (tl.ModularCausalAttention, {}) }, { 'attn_sparsity': 64, 'ff_sparsity': (256, 32), 'attn': (tl.ConvCausalAttention, {}) }, { 'attn_sparsity': 64, 'ff_sparsity': (256, 32), 'attn': (tl.MultiplicativeConvCausalAttention, { 'length_kernel_size': 1 }) }, { 'attn_sparsity': 64, 'ff_sparsity': (256, 32), 'attn': (tl.MultiplicativeConvCausalAttention, { 'length_kernel_size': 3 }) }, ] messages = [] for settings in all_settings: def _self_attention_fn(): return functools.partial(tl.SelfAttention, predict_drop_len=2 * max_len, predict_mem_len=2 * max_len) def _causal_attention_fn(): attn_layer, attn_kwargs = settings['attn'] # pylint: disable=cell-var-from-loop return functools.partial( attn_layer, sparsity=settings['attn_sparsity'], # pylint: disable=cell-var-from-loop max_inference_length=2 * max_len, **attn_kwargs) pred_model = models.Reformer2( mode='predict', d_model=96 * 96, d_ff=64 * 1024, dropout=0.05, max_len=max_len, n_heads=64, n_encoder_layers=2, n_decoder_layers=2, encoder_attention_type=_self_attention_fn(), encoder_decoder_attention_type=_causal_attention_fn(), input_vocab_size=4, ff_sparsity=settings['ff_sparsity'], axial_pos_shape=None, use_bfloat16=True, ) shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) pred_model.init(input_signature=(shape1l, shape11)) inputs = np.arange(16, dtype=np.int32).reshape(1, 16) # This is decoding.autoregressive_sample but simplified and with timing. result, counter, start_time, total_time = [], 0, time.time(), 0.0 for sample in decoding.autoregressive_sample_stream( pred_model, inputs, temperature=0.0): # accelerate=False): elapsed_time = time.time() - start_time start_time = time.time() if counter > 3: total_time += elapsed_time result.append(sample[:, None]) counter += 1 if counter >= 14: break # We print 5* time for 10 tokens, @2 layers this is ~1 token @ 100 layers. message = ( '\n\nSettings: %s\nTime for 5x10 tokens (~1tok @100): %.4f s\n\n\n' % (settings, 5 * total_time)) messages.append(message) print(message) # self.assertLess(total_time, 20.0) # If it's > 20s, it's some bug. # # Check resulting shapes. # s = np.concatenate(result, axis=1) # self.assertEqual(s.shape[0], 1) # self.assertEqual(s.shape[1], 14) print('Final results (recap):') for message in messages: print(message)
def _reformer2_decoding_time(self, settings): max_len = 16 def _self_attention_fn(): return functools.partial(tl.SelfAttention, predict_drop_len=2 * max_len, predict_mem_len=2 * max_len) def _causal_attention_fn(): attn_layer, attn_kwargs = settings['attn'] return functools.partial(attn_layer, max_inference_length=2 * max_len, **attn_kwargs) pred_model = models.Reformer2( mode='predict', d_model=settings['d_model'], d_ff=settings['d_ff'], dropout=0.1, max_len=max_len, n_heads=settings['n_heads'], n_encoder_layers=settings['encoder_layers'], n_decoder_layers=settings['decoder_layers'], encoder_attention_type=_self_attention_fn(), encoder_decoder_attention_type=_causal_attention_fn(), input_vocab_size=settings['vocab'], ff_sparsity=settings['ff_sparsity'], ff_use_sru=settings['ff_use_sru'], ff_dropout=0.1, ff_chunk_size=1024, attention_chunk_size=1, loss_sparsity=settings['loss_sparsity'], axial_pos_shape=None, use_bfloat16=True, ) shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) pred_model.init(input_signature=(shape1l, shape11)) inputs = np.arange(max_len, dtype=np.int32).reshape(1, max_len) # This is decoding.autoregressive_sample but simplified and with timing. result, start_time = [], time.time() elapsed_times = [] peak_memory = 0 for index, sample in zip( range(-4, 10), decoding.autoregressive_sample_stream(pred_model, inputs, temperature=0.0)): peak_memory = max(peak_memory, memory_usage()) result.append(sample[:, None]) elapsed_time = time.time() - start_time if index >= 0: elapsed_times.append(elapsed_time) start_time = time.time() if min(elapsed_times) * 2 < max(elapsed_times): print( 'WARNING! High variance found in elapsed times! Settings: {} ; ' 'elapsed times: {} ; Probably more warm-up steps should be used, ' 'or model size should be increased.'.format( settings, elapsed_times)) # Check resulting shapes. s = np.concatenate(result, axis=1) self.assertEqual(s.shape[0], 1) self.assertEqual(s.shape[1], 14) return size_of_model(pred_model), elapsed_times, peak_memory
def _reformer2_decoding_time(self, settings): # Garbage collection influences the timing, so we turn it off. gc.disable() max_len = 16 def _self_attention_fn(): return functools.partial(tl.SelfAttention, predict_drop_len=2 * max_len, predict_mem_len=2 * max_len) def _causal_attention_fn(): attn_layer, attn_kwargs = settings['attn'] return functools.partial(attn_layer, max_inference_length=2 * max_len, **attn_kwargs) pred_model = models.Reformer2( mode='predict', d_model=settings['d_model'], d_ff=settings['d_ff'], dropout=0.1, max_len=max_len, n_heads=settings['n_heads'], n_encoder_layers=settings['encoder_layers'], n_decoder_layers=settings['decoder_layers'], encoder_attention_type=_self_attention_fn(), encoder_decoder_attention_type=_causal_attention_fn(), input_vocab_size=settings['vocab'], ff_sparsity=settings['ff_sparsity'], ff_use_sru=settings['ff_use_sru'], ff_dropout=0.1, ff_chunk_size=1024, attention_chunk_size=1, n_decoder_attention_layers=settings['attention_layers'], loss_sparsity=settings['loss_sparsity'], pos_axial_shape=None, use_bfloat16=True, ) # We put acceleration outside of autoregressive_sample_stream, because # we want to have a separate run (separate input) for model compilation. pred_model = tl.Accelerate(pred_model) shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) pred_model.init(input_signature=(shape1l, shape11)) original_state = copy.deepcopy(pred_model.state) inputs_warmup = np.zeros((1, max_len), dtype=np.int32) inputs = np.arange(max_len, dtype=np.int32).reshape(1, max_len) # This is a warm-up run, for compilation. result, current_time = [], time.time() elapsed_warmup_times = [] for index, sample in zip( range(0, 4), decoding.autoregressive_sample_stream(pred_model, inputs_warmup, temperature=0.0, accelerate=False)): del index # unused result.append( sample[:, None]) # to be sure that the result is computed current_time, start_time = time.time(), current_time elapsed_warmup_times.append(current_time - start_time) # This is a real decoding timing run that we measure. pred_model.state = original_state result, current_time = [], time.time() elapsed_times = [] for index, sample in zip( range(12), decoding.autoregressive_sample_stream(pred_model, inputs, temperature=0.0, accelerate=False)): del index # unused result.append( sample[:, None]) # to be sure that the result is computed current_time, start_time = time.time(), current_time elapsed_times.append(current_time - start_time) peak_memory = _memory_usage() if min(elapsed_times[2:]) * 2 < max(elapsed_times[2:]): print( 'WARNING! High variance found in elapsed times! Settings: {} ; ' 'elapsed times: {} ; Probably more warm-up steps should be used, ' 'or model size should be increased.'.format( settings, elapsed_times)) # Check resulting shapes. s = np.concatenate(result, axis=1) self.assertEqual(s.shape[0], 1) self.assertEqual(s.shape[1], 12) model_size = int(_size_of_model(pred_model)) # We delete the model weights, because in some situations they won't be # deleted automatically. _recurrent_delete(pred_model.weights) gc.enable() return model_size, elapsed_times, peak_memory