Exemple #1
0
def t2t_decoder(problem_name, data_dir, decode_from_file, decode_to_file,
                checkpoint_path):
    trainer_lib.set_random_seed(FLAGS.random_seed)

    hp = trainer_lib.create_hparams(FLAGS.hparams_set,
                                    FLAGS.hparams,
                                    data_dir=os.path.expanduser(data_dir),
                                    problem_name=problem_name)

    decode_hp = decoding.decode_hparams(FLAGS.decode_hparams)
    decode_hp.shards = FLAGS.decode_shards
    decode_hp.shard_id = FLAGS.worker_id
    decode_in_memory = FLAGS.decode_in_memory or decode_hp.decode_in_memory
    decode_hp.decode_in_memory = decode_in_memory
    decode_hp.decode_to_file = decode_to_file
    decode_hp.decode_reference = None

    FLAGS.checkpoint_path = checkpoint_path
    estimator = trainer_lib.create_estimator(FLAGS.model,
                                             hp,
                                             t2t_trainer.create_run_config(hp),
                                             decode_hparams=decode_hp,
                                             use_tpu=FLAGS.use_tpu)

    decode_from_text_file(estimator,
                          problem_name,
                          decode_from_file,
                          hp,
                          decode_hp,
                          decode_to_file,
                          checkpoint_path=checkpoint_path)
