def main(argv): tf.logging.set_verbosity(tf.logging.INFO) if FLAGS.schedule != "train": mlperf_log.transformer_print(key=mlperf_log.RUN_START) if FLAGS.schedule == "run_std_server": run_std_server() mlperf_log.transformer_print(key=mlperf_log.RUN_SET_RANDOM_SEED, value=FLAGS.random_seed) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: set_hparams_from_args(argv[1:]) hparams = create_hparams() exp_fn = create_experiment_fn() exp = exp_fn(create_run_config(hparams), hparams) if is_chief(): save_metadata(hparams) execute_schedule(exp) if FLAGS.schedule != "train": mlperf_log.transformer_print(key=mlperf_log.RUN_FINAL)
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: set_hparams_from_args(argv[1:]) hparams = create_hparams() with maybe_cloud_tpu(): exp_fn = create_experiment_fn() exp = exp_fn(create_run_config(hparams), hparams) if is_chief(): save_metadata(hparams) execute_schedule(exp)
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) if FLAGS.v2: tf.enable_v2_behavior() # Hacking main v1 flags to work with v2. config_strs = [] config_strs.append( "train_fn.train_steps=" + str(FLAGS.train_steps)) config_strs.append( "train_fn.eval_steps=" + str(FLAGS.eval_steps)) config_strs.append( "train_fn.eval_frequency=" + str(FLAGS.local_eval_frequency)) if FLAGS.hparams: config_strs.extend(str(FLAGS.hparams).split(",")) config_str = "\n".join(config_strs) data_dir = os.path.expanduser(FLAGS.data_dir) output_dir = os.path.expanduser(FLAGS.output_dir) t2t_v2.t2t_train(FLAGS.model, FLAGS.problem, data_dir=data_dir, output_dir=output_dir, config_file=FLAGS.hparams_set, config=config_str) return usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) # If we just have to print the registry, do that and exit early. maybe_log_registry_and_exit() # Create HParams. if argv: set_hparams_from_args(argv[1:]) hparams = create_hparams() if FLAGS.schedule == "train" or FLAGS.schedule == "train_eval_and_decode": mlperf_log.transformer_print(key=mlperf_log.RUN_START, hparams=hparams) if FLAGS.schedule == "run_std_server": run_std_server() mlperf_log.transformer_print( key=mlperf_log.RUN_SET_RANDOM_SEED, value=FLAGS.random_seed, hparams=hparams) trainer_lib.set_random_seed(FLAGS.random_seed) if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() exp_fn = create_experiment_fn() exp = exp_fn(create_run_config(hparams), hparams) if is_chief(): save_metadata(hparams) execute_schedule(exp) if FLAGS.schedule != "train": mlperf_log.transformer_print(key=mlperf_log.RUN_FINAL, hparams=hparams)
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: t2t_trainer.generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) root_output_dir = FLAGS.output_dir if FLAGS.teacher_dir: teacher_dir = FLAGS.teacher_dir else: teacher_dir = os.path.join(root_output_dir, "teacher") # Train Teacher ============ if FLAGS.skip_teacher_training: tf.logging.info("training teacher skipped") else: hparams = t2t_trainer.create_hparams() hparams.distill_phase = "train" FLAGS.output_dir = teacher_dir exp_fn = t2t_trainer.create_experiment_fn() run_config = t2t_trainer.create_run_config(hparams) exp = exp_fn(run_config, hparams) if t2t_trainer.is_chief(): t2t_trainer.save_metadata(hparams) t2t_trainer.execute_schedule(exp) # ========================== # Train Student ============ hparams = t2t_trainer.create_hparams() hparams.add_hparam("teacher_dir", teacher_dir) hparams.distill_phase = "distill" if FLAGS.student_dir: student_dir = FLAGS.student_dir else: student_dir = os.path.join(root_output_dir, "student") FLAGS.output_dir = student_dir hparams.add_hparam("student_dir", student_dir) exp_fn = t2t_trainer.create_experiment_fn() run_config = t2t_trainer.create_run_config(hparams) exp = exp_fn(run_config, hparams) if t2t_trainer.is_chief(): t2t_trainer.save_metadata(hparams) t2t_trainer.execute_schedule(exp)
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: t2t_trainer.generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) with t2t_trainer.maybe_cloud_tpu(): root_output_dir = FLAGS.output_dir # Train Teacher ============ hparams = t2t_trainer.create_hparams() hparams.distill_phase = "train" teacher_dir = os.path.join(root_output_dir, "teacher") FLAGS.output_dir = teacher_dir exp_fn = t2t_trainer.create_experiment_fn() run_config = t2t_trainer.create_run_config(hparams) exp = exp_fn(run_config, hparams) if t2t_trainer.is_chief(): t2t_trainer.save_metadata(hparams) t2t_trainer.execute_schedule(exp) # ========================== # Train Student ============ hparams = t2t_trainer.create_hparams() hparams.add_hparam("teacher_dir", teacher_dir) hparams.distill_phase = "distill" student_dir = os.path.join(root_output_dir, "student") FLAGS.output_dir = student_dir exp_fn = t2t_trainer.create_experiment_fn() run_config = t2t_trainer.create_run_config(hparams) exp = exp_fn(run_config, hparams) if t2t_trainer.is_chief(): t2t_trainer.save_metadata(hparams) t2t_trainer.execute_schedule(exp)
def create_student_experiment(run_config, hparams, argv): """Creates experiment function.""" tf.logging.info("training student") tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: return cloud_mlengine.launch() if FLAGS.generate_data: t2t_trainer.generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) hparams.add_hparam("teacher_dir", FLAGS.teacher_dir) hparams.add_hparam("student_dir", FLAGS.student_dir) hparams.distill_phase = "distill" exp_fn = t2t_trainer.create_experiment_fn() exp = exp_fn(run_config, hparams) return exp
def create_teacher_experiment(run_config, hparams, argv): """Creates experiment function.""" tf.logging.info("training teacher") tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: return cloud_mlengine.launch() if FLAGS.generate_data: t2t_trainer.generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) with t2t_trainer.maybe_cloud_tpu(): hparams.distill_phase = "train" exp_fn = t2t_trainer.create_experiment_fn() exp = exp_fn(run_config, hparams) return exp
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) # If we just have to print the registry, do that and exit early. maybe_log_registry_and_exit() # Create HParams. if argv: set_hparams_from_args(argv[1:]) if FLAGS.schedule != "run_std_server": hparams = create_hparams() if FLAGS.gpu_automatic_mixed_precision: setattr(hparams, "gpu_automatic_mixed_precision", True) if FLAGS.schedule == "train" or FLAGS.schedule == "train_eval_and_decode": mlperf_log.transformer_print(key=mlperf_log.RUN_START, hparams=hparams) if FLAGS.schedule == "run_std_server": run_std_server() mlperf_log.transformer_print( key=mlperf_log.RUN_SET_RANDOM_SEED, value=FLAGS.random_seed, hparams=hparams) trainer_lib.set_random_seed(FLAGS.random_seed) for flag, val in FLAGS.__flags.items(): print(flag, ": ", val.value) if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() exp_fn = create_experiment_fn() exp = exp_fn(create_run_config(hparams), hparams) if is_chief(): save_metadata(hparams) execute_schedule(exp) if FLAGS.schedule != "train": mlperf_log.transformer_print(key=mlperf_log.RUN_FINAL, hparams=hparams)
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) hvd.init() trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) log_registry() if FLAGS.cloud_mlengine: return cloud_mlengine.launch() if FLAGS.generate_data: generate_data() if hasattr(FLAGS, "job_dir") and FLAGS.job_dir: FLAGS.output_dir = FLAGS.job_dir if argv: set_hparams_from_args(argv[1:]) # hparams = create_hparams() if is_chief(): save_metadata(hparams) # create_run_config会调用trainer_lib.create_session_config,这个函数包含gup_options初始化 config = create_run_config(hparams) decode_hparams = decoding.decode_hparams(FLAGS.decode_hparams) schedule = FLAGS.schedule estimator = create_estimator_fn(FLAGS.model, hparams, config, schedule, decode_hparams) # logging_hook = tf.train.LoggingTensorHook({"step": "test"}, every_n_iter=5) bcast_hook = hvd.BroadcastGlobalVariablesHook(0) estimator.train( input_fn=train_input_fn(hparams), steps=FLAGS.train_steps, hooks=[bcast_hook] )
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: t2t_trainer.generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) if FLAGS.surrogate_attack: tf.logging.warn("Performing surrogate model attack.") sur_hparams = create_surrogate_hparams() trainer_lib.add_problem_hparams(sur_hparams, FLAGS.problem) hparams = t2t_trainer.create_hparams() trainer_lib.add_problem_hparams(hparams, FLAGS.problem) attack_params = create_attack_params() attack_params.add_hparam(attack_params.epsilon_name, 0.0) if FLAGS.surrogate_attack: sur_config = create_surrogate_run_config(sur_hparams) config = t2t_trainer.create_run_config(hparams) params = { "batch_size": hparams.batch_size, "use_tpu": FLAGS.use_tpu, } # add "_rev" as a hack to avoid image standardization problem = registry.problem(FLAGS.problem + "_rev") inputs, labels, features = prepare_data(problem, hparams, params, config) sess = tf.Session() if FLAGS.surrogate_attack: sur_model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.surrogate_model, sur_hparams, use_tpu=FLAGS.use_tpu) sur_ch_model = adv_attack_utils.T2TAttackModel( sur_model_fn, features, params, sur_config, scope="surrogate") # Dummy call to construct graph sur_ch_model.get_probs(inputs) checkpoint_path = os.path.expanduser(FLAGS.surrogate_output_dir) tf.train.init_from_checkpoint( tf.train.latest_checkpoint(checkpoint_path), {"/": "surrogate/"}) sess.run(tf.global_variables_initializer()) other_vars = set(tf.global_variables()) model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.model, hparams) ch_model = adv_attack_utils.T2TAttackModel(model_fn, features, params, config) acc_mask = None probs = ch_model.get_probs(inputs) if FLAGS.ignore_incorrect: preds = tf.argmax(probs, -1, output_type=labels.dtype) preds = tf.reshape(preds, labels.shape) acc_mask = tf.to_float(tf.equal(labels, preds)) one_hot_labels = tf.one_hot(labels, probs.shape[-1]) if FLAGS.surrogate_attack: attack = create_attack(attack_params.attack)(sur_ch_model, sess=sess) else: attack = create_attack(attack_params.attack)(ch_model, sess=sess) new_vars = set(tf.global_variables()) - other_vars # Restore weights saver = tf.train.Saver(new_vars) checkpoint_path = os.path.expanduser(FLAGS.output_dir) saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path)) # reuse variables tf.get_variable_scope().reuse_variables() def compute_accuracy(x, l, mask): """Compute model accuracy.""" preds = ch_model.get_probs(x) preds = tf.squeeze(preds) preds = tf.argmax(preds, -1, output_type=l.dtype) _, acc_update_op = tf.metrics.accuracy(l, preds, weights=mask) if FLAGS.surrogate_attack: preds = sur_ch_model.get_probs(x) preds = tf.squeeze(preds) preds = tf.argmax(preds, -1, output_type=l.dtype) acc_update_op = tf.tuple((acc_update_op, tf.metrics.accuracy(l, preds, weights=mask)[1])) sess.run(tf.initialize_local_variables()) for i in range(FLAGS.eval_steps): tf.logging.info( "\tEvaluating batch [%d / %d]" % (i + 1, FLAGS.eval_steps)) acc = sess.run(acc_update_op) if FLAGS.surrogate_attack: tf.logging.info("\tFinal acc: (%.4f, %.4f)" % (acc[0], acc[1])) else: tf.logging.info("\tFinal acc: %.4f" % acc) return acc epsilon_acc_pairs = [] for epsilon in attack_params.attack_epsilons: tf.logging.info("Attacking @ eps=%.4f" % epsilon) attack_params.set_hparam(attack_params.epsilon_name, epsilon) adv_x = attack.generate(inputs, y=one_hot_labels, **attack_params.values()) acc = compute_accuracy(adv_x, labels, acc_mask) epsilon_acc_pairs.append((epsilon, acc)) for epsilon, acc in epsilon_acc_pairs: if FLAGS.surrogate_attack: tf.logging.info( "Accuracy @ eps=%.4f: (%.4f, %.4f)" % (epsilon, acc[0], acc[1])) else: tf.logging.info("Accuracy @ eps=%.4f: %.4f" % (epsilon, acc))
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: t2t_trainer.generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) if FLAGS.surrogate_attack: tf.logging.warn("Performing surrogate model attack.") sur_hparams = create_surrogate_hparams() trainer_lib.add_problem_hparams(sur_hparams, FLAGS.problem) hparams = t2t_trainer.create_hparams() trainer_lib.add_problem_hparams(hparams, FLAGS.problem) attack_params = create_attack_params() attack_params.add_hparam(attack_params.epsilon_name, 0.0) if FLAGS.surrogate_attack: sur_config = create_surrogate_run_config(sur_hparams) config = t2t_trainer.create_run_config(hparams) params = { "batch_size": hparams.batch_size, "use_tpu": FLAGS.use_tpu, } # add "_rev" as a hack to avoid image standardization problem = registry.problem(FLAGS.problem + "_rev") inputs, labels, features = prepare_data(problem, hparams, params, config) sess = tf.Session() if FLAGS.surrogate_attack: sur_model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.surrogate_model, sur_hparams, use_tpu=FLAGS.use_tpu) sur_ch_model = adv_attack_utils.T2TAttackModel( sur_model_fn, features, params, sur_config, scope="surrogate") # Dummy call to construct graph sur_ch_model.get_probs(inputs) checkpoint_path = os.path.expanduser(FLAGS.surrogate_output_dir) tf.contrib.framework.init_from_checkpoint( tf.train.latest_checkpoint(checkpoint_path), {"/": "surrogate/"}) sess.run(tf.global_variables_initializer()) other_vars = set(tf.global_variables()) model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.model, hparams) ch_model = adv_attack_utils.T2TAttackModel(model_fn, features, params, config) acc_mask = None probs = ch_model.get_probs(inputs) if FLAGS.ignore_incorrect: preds = tf.argmax(probs, -1, output_type=labels.dtype) preds = tf.reshape(preds, labels.shape) acc_mask = tf.to_float(tf.equal(labels, preds)) one_hot_labels = tf.one_hot(labels, probs.shape[-1]) if FLAGS.surrogate_attack: attack = create_attack(attack_params.attack)(sur_ch_model, sess=sess) else: attack = create_attack(attack_params.attack)(ch_model, sess=sess) new_vars = set(tf.global_variables()) - other_vars # Restore weights saver = tf.train.Saver(new_vars) checkpoint_path = os.path.expanduser(FLAGS.output_dir) saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path)) # reuse variables tf.get_variable_scope().reuse_variables() def compute_accuracy(x, l, mask): """Compute model accuracy.""" preds = ch_model.get_probs(x) preds = tf.squeeze(preds) preds = tf.argmax(preds, -1, output_type=l.dtype) _, acc_update_op = tf.metrics.accuracy(l, preds, weights=mask) if FLAGS.surrogate_attack: preds = sur_ch_model.get_probs(x) preds = tf.squeeze(preds) preds = tf.argmax(preds, -1, output_type=l.dtype) acc_update_op = tf.tuple((acc_update_op, tf.metrics.accuracy(l, preds, weights=mask)[1])) sess.run(tf.initialize_local_variables()) for i in range(FLAGS.eval_steps): tf.logging.info( "\tEvaluating batch [%d / %d]" % (i + 1, FLAGS.eval_steps)) acc = sess.run(acc_update_op) if FLAGS.surrogate_attack: tf.logging.info("\tFinal acc: (%.4f, %.4f)" % (acc[0], acc[1])) else: tf.logging.info("\tFinal acc: %.4f" % acc) return acc epsilon_acc_pairs = [] for epsilon in attack_params.attack_epsilons: tf.logging.info("Attacking @ eps=%.4f" % epsilon) attack_params.set_hparam(attack_params.epsilon_name, epsilon) adv_x = attack.generate(inputs, y=one_hot_labels, **attack_params.values()) acc = compute_accuracy(adv_x, labels, acc_mask) epsilon_acc_pairs.append((epsilon, acc)) for epsilon, acc in epsilon_acc_pairs: if FLAGS.surrogate_attack: tf.logging.info( "Accuracy @ eps=%.4f: (%.4f, %.4f)" % (epsilon, acc[0], acc[1])) else: tf.logging.info("Accuracy @ eps=%.4f: %.4f" % (epsilon, acc))
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) if FLAGS.jax: # Setup trax FLAGS dataset = FLAGS.problem model = FLAGS.model data_dir = FLAGS.data_dir output_dir = FLAGS.output_dir config_file = [FLAGS.hparams_set] config = [ "train.train_steps=%d" % FLAGS.train_steps, "train.eval_steps=%d" % FLAGS.eval_steps, "train.eval_frequency=%d" % FLAGS.local_eval_frequency, ] + str(FLAGS.hparams).split(",") # Copied _setup_gin exactly from trax/trainer.py and removed "FLAGS." def _setup_gin(): """Setup gin configuration.""" # Imports for configurables # pylint: disable=g-import-not-at-top,unused-import,g-bad-import-order,reimported,unused-variable from tensor2tensor.trax import inputs as _trax_inputs from tensor2tensor.trax import models as _trax_models from tensor2tensor.trax import optimizers as _trax_opt # pylint: disable=g-import-not-at-top,unused-import,g-bad-import-order,reimported,unused-variable configs = config or [] # Override with --dataset and --model if dataset: configs.append("inputs.dataset_name='%s'" % dataset) configs.append("inputs.data_dir='%s'" % data_dir) configs.append("[email protected]") if model: configs.append("[email protected].%s" % model) gin.parse_config_files_and_bindings(config_file, configs) _setup_gin() trax.train(output_dir=output_dir) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) # If we just have to print the registry, do that and exit early. maybe_log_registry_and_exit() # Create HParams. if argv: set_hparams_from_args(argv[1:]) hparams = create_hparams() if FLAGS.schedule == "train" or FLAGS.schedule == "train_eval_and_decode": mlperf_log.transformer_print(key=mlperf_log.RUN_START, hparams=hparams) if FLAGS.schedule == "run_std_server": run_std_server() mlperf_log.transformer_print( key=mlperf_log.RUN_SET_RANDOM_SEED, value=FLAGS.random_seed, hparams=hparams) trainer_lib.set_random_seed(FLAGS.random_seed) if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() exp_fn = create_experiment_fn() exp = exp_fn(create_run_config(hparams), hparams) if is_chief(): save_metadata(hparams) execute_schedule(exp) if FLAGS.schedule != "train": mlperf_log.transformer_print(key=mlperf_log.RUN_FINAL, hparams=hparams)
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: t2t_trainer.generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) hparams = t2t_trainer.create_hparams() trainer_lib.add_problem_hparams(hparams, FLAGS.problem) attack_params = create_attack_params() attack_params.add_hparam("eps", 0.0) config = t2t_trainer.create_run_config(hparams) params = {"batch_size": hparams.batch_size} # add "_rev" as a hack to avoid image standardization problem = registry.problem(FLAGS.problem + "_rev") input_fn = problem.make_estimator_input_fn(tf.estimator.ModeKeys.EVAL, hparams) dataset = input_fn(params, config).repeat() features, _ = dataset.make_one_shot_iterator().get_next() inputs, labels = features["targets"], features["inputs"] inputs = tf.to_float(inputs) labels = tf.squeeze(labels) sess = tf.Session() model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.model, hparams, use_tpu=FLAGS.use_tpu) ch_model = adv_attack_utils.T2TAttackModel(model_fn, params, config) acc_mask = None probs = ch_model.get_probs(inputs) if FLAGS.ignore_incorrect: preds = tf.argmax(probs, -1) preds = tf.squeeze(preds) acc_mask = tf.to_float(tf.equal(labels, preds)) one_hot_labels = tf.one_hot(labels, probs.shape[-1]) attack = create_attack(attack_params.attack)(ch_model, sess=sess) # Restore weights saver = tf.train.Saver() checkpoint_path = os.path.expanduser(FLAGS.output_dir or FLAGS.checkpoint_path) saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path)) # reuse variables tf.get_variable_scope().reuse_variables() def compute_accuracy(x, labels, mask): preds = ch_model.get_probs(x) preds = tf.squeeze(preds) preds = tf.argmax(preds, -1, output_type=labels.dtype) _, acc_update_op = tf.metrics.accuracy(labels=labels, predictions=preds, weights=mask) sess.run(tf.initialize_local_variables()) for _ in range(FLAGS.eval_steps): acc = sess.run(acc_update_op) return acc acc = compute_accuracy(inputs, labels, acc_mask) epsilon_acc_pairs = [(0.0, acc)] for epsilon in attack_params.attack_epsilons: attack_params.eps = epsilon adv_x = attack.generate(inputs, y=one_hot_labels, **attack_params.values()) acc = compute_accuracy(adv_x, labels, acc_mask) epsilon_acc_pairs.append((epsilon, acc)) for epsilon, acc in epsilon_acc_pairs: tf.logging.info("Accuracy @ eps=%f: %f" % (epsilon, acc))
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: t2t_trainer.generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) hparams = t2t_trainer.create_hparams() trainer_lib.add_problem_hparams(hparams, FLAGS.problem) attack_params = create_attack_params() attack_params.add_hparam("eps", 0.0) config = t2t_trainer.create_run_config(hparams) params = {"batch_size": hparams.batch_size} # add "_rev" as a hack to avoid image standardization problem = registry.problem(FLAGS.problem + "_rev") input_fn = problem.make_estimator_input_fn( tf.estimator.ModeKeys.EVAL, hparams) dataset = input_fn(params, config).repeat() features, _ = dataset.make_one_shot_iterator().get_next() inputs, labels = features["targets"], features["inputs"] inputs = tf.to_float(inputs) labels = tf.squeeze(labels) sess = tf.Session() model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.model, hparams, use_tpu=FLAGS.use_tpu) ch_model = adv_attack_utils.T2TAttackModel(model_fn, params, config) acc_mask = None probs = ch_model.get_probs(inputs) if FLAGS.ignore_incorrect: preds = tf.argmax(probs, -1) preds = tf.squeeze(preds) acc_mask = tf.to_float(tf.equal(labels, preds)) one_hot_labels = tf.one_hot(labels, probs.shape[-1]) attack = create_attack(attack_params.attack)(ch_model, sess=sess) # Restore weights saver = tf.train.Saver() checkpoint_path = os.path.expanduser(FLAGS.output_dir or FLAGS.checkpoint_path) saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path)) # reuse variables tf.get_variable_scope().reuse_variables() def compute_accuracy(x, labels, mask): preds = ch_model.get_probs(x) preds = tf.squeeze(preds) preds = tf.argmax(preds, -1, output_type=labels.dtype) _, acc_update_op = tf.metrics.accuracy( labels=labels, predictions=preds, weights=mask) sess.run(tf.initialize_local_variables()) for _ in range(FLAGS.eval_steps): acc = sess.run(acc_update_op) return acc acc = compute_accuracy(inputs, labels, acc_mask) epsilon_acc_pairs = [(0.0, acc)] for epsilon in attack_params.attack_epsilons: attack_params.eps = epsilon adv_x = attack.generate(inputs, y=one_hot_labels, **attack_params.values()) acc = compute_accuracy(adv_x, labels, acc_mask) epsilon_acc_pairs.append((epsilon, acc)) for epsilon, acc in epsilon_acc_pairs: tf.logging.info("Accuracy @ eps=%f: %f" % (epsilon, acc))