def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: t2t_trainer.generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) root_output_dir = FLAGS.output_dir if FLAGS.teacher_dir: teacher_dir = FLAGS.teacher_dir else: teacher_dir = os.path.join(root_output_dir, "teacher") # Train Teacher ============ if FLAGS.skip_teacher_training: tf.logging.info("training teacher skipped") else: hparams = t2t_trainer.create_hparams() hparams.distill_phase = "train" FLAGS.output_dir = teacher_dir exp_fn = t2t_trainer.create_experiment_fn() run_config = t2t_trainer.create_run_config(hparams) exp = exp_fn(run_config, hparams) if t2t_trainer.is_chief(): t2t_trainer.save_metadata(hparams) t2t_trainer.execute_schedule(exp) # ========================== # Train Student ============ hparams = t2t_trainer.create_hparams() hparams.add_hparam("teacher_dir", teacher_dir) hparams.distill_phase = "distill" if FLAGS.student_dir: student_dir = FLAGS.student_dir else: student_dir = os.path.join(root_output_dir, "student") FLAGS.output_dir = student_dir hparams.add_hparam("student_dir", student_dir) exp_fn = t2t_trainer.create_experiment_fn() run_config = t2t_trainer.create_run_config(hparams) exp = exp_fn(run_config, hparams) if t2t_trainer.is_chief(): t2t_trainer.save_metadata(hparams) t2t_trainer.execute_schedule(exp)
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: t2t_trainer.generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) with t2t_trainer.maybe_cloud_tpu(): root_output_dir = FLAGS.output_dir # Train Teacher ============ hparams = t2t_trainer.create_hparams() hparams.distill_phase = "train" teacher_dir = os.path.join(root_output_dir, "teacher") FLAGS.output_dir = teacher_dir exp_fn = t2t_trainer.create_experiment_fn() run_config = t2t_trainer.create_run_config(hparams) exp = exp_fn(run_config, hparams) if t2t_trainer.is_chief(): t2t_trainer.save_metadata(hparams) t2t_trainer.execute_schedule(exp) # ========================== # Train Student ============ hparams = t2t_trainer.create_hparams() hparams.add_hparam("teacher_dir", teacher_dir) hparams.distill_phase = "distill" student_dir = os.path.join(root_output_dir, "student") FLAGS.output_dir = student_dir exp_fn = t2t_trainer.create_experiment_fn() run_config = t2t_trainer.create_run_config(hparams) exp = exp_fn(run_config, hparams) if t2t_trainer.is_chief(): t2t_trainer.save_metadata(hparams) t2t_trainer.execute_schedule(exp)
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.generate_data: t2t_trainer.generate_data() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) hparams = t2t_trainer.create_hparams() trainer_lib.add_problem_hparams(hparams, FLAGS.problem) pruning_params = create_pruning_params() pruning_strategy = create_pruning_strategy(pruning_params.strategy) config = t2t_trainer.create_run_config(hparams) params = {"batch_size": hparams.batch_size} # add "_rev" as a hack to avoid image standardization problem = registry.problem(FLAGS.problem) input_fn = problem.make_estimator_input_fn(tf.estimator.ModeKeys.EVAL, hparams) dataset = input_fn(params, config).repeat() features, labels = dataset.make_one_shot_iterator().get_next() sess = tf.Session() model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.model, hparams, use_tpu=FLAGS.use_tpu) spec = model_fn( features, labels, tf.estimator.ModeKeys.EVAL, params=hparams, config=config) # Restore weights saver = tf.train.Saver() checkpoint_path = os.path.expanduser(FLAGS.output_dir or FLAGS.checkpoint_path) saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path)) def eval_model(): preds = spec.predictions["predictions"] preds = tf.argmax(preds, -1, output_type=labels.dtype) _, acc_update_op = tf.metrics.accuracy(labels=labels, predictions=preds) sess.run(tf.initialize_local_variables()) for _ in range(FLAGS.eval_steps): acc = sess.run(acc_update_op) return acc pruning_utils.sparsify(sess, eval_model, pruning_strategy, pruning_params)
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.generate_data: t2t_trainer.generate_data() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) hparams = t2t_trainer.create_hparams() trainer_lib.add_problem_hparams(hparams, FLAGS.problem) pruning_params = create_pruning_params() pruning_strategy = create_pruning_strategy(pruning_params.strategy) config = t2t_trainer.create_run_config(hparams) params = {"batch_size": hparams.batch_size} # add "_rev" as a hack to avoid image standardization problem = registry.problem(FLAGS.problem) input_fn = problem.make_estimator_input_fn(tf.estimator.ModeKeys.EVAL, hparams) dataset = input_fn(params, config).repeat() features, labels = dataset.make_one_shot_iterator().get_next() sess = tf.Session() model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.model, hparams, use_tpu=FLAGS.use_tpu) spec = model_fn( features, labels, tf.estimator.ModeKeys.EVAL, params=hparams, config=config) # Restore weights saver = tf.train.Saver() checkpoint_path = os.path.expanduser(FLAGS.output_dir or FLAGS.checkpoint_path) saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path)) def eval_model(): preds = spec.predictions["predictions"] preds = tf.argmax(preds, -1, output_type=labels.dtype) _, acc_update_op = tf.metrics.accuracy(labels=labels, predictions=preds) sess.run(tf.initialize_local_variables()) for _ in range(FLAGS.eval_steps): acc = sess.run(acc_update_op) return acc pruning_utils.sparsify(sess, eval_model, pruning_strategy, pruning_params)
if obj1 in objs and obj2 in objs: found = True break assert found print('YAS') data_fields = dict(sequence=tf.FixedLenFeature([128], tf.int64), attention_mask=tf.FixedLenFeature([128 * 4], tf.int64), theorem=tf.FixedLenFeature([1], tf.int64), targets=tf.FixedLenFeature([8], tf.int64), depth=tf.FixedLenFeature([1], tf.int64)) geo_problem = t2t_trainer.registry.problem(tf.flags.FLAGS.problem) hparams = t2t_trainer.create_hparams() t2t_trainer.trainer_lib.add_problem_hparams(hparams, tf.flags.FLAGS.problem) hparams.batch_shuffle_size = 8 dataset = geo_problem.input_fn(tf.estimator.ModeKeys.TRAIN, hparams, data_dir=tf.flags.FLAGS.data_dir, params=None, config=None, force_repeat=False, prevent_repeat=False, dataset_kwargs=dict(output_buffer_size=8, shuffle_buffer_size=8)) features = dataset.make_one_shot_iterator().get_next()[0]
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: t2t_trainer.generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) if FLAGS.surrogate_attack: tf.logging.warn("Performing surrogate model attack.") sur_hparams = create_surrogate_hparams() trainer_lib.add_problem_hparams(sur_hparams, FLAGS.problem) hparams = t2t_trainer.create_hparams() trainer_lib.add_problem_hparams(hparams, FLAGS.problem) attack_params = create_attack_params() attack_params.add_hparam(attack_params.epsilon_name, 0.0) if FLAGS.surrogate_attack: sur_config = create_surrogate_run_config(sur_hparams) config = t2t_trainer.create_run_config(hparams) params = { "batch_size": hparams.batch_size, "use_tpu": FLAGS.use_tpu, } # add "_rev" as a hack to avoid image standardization problem = registry.problem(FLAGS.problem + "_rev") inputs, labels, features = prepare_data(problem, hparams, params, config) sess = tf.Session() if FLAGS.surrogate_attack: sur_model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.surrogate_model, sur_hparams, use_tpu=FLAGS.use_tpu) sur_ch_model = adv_attack_utils.T2TAttackModel( sur_model_fn, features, params, sur_config, scope="surrogate") # Dummy call to construct graph sur_ch_model.get_probs(inputs) checkpoint_path = os.path.expanduser(FLAGS.surrogate_output_dir) tf.train.init_from_checkpoint( tf.train.latest_checkpoint(checkpoint_path), {"/": "surrogate/"}) sess.run(tf.global_variables_initializer()) other_vars = set(tf.global_variables()) model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.model, hparams) ch_model = adv_attack_utils.T2TAttackModel(model_fn, features, params, config) acc_mask = None probs = ch_model.get_probs(inputs) if FLAGS.ignore_incorrect: preds = tf.argmax(probs, -1, output_type=labels.dtype) preds = tf.reshape(preds, labels.shape) acc_mask = tf.to_float(tf.equal(labels, preds)) one_hot_labels = tf.one_hot(labels, probs.shape[-1]) if FLAGS.surrogate_attack: attack = create_attack(attack_params.attack)(sur_ch_model, sess=sess) else: attack = create_attack(attack_params.attack)(ch_model, sess=sess) new_vars = set(tf.global_variables()) - other_vars # Restore weights saver = tf.train.Saver(new_vars) checkpoint_path = os.path.expanduser(FLAGS.output_dir) saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path)) # reuse variables tf.get_variable_scope().reuse_variables() def compute_accuracy(x, l, mask): """Compute model accuracy.""" preds = ch_model.get_probs(x) preds = tf.squeeze(preds) preds = tf.argmax(preds, -1, output_type=l.dtype) _, acc_update_op = tf.metrics.accuracy(l, preds, weights=mask) if FLAGS.surrogate_attack: preds = sur_ch_model.get_probs(x) preds = tf.squeeze(preds) preds = tf.argmax(preds, -1, output_type=l.dtype) acc_update_op = tf.tuple((acc_update_op, tf.metrics.accuracy(l, preds, weights=mask)[1])) sess.run(tf.initialize_local_variables()) for i in range(FLAGS.eval_steps): tf.logging.info( "\tEvaluating batch [%d / %d]" % (i + 1, FLAGS.eval_steps)) acc = sess.run(acc_update_op) if FLAGS.surrogate_attack: tf.logging.info("\tFinal acc: (%.4f, %.4f)" % (acc[0], acc[1])) else: tf.logging.info("\tFinal acc: %.4f" % acc) return acc epsilon_acc_pairs = [] for epsilon in attack_params.attack_epsilons: tf.logging.info("Attacking @ eps=%.4f" % epsilon) attack_params.set_hparam(attack_params.epsilon_name, epsilon) adv_x = attack.generate(inputs, y=one_hot_labels, **attack_params.values()) acc = compute_accuracy(adv_x, labels, acc_mask) epsilon_acc_pairs.append((epsilon, acc)) for epsilon, acc in epsilon_acc_pairs: if FLAGS.surrogate_attack: tf.logging.info( "Accuracy @ eps=%.4f: (%.4f, %.4f)" % (epsilon, acc[0], acc[1])) else: tf.logging.info("Accuracy @ eps=%.4f: %.4f" % (epsilon, acc))
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: t2t_trainer.generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) if FLAGS.surrogate_attack: tf.logging.warn("Performing surrogate model attack.") sur_hparams = create_surrogate_hparams() trainer_lib.add_problem_hparams(sur_hparams, FLAGS.problem) hparams = t2t_trainer.create_hparams() trainer_lib.add_problem_hparams(hparams, FLAGS.problem) attack_params = create_attack_params() attack_params.add_hparam(attack_params.epsilon_name, 0.0) if FLAGS.surrogate_attack: sur_config = create_surrogate_run_config(sur_hparams) config = t2t_trainer.create_run_config(hparams) params = { "batch_size": hparams.batch_size, "use_tpu": FLAGS.use_tpu, } # add "_rev" as a hack to avoid image standardization problem = registry.problem(FLAGS.problem + "_rev") inputs, labels, features = prepare_data(problem, hparams, params, config) sess = tf.Session() if FLAGS.surrogate_attack: sur_model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.surrogate_model, sur_hparams, use_tpu=FLAGS.use_tpu) sur_ch_model = adv_attack_utils.T2TAttackModel( sur_model_fn, features, params, sur_config, scope="surrogate") # Dummy call to construct graph sur_ch_model.get_probs(inputs) checkpoint_path = os.path.expanduser(FLAGS.surrogate_output_dir) tf.contrib.framework.init_from_checkpoint( tf.train.latest_checkpoint(checkpoint_path), {"/": "surrogate/"}) sess.run(tf.global_variables_initializer()) other_vars = set(tf.global_variables()) model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.model, hparams) ch_model = adv_attack_utils.T2TAttackModel(model_fn, features, params, config) acc_mask = None probs = ch_model.get_probs(inputs) if FLAGS.ignore_incorrect: preds = tf.argmax(probs, -1, output_type=labels.dtype) preds = tf.reshape(preds, labels.shape) acc_mask = tf.to_float(tf.equal(labels, preds)) one_hot_labels = tf.one_hot(labels, probs.shape[-1]) if FLAGS.surrogate_attack: attack = create_attack(attack_params.attack)(sur_ch_model, sess=sess) else: attack = create_attack(attack_params.attack)(ch_model, sess=sess) new_vars = set(tf.global_variables()) - other_vars # Restore weights saver = tf.train.Saver(new_vars) checkpoint_path = os.path.expanduser(FLAGS.output_dir) saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path)) # reuse variables tf.get_variable_scope().reuse_variables() def compute_accuracy(x, l, mask): """Compute model accuracy.""" preds = ch_model.get_probs(x) preds = tf.squeeze(preds) preds = tf.argmax(preds, -1, output_type=l.dtype) _, acc_update_op = tf.metrics.accuracy(l, preds, weights=mask) if FLAGS.surrogate_attack: preds = sur_ch_model.get_probs(x) preds = tf.squeeze(preds) preds = tf.argmax(preds, -1, output_type=l.dtype) acc_update_op = tf.tuple((acc_update_op, tf.metrics.accuracy(l, preds, weights=mask)[1])) sess.run(tf.initialize_local_variables()) for i in range(FLAGS.eval_steps): tf.logging.info( "\tEvaluating batch [%d / %d]" % (i + 1, FLAGS.eval_steps)) acc = sess.run(acc_update_op) if FLAGS.surrogate_attack: tf.logging.info("\tFinal acc: (%.4f, %.4f)" % (acc[0], acc[1])) else: tf.logging.info("\tFinal acc: %.4f" % acc) return acc epsilon_acc_pairs = [] for epsilon in attack_params.attack_epsilons: tf.logging.info("Attacking @ eps=%.4f" % epsilon) attack_params.set_hparam(attack_params.epsilon_name, epsilon) adv_x = attack.generate(inputs, y=one_hot_labels, **attack_params.values()) acc = compute_accuracy(adv_x, labels, acc_mask) epsilon_acc_pairs.append((epsilon, acc)) for epsilon, acc in epsilon_acc_pairs: if FLAGS.surrogate_attack: tf.logging.info( "Accuracy @ eps=%.4f: (%.4f, %.4f)" % (epsilon, acc[0], acc[1])) else: tf.logging.info("Accuracy @ eps=%.4f: %.4f" % (epsilon, acc))
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: t2t_trainer.generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) hparams = t2t_trainer.create_hparams() trainer_lib.add_problem_hparams(hparams, FLAGS.problem) attack_params = create_attack_params() attack_params.add_hparam("eps", 0.0) config = t2t_trainer.create_run_config(hparams) params = {"batch_size": hparams.batch_size} # add "_rev" as a hack to avoid image standardization problem = registry.problem(FLAGS.problem + "_rev") input_fn = problem.make_estimator_input_fn(tf.estimator.ModeKeys.EVAL, hparams) dataset = input_fn(params, config).repeat() features, _ = dataset.make_one_shot_iterator().get_next() inputs, labels = features["targets"], features["inputs"] inputs = tf.to_float(inputs) labels = tf.squeeze(labels) sess = tf.Session() model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.model, hparams, use_tpu=FLAGS.use_tpu) ch_model = adv_attack_utils.T2TAttackModel(model_fn, params, config) acc_mask = None probs = ch_model.get_probs(inputs) if FLAGS.ignore_incorrect: preds = tf.argmax(probs, -1) preds = tf.squeeze(preds) acc_mask = tf.to_float(tf.equal(labels, preds)) one_hot_labels = tf.one_hot(labels, probs.shape[-1]) attack = create_attack(attack_params.attack)(ch_model, sess=sess) # Restore weights saver = tf.train.Saver() checkpoint_path = os.path.expanduser(FLAGS.output_dir or FLAGS.checkpoint_path) saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path)) # reuse variables tf.get_variable_scope().reuse_variables() def compute_accuracy(x, labels, mask): preds = ch_model.get_probs(x) preds = tf.squeeze(preds) preds = tf.argmax(preds, -1, output_type=labels.dtype) _, acc_update_op = tf.metrics.accuracy(labels=labels, predictions=preds, weights=mask) sess.run(tf.initialize_local_variables()) for _ in range(FLAGS.eval_steps): acc = sess.run(acc_update_op) return acc acc = compute_accuracy(inputs, labels, acc_mask) epsilon_acc_pairs = [(0.0, acc)] for epsilon in attack_params.attack_epsilons: attack_params.eps = epsilon adv_x = attack.generate(inputs, y=one_hot_labels, **attack_params.values()) acc = compute_accuracy(adv_x, labels, acc_mask) epsilon_acc_pairs.append((epsilon, acc)) for epsilon, acc in epsilon_acc_pairs: tf.logging.info("Accuracy @ eps=%f: %f" % (epsilon, acc))
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: t2t_trainer.generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) hparams = t2t_trainer.create_hparams() trainer_lib.add_problem_hparams(hparams, FLAGS.problem) attack_params = create_attack_params() attack_params.add_hparam("eps", 0.0) config = t2t_trainer.create_run_config(hparams) params = {"batch_size": hparams.batch_size} # add "_rev" as a hack to avoid image standardization problem = registry.problem(FLAGS.problem + "_rev") input_fn = problem.make_estimator_input_fn( tf.estimator.ModeKeys.EVAL, hparams) dataset = input_fn(params, config).repeat() features, _ = dataset.make_one_shot_iterator().get_next() inputs, labels = features["targets"], features["inputs"] inputs = tf.to_float(inputs) labels = tf.squeeze(labels) sess = tf.Session() model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.model, hparams, use_tpu=FLAGS.use_tpu) ch_model = adv_attack_utils.T2TAttackModel(model_fn, params, config) acc_mask = None probs = ch_model.get_probs(inputs) if FLAGS.ignore_incorrect: preds = tf.argmax(probs, -1) preds = tf.squeeze(preds) acc_mask = tf.to_float(tf.equal(labels, preds)) one_hot_labels = tf.one_hot(labels, probs.shape[-1]) attack = create_attack(attack_params.attack)(ch_model, sess=sess) # Restore weights saver = tf.train.Saver() checkpoint_path = os.path.expanduser(FLAGS.output_dir or FLAGS.checkpoint_path) saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path)) # reuse variables tf.get_variable_scope().reuse_variables() def compute_accuracy(x, labels, mask): preds = ch_model.get_probs(x) preds = tf.squeeze(preds) preds = tf.argmax(preds, -1, output_type=labels.dtype) _, acc_update_op = tf.metrics.accuracy( labels=labels, predictions=preds, weights=mask) sess.run(tf.initialize_local_variables()) for _ in range(FLAGS.eval_steps): acc = sess.run(acc_update_op) return acc acc = compute_accuracy(inputs, labels, acc_mask) epsilon_acc_pairs = [(0.0, acc)] for epsilon in attack_params.attack_epsilons: attack_params.eps = epsilon adv_x = attack.generate(inputs, y=one_hot_labels, **attack_params.values()) acc = compute_accuracy(adv_x, labels, acc_mask) epsilon_acc_pairs.append((epsilon, acc)) for epsilon, acc in epsilon_acc_pairs: tf.logging.info("Accuracy @ eps=%f: %f" % (epsilon, acc))