def main(argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  log_registry()

  if FLAGS.cloud_mlengine:
    return cloud_mlengine.launch()

  if FLAGS.generate_data:
    generate_data()

  if cloud_mlengine.job_dir():
    FLAGS.output_dir = cloud_mlengine.job_dir()

  if argv:
    set_hparams_from_args(argv[1:])
  hparams = create_hparams()
  if is_chief():
    save_metadata(hparams)

  with maybe_cloud_tpu():
    exp_fn = create_experiment_fn()
    exp = exp_fn(create_run_config(hparams), hparams)
    execute_schedule(exp)
Exemple #3
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

    ckpt_dir = os.path.expanduser(FLAGS.output_dir)

    hparams = create_hparams()
    hparams.no_data_parallelism = True  # To clear the devices
    problem = hparams.problem

    if FLAGS.export_as_tfhub:
        export_as_tfhub_module(hparams, problem, ckpt_dir)
        return

    run_config = t2t_trainer.create_run_config(hparams)

    estimator = create_estimator(run_config, hparams)

    exporter = tf.estimator.FinalExporter(
        "exporter", lambda: problem.serving_input_fn(hparams), as_text=True)

    export_dir = os.path.join(ckpt_dir, "export")
    exporter.export(estimator,
                    export_dir,
                    checkpoint_path=tf.train.latest_checkpoint(ckpt_dir),
                    eval_result=None,
                    is_the_final_export=True)
Exemple #4
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

    if FLAGS.score_file:
        filename = os.path.expanduser(FLAGS.score_file)
        if not tf.gfile.Exists(filename):
            raise ValueError("The file to score doesn't exist: %s" % filename)
        results = score_file(filename)
        # if not FLAGS.decode_to_file:
        #     raise ValueError("To score a file, specify --decode_to_file for results.")
        # write_file = tf.gfile.Open(os.path.expanduser(FLAGS.decode_to_file), "w")
        # for sentence, score in results:
        #     write_file.write(sentence + "\t" + "SCORE:" + "%.6f\n" % score)
        # write_file.close()
        return

    hp = create_hparams()
    decode_hp = create_decode_hparams()
    run_config = t2t_trainer.create_run_config(hp)
    if FLAGS.disable_grappler_optimizations:
        run_config.session_config.graph_options.rewrite_options.disable_meta_optimizer = True

    # summary-hook in tf.estimator.EstimatorSpec requires
    # hparams.model_dir to be set.
    hp.add_hparam("model_dir", run_config.model_dir)

    estimator = trainer_lib.create_estimator(FLAGS.model,
                                             hp,
                                             run_config,
                                             decode_hparams=decode_hp,
                                             use_tpu=FLAGS.use_tpu)

    decode(estimator, hp, decode_hp)
Exemple #5
0
def main(argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    t2t_trainer.maybe_log_registry_and_exit()

    if FLAGS.cloud_mlengine:
        cloud_mlengine.launch()
        return

    if FLAGS.generate_data:
        t2t_trainer.generate_data()

    if cloud_mlengine.job_dir():
        FLAGS.output_dir = cloud_mlengine.job_dir()

    if argv:
        t2t_trainer.set_hparams_from_args(argv[1:])

    root_output_dir = FLAGS.output_dir

    if FLAGS.teacher_dir:
        teacher_dir = FLAGS.teacher_dir
    else:
        teacher_dir = os.path.join(root_output_dir, "teacher")

    # Train Teacher ============
    if FLAGS.skip_teacher_training:
        tf.logging.info("training teacher skipped")
    else:
        hparams = t2t_trainer.create_hparams()
        hparams.distill_phase = "train"
        FLAGS.output_dir = teacher_dir

        exp_fn = t2t_trainer.create_experiment_fn()
        run_config = t2t_trainer.create_run_config(hparams)
        exp = exp_fn(run_config, hparams)
        if t2t_trainer.is_chief():
            t2t_trainer.save_metadata(hparams)
        t2t_trainer.execute_schedule(exp)

    # ==========================
    # Train Student ============
    hparams = t2t_trainer.create_hparams()
    hparams.add_hparam("teacher_dir", teacher_dir)
    hparams.distill_phase = "distill"
    if FLAGS.student_dir:
        student_dir = FLAGS.student_dir
    else:
        student_dir = os.path.join(root_output_dir, "student")
    FLAGS.output_dir = student_dir
    hparams.add_hparam("student_dir", student_dir)

    exp_fn = t2t_trainer.create_experiment_fn()
    run_config = t2t_trainer.create_run_config(hparams)
    exp = exp_fn(run_config, hparams)

    if t2t_trainer.is_chief():
        t2t_trainer.save_metadata(hparams)
    t2t_trainer.execute_schedule(exp)
Exemple #6
0
def create_student_experiment(run_config, hparams, argv):
    """Creates experiment function."""
    tf.logging.info("training student")
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    t2t_trainer.maybe_log_registry_and_exit()

    if FLAGS.cloud_mlengine:
        return cloud_mlengine.launch()

    if FLAGS.generate_data:
        t2t_trainer.generate_data()

    if cloud_mlengine.job_dir():
        FLAGS.output_dir = cloud_mlengine.job_dir()

    if argv:
        t2t_trainer.set_hparams_from_args(argv[1:])

    hparams.add_hparam("teacher_dir", FLAGS.teacher_dir)
    hparams.add_hparam("student_dir", FLAGS.student_dir)
    hparams.distill_phase = "distill"
    exp_fn = t2t_trainer.create_experiment_fn()
    exp = exp_fn(run_config, hparams)
    return exp
Exemple #7
0
def create_teacher_experiment(run_config, hparams, argv):
  """Creates experiment function."""
  tf.logging.info("training teacher")
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  t2t_trainer.maybe_log_registry_and_exit()

  if FLAGS.cloud_mlengine:
    return cloud_mlengine.launch()

  if FLAGS.generate_data:
    t2t_trainer.generate_data()

  if cloud_mlengine.job_dir():
    FLAGS.output_dir = cloud_mlengine.job_dir()

  if argv:
    t2t_trainer.set_hparams_from_args(argv[1:])

  with t2t_trainer.maybe_cloud_tpu():
    hparams.distill_phase = "train"
    exp_fn = t2t_trainer.create_experiment_fn()
    exp = exp_fn(run_config, hparams)
    return exp
Exemple #8
0
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)


  if FLAGS.score_file:
    filename = os.path.expanduser(FLAGS.score_file)
    if not tf.gfile.Exists(filename):
      raise ValueError("The file to score doesn't exist: %s" % filename)
    results = score_file(filename)
    if not FLAGS.decode_to_file:
      raise ValueError("To score a file, specify --decode_to_file for results.")
    write_file = tf.gfile.Open(os.path.expanduser(FLAGS.decode_to_file), "w")
    for score in results:
      write_file.write("%.6f\n" % score)
    write_file.close()
    return

  hp = create_hparams()
  decode_hp = create_decode_hparams()

  estimator = trainer_lib.create_estimator(
      FLAGS.model,
      hp,
      t2t_trainer.create_run_config(hp),
      decode_hparams=decode_hp,
      use_tpu=FLAGS.use_tpu)

  decode(estimator, hp, decode_hp)
Exemple #9
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

    if FLAGS.checkpoint_path:
        checkpoint_path = FLAGS.checkpoint_path
        ckpt_dir = os.path.dirname(checkpoint_path)
    else:
        ckpt_dir = os.path.expanduser(FLAGS.output_dir)
        checkpoint_path = tf.train.latest_checkpoint(ckpt_dir)

    hparams = create_hparams()
    hparams.no_data_parallelism = True  # To clear the devices
    problem = hparams.problem

    export_dir = FLAGS.export_dir or os.path.join(ckpt_dir, "export")

    if FLAGS.export_as_tfhub:
        checkpoint_path = tf.train.latest_checkpoint(ckpt_dir)
        decode_hparams = decoding.decode_hparams(FLAGS.decode_hparams)
        export_as_tfhub_module(FLAGS.model, hparams, decode_hparams, problem,
                               checkpoint_path, export_dir)
        return

    run_config = t2t_trainer.create_run_config(hparams)

    estimator = create_estimator(run_config, hparams)

    estimator.export_savedmodel(export_dir,
                                lambda: problem.serving_input_fn(hparams),
                                as_text=False,
                                checkpoint_path=checkpoint_path)
Exemple #10
0
def main(argv):
  tf.logging.set_verbosity(tf.logging.INFO)

  if FLAGS.v2:
    tf.enable_v2_behavior()
    # Hacking main v1 flags to work with v2.
    config_strs = []
    config_strs.append(
        "train_fn.train_steps=" + str(FLAGS.train_steps))
    config_strs.append(
        "train_fn.eval_steps=" + str(FLAGS.eval_steps))
    config_strs.append(
        "train_fn.eval_frequency=" + str(FLAGS.local_eval_frequency))
    if FLAGS.hparams:
      config_strs.extend(str(FLAGS.hparams).split(","))
    config_str = "\n".join(config_strs)
    data_dir = os.path.expanduser(FLAGS.data_dir)
    output_dir = os.path.expanduser(FLAGS.output_dir)
    t2t_v2.t2t_train(FLAGS.model, FLAGS.problem,
                     data_dir=data_dir, output_dir=output_dir,
                     config_file=FLAGS.hparams_set, config=config_str)
    return

  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

  # If we just have to print the registry, do that and exit early.
  maybe_log_registry_and_exit()

  # Create HParams.
  if argv:
    set_hparams_from_args(argv[1:])
  hparams = create_hparams()

  if FLAGS.schedule == "train" or FLAGS.schedule == "train_eval_and_decode":
    mlperf_log.transformer_print(key=mlperf_log.RUN_START, hparams=hparams)
  if FLAGS.schedule == "run_std_server":
    run_std_server()
  mlperf_log.transformer_print(
      key=mlperf_log.RUN_SET_RANDOM_SEED, value=FLAGS.random_seed,
      hparams=hparams)
  trainer_lib.set_random_seed(FLAGS.random_seed)

  if FLAGS.cloud_mlengine:
    cloud_mlengine.launch()
    return

  if FLAGS.generate_data:
    generate_data()

  if cloud_mlengine.job_dir():
    FLAGS.output_dir = cloud_mlengine.job_dir()

  exp_fn = create_experiment_fn()
  exp = exp_fn(create_run_config(hparams), hparams)
  if is_chief():
    save_metadata(hparams)
  execute_schedule(exp)
  if FLAGS.schedule != "train":
    mlperf_log.transformer_print(key=mlperf_log.RUN_FINAL,
                                 hparams=hparams)
Exemple #11
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

    # Create hparams
    hparams = create_hparams()
    hparams.force_full_predict = True
    batch_size = hparams.batch_size

    # Iterating over dev/test partition of the data.
    # Change the data partition if necessary.
    dataset = registry.problem(FLAGS.problem).dataset(
        tf.estimator.ModeKeys.PREDICT, shuffle_files=False, hparams=hparams)

    dataset = dataset.apply(
        tf.contrib.data.batch_and_drop_remainder(batch_size))
    data = dataset.make_one_shot_iterator().get_next()
    input_data = dict(
        (k, data[k]) for k in data.keys() if k.startswith("input"))

    # Creat model
    model_cls = registry.model(FLAGS.model)
    model = model_cls(hparams, tf.estimator.ModeKeys.PREDICT)
    prediction_ops = model.infer(input_data)

    # Confusion Matrix
    nr = hparams.problem.num_rewards
    cm_per_frame = np.zeros((nr, nr), dtype=np.uint64)
    cm_next_frame = np.zeros((nr, nr), dtype=np.uint64)

    saver = tf.train.Saver()
    with tf.train.SingularMonitoredSession() as sess:
        # Load latest checkpoint
        ckpt = tf.train.get_checkpoint_state(
            FLAGS.output_dir).model_checkpoint_path
        saver.restore(sess.raw_session(), ckpt)

        counter = 0
        while not sess.should_stop():
            counter += 1
            if counter % 1 == 0:
                print(counter)

            # Predict next frames
            rew_pd, rew_gt = sess.run(
                [prediction_ops["target_reward"], data["target_reward"]])

            for i in range(batch_size):
                cm_next_frame[rew_gt[i, 0, 0], rew_pd[i, 0, 0]] += 1
                for gt, pd in zip(rew_gt[i], rew_pd[i]):
                    cm_per_frame[gt, pd] += 1

    print_confusion_matrix("Per-frame Confusion Matrix", cm_per_frame)
    print_confusion_matrix("Next-frame Confusion Matrix", cm_next_frame)
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

  # Create hparams
  hparams = create_hparams()
  hparams.force_full_predict = True
  batch_size = hparams.batch_size

  # Iterating over dev/test partition of the data.
  # Change the data partition if necessary.
  dataset = registry.problem(FLAGS.problem).dataset(
      tf.estimator.ModeKeys.PREDICT,
      shuffle_files=False,
      hparams=hparams)

  dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
  data = dataset.make_one_shot_iterator().get_next()
  input_data = dict((k, data[k]) for k in data.keys() if k.startswith("input"))

  # Creat model
  model_cls = registry.model(FLAGS.model)
  model = model_cls(hparams, tf.estimator.ModeKeys.PREDICT)
  prediction_ops = model.infer(input_data)

  # Confusion Matrix
  nr = hparams.problem.num_rewards
  cm_per_frame = np.zeros((nr, nr), dtype=np.uint64)
  cm_next_frame = np.zeros((nr, nr), dtype=np.uint64)

  saver = tf.train.Saver()
  with tf.train.SingularMonitoredSession() as sess:
    # Load latest checkpoint
    ckpt = tf.train.get_checkpoint_state(FLAGS.output_dir).model_checkpoint_path
    saver.restore(sess.raw_session(), ckpt)

    counter = 0
    while not sess.should_stop():
      counter += 1
      if counter % 1 == 0:
        print(counter)

      # Predict next frames
      rew_pd, rew_gt = sess.run(
          [prediction_ops["target_reward"], data["target_reward"]])

      for i in range(batch_size):
        cm_next_frame[rew_gt[i, 0, 0], rew_pd[i, 0, 0]] += 1
        for gt, pd in zip(rew_gt[i], rew_pd[i]):
          cm_per_frame[gt, pd] += 1

  print_confusion_matrix("Per-frame Confusion Matrix", cm_per_frame)
  print_confusion_matrix("Next-frame Confusion Matrix", cm_next_frame)
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    # sess_dir = FLAGS.sess_dir
    # output_dir = os.path.expanduser(sess_dir+problem_name+'-'+model+'-'+hparams)
    output_dir = FLAGS.output_dir

    if FLAGS.score_file:
        filename = os.path.expanduser(FLAGS.score_file)
        if not tf.gfile.Exists(filename):
            raise ValueError("The file to score doesn't exist: %s" % filename)
        results = score_file(filename)
        if not FLAGS.decode_to_file:
            raise ValueError(
                "To score a file, specify --decode_to_file for results.")
        write_file = tf.gfile.Open(os.path.expanduser(FLAGS.decode_to_file),
                                   "w")
        for score in results:
            write_file.write("%.6f\n" % score)
        write_file.close()
        return

    hp = create_hparams()

    if FLAGS.global_steps:
        FLAGS.checkpoint_path = os.path.join(
            FLAGS.model_dir, f"model.ckpt-{FLAGS.global_steps}")
    else:
        FLAGS.checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)

    # Check if already exists
    dataset_split = "test" if FLAGS.split == "test" else "dev"
    decode_path = os.path.join(FLAGS.model_dir,
                               "decode_00000")  # default decoded_to_file
    decode_path = FLAGS.decode_to_file if FLAGS.decode_to_file else decode_path
    if os.path.isdir(decode_path):
        files = os.listdir(decode_path)
        for file in files:
            file_name = file.split(".")[0]
            file_name_to_be = f"{FLAGS.global_steps}{dataset_split}{FLAGS.test_shard:03d}"
            if file_name == file_name_to_be:
                print(f"Already {file_name_to_be} exists")
                return

    tf.reset_default_graph()
    decode_hp = create_decode_hparams(decode_path, FLAGS.test_shard)
    estimator = trainer_lib.create_estimator(FLAGS.model,
                                             hp,
                                             create_run_config(hp),
                                             decode_hparams=decode_hp,
                                             use_tpu=FLAGS.use_tpu)
    decode(estimator, hp, decode_hp)
    print("shard " + str(FLAGS.test_shard) + " completed")
