def test_autoregressive_sample_reformer2_timing(self): max_len = 16 def _self_attention_fn(): return functools.partial(layers.SelfAttention, predict_drop_len=2 * max_len, predict_mem_len=2 * max_len) def _causal_attention_fn(): return functools.partial(layers.ModularCausalAttention, n_modules=64, max_inference_length=2 * max_len) pred_model = models.Reformer2( mode='predict', d_model=8 * 1024, d_ff=64 * 1024, dropout=0.05, max_len=max_len, n_heads=64, n_encoder_layers=2, n_decoder_layers=2, encoder_attention_type=_self_attention_fn(), encoder_decoder_attention_type=_causal_attention_fn(), input_vocab_size=4, ff_sparsity=256, axial_pos_shape=None, ) shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) pred_model.init(input_signature=(shape1l, shape11)) inputs = np.arange(16, dtype=np.int32).reshape(1, 16) # This is decoding.autoregressive_sample but simplified and with timing. result, counter, start_time, total_time = [], 0, time.time(), 0.0 for sample in decoding.autoregressive_sample_stream( pred_model, inputs, temperature=0.0): # accelerate=False): elapsed_time = time.time() - start_time start_time = time.time() if counter > 3: total_time += elapsed_time result.append(sample[:, None]) counter += 1 if counter >= 14: break # We print 5* time for 10 tokens, @2 layers this is ~1 token @ 100 layers. print('\n\nTime for 5x10 tokens (~1tok @100): %.4fs\n\n\n' % (5 * total_time)) self.assertLess(total_time, 20.0) # If it's > 20s, it's some bug. # Check resulting shapes. s = np.concatenate(result, axis=1) self.assertEqual(s.shape[0], 1) self.assertEqual(s.shape[1], 14)
def _terraformer_decoding_time(self, settings): # Garbage collection influences the timing, so we turn it off. gc.disable() max_len = 16 def _self_attention_fn(): return functools.partial( tl.SelfAttention, predict_drop_len=2 * max_len, predict_mem_len=2 * max_len) def _causal_attention_fn(): attn_layer, attn_kwargs = settings['attn'] return functools.partial( attn_layer, max_inference_length=2 * max_len, **attn_kwargs) if settings['model'] == 'terraformer': pred_model = models.ConfigurableTerraformer( mode='predict', d_model=settings['d_model'], d_ff=settings['d_ff'], dropout=0.1, max_len=max_len, n_heads=settings['n_heads'], n_encoder_layers=settings['encoder_layers'], n_decoder_layers=settings['decoder_layers'], encoder_attention_type=_self_attention_fn(), encoder_decoder_attention_type=_causal_attention_fn(), input_vocab_size=settings['vocab'], ff_sparsity=settings['ff_sparsity'], ff_use_sru=settings['ff_use_sru'], ff_dropout=0.1, # ff_chunk_size=1024, # attention_chunk_size=1, n_decoder_attention_layers=settings['attention_layers'], loss_sparsity=settings['loss_sparsity'], pos_axial_shape=None, use_bfloat16=True, ) elif settings['model'] == 'transformer': pred_model = models.ConfigurableTransformer( mode='predict', d_model=settings['d_model'], d_ff=settings['d_ff'], dropout=0.1, max_len=max_len, n_heads=settings['n_heads'], n_encoder_layers=settings['encoder_layers'], n_decoder_layers=settings['decoder_layers'], # encoder_attention_type=_self_attention_fn(), encoder_decoder_attention_type=_causal_attention_fn(), input_vocab_size=settings['vocab'], ff_sparsity=settings['ff_sparsity'], ff_use_sru=settings['ff_use_sru'], # ff_dropout=0.1, # ff_chunk_size=1024, # attention_chunk_size=1, # n_decoder_attention_layers=settings['attention_layers'], loss_sparsity=settings['loss_sparsity'], pos_axial_shape=None, # enc_dec_attention_sparsity=settings['enc_dec_sparsity'], # use_bfloat16=True, ) else: assert False # We put acceleration outside of autoregressive_sample_stream, because # we want to have a separate run (separate input) for model compilation. pred_model = tl.Accelerate(pred_model) shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) pred_model.init(input_signature=(shape1l, shape11)) original_state = copy.deepcopy(pred_model.state) inputs_warmup = np.zeros((1, max_len), dtype=np.int32) inputs = np.arange(max_len, dtype=np.int32).reshape(1, max_len) # This is a warm-up run, for compilation. result, current_time = [], time.time() elapsed_warmup_times = [] for index, sample in zip(range(0, 4), decoding.autoregressive_sample_stream( pred_model, inputs_warmup, temperature=0.0, accelerate=False)): del index # unused result.append(sample[:, None]) # to be sure that the result is computed current_time, start_time = time.time(), current_time elapsed_warmup_times.append(current_time - start_time) # This is a real decoding timing run that we measure. pred_model.state = original_state result, current_time = [], time.time() elapsed_times = [] for index, sample in zip(range(12), decoding.autoregressive_sample_stream( pred_model, inputs, temperature=0.0, accelerate=False)): del index # unused result.append(sample[:, None]) # to be sure that the result is computed current_time, start_time = time.time(), current_time elapsed_times.append(current_time - start_time) peak_memory = _memory_usage() if min(elapsed_times[2:]) * 2 < max(elapsed_times[2:]): print('WARNING! High variance found in elapsed times! Settings: {} ; ' 'elapsed times: {} ; Probably more warm-up steps should be used, ' 'or model size should be increased.'.format(settings, elapsed_times)) # Check resulting shapes. s = np.concatenate(result, axis=1) self.assertEqual(s.shape[0], 1) self.assertEqual(s.shape[1], 12) model_size = int(_size_of_model(pred_model)) # We delete the model weights, because in some situations they won't be # deleted automatically. _recurrent_delete(pred_model.weights) gc.enable() return model_size, elapsed_times, peak_memory
def test_autoregressive_sample_reformer2_timing(self): max_len = 16 all_settings = [ { 'attn_sparsity': 64, 'ff_sparsity': (256, 32), 'attn': (tl.MultiplicativeCausalAttention, {}) }, { 'attn_sparsity': 64, 'ff_sparsity': (256, 32), 'attn': (tl.ModularCausalAttention, {}) }, { 'attn_sparsity': 64, 'ff_sparsity': (256, 32), 'attn': (tl.ConvCausalAttention, {}) }, { 'attn_sparsity': 64, 'ff_sparsity': (256, 32), 'attn': (tl.MultiplicativeConvCausalAttention, { 'length_kernel_size': 1 }) }, { 'attn_sparsity': 64, 'ff_sparsity': (256, 32), 'attn': (tl.MultiplicativeConvCausalAttention, { 'length_kernel_size': 3 }) }, ] messages = [] for settings in all_settings: def _self_attention_fn(): return functools.partial(tl.SelfAttention, predict_drop_len=2 * max_len, predict_mem_len=2 * max_len) def _causal_attention_fn(): attn_layer, attn_kwargs = settings['attn'] # pylint: disable=cell-var-from-loop return functools.partial( attn_layer, sparsity=settings['attn_sparsity'], # pylint: disable=cell-var-from-loop max_inference_length=2 * max_len, **attn_kwargs) pred_model = models.Reformer2( mode='predict', d_model=96 * 96, d_ff=64 * 1024, dropout=0.05, max_len=max_len, n_heads=64, n_encoder_layers=2, n_decoder_layers=2, encoder_attention_type=_self_attention_fn(), encoder_decoder_attention_type=_causal_attention_fn(), input_vocab_size=4, ff_sparsity=settings['ff_sparsity'], axial_pos_shape=None, use_bfloat16=True, ) shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) pred_model.init(input_signature=(shape1l, shape11)) inputs = np.arange(16, dtype=np.int32).reshape(1, 16) # This is decoding.autoregressive_sample but simplified and with timing. result, counter, start_time, total_time = [], 0, time.time(), 0.0 for sample in decoding.autoregressive_sample_stream( pred_model, inputs, temperature=0.0): # accelerate=False): elapsed_time = time.time() - start_time start_time = time.time() if counter > 3: total_time += elapsed_time result.append(sample[:, None]) counter += 1 if counter >= 14: break # We print 5* time for 10 tokens, @2 layers this is ~1 token @ 100 layers. message = ( '\n\nSettings: %s\nTime for 5x10 tokens (~1tok @100): %.4f s\n\n\n' % (settings, 5 * total_time)) messages.append(message) print(message) # self.assertLess(total_time, 20.0) # If it's > 20s, it's some bug. # # Check resulting shapes. # s = np.concatenate(result, axis=1) # self.assertEqual(s.shape[0], 1) # self.assertEqual(s.shape[1], 14) print('Final results (recap):') for message in messages: print(message)
def _reformer2_decoding_time(self, settings): max_len = 16 def _self_attention_fn(): return functools.partial(tl.SelfAttention, predict_drop_len=2 * max_len, predict_mem_len=2 * max_len) def _causal_attention_fn(): attn_layer, attn_kwargs = settings['attn'] return functools.partial(attn_layer, max_inference_length=2 * max_len, **attn_kwargs) pred_model = models.Reformer2( mode='predict', d_model=settings['d_model'], d_ff=settings['d_ff'], dropout=0.1, max_len=max_len, n_heads=settings['n_heads'], n_encoder_layers=settings['encoder_layers'], n_decoder_layers=settings['decoder_layers'], encoder_attention_type=_self_attention_fn(), encoder_decoder_attention_type=_causal_attention_fn(), input_vocab_size=settings['vocab'], ff_sparsity=settings['ff_sparsity'], ff_use_sru=settings['ff_use_sru'], ff_dropout=0.1, ff_chunk_size=1024, attention_chunk_size=1, loss_sparsity=settings['loss_sparsity'], axial_pos_shape=None, use_bfloat16=True, ) shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) pred_model.init(input_signature=(shape1l, shape11)) inputs = np.arange(max_len, dtype=np.int32).reshape(1, max_len) # This is decoding.autoregressive_sample but simplified and with timing. result, start_time = [], time.time() elapsed_times = [] peak_memory = 0 for index, sample in zip( range(-4, 10), decoding.autoregressive_sample_stream(pred_model, inputs, temperature=0.0)): peak_memory = max(peak_memory, memory_usage()) result.append(sample[:, None]) elapsed_time = time.time() - start_time if index >= 0: elapsed_times.append(elapsed_time) start_time = time.time() if min(elapsed_times) * 2 < max(elapsed_times): print( 'WARNING! High variance found in elapsed times! Settings: {} ; ' 'elapsed times: {} ; Probably more warm-up steps should be used, ' 'or model size should be increased.'.format( settings, elapsed_times)) # Check resulting shapes. s = np.concatenate(result, axis=1) self.assertEqual(s.shape[0], 1) self.assertEqual(s.shape[1], 14) return size_of_model(pred_model), elapsed_times, peak_memory
vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, n_reserved_ids=100))) def detokenize(x): return trax.data.detokenize(x, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, n_reserved_ids=100) with trax.fastmath.use_backend(trax.fastmath.Backend.JAX): model.state = old_state counter, tokens, max_length = 0, [], 30 for token in decoding.autoregressive_sample_stream(model, tokenized[None, :15 * 1024], batch_size=1, temperature=0.0, eval_mode=True, eval_min_length=1024): print(f'Token {counter}: "{detokenize(token)}" {token}') tokens.append(token[:, None]) counter += 1 if counter > max_length: break tokens = np.concatenate(tokens, axis=1) print(tokens) print(detokenize(tokens[0]))