コード例 #1
0
ファイル: export.py プロジェクト: qixiuai/tensor2tensor
def create_estimator(run_config, hparams):
  return trainer_lib.create_estimator(
      FLAGS.model,
      hparams,
      run_config,
      decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams),
      use_tpu=FLAGS.use_tpu)
コード例 #2
0
ファイル: g2p.py プロジェクト: cmusphinx/g2p-seq2seq
  def __prepare_model(self, train_mode=False):
    """Prepare utilities for decoding."""
    hparams = registry.hparams(self.params.hparams_set)
    hparams.problem = self.problem
    hparams.problem_hparams = self.problem.get_hparams(hparams)
    if self.params.hparams:
      tf.logging.info("Overriding hparams in %s with %s",
                      self.params.hparams_set,
                      self.params.hparams)
      hparams = hparams.parse(self.params.hparams)
    trainer_run_config = g2p_trainer_utils.create_run_config(hparams,
        self.params)
    if train_mode:
      exp_fn = g2p_trainer_utils.create_experiment_fn(self.params, self.problem)
      self.exp = exp_fn(trainer_run_config, hparams)

    decode_hp = decoding.decode_hparams(self.params.decode_hparams)
    estimator = trainer_lib.create_estimator(
        self.params.model_name,
        hparams,
        trainer_run_config,
        decode_hparams=decode_hp,
        use_tpu=False)

    return estimator, decode_hp, hparams
コード例 #3
0
def create_decode_hparams():
    decode_hp = decoding.decode_hparams(FLAGS.decode_hparams)
    decode_hp.shards = FLAGS.decode_shards
    decode_hp.shard_id = FLAGS.worker_id
    decode_in_memory = FLAGS.decode_in_memory or decode_hp.decode_in_memory
    decode_hp.decode_in_memory = decode_in_memory
    decode_hp.decode_to_file = FLAGS.decode_to_file
    decode_hp.decode_reference = FLAGS.decode_reference
    return decode_hp
コード例 #4
0
 def testDecodeInMemoryTrue(self):
   predictions, problem = self.get_predictions()
   decode_hparams = decoding.decode_hparams()
   decode_hparams.decode_in_memory = True
   decode_hooks = decoding.DecodeHookArgs(
       estimator=None, problem=problem, output_dirs=None,
       hparams=decode_hparams, decode_hparams=decode_hparams,
       predictions=predictions)
   metrics = video_utils.summarize_video_metrics(decode_hooks)
コード例 #5
0
ファイル: export.py プロジェクト: xinjianlv/tensor2tensor
def create_estimator(run_config, hparams):
    return trainer_lib.create_estimator(
        FLAGS.model,
        hparams,
        run_config,
        decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams),
        use_tpu=FLAGS.use_tpu,
        export_saved_model_api_version=FLAGS.export_saved_model_api_version,
        use_guarantee_const_getter=FLAGS.use_guarantee_const_getter)
コード例 #6
0
ファイル: t2t_model.py プロジェクト: yukeyi/tensor2tensor
  def __init__(self,
               hparams,
               mode=tf.estimator.ModeKeys.TRAIN,
               problem_hparams=None,
               data_parallelism=None,
               decode_hparams=None):
    """Create a T2TModel.

    Args:
      hparams: tf.contrib.training.HParams, model hyperparameters.
      mode: tf.estimator.ModeKeys, the execution mode.
      problem_hparams: tf.contrib.training.HParams, hyperparameters for the
        Problem. If provided here or in hparams.problems, the model will
        automatically determine bottom, top, and loss methods. If not provided,
        calling the model will only invoke body.
      data_parallelism: a expert_utils.Parallelism object,
        specifies devices for data parallelism.
      decode_hparams: a hyperparameter object with decoding parameters.
        See decoding.decode_hparams.

    Returns:
      a T2TModel
    """
    # Determine name first: use registered name if possible, class name else.
    default_name = registry.default_name(type(self))
    name = self.REGISTERED_NAME or default_name
    super(T2TModel, self).__init__(
        trainable=mode == tf.estimator.ModeKeys.TRAIN, name=name)

    if not problem_hparams and hasattr(hparams, "problems"):
      problem_hparams = hparams.problems[0]
    print(problem_hparams)
    self._problem_hparams = problem_hparams

    # Setup hparams
    # If vocabularies differ, unset shared_embedding_and_softmax_weights.
    hparams = copy.copy(hparams)
    if self._problem_hparams and hparams.shared_embedding_and_softmax_weights:
      same_vocab_sizes = True
      if "inputs" in self._problem_hparams.input_modality:
        if (self._problem_hparams.input_modality["inputs"] !=
            self._problem_hparams.target_modality):
          same_vocab_sizes = False
      if not same_vocab_sizes:
        log_info("Unsetting shared_embedding_and_softmax_weights.")
        hparams.shared_embedding_and_softmax_weights = 0
    self._original_hparams = hparams
    self.set_mode(mode)

    self._decode_hparams = copy.copy(decode_hparams or
                                     decoding.decode_hparams())
    self._data_parallelism = data_parallelism or eu.Parallelism([""])
    self._num_datashards = self._data_parallelism.n
    self._ps_devices = self._data_parallelism.ps_devices
    self._eager_var_store = create_eager_var_store()
    if self._problem_hparams:
      self._create_modalities(self._problem_hparams, self._hparams)
コード例 #7
0
ファイル: t2t_decoder.py プロジェクト: qixiuai/tensor2tensor
def create_decode_hparams():
  decode_hp = decoding.decode_hparams(FLAGS.decode_hparams)
  decode_hp.shards = FLAGS.decode_shards
  decode_hp.shard_id = FLAGS.worker_id
  decode_in_memory = FLAGS.decode_in_memory or decode_hp.decode_in_memory
  decode_hp.decode_in_memory = decode_in_memory
  decode_hp.decode_to_file = FLAGS.decode_to_file
  decode_hp.decode_reference = FLAGS.decode_reference
  return decode_hp
