Ejemplo n.º 1
0
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)


  if FLAGS.score_file:
    filename = os.path.expanduser(FLAGS.score_file)
    if not tf.gfile.Exists(filename):
      raise ValueError("The file to score doesn't exist: %s" % filename)
    results = score_file(filename)
    if not FLAGS.decode_to_file:
      raise ValueError("To score a file, specify --decode_to_file for results.")
    write_file = tf.gfile.Open(os.path.expanduser(FLAGS.decode_to_file), "w")
    for score in results:
      write_file.write("%.6f\n" % score)
    write_file.close()
    return

  hp = create_hparams()
  decode_hp = create_decode_hparams()

  estimator = trainer_lib.create_estimator(
      FLAGS.model,
      hp,
      t2t_trainer.create_run_config(hp),
      decode_hparams=decode_hp,
      use_tpu=FLAGS.use_tpu)

  decode(estimator, hp, decode_hp)
Ejemplo n.º 2
0
def main(argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  log_registry()

  if FLAGS.cloud_mlengine:
    return cloud_mlengine.launch()

  if FLAGS.generate_data:
    generate_data()

  if cloud_mlengine.job_dir():
    FLAGS.output_dir = cloud_mlengine.job_dir()

  if argv:
    set_hparams_from_args(argv[1:])
  hparams = create_hparams()
  if is_chief():
    save_metadata(hparams)

  with maybe_cloud_tpu():
    exp_fn = create_experiment_fn()
    exp = exp_fn(create_run_config(hparams), hparams)
    execute_schedule(exp)
Ejemplo n.º 3
0
    def __init__(self, translate_host, translate_port, source_lang, target_lang, model_name, problem, t2t_usr_dir, data_dir, preprocess_cmd, postprocess_cmd):
        """Initialize a TransformerTranslator object according to the given 
        configuration settings.
        
        @param translate_port: the port at which the Moses translator operates
        @param recase_port: the port at which the recaser operates
        @param source_lang: source language (ISO-639-1 ID)
        @param target_lang: target language (ISO-639-1 ID)
        @param preprocess_cmd: bash command for text preprocessing
        @param postprocess_cmd: bash command for text posprocessing
        """
        # precompile Tensorflow server addresses
        self.server = translate_host + ":" + translate_port

        # initialize text processing tools (can be shared among threads)
        self.tokenizer = Tokenizer({'lowercase': True,
                                    'moses_escape': True})
        self.preprocess = preprocess_cmd
        self.postprocess = postprocess_cmd
        usr_dir.import_usr_dir(t2t_usr_dir)
        self.problem = registry.problem(problem)
        hparams = tf.contrib.training.HParams(
            data_dir=os.path.expanduser(data_dir))
        self.problem.get_hparams(hparams)
        self.request_fn = serving_utils.make_grpc_request_fn(
            servable_name=model_name,
            server=self.server,
            timeout_secs=30)
Ejemplo n.º 4
0
def _initialize_t2t(t2t_usr_dir):
    global T2T_INITIALIZED
    if not T2T_INITIALIZED:
        logging.info("Setting up tensor2tensor library...")
        tf.logging.set_verbosity(tf.logging.INFO)
        usr_dir.import_usr_dir(t2t_usr_dir)
        T2T_INITIALIZED = True
Ejemplo n.º 5
0
def create_teacher_experiment(run_config, hparams, argv):
  """Creates experiment function."""
  tf.logging.info("training teacher")
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  t2t_trainer.maybe_log_registry_and_exit()

  if FLAGS.cloud_mlengine:
    return cloud_mlengine.launch()

  if FLAGS.generate_data:
    t2t_trainer.generate_data()

  if cloud_mlengine.job_dir():
    FLAGS.output_dir = cloud_mlengine.job_dir()

  if argv:
    t2t_trainer.set_hparams_from_args(argv[1:])

  with t2t_trainer.maybe_cloud_tpu():
    hparams.distill_phase = "train"
    exp_fn = t2t_trainer.create_experiment_fn()
    exp = exp_fn(run_config, hparams)
    return exp
Ejemplo n.º 6
0
def init():
  # global input_encoder, output_decoder, fname, problem
  global problem
  tf.logging.set_verbosity(tf.logging.INFO)
  tf.logging.info("importing ghsumm/trainer from {}".format(t2t_usr_dir))
  usr_dir.import_usr_dir(t2t_usr_dir)
  print(t2t_usr_dir)
  problem = registry.problem(problem_name)
  hparams = tf.contrib.training.HParams(data_dir=os.path.expanduser(data_dir))
  problem.get_hparams(hparams)
Ejemplo n.º 7
0
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

  # Create hparams
  hparams = create_hparams()
  hparams.force_full_predict = True
  batch_size = hparams.batch_size

  # Iterating over dev/test partition of the data.
  # Change the data partition if necessary.
  dataset = registry.problem(FLAGS.problem).dataset(
      tf.estimator.ModeKeys.PREDICT,
      shuffle_files=False,
      hparams=hparams)

  dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
  data = dataset.make_one_shot_iterator().get_next()
  input_data = dict((k, data[k]) for k in data.keys() if k.startswith("input"))

  # Creat model
  model_cls = registry.model(FLAGS.model)
  model = model_cls(hparams, tf.estimator.ModeKeys.PREDICT)
  prediction_ops = model.infer(input_data)

  # Confusion Matrix
  nr = hparams.problem.num_rewards
  cm_per_frame = np.zeros((nr, nr), dtype=np.uint64)
  cm_next_frame = np.zeros((nr, nr), dtype=np.uint64)

  saver = tf.train.Saver()
  with tf.train.SingularMonitoredSession() as sess:
    # Load latest checkpoint
    ckpt = tf.train.get_checkpoint_state(FLAGS.output_dir).model_checkpoint_path
    saver.restore(sess.raw_session(), ckpt)

    counter = 0
    while not sess.should_stop():
      counter += 1
      if counter % 1 == 0:
        print(counter)

      # Predict next frames
      rew_pd, rew_gt = sess.run(
          [prediction_ops["target_reward"], data["target_reward"]])

      for i in range(batch_size):
        cm_next_frame[rew_gt[i, 0, 0], rew_pd[i, 0, 0]] += 1
        for gt, pd in zip(rew_gt[i], rew_pd[i]):
          cm_per_frame[gt, pd] += 1

  print_confusion_matrix("Per-frame Confusion Matrix", cm_per_frame)
  print_confusion_matrix("Next-frame Confusion Matrix", cm_next_frame)
Ejemplo n.º 8
0
def main(_):
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

  # Calculate the list of problems to generate.
  problems = sorted(
      list(_SUPPORTED_PROBLEM_GENERATORS) + registry.list_problems())
  for exclude in FLAGS.exclude_problems.split(","):
    if exclude:
      problems = [p for p in problems if exclude not in p]
  if FLAGS.problem and FLAGS.problem[-1] == "*":
    problems = [p for p in problems if p.startswith(FLAGS.problem[:-1])]
  elif FLAGS.problem and "," in FLAGS.problem:
    problems = [p for p in problems if p in FLAGS.problem.split(",")]
  elif FLAGS.problem:
    problems = [p for p in problems if p == FLAGS.problem]
  else:
    problems = []

  # Remove TIMIT if paths are not given.
  if getattr(FLAGS, "timit_paths", None):
    problems = [p for p in problems if "timit" not in p]
  # Remove parsing if paths are not given.
  if getattr(FLAGS, "parsing_path", None):
    problems = [p for p in problems if "parsing_english_ptb" not in p]

  if not problems:
    problems_str = "\n  * ".join(
        sorted(list(_SUPPORTED_PROBLEM_GENERATORS) + registry.list_problems()))
    error_msg = ("You must specify one of the supported problems to "
                 "generate data for:\n  * " + problems_str + "\n")
    error_msg += ("TIMIT and parsing need data_sets specified with "
                  "--timit_paths and --parsing_path.")
    raise ValueError(error_msg)

  if not FLAGS.data_dir:
    FLAGS.data_dir = tempfile.gettempdir()
    tf.logging.warning("It is strongly recommended to specify --data_dir. "
                       "Data will be written to default data_dir=%s.",
                       FLAGS.data_dir)
  FLAGS.data_dir = os.path.expanduser(FLAGS.data_dir)
  tf.gfile.MakeDirs(FLAGS.data_dir)

  tf.logging.info("Generating problems:\n%s"
                  % registry.display_list_by_prefix(problems,
                                                    starting_spaces=4))
  if FLAGS.only_list:
    return
  for problem in problems:
    set_random_seed()

    if problem in _SUPPORTED_PROBLEM_GENERATORS:
      generate_data_for_problem(problem)
    else:
      generate_data_for_registered_problem(problem)
Ejemplo n.º 9
0
def main(argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  t2t_trainer.maybe_log_registry_and_exit()


  if FLAGS.generate_data:
    t2t_trainer.generate_data()

  if argv:
    t2t_trainer.set_hparams_from_args(argv[1:])
  hparams = t2t_trainer.create_hparams()
  trainer_lib.add_problem_hparams(hparams, FLAGS.problem)
  pruning_params = create_pruning_params()
  pruning_strategy = create_pruning_strategy(pruning_params.strategy)

  config = t2t_trainer.create_run_config(hparams)
  params = {"batch_size": hparams.batch_size}

  # add "_rev" as a hack to avoid image standardization
  problem = registry.problem(FLAGS.problem)
  input_fn = problem.make_estimator_input_fn(tf.estimator.ModeKeys.EVAL,
                                             hparams)
  dataset = input_fn(params, config).repeat()
  features, labels = dataset.make_one_shot_iterator().get_next()

  sess = tf.Session()

  model_fn = t2t_model.T2TModel.make_estimator_model_fn(
      FLAGS.model, hparams, use_tpu=FLAGS.use_tpu)
  spec = model_fn(
      features,
      labels,
      tf.estimator.ModeKeys.EVAL,
      params=hparams,
      config=config)

  # Restore weights
  saver = tf.train.Saver()
  checkpoint_path = os.path.expanduser(FLAGS.output_dir or
                                       FLAGS.checkpoint_path)
  saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path))

  def eval_model():
    preds = spec.predictions["predictions"]
    preds = tf.argmax(preds, -1, output_type=labels.dtype)
    _, acc_update_op = tf.metrics.accuracy(labels=labels, predictions=preds)
    sess.run(tf.initialize_local_variables())
    for _ in range(FLAGS.eval_steps):
      acc = sess.run(acc_update_op)
    return acc

  pruning_utils.sparsify(sess, eval_model, pruning_strategy, pruning_params)
