def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: t2t_trainer.generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) root_output_dir = FLAGS.output_dir if FLAGS.teacher_dir: teacher_dir = FLAGS.teacher_dir else: teacher_dir = os.path.join(root_output_dir, "teacher") # Train Teacher ============ if FLAGS.skip_teacher_training: tf.logging.info("training teacher skipped") else: hparams = t2t_trainer.create_hparams() hparams.distill_phase = "train" FLAGS.output_dir = teacher_dir exp_fn = t2t_trainer.create_experiment_fn() run_config = t2t_trainer.create_run_config(hparams) exp = exp_fn(run_config, hparams) if t2t_trainer.is_chief(): t2t_trainer.save_metadata(hparams) t2t_trainer.execute_schedule(exp) # ========================== # Train Student ============ hparams = t2t_trainer.create_hparams() hparams.add_hparam("teacher_dir", teacher_dir) hparams.distill_phase = "distill" if FLAGS.student_dir: student_dir = FLAGS.student_dir else: student_dir = os.path.join(root_output_dir, "student") FLAGS.output_dir = student_dir hparams.add_hparam("student_dir", student_dir) exp_fn = t2t_trainer.create_experiment_fn() run_config = t2t_trainer.create_run_config(hparams) exp = exp_fn(run_config, hparams) if t2t_trainer.is_chief(): t2t_trainer.save_metadata(hparams) t2t_trainer.execute_schedule(exp)
def create_teacher_experiment(run_config, hparams, argv): """Creates experiment function.""" tf.logging.info("training teacher") tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: return cloud_mlengine.launch() if FLAGS.generate_data: t2t_trainer.generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) with t2t_trainer.maybe_cloud_tpu(): hparams.distill_phase = "train" exp_fn = t2t_trainer.create_experiment_fn() exp = exp_fn(run_config, hparams) return exp
def create_student_experiment(run_config, hparams, argv): """Creates experiment function.""" tf.logging.info("training student") tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: return cloud_mlengine.launch() if FLAGS.generate_data: t2t_trainer.generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) hparams.add_hparam("teacher_dir", FLAGS.teacher_dir) hparams.add_hparam("student_dir", FLAGS.student_dir) hparams.distill_phase = "distill" exp_fn = t2t_trainer.create_experiment_fn() exp = exp_fn(run_config, hparams) return exp
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.generate_data: t2t_trainer.generate_data() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) hparams = t2t_trainer.create_hparams() trainer_lib.add_problem_hparams(hparams, FLAGS.problem) pruning_params = create_pruning_params() pruning_strategy = create_pruning_strategy(pruning_params.strategy) config = t2t_trainer.create_run_config(hparams) params = {"batch_size": hparams.batch_size} # add "_rev" as a hack to avoid image standardization problem = registry.problem(FLAGS.problem) input_fn = problem.make_estimator_input_fn(tf.estimator.ModeKeys.EVAL, hparams) dataset = input_fn(params, config).repeat() features, labels = dataset.make_one_shot_iterator().get_next() sess = tf.Session() model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.model, hparams, use_tpu=FLAGS.use_tpu) spec = model_fn( features, labels, tf.estimator.ModeKeys.EVAL, params=hparams, config=config) # Restore weights saver = tf.train.Saver() checkpoint_path = os.path.expanduser(FLAGS.output_dir or FLAGS.checkpoint_path) saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path)) def eval_model(): preds = spec.predictions["predictions"] preds = tf.argmax(preds, -1, output_type=labels.dtype) _, acc_update_op = tf.metrics.accuracy(labels=labels, predictions=preds) sess.run(tf.initialize_local_variables()) for _ in range(FLAGS.eval_steps): acc = sess.run(acc_update_op) return acc pruning_utils.sparsify(sess, eval_model, pruning_strategy, pruning_params)
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: t2t_trainer.generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) with t2t_trainer.maybe_cloud_tpu(): root_output_dir = FLAGS.output_dir # Train Teacher ============ hparams = t2t_trainer.create_hparams() hparams.distill_phase = "train" teacher_dir = os.path.join(root_output_dir, "teacher") FLAGS.output_dir = teacher_dir exp_fn = t2t_trainer.create_experiment_fn() run_config = t2t_trainer.create_run_config(hparams) exp = exp_fn(run_config, hparams) if t2t_trainer.is_chief(): t2t_trainer.save_metadata(hparams) t2t_trainer.execute_schedule(exp) # ========================== # Train Student ============ hparams = t2t_trainer.create_hparams() hparams.add_hparam("teacher_dir", teacher_dir) hparams.distill_phase = "distill" student_dir = os.path.join(root_output_dir, "student") FLAGS.output_dir = student_dir exp_fn = t2t_trainer.create_experiment_fn() run_config = t2t_trainer.create_run_config(hparams) exp = exp_fn(run_config, hparams) if t2t_trainer.is_chief(): t2t_trainer.save_metadata(hparams) t2t_trainer.execute_schedule(exp)
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: t2t_trainer.generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) if FLAGS.surrogate_attack: tf.logging.warn("Performing surrogate model attack.") sur_hparams = create_surrogate_hparams() trainer_lib.add_problem_hparams(sur_hparams, FLAGS.problem) hparams = t2t_trainer.create_hparams() trainer_lib.add_problem_hparams(hparams, FLAGS.problem) attack_params = create_attack_params() attack_params.add_hparam(attack_params.epsilon_name, 0.0) if FLAGS.surrogate_attack: sur_config = create_surrogate_run_config(sur_hparams) config = t2t_trainer.create_run_config(hparams) params = { "batch_size": hparams.batch_size, "use_tpu": FLAGS.use_tpu, } # add "_rev" as a hack to avoid image standardization problem = registry.problem(FLAGS.problem + "_rev") inputs, labels, features = prepare_data(problem, hparams, params, config) sess = tf.Session() if FLAGS.surrogate_attack: sur_model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.surrogate_model, sur_hparams, use_tpu=FLAGS.use_tpu) sur_ch_model = adv_attack_utils.T2TAttackModel( sur_model_fn, features, params, sur_config, scope="surrogate") # Dummy call to construct graph sur_ch_model.get_probs(inputs) checkpoint_path = os.path.expanduser(FLAGS.surrogate_output_dir) tf.train.init_from_checkpoint( tf.train.latest_checkpoint(checkpoint_path), {"/": "surrogate/"}) sess.run(tf.global_variables_initializer()) other_vars = set(tf.global_variables()) model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.model, hparams) ch_model = adv_attack_utils.T2TAttackModel(model_fn, features, params, config) acc_mask = None probs = ch_model.get_probs(inputs) if FLAGS.ignore_incorrect: preds = tf.argmax(probs, -1, output_type=labels.dtype) preds = tf.reshape(preds, labels.shape) acc_mask = tf.to_float(tf.equal(labels, preds)) one_hot_labels = tf.one_hot(labels, probs.shape[-1]) if FLAGS.surrogate_attack: attack = create_attack(attack_params.attack)(sur_ch_model, sess=sess) else: attack = create_attack(attack_params.attack)(ch_model, sess=sess) new_vars = set(tf.global_variables()) - other_vars # Restore weights saver = tf.train.Saver(new_vars) checkpoint_path = os.path.expanduser(FLAGS.output_dir) saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path)) # reuse variables tf.get_variable_scope().reuse_variables() def compute_accuracy(x, l, mask): """Compute model accuracy.""" preds = ch_model.get_probs(x) preds = tf.squeeze(preds) preds = tf.argmax(preds, -1, output_type=l.dtype) _, acc_update_op = tf.metrics.accuracy(l, preds, weights=mask) if FLAGS.surrogate_attack: preds = sur_ch_model.get_probs(x) preds = tf.squeeze(preds) preds = tf.argmax(preds, -1, output_type=l.dtype) acc_update_op = tf.tuple((acc_update_op, tf.metrics.accuracy(l, preds, weights=mask)[1])) sess.run(tf.initialize_local_variables()) for i in range(FLAGS.eval_steps): tf.logging.info( "\tEvaluating batch [%d / %d]" % (i + 1, FLAGS.eval_steps)) acc = sess.run(acc_update_op) if FLAGS.surrogate_attack: tf.logging.info("\tFinal acc: (%.4f, %.4f)" % (acc[0], acc[1])) else: tf.logging.info("\tFinal acc: %.4f" % acc) return acc epsilon_acc_pairs = [] for epsilon in attack_params.attack_epsilons: tf.logging.info("Attacking @ eps=%.4f" % epsilon) attack_params.set_hparam(attack_params.epsilon_name, epsilon) adv_x = attack.generate(inputs, y=one_hot_labels, **attack_params.values()) acc = compute_accuracy(adv_x, labels, acc_mask) epsilon_acc_pairs.append((epsilon, acc)) for epsilon, acc in epsilon_acc_pairs: if FLAGS.surrogate_attack: tf.logging.info( "Accuracy @ eps=%.4f: (%.4f, %.4f)" % (epsilon, acc[0], acc[1])) else: tf.logging.info("Accuracy @ eps=%.4f: %.4f" % (epsilon, acc))
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: t2t_trainer.generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) if FLAGS.surrogate_attack: tf.logging.warn("Performing surrogate model attack.") sur_hparams = create_surrogate_hparams() trainer_lib.add_problem_hparams(sur_hparams, FLAGS.problem) hparams = t2t_trainer.create_hparams() trainer_lib.add_problem_hparams(hparams, FLAGS.problem) attack_params = create_attack_params() attack_params.add_hparam(attack_params.epsilon_name, 0.0) if FLAGS.surrogate_attack: sur_config = create_surrogate_run_config(sur_hparams) config = t2t_trainer.create_run_config(hparams) params = { "batch_size": hparams.batch_size, "use_tpu": FLAGS.use_tpu, } # add "_rev" as a hack to avoid image standardization problem = registry.problem(FLAGS.problem + "_rev") inputs, labels, features = prepare_data(problem, hparams, params, config) sess = tf.Session() if FLAGS.surrogate_attack: sur_model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.surrogate_model, sur_hparams, use_tpu=FLAGS.use_tpu) sur_ch_model = adv_attack_utils.T2TAttackModel( sur_model_fn, features, params, sur_config, scope="surrogate") # Dummy call to construct graph sur_ch_model.get_probs(inputs) checkpoint_path = os.path.expanduser(FLAGS.surrogate_output_dir) tf.contrib.framework.init_from_checkpoint( tf.train.latest_checkpoint(checkpoint_path), {"/": "surrogate/"}) sess.run(tf.global_variables_initializer()) other_vars = set(tf.global_variables()) model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.model, hparams) ch_model = adv_attack_utils.T2TAttackModel(model_fn, features, params, config) acc_mask = None probs = ch_model.get_probs(inputs) if FLAGS.ignore_incorrect: preds = tf.argmax(probs, -1, output_type=labels.dtype) preds = tf.reshape(preds, labels.shape) acc_mask = tf.to_float(tf.equal(labels, preds)) one_hot_labels = tf.one_hot(labels, probs.shape[-1]) if FLAGS.surrogate_attack: attack = create_attack(attack_params.attack)(sur_ch_model, sess=sess) else: attack = create_attack(attack_params.attack)(ch_model, sess=sess) new_vars = set(tf.global_variables()) - other_vars # Restore weights saver = tf.train.Saver(new_vars) checkpoint_path = os.path.expanduser(FLAGS.output_dir) saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path)) # reuse variables tf.get_variable_scope().reuse_variables() def compute_accuracy(x, l, mask): """Compute model accuracy.""" preds = ch_model.get_probs(x) preds = tf.squeeze(preds) preds = tf.argmax(preds, -1, output_type=l.dtype) _, acc_update_op = tf.metrics.accuracy(l, preds, weights=mask) if FLAGS.surrogate_attack: preds = sur_ch_model.get_probs(x) preds = tf.squeeze(preds) preds = tf.argmax(preds, -1, output_type=l.dtype) acc_update_op = tf.tuple((acc_update_op, tf.metrics.accuracy(l, preds, weights=mask)[1])) sess.run(tf.initialize_local_variables()) for i in range(FLAGS.eval_steps): tf.logging.info( "\tEvaluating batch [%d / %d]" % (i + 1, FLAGS.eval_steps)) acc = sess.run(acc_update_op) if FLAGS.surrogate_attack: tf.logging.info("\tFinal acc: (%.4f, %.4f)" % (acc[0], acc[1])) else: tf.logging.info("\tFinal acc: %.4f" % acc) return acc epsilon_acc_pairs = [] for epsilon in attack_params.attack_epsilons: tf.logging.info("Attacking @ eps=%.4f" % epsilon) attack_params.set_hparam(attack_params.epsilon_name, epsilon) adv_x = attack.generate(inputs, y=one_hot_labels, **attack_params.values()) acc = compute_accuracy(adv_x, labels, acc_mask) epsilon_acc_pairs.append((epsilon, acc)) for epsilon, acc in epsilon_acc_pairs: if FLAGS.surrogate_attack: tf.logging.info( "Accuracy @ eps=%.4f: (%.4f, %.4f)" % (epsilon, acc[0], acc[1])) else: tf.logging.info("Accuracy @ eps=%.4f: %.4f" % (epsilon, acc))
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.generate_data: t2t_trainer.generate_data() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) # hparams = t2t_trainer.create_hparams() # hparams.add_hparam("data_dir", FLAGS.data_dir) # trainer_lib.add_problem_hparams(hparams, FLAGS.problem) hparams_path = os.path.join(FLAGS.output_dir, "hparams.json") hparams = trainer_lib.create_hparams( FLAGS.hparams_set, FLAGS.hparams, data_dir=FLAGS.data_dir, problem_name=FLAGS.problem, hparams_path=hparams_path) hparams.add_hparam("model_dir", FLAGS.output_dir) config = t2t_trainer.create_run_config(hparams) params = {"batch_size": hparams.batch_size} # add "_rev" as a hack to avoid image standardization problem = registry.problem(FLAGS.problem) input_fn = problem.make_estimator_input_fn(tf.estimator.ModeKeys.EVAL, hparams) sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.model, hparams, use_tpu=False) dataset = input_fn(params, config).repeat() dataset_iteraor = dataset.make_one_shot_iterator() features, labels = dataset_iteraor.get_next() # tf.logging.info("### t2t_wei_feat_distrib.py features %s", features) spec = model_fn( features, labels, tf.estimator.ModeKeys.EVAL, params=hparams, config=config) # get the summary model structure graph summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph) # Restore weights saver = tf.train.Saver() checkpoint_path = os.path.expanduser(FLAGS.output_dir or FLAGS.checkpoint_path) tf.logging.info("### t2t_wei_feat_distrib.py checkpoint_path %s", checkpoint_path) # saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path)) # Load weights from checkpoint. ckpts = tf.train.get_checkpoint_state(checkpoint_path) ckpt = ckpts.model_checkpoint_path saver.restore(sess, ckpt) # saver.restore(sess, checkpoint_path+'/model.ckpt-1421000') # initialize_from_ckpt(checkpoint_path) # get parameter pruning_params = create_pruning_params() pruning_strategy = create_pruning_strategy(pruning_params.strategy) # get evalutaion graph if 'image' in FLAGS.problem: acc, acc_update_op = get_eval_graph_image(spec, labels) tf.summary.scalar('accuracy', acc) # define evaluation function def eval_model(): sess.run(tf.initialize_local_variables()) for _ in range(FLAGS.eval_steps): acc = sess.run(acc_update_op) return acc elif 'translate' in FLAGS.problem: bleu_op = get_eval_graph_trans(spec, labels) # tf.summary.scalar('bleu', bleu_op) # define evaluation function def eval_model(): bleu_value = 0 # sess.run(tf.initialize_local_variables()) # sess.run() # local_vars = tf.local_variables() # tf.logging.info("###!!!!!!! t2t_wei_feat_distrib.py local_vars %s", local_vars) # for _ in range(FLAGS.eval_steps): for _ in range(FLAGS.eval_steps): # outputs_tensor, labels_tensor, preds_tensor = sess.run([outputs, labels, preds]) bleu = sess.run(bleu_op) # tf.logging.info("### t2t_wei_feat_distrib.py outputs_tensor %s", outputs_tensor[0].shape) # tf.logging.info("### t2t_wei_feat_distrib.py labels_tensor %s", labels_tensor[0].shape) # tf.logging.info("### t2t_wei_feat_distrib.py preds %s", preds_tensor[0].shape) bleu_value += bleu bleu_value /= FLAGS.eval_steps return bleu_value # get weight distribution graph wei_feat_distrib.get_weight_distrib_graph(pruning_params) # do accuracy sparsity tradeoff for model weights wei_feat_distrib.wei_sparsity_acc_tradeoff(sess, eval_model, pruning_strategy, pruning_params, summary_writer) # do accuracy sparsity tradeoff for model weights # save the summary summary_writer.close() sess.run(tf.initialize_local_variables()) preds = spec.predictions["predictions"] # features_shape=tf.shape(features) pred_shape=tf.shape(preds) labels_shape=tf.shape(labels)
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: t2t_trainer.generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) hparams = t2t_trainer.create_hparams() trainer_lib.add_problem_hparams(hparams, FLAGS.problem) attack_params = create_attack_params() attack_params.add_hparam("eps", 0.0) config = t2t_trainer.create_run_config(hparams) params = {"batch_size": hparams.batch_size} # add "_rev" as a hack to avoid image standardization problem = registry.problem(FLAGS.problem + "_rev") input_fn = problem.make_estimator_input_fn(tf.estimator.ModeKeys.EVAL, hparams) dataset = input_fn(params, config).repeat() features, _ = dataset.make_one_shot_iterator().get_next() inputs, labels = features["targets"], features["inputs"] inputs = tf.to_float(inputs) labels = tf.squeeze(labels) sess = tf.Session() model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.model, hparams, use_tpu=FLAGS.use_tpu) ch_model = adv_attack_utils.T2TAttackModel(model_fn, params, config) acc_mask = None probs = ch_model.get_probs(inputs) if FLAGS.ignore_incorrect: preds = tf.argmax(probs, -1) preds = tf.squeeze(preds) acc_mask = tf.to_float(tf.equal(labels, preds)) one_hot_labels = tf.one_hot(labels, probs.shape[-1]) attack = create_attack(attack_params.attack)(ch_model, sess=sess) # Restore weights saver = tf.train.Saver() checkpoint_path = os.path.expanduser(FLAGS.output_dir or FLAGS.checkpoint_path) saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path)) # reuse variables tf.get_variable_scope().reuse_variables() def compute_accuracy(x, labels, mask): preds = ch_model.get_probs(x) preds = tf.squeeze(preds) preds = tf.argmax(preds, -1, output_type=labels.dtype) _, acc_update_op = tf.metrics.accuracy(labels=labels, predictions=preds, weights=mask) sess.run(tf.initialize_local_variables()) for _ in range(FLAGS.eval_steps): acc = sess.run(acc_update_op) return acc acc = compute_accuracy(inputs, labels, acc_mask) epsilon_acc_pairs = [(0.0, acc)] for epsilon in attack_params.attack_epsilons: attack_params.eps = epsilon adv_x = attack.generate(inputs, y=one_hot_labels, **attack_params.values()) acc = compute_accuracy(adv_x, labels, acc_mask) epsilon_acc_pairs.append((epsilon, acc)) for epsilon, acc in epsilon_acc_pairs: tf.logging.info("Accuracy @ eps=%f: %f" % (epsilon, acc))
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: t2t_trainer.generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) hparams = t2t_trainer.create_hparams() trainer_lib.add_problem_hparams(hparams, FLAGS.problem) attack_params = create_attack_params() attack_params.add_hparam("eps", 0.0) config = t2t_trainer.create_run_config(hparams) params = {"batch_size": hparams.batch_size} # add "_rev" as a hack to avoid image standardization problem = registry.problem(FLAGS.problem + "_rev") input_fn = problem.make_estimator_input_fn( tf.estimator.ModeKeys.EVAL, hparams) dataset = input_fn(params, config).repeat() features, _ = dataset.make_one_shot_iterator().get_next() inputs, labels = features["targets"], features["inputs"] inputs = tf.to_float(inputs) labels = tf.squeeze(labels) sess = tf.Session() model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.model, hparams, use_tpu=FLAGS.use_tpu) ch_model = adv_attack_utils.T2TAttackModel(model_fn, params, config) acc_mask = None probs = ch_model.get_probs(inputs) if FLAGS.ignore_incorrect: preds = tf.argmax(probs, -1) preds = tf.squeeze(preds) acc_mask = tf.to_float(tf.equal(labels, preds)) one_hot_labels = tf.one_hot(labels, probs.shape[-1]) attack = create_attack(attack_params.attack)(ch_model, sess=sess) # Restore weights saver = tf.train.Saver() checkpoint_path = os.path.expanduser(FLAGS.output_dir or FLAGS.checkpoint_path) saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path)) # reuse variables tf.get_variable_scope().reuse_variables() def compute_accuracy(x, labels, mask): preds = ch_model.get_probs(x) preds = tf.squeeze(preds) preds = tf.argmax(preds, -1, output_type=labels.dtype) _, acc_update_op = tf.metrics.accuracy( labels=labels, predictions=preds, weights=mask) sess.run(tf.initialize_local_variables()) for _ in range(FLAGS.eval_steps): acc = sess.run(acc_update_op) return acc acc = compute_accuracy(inputs, labels, acc_mask) epsilon_acc_pairs = [(0.0, acc)] for epsilon in attack_params.attack_epsilons: attack_params.eps = epsilon adv_x = attack.generate(inputs, y=one_hot_labels, **attack_params.values()) acc = compute_accuracy(adv_x, labels, acc_mask) epsilon_acc_pairs.append((epsilon, acc)) for epsilon, acc in epsilon_acc_pairs: tf.logging.info("Accuracy @ eps=%f: %f" % (epsilon, acc))