예제 #1
0
 def test_autoregressive_sample_transformer_quality(self):
     pred_model = models.Transformer(d_model=64,
                                     d_ff=128,
                                     dropout=0.05,
                                     max_len=256,
                                     n_heads=2,
                                     n_encoder_layers=2,
                                     n_decoder_layers=2,
                                     input_vocab_size=13,
                                     mode='predict')
     shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
     model_path = os.path.join(_TESTDATA, 'transformer_copy.pkl.gz')
     pred_model.init_from_file(model_path,
                               weights_only=True,
                               input_signature=(shape11, shape11))
     inputs = np.array([[3, 7, 5, 3, 2, 4, 1, 8]], dtype=np.int32)
     s = decoding.autoregressive_sample(pred_model,
                                        inputs=inputs,
                                        eos_id=1,
                                        max_length=10,
                                        temperature=0.0)
     self.assertEqual(str(s[0]), '[3 7 5 3 2 4 1]')
예제 #2
0
 def test_autoregressive_sample_reformerlm_quality(self):
     timebin_self_attention = self._timebin_self_attention_fn()
     pred_model = models.ReformerLM(d_model=64,
                                    d_ff=128,
                                    dropout=0.05,
                                    max_len=256,
                                    n_heads=2,
                                    attention_type=timebin_self_attention,
                                    n_layers=2,
                                    vocab_size=13,
                                    mode='predict')
     shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
     model_path = os.path.join(_TESTDATA, 'reformerlm_copy.pkl.gz')
     pred_model.init_from_file(model_path,
                               weights_only=True,
                               input_signature=(shape11, shape11))
     inputs = np.array([[0, 3, 7, 5, 3, 2, 4, 0]], dtype=np.int32)
     s = decoding.autoregressive_sample(pred_model,
                                        inputs,
                                        max_length=6,
                                        temperature=0.0)
     self.assertEqual(str(s[0]), '[3 7 5 3 2 4]')
예제 #3
0
  def test_autoregressive_sample_reformer2_pure_lsh_attn_quality(self):
    gin.add_config_file_search_path(_CONFIG_DIR)
    max_len = 32  # 32 is the max length we trained the checkpoint for.
    test_lengths = [8, 16, 32]
    vocab_size = 13
    # The checkpoint is correct on ~90% sequences, set random seed to deflake.
    np.random.seed(0)
    for test_len in test_lengths:
      gin.clear_config()
      gin.parse_config_file('reformer2_purelsh_copy.gin')
      gin.bind_parameter('PureLSHSelfAttention.predict_mem_len', 2 * max_len)
      gin.bind_parameter('PureLSHSelfAttention.predict_drop_len', 2 * max_len)
      gin.bind_parameter('PureLSHSelfAttentionWrapper.bias', False)
      gin.bind_parameter('PureLSHSelfAttentionWrapper.num_weights', 2)

      pred_model = models.Reformer2(mode='predict')

      shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
      shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32)

      model_path = os.path.join(_TESTDATA, 'reformer2_purelsh_copy.pkl.gz')
      pred_model.init_from_file(model_path, weights_only=True,
                                input_signature=(shape1l, shape11))
      initial_state = pred_model.state

      for _ in range(2):  # Set low to make the test run reasonably fast.
        # Pick a length in [1, test_len] at random.
        inp_len = np.random.randint(low=1, high=test_len + 1)
        inputs = np.random.randint(low=1, high=vocab_size-1, size=(1, inp_len))
        inputs = np.pad(inputs, [(0, 0), (0, max_len - inp_len)],
                        mode='constant', constant_values=0)
        s = decoding.autoregressive_sample(
            pred_model, inputs=inputs, eos_id=-1, max_length=inp_len,
            temperature=0.0)

        np.testing.assert_equal(s[0], inputs[0, :inp_len])
        pred_model.state = initial_state
    gin.clear_config()  # Make sure to not affect other tests.