Exemple #14
0
def main(argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  t2t_trainer.maybe_log_registry_and_exit()


  if FLAGS.generate_data:
    t2t_trainer.generate_data()

  if argv:
    t2t_trainer.set_hparams_from_args(argv[1:])
  hparams = t2t_trainer.create_hparams()
  trainer_lib.add_problem_hparams(hparams, FLAGS.problem)
  pruning_params = create_pruning_params()
  pruning_strategy = create_pruning_strategy(pruning_params.strategy)

  config = t2t_trainer.create_run_config(hparams)
  params = {"batch_size": hparams.batch_size}

  # add "_rev" as a hack to avoid image standardization
  problem = registry.problem(FLAGS.problem)
  input_fn = problem.make_estimator_input_fn(tf.estimator.ModeKeys.EVAL,
                                             hparams)
  dataset = input_fn(params, config).repeat()
  features, labels = dataset.make_one_shot_iterator().get_next()

  sess = tf.Session()

  model_fn = t2t_model.T2TModel.make_estimator_model_fn(
      FLAGS.model, hparams, use_tpu=FLAGS.use_tpu)
  spec = model_fn(
      features,
      labels,
      tf.estimator.ModeKeys.EVAL,
      params=hparams,
      config=config)

  # Restore weights
  saver = tf.train.Saver()
  checkpoint_path = os.path.expanduser(FLAGS.output_dir or
                                       FLAGS.checkpoint_path)
  saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path))

  def eval_model():
    preds = spec.predictions["predictions"]
    preds = tf.argmax(preds, -1, output_type=labels.dtype)
    _, acc_update_op = tf.metrics.accuracy(labels=labels, predictions=preds)
    sess.run(tf.initialize_local_variables())
    for _ in range(FLAGS.eval_steps):
      acc = sess.run(acc_update_op)
    return acc

  pruning_utils.sparsify(sess, eval_model, pruning_strategy, pruning_params)