コード例 #8
0
ファイル: t2t_model.py プロジェクト: chqiwang/tensor2tensor
  def __init__(self,
               hparams,
               mode=tf.estimator.ModeKeys.TRAIN,
               problem_hparams=None,
               data_parallelism=None,
               decode_hparams=None):
    """Create a T2TModel.

    Args:
      hparams: tf.contrib.training.HParams, model hyperparameters.
      mode: tf.estimator.ModeKeys, the execution mode.
      problem_hparams: tf.contrib.training.HParams, hyperparameters for the
        Problem. If provided here or in hparams.problems, the model will
        automatically determine bottom, top, and loss methods. If not provided,
        calling the model will only invoke body.
      data_parallelism: a expert_utils.Parallelism object,
        specifies devices for data parallelism.
      decode_hparams: a hyperparameter object with decoding parameters.
        See decoding.decode_hparams.

    Returns:
      a T2TModel
    """
    # Determine name first: use registered name if possible, class name else.
    default_name = registry.default_name(type(self))
    name = self.REGISTERED_NAME or default_name
    super(T2TModel, self).__init__(
        trainable=mode == tf.estimator.ModeKeys.TRAIN, name=name)

    if not problem_hparams and hasattr(hparams, "problems"):
      problem_hparams = hparams.problems[0]
    self._problem_hparams = problem_hparams

    # Setup hparams
    # If vocabularies differ, unset shared_embedding_and_softmax_weights.
    hparams = copy.copy(hparams)
    if self._problem_hparams and hparams.shared_embedding_and_softmax_weights:
      same_vocab_sizes = True
      if "inputs" in self._problem_hparams.input_modality:
        if (self._problem_hparams.input_modality["inputs"] !=
            self._problem_hparams.target_modality):
          same_vocab_sizes = False
      if not same_vocab_sizes:
        log_info("Unsetting shared_embedding_and_softmax_weights.")
        hparams.shared_embedding_and_softmax_weights = 0
    self._original_hparams = hparams
    self.set_mode(mode)

    self._decode_hparams = copy.copy(decode_hparams or
                                     decoding.decode_hparams())
    self._data_parallelism = data_parallelism or eu.Parallelism([""])
    self._num_datashards = self._data_parallelism.n
    self._ps_devices = self._data_parallelism.ps_devices
    self._eager_var_store = create_eager_var_store()
    if self._problem_hparams:
      self._create_modalities(self._problem_hparams, self._hparams)
コード例 #9
0
def create_experiment_components(hparams, output_dir, data_dir, model_name):
    """Constructs and returns Estimator and train/eval input functions."""
    tf.logging.info("Creating experiment, storing model files in %s",
                    output_dir)

    num_datashards = devices.data_parallelism().n
    train_input_fn = input_fn_builder.build_input_fn(
        mode=tf.estimator.ModeKeys.TRAIN,
        hparams=hparams,
        data_file_patterns=get_data_filepatterns(data_dir,
                                                 tf.estimator.ModeKeys.TRAIN),
        num_datashards=num_datashards,
        worker_replicas=FLAGS.worker_replicas,
        worker_id=FLAGS.worker_id)

    eval_input_fn = input_fn_builder.build_input_fn(
        mode=tf.estimator.ModeKeys.EVAL,
        hparams=hparams,
        data_file_patterns=get_data_filepatterns(data_dir,
                                                 tf.estimator.ModeKeys.EVAL),
        num_datashards=num_datashards,
        worker_replicas=FLAGS.worker_replicas,
        worker_id=FLAGS.worker_id)

    autotune = False
    objective = None
    if hasattr(FLAGS, "autotune"):
        autotune = FLAGS.autotune
        objective = FLAGS.objective
    model_fn = model_builder.build_model_fn(
        model_name,
        problem_names=FLAGS.problems.split("-"),
        train_steps=FLAGS.train_steps,
        worker_id=FLAGS.worker_id,
        worker_replicas=FLAGS.worker_replicas,
        eval_run_autoregressive=FLAGS.eval_run_autoregressive,
        decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams),
        autotune=autotune,
        objective=objective)
    estimator = tf.estimator.Estimator(
        model_fn=model_fn,
        model_dir=output_dir,
        params=hparams,
        config=tf.contrib.learn.RunConfig(
            master=FLAGS.master,
            gpu_memory_fraction=FLAGS.worker_gpu_memory_fraction,
            session_config=session_config(),
            keep_checkpoint_max=FLAGS.keep_checkpoint_max,
            keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours,
            save_checkpoints_secs=FLAGS.save_checkpoints_secs))

    return estimator, {
        tf.estimator.ModeKeys.TRAIN: train_input_fn,
        tf.estimator.ModeKeys.EVAL: eval_input_fn
    }
コード例 #10
0
    def __init__(self, model_dir, config):
        self._signatures = dict()

        self._graph = tf.Graph()
        with self._graph.as_default():
            tf.set_random_seed(1234)

            # initialize the hparams, problem and model
            self._hparams = trainer_lib.create_hparams(
                config['hparams_set'], config.get('hparams_overrides', ''),
                os.path.join(model_dir, 'assets.extra'), config['problem'])
            problem = self._hparams.problem

            decode_hp = decoding.decode_hparams(
                config.get('decode_hparams', ''))

            run_config = trainer_lib.create_run_config(self._hparams,
                                                       model_dir=model_dir,
                                                       schedule="decode")

            model_fn = t2t_model.T2TModel.make_estimator_model_fn(
                config['model'], self._hparams, decode_hparams=decode_hp)

            # create the orediction signatures (input/output ops)
            serving_receiver = problem.direct_serving_input_fn(self._hparams)
            estimator_spec = model_fn(serving_receiver.features,
                                      None,
                                      mode=tf.estimator.ModeKeys.PREDICT,
                                      params=None,
                                      config=run_config)

            for key, sig_spec in estimator_spec.export_outputs.items():
                # only PredictOutputs are supported, ClassificationOutput
                # and RegressionOutputs are weird artifacts of Google shipping
                # almost unmodified Tensorflow graphs through their Cloud ML
                # platform
                assert isinstance(sig_spec, tf.estimator.export.PredictOutput)

                sig = Signature(key, serving_receiver.receiver_tensors,
                                sig_spec.outputs)
                self._signatures[key] = sig

            # load the model & init the session

            scaffold = tf.train.Scaffold()
            checkpoint_filename = os.path.join(
                model_dir, tf.saved_model.constants.VARIABLES_DIRECTORY,
                tf.saved_model.constants.VARIABLES_FILENAME)
            session_creator = tf.train.ChiefSessionCreator(
                scaffold,
                config=run_config.session_config,
                checkpoint_filename_with_path=checkpoint_filename)
            self._session = tf.train.MonitoredSession(
                session_creator=session_creator)
コード例 #11
0
  def hub_module_fn():
    """Creates the TF graph for the hub module."""
    model_fn = t2t_model.T2TModel.make_estimator_model_fn(
        FLAGS.model,
        hparams,
        decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams))
    features = problem.serving_input_fn(hparams).features
    spec = model_fn(features, labels=None, mode=tf.estimator.ModeKeys.PREDICT)

    # Currently only supports a single input and single output.
    hub.add_signature(
        inputs=features, outputs=spec.export_outputs["serving_default"].outputs)