예제 #4
0
  def test_autoregressive_sample_reformerlm_lsh(self):
    max_len = 32

    pred_model = models.ReformerLM(
        mode='predict',
        d_model=256,
        d_ff=512,
        dropout=0.05,
        max_len=2 * max_len,
        n_heads=4,
        n_layers=3,
        ff_use_sru=0,
        d_attention_key=64,
        d_attention_value=64,
        attention_type=functools.partial(tl.LSHSelfAttention,
                                         chunk_len=16,
                                         n_hashes=2,
                                         n_buckets=[32, 32],
                                         predict_drop_len=max_len,
                                         predict_mem_len=max_len,
                                         max_length_for_buckets=1024),
        vocab_size=13,
        pos_type='fixed-base',
        pos_d_axial_embs=None,
    )

    shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
    pred_model.init(shape11)

    # 0w0
    inputs = np.array([[0, 3, 7, 5, 3, 2, 0]], dtype=np.int32)
    inputs = np.pad(inputs, [(0, 0), (0, max_len - inputs.shape[1])],
                    mode='constant', constant_values=0)
    s = decoding.autoregressive_sample(
        pred_model, inputs=inputs, eos_id=-1, max_length=10, temperature=0.0)

    self.assertEqual(s.shape[0], 1)
    self.assertEqual(s.shape[1], 10)
예제 #5
0
  def test_autoregressive_sample_reformer2_pure_lsh(self):
    max_len = 128

    pred_model = models.Reformer2(
        mode='predict',
        d_model=256,
        d_ff=512,
        dropout=0.05,
        max_len=max_len,
        n_heads=4,
        n_encoder_layers=1,
        n_decoder_layers=1,
        ff_use_sru=1,
        d_attention_key=64,
        d_attention_value=64,
        encoder_attention_type=self._pure_lsh_self_attention_fn(
            n_chunks_after=1),
        encoder_decoder_attention_type=self._pure_lsh_self_attention_fn(),
        input_vocab_size=256,
        pos_axial_shape=None,
    )

    shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
    shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32)
    pred_model.init(input_signature=(shape1l, shape11))

    # 0w0w
    inputs = np.array(
        [[0, 3, 7, 5, 3, 2, 4, 1, 8, 0, 3, 7, 5, 3, 2, 4, 1, 8]],
        dtype=np.int32)
    inputs = np.pad(inputs, [(0, 0), (0, max_len - inputs.shape[1])],
                    mode='constant', constant_values=0)
    s = decoding.autoregressive_sample(
        pred_model, inputs=inputs, eos_id=-1, max_length=10, temperature=0.0)

    self.assertEqual(s.shape[0], 1)
    self.assertEqual(s.shape[1], 10)
