Exemplo n.º 1
0
def main(argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    t2t_trainer.maybe_log_registry_and_exit()

    if FLAGS.cloud_mlengine:
        cloud_mlengine.launch()
        return

    if FLAGS.generate_data:
        t2t_trainer.generate_data()

    if cloud_mlengine.job_dir():
        FLAGS.output_dir = cloud_mlengine.job_dir()

    if argv:
        t2t_trainer.set_hparams_from_args(argv[1:])

    root_output_dir = FLAGS.output_dir

    if FLAGS.teacher_dir:
        teacher_dir = FLAGS.teacher_dir
    else:
        teacher_dir = os.path.join(root_output_dir, "teacher")

    # Train Teacher ============
    if FLAGS.skip_teacher_training:
        tf.logging.info("training teacher skipped")
    else:
        hparams = t2t_trainer.create_hparams()
        hparams.distill_phase = "train"
        FLAGS.output_dir = teacher_dir

        exp_fn = t2t_trainer.create_experiment_fn()
        run_config = t2t_trainer.create_run_config(hparams)
        exp = exp_fn(run_config, hparams)
        if t2t_trainer.is_chief():
            t2t_trainer.save_metadata(hparams)
        t2t_trainer.execute_schedule(exp)

    # ==========================
    # Train Student ============
    hparams = t2t_trainer.create_hparams()
    hparams.add_hparam("teacher_dir", teacher_dir)
    hparams.distill_phase = "distill"
    if FLAGS.student_dir:
        student_dir = FLAGS.student_dir
    else:
        student_dir = os.path.join(root_output_dir, "student")
    FLAGS.output_dir = student_dir
    hparams.add_hparam("student_dir", student_dir)

    exp_fn = t2t_trainer.create_experiment_fn()
    run_config = t2t_trainer.create_run_config(hparams)
    exp = exp_fn(run_config, hparams)

    if t2t_trainer.is_chief():
        t2t_trainer.save_metadata(hparams)
    t2t_trainer.execute_schedule(exp)
Exemplo n.º 2
0
def create_teacher_experiment(run_config, hparams, argv):
  """Creates experiment function."""
  tf.logging.info("training teacher")
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  t2t_trainer.maybe_log_registry_and_exit()

  if FLAGS.cloud_mlengine:
    return cloud_mlengine.launch()

  if FLAGS.generate_data:
    t2t_trainer.generate_data()

  if cloud_mlengine.job_dir():
    FLAGS.output_dir = cloud_mlengine.job_dir()

  if argv:
    t2t_trainer.set_hparams_from_args(argv[1:])

  with t2t_trainer.maybe_cloud_tpu():
    hparams.distill_phase = "train"
    exp_fn = t2t_trainer.create_experiment_fn()
    exp = exp_fn(run_config, hparams)
    return exp
Exemplo n.º 3
0
def create_student_experiment(run_config, hparams, argv):
    """Creates experiment function."""
    tf.logging.info("training student")
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    t2t_trainer.maybe_log_registry_and_exit()

    if FLAGS.cloud_mlengine:
        return cloud_mlengine.launch()

    if FLAGS.generate_data:
        t2t_trainer.generate_data()

    if cloud_mlengine.job_dir():
        FLAGS.output_dir = cloud_mlengine.job_dir()

    if argv:
        t2t_trainer.set_hparams_from_args(argv[1:])

    hparams.add_hparam("teacher_dir", FLAGS.teacher_dir)
    hparams.add_hparam("student_dir", FLAGS.student_dir)
    hparams.distill_phase = "distill"
    exp_fn = t2t_trainer.create_experiment_fn()
    exp = exp_fn(run_config, hparams)
    return exp
Exemplo n.º 4
0
def main(argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  t2t_trainer.maybe_log_registry_and_exit()


  if FLAGS.generate_data:
    t2t_trainer.generate_data()

  if argv:
    t2t_trainer.set_hparams_from_args(argv[1:])
  hparams = t2t_trainer.create_hparams()
  trainer_lib.add_problem_hparams(hparams, FLAGS.problem)
  pruning_params = create_pruning_params()
  pruning_strategy = create_pruning_strategy(pruning_params.strategy)

  config = t2t_trainer.create_run_config(hparams)
  params = {"batch_size": hparams.batch_size}

  # add "_rev" as a hack to avoid image standardization
  problem = registry.problem(FLAGS.problem)
  input_fn = problem.make_estimator_input_fn(tf.estimator.ModeKeys.EVAL,
                                             hparams)
  dataset = input_fn(params, config).repeat()
  features, labels = dataset.make_one_shot_iterator().get_next()

  sess = tf.Session()

  model_fn = t2t_model.T2TModel.make_estimator_model_fn(
      FLAGS.model, hparams, use_tpu=FLAGS.use_tpu)
  spec = model_fn(
      features,
      labels,
      tf.estimator.ModeKeys.EVAL,
      params=hparams,
      config=config)

  # Restore weights
  saver = tf.train.Saver()
  checkpoint_path = os.path.expanduser(FLAGS.output_dir or
                                       FLAGS.checkpoint_path)
  saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path))

  def eval_model():
    preds = spec.predictions["predictions"]
    preds = tf.argmax(preds, -1, output_type=labels.dtype)
    _, acc_update_op = tf.metrics.accuracy(labels=labels, predictions=preds)
    sess.run(tf.initialize_local_variables())
    for _ in range(FLAGS.eval_steps):
      acc = sess.run(acc_update_op)
    return acc

  pruning_utils.sparsify(sess, eval_model, pruning_strategy, pruning_params)
