コード例 #1
0
    def load(self):

        run_config = t2t_trainer.create_run_config(self.hp)
        self.hp.add_hparam("model_dir", run_config.model_dir)
        self.estimator = trainer_lib.create_estimator(
            self.model,
            self.hp,
            run_config,
            decode_hparams=self.decode_hp,
            use_tpu=self.use_tpu)

        self.estimator_predictor = tf.contrib.predictor.from_estimator(
            self.estimator,
            self.input_fn,
            config=tf.ConfigProto(log_device_placement=True,
                                  allow_soft_placement=True))
        FLAGS.problem = "translate_enfr_wmt32k_rev"
        self.problem = "translate_enfr_wmt32k_rev"
        self.problem_name = self.problem
        FLAGS.checkpoint_path = os.path.join(
            os.getcwd(), "checkpoints/fren/model.ckpt-500000")
        run_config = t2t_trainer.create_run_config(self.hp)
        self.hp.model_dir = run_config.model_dir
        self.estimator = trainer_lib.create_estimator(
            self.model,
            self.hp,
            run_config,
            decode_hparams=self.decode_hp,
            use_tpu=self.use_tpu)

        self.estimator_decoder_predictor = tf.contrib.predictor.from_estimator(
            self.estimator,
            self.input_fn,
            config=tf.ConfigProto(log_device_placement=True,
                                  allow_soft_placement=True))
コード例 #2
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

    hparams = trainer_lib.create_hparams(FLAGS.hparams_set,
                                         FLAGS.hparams,
                                         data_dir=FLAGS.data_dir,
                                         problem_name=FLAGS.problem)

    # set appropriate dataset-split, if flags.eval_use_test_set.
    dataset_split = "test" if FLAGS.eval_use_test_set else None
    dataset_kwargs = {"dataset_split": dataset_split}
    eval_input_fn = hparams.problem.make_estimator_input_fn(
        tf.estimator.ModeKeys.EVAL, hparams, dataset_kwargs=dataset_kwargs)
    config = t2t_trainer.create_run_config(hparams)

    # summary-hook in tf.estimator.EstimatorSpec requires
    # hparams.model_dir to be set.
    hparams.add_hparam("model_dir", config.model_dir)

    estimator = trainer_lib.create_estimator(FLAGS.model,
                                             hparams,
                                             config,
                                             use_tpu=FLAGS.use_tpu)
    ckpt_iter = trainer_lib.next_checkpoint(hparams.model_dir,
                                            FLAGS.eval_timeout_mins)
    for ckpt_path in ckpt_iter:
        predictions = estimator.evaluate(eval_input_fn,
                                         steps=FLAGS.eval_steps,
                                         checkpoint_path=ckpt_path)
        tf.logging.info(predictions)
コード例 #3
0
ファイル: export.py プロジェクト: qixiuai/tensor2tensor
def create_estimator(run_config, hparams):
  return trainer_lib.create_estimator(
      FLAGS.model,
      hparams,
      run_config,
      decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams),
      use_tpu=FLAGS.use_tpu)
コード例 #4
0
    def _init_env(self):
        FLAGS.use_tpu = False
        tf.logging.set_verbosity(tf.logging.DEBUG)
        tf.logging.info("Import usr dir from %s", self._usr_dir)
        if self._usr_dir != None:
            usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
        tf.logging.info("Start to create hparams,for %s of %s", self._problem,
                        self._hparams_set)

        self._hparams = create_hparams()
        self._hparams_decode = create_decode_hparams(
            extra_length=self._extra_length,
            batch_size=self._batch_size,
            beam_size=self._beam_size,
            alpha=self._alpha,
            return_beams=self._return_beams,
            write_beam_scores=self._write_beam_scores)

        self.estimator = trainer_lib.create_estimator(
            FLAGS.model,
            self._hparams,
            t2t_trainer.create_run_config(self._hparams),
            decode_hparams=self._hparams_decode,
            use_tpu=False)

        tf.logging.info("Finish intialize environment")
コード例 #5
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

    if FLAGS.score_file:
        filename = os.path.expanduser(FLAGS.score_file)
        if not tf.gfile.Exists(filename):
            raise ValueError("The file to score doesn't exist: %s" % filename)
        results = score_file(filename)
        # if not FLAGS.decode_to_file:
        #     raise ValueError("To score a file, specify --decode_to_file for results.")
        # write_file = tf.gfile.Open(os.path.expanduser(FLAGS.decode_to_file), "w")
        # for sentence, score in results:
        #     write_file.write(sentence + "\t" + "SCORE:" + "%.6f\n" % score)
        # write_file.close()
        return

    hp = create_hparams()
    decode_hp = create_decode_hparams()
    run_config = t2t_trainer.create_run_config(hp)
    if FLAGS.disable_grappler_optimizations:
        run_config.session_config.graph_options.rewrite_options.disable_meta_optimizer = True

    # summary-hook in tf.estimator.EstimatorSpec requires
    # hparams.model_dir to be set.
    hp.add_hparam("model_dir", run_config.model_dir)

    estimator = trainer_lib.create_estimator(FLAGS.model,
                                             hp,
                                             run_config,
                                             decode_hparams=decode_hp,
                                             use_tpu=FLAGS.use_tpu)

    decode(estimator, hp, decode_hp)
コード例 #6
0
ファイル: dev_utils.py プロジェクト: projectclarify/clarify
  def export(self, hparams):
    """Run model export for serving."""

    hparams.no_data_parallelism = True

    problem = hparams.problem

    run_config = trainer_lib.create_run_config(model_name=self.model_name,
                                               model_dir=self.model_dir,
                                               num_gpus=0,
                                               use_tpu=False)

    estimator = trainer_lib.create_estimator(
        self.model_name,
        hparams,
        run_config,
        decode_hparams=decoding.decode_hparams(self.decode_hparams))

    exporter = tf.estimator.FinalExporter(
        "exporter", lambda: problem.serving_input_fn(hparams), as_text=True)

    exporter.export(estimator,
                    self.export_dir,
                    checkpoint_path=tf.train.latest_checkpoint(self.model_dir),
                    eval_result=None,
                    is_the_final_export=True)
