def t2t_decoder(problem_name, data_dir, decode_from_file, decode_to_file, checkpoint_path): trainer_lib.set_random_seed(FLAGS.random_seed) hp = trainer_lib.create_hparams(FLAGS.hparams_set, FLAGS.hparams, data_dir=os.path.expanduser(data_dir), problem_name=problem_name) decode_hp = decoding.decode_hparams(FLAGS.decode_hparams) decode_hp.shards = FLAGS.decode_shards decode_hp.shard_id = FLAGS.worker_id decode_in_memory = FLAGS.decode_in_memory or decode_hp.decode_in_memory decode_hp.decode_in_memory = decode_in_memory decode_hp.decode_to_file = decode_to_file decode_hp.decode_reference = None FLAGS.checkpoint_path = checkpoint_path estimator = trainer_lib.create_estimator(FLAGS.model, hp, t2t_trainer.create_run_config(hp), decode_hparams=decode_hp, use_tpu=FLAGS.use_tpu) decode_from_text_file(estimator, problem_name, decode_from_file, hp, decode_hp, decode_to_file, checkpoint_path=checkpoint_path)
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) log_registry() if FLAGS.cloud_mlengine: return cloud_mlengine.launch() if FLAGS.generate_data: generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: set_hparams_from_args(argv[1:]) hparams = create_hparams() if is_chief(): save_metadata(hparams) with maybe_cloud_tpu(): exp_fn = create_experiment_fn() exp = exp_fn(create_run_config(hparams), hparams) execute_schedule(exp)
def main(_): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) ckpt_dir = os.path.expanduser(FLAGS.output_dir) hparams = create_hparams() hparams.no_data_parallelism = True # To clear the devices problem = hparams.problem if FLAGS.export_as_tfhub: export_as_tfhub_module(hparams, problem, ckpt_dir) return run_config = t2t_trainer.create_run_config(hparams) estimator = create_estimator(run_config, hparams) exporter = tf.estimator.FinalExporter( "exporter", lambda: problem.serving_input_fn(hparams), as_text=True) export_dir = os.path.join(ckpt_dir, "export") exporter.export(estimator, export_dir, checkpoint_path=tf.train.latest_checkpoint(ckpt_dir), eval_result=None, is_the_final_export=True)
def main(_): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) if FLAGS.score_file: filename = os.path.expanduser(FLAGS.score_file) if not tf.gfile.Exists(filename): raise ValueError("The file to score doesn't exist: %s" % filename) results = score_file(filename) # if not FLAGS.decode_to_file: # raise ValueError("To score a file, specify --decode_to_file for results.") # write_file = tf.gfile.Open(os.path.expanduser(FLAGS.decode_to_file), "w") # for sentence, score in results: # write_file.write(sentence + "\t" + "SCORE:" + "%.6f\n" % score) # write_file.close() return hp = create_hparams() decode_hp = create_decode_hparams() run_config = t2t_trainer.create_run_config(hp) if FLAGS.disable_grappler_optimizations: run_config.session_config.graph_options.rewrite_options.disable_meta_optimizer = True # summary-hook in tf.estimator.EstimatorSpec requires # hparams.model_dir to be set. hp.add_hparam("model_dir", run_config.model_dir) estimator = trainer_lib.create_estimator(FLAGS.model, hp, run_config, decode_hparams=decode_hp, use_tpu=FLAGS.use_tpu) decode(estimator, hp, decode_hp)
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: t2t_trainer.generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) root_output_dir = FLAGS.output_dir if FLAGS.teacher_dir: teacher_dir = FLAGS.teacher_dir else: teacher_dir = os.path.join(root_output_dir, "teacher") # Train Teacher ============ if FLAGS.skip_teacher_training: tf.logging.info("training teacher skipped") else: hparams = t2t_trainer.create_hparams() hparams.distill_phase = "train" FLAGS.output_dir = teacher_dir exp_fn = t2t_trainer.create_experiment_fn() run_config = t2t_trainer.create_run_config(hparams) exp = exp_fn(run_config, hparams) if t2t_trainer.is_chief(): t2t_trainer.save_metadata(hparams) t2t_trainer.execute_schedule(exp) # ========================== # Train Student ============ hparams = t2t_trainer.create_hparams() hparams.add_hparam("teacher_dir", teacher_dir) hparams.distill_phase = "distill" if FLAGS.student_dir: student_dir = FLAGS.student_dir else: student_dir = os.path.join(root_output_dir, "student") FLAGS.output_dir = student_dir hparams.add_hparam("student_dir", student_dir) exp_fn = t2t_trainer.create_experiment_fn() run_config = t2t_trainer.create_run_config(hparams) exp = exp_fn(run_config, hparams) if t2t_trainer.is_chief(): t2t_trainer.save_metadata(hparams) t2t_trainer.execute_schedule(exp)
def create_student_experiment(run_config, hparams, argv): """Creates experiment function.""" tf.logging.info("training student") tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: return cloud_mlengine.launch() if FLAGS.generate_data: t2t_trainer.generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) hparams.add_hparam("teacher_dir", FLAGS.teacher_dir) hparams.add_hparam("student_dir", FLAGS.student_dir) hparams.distill_phase = "distill" exp_fn = t2t_trainer.create_experiment_fn() exp = exp_fn(run_config, hparams) return exp
def create_teacher_experiment(run_config, hparams, argv): """Creates experiment function.""" tf.logging.info("training teacher") tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: return cloud_mlengine.launch() if FLAGS.generate_data: t2t_trainer.generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) with t2t_trainer.maybe_cloud_tpu(): hparams.distill_phase = "train" exp_fn = t2t_trainer.create_experiment_fn() exp = exp_fn(run_config, hparams) return exp
def main(_): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) if FLAGS.score_file: filename = os.path.expanduser(FLAGS.score_file) if not tf.gfile.Exists(filename): raise ValueError("The file to score doesn't exist: %s" % filename) results = score_file(filename) if not FLAGS.decode_to_file: raise ValueError("To score a file, specify --decode_to_file for results.") write_file = tf.gfile.Open(os.path.expanduser(FLAGS.decode_to_file), "w") for score in results: write_file.write("%.6f\n" % score) write_file.close() return hp = create_hparams() decode_hp = create_decode_hparams() estimator = trainer_lib.create_estimator( FLAGS.model, hp, t2t_trainer.create_run_config(hp), decode_hparams=decode_hp, use_tpu=FLAGS.use_tpu) decode(estimator, hp, decode_hp)
def main(_): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) if FLAGS.checkpoint_path: checkpoint_path = FLAGS.checkpoint_path ckpt_dir = os.path.dirname(checkpoint_path) else: ckpt_dir = os.path.expanduser(FLAGS.output_dir) checkpoint_path = tf.train.latest_checkpoint(ckpt_dir) hparams = create_hparams() hparams.no_data_parallelism = True # To clear the devices problem = hparams.problem export_dir = FLAGS.export_dir or os.path.join(ckpt_dir, "export") if FLAGS.export_as_tfhub: checkpoint_path = tf.train.latest_checkpoint(ckpt_dir) decode_hparams = decoding.decode_hparams(FLAGS.decode_hparams) export_as_tfhub_module(FLAGS.model, hparams, decode_hparams, problem, checkpoint_path, export_dir) return run_config = t2t_trainer.create_run_config(hparams) estimator = create_estimator(run_config, hparams) estimator.export_savedmodel(export_dir, lambda: problem.serving_input_fn(hparams), as_text=False, checkpoint_path=checkpoint_path)
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) if FLAGS.v2: tf.enable_v2_behavior() # Hacking main v1 flags to work with v2. config_strs = [] config_strs.append( "train_fn.train_steps=" + str(FLAGS.train_steps)) config_strs.append( "train_fn.eval_steps=" + str(FLAGS.eval_steps)) config_strs.append( "train_fn.eval_frequency=" + str(FLAGS.local_eval_frequency)) if FLAGS.hparams: config_strs.extend(str(FLAGS.hparams).split(",")) config_str = "\n".join(config_strs) data_dir = os.path.expanduser(FLAGS.data_dir) output_dir = os.path.expanduser(FLAGS.output_dir) t2t_v2.t2t_train(FLAGS.model, FLAGS.problem, data_dir=data_dir, output_dir=output_dir, config_file=FLAGS.hparams_set, config=config_str) return usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) # If we just have to print the registry, do that and exit early. maybe_log_registry_and_exit() # Create HParams. if argv: set_hparams_from_args(argv[1:]) hparams = create_hparams() if FLAGS.schedule == "train" or FLAGS.schedule == "train_eval_and_decode": mlperf_log.transformer_print(key=mlperf_log.RUN_START, hparams=hparams) if FLAGS.schedule == "run_std_server": run_std_server() mlperf_log.transformer_print( key=mlperf_log.RUN_SET_RANDOM_SEED, value=FLAGS.random_seed, hparams=hparams) trainer_lib.set_random_seed(FLAGS.random_seed) if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() exp_fn = create_experiment_fn() exp = exp_fn(create_run_config(hparams), hparams) if is_chief(): save_metadata(hparams) execute_schedule(exp) if FLAGS.schedule != "train": mlperf_log.transformer_print(key=mlperf_log.RUN_FINAL, hparams=hparams)
def main(_): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) # Create hparams hparams = create_hparams() hparams.force_full_predict = True batch_size = hparams.batch_size # Iterating over dev/test partition of the data. # Change the data partition if necessary. dataset = registry.problem(FLAGS.problem).dataset( tf.estimator.ModeKeys.PREDICT, shuffle_files=False, hparams=hparams) dataset = dataset.apply( tf.contrib.data.batch_and_drop_remainder(batch_size)) data = dataset.make_one_shot_iterator().get_next() input_data = dict( (k, data[k]) for k in data.keys() if k.startswith("input")) # Creat model model_cls = registry.model(FLAGS.model) model = model_cls(hparams, tf.estimator.ModeKeys.PREDICT) prediction_ops = model.infer(input_data) # Confusion Matrix nr = hparams.problem.num_rewards cm_per_frame = np.zeros((nr, nr), dtype=np.uint64) cm_next_frame = np.zeros((nr, nr), dtype=np.uint64) saver = tf.train.Saver() with tf.train.SingularMonitoredSession() as sess: # Load latest checkpoint ckpt = tf.train.get_checkpoint_state( FLAGS.output_dir).model_checkpoint_path saver.restore(sess.raw_session(), ckpt) counter = 0 while not sess.should_stop(): counter += 1 if counter % 1 == 0: print(counter) # Predict next frames rew_pd, rew_gt = sess.run( [prediction_ops["target_reward"], data["target_reward"]]) for i in range(batch_size): cm_next_frame[rew_gt[i, 0, 0], rew_pd[i, 0, 0]] += 1 for gt, pd in zip(rew_gt[i], rew_pd[i]): cm_per_frame[gt, pd] += 1 print_confusion_matrix("Per-frame Confusion Matrix", cm_per_frame) print_confusion_matrix("Next-frame Confusion Matrix", cm_next_frame)
def main(_): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) # Create hparams hparams = create_hparams() hparams.force_full_predict = True batch_size = hparams.batch_size # Iterating over dev/test partition of the data. # Change the data partition if necessary. dataset = registry.problem(FLAGS.problem).dataset( tf.estimator.ModeKeys.PREDICT, shuffle_files=False, hparams=hparams) dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size)) data = dataset.make_one_shot_iterator().get_next() input_data = dict((k, data[k]) for k in data.keys() if k.startswith("input")) # Creat model model_cls = registry.model(FLAGS.model) model = model_cls(hparams, tf.estimator.ModeKeys.PREDICT) prediction_ops = model.infer(input_data) # Confusion Matrix nr = hparams.problem.num_rewards cm_per_frame = np.zeros((nr, nr), dtype=np.uint64) cm_next_frame = np.zeros((nr, nr), dtype=np.uint64) saver = tf.train.Saver() with tf.train.SingularMonitoredSession() as sess: # Load latest checkpoint ckpt = tf.train.get_checkpoint_state(FLAGS.output_dir).model_checkpoint_path saver.restore(sess.raw_session(), ckpt) counter = 0 while not sess.should_stop(): counter += 1 if counter % 1 == 0: print(counter) # Predict next frames rew_pd, rew_gt = sess.run( [prediction_ops["target_reward"], data["target_reward"]]) for i in range(batch_size): cm_next_frame[rew_gt[i, 0, 0], rew_pd[i, 0, 0]] += 1 for gt, pd in zip(rew_gt[i], rew_pd[i]): cm_per_frame[gt, pd] += 1 print_confusion_matrix("Per-frame Confusion Matrix", cm_per_frame) print_confusion_matrix("Next-frame Confusion Matrix", cm_next_frame)
def main(_): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) # sess_dir = FLAGS.sess_dir # output_dir = os.path.expanduser(sess_dir+problem_name+'-'+model+'-'+hparams) output_dir = FLAGS.output_dir if FLAGS.score_file: filename = os.path.expanduser(FLAGS.score_file) if not tf.gfile.Exists(filename): raise ValueError("The file to score doesn't exist: %s" % filename) results = score_file(filename) if not FLAGS.decode_to_file: raise ValueError( "To score a file, specify --decode_to_file for results.") write_file = tf.gfile.Open(os.path.expanduser(FLAGS.decode_to_file), "w") for score in results: write_file.write("%.6f\n" % score) write_file.close() return hp = create_hparams() if FLAGS.global_steps: FLAGS.checkpoint_path = os.path.join( FLAGS.model_dir, f"model.ckpt-{FLAGS.global_steps}") else: FLAGS.checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir) # Check if already exists dataset_split = "test" if FLAGS.split == "test" else "dev" decode_path = os.path.join(FLAGS.model_dir, "decode_00000") # default decoded_to_file decode_path = FLAGS.decode_to_file if FLAGS.decode_to_file else decode_path if os.path.isdir(decode_path): files = os.listdir(decode_path) for file in files: file_name = file.split(".")[0] file_name_to_be = f"{FLAGS.global_steps}{dataset_split}{FLAGS.test_shard:03d}" if file_name == file_name_to_be: print(f"Already {file_name_to_be} exists") return tf.reset_default_graph() decode_hp = create_decode_hparams(decode_path, FLAGS.test_shard) estimator = trainer_lib.create_estimator(FLAGS.model, hp, create_run_config(hp), decode_hparams=decode_hp, use_tpu=FLAGS.use_tpu) decode(estimator, hp, decode_hp) print("shard " + str(FLAGS.test_shard) + " completed")
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.generate_data: t2t_trainer.generate_data() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) hparams = t2t_trainer.create_hparams() trainer_lib.add_problem_hparams(hparams, FLAGS.problem) pruning_params = create_pruning_params() pruning_strategy = create_pruning_strategy(pruning_params.strategy) config = t2t_trainer.create_run_config(hparams) params = {"batch_size": hparams.batch_size} # add "_rev" as a hack to avoid image standardization problem = registry.problem(FLAGS.problem) input_fn = problem.make_estimator_input_fn(tf.estimator.ModeKeys.EVAL, hparams) dataset = input_fn(params, config).repeat() features, labels = dataset.make_one_shot_iterator().get_next() sess = tf.Session() model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.model, hparams, use_tpu=FLAGS.use_tpu) spec = model_fn( features, labels, tf.estimator.ModeKeys.EVAL, params=hparams, config=config) # Restore weights saver = tf.train.Saver() checkpoint_path = os.path.expanduser(FLAGS.output_dir or FLAGS.checkpoint_path) saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path)) def eval_model(): preds = spec.predictions["predictions"] preds = tf.argmax(preds, -1, output_type=labels.dtype) _, acc_update_op = tf.metrics.accuracy(labels=labels, predictions=preds) sess.run(tf.initialize_local_variables()) for _ in range(FLAGS.eval_steps): acc = sess.run(acc_update_op) return acc pruning_utils.sparsify(sess, eval_model, pruning_strategy, pruning_params)
def main(_): import ipdb tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) if FLAGS.score_file: filename = os.path.expanduser(FLAGS.score_file) if not tf.gfile.Exists(filename): raise ValueError("The file to score doesn't exist: %s" % filename) results = score_file(filename) if not FLAGS.decode_to_file: raise ValueError("To score a file, specify --decode_to_file for results.") write_file = tf.gfile.Open(os.path.expanduser(FLAGS.decode_to_file), "w") for score in results: write_file.write("%.6f\n" % score) write_file.close() return hp = create_hparams() decode_hp = create_decode_hparams() # eval_input_fn = hp.problem.make_estimator_input_fn( # tf.estimator.ModeKeys.TRAIN, hp, dataset_kwargs={"dataset_split": "eval"}) # print(eval_input_fn) # for foo in eval_input_fn(None, None): # print(type(foo[0]['targets'])) # print(foo[0]['targets'].numpy()) # exit() run_config = t2t_trainer.create_run_config(hp) if FLAGS.disable_grappler_optimizations: run_config.session_config.graph_options.rewrite_options.disable_meta_optimizer = True # summary-hook in tf.estimator.EstimatorSpec requires # hparams.model_dir to be set. hp.add_hparam("model_dir", run_config.model_dir) estimator = trainer_lib.create_estimator( FLAGS.model, hp, run_config, decode_hparams=decode_hp, use_tpu=FLAGS.use_tpu) decode(estimator, hp, decode_hp)
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: t2t_trainer.generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) with t2t_trainer.maybe_cloud_tpu(): root_output_dir = FLAGS.output_dir # Train Teacher ============ hparams = t2t_trainer.create_hparams() hparams.distill_phase = "train" teacher_dir = os.path.join(root_output_dir, "teacher") FLAGS.output_dir = teacher_dir exp_fn = t2t_trainer.create_experiment_fn() run_config = t2t_trainer.create_run_config(hparams) exp = exp_fn(run_config, hparams) if t2t_trainer.is_chief(): t2t_trainer.save_metadata(hparams) t2t_trainer.execute_schedule(exp) # ========================== # Train Student ============ hparams = t2t_trainer.create_hparams() hparams.add_hparam("teacher_dir", teacher_dir) hparams.distill_phase = "distill" student_dir = os.path.join(root_output_dir, "student") FLAGS.output_dir = student_dir exp_fn = t2t_trainer.create_experiment_fn() run_config = t2t_trainer.create_run_config(hparams) exp = exp_fn(run_config, hparams) if t2t_trainer.is_chief(): t2t_trainer.save_metadata(hparams) t2t_trainer.execute_schedule(exp)
def main(_): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) if FLAGS.score_file: filename = os.path.expanduser(FLAGS.score_file) if not tf.gfile.Exists(filename): raise ValueError("The file to score doesn't exist: %s" % filename) results = score_file(filename) if not FLAGS.decode_to_file: raise ValueError( "To score a file, specify --decode_to_file for results.") write_file = tf.gfile.Open(os.path.expanduser(FLAGS.decode_to_file), "w") for score in results: write_file.write("%.6f\n" % score) write_file.close() return hp = create_hparams() decode_hp = create_decode_hparams() estimator = trainer_lib.create_estimator(FLAGS.model, hp, t2t_trainer.create_run_config(hp), decode_hparams=decode_hp, use_tpu=FLAGS.use_tpu) decode(estimator, hp, decode_hp) # Post-process decodings (if necessary). if FLAGS.decode_to_file and FLAGS.output_line_prefix_tag: decode_filename_original = FLAGS.decode_to_file decode_filename_prefixed = "%s-%s" % (decode_filename_original, FLAGS.output_line_prefix_tag) tf.logging.info("Writing prefexed decodes into %s" % decode_filename_prefixed) # Read original lines. with tf.gfile.Open(decode_filename_original, "r") as original_fp: original_lines = original_fp.readlines() # Write prefixed lines. prefix = "<%s> " % FLAGS.output_line_prefix_tag prefixed_fp = tf.gfile.Open(decode_filename_prefixed, "w") for line in original_lines: prefixed_fp.write(prefix + line) prefixed_fp.flush() prefixed_fp.close() tf.logging.info("Done.")
def main(_): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) hp = t2t_decoder.create_hparams() decode_hp = t2t_decoder.create_decode_hparams() estimator = trainer_lib.create_estimator(FLAGS.model, hp, t2t_trainer.create_run_config(hp), decode_hparams=decode_hp, use_tpu=FLAGS.use_tpu) decode(estimator, hp, decode_hp)
def main(_): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) # Fathom start checkpoint_path = fathom_t2t_model_setup() # Fathom end usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) if FLAGS.score_file: filename = os.path.expanduser(FLAGS.score_file) if not tf.gfile.Exists(filename): raise ValueError("The file to score doesn't exist: %s" % filename) results = score_file(filename) if not FLAGS.decode_to_file: raise ValueError( "To score a file, specify --decode_to_file for results.") write_file = tf.gfile.Open(os.path.expanduser(FLAGS.decode_to_file), "w") for score in results: write_file.write("%.6f\n" % score) write_file.close() return hp = create_hparams() decode_hp = create_decode_hparams() estimator = trainer_lib.create_estimator(FLAGS.model, hp, t2t_trainer.create_run_config(hp), decode_hparams=decode_hp, use_tpu=FLAGS.use_tpu) decode(estimator, hp, decode_hp) # Fathom # This xcom is here so that tasks after decode know the local path to the # downloaded model. Train does this same xcom echo. # Decode, predict, and evaluate code should # converge to use the same fathom_t2t_model_setup. # TODO: since the truncation-boundary xcom value should be available in # the hparams_set, we should probably have consumers access this via a # SavedModel.hparams property rather than XCOM echo_yaml_for_xcom_ingest({ 'output-dir': os.path.dirname(checkpoint_path), 'output-file': FLAGS.decode_output_file, 'truncation-boundary': hp.max_input_seq_length })
def main(argv): # Fathom if FLAGS.fathom: fathom.t2t_trainer_setup(FLAGS.problem) # This exits if a checkpoint is set but not found. # Only takes action for an eval task. if FLAGS.schedule == 'evaluate': fathom.exit_if_no_eval_checkpoint_found(FLAGS.eval_checkpoint_path) tf.logging.set_verbosity(tf.logging.INFO) if FLAGS.schedule == "run_std_server": run_std_server() trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: # Fathom assert False, 'No cloudml support currently' cloud_mlengine.launch() return if FLAGS.generate_data: generate_data() # Fathom commented out # if cloud_mlengine.job_dir(): # FLAGS.output_dir = cloud_mlengine.job_dir() if argv: set_hparams_from_args(argv[1:]) hparams = create_hparams() # Fathom hparams = fathom.adjust_params(hparams) exp_fn = create_experiment_fn() exp = exp_fn(create_run_config(hparams), hparams) if is_chief(): save_metadata(hparams) execute_schedule(exp) # Fathom # NOTE: this must run LAST in the process, to make sure STDOUT is # appropriately populated. fathom.t2t_trainer_cleanup()
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) # If we just have to print the registry, do that and exit early. maybe_log_registry_and_exit() # Create HParams. if argv: set_hparams_from_args(argv[1:]) if FLAGS.schedule != "run_std_server": hparams = create_hparams() if FLAGS.gpu_automatic_mixed_precision: setattr(hparams, "gpu_automatic_mixed_precision", True) if FLAGS.schedule == "train" or FLAGS.schedule == "train_eval_and_decode": mlperf_log.transformer_print(key=mlperf_log.RUN_START, hparams=hparams) if FLAGS.schedule == "run_std_server": run_std_server() mlperf_log.transformer_print( key=mlperf_log.RUN_SET_RANDOM_SEED, value=FLAGS.random_seed, hparams=hparams) trainer_lib.set_random_seed(FLAGS.random_seed) for flag, val in FLAGS.__flags.items(): print(flag, ": ", val.value) if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() exp_fn = create_experiment_fn() exp = exp_fn(create_run_config(hparams), hparams) if is_chief(): save_metadata(hparams) execute_schedule(exp) if FLAGS.schedule != "train": mlperf_log.transformer_print(key=mlperf_log.RUN_FINAL, hparams=hparams)
def main(_): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) ckpt_dir = os.path.expanduser(FLAGS.output_dir) hparams = create_hparams() hparams.no_data_parallelism = True # To clear the devices run_config = t2t_trainer.create_run_config(hparams) estimator = create_estimator(run_config, hparams) problem = hparams.problem strategy = trainer_lib.create_export_strategy(problem, hparams) export_dir = os.path.join(ckpt_dir, "export", strategy.name) strategy.export(estimator, export_dir, checkpoint_path=tf.train.latest_checkpoint(ckpt_dir))
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) hvd.init() trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) log_registry() if FLAGS.cloud_mlengine: return cloud_mlengine.launch() if FLAGS.generate_data: generate_data() if hasattr(FLAGS, "job_dir") and FLAGS.job_dir: FLAGS.output_dir = FLAGS.job_dir if argv: set_hparams_from_args(argv[1:]) # hparams = create_hparams() if is_chief(): save_metadata(hparams) # create_run_config会调用trainer_lib.create_session_config,这个函数包含gup_options初始化 config = create_run_config(hparams) decode_hparams = decoding.decode_hparams(FLAGS.decode_hparams) schedule = FLAGS.schedule estimator = create_estimator_fn(FLAGS.model, hparams, config, schedule, decode_hparams) # logging_hook = tf.train.LoggingTensorHook({"step": "test"}, every_n_iter=5) bcast_hook = hvd.BroadcastGlobalVariablesHook(0) estimator.train( input_fn=train_input_fn(hparams), steps=FLAGS.train_steps, hooks=[bcast_hook] )
def entry(self): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) print("###self defined hp###") print(str(FLAGS.data_dir)) print(str(FLAGS.problem)) print(str(FLAGS.model)) print(str(FLAGS.hparams_set)) print(str(FLAGS.output_dir)) print(str(FLAGS.decode_hparams)) hp = self.create_hparams() decode_hp = self.create_decode_hparams() estimator = self.create_new_estimator(hp, decode_hp) output_decode = self.my_decode(estimator, hp, decode_hp) print('output decode-res = %s ' % str(output_decode)) return output_decode
def entry(input_str): # global estimator # global hp # global decode_hp # flags.FLAGS(argv , known_only=True) tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) print("###self defined hp###") print(str(FLAGS.data_dir)) print(str(FLAGS.problem)) print(str(FLAGS.model)) print(str(FLAGS.hparams_set)) print(str(FLAGS.output_dir)) print(str(FLAGS.decode_hparams)) # if hp is None: # print('hp is None !') # hp = create_hparams() # if decode_hp is None: # print('decode_hp is None !') # decode_hp = create_decode_hparams() # if estimator is None: # print('estimator is None !') # estimator = my_trainer_lib.create_estimator( # FLAGS.model, # hp, # t2t_trainer.create_run_config(hp), # decode_hparams=decode_hp, # use_tpu=FLAGS.use_tpu) hp=app.config['hp'] decode_hp=app.config['decode_hp'] estimator=app.config['estimator'] output_decode = my_decode(estimator, hp, decode_hp,input_str) print('output-decode-res = %s ' % str(output_decode)) return output_decode
def main(_): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) if FLAGS.checkpoint_path: checkpoint_path = FLAGS.checkpoint_path ckpt_dir = os.path.dirname(checkpoint_path) else: ckpt_dir = os.path.expanduser(FLAGS.output_dir) checkpoint_path = tf.train.latest_checkpoint(ckpt_dir) hparams = create_hparams() hparams.no_data_parallelism = True # To clear the devices problem = hparams.problem export_dir = FLAGS.export_dir or os.path.join(ckpt_dir, "export") if FLAGS.export_as_tfhub: checkpoint_path = tf.train.latest_checkpoint(ckpt_dir) decode_hparams = decoding.decode_hparams(FLAGS.decode_hparams) export_as_tfhub_module(FLAGS.model, hparams, decode_hparams, problem, checkpoint_path, export_dir) return run_config = t2t_trainer.create_run_config(hparams) estimator = create_estimator(run_config, hparams) exporter = tf.estimator.FinalExporter( "exporter", lambda: problem.serving_input_fn(hparams), as_text=True) exporter.export( estimator, export_dir, checkpoint_path=checkpoint_path, eval_result=None, is_the_final_export=True)
def main(argv): set_hparams_from_args(argv[1:]) hparams = create_hparams() hparams.add_hparam("data_dir", FLAGS.data_dir) trainer_lib.set_random_seed(FLAGS.random_seed) hparams_lib.add_problem_hparams(hparams, FLAGS.problem) problem = hparams.problem train_input_fn = problem.make_estimator_input_fn( tf.estimator.ModeKeys.TRAIN, hparams) eval_input_fn = problem.make_estimator_input_fn( tf.estimator.ModeKeys.EVAL, hparams, dataset_kwargs={'dataset_split': None}) features, labels, input_hooks = estimator_util._get_features_and_labels_from_input_fn( train_input_fn, tf.estimator.ModeKeys.TRAIN, hparams, create_run_config(hparams)) print(features) print(labels) print(input_hooks)
def create_hp_and_estimator(problem_name, data_dir, checkpoint_path): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) hp = trainer_lib.create_hparams(FLAGS.hparams_set, FLAGS.hparams, data_dir=os.path.expanduser(data_dir), problem_name=problem_name) decode_hp = decoding.decode_hparams(FLAGS.decode_hparams) decode_hp.shards = FLAGS.decode_shards decode_hp.shard_id = FLAGS.worker_id decode_in_memory = FLAGS.decode_in_memory or decode_hp.decode_in_memory decode_hp.decode_in_memory = decode_in_memory decode_hp.decode_to_file = None decode_hp.decode_reference = None FLAGS.checkpoint_path = checkpoint_path estimator = trainer_lib.create_estimator(FLAGS.model, hp, t2t_trainer.create_run_config(hp), decode_hparams=decode_hp, use_tpu=FLAGS.use_tpu) return hp, decode_hp, estimator
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: t2t_trainer.generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) hparams = t2t_trainer.create_hparams() trainer_lib.add_problem_hparams(hparams, FLAGS.problem) attack_params = create_attack_params() attack_params.add_hparam("eps", 0.0) config = t2t_trainer.create_run_config(hparams) params = {"batch_size": hparams.batch_size} # add "_rev" as a hack to avoid image standardization problem = registry.problem(FLAGS.problem + "_rev") input_fn = problem.make_estimator_input_fn( tf.estimator.ModeKeys.EVAL, hparams) dataset = input_fn(params, config).repeat() features, _ = dataset.make_one_shot_iterator().get_next() inputs, labels = features["targets"], features["inputs"] inputs = tf.to_float(inputs) labels = tf.squeeze(labels) sess = tf.Session() model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.model, hparams, use_tpu=FLAGS.use_tpu) ch_model = adv_attack_utils.T2TAttackModel(model_fn, params, config) acc_mask = None probs = ch_model.get_probs(inputs) if FLAGS.ignore_incorrect: preds = tf.argmax(probs, -1) preds = tf.squeeze(preds) acc_mask = tf.to_float(tf.equal(labels, preds)) one_hot_labels = tf.one_hot(labels, probs.shape[-1]) attack = create_attack(attack_params.attack)(ch_model, sess=sess) # Restore weights saver = tf.train.Saver() checkpoint_path = os.path.expanduser(FLAGS.output_dir or FLAGS.checkpoint_path) saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path)) # reuse variables tf.get_variable_scope().reuse_variables() def compute_accuracy(x, labels, mask): preds = ch_model.get_probs(x) preds = tf.squeeze(preds) preds = tf.argmax(preds, -1, output_type=labels.dtype) _, acc_update_op = tf.metrics.accuracy( labels=labels, predictions=preds, weights=mask) sess.run(tf.initialize_local_variables()) for _ in range(FLAGS.eval_steps): acc = sess.run(acc_update_op) return acc acc = compute_accuracy(inputs, labels, acc_mask) epsilon_acc_pairs = [(0.0, acc)] for epsilon in attack_params.attack_epsilons: attack_params.eps = epsilon adv_x = attack.generate(inputs, y=one_hot_labels, **attack_params.values()) acc = compute_accuracy(adv_x, labels, acc_mask) epsilon_acc_pairs.append((epsilon, acc)) for epsilon, acc in epsilon_acc_pairs: tf.logging.info("Accuracy @ eps=%f: %f" % (epsilon, acc))
def main(_): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) # Create hparams hparams = trainer_lib.create_hparams( FLAGS.hparams_set, FLAGS.hparams, data_dir=os.path.expanduser(FLAGS.data_dir), problem_name=FLAGS.problem) hparams.force_full_predict = True hparams.scheduled_sampling_k = -1 # Params num_agents = 1 # TODO(mbz): fix the code for more agents num_steps = FLAGS.num_steps if hasattr(hparams.problem, "num_actions"): num_actions = hparams.problem.num_actions else: num_actions = None frame_shape = hparams.problem.frame_shape resized_frame = hparams.preprocess_resize_frames is not None if resized_frame: frame_shape = hparams.preprocess_resize_frames frame_shape += [hparams.problem.num_channels] dataset = registry.problem(FLAGS.problem).dataset( tf.estimator.ModeKeys.TRAIN, shuffle_files=True, data_dir=os.path.expanduser(FLAGS.data_dir), hparams=hparams) dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(num_agents)) data = dataset.make_one_shot_iterator().get_next() # Setup input placeholders input_size = [num_agents, hparams.video_num_input_frames] if num_actions is None: placeholders = { "inputs": tf.placeholder(tf.float32, input_size + frame_shape) } else: placeholders = { "inputs": tf.placeholder(tf.float32, input_size + frame_shape), "input_action": tf.placeholder(tf.int64, input_size + [1]), "input_reward": tf.placeholder(tf.int64, input_size + [1]), "reset_internal_states": tf.placeholder(tf.float32, []), } # Create model. model_cls = registry.model(FLAGS.model) model = model_cls(hparams, tf.estimator.ModeKeys.PREDICT) prediction_ops = model.infer(placeholders) states_q = Queue(maxsize=hparams.video_num_input_frames) actions_q = Queue(maxsize=hparams.video_num_input_frames) rewards_q = Queue(maxsize=hparams.video_num_input_frames) if num_actions is not None: all_qs = [states_q, actions_q, rewards_q] else: all_qs = [states_q] writer = common_video.WholeVideoWriter( fps=FLAGS.fps, output_path=FLAGS.output_gif) saver = tf.train.Saver(tf.trainable_variables()) with tf.train.SingularMonitoredSession() as sess: # Load latest checkpoint ckpt = tf.train.get_checkpoint_state(FLAGS.output_dir).model_checkpoint_path saver.restore(sess.raw_session(), ckpt) # get init frames from the dataset data_np = sess.run(data) frames = np.split(data_np["inputs"], hparams.video_num_input_frames, 1) for frame in frames: frame = np.squeeze(frame, 1) states_q.put(frame) writer.write(frame[0].astype(np.uint8)) if num_actions is not None: actions = np.split(data_np["input_action"], hparams.video_num_input_frames, 1) for action in actions: actions_q.put(np.squeeze(action, 1)) rewards = np.split(data_np["input_reward"], hparams.video_num_input_frames, 1) for reward in rewards: rewards_q.put(np.squeeze(reward, 1)) for step in range(num_steps): print(">>>>>>> ", step) if num_actions is not None: random_actions = np.random.randint(num_actions-1) random_actions = np.expand_dims(random_actions, 0) random_actions = np.tile(random_actions, (num_agents, 1)) # Shape inputs and targets inputs, input_action, input_reward = ( np.stack(list(q.queue), axis=1) for q in all_qs) else: assert len(all_qs) == 1 q = all_qs[0] elems = list(q.queue) # Need to adjust shapes sometimes. for i, e in enumerate(elems): if len(e.shape) < 4: elems[i] = np.expand_dims(e, axis=0) inputs = np.stack(elems, axis=1) # Predict next frames if num_actions is None: feed = {placeholders["inputs"]: inputs} else: feed = { placeholders["inputs"]: inputs, placeholders["input_action"]: input_action, placeholders["input_reward"]: input_reward, placeholders["reset_internal_states"]: float(step == 0), } predictions = sess.run(prediction_ops, feed_dict=feed) if num_actions is None: predicted_states = predictions[:, 0] else: predicted_states = predictions["targets"][:, 0] predicted_reward = predictions["target_reward"][:, 0] # Update queues if num_actions is None: new_data = (predicted_states) else: new_data = (predicted_states, random_actions, predicted_reward) for q, d in zip(all_qs, new_data): q.get() q.put(d.copy()) writer.write(np.round(predicted_states[0]).astype(np.uint8)) writer.finish_to_disk()
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: t2t_trainer.generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) if FLAGS.surrogate_attack: tf.logging.warn("Performing surrogate model attack.") sur_hparams = create_surrogate_hparams() trainer_lib.add_problem_hparams(sur_hparams, FLAGS.problem) hparams = t2t_trainer.create_hparams() trainer_lib.add_problem_hparams(hparams, FLAGS.problem) attack_params = create_attack_params() attack_params.add_hparam(attack_params.epsilon_name, 0.0) if FLAGS.surrogate_attack: sur_config = create_surrogate_run_config(sur_hparams) config = t2t_trainer.create_run_config(hparams) params = { "batch_size": hparams.batch_size, "use_tpu": FLAGS.use_tpu, } # add "_rev" as a hack to avoid image standardization problem = registry.problem(FLAGS.problem + "_rev") inputs, labels, features = prepare_data(problem, hparams, params, config) sess = tf.Session() if FLAGS.surrogate_attack: sur_model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.surrogate_model, sur_hparams, use_tpu=FLAGS.use_tpu) sur_ch_model = adv_attack_utils.T2TAttackModel( sur_model_fn, features, params, sur_config, scope="surrogate") # Dummy call to construct graph sur_ch_model.get_probs(inputs) checkpoint_path = os.path.expanduser(FLAGS.surrogate_output_dir) tf.contrib.framework.init_from_checkpoint( tf.train.latest_checkpoint(checkpoint_path), {"/": "surrogate/"}) sess.run(tf.global_variables_initializer()) other_vars = set(tf.global_variables()) model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.model, hparams) ch_model = adv_attack_utils.T2TAttackModel(model_fn, features, params, config) acc_mask = None probs = ch_model.get_probs(inputs) if FLAGS.ignore_incorrect: preds = tf.argmax(probs, -1, output_type=labels.dtype) preds = tf.reshape(preds, labels.shape) acc_mask = tf.to_float(tf.equal(labels, preds)) one_hot_labels = tf.one_hot(labels, probs.shape[-1]) if FLAGS.surrogate_attack: attack = create_attack(attack_params.attack)(sur_ch_model, sess=sess) else: attack = create_attack(attack_params.attack)(ch_model, sess=sess) new_vars = set(tf.global_variables()) - other_vars # Restore weights saver = tf.train.Saver(new_vars) checkpoint_path = os.path.expanduser(FLAGS.output_dir) saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path)) # reuse variables tf.get_variable_scope().reuse_variables() def compute_accuracy(x, l, mask): """Compute model accuracy.""" preds = ch_model.get_probs(x) preds = tf.squeeze(preds) preds = tf.argmax(preds, -1, output_type=l.dtype) _, acc_update_op = tf.metrics.accuracy(l, preds, weights=mask) if FLAGS.surrogate_attack: preds = sur_ch_model.get_probs(x) preds = tf.squeeze(preds) preds = tf.argmax(preds, -1, output_type=l.dtype) acc_update_op = tf.tuple((acc_update_op, tf.metrics.accuracy(l, preds, weights=mask)[1])) sess.run(tf.initialize_local_variables()) for i in range(FLAGS.eval_steps): tf.logging.info( "\tEvaluating batch [%d / %d]" % (i + 1, FLAGS.eval_steps)) acc = sess.run(acc_update_op) if FLAGS.surrogate_attack: tf.logging.info("\tFinal acc: (%.4f, %.4f)" % (acc[0], acc[1])) else: tf.logging.info("\tFinal acc: %.4f" % acc) return acc epsilon_acc_pairs = [] for epsilon in attack_params.attack_epsilons: tf.logging.info("Attacking @ eps=%.4f" % epsilon) attack_params.set_hparam(attack_params.epsilon_name, epsilon) adv_x = attack.generate(inputs, y=one_hot_labels, **attack_params.values()) acc = compute_accuracy(adv_x, labels, acc_mask) epsilon_acc_pairs.append((epsilon, acc)) for epsilon, acc in epsilon_acc_pairs: if FLAGS.surrogate_attack: tf.logging.info( "Accuracy @ eps=%.4f: (%.4f, %.4f)" % (epsilon, acc[0], acc[1])) else: tf.logging.info("Accuracy @ eps=%.4f: %.4f" % (epsilon, acc))