Exemplo n.º 5
0
def main(argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  t2t_trainer.maybe_log_registry_and_exit()


  if FLAGS.generate_data:
    t2t_trainer.generate_data()

  if argv:
    t2t_trainer.set_hparams_from_args(argv[1:])
  hparams = t2t_trainer.create_hparams()
  trainer_lib.add_problem_hparams(hparams, FLAGS.problem)
  pruning_params = create_pruning_params()
  pruning_strategy = create_pruning_strategy(pruning_params.strategy)

  config = t2t_trainer.create_run_config(hparams)
  params = {"batch_size": hparams.batch_size}

  # add "_rev" as a hack to avoid image standardization
  problem = registry.problem(FLAGS.problem)
  input_fn = problem.make_estimator_input_fn(tf.estimator.ModeKeys.EVAL,
                                             hparams)
  dataset = input_fn(params, config).repeat()
  features, labels = dataset.make_one_shot_iterator().get_next()

  sess = tf.Session()

  model_fn = t2t_model.T2TModel.make_estimator_model_fn(
      FLAGS.model, hparams, use_tpu=FLAGS.use_tpu)
  spec = model_fn(
      features,
      labels,
      tf.estimator.ModeKeys.EVAL,
      params=hparams,
      config=config)

  # Restore weights
  saver = tf.train.Saver()
  checkpoint_path = os.path.expanduser(FLAGS.output_dir or
                                       FLAGS.checkpoint_path)
  saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path))

  def eval_model():
    preds = spec.predictions["predictions"]
    preds = tf.argmax(preds, -1, output_type=labels.dtype)
    _, acc_update_op = tf.metrics.accuracy(labels=labels, predictions=preds)
    sess.run(tf.initialize_local_variables())
    for _ in range(FLAGS.eval_steps):
      acc = sess.run(acc_update_op)
    return acc

  pruning_utils.sparsify(sess, eval_model, pruning_strategy, pruning_params)