def main(_):
  import ipdb
  
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)


  if FLAGS.score_file:
    filename = os.path.expanduser(FLAGS.score_file)
    if not tf.gfile.Exists(filename):
      raise ValueError("The file to score doesn't exist: %s" % filename)
    results = score_file(filename)
    if not FLAGS.decode_to_file:
      raise ValueError("To score a file, specify --decode_to_file for results.")
    write_file = tf.gfile.Open(os.path.expanduser(FLAGS.decode_to_file), "w")
    for score in results:
      write_file.write("%.6f\n" % score)
    write_file.close()
    return

  hp = create_hparams()
  decode_hp = create_decode_hparams()

  # eval_input_fn = hp.problem.make_estimator_input_fn(
  #   tf.estimator.ModeKeys.TRAIN, hp, dataset_kwargs={"dataset_split": "eval"})


  # print(eval_input_fn)
  # for foo in eval_input_fn(None, None):
  #   print(type(foo[0]['targets']))
  #   print(foo[0]['targets'].numpy())
  # exit()
  
  run_config = t2t_trainer.create_run_config(hp)
  if FLAGS.disable_grappler_optimizations:
    run_config.session_config.graph_options.rewrite_options.disable_meta_optimizer = True

  # summary-hook in tf.estimator.EstimatorSpec requires
  # hparams.model_dir to be set.
  hp.add_hparam("model_dir", run_config.model_dir)

  estimator = trainer_lib.create_estimator(
      FLAGS.model,
      hp,
      run_config,
      decode_hparams=decode_hp,
      use_tpu=FLAGS.use_tpu)

  decode(estimator, hp, decode_hp)