コード例 #7
0
  def __prepare_model(self, train_mode=False):
    """Prepare utilities for decoding."""
    hparams = registry.hparams(self.params.hparams_set)
    hparams.problem = self.problem
    hparams.problem_hparams = self.problem.get_hparams(hparams)
    if self.params.hparams:
      tf.logging.info("Overriding hparams in %s with %s",
                      self.params.hparams_set,
                      self.params.hparams)
      hparams = hparams.parse(self.params.hparams)
    trainer_run_config = g2p_trainer_utils.create_run_config(hparams,
        self.params)
    if train_mode:
      exp_fn = g2p_trainer_utils.create_experiment_fn(self.params, self.problem)
      self.exp = exp_fn(trainer_run_config, hparams)

    decode_hp = decoding.decode_hparams(self.params.decode_hparams)
    estimator = trainer_lib.create_estimator(
        self.params.model_name,
        hparams,
        trainer_run_config,
        decode_hparams=decode_hp,
        use_tpu=False)

    return estimator, decode_hp, hparams
コード例 #8
0
def create_estimator_fn(model_name,
                        hparams,
                        run_config,
                        schedule="train_and_evaluate",
                        decode_hparams=None
                        ):
    return trainer_lib.create_estimator(model_name, hparams, run_config, schedule, decode_hparams, False)
コード例 #9
0
def t2t_decoder(problem_name, data_dir, decode_from_file, decode_to_file,
                checkpoint_path):
    trainer_lib.set_random_seed(FLAGS.random_seed)

    hp = trainer_lib.create_hparams(FLAGS.hparams_set,
                                    FLAGS.hparams,
                                    data_dir=os.path.expanduser(data_dir),
                                    problem_name=problem_name)

    decode_hp = decoding.decode_hparams(FLAGS.decode_hparams)
    decode_hp.shards = FLAGS.decode_shards
    decode_hp.shard_id = FLAGS.worker_id
    decode_in_memory = FLAGS.decode_in_memory or decode_hp.decode_in_memory
    decode_hp.decode_in_memory = decode_in_memory
    decode_hp.decode_to_file = decode_to_file
    decode_hp.decode_reference = None

    FLAGS.checkpoint_path = checkpoint_path
    estimator = trainer_lib.create_estimator(FLAGS.model,
                                             hp,
                                             t2t_trainer.create_run_config(hp),
                                             decode_hparams=decode_hp,
                                             use_tpu=FLAGS.use_tpu)

    decode_from_text_file(estimator,
                          problem_name,
                          decode_from_file,
                          hp,
                          decode_hp,
                          decode_to_file,
                          checkpoint_path=checkpoint_path)
コード例 #10
0
ファイル: t2t_decoder.py プロジェクト: qixiuai/tensor2tensor
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)


  if FLAGS.score_file:
    filename = os.path.expanduser(FLAGS.score_file)
    if not tf.gfile.Exists(filename):
      raise ValueError("The file to score doesn't exist: %s" % filename)
    results = score_file(filename)
    if not FLAGS.decode_to_file:
      raise ValueError("To score a file, specify --decode_to_file for results.")
    write_file = tf.gfile.Open(os.path.expanduser(FLAGS.decode_to_file), "w")
    for score in results:
      write_file.write("%.6f\n" % score)
    write_file.close()
    return

  hp = create_hparams()
  decode_hp = create_decode_hparams()

  estimator = trainer_lib.create_estimator(
      FLAGS.model,
      hp,
      t2t_trainer.create_run_config(hp),
      decode_hparams=decode_hp,
      use_tpu=FLAGS.use_tpu)

  decode(estimator, hp, decode_hp)
コード例 #11
0
ファイル: g2p.py プロジェクト: cmusphinx/g2p-seq2seq
  def __prepare_model(self, train_mode=False):
    """Prepare utilities for decoding."""
    hparams = registry.hparams(self.params.hparams_set)
    hparams.problem = self.problem
    hparams.problem_hparams = self.problem.get_hparams(hparams)
    if self.params.hparams:
      tf.logging.info("Overriding hparams in %s with %s",
                      self.params.hparams_set,
                      self.params.hparams)
      hparams = hparams.parse(self.params.hparams)
    trainer_run_config = g2p_trainer_utils.create_run_config(hparams,
        self.params)
    if train_mode:
      exp_fn = g2p_trainer_utils.create_experiment_fn(self.params, self.problem)
      self.exp = exp_fn(trainer_run_config, hparams)

    decode_hp = decoding.decode_hparams(self.params.decode_hparams)
    estimator = trainer_lib.create_estimator(
        self.params.model_name,
        hparams,
        trainer_run_config,
        decode_hparams=decode_hp,
        use_tpu=False)

    return estimator, decode_hp, hparams
コード例 #12
0
ファイル: transformer.py プロジェクト: kwecht/cortex
def create_estimator(run_config, model_config):
    hparams = trainer_lib.create_hparams("transformer_base_single_gpu")

    # SentimentIMDBCortex subclasses SentimentIMDB
    problem = SentimentIMDBCortex(list(model_config["input"]["vocab"]))
    hparams.problem = problem
    hparams.problem_hparams = problem.get_hparams(hparams)

    # metrics specific to the sentiment problem
    problem.eval_metrics = lambda: [
        metrics.Metrics.ACC_TOP5,
        metrics.Metrics.ACC_PER_SEQ,
        metrics.Metrics.NEG_LOG_PERPLEXITY,
    ]

    # reduce memory load
    hparams.num_hidden_layers = 2
    hparams.hidden_size = 32
    hparams.filter_size = 32
    hparams.num_heads = 2

    # t2t expects these keys
    hparams.warm_start_from = None
    run_config.data_parallelism = None
    run_config.t2t_device_info = {"num_async_replicas": 1}

    return trainer_lib.create_estimator("transformer", hparams, run_config)
コード例 #13
0
def create_estimator(run_config, hparams):
    return trainer_lib.create_estimator(FLAGS.model,
                                        hparams,
                                        run_config,
                                        decode_hparams=decoding.decode_hparams(
                                            FLAGS.decode_hparams),
                                        use_tpu=FLAGS.use_tpu)
コード例 #14
0
ファイル: t2t_decoder.py プロジェクト: anonymusNLP/EMNLP2019
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

    if FLAGS.score_file:
        filename = os.path.expanduser(FLAGS.score_file)
        if not tf.gfile.Exists(filename):
            raise ValueError("The file to score doesn't exist: %s" % filename)
        results = score_file(filename)
        if not FLAGS.decode_to_file:
            raise ValueError(
                "To score a file, specify --decode_to_file for results.")
        write_file = tf.gfile.Open(os.path.expanduser(FLAGS.decode_to_file),
                                   "w")
        for score in results:
            write_file.write("%.6f\n" % score)
        write_file.close()
        return

    hp = create_hparams()
    decode_hp = create_decode_hparams()

    estimator = trainer_lib.create_estimator(FLAGS.model,
                                             hp,
                                             t2t_trainer.create_run_config(hp),
                                             decode_hparams=decode_hp,
                                             use_tpu=FLAGS.use_tpu)

    decode(estimator, hp, decode_hp)