Exemplo n.º 6
0
def main(argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  t2t_trainer.maybe_log_registry_and_exit()

  if FLAGS.cloud_mlengine:
    cloud_mlengine.launch()
    return

  if FLAGS.generate_data:
    t2t_trainer.generate_data()

  if cloud_mlengine.job_dir():
    FLAGS.output_dir = cloud_mlengine.job_dir()

  if argv:
    t2t_trainer.set_hparams_from_args(argv[1:])

  with t2t_trainer.maybe_cloud_tpu():
    root_output_dir = FLAGS.output_dir

    # Train Teacher ============
    hparams = t2t_trainer.create_hparams()
    hparams.distill_phase = "train"
    teacher_dir = os.path.join(root_output_dir, "teacher")
    FLAGS.output_dir = teacher_dir

    exp_fn = t2t_trainer.create_experiment_fn()
    run_config = t2t_trainer.create_run_config(hparams)
    exp = exp_fn(run_config, hparams)
    if t2t_trainer.is_chief():
      t2t_trainer.save_metadata(hparams)
    t2t_trainer.execute_schedule(exp)
    # ==========================
    # Train Student ============
    hparams = t2t_trainer.create_hparams()
    hparams.add_hparam("teacher_dir", teacher_dir)
    hparams.distill_phase = "distill"
    student_dir = os.path.join(root_output_dir, "student")
    FLAGS.output_dir = student_dir

    exp_fn = t2t_trainer.create_experiment_fn()
    run_config = t2t_trainer.create_run_config(hparams)
    exp = exp_fn(run_config, hparams)

    if t2t_trainer.is_chief():
      t2t_trainer.save_metadata(hparams)
    t2t_trainer.execute_schedule(exp)
Exemplo n.º 7
0
def main(argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  t2t_trainer.maybe_log_registry_and_exit()


  if FLAGS.cloud_mlengine:
    cloud_mlengine.launch()
    return

  if FLAGS.generate_data:
    t2t_trainer.generate_data()

  if cloud_mlengine.job_dir():
    FLAGS.output_dir = cloud_mlengine.job_dir()

  if argv:
    t2t_trainer.set_hparams_from_args(argv[1:])

  if FLAGS.surrogate_attack:
    tf.logging.warn("Performing surrogate model attack.")
    sur_hparams = create_surrogate_hparams()
    trainer_lib.add_problem_hparams(sur_hparams, FLAGS.problem)

  hparams = t2t_trainer.create_hparams()
  trainer_lib.add_problem_hparams(hparams, FLAGS.problem)

  attack_params = create_attack_params()
  attack_params.add_hparam(attack_params.epsilon_name, 0.0)

  if FLAGS.surrogate_attack:
    sur_config = create_surrogate_run_config(sur_hparams)
  config = t2t_trainer.create_run_config(hparams)
  params = {
      "batch_size": hparams.batch_size,
      "use_tpu": FLAGS.use_tpu,
  }

  # add "_rev" as a hack to avoid image standardization
  problem = registry.problem(FLAGS.problem + "_rev")

  inputs, labels, features = prepare_data(problem, hparams, params, config)

  sess = tf.Session()

  if FLAGS.surrogate_attack:
    sur_model_fn = t2t_model.T2TModel.make_estimator_model_fn(
        FLAGS.surrogate_model, sur_hparams, use_tpu=FLAGS.use_tpu)
    sur_ch_model = adv_attack_utils.T2TAttackModel(
        sur_model_fn, features, params, sur_config, scope="surrogate")
    # Dummy call to construct graph
    sur_ch_model.get_probs(inputs)

    checkpoint_path = os.path.expanduser(FLAGS.surrogate_output_dir)
    tf.train.init_from_checkpoint(
        tf.train.latest_checkpoint(checkpoint_path), {"/": "surrogate/"})
    sess.run(tf.global_variables_initializer())

  other_vars = set(tf.global_variables())

  model_fn = t2t_model.T2TModel.make_estimator_model_fn(
      FLAGS.model, hparams)
  ch_model = adv_attack_utils.T2TAttackModel(model_fn, features, params, config)

  acc_mask = None
  probs = ch_model.get_probs(inputs)
  if FLAGS.ignore_incorrect:
    preds = tf.argmax(probs, -1, output_type=labels.dtype)
    preds = tf.reshape(preds, labels.shape)
    acc_mask = tf.to_float(tf.equal(labels, preds))
  one_hot_labels = tf.one_hot(labels, probs.shape[-1])

  if FLAGS.surrogate_attack:
    attack = create_attack(attack_params.attack)(sur_ch_model, sess=sess)
  else:
    attack = create_attack(attack_params.attack)(ch_model, sess=sess)

  new_vars = set(tf.global_variables()) - other_vars

  # Restore weights
  saver = tf.train.Saver(new_vars)
  checkpoint_path = os.path.expanduser(FLAGS.output_dir)
  saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path))

  # reuse variables
  tf.get_variable_scope().reuse_variables()

  def compute_accuracy(x, l, mask):
    """Compute model accuracy."""
    preds = ch_model.get_probs(x)
    preds = tf.squeeze(preds)
    preds = tf.argmax(preds, -1, output_type=l.dtype)

    _, acc_update_op = tf.metrics.accuracy(l, preds, weights=mask)

    if FLAGS.surrogate_attack:
      preds = sur_ch_model.get_probs(x)
      preds = tf.squeeze(preds)
      preds = tf.argmax(preds, -1, output_type=l.dtype)
      acc_update_op = tf.tuple((acc_update_op,
                                tf.metrics.accuracy(l, preds, weights=mask)[1]))

    sess.run(tf.initialize_local_variables())
    for i in range(FLAGS.eval_steps):
      tf.logging.info(
          "\tEvaluating batch [%d / %d]" % (i + 1, FLAGS.eval_steps))
      acc = sess.run(acc_update_op)
    if FLAGS.surrogate_attack:
      tf.logging.info("\tFinal acc: (%.4f, %.4f)" % (acc[0], acc[1]))
    else:
      tf.logging.info("\tFinal acc: %.4f" % acc)
    return acc

  epsilon_acc_pairs = []
  for epsilon in attack_params.attack_epsilons:
    tf.logging.info("Attacking @ eps=%.4f" % epsilon)
    attack_params.set_hparam(attack_params.epsilon_name, epsilon)
    adv_x = attack.generate(inputs, y=one_hot_labels, **attack_params.values())
    acc = compute_accuracy(adv_x, labels, acc_mask)
    epsilon_acc_pairs.append((epsilon, acc))

  for epsilon, acc in epsilon_acc_pairs:
    if FLAGS.surrogate_attack:
      tf.logging.info(
          "Accuracy @ eps=%.4f: (%.4f, %.4f)" % (epsilon, acc[0], acc[1]))
    else:
      tf.logging.info("Accuracy @ eps=%.4f: %.4f" % (epsilon, acc))
