Ejemplo n.º 1
0
  def test_autoregressive_sample_reformer2_lsh_attn_quality(self):
    gin.add_config_file_search_path(_CONFIG_DIR)
    max_len = 32  # 32 is the max length we trained the checkpoint for.
    test_lengths = [8, 16, 32]
    vocab_size = 13
    # The checkpoint is correct on ~90% sequences, set random seed to deflake.
    np.random.seed(0)
    for test_len in test_lengths:
      gin.clear_config()
      gin.parse_config_file('reformer2_copy.gin')
      gin.bind_parameter('LSHSelfAttention.predict_mem_len', 2 * max_len)
      gin.bind_parameter('LSHSelfAttention.predict_drop_len', 2 * max_len)

      pred_model = models.Reformer2(mode='predict')

      shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
      shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32)

      model_path = os.path.join(_TESTDATA, 'reformer2_copy_lsh_attn.pkl.gz')
      pred_model.init_from_file(model_path, weights_only=True,
                                input_signature=(shape1l, shape11))
      initial_state = pred_model.state

      for _ in range(2):  # Set low to make the test run reasonably fast.
        # Pick a length in [1, test_len] at random.
        inp_len = np.random.randint(low=1, high=test_len + 1)
        inputs = np.random.randint(low=1, high=vocab_size-1, size=(1, inp_len))
        inputs = np.pad(inputs, [(0, 0), (0, max_len - inp_len)],
                        mode='constant', constant_values=0)
        s = decoding.autoregressive_sample(
            pred_model, inputs=inputs, eos_id=-1, max_length=inp_len,
            temperature=0.0)
        np.testing.assert_equal(s[0], inputs[0, :inp_len])
        pred_model.state = initial_state
    gin.clear_config()  # Make sure to not affect other tests.
Ejemplo n.º 2
0
    def test_autoregressive_sample_reformer2_timing(self):
        max_len = 16

        def _self_attention_fn():
            return functools.partial(layers.SelfAttention,
                                     predict_drop_len=2 * max_len,
                                     predict_mem_len=2 * max_len)

        def _causal_attention_fn():
            return functools.partial(layers.ModularCausalAttention,
                                     n_modules=64,
                                     max_inference_length=2 * max_len)

        pred_model = models.Reformer2(
            mode='predict',
            d_model=8 * 1024,
            d_ff=64 * 1024,
            dropout=0.05,
            max_len=max_len,
            n_heads=64,
            n_encoder_layers=2,
            n_decoder_layers=2,
            encoder_attention_type=_self_attention_fn(),
            encoder_decoder_attention_type=_causal_attention_fn(),
            input_vocab_size=4,
            ff_sparsity=256,
            axial_pos_shape=None,
        )

        shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
        shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32)
        pred_model.init(input_signature=(shape1l, shape11))
        inputs = np.arange(16, dtype=np.int32).reshape(1, 16)

        # This is decoding.autoregressive_sample but simplified and with timing.
        result, counter, start_time, total_time = [], 0, time.time(), 0.0
        for sample in decoding.autoregressive_sample_stream(
                pred_model, inputs, temperature=0.0):  # accelerate=False):
            elapsed_time = time.time() - start_time
            start_time = time.time()
            if counter > 3:
                total_time += elapsed_time
            result.append(sample[:, None])
            counter += 1
            if counter >= 14:
                break

        # We print 5* time for 10 tokens, @2 layers this is ~1 token @ 100 layers.
        print('\n\nTime for 5x10 tokens (~1tok @100): %.4fs\n\n\n' %
              (5 * total_time))
        self.assertLess(total_time, 20.0)  # If it's > 20s, it's some bug.
        # Check resulting shapes.
        s = np.concatenate(result, axis=1)
        self.assertEqual(s.shape[0], 1)
        self.assertEqual(s.shape[1], 14)
Ejemplo n.º 3
0
 def model(mode='train'):
     return models.Reformer2(
         mode=mode,
         d_model=16,
         d_ff=16,
         n_heads=2,
         dropout=0.05,
         n_decoder_layers=1,
         n_encoder_layers=1,
         input_vocab_size=256,
         encoder_attention_type=_mixed_lsh_self_attention_fn(),
         encoder_decoder_attention_type=_mixed_lsh_self_attention_fn(
         ),
     )