コード例 #15
0
ファイル: t2t_transformer.py プロジェクト: ancardona/cortex
def create_estimator(run_config, model_config):
    # t2t expects these keys in run_config
    run_config.data_parallelism = None
    run_config.t2t_device_info = {"num_async_replicas": 1}

    hparams = trainer_lib.create_hparams("transformer_base_single_gpu")

    problem = SentimentIMDBCortex(
        list(model_config["aggregates"]["reviews_vocab"]))
    p_hparams = problem.get_hparams(hparams)
    hparams.problem = problem
    hparams.problem_hparams = p_hparams

    problem.eval_metrics = lambda: [
        metrics.Metrics.ACC_TOP5,
        metrics.Metrics.ACC_PER_SEQ,
        metrics.Metrics.NEG_LOG_PERPLEXITY,
    ]

    # t2t expects this key
    hparams.warm_start_from = None

    # reduce memory load
    hparams.num_hidden_layers = 2
    hparams.hidden_size = 32
    hparams.filter_size = 32
    hparams.num_heads = 2

    estimator = trainer_lib.create_estimator("transformer", hparams,
                                             run_config)
    return estimator
コード例 #16
0
def create_new_estimator(hp,decode_hp):
    estimator = trainer_lib.create_estimator(
        FLAGS.model,
        hp,
        t2t_trainer.create_run_config(hp),
        decode_hparams=decode_hp,
        use_tpu=FLAGS.use_tpu)
    return estimator
コード例 #17
0
ファイル: generate.py プロジェクト: Hadryan/Moodzik
def run():
    """
    Load Transformer model according to flags and start sampling.
    :raises:
        ValueError: if required flags are missing or invalid.
    """
    if FLAGS.model_path is None:
        raise ValueError('Required Transformer pre-trained model path.')

    if FLAGS.output_dir is None:
        raise ValueError('Required Midi output directory.')

    if FLAGS.decode_length <= 0:
        raise ValueError('Decode length must be > 0.')

    problem = utils.PianoPerformanceLanguageModelProblem()
    unconditional_encoders = problem.get_feature_encoders()
    primer_ns = music_pb2.NoteSequence()
    if FLAGS.primer_path is None:
        targets = []
    else:
        if FLAGS.max_primer_second <= 0:
            raise ValueError('Max primer second must be > 0.')

        primer_ns = utils.get_primer_ns(FLAGS.primer_path,
                                        FLAGS.max_primer_second)
        targets = unconditional_encoders['targets'].encode_note_sequence(
            primer_ns)

        # Remove the end token from the encoded primer.
        targets = targets[:-1]
        if len(targets) >= FLAGS.decode_length:
            raise ValueError(
                'Primer has more or equal events than maximum sequence length:'
                ' %d >= %d; Aborting' % (len(targets), FLAGS.decode_length))
    decode_length = FLAGS.decode_length - len(targets)

    # Set up HParams.
    hparams = trainer_lib.create_hparams(hparams_set=FLAGS.hparams_set)
    trainer_lib.add_problem_hparams(hparams, problem)
    hparams.num_hidden_layers = FLAGS.layers
    hparams.sampling_method = FLAGS.sample

    # Set up decoding HParams.
    decode_hparams = decoding.decode_hparams()
    decode_hparams.alpha = FLAGS.alpha
    decode_hparams.beam_size = FLAGS.beam_size

    # Create Estimator.
    utils.LOGGER.info('Loading model')
    run_config = trainer_lib.create_run_config(hparams)
    estimator = trainer_lib.create_estimator(FLAGS.model_name,
                                             hparams,
                                             run_config,
                                             decode_hparams=decode_hparams)

    generate(estimator, unconditional_encoders, decode_length, targets,
             primer_ns)
コード例 #18
0
ファイル: export.py プロジェクト: changlan/tensor2tensor
def create_estimator(run_config, hparams):
    return trainer_lib.create_estimator(
        FLAGS.model,
        hparams,
        run_config,
        decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams),
        use_tpu=FLAGS.use_tpu,
        export_saved_model_api_version=FLAGS.export_saved_model_api_version,
        use_guarantee_const_getter=FLAGS.use_guarantee_const_getter)
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    # sess_dir = FLAGS.sess_dir
    # output_dir = os.path.expanduser(sess_dir+problem_name+'-'+model+'-'+hparams)
    output_dir = FLAGS.output_dir

    if FLAGS.score_file:
        filename = os.path.expanduser(FLAGS.score_file)
        if not tf.gfile.Exists(filename):
            raise ValueError("The file to score doesn't exist: %s" % filename)
        results = score_file(filename)
        if not FLAGS.decode_to_file:
            raise ValueError(
                "To score a file, specify --decode_to_file for results.")
        write_file = tf.gfile.Open(os.path.expanduser(FLAGS.decode_to_file),
                                   "w")
        for score in results:
            write_file.write("%.6f\n" % score)
        write_file.close()
        return

    hp = create_hparams()

    if FLAGS.global_steps:
        FLAGS.checkpoint_path = os.path.join(
            FLAGS.model_dir, f"model.ckpt-{FLAGS.global_steps}")
    else:
        FLAGS.checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)

    # Check if already exists
    dataset_split = "test" if FLAGS.split == "test" else "dev"
    decode_path = os.path.join(FLAGS.model_dir,
                               "decode_00000")  # default decoded_to_file
    decode_path = FLAGS.decode_to_file if FLAGS.decode_to_file else decode_path
    if os.path.isdir(decode_path):
        files = os.listdir(decode_path)
        for file in files:
            file_name = file.split(".")[0]
            file_name_to_be = f"{FLAGS.global_steps}{dataset_split}{FLAGS.test_shard:03d}"
            if file_name == file_name_to_be:
                print(f"Already {file_name_to_be} exists")
                return

    tf.reset_default_graph()
    decode_hp = create_decode_hparams(decode_path, FLAGS.test_shard)
    estimator = trainer_lib.create_estimator(FLAGS.model,
                                             hp,
                                             create_run_config(hp),
                                             decode_hparams=decode_hp,
                                             use_tpu=FLAGS.use_tpu)
    decode(estimator, hp, decode_hp)
    print("shard " + str(FLAGS.test_shard) + " completed")
コード例 #20
0
  def testCompatibility(self):
    model = "transformer"
    hp_set = "transformer_test"
    problem_name = "translate_ende_wmt8k"

    hp = trainer_lib.create_hparams(
        hp_set, data_dir=_DATA_DIR, problem_name=problem_name)
    run_config = trainer_lib.create_run_config(model, model_dir=_CKPT_DIR)
    estimator = trainer_lib.create_estimator(model, hp, run_config)

    for prediction in estimator.predict(self.input_fn):
      self.assertEqual(prediction["outputs"].dtype, np.int32)