Exemplo n.º 8
0
def main(argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  t2t_trainer.maybe_log_registry_and_exit()


  if FLAGS.cloud_mlengine:
    cloud_mlengine.launch()
    return

  if FLAGS.generate_data:
    t2t_trainer.generate_data()

  if cloud_mlengine.job_dir():
    FLAGS.output_dir = cloud_mlengine.job_dir()

  if argv:
    t2t_trainer.set_hparams_from_args(argv[1:])

  if FLAGS.surrogate_attack:
    tf.logging.warn("Performing surrogate model attack.")
    sur_hparams = create_surrogate_hparams()
    trainer_lib.add_problem_hparams(sur_hparams, FLAGS.problem)

  hparams = t2t_trainer.create_hparams()
  trainer_lib.add_problem_hparams(hparams, FLAGS.problem)

  attack_params = create_attack_params()
  attack_params.add_hparam(attack_params.epsilon_name, 0.0)

  if FLAGS.surrogate_attack:
    sur_config = create_surrogate_run_config(sur_hparams)
  config = t2t_trainer.create_run_config(hparams)
  params = {
      "batch_size": hparams.batch_size,
      "use_tpu": FLAGS.use_tpu,
  }

  # add "_rev" as a hack to avoid image standardization
  problem = registry.problem(FLAGS.problem + "_rev")

  inputs, labels, features = prepare_data(problem, hparams, params, config)

  sess = tf.Session()

  if FLAGS.surrogate_attack:
    sur_model_fn = t2t_model.T2TModel.make_estimator_model_fn(
        FLAGS.surrogate_model, sur_hparams, use_tpu=FLAGS.use_tpu)
    sur_ch_model = adv_attack_utils.T2TAttackModel(
        sur_model_fn, features, params, sur_config, scope="surrogate")
    # Dummy call to construct graph
    sur_ch_model.get_probs(inputs)

    checkpoint_path = os.path.expanduser(FLAGS.surrogate_output_dir)
    tf.contrib.framework.init_from_checkpoint(
        tf.train.latest_checkpoint(checkpoint_path), {"/": "surrogate/"})
    sess.run(tf.global_variables_initializer())

  other_vars = set(tf.global_variables())

  model_fn = t2t_model.T2TModel.make_estimator_model_fn(
      FLAGS.model, hparams)
  ch_model = adv_attack_utils.T2TAttackModel(model_fn, features, params, config)

  acc_mask = None
  probs = ch_model.get_probs(inputs)
  if FLAGS.ignore_incorrect:
    preds = tf.argmax(probs, -1, output_type=labels.dtype)
    preds = tf.reshape(preds, labels.shape)
    acc_mask = tf.to_float(tf.equal(labels, preds))
  one_hot_labels = tf.one_hot(labels, probs.shape[-1])

  if FLAGS.surrogate_attack:
    attack = create_attack(attack_params.attack)(sur_ch_model, sess=sess)
  else:
    attack = create_attack(attack_params.attack)(ch_model, sess=sess)

  new_vars = set(tf.global_variables()) - other_vars

  # Restore weights
  saver = tf.train.Saver(new_vars)
  checkpoint_path = os.path.expanduser(FLAGS.output_dir)
  saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path))

  # reuse variables
  tf.get_variable_scope().reuse_variables()

  def compute_accuracy(x, l, mask):
    """Compute model accuracy."""
    preds = ch_model.get_probs(x)
    preds = tf.squeeze(preds)
    preds = tf.argmax(preds, -1, output_type=l.dtype)

    _, acc_update_op = tf.metrics.accuracy(l, preds, weights=mask)

    if FLAGS.surrogate_attack:
      preds = sur_ch_model.get_probs(x)
      preds = tf.squeeze(preds)
      preds = tf.argmax(preds, -1, output_type=l.dtype)
      acc_update_op = tf.tuple((acc_update_op,
                                tf.metrics.accuracy(l, preds, weights=mask)[1]))

    sess.run(tf.initialize_local_variables())
    for i in range(FLAGS.eval_steps):
      tf.logging.info(
          "\tEvaluating batch [%d / %d]" % (i + 1, FLAGS.eval_steps))
      acc = sess.run(acc_update_op)
    if FLAGS.surrogate_attack:
      tf.logging.info("\tFinal acc: (%.4f, %.4f)" % (acc[0], acc[1]))
    else:
      tf.logging.info("\tFinal acc: %.4f" % acc)
    return acc

  epsilon_acc_pairs = []
  for epsilon in attack_params.attack_epsilons:
    tf.logging.info("Attacking @ eps=%.4f" % epsilon)
    attack_params.set_hparam(attack_params.epsilon_name, epsilon)
    adv_x = attack.generate(inputs, y=one_hot_labels, **attack_params.values())
    acc = compute_accuracy(adv_x, labels, acc_mask)
    epsilon_acc_pairs.append((epsilon, acc))

  for epsilon, acc in epsilon_acc_pairs:
    if FLAGS.surrogate_attack:
      tf.logging.info(
          "Accuracy @ eps=%.4f: (%.4f, %.4f)" % (epsilon, acc[0], acc[1]))
    else:
      tf.logging.info("Accuracy @ eps=%.4f: %.4f" % (epsilon, acc))