コード例 #12
0
def create_experiment_components(params,
                                 hparams,
                                 run_config,
                                 problem_instance,
                                 train_preprocess_file_path=None,
                                 dev_preprocess_file_path=None):
    """Constructs and returns Estimator and train/eval input functions."""
    tf.logging.info("Creating experiment, storing model files in %s",
                    run_config.model_dir)

    add_problem_hparams(hparams, params.problem_name, params.model_dir,
                        problem_instance)

    # hparams batch_size is used as minibatch size instead of tokens in batch
    batch_size = (hparams.use_fixed_batch_size and hparams.batch_size) or None
    num_datashards = 1
    train_input_fn = input_fn_builder.build_input_fn(
        mode=tf.estimator.ModeKeys.TRAIN,
        hparams=hparams,
        data_dir=params.data_dir,
        num_datashards=num_datashards,
        worker_replicas=FLAGS.worker_replicas,
        worker_id=FLAGS.worker_id,
        batch_size=batch_size,
        dataset_split=train_preprocess_file_path)

    eval_input_fn = input_fn_builder.build_input_fn(
        mode=tf.estimator.ModeKeys.EVAL,
        hparams=hparams,
        data_dir=params.data_dir,
        num_datashards=num_datashards,
        worker_replicas=FLAGS.worker_replicas,
        worker_id=FLAGS.worker_id,
        dataset_split=dev_preprocess_file_path)

    model_fn = model_builder.build_model_fn(
        params.model_name,
        problem_names=[params.problem_name],
        train_steps=params.train_steps,
        worker_id=FLAGS.worker_id,
        worker_replicas=FLAGS.worker_replicas,
        eval_run_autoregressive=FLAGS.eval_run_autoregressive,
        decode_hparams=decoding.decode_hparams(params.decode_hparams))

    estimator = tf.estimator.Estimator(model_fn=model_fn,
                                       model_dir=run_config.model_dir,
                                       params=hparams,
                                       config=run_config)

    return estimator, {
        tf.estimator.ModeKeys.TRAIN: train_input_fn,
        tf.estimator.ModeKeys.EVAL: eval_input_fn
    }
コード例 #13
0
def decode_hparams(overrides=""):
    """Hparams for decoding."""
    hparams = decoding.decode_hparams()
    # Number of interpolations between [0.0, 1.0].
    hparams.add_hparam("num_interp", 11)
    # Which level(s) to interpolate.
    hparams.add_hparam("level_interp", [0, 1, 2])
    # "all" or "ranked", interpolate all channels or a "ranked".
    hparams.add_hparam("channel_interp", "all")
    # interpolate channels ranked according to squared L2 norm.
    hparams.add_hparam("rank_interp", 1)
    # Whether on not to save frames as summaries
    hparams.add_hparam("save_frames", True)
    hparams.parse(overrides)
    return hparams
コード例 #14
0
  def __prepare_decode_model(self):
    """Prepare utilities for decoding."""
    hparams = trainer_utils.create_hparams(
        self.params.hparams_set,
        self.params.data_dir,
        passed_hparams=self.params.hparams)
    estimator, _ = g2p_trainer_utils.create_experiment_components(
        params=self.params,
        hparams=hparams,
        run_config=trainer_utils.create_run_config(self.params.model_dir),
        problem_instance=self.problem)

    decode_hp = decoding.decode_hparams(self.params.decode_hparams)
    decode_hp.add_hparam("shards", 1)
    return estimator, decode_hp
コード例 #15
0
def create_experiment_components(data_dir, model_name, hparams, run_config):
    """Constructs and returns Estimator and train/eval input functions."""
    tf.logging.info("Creating experiment, storing model files in %s",
                    run_config.model_dir)

    add_problem_hparams(hparams, FLAGS.problems)

    # hparams batch_size is used as minibatch size instead of tokens in batch
    batch_size = (hparams.use_fixed_batch_size and hparams.batch_size) or None
    num_datashards = devices.data_parallelism().n
    train_input_fn = input_fn_builder.build_input_fn(
        mode=tf.estimator.ModeKeys.TRAIN,
        hparams=hparams,
        data_dir=data_dir,
        num_datashards=num_datashards,
        worker_replicas=FLAGS.worker_replicas,
        worker_id=FLAGS.worker_id,
        batch_size=batch_size)  # return feature_map, feature_map["targets"]

    eval_input_fn = input_fn_builder.build_input_fn(
        mode=tf.estimator.ModeKeys.EVAL,
        hparams=hparams,
        data_dir=data_dir,
        num_datashards=num_datashards,
        worker_replicas=FLAGS.worker_replicas,
        worker_id=FLAGS.worker_id,
        dataset_split="test"
        if FLAGS.eval_use_test_set else None)  # evaluate on test dataset
    # input_fn return feature_map

    model_fn = model_builder.build_model_fn(
        model_name,
        problem_names=FLAGS.problems.split("-"),
        train_steps=FLAGS.train_steps,
        worker_id=FLAGS.worker_id,
        worker_replicas=FLAGS.worker_replicas,
        eval_run_autoregressive=FLAGS.eval_run_autoregressive,
        decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams))

    estimator = tf.estimator.Estimator(model_fn=model_fn,
                                       model_dir=run_config.model_dir,
                                       params=hparams,
                                       config=run_config)

    return estimator, {
        tf.estimator.ModeKeys.TRAIN: train_input_fn,
        tf.estimator.ModeKeys.EVAL: eval_input_fn
    }