Ejemplo n.º 10
0
def main(argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  t2t_trainer.maybe_log_registry_and_exit()

  if FLAGS.cloud_mlengine:
    cloud_mlengine.launch()
    return

  if FLAGS.generate_data:
    t2t_trainer.generate_data()

  if cloud_mlengine.job_dir():
    FLAGS.output_dir = cloud_mlengine.job_dir()

  if argv:
    t2t_trainer.set_hparams_from_args(argv[1:])

  with t2t_trainer.maybe_cloud_tpu():
    root_output_dir = FLAGS.output_dir

    # Train Teacher ============
    hparams = t2t_trainer.create_hparams()
    hparams.distill_phase = "train"
    teacher_dir = os.path.join(root_output_dir, "teacher")
    FLAGS.output_dir = teacher_dir

    exp_fn = t2t_trainer.create_experiment_fn()
    run_config = t2t_trainer.create_run_config(hparams)
    exp = exp_fn(run_config, hparams)
    if t2t_trainer.is_chief():
      t2t_trainer.save_metadata(hparams)
    t2t_trainer.execute_schedule(exp)
    # ==========================
    # Train Student ============
    hparams = t2t_trainer.create_hparams()
    hparams.add_hparam("teacher_dir", teacher_dir)
    hparams.distill_phase = "distill"
    student_dir = os.path.join(root_output_dir, "student")
    FLAGS.output_dir = student_dir

    exp_fn = t2t_trainer.create_experiment_fn()
    run_config = t2t_trainer.create_run_config(hparams)
    exp = exp_fn(run_config, hparams)

    if t2t_trainer.is_chief():
      t2t_trainer.save_metadata(hparams)
    t2t_trainer.execute_schedule(exp)
Ejemplo n.º 11
0
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  tpu_trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  log_registry()

  if FLAGS.generate_data:
    generate_data()

  hparams = create_hparams()
  run_config = create_run_config(hparams)

  exp_fn = create_experiment_fn()
  exp = exp_fn(run_config, hparams)
  execute_schedule(exp)
Ejemplo n.º 12
0
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  FLAGS.use_tpu = False  # decoding not supported on TPU

  hp = create_hparams()
  decode_hp = create_decode_hparams()

  estimator = trainer_lib.create_estimator(
      FLAGS.model,
      hp,
      t2t_trainer.create_run_config(hp),
      decode_hparams=decode_hp,
      use_tpu=False)

  decode(estimator, hp, decode_hp)
Ejemplo n.º 13
0
    def __init__(self, str_tokens, eval_tokens=None, batch_size=1000):
        """
        Args:
            batch_size: used for encoding
            str_tokens: the original token inputs, as the format of ['t1', 't2'...]. The items within should be strings
            eval_tokens: if not None, then should be the same length as tokens, for similarity comparisons.
        """
        assert type(str_tokens) is list
        assert len(str_tokens) > 0
        assert type(str_tokens[0]) is str
        self.str_tokens = str_tokens
        if eval_tokens is not None:
            assert (len(eval_tokens) == len(str_tokens)
                    and type(eval_tokens[0]) is str)
        self.eval_tokens = eval_tokens
        tf.logging.set_verbosity(tf.logging.INFO)
        tf.logging.info('tf logging set to INFO by: %s' %
                        self.__class__.__name__)

        usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
        trainer_utils.log_registry()
        trainer_utils.validate_flags()
        assert FLAGS.schedule == "train_and_evaluate"
        data_dir = os.path.expanduser(FLAGS.data_dir)
        out_dir = os.path.expanduser(FLAGS.output_dir)

        hparams = trainer_utils.create_hparams(FLAGS.hparams_set,
                                               data_dir,
                                               passed_hparams=FLAGS.hparams)

        trainer_utils.add_problem_hparams(hparams, FLAGS.problems)
        # print(hparams)
        hparams.eval_use_test_set = True

        self.estimator, _ = trainer_utils.create_experiment_components(
            data_dir=data_dir,
            model_name=FLAGS.model,
            hparams=hparams,
            run_config=trainer_utils.create_run_config(out_dir))

        decode_hp = decoding.decode_hparams(FLAGS.decode_hparams)
        decode_hp.add_hparam("shards", FLAGS.decode_shards)
        decode_hp.batch_size = batch_size
        self.decode_hp = decode_hp
        self.arr_results = None
        self._encoding_len = 1
Ejemplo n.º 14
0
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  if FLAGS.verbose:
    tf.logging.set_verbosity(tf.logging.DEBUG)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  rescorer = Rescorer()
  output_handlers = create_output_handlers()
  for ids, src_sentences, trg_sentences in batched_iter(
      FLAGS.src_test, FLAGS.trg_test, FLAGS.batch_size):
    res = sampler.sample(src_sentences, trg_sentences)
    for idx, samples in zip(ids, res):
      samples.sort(reverse=True, key=operator.itemgetter(0))
      for ohandler in output_handlers:
        ohandler.write(idx, samples)
    break
  for ohandler in output_handlers:
    ohandler.finish()
Ejemplo n.º 15
0
  def _init_env(self):
    tf.logging.info("Import usr dir from %s",self._usr_dir)
    if self._usr_dir != None:
      usr_dir.import_usr_dir(self._usr_dir)
    tf.logging.info("Start to create hparams,for %s of %s",self._problem,self._hparams_set)
    self._hparams = trainer_utils.create_hparams(self._hparams_set,self._data_dir)
    trainer_utils.add_problem_hparams(self._hparams, self._problem)
    tf.logging.info("build the model_fn of %s of %s",self._model_name,self._hparams)
    #self._model_fn = model_builder.build_model_fn(self._model_name,self._hparams)
    #self._model_fn = model_builder.build_model_fn(self._model_name)
    self._inputs_ph = tf.placeholder(dtype=tf.int32)# shape not specified,any shape

    batch_inputs = tf.reshape(self._inputs_ph,[self._batch_size,-1,1,1])
    #batch_inputs = tf.reshape(self._inputs_ph, [-1, -1, 1, 1])

    targets_ph = tf.placeholder(dtype=tf.int32)
    batch_targets = tf.reshape(targets_ph,[1,-1,1,1])
    features = {"inputs": batch_inputs,
            "problem_choice": 0,  # We run on the first problem here.
            "input_space_id": self._hparams.problems[0].input_space_id,
            "target_space_id": self._hparams.problems[0].target_space_id}
    mode = tf.estimator.ModeKeys.PREDICT
    estimator_spec = model_builder.model_fn(self._model_name,features, mode,self._hparams,
      problem_names=[self._problem],decode_hparams=self._hparams_dc)
    predictions_dict=estimator_spec.predictions
    self._predictions = predictions_dict["outputs"]
    #self._scores=predictions_dict['scores'] not return when greedy search
    tf.logging.info("Start to init tf session")
    if self._isGpu:
      print('Using GPU in Decoder')
      gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=self._fraction)
      self._sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,log_device_placement=False,gpu_options=gpu_options))
    else:
      print('Using CPU in Decoder')
      gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0)
      config = tf.ConfigProto(gpu_options=gpu_options)
      config.allow_soft_placement=True
      config.log_device_placement=False
      self._sess = tf.Session(config=config) 
    with self._sess.as_default():
        ckpt = saver_mod.get_checkpoint_state(self._model_dir)
        saver = tf.train.Saver()
        tf.logging.info("Start to restore the parameters from %s",ckpt.model_checkpoint_path)
        saver.restore(self._sess,ckpt.model_checkpoint_path)
    tf.logging.info("Finish intialize environment")
