def test_autoregressive_sample_transformerlm_quality_beam(self): pred_model = models.TransformerLM( d_model=64, d_ff=128, dropout=0.05, max_len=256, n_heads=2, n_layers=2, vocab_size=13, mode='predict') shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) model_path = os.path.join(_TESTDATA, 'transformerlm_copy.pkl.gz') pred_model.init_from_file(model_path, weights_only=True, input_signature=(shape11, shape11)) inputs = np.array([[0, 3, 7, 5, 3, 2, 4, 0]], dtype=np.int32) s = decoding.beam_search(pred_model, inputs, n_beams=3, max_length=6) self.assertEqual(len(s), 3) # 3 beams self.assertEqual(str(s[0][0][0]), '[3 7 5 3 2 4]') self.assertEqual(str(s[1][0][0]), '[7 5 3 2 4 4]') # different from above self.assertEqual(str(s[2][0][0]), '[7 5 3 2 4 3]') # different from above
def main(argv): if len(argv) > 1: raise absl_app.UsageError('Too many command-line arguments.') if not FLAGS.output_dir: raise absl_app.UsageError('--output_dir needs to be provided.') tf.compat.v1.enable_eager_execution() # Check that checkpoint_dir is correct: should contain model.pkl.gz file. model_file = os.path.join(FLAGS.checkpoint_dir, 'model.pkl.gz') _check_exists(model_file) gin.parse_config_file(os.path.join(FLAGS.checkpoint_dir, 'config.gin')) # Batching on our own because of possible repetitions of examples. gin.bind_parameter('data.Batch.batch_size', 1) if FLAGS.n_hashes is not None: gin.bind_parameter('LSHSelfAttention.n_hashes', FLAGS.n_hashes) gin.bind_parameter('ref2_encoder/LSHSelfAttention.n_hashes', FLAGS.n_hashes) vocab, model, initial_state = prepare_model(model_file, FLAGS.batch_size) host_id, host_count = jax.host_id(), jax.host_count() print('Running on host %d out of %d.' % (host_id, host_count)) example_count = 0 start_time = time.time() # Creates all intermediate directories if they do not exist tf.io.gfile.makedirs(FLAGS.output_dir) json_to_write = os.path.join(FLAGS.output_dir, 'output%d.json' % host_id) all_jsons = [] # In a case of a reset we have to check how much work was already done. # We can check whether the processing of an example was finished, but # currently we are only checking whether it was started. done = FLAGS.starting_example reload_count = 0 all_existing_files = tf.io.gfile.listdir(FLAGS.output_dir) for filename in all_existing_files: if 'processing' in filename: # The definition of digits looks for a number after the infix "processing" # in the file name. Example: tom_processing_532 will lead to # digits = "processing_532" and number equal to "532". digits = filename[filename.find('processing'):] number = ''.join(d for d in digits if d.isdigit()) if is_number(number) and int( number) < FLAGS.num_examples + FLAGS.starting_example: done = max(done, int(number)) print('The done number is {}'.format(done)) if FLAGS.use_eval_set: drop_gen = trax_data.CreateDropInputs(train=False)() else: drop_gen = trax_data.CreateDropInputs(train=True)() padding_fun = trax_data.PadToLength() # TODO(henrykm): improve managment of the counters. # example_count_total - all numeric examples # example_count - all numeric examples above starting_example # reload_count - if we processed FLAGS.reload_after examples, # then the checkpoint should be reloaded. # idx - total number of exaples example_count_total = 0 reload_count += 1 for idx, e in enumerate(drop_gen): if reload_count >= FLAGS.reload_after: vocab, model, initial_state = prepare_model( model_file, FLAGS.batch_size) reload_count = 0 if example_count >= FLAGS.num_examples: print('Reached the example_count {} - breaking'.format( example_count)) break if not is_number(e[1]): continue target_answer = float(e[1]) # We count numeric starting examples example_count_total += 1 if example_count_total <= FLAGS.starting_example: print('Skipping example_count_total {} because it is below {}'. format(example_count_total, FLAGS.starting_example)) continue if example_count % 10 == 0: elapsed_time = time.time() - start_time start_time = time.time() print('Starting inference on example %d, %.2fs since last log' % (example_count, elapsed_time), flush=True) example_count += 1 if example_count <= done - FLAGS.starting_example + 1: print('Skipping example_count {} because it is below {}'.format( example_count, done - FLAGS.starting_example)) # We are increasing the example_count because the example # was processed before continue if example_count % host_count != host_id: continue # At this point we are committed to the processing of an example with # index example_count processing_file = os.path.join(FLAGS.output_dir, 'processing_') data_id = str(example_count + FLAGS.starting_example) with tf.io.gfile.GFile(processing_file + data_id, 'w') as w: w.write('Procesing started.') for repetition_id, example in multiply_examples(e): question = example[0] question_text = question[question.find(':') + 2:] question_text = question_text.replace('-', ' - ') question = 'infer full calculation: ' + question_text list_num = [ float(num.replace(',', '').rstrip('.')) for num in re.findall( r'[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?', question) ] for i in range(len(list_num)): question += ' n{} = {}'.format(i, list_num[i]) # print('Question {}'.format(question)) tokenized_question = next( padding_fun( trax_data.tokenize([ question, ], vocab_file=gin.query_parameter( 'trax.data.Tokenize.vocab_file')))) state = model.state if FLAGS.use_beam_search: answer_beams = decoding.beam_search( model, tokenized_question[None, :], n_beams=FLAGS.n_beams, max_length=FLAGS.max_answer_len, accelerate=False) model.state = state else: answer_beams = [] # We recycle the n_beams flag to control the number # of autoregressive samples. for i in range(FLAGS.n_beams): answer = decoding.autoregressive_sample( model, tokenized_question[None, :], temperature=FLAGS.autoregressive_sample_temp, max_length=FLAGS.max_answer_len, accelerate=False) model.state = state answer_beams.append(answer) correct_example_index = -1 for i in range(len(answer_beams)): if FLAGS.use_beam_search: answer = trax_data.detokenize( answer_beams[i][0][0], vocab_file=gin.query_parameter( 'trax.data.Tokenize.vocab_file')) else: answer = trax_data.detokenize( answer_beams[i][0], vocab_file=gin.query_parameter( 'trax.data.Tokenize.vocab_file')) print('Proposed computation {}'.format(answer)) list_op = answer.split('|') if not list_op[-1]: list_op = list_op[:-1] try: result = trax_data.tf_inputs.compute_result( list_op, list_num) if target_answer in result: correct_example_index = result.index(target_answer) break # This is a temporary hack with "broad" exceptions - the computations # must fail sometime, because we evaluate arbitrary sequences; I am in # the process of checking what are possible failure modes. except Exception as e: # pylint: disable=broad-except print(e) try: result = trax_data.tf_inputs.compute_result( list_op[:-1], list_num) if target_answer in result: correct_example_index = result.index(target_answer) break except Exception as e: # pylint: disable=broad-except print(e) print('Infered incorrect computation.') if correct_example_index == -1: continue json_record = { 'question': question_text, 'input': question, 'calculation': '|'.join(list_op[:correct_example_index + 1]), 'target_answer': target_answer } all_jsons.append(json.dumps(json_record) + '\n') # Outputting the inferred data in JSONL format. data_id = str(example_count + FLAGS.starting_example) with tf.io.gfile.GFile(json_to_write + data_id, 'w') as w: w.write(json.dumps(json_record) + '\n') with tf.io.gfile.GFile(processing_file + data_id, 'w') as w: w.write('Procesing finished.') with tf.io.gfile.GFile(json_to_write + '_' + str(FLAGS.starting_example), 'w') as w: for record in all_jsons: w.write(record)