Exemple #1
0
    def test_autoregressive_sample_reformer2_timing(self):
        max_len = 16

        def _self_attention_fn():
            return functools.partial(layers.SelfAttention,
                                     predict_drop_len=2 * max_len,
                                     predict_mem_len=2 * max_len)

        def _causal_attention_fn():
            return functools.partial(layers.ModularCausalAttention,
                                     n_modules=64,
                                     max_inference_length=2 * max_len)

        pred_model = models.Reformer2(
            mode='predict',
            d_model=8 * 1024,
            d_ff=64 * 1024,
            dropout=0.05,
            max_len=max_len,
            n_heads=64,
            n_encoder_layers=2,
            n_decoder_layers=2,
            encoder_attention_type=_self_attention_fn(),
            encoder_decoder_attention_type=_causal_attention_fn(),
            input_vocab_size=4,
            ff_sparsity=256,
            axial_pos_shape=None,
        )

        shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
        shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32)
        pred_model.init(input_signature=(shape1l, shape11))
        inputs = np.arange(16, dtype=np.int32).reshape(1, 16)

        # This is decoding.autoregressive_sample but simplified and with timing.
        result, counter, start_time, total_time = [], 0, time.time(), 0.0
        for sample in decoding.autoregressive_sample_stream(
                pred_model, inputs, temperature=0.0):  # accelerate=False):
            elapsed_time = time.time() - start_time
            start_time = time.time()
            if counter > 3:
                total_time += elapsed_time
            result.append(sample[:, None])
            counter += 1
            if counter >= 14:
                break

        # We print 5* time for 10 tokens, @2 layers this is ~1 token @ 100 layers.
        print('\n\nTime for 5x10 tokens (~1tok @100): %.4fs\n\n\n' %
              (5 * total_time))
        self.assertLess(total_time, 20.0)  # If it's > 20s, it's some bug.
        # Check resulting shapes.
        s = np.concatenate(result, axis=1)
        self.assertEqual(s.shape[0], 1)
        self.assertEqual(s.shape[1], 14)