Ejemplo n.º 16
0
def main(argv):
  tf.logging.set_verbosity(tf.logging.INFO)

  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

  # If we just have to print the registry, do that and exit early.
  maybe_log_registry_and_exit()

  # Create HParams.
  if argv:
    set_hparams_from_args(argv[1:])
  if FLAGS.schedule != "run_std_server":
    hparams = create_hparams()
  if FLAGS.gpu_automatic_mixed_precision:
    setattr(hparams, "gpu_automatic_mixed_precision", True)

  if FLAGS.schedule == "train" or FLAGS.schedule == "train_eval_and_decode":
    mlperf_log.transformer_print(key=mlperf_log.RUN_START, hparams=hparams)
  if FLAGS.schedule == "run_std_server":
    run_std_server()
  mlperf_log.transformer_print(
      key=mlperf_log.RUN_SET_RANDOM_SEED, value=FLAGS.random_seed,
      hparams=hparams)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  for flag, val in FLAGS.__flags.items():
    print(flag, ": ", val.value)  

  if FLAGS.cloud_mlengine:
    cloud_mlengine.launch()
    return

  if FLAGS.generate_data:
    generate_data()

  if cloud_mlengine.job_dir():
    FLAGS.output_dir = cloud_mlengine.job_dir()

  exp_fn = create_experiment_fn()
  exp = exp_fn(create_run_config(hparams), hparams)
  if is_chief():
    save_metadata(hparams)
  execute_schedule(exp)
  if FLAGS.schedule != "train":
    mlperf_log.transformer_print(key=mlperf_log.RUN_FINAL,
                                 hparams=hparams)
Ejemplo n.º 17
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    log_registry()

    if FLAGS.generate_data:
        generate_data()

    hparams = create_hparams()
    run_config = create_run_config(hparams)

    if is_chief():
        save_metadata(hparams)

    exp_fn = create_experiment_fn()
    exp = exp_fn(run_config, hparams)
    execute_schedule(exp)
Ejemplo n.º 18
0
    def __init__(self,
                 params,
                 train_path="",
                 dev_path="",
                 test_path="",
                 cleanup=False):
        # Point out the current directory with t2t problem specified for g2p task.
        usr_dir.import_usr_dir(os.path.dirname(os.path.abspath(__file__)))
        self.params = params
        self.test_path = test_path
        if not os.path.exists(self.params.model_dir):
            os.makedirs(self.params.model_dir)

        # Register g2p problem.
        self.problem = registry._PROBLEMS[self.params.problem_name](
            self.params.model_dir,
            train_path=train_path,
            dev_path=dev_path,
            test_path=test_path,
            cleanup=cleanup)

        self.frozen_graph_filename = os.path.join(self.params.model_dir,
                                                  "frozen_model.pb")
        self.inputs, self.features, self.input_fn = None, None, None
        self.mon_sess, self.estimator_spec, self.g2p_gt_map = None, None, None
        self.first_ex = False
        if train_path:
            self.train_preprocess_file_path, self.dev_preprocess_file_path =\
                None, None
            self.estimator, self.decode_hp, self.hparams =\
                self.__prepare_model(train_mode=True)
            self.train_preprocess_file_path, self.dev_preprocess_file_path =\
                self.problem.generate_preprocess_data()

        elif os.path.exists(self.frozen_graph_filename):
            self.estimator, self.decode_hp, self.hparams =\
                self.__prepare_model()
            self.__load_graph()
            self.checkpoint_path = tf.train.latest_checkpoint(
                self.params.model_dir)

        else:
            self.estimator, self.decode_hp, self.hparams =\
                self.__prepare_model()
Ejemplo n.º 19
0
    def __init__(self, processor_configuration):
        """Creates the Transformer estimator.

    Args:
      processor_configuration: A ProcessorConfiguration protobuffer with the
        transformer fields populated.
    """
        # Do the pre-setup tensor2tensor requires for flags and configurations.
        transformer_config = processor_configuration["transformer"]
        FLAGS.output_dir = transformer_config["model_dir"]
        usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
        data_dir = os.path.expanduser(transformer_config["data_dir"])

        # Create the basic hyper parameters.
        self.hparams = trainer_lib.create_hparams(
            transformer_config["hparams_set"],
            transformer_config["hparams"],
            data_dir=data_dir,
            problem_name=transformer_config["problem"])

        decode_hp = decoding.decode_hparams()
        decode_hp.add_hparam("shards", 1)
        decode_hp.add_hparam("shard_id", 0)

        # Create the estimator and final hyper parameters.
        self.estimator = trainer_lib.create_estimator(
            transformer_config["model"],
            self.hparams,
            t2t_trainer.create_run_config(self.hparams),
            decode_hparams=decode_hp,
            use_tpu=False)

        # Fetch the vocabulary and other helpful variables for decoding.
        self.source_vocab = self.hparams.problem_hparams.vocabulary["inputs"]
        self.targets_vocab = self.hparams.problem_hparams.vocabulary["targets"]
        self.const_array_size = 10000

        # Prepare the Transformer's debug data directory.
        run_dirs = sorted(
            glob.glob(os.path.join("/tmp/t2t_server_dump", "run_*")))
        for run_dir in run_dirs:
            shutil.rmtree(run_dir)
Ejemplo n.º 20
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

    if FLAGS.score_file:
        # filename = os.path.expanduser(FLAGS.score_file)
        filename = FLAGS.score_file
        print(filename)
        if not tf.gfile.Exists(filename):
            raise ValueError("The file to score doesn't exist: %s" % filename)
        results = score_file(filename)
        if not FLAGS.decode_to_file:
            raise ValueError(
                "To score a file, specify --decode_to_file for results.")
        write_file = open(FLAGS.decode_to_file, "w")
        for score in results:
            write_file.write("%.6f\n" % score)
        write_file.close()
        return
Ejemplo n.º 21
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

    ckpt_dir = os.path.expanduser(FLAGS.output_dir)

    hparams = create_hparams()
    hparams.no_data_parallelism = True  # To clear the devices
    run_config = t2t_trainer.create_run_config(hparams)

    estimator = create_estimator(run_config, hparams)

    problem = hparams.problem
    strategy = trainer_lib.create_export_strategy(problem, hparams)

    export_dir = os.path.join(ckpt_dir, "export", strategy.name)
    strategy.export(estimator,
                    export_dir,
                    checkpoint_path=tf.train.latest_checkpoint(ckpt_dir))