コード例 #16
0
ファイル: trainer_utils.py プロジェクト: zeyu-h/tensor2tensor
def create_experiment_components(data_dir, model_name, hparams, run_config):
  """Constructs and returns Estimator and train/eval input functions."""
  tf.logging.info("Creating experiment, storing model files in %s",
                  run_config.model_dir)

  add_problem_hparams(hparams, FLAGS.problems)

  # hparams batch_size is used as minibatch size instead of tokens in batch
  batch_size = (hparams.use_fixed_batch_size and hparams.batch_size) or None
  num_datashards = devices.data_parallelism(hparams).n
  train_input_fn = input_fn_builder.build_input_fn(
      mode=tf.estimator.ModeKeys.TRAIN,
      hparams=hparams,
      data_dir=data_dir,
      num_datashards=num_datashards,
      worker_replicas=FLAGS.worker_replicas,
      worker_id=FLAGS.worker_id,
      batch_size=batch_size)

  eval_input_fn = input_fn_builder.build_input_fn(
      mode=tf.estimator.ModeKeys.EVAL,
      hparams=hparams,
      data_dir=data_dir,
      num_datashards=num_datashards,
      worker_replicas=FLAGS.worker_replicas,
      worker_id=FLAGS.worker_id,
      dataset_split="test" if FLAGS.eval_use_test_set else None)

  model_fn = model_builder.build_model_fn(
      model_name,
      problem_names=FLAGS.problems.split("-"),
      train_steps=FLAGS.train_steps,
      worker_id=FLAGS.worker_id,
      worker_replicas=FLAGS.worker_replicas,
      eval_run_autoregressive=FLAGS.eval_run_autoregressive,
      decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams))

  estimator = tf.estimator.Estimator(
      model_fn=model_fn,
      model_dir=run_config.model_dir,
      params=hparams,
      config=run_config)

  return estimator, {
      tf.estimator.ModeKeys.TRAIN: train_input_fn,
      tf.estimator.ModeKeys.EVAL: eval_input_fn
  }
def create_decode_hparams(decode_path, shard):
    decode_hp = decoding.decode_hparams("beam_size=1")
    decode_hp.shards = FLAGS.decode_shards
    decode_hp.shard_id = shard
    decode_in_memory = FLAGS.decode_in_memory or decode_hp.decode_in_memory
    decode_hp.decode_in_memory = decode_in_memory
    if FLAGS.global_steps:
        decode_hp.decode_to_file = os.path.join(
            decode_path, f"{FLAGS.global_steps}{FLAGS.split}")
    else:
        print("Set a global step to be decoded")
        1 / 0
    decode_hp.decode_reference = FLAGS.decode_reference
    decode_hp.log_results = True
    decode_hp.batch_size = 16
    # decode_hp.batch_size = 128
    return decode_hp
コード例 #18
0
    def __init__(self, str_tokens, eval_tokens=None, batch_size=1000):
        """
        Args:
            batch_size: used for encoding
            str_tokens: the original token inputs, as the format of ['t1', 't2'...]. The items within should be strings
            eval_tokens: if not None, then should be the same length as tokens, for similarity comparisons.
        """
        assert type(str_tokens) is list
        assert len(str_tokens) > 0
        assert type(str_tokens[0]) is str
        self.str_tokens = str_tokens
        if eval_tokens is not None:
            assert (len(eval_tokens) == len(str_tokens)
                    and type(eval_tokens[0]) is str)
        self.eval_tokens = eval_tokens
        tf.logging.set_verbosity(tf.logging.INFO)
        tf.logging.info('tf logging set to INFO by: %s' %
                        self.__class__.__name__)

        usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
        trainer_utils.log_registry()
        trainer_utils.validate_flags()
        assert FLAGS.schedule == "train_and_evaluate"
        data_dir = os.path.expanduser(FLAGS.data_dir)
        out_dir = os.path.expanduser(FLAGS.output_dir)

        hparams = trainer_utils.create_hparams(FLAGS.hparams_set,
                                               data_dir,
                                               passed_hparams=FLAGS.hparams)

        trainer_utils.add_problem_hparams(hparams, FLAGS.problems)
        # print(hparams)
        hparams.eval_use_test_set = True

        self.estimator, _ = trainer_utils.create_experiment_components(
            data_dir=data_dir,
            model_name=FLAGS.model,
            hparams=hparams,
            run_config=trainer_utils.create_run_config(out_dir))

        decode_hp = decoding.decode_hparams(FLAGS.decode_hparams)
        decode_hp.add_hparam("shards", FLAGS.decode_shards)
        decode_hp.batch_size = batch_size
        self.decode_hp = decode_hp
        self.arr_results = None
        self._encoding_len = 1
コード例 #19
0
def create_experiment_components(data_dir, model_name, hparams, run_config):
    """Constructs and returns Estimator and train/eval input functions."""
    tf.logging.info("Creating experiment, storing model files in %s",
                    run_config.model_dir)

    hparams = add_problem_hparams(hparams, FLAGS.problems)

    num_datashards = devices.data_parallelism().n
    train_input_fn = input_fn_builder.build_input_fn(
        mode=tf.estimator.ModeKeys.TRAIN,
        hparams=hparams,
        data_file_patterns=get_data_filepatterns(data_dir,
                                                 tf.estimator.ModeKeys.TRAIN),
        num_datashards=num_datashards,
        worker_replicas=FLAGS.worker_replicas,
        worker_id=FLAGS.worker_id)

    eval_input_fn = input_fn_builder.build_input_fn(
        mode=tf.estimator.ModeKeys.EVAL,
        hparams=hparams,
        data_file_patterns=get_data_filepatterns(data_dir,
                                                 tf.estimator.ModeKeys.EVAL),
        num_datashards=num_datashards,
        worker_replicas=FLAGS.worker_replicas,
        worker_id=FLAGS.worker_id)

    model_fn = model_builder.build_model_fn(
        model_name,
        problem_names=FLAGS.problems.split("-"),
        train_steps=FLAGS.train_steps,
        worker_id=FLAGS.worker_id,
        worker_replicas=FLAGS.worker_replicas,
        eval_run_autoregressive=FLAGS.eval_run_autoregressive,
        decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams))

    estimator = tf.estimator.Estimator(model_fn=model_fn,
                                       model_dir=run_config.model_dir,
                                       params=hparams,
                                       config=run_config)

    return estimator, {
        tf.estimator.ModeKeys.TRAIN: train_input_fn,
        tf.estimator.ModeKeys.EVAL: eval_input_fn
    }
コード例 #20
0
ファイル: t2t_trainer.py プロジェクト: yynst2/tensor2tensor
def create_experiment_fn():
  return trainer_lib.create_experiment_fn(
      model_name=FLAGS.model,
      problem_name=get_problem_name(),
      data_dir=os.path.expanduser(FLAGS.data_dir),
      train_steps=FLAGS.train_steps,
      eval_steps=FLAGS.eval_steps,
      min_eval_frequency=FLAGS.local_eval_frequency,
      schedule=FLAGS.schedule,
      export=FLAGS.export_saved_model,
      decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams),
      use_tfdbg=FLAGS.tfdbg,
      use_dbgprofile=FLAGS.dbgprofile,
      eval_early_stopping_steps=FLAGS.eval_early_stopping_steps,
      eval_early_stopping_metric=FLAGS.eval_early_stopping_metric,
      eval_early_stopping_metric_delta=FLAGS.eval_early_stopping_metric_delta,
      eval_early_stopping_metric_minimize=FLAGS.
      eval_early_stopping_metric_minimize,
      use_tpu=FLAGS.use_tpu)