Exemplo n.º 9
0
def main(argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  t2t_trainer.maybe_log_registry_and_exit()


  if FLAGS.generate_data:
    t2t_trainer.generate_data()
  if argv:
    t2t_trainer.set_hparams_from_args(argv[1:])

  # hparams = t2t_trainer.create_hparams()
  # hparams.add_hparam("data_dir", FLAGS.data_dir)
  # trainer_lib.add_problem_hparams(hparams, FLAGS.problem)
  hparams_path = os.path.join(FLAGS.output_dir, "hparams.json")
  hparams = trainer_lib.create_hparams(
      FLAGS.hparams_set, FLAGS.hparams, data_dir=FLAGS.data_dir,
      problem_name=FLAGS.problem, hparams_path=hparams_path)
  hparams.add_hparam("model_dir", FLAGS.output_dir)

  config = t2t_trainer.create_run_config(hparams)
  params = {"batch_size": hparams.batch_size}

  # add "_rev" as a hack to avoid image standardization
  problem = registry.problem(FLAGS.problem)
  input_fn = problem.make_estimator_input_fn(tf.estimator.ModeKeys.EVAL,
                                             hparams)

  sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))

  model_fn = t2t_model.T2TModel.make_estimator_model_fn(
      FLAGS.model, hparams, use_tpu=False)

  dataset = input_fn(params, config).repeat()
  dataset_iteraor = dataset.make_one_shot_iterator()
  features, labels = dataset_iteraor.get_next()

  
  # tf.logging.info("### t2t_wei_feat_distrib.py features %s", features)
  spec = model_fn(
      features,
      labels,
      tf.estimator.ModeKeys.EVAL,
      params=hparams,
      config=config)


  # get the summary model structure graph
  summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)
  
  # Restore weights
  saver = tf.train.Saver()
  checkpoint_path = os.path.expanduser(FLAGS.output_dir or
                                       FLAGS.checkpoint_path)
  tf.logging.info("### t2t_wei_feat_distrib.py checkpoint_path %s", checkpoint_path)
  # saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path))

  # Load weights from checkpoint.
  ckpts = tf.train.get_checkpoint_state(checkpoint_path)
  ckpt = ckpts.model_checkpoint_path
  saver.restore(sess, ckpt)

  # saver.restore(sess, checkpoint_path+'/model.ckpt-1421000')
  # initialize_from_ckpt(checkpoint_path)

  # get parameter
  pruning_params = create_pruning_params()
  pruning_strategy = create_pruning_strategy(pruning_params.strategy)

  # get evalutaion graph
  if 'image' in FLAGS.problem:
    acc, acc_update_op = get_eval_graph_image(spec, labels)
    tf.summary.scalar('accuracy', acc)
    # define evaluation function
    def eval_model():
      sess.run(tf.initialize_local_variables())
      for _ in range(FLAGS.eval_steps):
        acc = sess.run(acc_update_op)
      return acc
  elif 'translate' in FLAGS.problem:
    bleu_op = get_eval_graph_trans(spec, labels)
    # tf.summary.scalar('bleu', bleu_op)
    # define evaluation function
    def eval_model():
      bleu_value = 0
      # sess.run(tf.initialize_local_variables())
      # sess.run()
      # local_vars = tf.local_variables()
      # tf.logging.info("###!!!!!!! t2t_wei_feat_distrib.py local_vars %s", local_vars)
      # for _ in range(FLAGS.eval_steps):
      for _ in range(FLAGS.eval_steps):
        # outputs_tensor, labels_tensor, preds_tensor = sess.run([outputs, labels, preds])
        bleu = sess.run(bleu_op)
        # tf.logging.info("### t2t_wei_feat_distrib.py outputs_tensor %s", outputs_tensor[0].shape)
        # tf.logging.info("### t2t_wei_feat_distrib.py labels_tensor %s", labels_tensor[0].shape)
        # tf.logging.info("### t2t_wei_feat_distrib.py preds %s", preds_tensor[0].shape)
        bleu_value += bleu
      bleu_value /= FLAGS.eval_steps
      return bleu_value

  # get weight distribution graph
  wei_feat_distrib.get_weight_distrib_graph(pruning_params)


  
  # do accuracy sparsity tradeoff for model weights
  wei_feat_distrib.wei_sparsity_acc_tradeoff(sess, eval_model, pruning_strategy, pruning_params, summary_writer)

  # do accuracy sparsity tradeoff for model weights

  # save the summary
  summary_writer.close()


  sess.run(tf.initialize_local_variables())
  preds = spec.predictions["predictions"]
  # features_shape=tf.shape(features)
  pred_shape=tf.shape(preds)
  labels_shape=tf.shape(labels)