Exemple #16
0
def main(argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  t2t_trainer.maybe_log_registry_and_exit()

  if FLAGS.cloud_mlengine:
    cloud_mlengine.launch()
    return

  if FLAGS.generate_data:
    t2t_trainer.generate_data()

  if cloud_mlengine.job_dir():
    FLAGS.output_dir = cloud_mlengine.job_dir()

  if argv:
    t2t_trainer.set_hparams_from_args(argv[1:])

  with t2t_trainer.maybe_cloud_tpu():
    root_output_dir = FLAGS.output_dir

    # Train Teacher ============
    hparams = t2t_trainer.create_hparams()
    hparams.distill_phase = "train"
    teacher_dir = os.path.join(root_output_dir, "teacher")
    FLAGS.output_dir = teacher_dir

    exp_fn = t2t_trainer.create_experiment_fn()
    run_config = t2t_trainer.create_run_config(hparams)
    exp = exp_fn(run_config, hparams)
    if t2t_trainer.is_chief():
      t2t_trainer.save_metadata(hparams)
    t2t_trainer.execute_schedule(exp)
    # ==========================
    # Train Student ============
    hparams = t2t_trainer.create_hparams()
    hparams.add_hparam("teacher_dir", teacher_dir)
    hparams.distill_phase = "distill"
    student_dir = os.path.join(root_output_dir, "student")
    FLAGS.output_dir = student_dir

    exp_fn = t2t_trainer.create_experiment_fn()
    run_config = t2t_trainer.create_run_config(hparams)
    exp = exp_fn(run_config, hparams)

    if t2t_trainer.is_chief():
      t2t_trainer.save_metadata(hparams)
    t2t_trainer.execute_schedule(exp)
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

    if FLAGS.score_file:
        filename = os.path.expanduser(FLAGS.score_file)
        if not tf.gfile.Exists(filename):
            raise ValueError("The file to score doesn't exist: %s" % filename)
        results = score_file(filename)
        if not FLAGS.decode_to_file:
            raise ValueError(
                "To score a file, specify --decode_to_file for results.")
        write_file = tf.gfile.Open(os.path.expanduser(FLAGS.decode_to_file),
                                   "w")
        for score in results:
            write_file.write("%.6f\n" % score)
        write_file.close()
        return

    hp = create_hparams()
    decode_hp = create_decode_hparams()

    estimator = trainer_lib.create_estimator(FLAGS.model,
                                             hp,
                                             t2t_trainer.create_run_config(hp),
                                             decode_hparams=decode_hp,
                                             use_tpu=FLAGS.use_tpu)

    decode(estimator, hp, decode_hp)

    # Post-process decodings (if necessary).
    if FLAGS.decode_to_file and FLAGS.output_line_prefix_tag:
        decode_filename_original = FLAGS.decode_to_file
        decode_filename_prefixed = "%s-%s" % (decode_filename_original,
                                              FLAGS.output_line_prefix_tag)
        tf.logging.info("Writing prefexed decodes into %s" %
                        decode_filename_prefixed)
        # Read original lines.
        with tf.gfile.Open(decode_filename_original, "r") as original_fp:
            original_lines = original_fp.readlines()
        # Write prefixed lines.
        prefix = "<%s> " % FLAGS.output_line_prefix_tag
        prefixed_fp = tf.gfile.Open(decode_filename_prefixed, "w")
        for line in original_lines:
            prefixed_fp.write(prefix + line)
        prefixed_fp.flush()
        prefixed_fp.close()
        tf.logging.info("Done.")
Exemple #18
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

    hp = t2t_decoder.create_hparams()
    decode_hp = t2t_decoder.create_decode_hparams()

    estimator = trainer_lib.create_estimator(FLAGS.model,
                                             hp,
                                             t2t_trainer.create_run_config(hp),
                                             decode_hparams=decode_hp,
                                             use_tpu=FLAGS.use_tpu)

    decode(estimator, hp, decode_hp)
Exemple #19
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    # Fathom start
    checkpoint_path = fathom_t2t_model_setup()
    # Fathom end
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

    if FLAGS.score_file:
        filename = os.path.expanduser(FLAGS.score_file)
        if not tf.gfile.Exists(filename):
            raise ValueError("The file to score doesn't exist: %s" % filename)
        results = score_file(filename)
        if not FLAGS.decode_to_file:
            raise ValueError(
                "To score a file, specify --decode_to_file for results.")
        write_file = tf.gfile.Open(os.path.expanduser(FLAGS.decode_to_file),
                                   "w")
        for score in results:
            write_file.write("%.6f\n" % score)
        write_file.close()
        return

    hp = create_hparams()
    decode_hp = create_decode_hparams()

    estimator = trainer_lib.create_estimator(FLAGS.model,
                                             hp,
                                             t2t_trainer.create_run_config(hp),
                                             decode_hparams=decode_hp,
                                             use_tpu=FLAGS.use_tpu)

    decode(estimator, hp, decode_hp)

    # Fathom
    # This xcom is here so that tasks after decode know the local path to the
    # downloaded model. Train does this same xcom echo.
    # Decode, predict, and evaluate code should
    # converge to use the same fathom_t2t_model_setup.
    # TODO: since the truncation-boundary xcom value should be available in
    #  the hparams_set, we should probably have consumers access this via a
    #  SavedModel.hparams property rather than XCOM
    echo_yaml_for_xcom_ingest({
        'output-dir': os.path.dirname(checkpoint_path),
        'output-file': FLAGS.decode_output_file,
        'truncation-boundary': hp.max_input_seq_length
    })
Exemple #20
0
def main(argv):
    # Fathom
    if FLAGS.fathom:
        fathom.t2t_trainer_setup(FLAGS.problem)

    # This exits if a checkpoint is set but not found.
    # Only takes action for an eval task.
    if FLAGS.schedule == 'evaluate':
        fathom.exit_if_no_eval_checkpoint_found(FLAGS.eval_checkpoint_path)

    tf.logging.set_verbosity(tf.logging.INFO)
    if FLAGS.schedule == "run_std_server":
        run_std_server()
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    maybe_log_registry_and_exit()

    if FLAGS.cloud_mlengine:
        # Fathom
        assert False, 'No cloudml support currently'
        cloud_mlengine.launch()
        return

    if FLAGS.generate_data:
        generate_data()

    # Fathom commented out
    # if cloud_mlengine.job_dir():
    #   FLAGS.output_dir = cloud_mlengine.job_dir()

    if argv:
        set_hparams_from_args(argv[1:])
    hparams = create_hparams()

    # Fathom
    hparams = fathom.adjust_params(hparams)

    exp_fn = create_experiment_fn()
    exp = exp_fn(create_run_config(hparams), hparams)
    if is_chief():
        save_metadata(hparams)
    execute_schedule(exp)

    # Fathom
    # NOTE: this must run LAST in the process, to make sure STDOUT is
    # appropriately populated.
    fathom.t2t_trainer_cleanup()
Exemple #21
0
def main(argv):
  tf.logging.set_verbosity(tf.logging.INFO)

  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

  # If we just have to print the registry, do that and exit early.
  maybe_log_registry_and_exit()

  # Create HParams.
  if argv:
    set_hparams_from_args(argv[1:])
  if FLAGS.schedule != "run_std_server":
    hparams = create_hparams()
  if FLAGS.gpu_automatic_mixed_precision:
    setattr(hparams, "gpu_automatic_mixed_precision", True)

  if FLAGS.schedule == "train" or FLAGS.schedule == "train_eval_and_decode":
    mlperf_log.transformer_print(key=mlperf_log.RUN_START, hparams=hparams)
  if FLAGS.schedule == "run_std_server":
    run_std_server()
  mlperf_log.transformer_print(
      key=mlperf_log.RUN_SET_RANDOM_SEED, value=FLAGS.random_seed,
      hparams=hparams)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  for flag, val in FLAGS.__flags.items():
    print(flag, ": ", val.value)  

  if FLAGS.cloud_mlengine:
    cloud_mlengine.launch()
    return

  if FLAGS.generate_data:
    generate_data()

  if cloud_mlengine.job_dir():
    FLAGS.output_dir = cloud_mlengine.job_dir()

  exp_fn = create_experiment_fn()
  exp = exp_fn(create_run_config(hparams), hparams)
  if is_chief():
    save_metadata(hparams)
  execute_schedule(exp)
  if FLAGS.schedule != "train":
    mlperf_log.transformer_print(key=mlperf_log.RUN_FINAL,
                                 hparams=hparams)
Exemple #22
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

    ckpt_dir = os.path.expanduser(FLAGS.output_dir)

    hparams = create_hparams()
    hparams.no_data_parallelism = True  # To clear the devices
    run_config = t2t_trainer.create_run_config(hparams)

    estimator = create_estimator(run_config, hparams)

    problem = hparams.problem
    strategy = trainer_lib.create_export_strategy(problem, hparams)

    export_dir = os.path.join(ckpt_dir, "export", strategy.name)
    strategy.export(estimator,
                    export_dir,
                    checkpoint_path=tf.train.latest_checkpoint(ckpt_dir))
Exemple #23
0
def main(argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    hvd.init()

    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    log_registry()

    if FLAGS.cloud_mlengine:
        return cloud_mlengine.launch()

    if FLAGS.generate_data:
        generate_data()

    if hasattr(FLAGS, "job_dir") and FLAGS.job_dir:
        FLAGS.output_dir = FLAGS.job_dir

    if argv:
        set_hparams_from_args(argv[1:])

    #
    hparams = create_hparams()

    if is_chief():
        save_metadata(hparams)

    # create_run_config会调用trainer_lib.create_session_config,这个函数包含gup_options初始化
    config = create_run_config(hparams)
    decode_hparams = decoding.decode_hparams(FLAGS.decode_hparams)
    schedule = FLAGS.schedule

    estimator = create_estimator_fn(FLAGS.model, hparams, config, schedule, decode_hparams)

    # logging_hook = tf.train.LoggingTensorHook({"step": "test"}, every_n_iter=5)
    bcast_hook = hvd.BroadcastGlobalVariablesHook(0)

    estimator.train(
        input_fn=train_input_fn(hparams),
        steps=FLAGS.train_steps,
        hooks=[bcast_hook]
    )
Exemple #24
0
    def entry(self):

        tf.logging.set_verbosity(tf.logging.INFO)
        trainer_lib.set_random_seed(FLAGS.random_seed)
        usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

        print("###self defined hp###")
        print(str(FLAGS.data_dir))
        print(str(FLAGS.problem))
        print(str(FLAGS.model))
        print(str(FLAGS.hparams_set))
        print(str(FLAGS.output_dir))
        print(str(FLAGS.decode_hparams))

        hp = self.create_hparams()
        decode_hp = self.create_decode_hparams()
        estimator = self.create_new_estimator(hp, decode_hp)

        output_decode = self.my_decode(estimator, hp, decode_hp)
        print('output decode-res  = %s ' % str(output_decode))
        return output_decode
Exemple #25
0
def entry(input_str):
  # global estimator
  # global hp
  # global decode_hp
  # flags.FLAGS(argv , known_only=True)

  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

  print("###self defined hp###")
  print(str(FLAGS.data_dir))
  print(str(FLAGS.problem))
  print(str(FLAGS.model))
  print(str(FLAGS.hparams_set))
  print(str(FLAGS.output_dir))
  print(str(FLAGS.decode_hparams))

  # if hp is None:
  #   print('hp is None !')
  #   hp = create_hparams()
  # if decode_hp is None:
  #   print('decode_hp is None !')
  #   decode_hp = create_decode_hparams()
  # if estimator is None:
  #   print('estimator is None !')
  #   estimator = my_trainer_lib.create_estimator(
  #     FLAGS.model,
  #     hp,
  #     t2t_trainer.create_run_config(hp),
  #     decode_hparams=decode_hp,
  #     use_tpu=FLAGS.use_tpu)

  hp=app.config['hp']
  decode_hp=app.config['decode_hp']
  estimator=app.config['estimator']

  output_decode = my_decode(estimator, hp, decode_hp,input_str)
  print('output-decode-res  = %s ' % str(output_decode))
  return output_decode
Exemple #26
0
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

  if FLAGS.checkpoint_path:
    checkpoint_path = FLAGS.checkpoint_path
    ckpt_dir = os.path.dirname(checkpoint_path)
  else:
    ckpt_dir = os.path.expanduser(FLAGS.output_dir)
    checkpoint_path = tf.train.latest_checkpoint(ckpt_dir)

  hparams = create_hparams()
  hparams.no_data_parallelism = True  # To clear the devices
  problem = hparams.problem

  export_dir = FLAGS.export_dir or os.path.join(ckpt_dir, "export")

  if FLAGS.export_as_tfhub:
    checkpoint_path = tf.train.latest_checkpoint(ckpt_dir)
    decode_hparams = decoding.decode_hparams(FLAGS.decode_hparams)
    export_as_tfhub_module(FLAGS.model, hparams, decode_hparams, problem,
                           checkpoint_path, export_dir)
    return

  run_config = t2t_trainer.create_run_config(hparams)

  estimator = create_estimator(run_config, hparams)

  exporter = tf.estimator.FinalExporter(
      "exporter", lambda: problem.serving_input_fn(hparams), as_text=True)

  exporter.export(
      estimator,
      export_dir,
      checkpoint_path=checkpoint_path,
      eval_result=None,
      is_the_final_export=True)
Exemple #27
0
def main(argv):
    set_hparams_from_args(argv[1:])
    hparams = create_hparams()
    hparams.add_hparam("data_dir", FLAGS.data_dir)
    trainer_lib.set_random_seed(FLAGS.random_seed)

    hparams_lib.add_problem_hparams(hparams, FLAGS.problem)
    problem = hparams.problem
    train_input_fn = problem.make_estimator_input_fn(
        tf.estimator.ModeKeys.TRAIN, hparams)

    eval_input_fn = problem.make_estimator_input_fn(
        tf.estimator.ModeKeys.EVAL,
        hparams,
        dataset_kwargs={'dataset_split': None})

    features, labels, input_hooks = estimator_util._get_features_and_labels_from_input_fn(
        train_input_fn, tf.estimator.ModeKeys.TRAIN, hparams,
        create_run_config(hparams))

    print(features)
    print(labels)
    print(input_hooks)
Exemple #28
0
def create_hp_and_estimator(problem_name, data_dir, checkpoint_path):
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)

    hp = trainer_lib.create_hparams(FLAGS.hparams_set,
                                    FLAGS.hparams,
                                    data_dir=os.path.expanduser(data_dir),
                                    problem_name=problem_name)

    decode_hp = decoding.decode_hparams(FLAGS.decode_hparams)
    decode_hp.shards = FLAGS.decode_shards
    decode_hp.shard_id = FLAGS.worker_id
    decode_in_memory = FLAGS.decode_in_memory or decode_hp.decode_in_memory
    decode_hp.decode_in_memory = decode_in_memory
    decode_hp.decode_to_file = None
    decode_hp.decode_reference = None

    FLAGS.checkpoint_path = checkpoint_path
    estimator = trainer_lib.create_estimator(FLAGS.model,
                                             hp,
                                             t2t_trainer.create_run_config(hp),
                                             decode_hparams=decode_hp,
                                             use_tpu=FLAGS.use_tpu)
    return hp, decode_hp, estimator
Exemple #29
0
def main(argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  t2t_trainer.maybe_log_registry_and_exit()


  if FLAGS.cloud_mlengine:
    cloud_mlengine.launch()
    return

  if FLAGS.generate_data:
    t2t_trainer.generate_data()

  if cloud_mlengine.job_dir():
    FLAGS.output_dir = cloud_mlengine.job_dir()

  if argv:
    t2t_trainer.set_hparams_from_args(argv[1:])
  hparams = t2t_trainer.create_hparams()
  trainer_lib.add_problem_hparams(hparams, FLAGS.problem)
  attack_params = create_attack_params()
  attack_params.add_hparam("eps", 0.0)

  config = t2t_trainer.create_run_config(hparams)
  params = {"batch_size": hparams.batch_size}

  # add "_rev" as a hack to avoid image standardization
  problem = registry.problem(FLAGS.problem + "_rev")
  input_fn = problem.make_estimator_input_fn(
      tf.estimator.ModeKeys.EVAL, hparams)
  dataset = input_fn(params, config).repeat()
  features, _ = dataset.make_one_shot_iterator().get_next()
  inputs, labels = features["targets"], features["inputs"]
  inputs = tf.to_float(inputs)
  labels = tf.squeeze(labels)

  sess = tf.Session()

  model_fn = t2t_model.T2TModel.make_estimator_model_fn(
      FLAGS.model, hparams, use_tpu=FLAGS.use_tpu)
  ch_model = adv_attack_utils.T2TAttackModel(model_fn, params, config)

  acc_mask = None
  probs = ch_model.get_probs(inputs)
  if FLAGS.ignore_incorrect:
    preds = tf.argmax(probs, -1)
    preds = tf.squeeze(preds)
    acc_mask = tf.to_float(tf.equal(labels, preds))
  one_hot_labels = tf.one_hot(labels, probs.shape[-1])

  attack = create_attack(attack_params.attack)(ch_model, sess=sess)

  # Restore weights
  saver = tf.train.Saver()
  checkpoint_path = os.path.expanduser(FLAGS.output_dir or
                                       FLAGS.checkpoint_path)
  saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path))

  # reuse variables
  tf.get_variable_scope().reuse_variables()

  def compute_accuracy(x, labels, mask):
    preds = ch_model.get_probs(x)
    preds = tf.squeeze(preds)
    preds = tf.argmax(preds, -1, output_type=labels.dtype)
    _, acc_update_op = tf.metrics.accuracy(
        labels=labels, predictions=preds, weights=mask)
    sess.run(tf.initialize_local_variables())
    for _ in range(FLAGS.eval_steps):
      acc = sess.run(acc_update_op)
    return acc

  acc = compute_accuracy(inputs, labels, acc_mask)
  epsilon_acc_pairs = [(0.0, acc)]
  for epsilon in attack_params.attack_epsilons:
    attack_params.eps = epsilon
    adv_x = attack.generate(inputs, y=one_hot_labels, **attack_params.values())
    acc = compute_accuracy(adv_x, labels, acc_mask)
    epsilon_acc_pairs.append((epsilon, acc))

  for epsilon, acc in epsilon_acc_pairs:
    tf.logging.info("Accuracy @ eps=%f: %f" % (epsilon, acc))
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

  # Create hparams
  hparams = trainer_lib.create_hparams(
      FLAGS.hparams_set,
      FLAGS.hparams,
      data_dir=os.path.expanduser(FLAGS.data_dir),
      problem_name=FLAGS.problem)
  hparams.force_full_predict = True
  hparams.scheduled_sampling_k = -1

  # Params
  num_agents = 1  # TODO(mbz): fix the code for more agents
  num_steps = FLAGS.num_steps
  if hasattr(hparams.problem, "num_actions"):
    num_actions = hparams.problem.num_actions
  else:
    num_actions = None
  frame_shape = hparams.problem.frame_shape
  resized_frame = hparams.preprocess_resize_frames is not None
  if resized_frame:
    frame_shape = hparams.preprocess_resize_frames
    frame_shape += [hparams.problem.num_channels]

  dataset = registry.problem(FLAGS.problem).dataset(
      tf.estimator.ModeKeys.TRAIN,
      shuffle_files=True,
      data_dir=os.path.expanduser(FLAGS.data_dir),
      hparams=hparams)

  dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(num_agents))
  data = dataset.make_one_shot_iterator().get_next()
  # Setup input placeholders
  input_size = [num_agents, hparams.video_num_input_frames]
  if num_actions is None:
    placeholders = {
        "inputs": tf.placeholder(tf.float32, input_size + frame_shape)
    }
  else:
    placeholders = {
        "inputs": tf.placeholder(tf.float32, input_size + frame_shape),
        "input_action": tf.placeholder(tf.int64, input_size + [1]),
        "input_reward": tf.placeholder(tf.int64, input_size + [1]),
        "reset_internal_states": tf.placeholder(tf.float32, []),
    }
  # Create model.
  model_cls = registry.model(FLAGS.model)
  model = model_cls(hparams, tf.estimator.ModeKeys.PREDICT)
  prediction_ops = model.infer(placeholders)

  states_q = Queue(maxsize=hparams.video_num_input_frames)
  actions_q = Queue(maxsize=hparams.video_num_input_frames)
  rewards_q = Queue(maxsize=hparams.video_num_input_frames)
  if num_actions is not None:
    all_qs = [states_q, actions_q, rewards_q]
  else:
    all_qs = [states_q]

  writer = common_video.WholeVideoWriter(
      fps=FLAGS.fps, output_path=FLAGS.output_gif)

  saver = tf.train.Saver(tf.trainable_variables())
  with tf.train.SingularMonitoredSession() as sess:
    # Load latest checkpoint
    ckpt = tf.train.get_checkpoint_state(FLAGS.output_dir).model_checkpoint_path
    saver.restore(sess.raw_session(), ckpt)

    # get init frames from the dataset
    data_np = sess.run(data)

    frames = np.split(data_np["inputs"], hparams.video_num_input_frames, 1)
    for frame in frames:
      frame = np.squeeze(frame, 1)
      states_q.put(frame)
      writer.write(frame[0].astype(np.uint8))

    if num_actions is not None:
      actions = np.split(data_np["input_action"],
                         hparams.video_num_input_frames, 1)
      for action in actions:
        actions_q.put(np.squeeze(action, 1))

      rewards = np.split(data_np["input_reward"],
                         hparams.video_num_input_frames, 1)
      for reward in rewards:
        rewards_q.put(np.squeeze(reward, 1))

    for step in range(num_steps):
      print(">>>>>>> ", step)

      if num_actions is not None:
        random_actions = np.random.randint(num_actions-1)
        random_actions = np.expand_dims(random_actions, 0)
        random_actions = np.tile(random_actions, (num_agents, 1))

        # Shape inputs and targets
        inputs, input_action, input_reward = (
            np.stack(list(q.queue), axis=1) for q in all_qs)
      else:
        assert len(all_qs) == 1
        q = all_qs[0]
        elems = list(q.queue)
        # Need to adjust shapes sometimes.
        for i, e in enumerate(elems):
          if len(e.shape) < 4:
            elems[i] = np.expand_dims(e, axis=0)
        inputs = np.stack(elems, axis=1)

      # Predict next frames
      if num_actions is None:
        feed = {placeholders["inputs"]: inputs}
      else:
        feed = {
            placeholders["inputs"]: inputs,
            placeholders["input_action"]: input_action,
            placeholders["input_reward"]: input_reward,
            placeholders["reset_internal_states"]: float(step == 0),
        }
      predictions = sess.run(prediction_ops, feed_dict=feed)

      if num_actions is None:
        predicted_states = predictions[:, 0]
      else:
        predicted_states = predictions["targets"][:, 0]
        predicted_reward = predictions["target_reward"][:, 0]

      # Update queues
      if num_actions is None:
        new_data = (predicted_states)
      else:
        new_data = (predicted_states, random_actions, predicted_reward)
      for q, d in zip(all_qs, new_data):
        q.get()
        q.put(d.copy())

      writer.write(np.round(predicted_states[0]).astype(np.uint8))

    writer.finish_to_disk()