Ejemplo n.º 4
0
    def test_autoregressive_sample_reformer2_copy_self_attn_quality(self):
        max_len = 32

        def _self_attention_fn():
            return functools.partial(
                tl.SelfAttention,
                predict_drop_len=2 * max_len,
                predict_mem_len=2 * max_len,
            )

        pred_model = models.Reformer2(
            mode='predict',
            d_model=256,
            d_ff=512,
            dropout=0.05,
            max_len=max_len,
            n_heads=4,
            n_encoder_layers=3,
            n_decoder_layers=3,
            ff_use_sru=1,
            d_attention_key=64,
            d_attention_value=64,
            encoder_attention_type=_self_attention_fn(),
            encoder_decoder_attention_type=_self_attention_fn(),
            n_decoder_attention_layers=1,
            input_vocab_size=13,
            axial_pos_shape=None,
        )

        shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
        shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32)

        model_path = os.path.join(_TESTDATA, 'reformer2_copy_self_attn.pkl.gz')
        pred_model.init_from_file(model_path,
                                  weights_only=True,
                                  input_signature=(shape1l, shape11))

        inputs = np.array([[11, 5, 1, 2, 3, 4]], dtype=np.int32)
        inp_len = inputs.shape[1]
        inputs = np.pad(inputs, [(0, 0), (0, max_len - inp_len)],
                        mode='constant',
                        constant_values=0)
        s = decoding.autoregressive_sample(pred_model,
                                           inputs=inputs,
                                           eos_id=-1,
                                           max_length=inp_len,
                                           temperature=0.0)

        np.testing.assert_equal(s[0], inputs[0, :inp_len])
Ejemplo n.º 5
0
    def test_autoregressive_sample_reformer2_pure_lsh(self):
        max_len = 128

        pred_model = models.Reformer2(
            mode='predict',
            d_model=256,
            d_ff=512,
            dropout=0.05,
            max_len=max_len,
            n_heads=4,
            n_encoder_layers=1,
            n_decoder_layers=1,
            ff_use_sru=1,
            d_attention_key=64,
            d_attention_value=64,
            encoder_attention_type=self._pure_lsh_self_attention_fn(
                n_chunks_after=1),
            encoder_decoder_attention_type=self._pure_lsh_self_attention_fn(),
            input_vocab_size=256,
            pos_axial_shape=None,
        )

        shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
        shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32)
        pred_model.init(input_signature=(shape1l, shape11))

        # 0w0w
        inputs = np.array(
            [[0, 3, 7, 5, 3, 2, 4, 1, 8, 0, 3, 7, 5, 3, 2, 4, 1, 8]],
            dtype=np.int32)
        inputs = np.pad(inputs, [(0, 0), (0, max_len - inputs.shape[1])],
                        mode='constant',
                        constant_values=0)
        s = decoding.autoregressive_sample(pred_model,
                                           inputs=inputs,
                                           eos_id=-1,
                                           max_length=10,
                                           temperature=0.0)

        self.assertEqual(s.shape[0], 1)
        self.assertEqual(s.shape[1], 10)