def main(argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    t2t_trainer.maybe_log_registry_and_exit()

    if FLAGS.cloud_mlengine:
        cloud_mlengine.launch()
        return

    if FLAGS.generate_data:
        t2t_trainer.generate_data()

    if cloud_mlengine.job_dir():
        FLAGS.output_dir = cloud_mlengine.job_dir()

    if argv:
        t2t_trainer.set_hparams_from_args(argv[1:])
    hparams = t2t_trainer.create_hparams()
    trainer_lib.add_problem_hparams(hparams, FLAGS.problem)
    attack_params = create_attack_params()
    attack_params.add_hparam("eps", 0.0)

    config = t2t_trainer.create_run_config(hparams)
    params = {"batch_size": hparams.batch_size}

    # add "_rev" as a hack to avoid image standardization
    problem = registry.problem(FLAGS.problem + "_rev")
    input_fn = problem.make_estimator_input_fn(tf.estimator.ModeKeys.EVAL,
                                               hparams)
    dataset = input_fn(params, config).repeat()
    features, _ = dataset.make_one_shot_iterator().get_next()
    inputs, labels = features["targets"], features["inputs"]
    inputs = tf.to_float(inputs)
    labels = tf.squeeze(labels)

    sess = tf.Session()

    model_fn = t2t_model.T2TModel.make_estimator_model_fn(
        FLAGS.model, hparams, use_tpu=FLAGS.use_tpu)
    ch_model = adv_attack_utils.T2TAttackModel(model_fn, params, config)

    acc_mask = None
    probs = ch_model.get_probs(inputs)
    if FLAGS.ignore_incorrect:
        preds = tf.argmax(probs, -1)
        preds = tf.squeeze(preds)
        acc_mask = tf.to_float(tf.equal(labels, preds))
    one_hot_labels = tf.one_hot(labels, probs.shape[-1])

    attack = create_attack(attack_params.attack)(ch_model, sess=sess)

    # Restore weights
    saver = tf.train.Saver()
    checkpoint_path = os.path.expanduser(FLAGS.output_dir
                                         or FLAGS.checkpoint_path)
    saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path))

    # reuse variables
    tf.get_variable_scope().reuse_variables()

    def compute_accuracy(x, labels, mask):
        preds = ch_model.get_probs(x)
        preds = tf.squeeze(preds)
        preds = tf.argmax(preds, -1, output_type=labels.dtype)
        _, acc_update_op = tf.metrics.accuracy(labels=labels,
                                               predictions=preds,
                                               weights=mask)
        sess.run(tf.initialize_local_variables())
        for _ in range(FLAGS.eval_steps):
            acc = sess.run(acc_update_op)
        return acc

    acc = compute_accuracy(inputs, labels, acc_mask)
    epsilon_acc_pairs = [(0.0, acc)]
    for epsilon in attack_params.attack_epsilons:
        attack_params.eps = epsilon
        adv_x = attack.generate(inputs,
                                y=one_hot_labels,
                                **attack_params.values())
        acc = compute_accuracy(adv_x, labels, acc_mask)
        epsilon_acc_pairs.append((epsilon, acc))

    for epsilon, acc in epsilon_acc_pairs:
        tf.logging.info("Accuracy @ eps=%f: %f" % (epsilon, acc))