Exemple #31
0
def main(argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  t2t_trainer.maybe_log_registry_and_exit()


  if FLAGS.cloud_mlengine:
    cloud_mlengine.launch()
    return

  if FLAGS.generate_data:
    t2t_trainer.generate_data()

  if cloud_mlengine.job_dir():
    FLAGS.output_dir = cloud_mlengine.job_dir()

  if argv:
    t2t_trainer.set_hparams_from_args(argv[1:])

  if FLAGS.surrogate_attack:
    tf.logging.warn("Performing surrogate model attack.")
    sur_hparams = create_surrogate_hparams()
    trainer_lib.add_problem_hparams(sur_hparams, FLAGS.problem)

  hparams = t2t_trainer.create_hparams()
  trainer_lib.add_problem_hparams(hparams, FLAGS.problem)

  attack_params = create_attack_params()
  attack_params.add_hparam(attack_params.epsilon_name, 0.0)

  if FLAGS.surrogate_attack:
    sur_config = create_surrogate_run_config(sur_hparams)
  config = t2t_trainer.create_run_config(hparams)
  params = {
      "batch_size": hparams.batch_size,
      "use_tpu": FLAGS.use_tpu,
  }

  # add "_rev" as a hack to avoid image standardization
  problem = registry.problem(FLAGS.problem + "_rev")

  inputs, labels, features = prepare_data(problem, hparams, params, config)

  sess = tf.Session()

  if FLAGS.surrogate_attack:
    sur_model_fn = t2t_model.T2TModel.make_estimator_model_fn(
        FLAGS.surrogate_model, sur_hparams, use_tpu=FLAGS.use_tpu)
    sur_ch_model = adv_attack_utils.T2TAttackModel(
        sur_model_fn, features, params, sur_config, scope="surrogate")
    # Dummy call to construct graph
    sur_ch_model.get_probs(inputs)

    checkpoint_path = os.path.expanduser(FLAGS.surrogate_output_dir)
    tf.contrib.framework.init_from_checkpoint(
        tf.train.latest_checkpoint(checkpoint_path), {"/": "surrogate/"})
    sess.run(tf.global_variables_initializer())

  other_vars = set(tf.global_variables())

  model_fn = t2t_model.T2TModel.make_estimator_model_fn(
      FLAGS.model, hparams)
  ch_model = adv_attack_utils.T2TAttackModel(model_fn, features, params, config)

  acc_mask = None
  probs = ch_model.get_probs(inputs)
  if FLAGS.ignore_incorrect:
    preds = tf.argmax(probs, -1, output_type=labels.dtype)
    preds = tf.reshape(preds, labels.shape)
    acc_mask = tf.to_float(tf.equal(labels, preds))
  one_hot_labels = tf.one_hot(labels, probs.shape[-1])

  if FLAGS.surrogate_attack:
    attack = create_attack(attack_params.attack)(sur_ch_model, sess=sess)
  else:
    attack = create_attack(attack_params.attack)(ch_model, sess=sess)

  new_vars = set(tf.global_variables()) - other_vars

  # Restore weights
  saver = tf.train.Saver(new_vars)
  checkpoint_path = os.path.expanduser(FLAGS.output_dir)
  saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path))

  # reuse variables
  tf.get_variable_scope().reuse_variables()

  def compute_accuracy(x, l, mask):
    """Compute model accuracy."""
    preds = ch_model.get_probs(x)
    preds = tf.squeeze(preds)
    preds = tf.argmax(preds, -1, output_type=l.dtype)

    _, acc_update_op = tf.metrics.accuracy(l, preds, weights=mask)

    if FLAGS.surrogate_attack:
      preds = sur_ch_model.get_probs(x)
      preds = tf.squeeze(preds)
      preds = tf.argmax(preds, -1, output_type=l.dtype)
      acc_update_op = tf.tuple((acc_update_op,
                                tf.metrics.accuracy(l, preds, weights=mask)[1]))

    sess.run(tf.initialize_local_variables())
    for i in range(FLAGS.eval_steps):
      tf.logging.info(
          "\tEvaluating batch [%d / %d]" % (i + 1, FLAGS.eval_steps))
      acc = sess.run(acc_update_op)
    if FLAGS.surrogate_attack:
      tf.logging.info("\tFinal acc: (%.4f, %.4f)" % (acc[0], acc[1]))
    else:
      tf.logging.info("\tFinal acc: %.4f" % acc)
    return acc

  epsilon_acc_pairs = []
  for epsilon in attack_params.attack_epsilons:
    tf.logging.info("Attacking @ eps=%.4f" % epsilon)
    attack_params.set_hparam(attack_params.epsilon_name, epsilon)
    adv_x = attack.generate(inputs, y=one_hot_labels, **attack_params.values())
    acc = compute_accuracy(adv_x, labels, acc_mask)
    epsilon_acc_pairs.append((epsilon, acc))

  for epsilon, acc in epsilon_acc_pairs:
    if FLAGS.surrogate_attack:
      tf.logging.info(
          "Accuracy @ eps=%.4f: (%.4f, %.4f)" % (epsilon, acc[0], acc[1]))
    else:
      tf.logging.info("Accuracy @ eps=%.4f: %.4f" % (epsilon, acc))