Exemple #2
0
  def _terraformer_decoding_time(self, settings):
    # Garbage collection influences the timing, so we turn it off.
    gc.disable()
    max_len = 16

    def _self_attention_fn():
      return functools.partial(
          tl.SelfAttention,
          predict_drop_len=2 * max_len,
          predict_mem_len=2 * max_len)

    def _causal_attention_fn():
      attn_layer, attn_kwargs = settings['attn']
      return functools.partial(
          attn_layer,
          max_inference_length=2 * max_len, **attn_kwargs)

    if settings['model'] == 'terraformer':
      pred_model = models.ConfigurableTerraformer(
          mode='predict',
          d_model=settings['d_model'],
          d_ff=settings['d_ff'],
          dropout=0.1,
          max_len=max_len,
          n_heads=settings['n_heads'],
          n_encoder_layers=settings['encoder_layers'],
          n_decoder_layers=settings['decoder_layers'],
          encoder_attention_type=_self_attention_fn(),
          encoder_decoder_attention_type=_causal_attention_fn(),
          input_vocab_size=settings['vocab'],
          ff_sparsity=settings['ff_sparsity'],
          ff_use_sru=settings['ff_use_sru'],
          ff_dropout=0.1,
          # ff_chunk_size=1024,
          # attention_chunk_size=1,
          n_decoder_attention_layers=settings['attention_layers'],
          loss_sparsity=settings['loss_sparsity'],
          pos_axial_shape=None,
          use_bfloat16=True,
      )
    elif settings['model'] == 'transformer':
      pred_model = models.ConfigurableTransformer(
          mode='predict',
          d_model=settings['d_model'],
          d_ff=settings['d_ff'],
          dropout=0.1,
          max_len=max_len,
          n_heads=settings['n_heads'],
          n_encoder_layers=settings['encoder_layers'],
          n_decoder_layers=settings['decoder_layers'],
          # encoder_attention_type=_self_attention_fn(),
          encoder_decoder_attention_type=_causal_attention_fn(),
          input_vocab_size=settings['vocab'],
          ff_sparsity=settings['ff_sparsity'],
          ff_use_sru=settings['ff_use_sru'],
          # ff_dropout=0.1,
          # ff_chunk_size=1024,
          # attention_chunk_size=1,
          # n_decoder_attention_layers=settings['attention_layers'],
          loss_sparsity=settings['loss_sparsity'],
          pos_axial_shape=None,
          # enc_dec_attention_sparsity=settings['enc_dec_sparsity'],
          # use_bfloat16=True,
      )
    else:
      assert False
    # We put acceleration outside of autoregressive_sample_stream, because
    # we want to have a separate run (separate input) for model compilation.
    pred_model = tl.Accelerate(pred_model)

    shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
    shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32)
    pred_model.init(input_signature=(shape1l, shape11))
    original_state = copy.deepcopy(pred_model.state)

    inputs_warmup = np.zeros((1, max_len), dtype=np.int32)
    inputs = np.arange(max_len, dtype=np.int32).reshape(1, max_len)

    # This is a warm-up run, for compilation.
    result, current_time = [], time.time()
    elapsed_warmup_times = []
    for index, sample in zip(range(0, 4), decoding.autoregressive_sample_stream(
        pred_model, inputs_warmup, temperature=0.0, accelerate=False)):
      del index  # unused
      result.append(sample[:, None])  # to be sure that the result is computed

      current_time, start_time = time.time(), current_time
      elapsed_warmup_times.append(current_time - start_time)

    # This is a real decoding timing run that we measure.
    pred_model.state = original_state
    result, current_time = [], time.time()
    elapsed_times = []
    for index, sample in zip(range(12), decoding.autoregressive_sample_stream(
        pred_model, inputs, temperature=0.0, accelerate=False)):
      del index  # unused
      result.append(sample[:, None])  # to be sure that the result is computed

      current_time, start_time = time.time(), current_time
      elapsed_times.append(current_time - start_time)
    peak_memory = _memory_usage()

    if min(elapsed_times[2:]) * 2 < max(elapsed_times[2:]):
      print('WARNING! High variance found in elapsed times! Settings: {} ; '
            'elapsed times: {} ; Probably more warm-up steps should be used, '
            'or model size should be increased.'.format(settings,
                                                        elapsed_times))
    # Check resulting shapes.
    s = np.concatenate(result, axis=1)
    self.assertEqual(s.shape[0], 1)
    self.assertEqual(s.shape[1], 12)
    model_size = int(_size_of_model(pred_model))

    # We delete the model weights, because in some situations they won't be
    # deleted automatically.
    _recurrent_delete(pred_model.weights)
    gc.enable()
    return model_size, elapsed_times, peak_memory
    def test_autoregressive_sample_reformer2_timing(self):
        max_len = 16
        all_settings = [
            {
                'attn_sparsity': 64,
                'ff_sparsity': (256, 32),
                'attn': (tl.MultiplicativeCausalAttention, {})
            },
            {
                'attn_sparsity': 64,
                'ff_sparsity': (256, 32),
                'attn': (tl.ModularCausalAttention, {})
            },
            {
                'attn_sparsity': 64,
                'ff_sparsity': (256, 32),
                'attn': (tl.ConvCausalAttention, {})
            },
            {
                'attn_sparsity':
                64,
                'ff_sparsity': (256, 32),
                'attn': (tl.MultiplicativeConvCausalAttention, {
                    'length_kernel_size': 1
                })
            },
            {
                'attn_sparsity':
                64,
                'ff_sparsity': (256, 32),
                'attn': (tl.MultiplicativeConvCausalAttention, {
                    'length_kernel_size': 3
                })
            },
        ]
        messages = []

        for settings in all_settings:

            def _self_attention_fn():
                return functools.partial(tl.SelfAttention,
                                         predict_drop_len=2 * max_len,
                                         predict_mem_len=2 * max_len)

            def _causal_attention_fn():
                attn_layer, attn_kwargs = settings['attn']  # pylint: disable=cell-var-from-loop
                return functools.partial(
                    attn_layer,
                    sparsity=settings['attn_sparsity'],  # pylint: disable=cell-var-from-loop
                    max_inference_length=2 * max_len,
                    **attn_kwargs)

            pred_model = models.Reformer2(
                mode='predict',
                d_model=96 * 96,
                d_ff=64 * 1024,
                dropout=0.05,
                max_len=max_len,
                n_heads=64,
                n_encoder_layers=2,
                n_decoder_layers=2,
                encoder_attention_type=_self_attention_fn(),
                encoder_decoder_attention_type=_causal_attention_fn(),
                input_vocab_size=4,
                ff_sparsity=settings['ff_sparsity'],
                axial_pos_shape=None,
                use_bfloat16=True,
            )

            shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
            shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32)
            pred_model.init(input_signature=(shape1l, shape11))
            inputs = np.arange(16, dtype=np.int32).reshape(1, 16)

            # This is decoding.autoregressive_sample but simplified and with timing.
            result, counter, start_time, total_time = [], 0, time.time(), 0.0
            for sample in decoding.autoregressive_sample_stream(
                    pred_model, inputs, temperature=0.0):  # accelerate=False):
                elapsed_time = time.time() - start_time
                start_time = time.time()
                if counter > 3:
                    total_time += elapsed_time
                result.append(sample[:, None])
                counter += 1
                if counter >= 14:
                    break

            # We print 5* time for 10 tokens, @2 layers this is ~1 token @ 100 layers.
            message = (
                '\n\nSettings: %s\nTime for 5x10 tokens (~1tok @100): %.4f s\n\n\n'
                % (settings, 5 * total_time))
            messages.append(message)
            print(message)
            # self.assertLess(total_time, 20.0)  # If it's > 20s, it's some bug.
            # # Check resulting shapes.
            # s = np.concatenate(result, axis=1)
            # self.assertEqual(s.shape[0], 1)
            # self.assertEqual(s.shape[1], 14)

        print('Final results (recap):')
        for message in messages:
            print(message)