Ejemplo n.º 6
0
    def test_autoregressive_sample_reformer2_timing(self):
        max_len = 16
        all_settings = [
            {
                'attn_sparsity': 64,
                'ff_sparsity': (256, 32),
                'attn': (tl.MultiplicativeCausalAttention, {})
            },
            {
                'attn_sparsity': 64,
                'ff_sparsity': (256, 32),
                'attn': (tl.ModularCausalAttention, {})
            },
            {
                'attn_sparsity': 64,
                'ff_sparsity': (256, 32),
                'attn': (tl.ConvCausalAttention, {})
            },
            {
                'attn_sparsity':
                64,
                'ff_sparsity': (256, 32),
                'attn': (tl.MultiplicativeConvCausalAttention, {
                    'length_kernel_size': 1
                })
            },
            {
                'attn_sparsity':
                64,
                'ff_sparsity': (256, 32),
                'attn': (tl.MultiplicativeConvCausalAttention, {
                    'length_kernel_size': 3
                })
            },
        ]
        messages = []

        for settings in all_settings:

            def _self_attention_fn():
                return functools.partial(tl.SelfAttention,
                                         predict_drop_len=2 * max_len,
                                         predict_mem_len=2 * max_len)

            def _causal_attention_fn():
                attn_layer, attn_kwargs = settings['attn']  # pylint: disable=cell-var-from-loop
                return functools.partial(
                    attn_layer,
                    sparsity=settings['attn_sparsity'],  # pylint: disable=cell-var-from-loop
                    max_inference_length=2 * max_len,
                    **attn_kwargs)

            pred_model = models.Reformer2(
                mode='predict',
                d_model=96 * 96,
                d_ff=64 * 1024,
                dropout=0.05,
                max_len=max_len,
                n_heads=64,
                n_encoder_layers=2,
                n_decoder_layers=2,
                encoder_attention_type=_self_attention_fn(),
                encoder_decoder_attention_type=_causal_attention_fn(),
                input_vocab_size=4,
                ff_sparsity=settings['ff_sparsity'],
                axial_pos_shape=None,
                use_bfloat16=True,
            )

            shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
            shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32)
            pred_model.init(input_signature=(shape1l, shape11))
            inputs = np.arange(16, dtype=np.int32).reshape(1, 16)

            # This is decoding.autoregressive_sample but simplified and with timing.
            result, counter, start_time, total_time = [], 0, time.time(), 0.0
            for sample in decoding.autoregressive_sample_stream(
                    pred_model, inputs, temperature=0.0):  # accelerate=False):
                elapsed_time = time.time() - start_time
                start_time = time.time()
                if counter > 3:
                    total_time += elapsed_time
                result.append(sample[:, None])
                counter += 1
                if counter >= 14:
                    break

            # We print 5* time for 10 tokens, @2 layers this is ~1 token @ 100 layers.
            message = (
                '\n\nSettings: %s\nTime for 5x10 tokens (~1tok @100): %.4f s\n\n\n'
                % (settings, 5 * total_time))
            messages.append(message)
            print(message)
            # self.assertLess(total_time, 20.0)  # If it's > 20s, it's some bug.
            # # Check resulting shapes.
            # s = np.concatenate(result, axis=1)
            # self.assertEqual(s.shape[0], 1)
            # self.assertEqual(s.shape[1], 14)

        print('Final results (recap):')
        for message in messages:
            print(message)
Ejemplo n.º 7
0
    def _reformer2_decoding_time(self, settings):
        max_len = 16

        def _self_attention_fn():
            return functools.partial(tl.SelfAttention,
                                     predict_drop_len=2 * max_len,
                                     predict_mem_len=2 * max_len)

        def _causal_attention_fn():
            attn_layer, attn_kwargs = settings['attn']
            return functools.partial(attn_layer,
                                     max_inference_length=2 * max_len,
                                     **attn_kwargs)

        pred_model = models.Reformer2(
            mode='predict',
            d_model=settings['d_model'],
            d_ff=settings['d_ff'],
            dropout=0.1,
            max_len=max_len,
            n_heads=settings['n_heads'],
            n_encoder_layers=settings['encoder_layers'],
            n_decoder_layers=settings['decoder_layers'],
            encoder_attention_type=_self_attention_fn(),
            encoder_decoder_attention_type=_causal_attention_fn(),
            input_vocab_size=settings['vocab'],
            ff_sparsity=settings['ff_sparsity'],
            ff_use_sru=settings['ff_use_sru'],
            ff_dropout=0.1,
            ff_chunk_size=1024,
            attention_chunk_size=1,
            loss_sparsity=settings['loss_sparsity'],
            axial_pos_shape=None,
            use_bfloat16=True,
        )

        shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
        shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32)
        pred_model.init(input_signature=(shape1l, shape11))
        inputs = np.arange(max_len, dtype=np.int32).reshape(1, max_len)

        # This is decoding.autoregressive_sample but simplified and with timing.
        result, start_time = [], time.time()
        elapsed_times = []
        peak_memory = 0
        for index, sample in zip(
                range(-4, 10),
                decoding.autoregressive_sample_stream(pred_model,
                                                      inputs,
                                                      temperature=0.0)):
            peak_memory = max(peak_memory, memory_usage())
            result.append(sample[:, None])
            elapsed_time = time.time() - start_time
            if index >= 0:
                elapsed_times.append(elapsed_time)
            start_time = time.time()

        if min(elapsed_times) * 2 < max(elapsed_times):
            print(
                'WARNING! High variance found in elapsed times! Settings: {} ; '
                'elapsed times: {} ; Probably more warm-up steps should be used, '
                'or model size should be increased.'.format(
                    settings, elapsed_times))
        # Check resulting shapes.
        s = np.concatenate(result, axis=1)
        self.assertEqual(s.shape[0], 1)
        self.assertEqual(s.shape[1], 14)
        return size_of_model(pred_model), elapsed_times, peak_memory