コード例 #21
0
  def testCompatibility(self):
    model = "transformer"
    hp_set = "transformer_test"
    problem_name = "translate_ende_wmt8k"

    hp = trainer_lib.create_hparams(
        hp_set, data_dir=_DATA_DIR, problem_name=problem_name)
    run_config = trainer_lib.create_run_config(model_dir=_CKPT_DIR)
    estimator = trainer_lib.create_estimator(model, hp, run_config)

    for prediction in estimator.predict(self.input_fn):
      self.assertEqual(prediction["outputs"].dtype, np.int32)
コード例 #22
0
def main(_):
    FLAGS.decode_interactive = True
    hp = create_hparams()
    decode_hp = create_decode_hparams()

    estimator = trainer_lib.create_estimator(FLAGS.model,
                                             hp,
                                             t2t_trainer.create_run_config(hp),
                                             decode_hparams=decode_hp,
                                             use_tpu=False)

    decode(estimator, hp, decode_hp)
コード例 #23
0
def main(_):
  import ipdb
  
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)


  if FLAGS.score_file:
    filename = os.path.expanduser(FLAGS.score_file)
    if not tf.gfile.Exists(filename):
      raise ValueError("The file to score doesn't exist: %s" % filename)
    results = score_file(filename)
    if not FLAGS.decode_to_file:
      raise ValueError("To score a file, specify --decode_to_file for results.")
    write_file = tf.gfile.Open(os.path.expanduser(FLAGS.decode_to_file), "w")
    for score in results:
      write_file.write("%.6f\n" % score)
    write_file.close()
    return

  hp = create_hparams()
  decode_hp = create_decode_hparams()

  # eval_input_fn = hp.problem.make_estimator_input_fn(
  #   tf.estimator.ModeKeys.TRAIN, hp, dataset_kwargs={"dataset_split": "eval"})


  # print(eval_input_fn)
  # for foo in eval_input_fn(None, None):
  #   print(type(foo[0]['targets']))
  #   print(foo[0]['targets'].numpy())
  # exit()
  
  run_config = t2t_trainer.create_run_config(hp)
  if FLAGS.disable_grappler_optimizations:
    run_config.session_config.graph_options.rewrite_options.disable_meta_optimizer = True

  # summary-hook in tf.estimator.EstimatorSpec requires
  # hparams.model_dir to be set.
  hp.add_hparam("model_dir", run_config.model_dir)

  estimator = trainer_lib.create_estimator(
      FLAGS.model,
      hp,
      run_config,
      decode_hparams=decode_hp,
      use_tpu=FLAGS.use_tpu)

  decode(estimator, hp, decode_hp)
コード例 #24
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

    if FLAGS.score_file:
        filename = os.path.expanduser(FLAGS.score_file)
        if not tf.gfile.Exists(filename):
            raise ValueError("The file to score doesn't exist: %s" % filename)
        results = score_file(filename)
        if not FLAGS.decode_to_file:
            raise ValueError(
                "To score a file, specify --decode_to_file for results.")
        write_file = tf.gfile.Open(os.path.expanduser(FLAGS.decode_to_file),
                                   "w")
        for score in results:
            write_file.write("%.6f\n" % score)
        write_file.close()
        return

    hp = create_hparams()
    decode_hp = create_decode_hparams()

    estimator = trainer_lib.create_estimator(FLAGS.model,
                                             hp,
                                             t2t_trainer.create_run_config(hp),
                                             decode_hparams=decode_hp,
                                             use_tpu=FLAGS.use_tpu)

    decode(estimator, hp, decode_hp)

    # Post-process decodings (if necessary).
    if FLAGS.decode_to_file and FLAGS.output_line_prefix_tag:
        decode_filename_original = FLAGS.decode_to_file
        decode_filename_prefixed = "%s-%s" % (decode_filename_original,
                                              FLAGS.output_line_prefix_tag)
        tf.logging.info("Writing prefexed decodes into %s" %
                        decode_filename_prefixed)
        # Read original lines.
        with tf.gfile.Open(decode_filename_original, "r") as original_fp:
            original_lines = original_fp.readlines()
        # Write prefixed lines.
        prefix = "<%s> " % FLAGS.output_line_prefix_tag
        prefixed_fp = tf.gfile.Open(decode_filename_prefixed, "w")
        for line in original_lines:
            prefixed_fp.write(prefix + line)
        prefixed_fp.flush()
        prefixed_fp.close()
        tf.logging.info("Done.")
コード例 #25
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

    hp = t2t_decoder.create_hparams()
    decode_hp = t2t_decoder.create_decode_hparams()

    estimator = trainer_lib.create_estimator(FLAGS.model,
                                             hp,
                                             t2t_trainer.create_run_config(hp),
                                             decode_hparams=decode_hp,
                                             use_tpu=FLAGS.use_tpu)

    decode(estimator, hp, decode_hp)
コード例 #26
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    FLAGS.use_tpu = False  # decoding not supported on TPU

    hp = create_hparams()
    decode_hp = create_decode_hparams()

    estimator = trainer_lib.create_estimator(FLAGS.model,
                                             hp,
                                             t2t_trainer.create_run_config(hp),
                                             decode_hparams=decode_hp,
                                             use_tpu=False)

    decode(estimator, hp, decode_hp)
コード例 #27
0
ファイル: t2t_decoder.py プロジェクト: chqiwang/tensor2tensor
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  FLAGS.use_tpu = False  # decoding not supported on TPU

  hp = create_hparams()
  decode_hp = create_decode_hparams()

  estimator = trainer_lib.create_estimator(
      FLAGS.model,
      hp,
      t2t_trainer.create_run_config(hp),
      decode_hparams=decode_hp,
      use_tpu=False)

  decode(estimator, hp, decode_hp)
