Exemple #1
0
 def test_autoregressive_sample_transformerlm_quality_beam(self):
   pred_model = models.TransformerLM(
       d_model=64, d_ff=128, dropout=0.05, max_len=256, n_heads=2,
       n_layers=2, vocab_size=13, mode='predict')
   shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
   model_path = os.path.join(_TESTDATA, 'transformerlm_copy.pkl.gz')
   pred_model.init_from_file(model_path, weights_only=True,
                             input_signature=(shape11, shape11))
   inputs = np.array([[0, 3, 7, 5, 3, 2, 4, 0]], dtype=np.int32)
   s = decoding.beam_search(pred_model, inputs, n_beams=3, max_length=6)
   self.assertEqual(len(s), 3)  # 3 beams
   self.assertEqual(str(s[0][0][0]), '[3 7 5 3 2 4]')
   self.assertEqual(str(s[1][0][0]), '[7 5 3 2 4 4]')  # different from above
   self.assertEqual(str(s[2][0][0]), '[7 5 3 2 4 3]')  # different from above
Exemple #2
0
def main(argv):
    if len(argv) > 1:
        raise absl_app.UsageError('Too many command-line arguments.')
    if not FLAGS.output_dir:
        raise absl_app.UsageError('--output_dir needs to be provided.')

    tf.compat.v1.enable_eager_execution()

    # Check that checkpoint_dir is correct: should contain model.pkl.gz file.
    model_file = os.path.join(FLAGS.checkpoint_dir, 'model.pkl.gz')
    _check_exists(model_file)

    gin.parse_config_file(os.path.join(FLAGS.checkpoint_dir, 'config.gin'))
    # Batching on our own because of possible repetitions of examples.
    gin.bind_parameter('data.Batch.batch_size', 1)
    if FLAGS.n_hashes is not None:
        gin.bind_parameter('LSHSelfAttention.n_hashes', FLAGS.n_hashes)
        gin.bind_parameter('ref2_encoder/LSHSelfAttention.n_hashes',
                           FLAGS.n_hashes)

    vocab, model, initial_state = prepare_model(model_file, FLAGS.batch_size)

    host_id, host_count = jax.host_id(), jax.host_count()
    print('Running on host %d out of %d.' % (host_id, host_count))

    example_count = 0
    start_time = time.time()

    # Creates all intermediate directories if they do not exist
    tf.io.gfile.makedirs(FLAGS.output_dir)

    json_to_write = os.path.join(FLAGS.output_dir, 'output%d.json' % host_id)
    all_jsons = []

    # In a case of a reset we have to check how much work was already done.
    # We can check whether the processing of an example was finished, but
    # currently we are only checking whether it was started.
    done = FLAGS.starting_example
    reload_count = 0
    all_existing_files = tf.io.gfile.listdir(FLAGS.output_dir)
    for filename in all_existing_files:
        if 'processing' in filename:
            # The definition of digits looks for a number after the infix "processing"
            # in the file name. Example: tom_processing_532 will lead to
            # digits = "processing_532" and number equal to "532".
            digits = filename[filename.find('processing'):]
            number = ''.join(d for d in digits if d.isdigit())
            if is_number(number) and int(
                    number) < FLAGS.num_examples + FLAGS.starting_example:
                done = max(done, int(number))
    print('The done number is {}'.format(done))

    if FLAGS.use_eval_set:
        drop_gen = trax_data.CreateDropInputs(train=False)()
    else:
        drop_gen = trax_data.CreateDropInputs(train=True)()
    padding_fun = trax_data.PadToLength()

    # TODO(henrykm): improve managment of the counters.
    # example_count_total - all numeric examples
    # example_count - all numeric examples above starting_example
    # reload_count - if we processed FLAGS.reload_after examples,
    #   then the checkpoint should be reloaded.
    # idx - total number of exaples
    example_count_total = 0
    reload_count += 1
    for idx, e in enumerate(drop_gen):
        if reload_count >= FLAGS.reload_after:
            vocab, model, initial_state = prepare_model(
                model_file, FLAGS.batch_size)
            reload_count = 0
        if example_count >= FLAGS.num_examples:
            print('Reached the example_count {} - breaking'.format(
                example_count))
            break
        if not is_number(e[1]):
            continue
        target_answer = float(e[1])

        # We count numeric starting examples
        example_count_total += 1
        if example_count_total <= FLAGS.starting_example:
            print('Skipping example_count_total {} because it is below {}'.
                  format(example_count_total, FLAGS.starting_example))
            continue

        if example_count % 10 == 0:
            elapsed_time = time.time() - start_time
            start_time = time.time()
            print('Starting inference on example %d, %.2fs since last log' %
                  (example_count, elapsed_time),
                  flush=True)

        example_count += 1
        if example_count <= done - FLAGS.starting_example + 1:
            print('Skipping example_count {} because it is below {}'.format(
                example_count, done - FLAGS.starting_example))
            # We are increasing the example_count because the example
            # was processed before
            continue

        if example_count % host_count != host_id:
            continue

        # At this point we are committed to the processing of an example with
        # index example_count
        processing_file = os.path.join(FLAGS.output_dir, 'processing_')
        data_id = str(example_count + FLAGS.starting_example)
        with tf.io.gfile.GFile(processing_file + data_id, 'w') as w:
            w.write('Procesing started.')
        for repetition_id, example in multiply_examples(e):
            question = example[0]
            question_text = question[question.find(':') + 2:]
            question_text = question_text.replace('-', ' - ')
            question = 'infer full calculation: ' + question_text

            list_num = [
                float(num.replace(',', '').rstrip('.')) for num in re.findall(
                    r'[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?',
                    question)
            ]
            for i in range(len(list_num)):
                question += ' n{} = {}'.format(i, list_num[i])

            # print('Question {}'.format(question))
            tokenized_question = next(
                padding_fun(
                    trax_data.tokenize([
                        question,
                    ],
                                       vocab_file=gin.query_parameter(
                                           'trax.data.Tokenize.vocab_file'))))
            state = model.state
            if FLAGS.use_beam_search:
                answer_beams = decoding.beam_search(
                    model,
                    tokenized_question[None, :],
                    n_beams=FLAGS.n_beams,
                    max_length=FLAGS.max_answer_len,
                    accelerate=False)
                model.state = state
            else:
                answer_beams = []
                # We recycle the n_beams flag to control the number
                # of autoregressive samples.
                for i in range(FLAGS.n_beams):
                    answer = decoding.autoregressive_sample(
                        model,
                        tokenized_question[None, :],
                        temperature=FLAGS.autoregressive_sample_temp,
                        max_length=FLAGS.max_answer_len,
                        accelerate=False)
                    model.state = state
                    answer_beams.append(answer)

            correct_example_index = -1

            for i in range(len(answer_beams)):
                if FLAGS.use_beam_search:
                    answer = trax_data.detokenize(
                        answer_beams[i][0][0],
                        vocab_file=gin.query_parameter(
                            'trax.data.Tokenize.vocab_file'))
                else:
                    answer = trax_data.detokenize(
                        answer_beams[i][0],
                        vocab_file=gin.query_parameter(
                            'trax.data.Tokenize.vocab_file'))
                print('Proposed computation {}'.format(answer))
                list_op = answer.split('|')
                if not list_op[-1]:
                    list_op = list_op[:-1]

                try:
                    result = trax_data.tf_inputs.compute_result(
                        list_op, list_num)
                    if target_answer in result:
                        correct_example_index = result.index(target_answer)
                        break
                # This is a temporary hack with "broad" exceptions - the computations
                # must fail sometime, because we evaluate arbitrary sequences; I am in
                # the process of checking what are possible failure modes.
                except Exception as e:  # pylint: disable=broad-except
                    print(e)
                    try:
                        result = trax_data.tf_inputs.compute_result(
                            list_op[:-1], list_num)
                        if target_answer in result:
                            correct_example_index = result.index(target_answer)
                            break
                    except Exception as e:  # pylint: disable=broad-except
                        print(e)
                        print('Infered incorrect computation.')

            if correct_example_index == -1:
                continue

            json_record = {
                'question': question_text,
                'input': question,
                'calculation': '|'.join(list_op[:correct_example_index + 1]),
                'target_answer': target_answer
            }
            all_jsons.append(json.dumps(json_record) + '\n')
            # Outputting the inferred data in JSONL format.
            data_id = str(example_count + FLAGS.starting_example)
            with tf.io.gfile.GFile(json_to_write + data_id, 'w') as w:
                w.write(json.dumps(json_record) + '\n')
        with tf.io.gfile.GFile(processing_file + data_id, 'w') as w:
            w.write('Procesing finished.')

    with tf.io.gfile.GFile(json_to_write + '_' + str(FLAGS.starting_example),
                           'w') as w:
        for record in all_jsons:
            w.write(record)