Exemplo n.º 11
0
def main(argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  t2t_trainer.maybe_log_registry_and_exit()


  if FLAGS.cloud_mlengine:
    cloud_mlengine.launch()
    return

  if FLAGS.generate_data:
    t2t_trainer.generate_data()

  if cloud_mlengine.job_dir():
    FLAGS.output_dir = cloud_mlengine.job_dir()

  if argv:
    t2t_trainer.set_hparams_from_args(argv[1:])
  hparams = t2t_trainer.create_hparams()
  trainer_lib.add_problem_hparams(hparams, FLAGS.problem)
  attack_params = create_attack_params()
  attack_params.add_hparam("eps", 0.0)

  config = t2t_trainer.create_run_config(hparams)
  params = {"batch_size": hparams.batch_size}

  # add "_rev" as a hack to avoid image standardization
  problem = registry.problem(FLAGS.problem + "_rev")
  input_fn = problem.make_estimator_input_fn(
      tf.estimator.ModeKeys.EVAL, hparams)
  dataset = input_fn(params, config).repeat()
  features, _ = dataset.make_one_shot_iterator().get_next()
  inputs, labels = features["targets"], features["inputs"]
  inputs = tf.to_float(inputs)
  labels = tf.squeeze(labels)

  sess = tf.Session()

  model_fn = t2t_model.T2TModel.make_estimator_model_fn(
      FLAGS.model, hparams, use_tpu=FLAGS.use_tpu)
  ch_model = adv_attack_utils.T2TAttackModel(model_fn, params, config)

  acc_mask = None
  probs = ch_model.get_probs(inputs)
  if FLAGS.ignore_incorrect:
    preds = tf.argmax(probs, -1)
    preds = tf.squeeze(preds)
    acc_mask = tf.to_float(tf.equal(labels, preds))
  one_hot_labels = tf.one_hot(labels, probs.shape[-1])

  attack = create_attack(attack_params.attack)(ch_model, sess=sess)

  # Restore weights
  saver = tf.train.Saver()
  checkpoint_path = os.path.expanduser(FLAGS.output_dir or
                                       FLAGS.checkpoint_path)
  saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path))

  # reuse variables
  tf.get_variable_scope().reuse_variables()

  def compute_accuracy(x, labels, mask):
    preds = ch_model.get_probs(x)
    preds = tf.squeeze(preds)
    preds = tf.argmax(preds, -1, output_type=labels.dtype)
    _, acc_update_op = tf.metrics.accuracy(
        labels=labels, predictions=preds, weights=mask)
    sess.run(tf.initialize_local_variables())
    for _ in range(FLAGS.eval_steps):
      acc = sess.run(acc_update_op)
    return acc

  acc = compute_accuracy(inputs, labels, acc_mask)
  epsilon_acc_pairs = [(0.0, acc)]
  for epsilon in attack_params.attack_epsilons:
    attack_params.eps = epsilon
    adv_x = attack.generate(inputs, y=one_hot_labels, **attack_params.values())
    acc = compute_accuracy(adv_x, labels, acc_mask)
    epsilon_acc_pairs.append((epsilon, acc))

  for epsilon, acc in epsilon_acc_pairs:
    tf.logging.info("Accuracy @ eps=%f: %f" % (epsilon, acc))