예제 #6
0
def main(argv):
    if len(argv) > 1:
        raise absl_app.UsageError('Too many command-line arguments.')
    if not FLAGS.output_dir:
        raise absl_app.UsageError('--output_dir needs to be provided.')

    tf.compat.v1.enable_eager_execution()

    # Check that checkpoint_dir is correct: should contain model.pkl.gz file.
    model_file = os.path.join(FLAGS.checkpoint_dir, 'model.pkl.gz')
    _check_exists(model_file)

    gin.parse_config_file(os.path.join(FLAGS.checkpoint_dir, 'config.gin'))
    # Batching on our own because of possible repetitions of examples.
    gin.bind_parameter('data.Batch.batch_size', 1)
    if FLAGS.n_hashes is not None:
        gin.bind_parameter('LSHSelfAttention.n_hashes', FLAGS.n_hashes)
        gin.bind_parameter('ref2_encoder/LSHSelfAttention.n_hashes',
                           FLAGS.n_hashes)

    vocab, model, initial_state = prepare_model(model_file, FLAGS.batch_size)

    host_id, host_count = jax.host_id(), jax.host_count()
    print('Running on host %d out of %d.' % (host_id, host_count))

    example_count = 0
    start_time = time.time()

    # Creates all intermediate directories if they do not exist
    tf.io.gfile.makedirs(FLAGS.output_dir)

    json_to_write = os.path.join(FLAGS.output_dir, 'output%d.json' % host_id)
    all_jsons = []

    # In a case of a reset we have to check how much work was already done.
    # We can check whether the processing of an example was finished, but
    # currently we are only checking whether it was started.
    done = FLAGS.starting_example
    reload_count = 0
    all_existing_files = tf.io.gfile.listdir(FLAGS.output_dir)
    for filename in all_existing_files:
        if 'processing' in filename:
            # The definition of digits looks for a number after the infix "processing"
            # in the file name. Example: tom_processing_532 will lead to
            # digits = "processing_532" and number equal to "532".
            digits = filename[filename.find('processing'):]
            number = ''.join(d for d in digits if d.isdigit())
            if is_number(number) and int(
                    number) < FLAGS.num_examples + FLAGS.starting_example:
                done = max(done, int(number))
    print('The done number is {}'.format(done))

    if FLAGS.use_eval_set:
        drop_gen = trax_data.CreateDropInputs(train=False)()
    else:
        drop_gen = trax_data.CreateDropInputs(train=True)()
    padding_fun = trax_data.PadToLength()

    # TODO(henrykm): improve managment of the counters.
    # example_count_total - all numeric examples
    # example_count - all numeric examples above starting_example
    # reload_count - if we processed FLAGS.reload_after examples,
    #   then the checkpoint should be reloaded.
    # idx - total number of exaples
    example_count_total = 0
    reload_count += 1
    for idx, e in enumerate(drop_gen):
        if reload_count >= FLAGS.reload_after:
            vocab, model, initial_state = prepare_model(
                model_file, FLAGS.batch_size)
            reload_count = 0
        if example_count >= FLAGS.num_examples:
            print('Reached the example_count {} - breaking'.format(
                example_count))
            break
        if not is_number(e[1]):
            continue
        target_answer = float(e[1])

        # We count numeric starting examples
        example_count_total += 1
        if example_count_total <= FLAGS.starting_example:
            print('Skipping example_count_total {} because it is below {}'.
                  format(example_count_total, FLAGS.starting_example))
            continue

        if example_count % 10 == 0:
            elapsed_time = time.time() - start_time
            start_time = time.time()
            print('Starting inference on example %d, %.2fs since last log' %
                  (example_count, elapsed_time),
                  flush=True)

        example_count += 1
        if example_count <= done - FLAGS.starting_example + 1:
            print('Skipping example_count {} because it is below {}'.format(
                example_count, done - FLAGS.starting_example))
            # We are increasing the example_count because the example
            # was processed before
            continue

        if example_count % host_count != host_id:
            continue

        # At this point we are committed to the processing of an example with
        # index example_count
        processing_file = os.path.join(FLAGS.output_dir, 'processing_')
        data_id = str(example_count + FLAGS.starting_example)
        with tf.io.gfile.GFile(processing_file + data_id, 'w') as w:
            w.write('Procesing started.')
        for repetition_id, example in multiply_examples(e):
            question = example[0]
            question_text = question[question.find(':') + 2:]
            question_text = question_text.replace('-', ' - ')
            question = 'infer full calculation: ' + question_text

            list_num = [
                float(num.replace(',', '').rstrip('.')) for num in re.findall(
                    r'[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?',
                    question)
            ]
            for i in range(len(list_num)):
                question += ' n{} = {}'.format(i, list_num[i])

            # print('Question {}'.format(question))
            tokenized_question = next(
                padding_fun(
                    trax_data.tokenize([
                        question,
                    ],
                                       vocab_file=gin.query_parameter(
                                           'trax.data.Tokenize.vocab_file'))))
            state = model.state
            if FLAGS.use_beam_search:
                answer_beams = decoding.beam_search(
                    model,
                    tokenized_question[None, :],
                    n_beams=FLAGS.n_beams,
                    max_length=FLAGS.max_answer_len,
                    accelerate=False)
                model.state = state
            else:
                answer_beams = []
                # We recycle the n_beams flag to control the number
                # of autoregressive samples.
                for i in range(FLAGS.n_beams):
                    answer = decoding.autoregressive_sample(
                        model,
                        tokenized_question[None, :],
                        temperature=FLAGS.autoregressive_sample_temp,
                        max_length=FLAGS.max_answer_len,
                        accelerate=False)
                    model.state = state
                    answer_beams.append(answer)

            correct_example_index = -1

            for i in range(len(answer_beams)):
                if FLAGS.use_beam_search:
                    answer = trax_data.detokenize(
                        answer_beams[i][0][0],
                        vocab_file=gin.query_parameter(
                            'trax.data.Tokenize.vocab_file'))
                else:
                    answer = trax_data.detokenize(
                        answer_beams[i][0],
                        vocab_file=gin.query_parameter(
                            'trax.data.Tokenize.vocab_file'))
                print('Proposed computation {}'.format(answer))
                list_op = answer.split('|')
                if not list_op[-1]:
                    list_op = list_op[:-1]

                try:
                    result = trax_data.tf_inputs.compute_result(
                        list_op, list_num)
                    if target_answer in result:
                        correct_example_index = result.index(target_answer)
                        break
                # This is a temporary hack with "broad" exceptions - the computations
                # must fail sometime, because we evaluate arbitrary sequences; I am in
                # the process of checking what are possible failure modes.
                except Exception as e:  # pylint: disable=broad-except
                    print(e)
                    try:
                        result = trax_data.tf_inputs.compute_result(
                            list_op[:-1], list_num)
                        if target_answer in result:
                            correct_example_index = result.index(target_answer)
                            break
                    except Exception as e:  # pylint: disable=broad-except
                        print(e)
                        print('Infered incorrect computation.')

            if correct_example_index == -1:
                continue

            json_record = {
                'question': question_text,
                'input': question,
                'calculation': '|'.join(list_op[:correct_example_index + 1]),
                'target_answer': target_answer
            }
            all_jsons.append(json.dumps(json_record) + '\n')
            # Outputting the inferred data in JSONL format.
            data_id = str(example_count + FLAGS.starting_example)
            with tf.io.gfile.GFile(json_to_write + data_id, 'w') as w:
                w.write(json.dumps(json_record) + '\n')
        with tf.io.gfile.GFile(processing_file + data_id, 'w') as w:
            w.write('Procesing finished.')

    with tf.io.gfile.GFile(json_to_write + '_' + str(FLAGS.starting_example),
                           'w') as w:
        for record in all_jsons:
            w.write(record)