コード例 #21
0
def create_experiment_fn():
  return tpu_trainer_lib.create_experiment_fn(
      model_name=FLAGS.model,
      problem_name=get_problem_name(),
      data_dir=os.path.expanduser(FLAGS.data_dir),
      train_steps=FLAGS.train_steps,
      eval_steps=FLAGS.eval_steps,
      min_eval_frequency=FLAGS.local_eval_frequency,
      schedule=FLAGS.schedule,
      export=FLAGS.export_saved_model,
      decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams),
      use_tfdbg=FLAGS.tfdbg,
      use_dbgprofile=FLAGS.dbgprofile,
      eval_early_stopping_steps=FLAGS.eval_early_stopping_steps,
      eval_early_stopping_metric=FLAGS.eval_early_stopping_metric,
      eval_early_stopping_metric_delta=FLAGS.eval_early_stopping_metric_delta,
      eval_early_stopping_metric_minimize=FLAGS.
      eval_early_stopping_metric_minimize,
      use_tpu=FLAGS.use_tpu)
コード例 #22
0
    def __init__(self, processor_configuration):
        """Creates the Transformer estimator.

    Args:
      processor_configuration: A ProcessorConfiguration protobuffer with the
        transformer fields populated.
    """
        # Do the pre-setup tensor2tensor requires for flags and configurations.
        transformer_config = processor_configuration["transformer"]
        FLAGS.output_dir = transformer_config["model_dir"]
        usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
        data_dir = os.path.expanduser(transformer_config["data_dir"])

        # Create the basic hyper parameters.
        self.hparams = trainer_lib.create_hparams(
            transformer_config["hparams_set"],
            transformer_config["hparams"],
            data_dir=data_dir,
            problem_name=transformer_config["problem"])

        decode_hp = decoding.decode_hparams()
        decode_hp.add_hparam("shards", 1)
        decode_hp.add_hparam("shard_id", 0)

        # Create the estimator and final hyper parameters.
        self.estimator = trainer_lib.create_estimator(
            transformer_config["model"],
            self.hparams,
            t2t_trainer.create_run_config(self.hparams),
            decode_hparams=decode_hp,
            use_tpu=False)

        # Fetch the vocabulary and other helpful variables for decoding.
        self.source_vocab = self.hparams.problem_hparams.vocabulary["inputs"]
        self.targets_vocab = self.hparams.problem_hparams.vocabulary["targets"]
        self.const_array_size = 10000

        # Prepare the Transformer's debug data directory.
        run_dirs = sorted(
            glob.glob(os.path.join("/tmp/t2t_server_dump", "run_*")))
        for run_dir in run_dirs:
            shutil.rmtree(run_dir)
コード例 #23
0
  def testConvertPredictionsToImageSummaries(self):
    # Initialize predictions.
    rng = np.random.RandomState(0)
    x = rng.randint(0, 255, (32, 32, 3))
    predictions = [[{"outputs": x, "inputs": x}] * 50]

    decode_hparams = decoding.decode_hparams()
    # should return 20 summaries of images, 10 outputs and 10 inputs if
    # display_decoded_images is set to True.
    for display, summaries_length in zip([True, False], [20, 0]):
      decode_hparams.display_decoded_images = display
      decode_hooks = decoding.DecodeHookArgs(
          estimator=None, problem=None, output_dirs=None,
          hparams=decode_hparams, decode_hparams=decode_hparams,
          predictions=predictions)
      summaries = image_utils.convert_predictions_to_image_summaries(
          decode_hooks)
      self.assertEqual(len(summaries), summaries_length)
      if summaries:
        self.assertTrue(isinstance(summaries[0], tf.Summary.Value))
コード例 #24
0
  def testConvertPredictionsToImageSummaries(self):
    # Initialize predictions.
    rng = np.random.RandomState(0)
    x = rng.randint(0, 255, (32, 32, 3))
    predictions = [[{"outputs": x, "inputs": x}] * 50]

    decode_hparams = decoding.decode_hparams()
    # should return 20 summaries of images, 10 outputs and 10 inputs if
    # display_decoded_images is set to True.
    for display, summaries_length in zip([True, False], [20, 0]):
      decode_hparams.display_decoded_images = display
      decode_hooks = decoding.DecodeHookArgs(
          estimator=None, problem=None, output_dirs=None,
          hparams=decode_hparams, decode_hparams=decode_hparams,
          predictions=predictions)
      summaries = image_utils.convert_predictions_to_image_summaries(
          decode_hooks)
      self.assertEqual(len(summaries), summaries_length)
      if summaries:
        self.assertTrue(isinstance(summaries[0], tf.Summary.Value))