Ejemplo n.º 22
0
def main(argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    hvd.init()

    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    log_registry()

    if FLAGS.cloud_mlengine:
        return cloud_mlengine.launch()

    if FLAGS.generate_data:
        generate_data()

    if hasattr(FLAGS, "job_dir") and FLAGS.job_dir:
        FLAGS.output_dir = FLAGS.job_dir

    if argv:
        set_hparams_from_args(argv[1:])

    #
    hparams = create_hparams()

    if is_chief():
        save_metadata(hparams)

    # create_run_config会调用trainer_lib.create_session_config,这个函数包含gup_options初始化
    config = create_run_config(hparams)
    decode_hparams = decoding.decode_hparams(FLAGS.decode_hparams)
    schedule = FLAGS.schedule

    estimator = create_estimator_fn(FLAGS.model, hparams, config, schedule, decode_hparams)

    # logging_hook = tf.train.LoggingTensorHook({"step": "test"}, every_n_iter=5)
    bcast_hook = hvd.BroadcastGlobalVariablesHook(0)

    estimator.train(
        input_fn=train_input_fn(hparams),
        steps=FLAGS.train_steps,
        hooks=[bcast_hook]
    )
Ejemplo n.º 23
0
    def entry(self):

        tf.logging.set_verbosity(tf.logging.INFO)
        trainer_lib.set_random_seed(FLAGS.random_seed)
        usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

        print("###self defined hp###")
        print(str(FLAGS.data_dir))
        print(str(FLAGS.problem))
        print(str(FLAGS.model))
        print(str(FLAGS.hparams_set))
        print(str(FLAGS.output_dir))
        print(str(FLAGS.decode_hparams))

        hp = self.create_hparams()
        decode_hp = self.create_decode_hparams()
        estimator = self.create_new_estimator(hp, decode_hp)

        output_decode = self.my_decode(estimator, hp, decode_hp)
        print('output decode-res  = %s ' % str(output_decode))
        return output_decode
Ejemplo n.º 24
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    problem = registry.problem(FLAGS.problem)
    hparams = tf.contrib.training.HParams(
        data_dir=os.path.expanduser(FLAGS.data_dir))
    problem.get_hparams(hparams)
    predict_fn = make_predict_fn(FLAGS.frozen_graph_filename)
    while True:
        inputs = input(">> ")
        outputs = predict([inputs], problem, predict_fn)
        outputs, = outputs
        output, score = outputs
        print_str = """
Input:
{inputs}

Output (Score {score:.3f}):
{output}
    """
        print(print_str.format(inputs=inputs, output=output, score=score))
Ejemplo n.º 25
0
  def __init__(self, processor_configuration):
    """Creates the Transformer estimator.

    Args:
      processor_configuration: A ProcessorConfiguration protobuffer with the
        transformer fields populated.
    """
    # Do the pre-setup tensor2tensor requires for flags and configurations.
    transformer_config = processor_configuration["transformer"]
    FLAGS.output_dir = transformer_config["model_dir"]
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    data_dir = os.path.expanduser(transformer_config["data_dir"])

    # Create the basic hyper parameters.
    self.hparams = trainer_lib.create_hparams(
        transformer_config["hparams_set"],
        transformer_config["hparams"],
        data_dir=data_dir,
        problem_name=transformer_config["problem"])

    decode_hp = decoding.decode_hparams()
    decode_hp.add_hparam("shards", 1)
    decode_hp.add_hparam("shard_id", 0)

    # Create the estimator and final hyper parameters.
    self.estimator = trainer_lib.create_estimator(
        transformer_config["model"],
        self.hparams,
        t2t_trainer.create_run_config(self.hparams),
        decode_hparams=decode_hp, use_tpu=False)

    # Fetch the vocabulary and other helpful variables for decoding.
    self.source_vocab = self.hparams.problem_hparams.vocabulary["inputs"]
    self.targets_vocab = self.hparams.problem_hparams.vocabulary["targets"]
    self.const_array_size = 10000

    # Prepare the Transformer's debug data directory.
    run_dirs = sorted(glob.glob(os.path.join("/tmp/t2t_server_dump", "run_*")))
    for run_dir in run_dirs:
      shutil.rmtree(run_dir)
Ejemplo n.º 26
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

    if FLAGS.checkpoint_path:
        checkpoint_path = FLAGS.checkpoint_path
        ckpt_dir = os.path.dirname(checkpoint_path)
    else:
        ckpt_dir = os.path.expanduser(FLAGS.output_dir)
        checkpoint_path = tf.train.latest_checkpoint(ckpt_dir)

    hparams = create_hparams()
    hparams.no_data_parallelism = True  # To clear the devices
    problem = hparams.problem
    decode_hparams = decoding.decode_hparams(FLAGS.decode_hparams)

    export_dir = FLAGS.export_dir or os.path.join(ckpt_dir, "export")

    if FLAGS.export_as_tfhub:
        checkpoint_path = tf.train.latest_checkpoint(ckpt_dir)
        export_as_tfhub_module(FLAGS.model, hparams, decode_hparams, problem,
                               checkpoint_path, export_dir)
        return

    run_config = t2t_trainer.create_run_config(hparams)

    estimator = create_estimator(run_config, hparams)

    exporter = tf_estimator.FinalExporter(
        "exporter",
        lambda: problem.serving_input_fn(hparams, decode_hparams, FLAGS.use_tpu
                                         ),
        as_text=FLAGS.as_text)

    exporter.export(estimator,
                    export_dir,
                    checkpoint_path=checkpoint_path,
                    eval_result=None,
                    is_the_final_export=True)
Ejemplo n.º 27
0
  def __init__(self, data_dir, model_dir):
    """Creates the Transformer estimator.

    Args:
      data_dir: The training data directory.
      model_dir: The trained model directory.
    """
    # Do the pre-setup tensor2tensor requires for flags and configurations.
    FLAGS.output_dir = model_dir
    FLAGS.data_dir = data_dir
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    data_dir = os.path.expanduser(data_dir)

    # Create the basic hyper parameters.
    self.hparams = tpu_trainer_lib.create_hparams(
        FLAGS.hparams_set,
        FLAGS.hparams,
        data_dir=data_dir,
        problem_name=FLAGS.problems)

    decode_hp = decoding.decode_hparams(FLAGS.decode_hparams)
    decode_hp.add_hparam("shards", 1)
    decode_hp.add_hparam("shard_id", 0)

    # Create the estimator and final hyper parameters.
    self.estimator = tpu_trainer_lib.create_estimator(
        FLAGS.model,
        self.hparams,
        tpu_trainer.create_run_config(),
        decode_hp, use_tpu=False)

    # Fetch the vocabulary and other helpful variables for decoding.
    self.source_vocab = self.hparams.problems[0].vocabulary["inputs"]
    self.targets_vocab = self.hparams.problems[0].vocabulary["targets"]
    self.const_array_size = 10000

    # Prepare the Transformer's debug data directory.
    run_dirs = sorted(glob.glob(os.path.join("/tmp/t2t_server_dump", "run_*")))
    for run_dir in run_dirs:
      shutil.rmtree(run_dir)
Ejemplo n.º 28
0
def entry(input_str):
  # global estimator
  # global hp
  # global decode_hp
  # flags.FLAGS(argv , known_only=True)

  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

  print("###self defined hp###")
  print(str(FLAGS.data_dir))
  print(str(FLAGS.problem))
  print(str(FLAGS.model))
  print(str(FLAGS.hparams_set))
  print(str(FLAGS.output_dir))
  print(str(FLAGS.decode_hparams))

  # if hp is None:
  #   print('hp is None !')
  #   hp = create_hparams()
  # if decode_hp is None:
  #   print('decode_hp is None !')
  #   decode_hp = create_decode_hparams()
  # if estimator is None:
  #   print('estimator is None !')
  #   estimator = my_trainer_lib.create_estimator(
  #     FLAGS.model,
  #     hp,
  #     t2t_trainer.create_run_config(hp),
  #     decode_hparams=decode_hp,
  #     use_tpu=FLAGS.use_tpu)

  hp=app.config['hp']
  decode_hp=app.config['decode_hp']
  estimator=app.config['estimator']

  output_decode = my_decode(estimator, hp, decode_hp,input_str)
  print('output-decode-res  = %s ' % str(output_decode))
  return output_decode
Ejemplo n.º 29
0
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  validate_flags()
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  problem = registry.problem(FLAGS.problem)
  hparams = tf.contrib.training.HParams(
      data_dir=os.path.expanduser(FLAGS.data_dir))
  problem.get_hparams(hparams)
  request_fn = make_request_fn()
  while True:
    inputs = FLAGS.inputs_once if FLAGS.inputs_once else input(">> ")
    outputs = serving_utils.predict([inputs], problem, request_fn)
    print_str = """
Input:
{inputs}

Output:
{outputs}
    """
    print(print_str.format(inputs=inputs, outputs=outputs[0]))
    if FLAGS.inputs_once:
      break
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

    if FLAGS.score_file:
        filename = os.path.expanduser(FLAGS.score_file)
        if not tf.gfile.Exists(filename):
            raise ValueError("The file to score doesn't exist: %s" % filename)
        results = score_file(filename)
        if not FLAGS.decode_to_file:
            raise ValueError(
                "To score a file, specify --decode_to_file for results.")
        write_file = tf.gfile.Open(os.path.expanduser(FLAGS.decode_to_file),
                                   "w")
        for score in results:
            write_file.write("%.6f\n" % score)
        write_file.close()
        return

    hp = create_hparams()
    decode_hp = create_decode_hparams()
    run_config = t2t_trainer.create_run_config(hp)
    if FLAGS.disable_grappler_optimizations:
        run_config.session_config.graph_options.rewrite_options.disable_meta_optimizer = True

    # summary-hook in tf.estimator.EstimatorSpec requires
    # hparams.model_dir to be set.
    hp.add_hparam("model_dir", run_config.model_dir)

    estimator = trainer_lib.create_estimator(FLAGS.model,
                                             hp,
                                             run_config,
                                             decode_hparams=decode_hp,
                                             use_tpu=FLAGS.use_tpu)

    input_output = decode(estimator, hp, decode_hp)
    pdd.to_pickle([input_output, yll], './tmp/output.pkl')
    print('')
Ejemplo n.º 31
0
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

  if FLAGS.checkpoint_path:
    checkpoint_path = FLAGS.checkpoint_path
    ckpt_dir = os.path.dirname(checkpoint_path)
  else:
    ckpt_dir = os.path.expanduser(FLAGS.output_dir)
    checkpoint_path = tf.train.latest_checkpoint(ckpt_dir)

  hparams = create_hparams()
  hparams.no_data_parallelism = True  # To clear the devices
  problem = hparams.problem

  export_dir = FLAGS.export_dir or os.path.join(ckpt_dir, "export")

  if FLAGS.export_as_tfhub:
    checkpoint_path = tf.train.latest_checkpoint(ckpt_dir)
    decode_hparams = decoding.decode_hparams(FLAGS.decode_hparams)
    export_as_tfhub_module(FLAGS.model, hparams, decode_hparams, problem,
                           checkpoint_path, export_dir)
    return

  run_config = t2t_trainer.create_run_config(hparams)

  estimator = create_estimator(run_config, hparams)

  exporter = tf.estimator.FinalExporter(
      "exporter", lambda: problem.serving_input_fn(hparams), as_text=True)

  exporter.export(
      estimator,
      export_dir,
      checkpoint_path=checkpoint_path,
      eval_result=None,
      is_the_final_export=True)
Ejemplo n.º 32
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    validate_flags()
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    problem = registry.problem(FLAGS.problem)
    hparams = HParams(data_dir=os.path.expanduser(FLAGS.data_dir))
    problem.get_hparams(hparams)
    request_fn = make_request_fn()
    while True:
        inputs = FLAGS.inputs_once if FLAGS.inputs_once else input(">> ")
        outputs = serving_utils.predict([inputs], problem, request_fn)
        outputs, = outputs
        output, score = outputs
        if len(score.shape) > 0:  # pylint: disable=g-explicit-length-test
            print_str = """
Input:
{inputs}

Output (Scores [{score}]):
{output}
        """
            score_text = ",".join(["{:.3f}".format(s) for s in score])
            print(
                print_str.format(inputs=inputs,
                                 output=output,
                                 score=score_text))
        else:
            print_str = """
Input:
{inputs}

Output (Score {score:.3f}):
{output}
        """
            print(print_str.format(inputs=inputs, output=output, score=score))

        if FLAGS.inputs_once:
            break
Ejemplo n.º 33
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    validate_flags()
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    problem = registry.problem(FLAGS.problem)
    hparams = tf.contrib.training.HParams(
        data_dir=os.path.expanduser(FLAGS.data_dir))
    problem.get_hparams(hparams)
    request_fn = make_request_fn()
    if FLAGS.test_data:
        inputs = []
        with open(FLAGS.test_data, 'r') as f:
            with open(FLAGS.output, 'w+') as fout:
                print("Id,Prediction", file=fout)
                for line in tqdm(f):
                    num, text = line.rstrip().split(',', 1)
                    outputs = serving_utils.predict([text], problem,
                                                    request_fn)
                    print('{},{}'.format(
                        num, "-1" if outputs[0][0] == "neg" else "1"),
                          file=fout)
    else:
        print("Missing test_data nd output file")
Ejemplo n.º 34
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    validate_flags()
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    problem = registry.problem(FLAGS.problem)
    hparams = HParams(data_dir=os.path.expanduser(FLAGS.data_dir))
    problem.get_hparams(hparams)
    request_fn = make_request_fn()
    while True:
        inputs = FLAGS.inputs_once if FLAGS.inputs_once else input(">> ")
        if FLAGS.json:
            inputs = json.loads(inputs)
            ret = serving_utils.predict_features([inputs], problem, request_fn)
            outputs = ret["outputs"]
        else:
            outputs = serving_utils.predict([inputs], problem, request_fn)
        outputs, = outputs
        output, score = outputs
        if problem.multi_targets:
            print_str = """
Input:
{inputs}

Output (Score {score}):
{output}
      """
        else:
            print_str = """
Input:
{inputs}

Output (Score {score:.3f}):
{output}
      """
        print(print_str.format(inputs=inputs, output=output, score=score))
        if FLAGS.inputs_once:
            break
Ejemplo n.º 35
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

    problem = registry.problem(FLAGS.problem)
    hparams = tf.contrib.training.HParams(
        data_dir=os.path.expanduser(FLAGS.data_dir))
    problem.get_hparams(hparams)

    fname = "inputs" if problem.has_inputs else "targets"
    input_encoder = problem.feature_info[fname].encoder
    output_decoder = problem.feature_info["targets"].encoder

    stub = create_stub()

    while True:
        prompt = ">> "
        if FLAGS.inputs_once:
            inputs = FLAGS.inputs_once
        else:
            inputs = input(prompt)

        input_ids = encode(inputs, input_encoder)
        output_ids = query(stub, input_ids, feature_name=fname)

        outputs = decode(output_ids, output_decoder)

        print_str = """
Input:
{inputs}

Output:
{outputs}
    """
        print(print_str.format(inputs=inputs, outputs=outputs))
        if FLAGS.inputs_once:
            break
Ejemplo n.º 36
0
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

  problem = registry.problem(FLAGS.problem)
  hparams = tf.contrib.training.HParams(
      data_dir=os.path.expanduser(FLAGS.data_dir))
  problem.get_hparams(hparams)

  fname = "inputs" if problem.has_inputs else "targets"
  input_encoder = problem.feature_info[fname].encoder
  output_decoder = problem.feature_info["targets"].encoder

  stub = create_stub()

  while True:
    prompt = ">> "
    if FLAGS.inputs_once:
      inputs = FLAGS.inputs_once
    else:
      inputs = input(prompt)

    input_ids = encode(inputs, input_encoder)
    output_ids = query(stub, input_ids, feature_name=fname)

    outputs = decode(output_ids, output_decoder)

    print_str = """
Input:
{inputs}

Output:
{outputs}
    """
    print(print_str.format(inputs=inputs, outputs=outputs))
    if FLAGS.inputs_once:
      break
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

    if FLAGS.score_file:
        filename = os.path.expanduser(FLAGS.score_file)
        if not tf.gfile.Exists(filename):
            raise ValueError("The file to score doesn't exist: %s" % filename)
        results = score_file(filename)
        if not FLAGS.decode_to_file:
            raise ValueError(
                "To score a file, specify --decode_to_file for results.")
        write_file = tf.gfile.Open(os.path.expanduser(FLAGS.decode_to_file),
                                   "w")
        for score in results:
            write_file.write("%.6f\n" % score)

            print(
                "=================================================================================================================================="
            )

            print(score)
        write_file.close()
        return

    hp = create_hparams()
    decode_hp = create_decode_hparams()

    estimator = trainer_lib.create_estimator(FLAGS.model,
                                             hp,
                                             t2t_trainer.create_run_config(hp),
                                             decode_hparams=decode_hp,
                                             use_tpu=FLAGS.use_tpu)

    decode(estimator, hp, decode_hp)
Ejemplo n.º 38
0
def Translation(input):
  start = time.time()
  tf.logging.set_verbosity(tf.logging.INFO)
  validate_flags()
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  problem = registry.problem(FLAGS.problem)
  hparams = tf.contrib.training.HParams(
      data_dir=os.path.expanduser(FLAGS.data_dir))
  problem.get_hparams(hparams)
  request_fn = make_request_fn()

  # if FLAGS.word_cut:
  #   input = " ".join(jieba.cut(input))
  outputs = serving_utils.predict(input, problem, request_fn)
  print('outputs:',outputs)
  # outputs = '.'.join(result for result,score in outputs)
  for result, _ in outputs:
      yield result

  # outputs, = outputs
  # output, score = outputs
  # end = time.time()
  # print('client time:',(end - start))
  print_str = """
Ejemplo n.º 39
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

    ckpt_dir = os.path.expanduser(FLAGS.output_dir)

    hparams = create_hparams()
    hparams.no_data_parallelism = True  # To clear the devices
    run_config = t2t_trainer.create_run_config(hparams)

    estimator = create_estimator(run_config, hparams)

    problem = hparams.problem

    exporter = tf.estimator.FinalExporter(
        "exporter", lambda: problem.serving_input_fn(hparams), as_text=True)

    export_dir = os.path.join(ckpt_dir, "export")
    exporter.export(estimator,
                    export_dir,
                    checkpoint_path=tf.train.latest_checkpoint(ckpt_dir),
                    eval_result=None,
                    is_the_final_export=True)
Ejemplo n.º 40
0
def create_teacher_experiment(run_config, hparams, argv):
    """Creates experiment function."""
    tf.logging.info("training teacher")
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    t2t_trainer.maybe_log_registry_and_exit()

    if FLAGS.cloud_mlengine:
        return cloud_mlengine.launch()

    if FLAGS.generate_data:
        t2t_trainer.generate_data()

    if cloud_mlengine.job_dir():
        FLAGS.output_dir = cloud_mlengine.job_dir()

    if argv:
        t2t_trainer.set_hparams_from_args(argv[1:])

    hparams.distill_phase = "train"
    exp_fn = t2t_trainer.create_experiment_fn()
    exp = exp_fn(run_config, hparams)
    return exp
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    problem = registry.problem(FLAGS.problem)
    hparams = tf.contrib.training.HParams(
        data_dir=os.path.expanduser(FLAGS.data_dir))
    problem.get_hparams(hparams)
    predict_fn = make_predict_fn()
    tokenizer = MosesTokenizer("en")
    while True:
        inputs = input(">> ")
        inputs = tokenizer(inputs)
        inputs = html.unescape(" ".join(inputs).replace("@-@", "-"))
        outputs = predict([inputs], problem, predict_fn)
        outputs, = outputs
        output, score = outputs
        print_str = """
Input:
{inputs}

Output (Score {score:.3f}):
{output}
    """
        print(print_str.format(inputs=inputs, output=output, score=score))
Ejemplo n.º 42
0
  def __init__(self, params, train_path="", dev_path="", test_path="",
               cleanup=False, p2g_mode=False):
    # Point out the current directory with t2t problem specified for g2p task.
    usr_dir.import_usr_dir(os.path.dirname(os.path.abspath(__file__)))
    self.params = params
    self.test_path = test_path
    if not os.path.exists(self.params.model_dir):
      os.makedirs(self.params.model_dir)

    # Register g2p problem.
    self.problem = registry._PROBLEMS[self.params.problem_name](
        self.params.model_dir, train_path=train_path, dev_path=dev_path,
        test_path=test_path, cleanup=cleanup, p2g_mode=p2g_mode)

    self.frozen_graph_filename = os.path.join(self.params.model_dir,
                                              "frozen_model.pb")
    self.inputs, self.features, self.input_fn = None, None, None
    self.mon_sess, self.estimator_spec, self.g2p_gt_map = None, None, None
    self.first_ex = False
    if train_path:
      self.train_preprocess_file_path, self.dev_preprocess_file_path =\
          None, None
      self.estimator, self.decode_hp, self.hparams =\
          self.__prepare_model(train_mode=True)
      self.train_preprocess_file_path, self.dev_preprocess_file_path =\
          self.problem.generate_preprocess_data()

    elif os.path.exists(self.frozen_graph_filename):
      self.estimator, self.decode_hp, self.hparams =\
          self.__prepare_model()
      self.__load_graph()
      self.checkpoint_path = tf.train.latest_checkpoint(self.params.model_dir)

    else:
      self.estimator, self.decode_hp, self.hparams =\
          self.__prepare_model()
Ejemplo n.º 43
0
def main(argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    hparams = create_hparams()
    if FLAGS.schedule == "train" or FLAGS.schedule == "train_eval_and_decode":
        mlperf_log.transformer_print(key=mlperf_log.RUN_START,
                                     mlperf_mode=hparams.mlperf_mode)
    if FLAGS.schedule == "run_std_server":
        run_std_server()
    mlperf_log.transformer_print(key=mlperf_log.RUN_SET_RANDOM_SEED,
                                 value=FLAGS.random_seed,
                                 mlperf_mode=hparams.mlperf_mode)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    maybe_log_registry_and_exit()

    if FLAGS.cloud_mlengine:
        cloud_mlengine.launch()
        return

    if FLAGS.generate_data:
        generate_data()

    if cloud_mlengine.job_dir():
        FLAGS.output_dir = cloud_mlengine.job_dir()

    if argv:
        set_hparams_from_args(argv[1:])

    exp_fn = create_experiment_fn()
    exp = exp_fn(create_run_config(hparams), hparams)
    if is_chief():
        save_metadata(hparams)
    execute_schedule(exp)
    if FLAGS.schedule != "train":
        mlperf_log.transformer_print(key=mlperf_log.RUN_FINAL,
                                     mlperf_mode=hparams.mlperf_mode)
Ejemplo n.º 44
0
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  validate_flags()
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  problem = registry.problem(FLAGS.problem)
  hparams = tf.contrib.training.HParams(
      data_dir=os.path.expanduser(FLAGS.data_dir))
  problem.get_hparams(hparams)
  request_fn = make_request_fn()
  while True:
    inputs = FLAGS.inputs_once if FLAGS.inputs_once else input(">> ")
    outputs = serving_utils.predict([inputs], problem, request_fn)
    outputs, = outputs
    output, score = outputs
    print_str = """
Input:
{inputs}

Output (Score {score:.3f}):
{output}
    """
    print(print_str.format(inputs=inputs, output=output, score=score))
    if FLAGS.inputs_once:
      break
Ejemplo n.º 45
0
def main(_):
    # Fathom

    tf.logging.set_verbosity(tf.logging.INFO)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

    # Fathom
    fathom_t2t_model_setup()

    # Calculate the list of problems to generate.
    problems = sorted(
        list(_SUPPORTED_PROBLEM_GENERATORS) + registry.list_problems())
    for exclude in FLAGS.exclude_problems.split(","):
        if exclude:
            problems = [p for p in problems if exclude not in p]
    if FLAGS.problem and FLAGS.problem[-1] == "*":
        problems = [p for p in problems if p.startswith(FLAGS.problem[:-1])]
    elif FLAGS.problem and "," in FLAGS.problem:
        problems = [p for p in problems if p in FLAGS.problem.split(",")]
    elif FLAGS.problem:
        problems = [p for p in problems if p == FLAGS.problem]
    else:
        problems = []

    # Remove TIMIT if paths are not given.
    if getattr(FLAGS, "timit_paths", None):
        problems = [p for p in problems if "timit" not in p]
    # Remove parsing if paths are not given.
    if getattr(FLAGS, "parsing_path", None):
        problems = [p for p in problems if "parsing_english_ptb" not in p]

    if not problems:
        problems_str = "\n  * ".join(
            sorted(
                list(_SUPPORTED_PROBLEM_GENERATORS) +
                registry.list_problems()))
        error_msg = ("You must specify one of the supported problems to "
                     "generate data for:\n  * " + problems_str + "\n")
        error_msg += ("TIMIT and parsing need data_sets specified with "
                      "--timit_paths and --parsing_path.")
        raise ValueError(error_msg)

    if not FLAGS.data_dir:
        FLAGS.data_dir = tempfile.gettempdir()
        tf.logging.warning(
            "It is strongly recommended to specify --data_dir. "
            "Data will be written to default data_dir=%s.", FLAGS.data_dir)

    tf.logging.info(
        "Generating problems:\n%s" %
        registry.display_list_by_prefix(problems, starting_spaces=4))
    if FLAGS.only_list:
        return
    for problem in problems:
        set_random_seed()

        if problem in _SUPPORTED_PROBLEM_GENERATORS:
            generate_data_for_problem(problem)
        else:
            generate_data_for_registered_problem(problem)

    # Fathom
    xcom.echo_yaml_for_xcom_ingest({'t2t_data_dir': FLAGS.data_dir})
Ejemplo n.º 46
0
def main(argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  t2t_trainer.maybe_log_registry_and_exit()


  if FLAGS.cloud_mlengine:
    cloud_mlengine.launch()
    return

  if FLAGS.generate_data:
    t2t_trainer.generate_data()

  if cloud_mlengine.job_dir():
    FLAGS.output_dir = cloud_mlengine.job_dir()

  if argv:
    t2t_trainer.set_hparams_from_args(argv[1:])
  hparams = t2t_trainer.create_hparams()
  trainer_lib.add_problem_hparams(hparams, FLAGS.problem)
  attack_params = create_attack_params()
  attack_params.add_hparam("eps", 0.0)

  config = t2t_trainer.create_run_config(hparams)
  params = {"batch_size": hparams.batch_size}

  # add "_rev" as a hack to avoid image standardization
  problem = registry.problem(FLAGS.problem + "_rev")
  input_fn = problem.make_estimator_input_fn(
      tf.estimator.ModeKeys.EVAL, hparams)
  dataset = input_fn(params, config).repeat()
  features, _ = dataset.make_one_shot_iterator().get_next()
  inputs, labels = features["targets"], features["inputs"]
  inputs = tf.to_float(inputs)
  labels = tf.squeeze(labels)

  sess = tf.Session()

  model_fn = t2t_model.T2TModel.make_estimator_model_fn(
      FLAGS.model, hparams, use_tpu=FLAGS.use_tpu)
  ch_model = adv_attack_utils.T2TAttackModel(model_fn, params, config)

  acc_mask = None
  probs = ch_model.get_probs(inputs)
  if FLAGS.ignore_incorrect:
    preds = tf.argmax(probs, -1)
    preds = tf.squeeze(preds)
    acc_mask = tf.to_float(tf.equal(labels, preds))
  one_hot_labels = tf.one_hot(labels, probs.shape[-1])

  attack = create_attack(attack_params.attack)(ch_model, sess=sess)

  # Restore weights
  saver = tf.train.Saver()
  checkpoint_path = os.path.expanduser(FLAGS.output_dir or
                                       FLAGS.checkpoint_path)
  saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path))

  # reuse variables
  tf.get_variable_scope().reuse_variables()

  def compute_accuracy(x, labels, mask):
    preds = ch_model.get_probs(x)
    preds = tf.squeeze(preds)
    preds = tf.argmax(preds, -1, output_type=labels.dtype)
    _, acc_update_op = tf.metrics.accuracy(
        labels=labels, predictions=preds, weights=mask)
    sess.run(tf.initialize_local_variables())
    for _ in range(FLAGS.eval_steps):
      acc = sess.run(acc_update_op)
    return acc

  acc = compute_accuracy(inputs, labels, acc_mask)
  epsilon_acc_pairs = [(0.0, acc)]
  for epsilon in attack_params.attack_epsilons:
    attack_params.eps = epsilon
    adv_x = attack.generate(inputs, y=one_hot_labels, **attack_params.values())
    acc = compute_accuracy(adv_x, labels, acc_mask)
    epsilon_acc_pairs.append((epsilon, acc))

  for epsilon, acc in epsilon_acc_pairs:
    tf.logging.info("Accuracy @ eps=%f: %f" % (epsilon, acc))
    def _init_env(self):
        FLAGS.use_tpu = False
        tf.logging.set_verbosity(tf.logging.DEBUG)
        tf.logging.info("Import usr dir from %s", self._usr_dir)
        if self._usr_dir != None:
            usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
        tf.logging.info("Start to create hparams,for %s of %s", self._problem,
                        self._hparams_set)

        # 获取 模型参数
        self._hparams = create_hparams()
        # 获取 decode用的参数
        self._hparams_decode = create_decode_hparams(
            extra_length=self._extra_length,
            batch_size=self._batch_size,
            beam_size=self._beam_size,
            alpha=self._alpha,
            return_beams=self._return_beams,
            write_beam_scores=self._write_beam_scores,
            force_decode_length=self._force_decode_length)

        # self.estimator = trainer_lib.create_estimator(
        #     FLAGS.model,
        #     self._hparams,
        #     t2t_trainer.create_run_config(self._hparams),
        #     decode_hparams=self._hparams_decode,
        #     use_tpu=False)

        tf.logging.info("Finish intialize environment")

        #######

        ### make input placeholder
        self._inputs_ph = tf.placeholder(
            dtype=tf.int32)  # shape not specified,any shape

        x = tf.placeholder(dtype=tf.int32)
        x.set_shape([None, None])  # ? -> (?,?)
        x = tf.expand_dims(x, axis=[2])  # -> (?,?,1)
        x = tf.to_int32(x)
        self._inputs_ph = x

        #batch_inputs = tf.reshape(self._inputs_ph, [self._batch_size, -1, 1, 1])
        batch_inputs = x
        ###

        # batch_inputs = tf.reshape(self._inputs_ph, [-1, -1, 1, 1])

        #targets_ph = tf.placeholder(dtype=tf.int32)
        #batch_targets = tf.reshape(targets_ph, [1, -1, 1, 1])
        self._features = {
            "inputs": batch_inputs,
            "problem_choice": 0,  # We run on the first problem here.
            "input_space_id": self._hparams.problem_hparams.input_space_id,
            "target_space_id": self._hparams.problem_hparams.target_space_id
        }
        ### 加入 decode length  变长的  分类时候没用
        self.input_extra_length_ph = tf.placeholder(dtype=tf.int32)
        self._features['decode_length'] = self.input_extra_length_ph
        ## target
        self._targets_ph = tf.placeholder(tf.int32,
                                          shape=(None, None, None, None),
                                          name='targets')
        self._features['targets'] = self._targets_ph
        target_pretend = np.zeros((1, 1, 1, 1))

        ## 去掉 整数的
        del self._features["problem_choice"]
        del self._features["input_space_id"]
        del self._features["target_space_id"]
        del self._features['decode_length']
        ####
        #mode = tf.estimator.ModeKeys.PREDICT # affect last_only  t2t_model._top_single  ,[1,?,1,512]->[1,1,1,1,64]
        # if self.predict_or_eval=='EVAL':
        #     mode = tf.estimator.ModeKeys.EVAL # affect last_only  t2t_model._top_single  ,[1,?,1,512]->[1,?,1,1,64]
        # # estimator_spec = model_builder.model_fn(self._model_name, features, mode, self._hparams,
        # #                                         problem_names=[self._problem], decode_hparams=self._hparams_dc)
        # if self.predict_or_eval=='PREDICT':
        #     mode = tf.estimator.ModeKeys.PREDICT

        if self.predict_or_eval == 'and':
            mode = tf.estimator.ModeKeys.EVAL

        ###########
        # registry.model
        ############
        translate_model = registry.model(self._model_name)(
            hparams=self._hparams,
            decode_hparams=self._hparams_decode,
            mode=mode)

        self.predict_dict = {}

        ### get logit ,EVAL mode
        self.logits, _ = translate_model(self._features)
        ### get infer result ,PREDICT mode
        translate_model.set_mode(tf.estimator.ModeKeys.PREDICT)
        with tf.variable_scope(tf.get_variable_scope(), reuse=True):
            self.outputs_scores = translate_model.infer(
                features=self._features,
                decode_length=50,
                beam_size=self._beam_size,
                top_beams=self._beam_size,
                alpha=self._alpha)

        ######

        tf.logging.info("Start to init tf session")
        if self._isGpu:
            print('Using GPU in Decoder')
            gpu_options = tf.GPUOptions(
                per_process_gpu_memory_fraction=self._fraction)
            self._sess = tf.Session(
                config=tf.ConfigProto(allow_soft_placement=True,
                                      log_device_placement=False,
                                      gpu_options=gpu_options))
        else:
            print('Using CPU in Decoder')
            gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0)
            config = tf.ConfigProto(gpu_options=gpu_options)
            config.allow_soft_placement = True
            config.log_device_placement = False
            self._sess = tf.Session(config=config)
        with self._sess.as_default():
            ckpt = saver_mod.get_checkpoint_state(self._model_dir)
            saver = tf.train.Saver()
            tf.logging.info("Start to restore the parameters from %s",
                            ckpt.model_checkpoint_path)
            saver.restore(self._sess, ckpt.model_checkpoint_path)
        tf.logging.info("Finish intialize environment")
    def _init_env(self):
        FLAGS.use_tpu = False
        tf.logging.set_verbosity(tf.logging.DEBUG)
        tf.logging.info("Import usr dir from %s", self._usr_dir)
        if self._usr_dir != None:
            usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
        tf.logging.info("Start to create hparams,for %s of %s", self._problem,
                        self._hparams_set)

        self._hparams = create_hparams()
        self._hparams_decode = create_decode_hparams(
            extra_length=self._extra_length,
            batch_size=self._batch_size,
            beam_size=self._beam_size,
            alpha=self._alpha,
            return_beams=self._return_beams,
            write_beam_scores=self._write_beam_scores,
            force_decode_length=self._force_decode_length)

        self.estimator = trainer_lib.create_estimator(
            FLAGS.model,
            self._hparams,
            t2t_trainer.create_run_config(self._hparams),
            decode_hparams=self._hparams_decode,
            use_tpu=False)

        tf.logging.info("Finish intialize environment")

        #######

        ### make input placeholder
        #self._inputs_ph = tf.placeholder(dtype=tf.int32)  # shape not specified,any shape

        # x=tf.placeholder(dtype=tf.int32)
        # x.set_shape([None, None]) # ? -> (?,?)
        # x = tf.expand_dims(x, axis=[2])# -> (?,?,1)
        # x = tf.to_int32(x)
        # self._inputs_ph=x

        #batch_inputs = tf.reshape(self._inputs_ph, [self._batch_size, -1, 1, 1])
        #batch_inputs=x
        ###

        # batch_inputs = tf.reshape(self._inputs_ph, [-1, -1, 1, 1])

        #targets_ph = tf.placeholder(dtype=tf.int32)
        #batch_targets = tf.reshape(targets_ph, [1, -1, 1, 1])

        self.inputs_ph = tf.placeholder(tf.int32,
                                        shape=(None, None, 1, 1),
                                        name='inputs')
        self.targets_ph = tf.placeholder(tf.int32,
                                         shape=(None, None, None, None),
                                         name='targets')
        self.input_extra_length_ph = tf.placeholder(dtype=tf.int32, shape=[])

        self._features = {
            "inputs": self.inputs_ph,
            "problem_choice": 0,  # We run on the first problem here.
            "input_space_id": self._hparams.problem_hparams.input_space_id,
            "target_space_id": self._hparams.problem_hparams.target_space_id
        }
        ### 加入 decode length  变长的
        self._features['decode_length'] = self.input_extra_length_ph
        ## target
        self._features['targets'] = self.targets_ph

        ## 去掉 整数的
        del self._features["problem_choice"]
        del self._features["input_space_id"]
        del self._features["target_space_id"]
        #del self._features['decode_length']
        ####

        mode = tf.estimator.ModeKeys.EVAL

        translate_model = registry.model(self._model_name)(
            hparams=self._hparams,
            decode_hparams=self._hparams_decode,
            mode=mode)

        self.predict_dict = {}

        ### get logit  ,attention mats
        self.logits, _ = translate_model(self._features)  #[? ? ? 1 vocabsz]
        #translate_model(features)
        from visualization import get_att_mats
        self.att_mats = get_att_mats(translate_model,
                                     self._model_name)  # enc, dec, encdec
        ### get infer
        translate_model.set_mode(tf.estimator.ModeKeys.PREDICT)
        with tf.variable_scope(tf.get_variable_scope(), reuse=True):
            self.outputs_scores = translate_model.infer(
                features=self._features,
                decode_length=self._extra_length,
                beam_size=self._beam_size,
                top_beams=self._beam_size,
                alpha=self._alpha)  #outputs 4,4,63

        ######
        tf.logging.info("Start to init tf session")
        if self._isGpu:
            print('Using GPU in Decoder')
            gpu_options = tf.GPUOptions(
                per_process_gpu_memory_fraction=self._fraction)
            self._sess = tf.Session(
                config=tf.ConfigProto(allow_soft_placement=True,
                                      log_device_placement=False,
                                      gpu_options=gpu_options))
        else:
            print('Using CPU in Decoder')
            gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0)
            config = tf.ConfigProto(gpu_options=gpu_options)
            config.allow_soft_placement = True
            config.log_device_placement = False
            self._sess = tf.Session(config=config)
        with self._sess.as_default():
            ckpt = saver_mod.get_checkpoint_state(self._model_dir)
            saver = tf.train.Saver()
            tf.logging.info("Start to restore the parameters from %s",
                            ckpt.model_checkpoint_path)
            saver.restore(self._sess, ckpt.model_checkpoint_path)
        tf.logging.info("Finish intialize environment")
Ejemplo n.º 49
0
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

  # Create hparams
  hparams = trainer_lib.create_hparams(
      FLAGS.hparams_set,
      FLAGS.hparams,
      data_dir=os.path.expanduser(FLAGS.data_dir),
      problem_name=FLAGS.problem)
  hparams.force_full_predict = True
  hparams.scheduled_sampling_k = -1

  # Params
  num_agents = 1  # TODO(mbz): fix the code for more agents
  num_steps = FLAGS.num_steps
  if hasattr(hparams.problem, "num_actions"):
    num_actions = hparams.problem.num_actions
  else:
    num_actions = None
  frame_shape = hparams.problem.frame_shape
  resized_frame = hparams.preprocess_resize_frames is not None
  if resized_frame:
    frame_shape = hparams.preprocess_resize_frames
    frame_shape += [hparams.problem.num_channels]

  dataset = registry.problem(FLAGS.problem).dataset(
      tf.estimator.ModeKeys.TRAIN,
      shuffle_files=True,
      data_dir=os.path.expanduser(FLAGS.data_dir),
      hparams=hparams)

  dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(num_agents))
  data = dataset.make_one_shot_iterator().get_next()
  # Setup input placeholders
  input_size = [num_agents, hparams.video_num_input_frames]
  if num_actions is None:
    placeholders = {
        "inputs": tf.placeholder(tf.float32, input_size + frame_shape)
    }
  else:
    placeholders = {
        "inputs": tf.placeholder(tf.float32, input_size + frame_shape),
        "input_action": tf.placeholder(tf.int64, input_size + [1]),
        "input_reward": tf.placeholder(tf.int64, input_size + [1]),
        "reset_internal_states": tf.placeholder(tf.float32, []),
    }
  # Create model.
  model_cls = registry.model(FLAGS.model)
  model = model_cls(hparams, tf.estimator.ModeKeys.PREDICT)
  prediction_ops = model.infer(placeholders)

  states_q = Queue(maxsize=hparams.video_num_input_frames)
  actions_q = Queue(maxsize=hparams.video_num_input_frames)
  rewards_q = Queue(maxsize=hparams.video_num_input_frames)
  if num_actions is not None:
    all_qs = [states_q, actions_q, rewards_q]
  else:
    all_qs = [states_q]

  writer = common_video.WholeVideoWriter(
      fps=FLAGS.fps, output_path=FLAGS.output_gif)

  saver = tf.train.Saver(tf.trainable_variables())
  with tf.train.SingularMonitoredSession() as sess:
    # Load latest checkpoint
    ckpt = tf.train.get_checkpoint_state(FLAGS.output_dir).model_checkpoint_path
    saver.restore(sess.raw_session(), ckpt)

    # get init frames from the dataset
    data_np = sess.run(data)

    frames = np.split(data_np["inputs"], hparams.video_num_input_frames, 1)
    for frame in frames:
      frame = np.squeeze(frame, 1)
      states_q.put(frame)
      writer.write(frame[0].astype(np.uint8))

    if num_actions is not None:
      actions = np.split(data_np["input_action"],
                         hparams.video_num_input_frames, 1)
      for action in actions:
        actions_q.put(np.squeeze(action, 1))

      rewards = np.split(data_np["input_reward"],
                         hparams.video_num_input_frames, 1)
      for reward in rewards:
        rewards_q.put(np.squeeze(reward, 1))

    for step in range(num_steps):
      print(">>>>>>> ", step)

      if num_actions is not None:
        random_actions = np.random.randint(num_actions-1)
        random_actions = np.expand_dims(random_actions, 0)
        random_actions = np.tile(random_actions, (num_agents, 1))

        # Shape inputs and targets
        inputs, input_action, input_reward = (
            np.stack(list(q.queue), axis=1) for q in all_qs)
      else:
        assert len(all_qs) == 1
        q = all_qs[0]
        elems = list(q.queue)
        # Need to adjust shapes sometimes.
        for i, e in enumerate(elems):
          if len(e.shape) < 4:
            elems[i] = np.expand_dims(e, axis=0)
        inputs = np.stack(elems, axis=1)

      # Predict next frames
      if num_actions is None:
        feed = {placeholders["inputs"]: inputs}
      else:
        feed = {
            placeholders["inputs"]: inputs,
            placeholders["input_action"]: input_action,
            placeholders["input_reward"]: input_reward,
            placeholders["reset_internal_states"]: float(step == 0),
        }
      predictions = sess.run(prediction_ops, feed_dict=feed)

      if num_actions is None:
        predicted_states = predictions[:, 0]
      else:
        predicted_states = predictions["targets"][:, 0]
        predicted_reward = predictions["target_reward"][:, 0]

      # Update queues
      if num_actions is None:
        new_data = (predicted_states)
      else:
        new_data = (predicted_states, random_actions, predicted_reward)
      for q, d in zip(all_qs, new_data):
        q.get()
        q.put(d.copy())

      writer.write(np.round(predicted_states[0]).astype(np.uint8))

    writer.finish_to_disk()
Ejemplo n.º 50
0
def main(argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  t2t_trainer.maybe_log_registry_and_exit()


  if FLAGS.cloud_mlengine:
    cloud_mlengine.launch()
    return

  if FLAGS.generate_data:
    t2t_trainer.generate_data()

  if cloud_mlengine.job_dir():
    FLAGS.output_dir = cloud_mlengine.job_dir()

  if argv:
    t2t_trainer.set_hparams_from_args(argv[1:])

  if FLAGS.surrogate_attack:
    tf.logging.warn("Performing surrogate model attack.")
    sur_hparams = create_surrogate_hparams()
    trainer_lib.add_problem_hparams(sur_hparams, FLAGS.problem)

  hparams = t2t_trainer.create_hparams()
  trainer_lib.add_problem_hparams(hparams, FLAGS.problem)

  attack_params = create_attack_params()
  attack_params.add_hparam(attack_params.epsilon_name, 0.0)

  if FLAGS.surrogate_attack:
    sur_config = create_surrogate_run_config(sur_hparams)
  config = t2t_trainer.create_run_config(hparams)
  params = {
      "batch_size": hparams.batch_size,
      "use_tpu": FLAGS.use_tpu,
  }

  # add "_rev" as a hack to avoid image standardization
  problem = registry.problem(FLAGS.problem + "_rev")

  inputs, labels, features = prepare_data(problem, hparams, params, config)

  sess = tf.Session()

  if FLAGS.surrogate_attack:
    sur_model_fn = t2t_model.T2TModel.make_estimator_model_fn(
        FLAGS.surrogate_model, sur_hparams, use_tpu=FLAGS.use_tpu)
    sur_ch_model = adv_attack_utils.T2TAttackModel(
        sur_model_fn, features, params, sur_config, scope="surrogate")
    # Dummy call to construct graph
    sur_ch_model.get_probs(inputs)

    checkpoint_path = os.path.expanduser(FLAGS.surrogate_output_dir)
    tf.contrib.framework.init_from_checkpoint(
        tf.train.latest_checkpoint(checkpoint_path), {"/": "surrogate/"})
    sess.run(tf.global_variables_initializer())

  other_vars = set(tf.global_variables())

  model_fn = t2t_model.T2TModel.make_estimator_model_fn(
      FLAGS.model, hparams)
  ch_model = adv_attack_utils.T2TAttackModel(model_fn, features, params, config)

  acc_mask = None
  probs = ch_model.get_probs(inputs)
  if FLAGS.ignore_incorrect:
    preds = tf.argmax(probs, -1, output_type=labels.dtype)
    preds = tf.reshape(preds, labels.shape)
    acc_mask = tf.to_float(tf.equal(labels, preds))
  one_hot_labels = tf.one_hot(labels, probs.shape[-1])

  if FLAGS.surrogate_attack:
    attack = create_attack(attack_params.attack)(sur_ch_model, sess=sess)
  else:
    attack = create_attack(attack_params.attack)(ch_model, sess=sess)

  new_vars = set(tf.global_variables()) - other_vars

  # Restore weights
  saver = tf.train.Saver(new_vars)
  checkpoint_path = os.path.expanduser(FLAGS.output_dir)
  saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path))

  # reuse variables
  tf.get_variable_scope().reuse_variables()

  def compute_accuracy(x, l, mask):
    """Compute model accuracy."""
    preds = ch_model.get_probs(x)
    preds = tf.squeeze(preds)
    preds = tf.argmax(preds, -1, output_type=l.dtype)

    _, acc_update_op = tf.metrics.accuracy(l, preds, weights=mask)

    if FLAGS.surrogate_attack:
      preds = sur_ch_model.get_probs(x)
      preds = tf.squeeze(preds)
      preds = tf.argmax(preds, -1, output_type=l.dtype)
      acc_update_op = tf.tuple((acc_update_op,
                                tf.metrics.accuracy(l, preds, weights=mask)[1]))

    sess.run(tf.initialize_local_variables())
    for i in range(FLAGS.eval_steps):
      tf.logging.info(
          "\tEvaluating batch [%d / %d]" % (i + 1, FLAGS.eval_steps))
      acc = sess.run(acc_update_op)
    if FLAGS.surrogate_attack:
      tf.logging.info("\tFinal acc: (%.4f, %.4f)" % (acc[0], acc[1]))
    else:
      tf.logging.info("\tFinal acc: %.4f" % acc)
    return acc

  epsilon_acc_pairs = []
  for epsilon in attack_params.attack_epsilons:
    tf.logging.info("Attacking @ eps=%.4f" % epsilon)
    attack_params.set_hparam(attack_params.epsilon_name, epsilon)
    adv_x = attack.generate(inputs, y=one_hot_labels, **attack_params.values())
    acc = compute_accuracy(adv_x, labels, acc_mask)
    epsilon_acc_pairs.append((epsilon, acc))

  for epsilon, acc in epsilon_acc_pairs:
    if FLAGS.surrogate_attack:
      tf.logging.info(
          "Accuracy @ eps=%.4f: (%.4f, %.4f)" % (epsilon, acc[0], acc[1]))
    else:
      tf.logging.info("Accuracy @ eps=%.4f: %.4f" % (epsilon, acc))