# an extensive example
inputs = 'question: What are some of the colours of a rose? context: A rose is a woody perennial flowering plant of the genus Rosa, in the family Rosaceae, or the flower it bears.There are over three hundred species and tens of thousands of cultivars. They form a group of plants that can be erect shrubs, climbing, or trailing, with stems that are often armed with sharp prickles. Flowers vary in size and shape and are usually large and showy, in colours ranging from white through yellows and reds. Most species are native to Asia, with smaller numbers native to Europe, North America, and northwestern Africa. Species, cultivars and hybrids are all widely grown for their beauty and often are fragrant.'

# In[9]:

# tokenizing the input so we could feed it for decoding
print(tokenize(inputs))
test_inputs = tokenize(inputs)

# Run the cell below to decode.
#
# ### Note: This will take some time to run

# In[ ]:

# Temperature is a parameter for sampling.
#   # * 0.0: same as argmax, always pick the most probable token
#   # * 1.0: sampling from the distribution (can sometimes say random things)
#   # * values inbetween can trade off diversity and quality, try it out!
output = decoding.autoregressive_sample(
    model,
    inputs=np.array(test_inputs)[None, :],
    temperature=0.0,
    max_length=5)  # originally max_length=10
print(wrapper.fill(pretty_decode(output[0])))

# You should also be aware that the quality of the decoding is not very good because max_length was downsized from 10 to 5 so that this runs faster within this environment. The colab version uses the original max_length so check that one for the actual decoding.

# In[ ]:
예제 #8
0
"""

# # using the 3rd example
# c4_input = inputs_targets_pairs[2][0]
# c4_target = inputs_targets_pairs[2][1]

# using the 1st example
c4_input = inputs_targets_pairs[0][0]
c4_target = inputs_targets_pairs[0][1]

print('pretty_decoded input: \n\n', pretty_decode(c4_input))
print('\npretty_decoded target: \n\n', pretty_decode(c4_target))
print('\nc4_input:\n\n', c4_input)
print('\nc4_target:\n\n', c4_target)
print(len(c4_target))
print(len(pretty_decode(c4_target)))

"""Run the cell below to decode"""

# Faster decoding: (still - maybe lower max_length to 20 for speed)
# Temperature is a parameter for sampling.
#   # * 0.0: same as argmax, always pick the most probable token
#   # * 1.0: sampling from the distribution (can sometimes say random things)
#   # * values inbetween can trade off diversity and quality, try it out!
output = decoding.autoregressive_sample(model, inputs=np.array(c4_input)[None, :],
                                        temperature=0.0, max_length=50)
print(wrapper.fill(pretty_decode(output[0])))

"""### Note: As you can see the RAM is almost full, it is because the model and the decoding is memory heavy. Running it the second time might give you an answer that makes no sense, or repetitive words. If that happens restart the runtime (see how to at the start of the notebook) and run all the cells again."""