Exemple #4
0
    def _reformer2_decoding_time(self, settings):
        max_len = 16

        def _self_attention_fn():
            return functools.partial(tl.SelfAttention,
                                     predict_drop_len=2 * max_len,
                                     predict_mem_len=2 * max_len)

        def _causal_attention_fn():
            attn_layer, attn_kwargs = settings['attn']
            return functools.partial(attn_layer,
                                     max_inference_length=2 * max_len,
                                     **attn_kwargs)

        pred_model = models.Reformer2(
            mode='predict',
            d_model=settings['d_model'],
            d_ff=settings['d_ff'],
            dropout=0.1,
            max_len=max_len,
            n_heads=settings['n_heads'],
            n_encoder_layers=settings['encoder_layers'],
            n_decoder_layers=settings['decoder_layers'],
            encoder_attention_type=_self_attention_fn(),
            encoder_decoder_attention_type=_causal_attention_fn(),
            input_vocab_size=settings['vocab'],
            ff_sparsity=settings['ff_sparsity'],
            ff_use_sru=settings['ff_use_sru'],
            ff_dropout=0.1,
            ff_chunk_size=1024,
            attention_chunk_size=1,
            loss_sparsity=settings['loss_sparsity'],
            axial_pos_shape=None,
            use_bfloat16=True,
        )

        shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
        shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32)
        pred_model.init(input_signature=(shape1l, shape11))
        inputs = np.arange(max_len, dtype=np.int32).reshape(1, max_len)

        # This is decoding.autoregressive_sample but simplified and with timing.
        result, start_time = [], time.time()
        elapsed_times = []
        peak_memory = 0
        for index, sample in zip(
                range(-4, 10),
                decoding.autoregressive_sample_stream(pred_model,
                                                      inputs,
                                                      temperature=0.0)):
            peak_memory = max(peak_memory, memory_usage())
            result.append(sample[:, None])
            elapsed_time = time.time() - start_time
            if index >= 0:
                elapsed_times.append(elapsed_time)
            start_time = time.time()

        if min(elapsed_times) * 2 < max(elapsed_times):
            print(
                'WARNING! High variance found in elapsed times! Settings: {} ; '
                'elapsed times: {} ; Probably more warm-up steps should be used, '
                'or model size should be increased.'.format(
                    settings, elapsed_times))
        # Check resulting shapes.
        s = np.concatenate(result, axis=1)
        self.assertEqual(s.shape[0], 1)
        self.assertEqual(s.shape[1], 14)
        return size_of_model(pred_model), elapsed_times, peak_memory
Exemple #5
0
                           vocab_file=VOCAB_FILE,
                           vocab_dir=VOCAB_DIR,
                           n_reserved_ids=100)))


def detokenize(x):
    return trax.data.detokenize(x,
                                vocab_file=VOCAB_FILE,
                                vocab_dir=VOCAB_DIR,
                                n_reserved_ids=100)


with trax.fastmath.use_backend(trax.fastmath.Backend.JAX):
    model.state = old_state
    counter, tokens, max_length = 0, [], 30
    for token in decoding.autoregressive_sample_stream(model,
                                                       tokenized[None, :15 *
                                                                 1024],
                                                       batch_size=1,
                                                       temperature=0.0,
                                                       eval_mode=True,
                                                       eval_min_length=1024):
        print(f'Token {counter}: "{detokenize(token)}" {token}')
        tokens.append(token[:, None])
        counter += 1
        if counter > max_length:
            break
    tokens = np.concatenate(tokens, axis=1)
    print(tokens)
    print(detokenize(tokens[0]))