コード例 #28
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    # Fathom start
    checkpoint_path = fathom_t2t_model_setup()
    # Fathom end
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

    if FLAGS.score_file:
        filename = os.path.expanduser(FLAGS.score_file)
        if not tf.gfile.Exists(filename):
            raise ValueError("The file to score doesn't exist: %s" % filename)
        results = score_file(filename)
        if not FLAGS.decode_to_file:
            raise ValueError(
                "To score a file, specify --decode_to_file for results.")
        write_file = tf.gfile.Open(os.path.expanduser(FLAGS.decode_to_file),
                                   "w")
        for score in results:
            write_file.write("%.6f\n" % score)
        write_file.close()
        return

    hp = create_hparams()
    decode_hp = create_decode_hparams()

    estimator = trainer_lib.create_estimator(FLAGS.model,
                                             hp,
                                             t2t_trainer.create_run_config(hp),
                                             decode_hparams=decode_hp,
                                             use_tpu=FLAGS.use_tpu)

    decode(estimator, hp, decode_hp)

    # Fathom
    # This xcom is here so that tasks after decode know the local path to the
    # downloaded model. Train does this same xcom echo.
    # Decode, predict, and evaluate code should
    # converge to use the same fathom_t2t_model_setup.
    # TODO: since the truncation-boundary xcom value should be available in
    #  the hparams_set, we should probably have consumers access this via a
    #  SavedModel.hparams property rather than XCOM
    echo_yaml_for_xcom_ingest({
        'output-dir': os.path.dirname(checkpoint_path),
        'output-file': FLAGS.decode_output_file,
        'truncation-boundary': hp.max_input_seq_length
    })
コード例 #29
0
    def __init__(self, processor_configuration):
        """Creates the Transformer estimator.

    Args:
      processor_configuration: A ProcessorConfiguration protobuffer with the
        transformer fields populated.
    """
        # Do the pre-setup tensor2tensor requires for flags and configurations.
        transformer_config = processor_configuration["transformer"]
        FLAGS.output_dir = transformer_config["model_dir"]
        usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
        data_dir = os.path.expanduser(transformer_config["data_dir"])

        # Create the basic hyper parameters.
        self.hparams = trainer_lib.create_hparams(
            transformer_config["hparams_set"],
            transformer_config["hparams"],
            data_dir=data_dir,
            problem_name=transformer_config["problem"])

        decode_hp = decoding.decode_hparams()
        decode_hp.add_hparam("shards", 1)
        decode_hp.add_hparam("shard_id", 0)

        # Create the estimator and final hyper parameters.
        self.estimator = trainer_lib.create_estimator(
            transformer_config["model"],
            self.hparams,
            t2t_trainer.create_run_config(self.hparams),
            decode_hparams=decode_hp,
            use_tpu=False)

        # Fetch the vocabulary and other helpful variables for decoding.
        self.source_vocab = self.hparams.problem_hparams.vocabulary["inputs"]
        self.targets_vocab = self.hparams.problem_hparams.vocabulary["targets"]
        self.const_array_size = 10000

        # Prepare the Transformer's debug data directory.
        run_dirs = sorted(
            glob.glob(os.path.join("/tmp/t2t_server_dump", "run_*")))
        for run_dir in run_dirs:
            shutil.rmtree(run_dir)
コード例 #30
0
ファイル: t2t.py プロジェクト: zyxyaoqi/cortex
def create_estimator(run_config, model_config):
    # t2t expects these keys in run_config
    run_config.data_parallelism = None
    run_config.t2t_device_info = {"num_async_replicas": 1}

    # t2t has its own set of hyperparameters we can use
    hparams = trainer_lib.create_hparams("basic_fc_small")
    problem = registry.problem("image_mnist")
    p_hparams = problem.get_hparams(hparams)
    hparams.problem = problem
    hparams.problem_hparams = p_hparams

    # don't need eval_metrics
    problem.eval_metrics = lambda: []

    # t2t expects this key
    hparams.warm_start_from = None

    estimator = trainer_lib.create_estimator("basic_fc_relu", hparams, run_config)
    return estimator
コード例 #31
0
ファイル: g2p.py プロジェクト: jupinter/g2p-seq2seq
  def __prepare_model(self):
    """Prepare utilities for decoding."""
    hparams = trainer_lib.create_hparams(
        hparams_set=self.params.hparams_set,
        hparams_overrides_str=self.params.hparams)
    trainer_run_config = g2p_trainer_utils.create_run_config(hparams,
        self.params)
    exp_fn = g2p_trainer_utils.create_experiment_fn(self.params, self.problem)
    self.exp = exp_fn(trainer_run_config, hparams)

    decode_hp = decoding.decode_hparams(self.params.decode_hparams)
    decode_hp.add_hparam("shards", self.params.decode_shards)
    decode_hp.add_hparam("shard_id", self.params.worker_id)
    estimator = trainer_lib.create_estimator(
        self.params.model_name,
        hparams,
        trainer_run_config,
        decode_hparams=decode_hp,
        use_tpu=False)

    return estimator, decode_hp, hparams
コード例 #32
0
  def __init__(self, processor_configuration):
    """Creates the Transformer estimator.

    Args:
      processor_configuration: A ProcessorConfiguration protobuffer with the
        transformer fields populated.
    """
    # Do the pre-setup tensor2tensor requires for flags and configurations.
    transformer_config = processor_configuration["transformer"]
    FLAGS.output_dir = transformer_config["model_dir"]
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    data_dir = os.path.expanduser(transformer_config["data_dir"])

    # Create the basic hyper parameters.
    self.hparams = trainer_lib.create_hparams(
        transformer_config["hparams_set"],
        transformer_config["hparams"],
        data_dir=data_dir,
        problem_name=transformer_config["problem"])

    decode_hp = decoding.decode_hparams()
    decode_hp.add_hparam("shards", 1)
    decode_hp.add_hparam("shard_id", 0)

    # Create the estimator and final hyper parameters.
    self.estimator = trainer_lib.create_estimator(
        transformer_config["model"],
        self.hparams,
        t2t_trainer.create_run_config(self.hparams),
        decode_hparams=decode_hp, use_tpu=False)

    # Fetch the vocabulary and other helpful variables for decoding.
    self.source_vocab = self.hparams.problem_hparams.vocabulary["inputs"]
    self.targets_vocab = self.hparams.problem_hparams.vocabulary["targets"]
    self.const_array_size = 10000

    # Prepare the Transformer's debug data directory.
    run_dirs = sorted(glob.glob(os.path.join("/tmp/t2t_server_dump", "run_*")))
    for run_dir in run_dirs:
      shutil.rmtree(run_dir)
