def process_c4_with_span_corruption(spm_path=None, extra_ids=0, train=False, max_length=100, noise_density=0.15, mean_noise_span_length=3.0, seed1=None, seed2=None): return data.Serial( data.TFDS('c4/en:2.3.0', data_dir=_TESTDATA, keys=('text', ), train=train), data.SentencePieceTokenize(spm_path=spm_path, extra_ids=extra_ids), data.generate_sequential_chunks(max_length=max_length), data.generate_random_noise_mask( noise_density=noise_density, mean_noise_span_length=mean_noise_span_length, seed1=seed1, seed2=seed2), data.consume_noise_mask(vocab_size=32000 + extra_ids), data.FilterEmptyExamples(), data.AppendValue(val={ 0: [1], 1: [1] }), data.PadToLength(len_map={ 0: 100, 1: 30 }, pad_value={ 0: 0, 1: 0 }), data.AddLossWeights(id_to_mask=0), data.Batch(batch_size=2))
def main(argv): if len(argv) > 1: raise absl_app.UsageError('Too many command-line arguments.') if not FLAGS.output_dir: raise absl_app.UsageError('--output_dir needs to be provided.') tf.compat.v1.enable_eager_execution() # Check that checkpoint_dir is correct: should contain model.pkl.gz file. model_file = os.path.join(FLAGS.checkpoint_dir, 'model.pkl.gz') _check_exists(model_file) gin.parse_config_file(os.path.join(FLAGS.checkpoint_dir, 'config.gin')) # Batching on our own because of possible repetitions of examples. gin.bind_parameter('data.Batch.batch_size', 1) if FLAGS.n_hashes is not None: gin.bind_parameter('LSHSelfAttention.n_hashes', FLAGS.n_hashes) gin.bind_parameter('ref2_encoder/LSHSelfAttention.n_hashes', FLAGS.n_hashes) vocab, model, initial_state = prepare_model(model_file, FLAGS.batch_size) host_id, host_count = jax.host_id(), jax.host_count() print('Running on host %d out of %d.' % (host_id, host_count)) example_count = 0 start_time = time.time() # Creates all intermediate directories if they do not exist tf.io.gfile.makedirs(FLAGS.output_dir) json_to_write = os.path.join(FLAGS.output_dir, 'output%d.json' % host_id) all_jsons = [] # In a case of a reset we have to check how much work was already done. # We can check whether the processing of an example was finished, but # currently we are only checking whether it was started. done = FLAGS.starting_example reload_count = 0 all_existing_files = tf.io.gfile.listdir(FLAGS.output_dir) for filename in all_existing_files: if 'processing' in filename: # The definition of digits looks for a number after the infix "processing" # in the file name. Example: tom_processing_532 will lead to # digits = "processing_532" and number equal to "532". digits = filename[filename.find('processing'):] number = ''.join(d for d in digits if d.isdigit()) if is_number(number) and int( number) < FLAGS.num_examples + FLAGS.starting_example: done = max(done, int(number)) print('The done number is {}'.format(done)) if FLAGS.use_eval_set: drop_gen = trax_data.CreateDropInputs(train=False)() else: drop_gen = trax_data.CreateDropInputs(train=True)() padding_fun = trax_data.PadToLength() # TODO(henrykm): improve managment of the counters. # example_count_total - all numeric examples # example_count - all numeric examples above starting_example # reload_count - if we processed FLAGS.reload_after examples, # then the checkpoint should be reloaded. # idx - total number of exaples example_count_total = 0 reload_count += 1 for idx, e in enumerate(drop_gen): if reload_count >= FLAGS.reload_after: vocab, model, initial_state = prepare_model( model_file, FLAGS.batch_size) reload_count = 0 if example_count >= FLAGS.num_examples: print('Reached the example_count {} - breaking'.format( example_count)) break if not is_number(e[1]): continue target_answer = float(e[1]) # We count numeric starting examples example_count_total += 1 if example_count_total <= FLAGS.starting_example: print('Skipping example_count_total {} because it is below {}'. format(example_count_total, FLAGS.starting_example)) continue if example_count % 10 == 0: elapsed_time = time.time() - start_time start_time = time.time() print('Starting inference on example %d, %.2fs since last log' % (example_count, elapsed_time), flush=True) example_count += 1 if example_count <= done - FLAGS.starting_example + 1: print('Skipping example_count {} because it is below {}'.format( example_count, done - FLAGS.starting_example)) # We are increasing the example_count because the example # was processed before continue if example_count % host_count != host_id: continue # At this point we are committed to the processing of an example with # index example_count processing_file = os.path.join(FLAGS.output_dir, 'processing_') data_id = str(example_count + FLAGS.starting_example) with tf.io.gfile.GFile(processing_file + data_id, 'w') as w: w.write('Procesing started.') for repetition_id, example in multiply_examples(e): question = example[0] question_text = question[question.find(':') + 2:] question_text = question_text.replace('-', ' - ') question = 'infer full calculation: ' + question_text list_num = [ float(num.replace(',', '').rstrip('.')) for num in re.findall( r'[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?', question) ] for i in range(len(list_num)): question += ' n{} = {}'.format(i, list_num[i]) # print('Question {}'.format(question)) tokenized_question = next( padding_fun( trax_data.tokenize([ question, ], vocab_file=gin.query_parameter( 'trax.data.Tokenize.vocab_file')))) state = model.state if FLAGS.use_beam_search: answer_beams = decoding.beam_search( model, tokenized_question[None, :], n_beams=FLAGS.n_beams, max_length=FLAGS.max_answer_len, accelerate=False) model.state = state else: answer_beams = [] # We recycle the n_beams flag to control the number # of autoregressive samples. for i in range(FLAGS.n_beams): answer = decoding.autoregressive_sample( model, tokenized_question[None, :], temperature=FLAGS.autoregressive_sample_temp, max_length=FLAGS.max_answer_len, accelerate=False) model.state = state answer_beams.append(answer) correct_example_index = -1 for i in range(len(answer_beams)): if FLAGS.use_beam_search: answer = trax_data.detokenize( answer_beams[i][0][0], vocab_file=gin.query_parameter( 'trax.data.Tokenize.vocab_file')) else: answer = trax_data.detokenize( answer_beams[i][0], vocab_file=gin.query_parameter( 'trax.data.Tokenize.vocab_file')) print('Proposed computation {}'.format(answer)) list_op = answer.split('|') if not list_op[-1]: list_op = list_op[:-1] try: result = trax_data.tf_inputs.compute_result( list_op, list_num) if target_answer in result: correct_example_index = result.index(target_answer) break # This is a temporary hack with "broad" exceptions - the computations # must fail sometime, because we evaluate arbitrary sequences; I am in # the process of checking what are possible failure modes. except Exception as e: # pylint: disable=broad-except print(e) try: result = trax_data.tf_inputs.compute_result( list_op[:-1], list_num) if target_answer in result: correct_example_index = result.index(target_answer) break except Exception as e: # pylint: disable=broad-except print(e) print('Infered incorrect computation.') if correct_example_index == -1: continue json_record = { 'question': question_text, 'input': question, 'calculation': '|'.join(list_op[:correct_example_index + 1]), 'target_answer': target_answer } all_jsons.append(json.dumps(json_record) + '\n') # Outputting the inferred data in JSONL format. data_id = str(example_count + FLAGS.starting_example) with tf.io.gfile.GFile(json_to_write + data_id, 'w') as w: w.write(json.dumps(json_record) + '\n') with tf.io.gfile.GFile(processing_file + data_id, 'w') as w: w.write('Procesing finished.') with tf.io.gfile.GFile(json_to_write + '_' + str(FLAGS.starting_example), 'w') as w: for record in all_jsons: w.write(record)