Esempio n. 1
0
 def model(mode='train'):
     return models.ConfigurableTerraformer(
         mode=mode,
         d_model=16,
         d_ff=16,
         n_heads=2,
         dropout=0.05,
         n_decoder_layers=1,
         n_encoder_layers=1,
         input_vocab_size=256,
         encoder_attention_type=_mixed_lsh_self_attention_fn(),
         encoder_decoder_attention_type=_mixed_lsh_self_attention_fn(
         ),
     )
Esempio n. 2
0
    def test_autoregressive_sample_terraformer_pure_lsh_attn_quality(self):
        gin.add_config_file_search_path(_CONFIG_DIR)
        max_len = 32  # 32 is the max length we trained the checkpoint for.
        test_lengths = [8, 16, 32]
        vocab_size = 13
        # The checkpoint is correct on ~90% sequences, set random seed to deflake.
        np.random.seed(0)
        for test_len in test_lengths:
            gin.clear_config()
            gin.parse_config_file('terraformer_purelsh_copy.gin')
            gin.bind_parameter('PureLSHSelfAttention.predict_mem_len',
                               2 * max_len)
            gin.bind_parameter('PureLSHSelfAttention.predict_drop_len',
                               2 * max_len)
            gin.bind_parameter('PureLSHSelfAttentionWrapper.bias', False)
            gin.bind_parameter('PureLSHSelfAttentionWrapper.num_weights', 2)

            pred_model = models.ConfigurableTerraformer(mode='predict')

            shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
            shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32)

            model_path = os.path.join(_TESTDATA,
                                      'terraformer_purelsh_copy.pkl.gz')
            pred_model.init_from_file(model_path,
                                      weights_only=True,
                                      input_signature=(shape1l, shape11))
            initial_state = pred_model.state

            for _ in range(2):  # Set low to make the test run reasonably fast.
                # Pick a length in [1, test_len] at random.
                inp_len = np.random.randint(low=1, high=test_len + 1)
                inputs = np.random.randint(low=1,
                                           high=vocab_size - 1,
                                           size=(1, max_len))
                # TODO(jaszczur): properly fix padding in terraformer predict mode,
                # and add a test here.
                s = decoding.autoregressive_sample(pred_model,
                                                   inputs=inputs,
                                                   eos_id=-1,
                                                   max_length=inp_len,
                                                   temperature=0.0)

                np.testing.assert_equal(s[0], inputs[0, :inp_len])
                pred_model.state = initial_state
        gin.clear_config()  # Make sure to not affect other tests.
Esempio n. 3
0
    def test_autoregressive_sample_terraformer_pure_lsh(self):
        max_len = 128

        pred_model = models.ConfigurableTerraformer(
            mode='predict',
            d_model=256,
            d_ff=512,
            dropout=0.05,
            max_len=max_len,
            n_heads=4,
            n_encoder_layers=1,
            n_decoder_layers=1,
            ff_use_sru=1,
            d_attention_key=64,
            d_attention_value=64,
            encoder_attention_type=self._pure_lsh_self_attention_fn(
                n_chunks_after=1),
            encoder_decoder_attention_type=self._pure_lsh_self_attention_fn(),
            input_vocab_size=256,
            pos_axial_shape=None,
        )

        shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
        shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32)
        pred_model.init(input_signature=(shape1l, shape11))

        # 0w0w
        inputs = np.array(
            [[0, 3, 7, 5, 3, 2, 4, 1, 8, 0, 3, 7, 5, 3, 2, 4, 1, 8]],
            dtype=np.int32)
        inputs = np.pad(inputs, [(0, 0), (0, max_len - inputs.shape[1])],
                        mode='constant',
                        constant_values=0)
        s = decoding.autoregressive_sample(pred_model,
                                           inputs=inputs,
                                           eos_id=-1,
                                           max_length=10,
                                           temperature=0.0)

        self.assertEqual(s.shape[0], 1)
        self.assertEqual(s.shape[1], 10)