コード例 #25
0
def main(argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    hvd.init()

    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    log_registry()

    if FLAGS.cloud_mlengine:
        return cloud_mlengine.launch()

    if FLAGS.generate_data:
        generate_data()

    if hasattr(FLAGS, "job_dir") and FLAGS.job_dir:
        FLAGS.output_dir = FLAGS.job_dir

    if argv:
        set_hparams_from_args(argv[1:])

    #
    hparams = create_hparams()

    if is_chief():
        save_metadata(hparams)

    # create_run_config会调用trainer_lib.create_session_config,这个函数包含gup_options初始化
    config = create_run_config(hparams)
    decode_hparams = decoding.decode_hparams(FLAGS.decode_hparams)
    schedule = FLAGS.schedule

    estimator = create_estimator_fn(FLAGS.model, hparams, config, schedule, decode_hparams)

    # logging_hook = tf.train.LoggingTensorHook({"step": "test"}, every_n_iter=5)
    bcast_hook = hvd.BroadcastGlobalVariablesHook(0)

    estimator.train(
        input_fn=train_input_fn(hparams),
        steps=FLAGS.train_steps,
        hooks=[bcast_hook]
    )
コード例 #26
0
ファイル: g2p.py プロジェクト: jupinter/g2p-seq2seq
  def __prepare_model(self):
    """Prepare utilities for decoding."""
    hparams = trainer_lib.create_hparams(
        hparams_set=self.params.hparams_set,
        hparams_overrides_str=self.params.hparams)
    trainer_run_config = g2p_trainer_utils.create_run_config(hparams,
        self.params)
    exp_fn = g2p_trainer_utils.create_experiment_fn(self.params, self.problem)
    self.exp = exp_fn(trainer_run_config, hparams)

    decode_hp = decoding.decode_hparams(self.params.decode_hparams)
    decode_hp.add_hparam("shards", self.params.decode_shards)
    decode_hp.add_hparam("shard_id", self.params.worker_id)
    estimator = trainer_lib.create_estimator(
        self.params.model_name,
        hparams,
        trainer_run_config,
        decode_hparams=decode_hp,
        use_tpu=False)

    return estimator, decode_hp, hparams
コード例 #27
0
  def testConvertPredictionsToVideoSummaries(self):
    # Initialize predictions.
    rng = np.random.RandomState(0)
    inputs = rng.randint(0, 255, (2, 32, 32, 3))
    outputs = rng.randint(0, 255, (5, 32, 32, 3))
    targets = rng.randint(0, 255, (5, 32, 32, 3))

    # batch it up.
    prediction = [{"outputs": outputs, "inputs": inputs, "targets": targets}]*50
    predictions = [prediction]
    decode_hparams = decoding.decode_hparams()

    decode_hooks = decoding.DecodeHookArgs(
        estimator=None, problem=None, output_dirs=None,
        hparams=decode_hparams, decode_hparams=decode_hparams,
        predictions=predictions)
    summaries = video_utils.display_video_hooks(decode_hooks)
    # ground_truth + output.
    self.assertEqual(len(summaries), 20)
    for summary in summaries:
      self.assertTrue(isinstance(summary, tf.Summary.Value))
コード例 #28
0
  def __init__(self, processor_configuration):
    """Creates the Transformer estimator.

    Args:
      processor_configuration: A ProcessorConfiguration protobuffer with the
        transformer fields populated.
    """
    # Do the pre-setup tensor2tensor requires for flags and configurations.
    transformer_config = processor_configuration["transformer"]
    FLAGS.output_dir = transformer_config["model_dir"]
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    data_dir = os.path.expanduser(transformer_config["data_dir"])

    # Create the basic hyper parameters.
    self.hparams = trainer_lib.create_hparams(
        transformer_config["hparams_set"],
        transformer_config["hparams"],
        data_dir=data_dir,
        problem_name=transformer_config["problem"])

    decode_hp = decoding.decode_hparams()
    decode_hp.add_hparam("shards", 1)
    decode_hp.add_hparam("shard_id", 0)

    # Create the estimator and final hyper parameters.
    self.estimator = trainer_lib.create_estimator(
        transformer_config["model"],
        self.hparams,
        t2t_trainer.create_run_config(self.hparams),
        decode_hparams=decode_hp, use_tpu=False)

    # Fetch the vocabulary and other helpful variables for decoding.
    self.source_vocab = self.hparams.problem_hparams.vocabulary["inputs"]
    self.targets_vocab = self.hparams.problem_hparams.vocabulary["targets"]
    self.const_array_size = 10000

    # Prepare the Transformer's debug data directory.
    run_dirs = sorted(glob.glob(os.path.join("/tmp/t2t_server_dump", "run_*")))
    for run_dir in run_dirs:
      shutil.rmtree(run_dir)
コード例 #29
0
ファイル: export.py プロジェクト: changlan/tensor2tensor
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

    if FLAGS.checkpoint_path:
        checkpoint_path = FLAGS.checkpoint_path
        ckpt_dir = os.path.dirname(checkpoint_path)
    else:
        ckpt_dir = os.path.expanduser(FLAGS.output_dir)
        checkpoint_path = tf.train.latest_checkpoint(ckpt_dir)

    hparams = create_hparams()
    hparams.no_data_parallelism = True  # To clear the devices
    problem = hparams.problem
    decode_hparams = decoding.decode_hparams(FLAGS.decode_hparams)

    export_dir = FLAGS.export_dir or os.path.join(ckpt_dir, "export")

    if FLAGS.export_as_tfhub:
        checkpoint_path = tf.train.latest_checkpoint(ckpt_dir)
        export_as_tfhub_module(FLAGS.model, hparams, decode_hparams, problem,
                               checkpoint_path, export_dir)
        return

    run_config = t2t_trainer.create_run_config(hparams)

    estimator = create_estimator(run_config, hparams)

    exporter = tf_estimator.FinalExporter(
        "exporter",
        lambda: problem.serving_input_fn(hparams, decode_hparams, FLAGS.use_tpu
                                         ),
        as_text=FLAGS.as_text)

    exporter.export(estimator,
                    export_dir,
                    checkpoint_path=checkpoint_path,
                    eval_result=None,
                    is_the_final_export=True)
コード例 #30
0
  def __init__(self, data_dir, model_dir):
    """Creates the Transformer estimator.

    Args:
      data_dir: The training data directory.
      model_dir: The trained model directory.
    """
    # Do the pre-setup tensor2tensor requires for flags and configurations.
    FLAGS.output_dir = model_dir
    FLAGS.data_dir = data_dir
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    data_dir = os.path.expanduser(data_dir)

    # Create the basic hyper parameters.
    self.hparams = tpu_trainer_lib.create_hparams(
        FLAGS.hparams_set,
        FLAGS.hparams,
        data_dir=data_dir,
        problem_name=FLAGS.problems)

    decode_hp = decoding.decode_hparams(FLAGS.decode_hparams)
    decode_hp.add_hparam("shards", 1)
    decode_hp.add_hparam("shard_id", 0)

    # Create the estimator and final hyper parameters.
    self.estimator = tpu_trainer_lib.create_estimator(
        FLAGS.model,
        self.hparams,
        tpu_trainer.create_run_config(),
        decode_hp, use_tpu=False)

    # Fetch the vocabulary and other helpful variables for decoding.
    self.source_vocab = self.hparams.problems[0].vocabulary["inputs"]
    self.targets_vocab = self.hparams.problems[0].vocabulary["targets"]
    self.const_array_size = 10000

    # Prepare the Transformer's debug data directory.
    run_dirs = sorted(glob.glob(os.path.join("/tmp/t2t_server_dump", "run_*")))
    for run_dir in run_dirs:
      shutil.rmtree(run_dir)
コード例 #31
0
ファイル: tpu_trainer.py プロジェクト: wqh17101/tensor2tensor
def create_experiment_fn():
    use_validation_monitor = (
        FLAGS.schedule in ["train_and_evaluate", "continuous_train_and_eval"]
        and FLAGS.local_eval_frequency)
    return tpu_trainer_lib.create_experiment_fn(
        model_name=FLAGS.model,
        problem_name=get_problem_name(),
        data_dir=os.path.expanduser(FLAGS.data_dir),
        train_steps=FLAGS.train_steps,
        eval_steps=FLAGS.eval_steps,
        min_eval_frequency=FLAGS.local_eval_frequency,
        schedule=FLAGS.schedule,
        export=FLAGS.export_saved_model,
        decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams),
        use_tfdbg=FLAGS.tfdbg,
        use_dbgprofile=FLAGS.dbgprofile,
        use_validation_monitor=use_validation_monitor,
        eval_early_stopping_steps=FLAGS.eval_early_stopping_steps,
        eval_early_stopping_metric=FLAGS.eval_early_stopping_metric,
        eval_early_stopping_metric_minimize=FLAGS.
        eval_early_stopping_metric_minimize,
        use_tpu=FLAGS.use_tpu)