コード例 #33
0
    def initialize(self, is_conditioned=False):
        self.model_name = 'transformer'
        self.hparams_set = 'transformer_tpu'
        self.conditioned = is_conditioned
        if self.conditioned:
            self.ckpt_path = 'models/checkpoints/melody_conditioned_model_16.ckpt'
            problem = MelodyToPianoPerformanceProblem()
        else:
            self.ckpt_path = 'models/checkpoints/unconditional_model_16.ckpt'
            problem = PianoPerformanceLanguageModelProblem()

        self.encoders = problem.get_feature_encoders()

        # Set up hyperparams
        hparams = trainer_lib.create_hparams(hparams_set=self.hparams_set)
        trainer_lib.add_problem_hparams(hparams, problem)
        hparams.num_hidden_layers = 16
        hparams.sampling_method = 'random'

        # Set up decoding hyperparams
        decode_hparams = decoding.decode_hparams()
        decode_hparams.alpha = 0.0
        decode_hparams.beam_size = 1
        if self.conditioned:
            self.inputs = []
        else:
            self.targets = []

        self.decode_length = 0
        run_config = trainer_lib.create_run_config(hparams)
        estimator = trainer_lib.create_estimator(
            self.model_name, hparams, run_config,
            decode_hparams=decode_hparams)
        fnc = self.input_generation_conditional if self.conditioned else self.input_generator_unconditional
        input_fn = decoding.make_input_fn_from_generator(fnc())
        self.samples = estimator.predict(
            input_fn, checkpoint_path=self.ckpt_path)
        _ = next(self.samples)
コード例 #34
0
ファイル: decoding.py プロジェクト: khuongav/back_translate
def create_hp_and_estimator(problem_name, data_dir, checkpoint_path):
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)

    hp = trainer_lib.create_hparams(FLAGS.hparams_set,
                                    FLAGS.hparams,
                                    data_dir=os.path.expanduser(data_dir),
                                    problem_name=problem_name)

    decode_hp = decoding.decode_hparams(FLAGS.decode_hparams)
    decode_hp.shards = FLAGS.decode_shards
    decode_hp.shard_id = FLAGS.worker_id
    decode_in_memory = FLAGS.decode_in_memory or decode_hp.decode_in_memory
    decode_hp.decode_in_memory = decode_in_memory
    decode_hp.decode_to_file = None
    decode_hp.decode_reference = None

    FLAGS.checkpoint_path = checkpoint_path
    estimator = trainer_lib.create_estimator(FLAGS.model,
                                             hp,
                                             t2t_trainer.create_run_config(hp),
                                             decode_hparams=decode_hp,
                                             use_tpu=FLAGS.use_tpu)
    return hp, decode_hp, estimator
コード例 #35
0
ファイル: musicGen.py プロジェクト: Dkpalea/transformer-bard
def music_generator(primer='erik_gnossienne',
                    primer_begin_buffer=10,
                    primer_length=90,
                    output_path='.',
                    filename='./public/output'):
    SF2_PATH = './models/Yamaha-C5-Salamander-JNv5.1.sf2'
    SAMPLE_RATE = 16000

    # Upload a MIDI file and convert to NoteSequence.
    def upload_midi():
        data = list(files.upload().values())
        if len(data) > 1:
            print('Multiple files uploaded; using only one.')
        return mm.midi_to_note_sequence(data[0])

    # Decode a list of IDs.
    def decode(ids, encoder):
        ids = list(ids)
        if text_encoder.EOS_ID in ids:
            ids = ids[:ids.index(text_encoder.EOS_ID)]
        return encoder.decode(ids)

    model_name = 'transformer'
    hparams_set = 'transformer_tpu'
    ckpt_path = './models/checkpoints/unconditional_model_16.ckpt'

    class PianoPerformanceLanguageModelProblem(score2perf.Score2PerfProblem):
        @property
        def add_eos_symbol(self):
            return True

    problem = PianoPerformanceLanguageModelProblem()
    unconditional_encoders = problem.get_feature_encoders()

    # Set up HParams.
    hparams = trainer_lib.create_hparams(hparams_set=hparams_set)
    trainer_lib.add_problem_hparams(hparams, problem)
    hparams.num_hidden_layers = 16
    hparams.sampling_method = 'random'

    # Set up decoding HParams.
    decode_hparams = decoding.decode_hparams()
    decode_hparams.alpha = 0.0
    decode_hparams.beam_size = 1

    # Create Estimator.
    run_config = trainer_lib.create_run_config(hparams)
    estimator = trainer_lib.create_estimator(model_name,
                                             hparams,
                                             run_config,
                                             decode_hparams=decode_hparams)

    # These values will be changed by subsequent cells.
    targets = []
    decode_length = 0

    # Create input generator (so we can adjust priming and
    # decode length on the fly).
    def input_generator():
        global targets
        global decode_length
        while True:
            yield {
                'targets': np.array([targets], dtype=np.int32),
                'decode_length': np.array(decode_length, dtype=np.int32)
            }

    # Start the Estimator, loading from the specified checkpoint.
    input_fn = decoding.make_input_fn_from_generator(input_generator())
    unconditional_samples = estimator.predict(input_fn,
                                              checkpoint_path=ckpt_path)

    # "Burn" one.
    _ = next(unconditional_samples)

    filenames = {
        'C major arpeggio': './models/primers/c_major_arpeggio.mid',
        'C major scale': './models/primers/c_major_scale.mid',
        'Clair de Lune': './models/primers/clair_de_lune.mid',
        'Classical':
        'audio_midi/Classical_Piano_piano-midi.de_MIDIRip/bach/bach_846_format0.mid',
        'erik_gymnopedie': 'audio_midi/erik_satie/gymnopedie_1_(c)oguri.mid',
        'erik_gymnopedie_2': 'audio_midi/erik_satie/gymnopedie_2_(c)oguri.mid',
        'erik_gymnopedie_3': 'audio_midi/erik_satie/gymnopedie_3_(c)oguri.mid',
        'erik_gnossienne': 'audio_midi/erik_satie/gnossienne_1_(c)oguri.mid',
        'erik_gnossienne_2': 'audio_midi/erik_satie/gnossienne_2_(c)oguri.mid',
        'erik_gnossienne_3': 'audio_midi/erik_satie/gnossienne_3_(c)oguri.mid',
        'erik_gnossienne_dery':
        'audio_midi/erik_satie/gnossienne_1_(c)dery.mid',
        'erik_gnossienne_dery_2':
        'audio_midi/erik_satie/gnossienne_2_(c)dery.mid',
        'erik_gnossienne_dery_3':
        'audio_midi/erik_satie/gnossienne_3_(c)dery.mid',
        'erik_gnossienne_dery_5':
        'audio_midi/erik_satie/gnossienne_5_(c)dery.mid',
        'erik_gnossienne_dery_6':
        'audio_midi/erik_satie/gnossienne_6_(c)dery.mid',
        '1': 'audio_midi/erik_satie/1.mid',
        '2': 'audio_midi/erik_satie/2.mid',
        '3': 'audio_midi/erik_satie/3.mid',
        '4': 'audio_midi/erik_satie/4.mid',
        '5': 'audio_midi/erik_satie/5.mid',
        '6': 'audio_midi/erik_satie/6.mid',
        '7': 'audio_midi/erik_satie/7.mid',
        '8': 'audio_midi/erik_satie/8.mid',
        '9': 'audio_midi/erik_satie/9.mid',
        '10': 'audio_midi/erik_satie/10.mid',
    }
    # primer = 'C major scale'

    #if primer == 'Upload your own!':
    #  primer_ns = upload_midi()
    #else:
    #  # Use one of the provided primers.
    #  primer_ns = mm.midi_file_to_note_sequence(filenames[primer])
    primer_ns = mm.midi_file_to_note_sequence(filenames[primer])
    # Handle sustain pedal in the primer.
    primer_ns = mm.apply_sustain_control_changes(primer_ns)

    # Trim to desired number of seconds.
    max_primer_seconds = primer_length
    if primer_ns.total_time > max_primer_seconds:
        print('Primer is longer than %d seconds, truncating.' %
              max_primer_seconds)
        primer_ns = mm.extract_subsequence(
            primer_ns, primer_begin_buffer,
            max_primer_seconds + primer_begin_buffer)

    # Remove drums from primer if present.
    if any(note.is_drum for note in primer_ns.notes):
        print('Primer contains drums; they will be removed.')
        notes = [note for note in primer_ns.notes if not note.is_drum]
        del primer_ns.notes[:]
        primer_ns.notes.extend(notes)

    # Set primer instrument and program.
    for note in primer_ns.notes:
        note.instrument = 1
        note.program = 0

    ## Play and plot the primer.
    #mm.play_sequence(
    #    primer_ns,
    #    synth=mm.fluidsynth, sample_rate=SAMPLE_RATE, sf2_path=SF2_PATH)
    #mm.plot_sequence(primer_ns)
    mm.sequence_proto_to_midi_file(
        primer_ns, join(output_path, 'primer_{}.mid'.format(filename)))

    targets = unconditional_encoders['targets'].encode_note_sequence(primer_ns)

    # Remove the end token from the encoded primer.
    targets = targets[:-1]

    decode_length = max(0, 10000 - len(targets))
    if len(targets) >= 4096:
        print(
            'Primer has more events than maximum sequence length; nothing will be generated.'
        )

    # Generate sample events.
    sample_ids = next(unconditional_samples)['outputs']

    # Decode to NoteSequence.
    midi_filename = decode(sample_ids,
                           encoder=unconditional_encoders['targets'])
    ns = mm.midi_file_to_note_sequence(midi_filename)
    print('Sample IDs: {}'.format(sample_ids))
    print('Sample IDs length: {}'.format(len(sample_ids)))
    print('Encoder: {}'.format(unconditional_encoders['targets']))
    print('Unconditional Samples: {}'.format(unconditional_samples))
    # print('{}'.format(ns))

    # continuation_ns = mm.concatenate_sequences([primer_ns, ns])
    continuation_ns = ns
    # mm.play_sequence(
    #     continuation_ns,
    #     synth=mm.fluidsynth, sample_rate=SAMPLE_RATE, sf2_path=SF2_PATH)
    # mm.plot_sequence(continuation_ns)
    # try:
    audio = mm.fluidsynth(continuation_ns,
                          sample_rate=SAMPLE_RATE,
                          sf2_path=SF2_PATH)

    normalizer = float(np.iinfo(np.int16).max)
    array_of_ints = np.array(np.asarray(audio) * normalizer, dtype=np.int16)

    wavfile.write(join(output_path, filename + '.wav'), SAMPLE_RATE,
                  array_of_ints)
    print('[+] Output stored as {}'.format(filename + '.wav'))
    mm.sequence_proto_to_midi_file(
        continuation_ns,
        join(output_path, 'continuation_{}.mid'.format(filename)))