Esempio n. 4
0
  def _terraformer_decoding_time(self, settings):
    # Garbage collection influences the timing, so we turn it off.
    gc.disable()
    max_len = 16

    def _self_attention_fn():
      return functools.partial(
          tl.SelfAttention,
          predict_drop_len=2 * max_len,
          predict_mem_len=2 * max_len)

    def _causal_attention_fn():
      attn_layer, attn_kwargs = settings['attn']
      return functools.partial(
          attn_layer,
          max_inference_length=2 * max_len, **attn_kwargs)

    if settings['model'] == 'terraformer':
      pred_model = models.ConfigurableTerraformer(
          mode='predict',
          d_model=settings['d_model'],
          d_ff=settings['d_ff'],
          dropout=0.1,
          max_len=max_len,
          n_heads=settings['n_heads'],
          n_encoder_layers=settings['encoder_layers'],
          n_decoder_layers=settings['decoder_layers'],
          encoder_attention_type=_self_attention_fn(),
          encoder_decoder_attention_type=_causal_attention_fn(),
          input_vocab_size=settings['vocab'],
          ff_sparsity=settings['ff_sparsity'],
          ff_use_sru=settings['ff_use_sru'],
          ff_dropout=0.1,
          # ff_chunk_size=1024,
          # attention_chunk_size=1,
          n_decoder_attention_layers=settings['attention_layers'],
          loss_sparsity=settings['loss_sparsity'],
          pos_axial_shape=None,
          use_bfloat16=True,
      )
    elif settings['model'] == 'transformer':
      pred_model = models.ConfigurableTransformer(
          mode='predict',
          d_model=settings['d_model'],
          d_ff=settings['d_ff'],
          dropout=0.1,
          max_len=max_len,
          n_heads=settings['n_heads'],
          n_encoder_layers=settings['encoder_layers'],
          n_decoder_layers=settings['decoder_layers'],
          # encoder_attention_type=_self_attention_fn(),
          encoder_decoder_attention_type=_causal_attention_fn(),
          input_vocab_size=settings['vocab'],
          ff_sparsity=settings['ff_sparsity'],
          ff_use_sru=settings['ff_use_sru'],
          # ff_dropout=0.1,
          # ff_chunk_size=1024,
          # attention_chunk_size=1,
          # n_decoder_attention_layers=settings['attention_layers'],
          loss_sparsity=settings['loss_sparsity'],
          pos_axial_shape=None,
          # enc_dec_attention_sparsity=settings['enc_dec_sparsity'],
          # use_bfloat16=True,
      )
    else:
      assert False
    # We put acceleration outside of autoregressive_sample_stream, because
    # we want to have a separate run (separate input) for model compilation.
    pred_model = tl.Accelerate(pred_model)

    shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
    shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32)
    pred_model.init(input_signature=(shape1l, shape11))
    original_state = copy.deepcopy(pred_model.state)

    inputs_warmup = np.zeros((1, max_len), dtype=np.int32)
    inputs = np.arange(max_len, dtype=np.int32).reshape(1, max_len)

    # This is a warm-up run, for compilation.
    result, current_time = [], time.time()
    elapsed_warmup_times = []
    for index, sample in zip(range(0, 4), decoding.autoregressive_sample_stream(
        pred_model, inputs_warmup, temperature=0.0, accelerate=False)):
      del index  # unused
      result.append(sample[:, None])  # to be sure that the result is computed

      current_time, start_time = time.time(), current_time
      elapsed_warmup_times.append(current_time - start_time)

    # This is a real decoding timing run that we measure.
    pred_model.state = original_state
    result, current_time = [], time.time()
    elapsed_times = []
    for index, sample in zip(range(12), decoding.autoregressive_sample_stream(
        pred_model, inputs, temperature=0.0, accelerate=False)):
      del index  # unused
      result.append(sample[:, None])  # to be sure that the result is computed

      current_time, start_time = time.time(), current_time
      elapsed_times.append(current_time - start_time)
    peak_memory = _memory_usage()

    if min(elapsed_times[2:]) * 2 < max(elapsed_times[2:]):
      print('WARNING! High variance found in elapsed times! Settings: {} ; '
            'elapsed times: {} ; Probably more warm-up steps should be used, '
            'or model size should be increased.'.format(settings,
                                                        elapsed_times))
    # Check resulting shapes.
    s = np.concatenate(result, axis=1)
    self.assertEqual(s.shape[0], 1)
    self.assertEqual(s.shape[1], 12)
    model_size = int(_size_of_model(pred_model))

    # We delete the model weights, because in some situations they won't be
    # deleted automatically.
    _recurrent_delete(pred_model.weights)
    gc.enable()
    return model_size, elapsed_times, peak_memory
Esempio n. 5
0
def model(mode):
    return models.ConfigurableTerraformer(mode=mode)