コード例 #32
0
  def testConvertPredictionsToVideoSummaries(self, num_decodes=5,
                                             max_output_steps=5):
    # Initialize predictions.
    rng = np.random.RandomState(0)
    inputs = rng.randint(0, 255, (2, 32, 32, 3))
    outputs = rng.randint(0, 255, (max_output_steps, 32, 32, 3))
    targets = rng.randint(0, 255, (5, 32, 32, 3))

    # batch it up.
    prediction = [{"outputs": outputs, "inputs": inputs, "targets": targets}]*5
    predictions = [prediction] * num_decodes
    decode_hparams = decoding.decode_hparams(
        overrides="max_display_decodes=5")

    decode_hooks = decoding.DecodeHookArgs(
        estimator=None, problem=None, output_dirs=None,
        hparams=decode_hparams, decode_hparams=decode_hparams,
        predictions=predictions)
    summaries = video_utils.display_video_hooks(decode_hooks)

    for summary in summaries:
      self.assertIsInstance(summary, tf.Summary.Value)
コード例 #33
0
ファイル: export.py プロジェクト: qixiuai/tensor2tensor
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

  if FLAGS.checkpoint_path:
    checkpoint_path = FLAGS.checkpoint_path
    ckpt_dir = os.path.dirname(checkpoint_path)
  else:
    ckpt_dir = os.path.expanduser(FLAGS.output_dir)
    checkpoint_path = tf.train.latest_checkpoint(ckpt_dir)

  hparams = create_hparams()
  hparams.no_data_parallelism = True  # To clear the devices
  problem = hparams.problem

  export_dir = FLAGS.export_dir or os.path.join(ckpt_dir, "export")

  if FLAGS.export_as_tfhub:
    checkpoint_path = tf.train.latest_checkpoint(ckpt_dir)
    decode_hparams = decoding.decode_hparams(FLAGS.decode_hparams)
    export_as_tfhub_module(FLAGS.model, hparams, decode_hparams, problem,
                           checkpoint_path, export_dir)
    return

  run_config = t2t_trainer.create_run_config(hparams)

  estimator = create_estimator(run_config, hparams)

  exporter = tf.estimator.FinalExporter(
      "exporter", lambda: problem.serving_input_fn(hparams), as_text=True)

  exporter.export(
      estimator,
      export_dir,
      checkpoint_path=checkpoint_path,
      eval_result=None,
      is_the_final_export=True)
コード例 #34
0
    def initialize(self, is_conditioned=False):
        self.model_name = 'transformer'
        self.hparams_set = 'transformer_tpu'
        self.conditioned = is_conditioned
        if self.conditioned:
            self.ckpt_path = 'models/checkpoints/melody_conditioned_model_16.ckpt'
            problem = MelodyToPianoPerformanceProblem()
        else:
            self.ckpt_path = 'models/checkpoints/unconditional_model_16.ckpt'
            problem = PianoPerformanceLanguageModelProblem()

        self.encoders = problem.get_feature_encoders()

        # Set up hyperparams
        hparams = trainer_lib.create_hparams(hparams_set=self.hparams_set)
        trainer_lib.add_problem_hparams(hparams, problem)
        hparams.num_hidden_layers = 16
        hparams.sampling_method = 'random'

        # Set up decoding hyperparams
        decode_hparams = decoding.decode_hparams()
        decode_hparams.alpha = 0.0
        decode_hparams.beam_size = 1
        if self.conditioned:
            self.inputs = []
        else:
            self.targets = []

        self.decode_length = 0
        run_config = trainer_lib.create_run_config(hparams)
        estimator = trainer_lib.create_estimator(
            self.model_name, hparams, run_config,
            decode_hparams=decode_hparams)
        fnc = self.input_generation_conditional if self.conditioned else self.input_generator_unconditional
        input_fn = decoding.make_input_fn_from_generator(fnc())
        self.samples = estimator.predict(
            input_fn, checkpoint_path=self.ckpt_path)
        _ = next(self.samples)
コード例 #35
0
def create_experiment_fn(params, problem_instance):
  use_validation_monitor = (params.schedule in
                            ["train_and_evaluate", "continuous_train_and_eval"]
                            and params.local_eval_frequency)
  return create_experiment_func(
      model_name=params.model_name,
      params=params,
      problem_instance=problem_instance,
      data_dir=os.path.expanduser(params.data_dir_name),
      train_steps=params.train_steps,
      eval_steps=params.eval_steps,
      min_eval_frequency=params.local_eval_frequency,
      schedule=params.schedule,
      export=params.export_saved_model,
      decode_hparams=decoding.decode_hparams(params.decode_hparams),
      use_tfdbg=params.tfdbg,
      use_dbgprofile=params.dbgprofile,
      use_validation_monitor=use_validation_monitor,
      eval_early_stopping_steps=params.eval_early_stopping_steps,
      eval_early_stopping_metric=params.eval_early_stopping_metric,
      eval_early_stopping_metric_minimize=\
        params.eval_early_stopping_metric_minimize,
      use_tpu=params.use_tpu)
コード例 #36
0
    def __prepare_model(self):
        """Prepare utilities for decoding."""
        hparams = registry.hparams(self.params.hparams_set)
        hparams.problem = self.problem
        hparams.problem_hparams = self.problem.get_hparams(hparams)
        if self.params.hparams:
            tf.logging.info("Overriding hparams in %s with %s",
                            self.params.hparams_set, self.params.hparams)
            hparams = hparams.parse(self.params.hparams)
        trainer_run_config = g2p_trainer_utils.create_run_config(
            hparams, self.params)
        exp_fn = g2p_trainer_utils.create_experiment_fn(
            self.params, self.problem)
        self.exp = exp_fn(trainer_run_config, hparams)

        decode_hp = decoding.decode_hparams(self.params.decode_hparams)
        estimator = trainer_lib.create_estimator(self.params.model_name,
                                                 hparams,
                                                 trainer_run_config,
                                                 decode_hparams=decode_hp,
                                                 use_tpu=False)

        return estimator, decode_hp, hparams