Ejemplo n.º 8
0
    def _reformer2_decoding_time(self, settings):
        # Garbage collection influences the timing, so we turn it off.
        gc.disable()
        max_len = 16

        def _self_attention_fn():
            return functools.partial(tl.SelfAttention,
                                     predict_drop_len=2 * max_len,
                                     predict_mem_len=2 * max_len)

        def _causal_attention_fn():
            attn_layer, attn_kwargs = settings['attn']
            return functools.partial(attn_layer,
                                     max_inference_length=2 * max_len,
                                     **attn_kwargs)

        pred_model = models.Reformer2(
            mode='predict',
            d_model=settings['d_model'],
            d_ff=settings['d_ff'],
            dropout=0.1,
            max_len=max_len,
            n_heads=settings['n_heads'],
            n_encoder_layers=settings['encoder_layers'],
            n_decoder_layers=settings['decoder_layers'],
            encoder_attention_type=_self_attention_fn(),
            encoder_decoder_attention_type=_causal_attention_fn(),
            input_vocab_size=settings['vocab'],
            ff_sparsity=settings['ff_sparsity'],
            ff_use_sru=settings['ff_use_sru'],
            ff_dropout=0.1,
            ff_chunk_size=1024,
            attention_chunk_size=1,
            n_decoder_attention_layers=settings['attention_layers'],
            loss_sparsity=settings['loss_sparsity'],
            pos_axial_shape=None,
            use_bfloat16=True,
        )
        # We put acceleration outside of autoregressive_sample_stream, because
        # we want to have a separate run (separate input) for model compilation.
        pred_model = tl.Accelerate(pred_model)

        shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
        shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32)
        pred_model.init(input_signature=(shape1l, shape11))
        original_state = copy.deepcopy(pred_model.state)

        inputs_warmup = np.zeros((1, max_len), dtype=np.int32)
        inputs = np.arange(max_len, dtype=np.int32).reshape(1, max_len)

        # This is a warm-up run, for compilation.
        result, current_time = [], time.time()
        elapsed_warmup_times = []
        for index, sample in zip(
                range(0, 4),
                decoding.autoregressive_sample_stream(pred_model,
                                                      inputs_warmup,
                                                      temperature=0.0,
                                                      accelerate=False)):
            del index  # unused
            result.append(
                sample[:, None])  # to be sure that the result is computed

            current_time, start_time = time.time(), current_time
            elapsed_warmup_times.append(current_time - start_time)

        # This is a real decoding timing run that we measure.
        pred_model.state = original_state
        result, current_time = [], time.time()
        elapsed_times = []
        for index, sample in zip(
                range(12),
                decoding.autoregressive_sample_stream(pred_model,
                                                      inputs,
                                                      temperature=0.0,
                                                      accelerate=False)):
            del index  # unused
            result.append(
                sample[:, None])  # to be sure that the result is computed

            current_time, start_time = time.time(), current_time
            elapsed_times.append(current_time - start_time)
        peak_memory = _memory_usage()

        if min(elapsed_times[2:]) * 2 < max(elapsed_times[2:]):
            print(
                'WARNING! High variance found in elapsed times! Settings: {} ; '
                'elapsed times: {} ; Probably more warm-up steps should be used, '
                'or model size should be increased.'.format(
                    settings, elapsed_times))
        # Check resulting shapes.
        s = np.concatenate(result, axis=1)
        self.assertEqual(s.shape[0], 1)
        self.assertEqual(s.shape[1], 12)
        model_size = int(_size_of_model(pred_model))

        # We delete the model weights, because in some situations they won't be
        # deleted automatically.
        _recurrent_delete(pred_model.weights)
        gc.enable()
        return model_size, elapsed_times, peak_memory