コード例 #36
0
    def _init_env(self):
        FLAGS.use_tpu = False
        tf.logging.set_verbosity(tf.logging.DEBUG)
        tf.logging.info("Import usr dir from %s", self._usr_dir)
        if self._usr_dir != None:
            usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
        tf.logging.info("Start to create hparams,for %s of %s", self._problem,
                        self._hparams_set)

        self._hparams = create_hparams()
        self._hparams_decode = create_decode_hparams(
            extra_length=self._extra_length,
            batch_size=self._batch_size,
            beam_size=self._beam_size,
            alpha=self._alpha,
            return_beams=self._return_beams,
            write_beam_scores=self._write_beam_scores,
            force_decode_length=self._force_decode_length)

        self.estimator = trainer_lib.create_estimator(
            FLAGS.model,
            self._hparams,
            t2t_trainer.create_run_config(self._hparams),
            decode_hparams=self._hparams_decode,
            use_tpu=False)

        tf.logging.info("Finish intialize environment")

        #######

        ### make input placeholder
        #self._inputs_ph = tf.placeholder(dtype=tf.int32)  # shape not specified,any shape

        # x=tf.placeholder(dtype=tf.int32)
        # x.set_shape([None, None]) # ? -> (?,?)
        # x = tf.expand_dims(x, axis=[2])# -> (?,?,1)
        # x = tf.to_int32(x)
        # self._inputs_ph=x

        #batch_inputs = tf.reshape(self._inputs_ph, [self._batch_size, -1, 1, 1])
        #batch_inputs=x
        ###

        # batch_inputs = tf.reshape(self._inputs_ph, [-1, -1, 1, 1])

        #targets_ph = tf.placeholder(dtype=tf.int32)
        #batch_targets = tf.reshape(targets_ph, [1, -1, 1, 1])

        self.inputs_ph = tf.placeholder(tf.int32,
                                        shape=(None, None, 1, 1),
                                        name='inputs')
        self.targets_ph = tf.placeholder(tf.int32,
                                         shape=(None, None, None, None),
                                         name='targets')
        self.input_extra_length_ph = tf.placeholder(dtype=tf.int32, shape=[])

        self._features = {
            "inputs": self.inputs_ph,
            "problem_choice": 0,  # We run on the first problem here.
            "input_space_id": self._hparams.problem_hparams.input_space_id,
            "target_space_id": self._hparams.problem_hparams.target_space_id
        }
        ### 加入 decode length  变长的
        self._features['decode_length'] = self.input_extra_length_ph
        ## target
        self._features['targets'] = self.targets_ph

        ## 去掉 整数的
        del self._features["problem_choice"]
        del self._features["input_space_id"]
        del self._features["target_space_id"]
        #del self._features['decode_length']
        ####

        mode = tf.estimator.ModeKeys.EVAL

        translate_model = registry.model(self._model_name)(
            hparams=self._hparams,
            decode_hparams=self._hparams_decode,
            mode=mode)

        self.predict_dict = {}

        ### get logit  ,attention mats
        self.logits, _ = translate_model(self._features)  #[? ? ? 1 vocabsz]
        #translate_model(features)
        from visualization import get_att_mats
        self.att_mats = get_att_mats(translate_model,
                                     self._model_name)  # enc, dec, encdec
        ### get infer
        translate_model.set_mode(tf.estimator.ModeKeys.PREDICT)
        with tf.variable_scope(tf.get_variable_scope(), reuse=True):
            self.outputs_scores = translate_model.infer(
                features=self._features,
                decode_length=self._extra_length,
                beam_size=self._beam_size,
                top_beams=self._beam_size,
                alpha=self._alpha)  #outputs 4,4,63

        ######
        tf.logging.info("Start to init tf session")
        if self._isGpu:
            print('Using GPU in Decoder')
            gpu_options = tf.GPUOptions(
                per_process_gpu_memory_fraction=self._fraction)
            self._sess = tf.Session(
                config=tf.ConfigProto(allow_soft_placement=True,
                                      log_device_placement=False,
                                      gpu_options=gpu_options))
        else:
            print('Using CPU in Decoder')
            gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0)
            config = tf.ConfigProto(gpu_options=gpu_options)
            config.allow_soft_placement = True
            config.log_device_placement = False
            self._sess = tf.Session(config=config)
        with self._sess.as_default():
            ckpt = saver_mod.get_checkpoint_state(self._model_dir)
            saver = tf.train.Saver()
            tf.logging.info("Start to restore the parameters from %s",
                            ckpt.model_checkpoint_path)
            saver.restore(self._sess, ckpt.model_checkpoint_path)
        tf.logging.info("Finish intialize environment")