コード例 #37
0
def create_hp_and_estimator(problem_name, data_dir, checkpoint_path):
    trainer_lib.set_random_seed(FLAGS.random_seed)

    hp = trainer_lib.create_hparams(FLAGS.hparams_set,
                                    FLAGS.hparams,
                                    data_dir=os.path.expanduser(data_dir),
                                    problem_name=problem_name)

    decode_hp = decoding.decode_hparams(FLAGS.decode_hparams)
    decode_hp.shards = FLAGS.decode_shards
    decode_hp.shard_id = FLAGS.worker_id
    decode_in_memory = FLAGS.decode_in_memory or decode_hp.decode_in_memory
    decode_hp.decode_in_memory = decode_in_memory
    decode_hp.decode_to_file = None
    decode_hp.decode_reference = None

    FLAGS.checkpoint_path = checkpoint_path
    estimator = trainer_lib.create_estimator(FLAGS.model,
                                             hp,
                                             t2t_trainer.create_run_config(hp),
                                             decode_hparams=decode_hp,
                                             use_tpu=FLAGS.use_tpu)
    return hp, decode_hp, estimator
コード例 #38
0
def create_experiment_fn(params, problem_instance):
  use_validation_monitor = (params.schedule in
                            ["train_and_evaluate", "continuous_train_and_eval"]
                            and params.local_eval_frequency)
  return create_experiment_func(
      model_name=params.model_name,
      params=params,
      problem_instance=problem_instance,
      data_dir=os.path.expanduser(params.data_dir_name),
      train_steps=params.train_steps,
      eval_steps=params.eval_steps,
      min_eval_frequency=params.local_eval_frequency,
      schedule=params.schedule,
      export=params.export_saved_model,
      decode_hparams=decoding.decode_hparams(params.decode_hparams),
      use_tfdbg=params.tfdbg,
      use_dbgprofile=params.dbgprofile,
      use_validation_monitor=use_validation_monitor,
      eval_early_stopping_steps=params.eval_early_stopping_steps,
      eval_early_stopping_metric=params.eval_early_stopping_metric,
      eval_early_stopping_metric_minimize=\
        params.eval_early_stopping_metric_minimize,
      use_tpu=params.use_tpu)
コード例 #39
0
def create_experiment_fn(**kwargs):
    return trainer_lib.create_experiment_fn(
        model_name=FLAGS.model,
        problem_name=FLAGS.problem,
        data_dir=os.path.expanduser(FLAGS.data_dir),
        train_steps=FLAGS.train_steps,
        eval_steps=FLAGS.eval_steps,
        min_eval_frequency=FLAGS.local_eval_frequency,
        schedule=FLAGS.schedule,
        eval_throttle_seconds=FLAGS.eval_throttle_seconds,
        export=FLAGS.export_saved_model,
        decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams),
        use_tfdbg=FLAGS.tfdbg,
        use_dbgprofile=FLAGS.dbgprofile,
        eval_early_stopping_steps=FLAGS.eval_early_stopping_steps,
        eval_early_stopping_metric=FLAGS.eval_early_stopping_metric,
        eval_early_stopping_metric_delta=FLAGS.
        eval_early_stopping_metric_delta,
        eval_early_stopping_metric_minimize=FLAGS.
        eval_early_stopping_metric_minimize,
        use_tpu=FLAGS.use_tpu,
        use_tpu_estimator=FLAGS.use_tpu_estimator,
        use_xla=FLAGS.xla_compile,
        **kwargs)
コード例 #40
0
def create_experiment_fn():
    return trainer_lib.create_experiment_fn(
        model_name=FLAGS.model,
        problem_name=FLAGS.problem,
        data_dir=os.path.expanduser(FLAGS.data_dir),
        train_steps=FLAGS.train_steps,
        eval_steps=FLAGS.eval_steps,
        min_eval_frequency=FLAGS.local_eval_frequency,
        schedule=FLAGS.schedule,
        eval_throttle_seconds=FLAGS.eval_throttle_seconds,
        export=FLAGS.export_saved_model,
        decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams),
        use_tfdbg=FLAGS.tfdbg,
        use_dbgprofile=FLAGS.dbgprofile,
        eval_early_stopping_steps=FLAGS.eval_early_stopping_steps,
        eval_early_stopping_metric=FLAGS.eval_early_stopping_metric,
        eval_early_stopping_metric_delta=FLAGS.
        eval_early_stopping_metric_delta,
        eval_early_stopping_metric_minimize=FLAGS.
        eval_early_stopping_metric_minimize,
        eval_timeout_mins=FLAGS.eval_timeout_mins,
        eval_use_test_set=FLAGS.eval_use_test_set,
        use_tpu=FLAGS.use_tpu,
        use_tpu_estimator=FLAGS.use_tpu_estimator,
        use_xla=FLAGS.xla_compile,
        warm_start_from=FLAGS.warm_start_from,
        decode_from_file=FLAGS.decode_from_file,
        decode_to_file=FLAGS.decode_to_file,
        decode_reference=FLAGS.decode_reference,
        std_server_protocol=FLAGS.std_server_protocol,
        weight_lower_bound=FLAGS.weight_lower_bound,
        original_shake_shake=FLAGS.original_shake_shake,
        switch_grads=FLAGS.switch_grads,
        relu_first=FLAGS.relu_first,
        is_switchable=FLAGS.is_switchable,
        original_switchable=FLAGS.original_switchable)
コード例 #41
0
ファイル: t2t_decoder.py プロジェクト: kltony/tensor2tensor
def create_decode_hparams():
  decode_hp = decoding.decode_hparams(FLAGS.decode_hparams)
  decode_hp.shards = FLAGS.decode_shards
  decode_hp.shard_id = FLAGS.worker_id
  return decode_hp
コード例 #42
0
ファイル: t2t_decoder.py プロジェクト: chqiwang/tensor2tensor
def create_decode_hparams():
  decode_hp = decoding.decode_hparams(FLAGS.decode_hparams)
  decode_hp.add_hparam("shards", FLAGS.decode_shards)
  decode_hp.add_hparam("shard_id", FLAGS.worker_id)
  return decode_hp