def main(argv): # HACK: redirect the create_hparams function to setup the hparams # using the passed in command-line args argv = common_flags.update_argv(argv) t2t_decoder.create_hparams = functools.partial(create_hparams, argv) t2t_decoder.main(None)
def main(argv): global counter print(argv, '\n', '=' * 20) ''' if args.query: t2t_query.main(argv) exit() ''' if train_not_test: while counter < limit or args.no_limit: tf.flags.FLAGS.set_default('train_steps', counter + args.increment) tf.flags.FLAGS.train_steps = counter + args.increment print('flag:', tf.flags.FLAGS.get_flag_value('train_steps', 5), str(counter + args.increment)) t2t_trainer.main(argv) counter += args.increment print('=' * 50, counter, limit, '=' * 50) else: t2t_decoder.main(argv) pass
def decode(generate_data=True): FLAGS.problem = "english_grammar_error" FLAGS.model = "transformer" FLAGS.hparams_set = "transformer_big_single_gpu" FLAGS.t2t_usr_dir = "src" FLAGS.output_dir = "finetune_dir" FLAGS.data_dir = "t2t_finetune" FLAGS.decode_hparams = "beam_size=4,alpha=0.6" FLAGS.decode_from_file = 'test_new.txt' FLAGS.decode_to_file = 'output_new.txt' t2t_decoder.main(None)
def main(argv): t2t_decoder.main(argv)