コード例 #37
0
def create_experiment(run_config,
                      hparams,
                      model_name,
                      params,
                      problem_instance,
                      data_dir,
                      train_steps,
                      eval_steps,
                      min_eval_frequency=2000,
                      eval_throttle_seconds=600,
                      schedule="train_and_evaluate",
                      export=False,
                      decode_hparams=None,
                      use_tfdbg=False,
                      use_dbgprofile=False,
                      use_validation_monitor=False,
                      eval_early_stopping_steps=None,
                      eval_early_stopping_metric=None,
                      eval_early_stopping_metric_delta=None,
                      eval_early_stopping_metric_minimize=True,
                      autotune=False,
                      use_tpu=False):
  """Create Experiment."""
  # HParams
  hparams.add_hparam('model_dir', params.model_dir)
  hparams.add_hparam("data_dir", data_dir)
  hparams.add_hparam("train_steps", train_steps)
  hparams.add_hparam("eval_steps", eval_steps)
  hparams.add_hparam("schedule", schedule)
  add_problem_hparams(hparams, problem_instance)

  # Estimator
  estimator = trainer_lib.create_estimator(
      model_name,
      hparams,
      run_config,
      schedule=schedule,
      decode_hparams=decode_hparams,
      use_tpu=use_tpu)

  # Input fns from Problem
  problem = hparams.problem
  train_input_fn = problem.make_estimator_input_fn(
      tf.estimator.ModeKeys.TRAIN, hparams)
  eval_input_fn = problem.make_estimator_input_fn(
      tf.estimator.ModeKeys.EVAL, hparams)

  # Export
  if export:
    tf.logging.warn("Exporting from the trainer is deprecated. "
                    "See serving/export.py.")

  # Hooks
  validation_monitor_kwargs = dict(
      input_fn=eval_input_fn,
      eval_steps=eval_steps,
      every_n_steps=min_eval_frequency,
      early_stopping_rounds=eval_early_stopping_steps,
      early_stopping_metric=eval_early_stopping_metric,
      early_stopping_metric_minimize=eval_early_stopping_metric_minimize)
  dbgprofile_kwargs = {"output_dir": run_config.model_dir}
  early_stopping_kwargs = dict(
      events_dir=os.path.join(run_config.model_dir, "eval_continuous"),
      tag=eval_early_stopping_metric,
      num_plateau_steps=eval_early_stopping_steps,
      plateau_decrease=eval_early_stopping_metric_minimize,
      plateau_delta=eval_early_stopping_metric_delta,
      every_n_steps=min_eval_frequency)

  # In-process eval (and possible early stopping)
  if schedule == "continuous_train_and_eval" and min_eval_frequency:
    tf.logging.warn("ValidationMonitor only works with "
                    "--schedule=train_and_evaluate")
  use_validation_monitor = (
      schedule == "train_and_evaluate" and min_eval_frequency)
  # Distributed early stopping
  local_schedules = ["train_and_evaluate", "continuous_train_and_eval"]
  use_early_stopping = (
      schedule not in local_schedules and eval_early_stopping_steps)
  train_hooks, eval_hooks = trainer_lib.create_hooks(
      use_tfdbg=use_tfdbg,
      use_dbgprofile=use_dbgprofile,
      dbgprofile_kwargs=dbgprofile_kwargs,
      use_validation_monitor=use_validation_monitor,
      validation_monitor_kwargs=validation_monitor_kwargs,
      use_early_stopping=use_early_stopping,
      early_stopping_kwargs=early_stopping_kwargs)
  train_hooks += t2t_model.T2TModel.get_train_hooks(model_name)
  eval_hooks += t2t_model.T2TModel.get_eval_hooks(model_name)

  train_hooks = tf.contrib.learn.monitors.replace_monitors_with_hooks(
      train_hooks, estimator)
  eval_hooks = tf.contrib.learn.monitors.replace_monitors_with_hooks(
      eval_hooks, estimator)

  train_spec = tf.estimator.TrainSpec(
      train_input_fn, max_steps=train_steps, hooks=train_hooks)
  eval_spec = tf.estimator.EvalSpec(
      eval_input_fn,
      steps=eval_steps,
      hooks=eval_hooks,
      start_delay_secs=0 if hparams.schedule == "evaluate" else 120,
      throttle_secs=eval_throttle_seconds)

  if autotune:
    hooks_kwargs = {"train_monitors": train_hooks, "eval_hooks": eval_hooks}
    return tf.contrib.learn.Experiment(
        estimator=estimator,
        train_input_fn=train_input_fn,
        eval_input_fn=eval_input_fn,
        train_steps=train_steps,
        eval_steps=eval_steps,
        min_eval_frequency=min_eval_frequency,
        train_steps_per_iteration=min(min_eval_frequency, train_steps),
        eval_delay_secs=0 if schedule == "evaluate" else 120,
        **hooks_kwargs if not use_tpu else {})
  return trainer_lib.T2TExperiment(estimator, hparams, train_spec, eval_spec,
                                   use_validation_monitor, decode_hparams)