def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') dataset = preprocessor.get_dataset_from_tfds(FLAGS.dataset, FLAGS.split) preprocessor.write_dataset(dataset, FLAGS.save_path) token_vocab = preprocessor.get_token_vocab(FLAGS.save_path) preprocessor.write_token_vocab(token_vocab, FLAGS.save_path)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') FLAGS.save_path = update_flag_value(FLAGS.save_path) FLAGS.eval_results_path = update_flag_value(FLAGS.eval_results_path) FLAGS.questions_path = update_flag_value(FLAGS.questions_path) FLAGS.golden_answers_path = update_flag_value(FLAGS.golden_answers_path) FLAGS.inferred_answers_path = update_flag_value( FLAGS.inferred_answers_path) if os.path.exists(os.path.join(FLAGS.save_path, 'vocab.cfq.tokens')): print_status('Skipping preprocessing') else: print_status('Running preprocessing') dataset = preprocessor.get_dataset_from_tfds(FLAGS.dataset, FLAGS.split) preprocessor.write_dataset(dataset, FLAGS.save_path) token_vocab = preprocessor.get_token_vocab(FLAGS.save_path) preprocessor.write_token_vocab(token_vocab, FLAGS.save_path) t2t_usr_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'cfq') output_dir = os.path.join(FLAGS.save_path, 'output') print_status('Running t2t-datagen') # NOTE: This one skips automatically if the files exist. # TODO(danielfurrer): Sometimes one of the steps here will crash with a # CUBLAS_STATUS_NOT_INITIALIZED error. I suspect this is # related to the subprocess calls here. Rerunning seems to # solve the problem (perhaps because t2t-datagen returns # quickly if it was completed before). subprocess.run([ 't2t-datagen', '--t2t_usr_dir=' + t2t_usr_dir, '--data_dir=' + FLAGS.save_path, '--problem=' + T2T_PROBLEM, '--tmp_dir=/tmp/cfq_tmp', ], check=True) print_status('Running t2t-trainer') subprocess.run([ 't2t-trainer', '--t2t_usr_dir=' + t2t_usr_dir, '--data_dir=' + FLAGS.save_path, '--problem=' + T2T_PROBLEM, '--model=' + FLAGS.model, '--hparams_set=' + FLAGS.hparams_set, '--output_dir=' + output_dir, '--train_steps=%s' % FLAGS.train_steps, ], check=True) print_status('Running t2t-decoder') checkpoint_path = os.path.join(output_dir, 'model.ckpt-%s' % FLAGS.train_steps) subprocess.run([ 't2t-decoder', '--t2t_usr_dir=' + t2t_usr_dir, '--data_dir=' + FLAGS.save_path, '--problem=' + T2T_PROBLEM, '--model=' + FLAGS.model, '--hparams_set=' + FLAGS.hparams_set, '--checkpoint_path=' + checkpoint_path, '--decode_from_file=' + FLAGS.questions_path, '--decode_to_file=' + FLAGS.inferred_answers_path, '--output_dir=' + output_dir, ], check=True) print_status('Calculating accuracy') accuracy_result = evaluator.get_accuracy_result( FLAGS.questions_path, FLAGS.golden_answers_path, FLAGS.inferred_answers_path) evaluator.write_accuracy_result(accuracy_result, FLAGS.eval_results_path, print_output=True)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') FLAGS.save_path = update_flag_value(FLAGS.save_path) FLAGS.eval_results_path = update_flag_value(FLAGS.eval_results_path) FLAGS.questions_path = update_flag_value(FLAGS.questions_path) FLAGS.golden_answers_path = update_flag_value(FLAGS.golden_answers_path) FLAGS.inferred_answers_path = update_flag_value( FLAGS.inferred_answers_path) if os.path.exists(os.path.join(FLAGS.save_path, 'vocab.cfq.tokens')): print_status('Skipping preprocessing') else: print_status('Running preprocessing') dataset = preprocessor.get_dataset_from_tfds(FLAGS.dataset, FLAGS.split) preprocessor.write_dataset(dataset, FLAGS.save_path) token_vocab = preprocessor.get_token_vocab(FLAGS.save_path) preprocessor.write_token_vocab(token_vocab, FLAGS.save_path) t2t_usr_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'cfq') output_dir = os.path.join(FLAGS.save_path, 'output') print_status('Running t2t-datagen') # NOTE: This one skips automatically if the files exist. subprocess.run([ 't2t-datagen', '--t2t_usr_dir=' + t2t_usr_dir, '--data_dir=' + FLAGS.save_path, '--problem=' + T2T_PROBLEM, '--tmp_dir=/tmp/cfq_tmp', ], check=True) print_status('Running t2t-trainer') subprocess.run([ 't2t-trainer', '--t2t_usr_dir=' + t2t_usr_dir, '--data_dir=' + FLAGS.save_path, '--problem=' + T2T_PROBLEM, '--model=' + FLAGS.model, '--hparams_set=' + FLAGS.hparams_set, '--output_dir=' + output_dir, '--train_steps=%s' % FLAGS.train_steps, ], check=True) print_status('Running t2t-decoder') checkpoint_path = os.path.join(output_dir, 'model.ckpt-%s' % FLAGS.train_steps) subprocess.run([ 't2t-decoder', '--t2t_usr_dir=' + t2t_usr_dir, '--data_dir=' + FLAGS.save_path, '--problem=' + T2T_PROBLEM, '--model=' + FLAGS.model, '--hparams_set=' + FLAGS.hparams_set, '--checkpoint_path=' + checkpoint_path, '--decode_from_file=' + FLAGS.questions_path, '--decode_to_file=' + FLAGS.inferred_answers_path, '--output_dir=' + output_dir, ], check=True) print_status('Calculating accuracy') accuracy_result = evaluator.get_accuracy_result( FLAGS.questions_path, FLAGS.golden_answers_path, FLAGS.inferred_answers_path) evaluator.write_accuracy_result(accuracy_result, FLAGS.eval_results_path, print_output=True)