Exemplo n.º 1
0
    def _restore_checkpoint(self,
                            master,
                            saver=None,
                            checkpoint_dir=None,
                            checkpoint_filename_with_path=None,
                            wait_for_checkpoint=False,
                            max_wait_secs=7200,
                            config=None):
        """Creates a `Session`, and tries to restore a checkpoint.


    Args:
      master: `String` representation of the TensorFlow master to use.
      saver: A `Saver` object used to restore a model.
      checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the
        dir will be used to restore.
      checkpoint_filename_with_path: Full file name path to the checkpoint file.
      wait_for_checkpoint: Whether to wait for checkpoint to become available.
      max_wait_secs: Maximum time to wait for checkpoints to become available.
      config: Optional `ConfigProto` proto used to configure the session.

    Returns:
      A pair (sess, is_restored) where 'is_restored' is `True` if
      the session could be restored, `False` otherwise.

    Raises:
      ValueError: If both checkpoint_dir and checkpoint_filename_with_path are
        set.
    """
        self._target = master
        sess = session.Session(self._target, graph=self._graph, config=config)

        if checkpoint_dir and checkpoint_filename_with_path:
            raise ValueError("Can not provide both checkpoint_dir and "
                             "checkpoint_filename_with_path.")
        # If either saver or checkpoint_* is not specified, cannot restore. Just
        # return.
        if not saver or not (checkpoint_dir or checkpoint_filename_with_path):
            return sess, False

        if checkpoint_filename_with_path:
            saver.restore(sess, checkpoint_filename_with_path)
            return sess, True

        # Waits up until max_wait_secs for checkpoint to become available.
        wait_time = 0
        ckpt = saver_mod.get_checkpoint_state(checkpoint_dir)
        while not ckpt or not ckpt.model_checkpoint_path:
            if wait_for_checkpoint and wait_time < max_wait_secs:
                logging.info("Waiting for checkpoint to be available.")
                time.sleep(self._recovery_wait_secs)
                wait_time += self._recovery_wait_secs
                ckpt = saver_mod.get_checkpoint_state(checkpoint_dir)
            else:
                return sess, False

        # Loads the checkpoint.
        saver.restore(sess, ckpt.model_checkpoint_path)
        saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths)
        return sess, True
Exemplo n.º 2
0
  def recover_session(self, master, saver=None, checkpoint_dir=None,
                      wait_for_checkpoint=False, max_wait_secs=7200,
                      config=None):
    """Creates a `Session`, recovering if possible.

    Creates a new session on 'master'.  If the session is not initialized
    and can be recovered from a checkpoint, recover it.

    Args:
      master: `String` representation of the TensorFlow master to use.
      saver: A `Saver` object used to restore a model.
      checkpoint_dir: Path to the checkpoint files.
      wait_for_checkpoint: Whether to wait for checkpoint to become available.
      max_wait_secs: Maximum time to wait for checkpoints to become available.
      config: Optional `ConfigProto` proto used to configure the session.

    Returns:
      A pair (sess, initialized) where 'initialized' is `True` if
      the session could be recovered, `False` otherwise.
    """
    self._target = master
    sess = session.Session(self._target, graph=self._graph, config=config)
    if self._local_init_op:
      sess.run([self._local_init_op])

    # If either saver or checkpoint_dir is not specified, cannot restore. Just
    # return.
    if not saver or not checkpoint_dir:
      not_ready = self._model_not_ready(sess)
      return sess, not_ready is None

    # Waits up until max_wait_secs for checkpoint to become available.
    wait_time = 0
    ckpt = saver_mod.get_checkpoint_state(checkpoint_dir)
    while not ckpt or not ckpt.model_checkpoint_path:
      if wait_for_checkpoint and wait_time < max_wait_secs:
        logging.info("Waiting for checkpoint to be available.")
        time.sleep(self._recovery_wait_secs)
        wait_time += self._recovery_wait_secs
        ckpt = saver_mod.get_checkpoint_state(checkpoint_dir)
      else:
        return sess, False

    # Loads the checkpoint and verifies that it makes the model ready.
    saver.restore(sess, ckpt.model_checkpoint_path)
    last_checkpoints = []
    for fname in ckpt.all_model_checkpoint_paths:
      fnames = gfile.Glob(fname)
      if fnames:
        mtime = gfile.Stat(fnames[0]).mtime
        last_checkpoints.append((fname, mtime))
    saver.set_last_checkpoints_with_time(last_checkpoints)
    not_ready = self._model_not_ready(sess)
    if not_ready:
      logging.info("Restoring model from %s did not make model ready: %s",
                   ckpt.model_checkpoint_path, not_ready)
      return sess, False
    else:
      logging.info("Restored model from %s", ckpt.model_checkpoint_path)
      return sess, True
Exemplo n.º 3
0
  def recover_session(self, master, saver=None, checkpoint_dir=None,
                      wait_for_checkpoint=False, max_wait_secs=7200,
                      config=None):
    """Creates a `Session`, recovering if possible.

    Creates a new session on 'master'.  If the session is not initialized
    and can be recovered from a checkpoint, recover it.

    Args:
      master: `String` representation of the TensorFlow master to use.
      saver: A `Saver` object used to restore a model.
      checkpoint_dir: Path to the checkpoint files.
      wait_for_checkpoint: Whether to wait for checkpoint to become available.
      max_wait_secs: Maximum time to wait for checkpoints to become available.
      config: Optional `ConfigProto` proto used to configure the session.

    Returns:
      A pair (sess, initialized) where 'initialized' is `True` if
      the session could be recovered, `False` otherwise.
    """
    target = self._maybe_launch_in_process_server(master)
    sess = session.Session(target, graph=self._graph, config=config)
    if self._local_init_op:
      sess.run([self._local_init_op])

    # If either saver or checkpoint_dir is not specified, cannot restore. Just
    # return.
    if not saver or not checkpoint_dir:
      not_ready = self._model_not_ready(sess)
      return sess, not_ready is None

    # Waits up until max_wait_secs for checkpoint to become available.
    wait_time = 0
    ckpt = saver_mod.get_checkpoint_state(checkpoint_dir)
    while not ckpt or not ckpt.model_checkpoint_path:
      if wait_for_checkpoint and wait_time < max_wait_secs:
        logging.info("Waiting for checkpoint to be available.")
        time.sleep(self._recovery_wait_secs)
        wait_time += self._recovery_wait_secs
        ckpt = saver_mod.get_checkpoint_state(checkpoint_dir)
      else:
        return sess, False

    # Loads the checkpoint and verifies that it makes the model ready.
    saver.restore(sess, ckpt.model_checkpoint_path)
    last_checkpoints = []
    for fname in ckpt.all_model_checkpoint_paths:
      fnames = gfile.Glob(fname)
      if fnames:
        mtime = gfile.Stat(fnames[0]).mtime
        last_checkpoints.append((fname, mtime))
    saver.set_last_checkpoints_with_time(last_checkpoints)
    not_ready = self._model_not_ready(sess)
    if not_ready:
      logging.info("Restoring model from %s did not make model ready: %s",
                   ckpt.model_checkpoint_path, not_ready)
      return sess, False
    else:
      logging.info("Restored model from %s", ckpt.model_checkpoint_path)
      return sess, True
Exemplo n.º 4
0
  def _restore_checkpoint(self,
                          master,
                          saver=None,
                          checkpoint_dir=None,
                          checkpoint_filename_with_path=None,
                          wait_for_checkpoint=False,
                          max_wait_secs=7200,
                          config=None):
    """Creates a `Session`, and tries to restore a checkpoint.


    Args:
      master: `String` representation of the TensorFlow master to use.
      saver: A `Saver` object used to restore a model.
      checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the
        dir will be used to restore.
      checkpoint_filename_with_path: Full file name path to the checkpoint file.
      wait_for_checkpoint: Whether to wait for checkpoint to become available.
      max_wait_secs: Maximum time to wait for checkpoints to become available.
      config: Optional `ConfigProto` proto used to configure the session.

    Returns:
      A pair (sess, is_restored) where 'is_restored' is `True` if
      the session could be restored, `False` otherwise.

    Raises:
      ValueError: If both checkpoint_dir and checkpoint_filename_with_path are
        set.
    """
    self._target = master
    sess = session.Session(self._target, graph=self._graph, config=config)

    if checkpoint_dir and checkpoint_filename_with_path:
      raise ValueError("Can not provide both checkpoint_dir and "
                       "checkpoint_filename_with_path.")
    # If either saver or checkpoint_* is not specified, cannot restore. Just
    # return.
    if not saver or not (checkpoint_dir or checkpoint_filename_with_path):
      return sess, False

    if checkpoint_filename_with_path:
      saver.restore(sess, checkpoint_filename_with_path)
      return sess, True

    # Waits up until max_wait_secs for checkpoint to become available.
    wait_time = 0
    ckpt = saver_mod.get_checkpoint_state(checkpoint_dir)
    while not ckpt or not ckpt.model_checkpoint_path:
      if wait_for_checkpoint and wait_time < max_wait_secs:
        logging.info("Waiting for checkpoint to be available.")
        time.sleep(self._recovery_wait_secs)
        wait_time += self._recovery_wait_secs
        ckpt = saver_mod.get_checkpoint_state(checkpoint_dir)
      else:
        return sess, False

    # Loads the checkpoint.
    saver.restore(sess, ckpt.model_checkpoint_path)
    saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths)
    return sess, True
Exemplo n.º 5
0
    def _restore_checkpoint(self,
                            master,
                            saver=None,
                            checkpoint_dir=None,
                            wait_for_checkpoint=False,
                            max_wait_secs=7200,
                            config=None):
        """Creates a `Session`, and tries to restore a checkpoint.


    Args:
      master: `String` representation of the TensorFlow master to use.
      saver: A `Saver` object used to restore a model.
      checkpoint_dir: Path to the checkpoint files.
      wait_for_checkpoint: Whether to wait for checkpoint to become available.
      max_wait_secs: Maximum time to wait for checkpoints to become available.
      config: Optional `ConfigProto` proto used to configure the session.

    Returns:
      A pair (sess, is_restored) where 'is_restored' is `True` if
      the session could be restored, `False` otherwise.
    """
        self._target = master
        sess = session.Session(self._target, graph=self._graph, config=config)

        # If either saver or checkpoint_dir is not specified, cannot restore. Just
        # return.
        if not saver or not checkpoint_dir:
            return sess, False

        # Waits up until max_wait_secs for checkpoint to become available.
        wait_time = 0
        ckpt = saver_mod.get_checkpoint_state(checkpoint_dir)
        while not ckpt or not ckpt.model_checkpoint_path:
            if wait_for_checkpoint and wait_time < max_wait_secs:
                logging.info("Waiting for checkpoint to be available.")
                time.sleep(self._recovery_wait_secs)
                wait_time += self._recovery_wait_secs
                ckpt = saver_mod.get_checkpoint_state(checkpoint_dir)
            else:
                return sess, False

        # Loads the checkpoint.
        saver.restore(sess, ckpt.model_checkpoint_path)
        last_checkpoints = []
        for fname in ckpt.all_model_checkpoint_paths:
            fnames = gfile.Glob(fname)
            if fnames:
                mtime = gfile.Stat(fnames[0]).mtime
                last_checkpoints.append((fname, mtime))
        saver.set_last_checkpoints_with_time(last_checkpoints)
        return sess, True
Exemplo n.º 6
0
  def _restore_checkpoint(self,
                          master,
                          saver=None,
                          checkpoint_dir=None,
                          wait_for_checkpoint=False,
                          max_wait_secs=7200,
                          config=None):
    """Creates a `Session`, and tries to restore a checkpoint.


    Args:
      master: `String` representation of the TensorFlow master to use.
      saver: A `Saver` object used to restore a model.
      checkpoint_dir: Path to the checkpoint files.
      wait_for_checkpoint: Whether to wait for checkpoint to become available.
      max_wait_secs: Maximum time to wait for checkpoints to become available.
      config: Optional `ConfigProto` proto used to configure the session.

    Returns:
      A pair (sess, is_restored) where 'is_restored' is `True` if
      the session could be restored, `False` otherwise.
    """
    self._target = master
    sess = session.Session(self._target, graph=self._graph, config=config)

    # If either saver or checkpoint_dir is not specified, cannot restore. Just
    # return.
    if not saver or not checkpoint_dir:
      return sess, False

    # Waits up until max_wait_secs for checkpoint to become available.
    wait_time = 0
    ckpt = saver_mod.get_checkpoint_state(checkpoint_dir)
    while not ckpt or not ckpt.model_checkpoint_path:
      if wait_for_checkpoint and wait_time < max_wait_secs:
        logging.info("Waiting for checkpoint to be available.")
        time.sleep(self._recovery_wait_secs)
        wait_time += self._recovery_wait_secs
        ckpt = saver_mod.get_checkpoint_state(checkpoint_dir)
      else:
        return sess, False

    # Loads the checkpoint.
    saver.restore(sess, ckpt.model_checkpoint_path)
    last_checkpoints = []
    for fname in ckpt.all_model_checkpoint_paths:
      fnames = gfile.Glob(fname)
      if fnames:
        mtime = gfile.Stat(fnames[0]).mtime
        last_checkpoints.append((fname, mtime))
    saver.set_last_checkpoints_with_time(last_checkpoints)
    return sess, True
Exemplo n.º 7
0
def run_eval(args):
    if not args.ckpt_path:
        run_name = args.name or args.model
        log_dir = os.path.join(args.base_dir,
                               'logs-%s-%s' % (run_name, args.description))
        print(
            "Trying to restore saved checkpoints from {} ...".format(log_dir))
        ckpt = get_checkpoint_state(log_dir)
        if ckpt:
            print("Checkpoint found: {}".format(ckpt.model_checkpoint_path))
            ckpt_path = ckpt.model_checkpoint_path
        else:
            print('no model found')
            raise
    else:
        ckpt_path = args.ckpt_path
    print(hparams_debug_string())
    synth = Synthesizer()
    synth.load(ckpt_path)
    base_path = get_output_base_path(ckpt_path)
    os.makedirs(base_path, exist_ok=True)
    for i, text in enumerate(sentences):
        text = re.sub("[A-Za-z0-9\!\%\[\]\,\,\。\…\:\“\”]", "", text)
        text = text.strip()
        path = os.path.join(base_path,
                            '%d-identity-%d-%s.wav' % (i, args.identity, text))
        path_alignment = os.path.join(
            base_path, '%d-identity-%d.png' % (i, args.identity))
        print('Synthesizing: %s' % path)
        synth.synthesize(text, args.identity, path, path_alignment)
Exemplo n.º 8
0
def latest_checkpoint(checkpoint_dir, latest_filename=None):
  """Finds the filename of latest saved checkpoint file.
  Args:
    checkpoint_dir: Directory where the variables were saved.
    latest_filename: Optional name for the protocol buffer file that
      contains the list of most recent checkpoint filenames.
      See the corresponding argument to `Saver.save()`.
  Returns:
    The full path to the latest checkpoint or `None` if no checkpoint was found.
  """
  # Pick the latest checkpoint based on checkpoint state.
  ckpt = get_checkpoint_state(checkpoint_dir, latest_filename)
  # strips directory-identifying component from path
  suffix = re.search("model\.ckpt.+", ckpt.model_checkpoint_path).group(0)
  path = os.path.join(checkpoint_dir, suffix)
  if ckpt and path:
    # Look for either a V2 path or a V1 path, with priority for V2.
    v2_path = _prefix_to_checkpoint_path(path,
                                         saver_pb2.SaverDef.V2)
    v1_path = _prefix_to_checkpoint_path(path,
                                         saver_pb2.SaverDef.V1)
    if file_io.get_matching_files(v2_path) or file_io.get_matching_files(
        v1_path):
      return path
    else:
      logging.error("Couldn't match files for checkpoint %s",
                    path)
  return None
 def loadParameters(self):
     print('Start to load parameters from {0}'.format(self.model_dir))
     gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0)
     config = tf.ConfigProto(gpu_options=gpu_options)
     config.allow_soft_placement = True
     config.log_device_placement = False
     self.sess = tf.Session(config=config)
     with self.sess.as_default():
         s = tf.train.Saver()
         ckpt = saver_mod.get_checkpoint_state(self.model_dir)
         s.restore(self.sess, ckpt.model_checkpoint_path)
Exemplo n.º 10
0
 def _assert_ckpt(self, output_dir, expected=True):
   ckpt_state = saver_lib.get_checkpoint_state(output_dir)
   if expected:
     pattern = '%s/model.ckpt-.*' % output_dir
     primary_ckpt_path = ckpt_state.model_checkpoint_path
     self.assertRegexpMatches(primary_ckpt_path, pattern)
     all_ckpt_paths = ckpt_state.all_model_checkpoint_paths
     self.assertTrue(primary_ckpt_path in all_ckpt_paths)
     for ckpt_path in all_ckpt_paths:
       self.assertRegexpMatches(ckpt_path, pattern)
   else:
     self.assertTrue(ckpt_state is None)
Exemplo n.º 11
0
 def _assert_ckpt(self, output_dir, expected=True):
     ckpt_state = saver_lib.get_checkpoint_state(output_dir)
     if expected:
         pattern = '%s/model.ckpt-.*' % output_dir
         primary_ckpt_path = ckpt_state.model_checkpoint_path
         self.assertRegexpMatches(primary_ckpt_path, pattern)
         all_ckpt_paths = ckpt_state.all_model_checkpoint_paths
         self.assertTrue(primary_ckpt_path in all_ckpt_paths)
         for ckpt_path in all_ckpt_paths:
             self.assertRegexpMatches(ckpt_path, pattern)
     else:
         self.assertTrue(ckpt_state is None)
Exemplo n.º 12
0
    def after_create_session(self, session, coord):
        super(CheckpointRestorerHook,
              self).after_create_session(session, coord)

        if self._file:
            logging.info("Restoring params from file")
            self._saver.restore(session, self._file)
            logging.info("Finished restoring")
            return
        wait_time = 0
        ckpt = saver.get_checkpoint_state(self._dir)
        while not ckpt or not ckpt.model_checkpoint_path:
            if self._wait_for_checkpoint and wait_time < self._max_wait_secs:
                logging.info("Waiting for checkpoint to be available.")
                time.sleep(self._recovery_wait_secs)
                wait_time += self._recovery_wait_secs
                ckpt = saver.get_checkpoint_state(self._dir)
            else:
                return

        # Loads the checkpoint.
        self._saver.restore(session, ckpt.model_checkpoint_path)
        self._saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths)
Exemplo n.º 13
0
 def prepare_session(self, master, checkpoint_dir=None, saver=None, config=None, **_):
     logger = get_logger()
     logger.info('prepare_session')
     session = Session(master, graph=self._graph, config=config)
     self._session_init_fn(session)
     if saver and checkpoint_dir:
         ckpt = get_checkpoint_state(checkpoint_dir)
         if ckpt and ckpt.model_checkpoint_path:  # pylint: disable=no-member
             logger.info('restoring from %s',
                         ckpt.model_checkpoint_path)  # pylint: disable=no-member
             saver.restore(session, ckpt.model_checkpoint_path)  # pylint: disable=no-member
             saver.recover_last_checkpoints(
                 ckpt.all_model_checkpoint_paths)  # pylint: disable=no-member
         else:
             logger.info('no valid checkpoint in %s', checkpoint_dir)
     return session
Exemplo n.º 14
0
def create_restore_fn(checkpoint_path, saver, sess):
    if tf.gfile.IsDirectory(checkpoint_path):
        latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
        if not latest_checkpoint:
            return

    tf.logging.info("Loading model from checkpoint: %s", checkpoint_path)

    ckpt = saver_mod.get_checkpoint_state(checkpoint_path)
    while not ckpt or not ckpt.model_checkpoint_path:
        return
    print(ckpt.model_checkpoint_path)
    saver.restore(sess, ckpt.model_checkpoint_path)
    saver.recover_last_checkpoints([checkpoint_path])
    tf.logging.info("Successfully loaded checkpoint: %s",
                    os.path.basename(checkpoint_path))
Exemplo n.º 15
0
  def _init_env(self):
    tf.logging.info("Import usr dir from %s",self._usr_dir)
    if self._usr_dir != None:
      usr_dir.import_usr_dir(self._usr_dir)
    tf.logging.info("Start to create hparams,for %s of %s",self._problem,self._hparams_set)
    self._hparams = trainer_utils.create_hparams(self._hparams_set,self._data_dir)
    trainer_utils.add_problem_hparams(self._hparams, self._problem)
    tf.logging.info("build the model_fn of %s of %s",self._model_name,self._hparams)
    #self._model_fn = model_builder.build_model_fn(self._model_name,self._hparams)
    #self._model_fn = model_builder.build_model_fn(self._model_name)
    self._inputs_ph = tf.placeholder(dtype=tf.int32)# shape not specified,any shape

    batch_inputs = tf.reshape(self._inputs_ph,[self._batch_size,-1,1,1])
    #batch_inputs = tf.reshape(self._inputs_ph, [-1, -1, 1, 1])

    targets_ph = tf.placeholder(dtype=tf.int32)
    batch_targets = tf.reshape(targets_ph,[1,-1,1,1])
    features = {"inputs": batch_inputs,
            "problem_choice": 0,  # We run on the first problem here.
            "input_space_id": self._hparams.problems[0].input_space_id,
            "target_space_id": self._hparams.problems[0].target_space_id}
    mode = tf.estimator.ModeKeys.PREDICT
    estimator_spec = model_builder.model_fn(self._model_name,features, mode,self._hparams,
      problem_names=[self._problem],decode_hparams=self._hparams_dc)
    predictions_dict=estimator_spec.predictions
    self._predictions = predictions_dict["outputs"]
    #self._scores=predictions_dict['scores'] not return when greedy search
    tf.logging.info("Start to init tf session")
    if self._isGpu:
      print('Using GPU in Decoder')
      gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=self._fraction)
      self._sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,log_device_placement=False,gpu_options=gpu_options))
    else:
      print('Using CPU in Decoder')
      gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0)
      config = tf.ConfigProto(gpu_options=gpu_options)
      config.allow_soft_placement=True
      config.log_device_placement=False
      self._sess = tf.Session(config=config) 
    with self._sess.as_default():
        ckpt = saver_mod.get_checkpoint_state(self._model_dir)
        saver = tf.train.Saver()
        tf.logging.info("Start to restore the parameters from %s",ckpt.model_checkpoint_path)
        saver.restore(self._sess,ckpt.model_checkpoint_path)
    tf.logging.info("Finish intialize environment")
Exemplo n.º 16
0
 def get_checkpoint_path(self, model_dir, sec_to_now):
     ckpt = saver.get_checkpoint_state(model_dir, None)
     if ckpt and len(ckpt.all_model_checkpoint_paths) > 0:
         path2tm = dict()  
         for i, p in enumerate(ckpt.all_model_checkpoint_paths):
             print ("^" * 40)
             st = Stat(p+".meta")
             tm = time.time() - st.mtime_nsec / 10**9
             path2tm[p] = tm
             print ("NO", str(i), ": ", p, " tm:", str(st.mtime_nsec / 10**9))
         sorted_paths = sorted(path2tm.items(), key=lambda x:x[1])
         print("sorted_paths: ", sorted_paths, "lenghts: ", len(sorted_paths))
         valid_paths = [x for x in sorted_paths if x[1] >= sec_to_now]
         print("valid_paths: ", valid_paths, "lengths: ", len(valid_paths))
         if len(valid_paths) > 0:
             return valid_paths[0]
         elif len(sorted_paths) > 0:
             return sorted_paths[0]
         return None
     return None            
    def _init_env(self):
        FLAGS.use_tpu = False
        tf.logging.set_verbosity(tf.logging.DEBUG)
        tf.logging.info("Import usr dir from %s", self._usr_dir)
        if self._usr_dir != None:
            usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
        tf.logging.info("Start to create hparams,for %s of %s", self._problem,
                        self._hparams_set)

        # 获取 模型参数
        self._hparams = create_hparams()
        # 获取 decode用的参数
        self._hparams_decode = create_decode_hparams(
            extra_length=self._extra_length,
            batch_size=self._batch_size,
            beam_size=self._beam_size,
            alpha=self._alpha,
            return_beams=self._return_beams,
            write_beam_scores=self._write_beam_scores,
            force_decode_length=self._force_decode_length)

        # self.estimator = trainer_lib.create_estimator(
        #     FLAGS.model,
        #     self._hparams,
        #     t2t_trainer.create_run_config(self._hparams),
        #     decode_hparams=self._hparams_decode,
        #     use_tpu=False)

        tf.logging.info("Finish intialize environment")

        #######

        ### make input placeholder
        self._inputs_ph = tf.placeholder(
            dtype=tf.int32)  # shape not specified,any shape

        x = tf.placeholder(dtype=tf.int32)
        x.set_shape([None, None])  # ? -> (?,?)
        x = tf.expand_dims(x, axis=[2])  # -> (?,?,1)
        x = tf.to_int32(x)
        self._inputs_ph = x

        #batch_inputs = tf.reshape(self._inputs_ph, [self._batch_size, -1, 1, 1])
        batch_inputs = x
        ###

        # batch_inputs = tf.reshape(self._inputs_ph, [-1, -1, 1, 1])

        #targets_ph = tf.placeholder(dtype=tf.int32)
        #batch_targets = tf.reshape(targets_ph, [1, -1, 1, 1])
        self._features = {
            "inputs": batch_inputs,
            "problem_choice": 0,  # We run on the first problem here.
            "input_space_id": self._hparams.problem_hparams.input_space_id,
            "target_space_id": self._hparams.problem_hparams.target_space_id
        }
        ### 加入 decode length  变长的  分类时候没用
        self.input_extra_length_ph = tf.placeholder(dtype=tf.int32)
        self._features['decode_length'] = self.input_extra_length_ph
        ## target
        self._targets_ph = tf.placeholder(tf.int32,
                                          shape=(None, None, None, None),
                                          name='targets')
        self._features['targets'] = self._targets_ph
        target_pretend = np.zeros((1, 1, 1, 1))

        ## 去掉 整数的
        del self._features["problem_choice"]
        del self._features["input_space_id"]
        del self._features["target_space_id"]
        del self._features['decode_length']
        ####
        #mode = tf.estimator.ModeKeys.PREDICT # affect last_only  t2t_model._top_single  ,[1,?,1,512]->[1,1,1,1,64]
        # if self.predict_or_eval=='EVAL':
        #     mode = tf.estimator.ModeKeys.EVAL # affect last_only  t2t_model._top_single  ,[1,?,1,512]->[1,?,1,1,64]
        # # estimator_spec = model_builder.model_fn(self._model_name, features, mode, self._hparams,
        # #                                         problem_names=[self._problem], decode_hparams=self._hparams_dc)
        # if self.predict_or_eval=='PREDICT':
        #     mode = tf.estimator.ModeKeys.PREDICT

        if self.predict_or_eval == 'and':
            mode = tf.estimator.ModeKeys.EVAL

        ###########
        # registry.model
        ############
        translate_model = registry.model(self._model_name)(
            hparams=self._hparams,
            decode_hparams=self._hparams_decode,
            mode=mode)

        self.predict_dict = {}

        ### get logit ,EVAL mode
        self.logits, _ = translate_model(self._features)
        ### get infer result ,PREDICT mode
        translate_model.set_mode(tf.estimator.ModeKeys.PREDICT)
        with tf.variable_scope(tf.get_variable_scope(), reuse=True):
            self.outputs_scores = translate_model.infer(
                features=self._features,
                decode_length=50,
                beam_size=self._beam_size,
                top_beams=self._beam_size,
                alpha=self._alpha)

        ######

        tf.logging.info("Start to init tf session")
        if self._isGpu:
            print('Using GPU in Decoder')
            gpu_options = tf.GPUOptions(
                per_process_gpu_memory_fraction=self._fraction)
            self._sess = tf.Session(
                config=tf.ConfigProto(allow_soft_placement=True,
                                      log_device_placement=False,
                                      gpu_options=gpu_options))
        else:
            print('Using CPU in Decoder')
            gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0)
            config = tf.ConfigProto(gpu_options=gpu_options)
            config.allow_soft_placement = True
            config.log_device_placement = False
            self._sess = tf.Session(config=config)
        with self._sess.as_default():
            ckpt = saver_mod.get_checkpoint_state(self._model_dir)
            saver = tf.train.Saver()
            tf.logging.info("Start to restore the parameters from %s",
                            ckpt.model_checkpoint_path)
            saver.restore(self._sess, ckpt.model_checkpoint_path)
        tf.logging.info("Finish intialize environment")
Exemplo n.º 18
0
    def _init_env(self):
        FLAGS.use_tpu = False
        tf.logging.set_verbosity(tf.logging.DEBUG)
        tf.logging.info("Import usr dir from %s", self._usr_dir)
        if self._usr_dir != None:
            usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
        tf.logging.info("Start to create hparams,for %s of %s", self._problem,
                        self._hparams_set)

        self._hparams = create_hparams()
        self._hparams_decode = create_decode_hparams(
            extra_length=self._extra_length,
            batch_size=self._batch_size,
            beam_size=self._beam_size,
            alpha=self._alpha,
            return_beams=self._return_beams,
            write_beam_scores=self._write_beam_scores,
            force_decode_length=self._force_decode_length)

        # self.estimator_spec = t2t_model.T2TModel.make_estimator_model_fn(
        #     self._model_name, self._hparams, decode_hparams=self._hparams_decode, use_tpu=False)

        self.estimator = trainer_lib.create_estimator(
            FLAGS.model,
            self._hparams,
            t2t_trainer.create_run_config(self._hparams),
            decode_hparams=self._hparams_decode,
            use_tpu=False)

        tf.logging.info("Finish intialize environment")

        #######

        ### make input placeholder
        self._inputs_ph = tf.placeholder(
            dtype=tf.int32)  # shape not specified,any shape

        x = tf.placeholder(dtype=tf.int32)
        x.set_shape([None, None])  # ? -> (?,?)
        x = tf.expand_dims(x, axis=[2])  # -> (?,?,1)
        x = tf.to_int32(x)
        self._inputs_ph = x

        #batch_inputs = tf.reshape(self._inputs_ph, [self._batch_size, -1, 1, 1])
        batch_inputs = x
        ###

        # batch_inputs = tf.reshape(self._inputs_ph, [-1, -1, 1, 1])

        #targets_ph = tf.placeholder(dtype=tf.int32)
        #batch_targets = tf.reshape(targets_ph, [1, -1, 1, 1])

        #self.inputs_ph = tf.placeholder(tf.int32, shape=(None, None, 1, 1), name='inputs')
        #self.targets_ph = tf.placeholder(tf.int32, shape=(None, None, None, None), name='targets')
        self.inputs_ph = tf.placeholder(tf.int32,
                                        shape=(None, None, 1, 1),
                                        name='inputs')
        self.targets_ph = tf.placeholder(tf.int32,
                                         shape=(None, None, 1, 1),
                                         name='targets')

        self._features = {
            "inputs": self.inputs_ph,
            "problem_choice": 0,  # We run on the first problem here.
            "input_space_id": self._hparams.problem_hparams.input_space_id,
            "target_space_id": self._hparams.problem_hparams.target_space_id
        }
        ### 加入 decode length  变长的
        self.input_extra_length_ph = tf.placeholder(dtype=tf.int32)
        self._features['decode_length'] = self.input_extra_length_ph
        ## target
        #self._targets_ph= tf.placeholder(tf.int32, shape=(None, None, None, None), name='targets')
        self._features['targets'] = self.targets_ph
        target_pretend = np.zeros((1, 1, 1, 1))

        ## 去掉 整数的
        del self._features["problem_choice"]
        del self._features["input_space_id"]
        del self._features["target_space_id"]
        del self._features['decode_length']
        ####
        #mode = tf.estimator.ModeKeys.PREDICT # affect last_only  t2t_model._top_single  ,[1,?,1,512]->[1,1,1,1,64]
        # if self.predict_or_eval=='EVAL':
        #     mode = tf.estimator.ModeKeys.EVAL # affect last_only  t2t_model._top_single  ,[1,?,1,512]->[1,?,1,1,64]
        # # estimator_spec = model_builder.model_fn(self._model_name, features, mode, self._hparams,
        # #                                         problem_names=[self._problem], decode_hparams=self._hparams_dc)
        # if self.predict_or_eval=='PREDICT':
        #     mode = tf.estimator.ModeKeys.PREDICT

        if self.predict_or_eval == 'and':
            mode = tf.estimator.ModeKeys.TRAIN

        ###########
        # registry.model
        ############
        translate_model = registry.model(self._model_name)(
            hparams=self._hparams,
            decode_hparams=self._hparams_decode,
            mode=mode)

        self.predict_dict = {}
        # if self.predict_or_eval == 'EVAL':
        #     self.logits,_=translate_model(self._features)
        #     self.predict_dict['scores']=self.logits
        #
        # if self.predict_or_eval == 'PREDICT':
        #
        #     self.predict_dict=translate_model.infer(features=self._features,
        #                             decode_length=50,
        #                             beam_size=1,
        #                             top_beams=1)
        #     print ''
        if self.predict_or_eval == 'and':
            ### get logit EVAL mode
            #self._features['targets'] = [[self._targets_ph]] # function body()
            self.logits, self.ret2 = translate_model(self._features)

        ##################
        ##  model_fn fetch logits FAIL : key not found
        #############
        # logits,_=translate_model.model_fn(self._features)

        # self._beam_result = model_i._fast_decode(self._features, decode_length=5, beam_size=10, top_beams=10,
        #                                          alpha=0.6) #fail
        # self._beam_result = model_i._beam_decode(self._features,
        #                                          decode_length=5,
        #                                          beam_size=self._beam_size,
        #                                          top_beams=self._beam_size,
        #                                          alpha=0.6)

        ##########

        # logits,_=model_i.model_fn(self._features)
        # assert len(logits.shape) == 5
        # logits = tf.squeeze(logits, [2, 3])
        # # Compute the log probabilities
        # from tensor2tensor.layers import common_layers
        # self.log_probs = common_layers.log_prob_from_logits(logits)

        ######

        #self._predictions = self._predictions_dict["outputs"]
        # self._scores=predictions_dict['scores'] not return when greedy search
        tf.logging.info("Start to init tf session")
        if self._isGpu:
            print('Using GPU in Decoder')
            gpu_options = tf.GPUOptions(
                per_process_gpu_memory_fraction=self._fraction)
            self._sess = tf.Session(
                config=tf.ConfigProto(allow_soft_placement=True,
                                      log_device_placement=False,
                                      gpu_options=gpu_options))
        else:
            print('Using CPU in Decoder')
            gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0)
            config = tf.ConfigProto(gpu_options=gpu_options)
            config.allow_soft_placement = True
            config.log_device_placement = False
            self._sess = tf.Session(config=config)
        with self._sess.as_default():
            ckpt = saver_mod.get_checkpoint_state(self._model_dir)
            saver = tf.train.Saver()
            #tf.logging.info("Start to restore the parameters from %s", ckpt.model_checkpoint_path)
            #saver.restore(self._sess, ckpt.model_checkpoint_path)
            ########## 重新初始化参数
            self._sess.run(tf.global_variables_initializer())
        tf.logging.info("Finish intialize environment")
    def _init_env(self):
        FLAGS.use_tpu = False
        tf.logging.set_verbosity(tf.logging.DEBUG)
        tf.logging.info("Import usr dir from %s", self._usr_dir)
        if self._usr_dir != None:
            usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
        tf.logging.info("Start to create hparams,for %s of %s", self._problem,
                        self._hparams_set)

        self._hparams = create_hparams()
        self._hparams_decode = create_decode_hparams(
            extra_length=self._extra_length,
            batch_size=self._batch_size,
            beam_size=self._beam_size,
            alpha=self._alpha,
            return_beams=self._return_beams,
            write_beam_scores=self._write_beam_scores,
            force_decode_length=self._force_decode_length)

        self.estimator = trainer_lib.create_estimator(
            FLAGS.model,
            self._hparams,
            t2t_trainer.create_run_config(self._hparams),
            decode_hparams=self._hparams_decode,
            use_tpu=False)

        tf.logging.info("Finish intialize environment")

        #######

        ### make input placeholder
        #self._inputs_ph = tf.placeholder(dtype=tf.int32)  # shape not specified,any shape

        # x=tf.placeholder(dtype=tf.int32)
        # x.set_shape([None, None]) # ? -> (?,?)
        # x = tf.expand_dims(x, axis=[2])# -> (?,?,1)
        # x = tf.to_int32(x)
        # self._inputs_ph=x

        #batch_inputs = tf.reshape(self._inputs_ph, [self._batch_size, -1, 1, 1])
        #batch_inputs=x
        ###

        # batch_inputs = tf.reshape(self._inputs_ph, [-1, -1, 1, 1])

        #targets_ph = tf.placeholder(dtype=tf.int32)
        #batch_targets = tf.reshape(targets_ph, [1, -1, 1, 1])

        self.inputs_ph = tf.placeholder(tf.int32,
                                        shape=(None, None, 1, 1),
                                        name='inputs')
        self.targets_ph = tf.placeholder(tf.int32,
                                         shape=(None, None, None, None),
                                         name='targets')
        self.input_extra_length_ph = tf.placeholder(dtype=tf.int32, shape=[])

        self._features = {
            "inputs": self.inputs_ph,
            "problem_choice": 0,  # We run on the first problem here.
            "input_space_id": self._hparams.problem_hparams.input_space_id,
            "target_space_id": self._hparams.problem_hparams.target_space_id
        }
        ### 加入 decode length  变长的
        self._features['decode_length'] = self.input_extra_length_ph
        ## target
        self._features['targets'] = self.targets_ph

        ## 去掉 整数的
        del self._features["problem_choice"]
        del self._features["input_space_id"]
        del self._features["target_space_id"]
        #del self._features['decode_length']
        ####

        mode = tf.estimator.ModeKeys.EVAL

        translate_model = registry.model(self._model_name)(
            hparams=self._hparams,
            decode_hparams=self._hparams_decode,
            mode=mode)

        self.predict_dict = {}

        ### get logit  ,attention mats
        self.logits, _ = translate_model(self._features)  #[? ? ? 1 vocabsz]
        #translate_model(features)
        from visualization import get_att_mats
        self.att_mats = get_att_mats(translate_model,
                                     self._model_name)  # enc, dec, encdec
        ### get infer
        translate_model.set_mode(tf.estimator.ModeKeys.PREDICT)
        with tf.variable_scope(tf.get_variable_scope(), reuse=True):
            self.outputs_scores = translate_model.infer(
                features=self._features,
                decode_length=self._extra_length,
                beam_size=self._beam_size,
                top_beams=self._beam_size,
                alpha=self._alpha)  #outputs 4,4,63

        ######
        tf.logging.info("Start to init tf session")
        if self._isGpu:
            print('Using GPU in Decoder')
            gpu_options = tf.GPUOptions(
                per_process_gpu_memory_fraction=self._fraction)
            self._sess = tf.Session(
                config=tf.ConfigProto(allow_soft_placement=True,
                                      log_device_placement=False,
                                      gpu_options=gpu_options))
        else:
            print('Using CPU in Decoder')
            gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0)
            config = tf.ConfigProto(gpu_options=gpu_options)
            config.allow_soft_placement = True
            config.log_device_placement = False
            self._sess = tf.Session(config=config)
        with self._sess.as_default():
            ckpt = saver_mod.get_checkpoint_state(self._model_dir)
            saver = tf.train.Saver()
            tf.logging.info("Start to restore the parameters from %s",
                            ckpt.model_checkpoint_path)
            saver.restore(self._sess, ckpt.model_checkpoint_path)
        tf.logging.info("Finish intialize environment")
    def _init_env(self):
        FLAGS.use_tpu = False
        tf.logging.set_verbosity(tf.logging.DEBUG)
        tf.logging.info("Import usr dir from %s", self._usr_dir)
        if self._usr_dir != None:
            usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
        tf.logging.info("Start to create hparams,for %s of %s", self._problem,
                        self._hparams_set)

        self._hparams = create_hparams()
        self._hparams_decode = create_decode_hparams(
            extra_length=self._extra_length,
            batch_size=self._batch_size,
            beam_size=self._beam_size,
            alpha=self._alpha,
            return_beams=self._return_beams,
            write_beam_scores=self._write_beam_scores)

        # self.estimator_spec = t2t_model.T2TModel.make_estimator_model_fn(
        #     self._model_name, self._hparams, decode_hparams=self._hparams_decode, use_tpu=False)

        self.estimator = trainer_lib.create_estimator(
            FLAGS.model,
            self._hparams,
            t2t_trainer.create_run_config(self._hparams),
            decode_hparams=self._hparams_decode,
            use_tpu=False)

        tf.logging.info("Finish intialize environment")

        #######

        ### make input placeholder
        self._inputs_ph = tf.placeholder(
            dtype=tf.int32)  # shape not specified,any shape

        x = tf.placeholder(dtype=tf.int32)
        x.set_shape([None, None])  # ? -> (?,?)
        x = tf.expand_dims(x, axis=[2])  # -> (?,?,1)
        x = tf.to_int32(x)
        self._inputs_ph = x

        #batch_inputs = tf.reshape(self._inputs_ph, [self._batch_size, -1, 1, 1])
        batch_inputs = x

        # batch_inputs = tf.reshape(self._inputs_ph, [-1, -1, 1, 1])

        #targets_ph = tf.placeholder(dtype=tf.int32)
        #batch_targets = tf.reshape(targets_ph, [1, -1, 1, 1])
        self._features = {
            "inputs": batch_inputs,
            "problem_choice": 0,  # We run on the first problem here.
            "input_space_id": self._hparams.problem_hparams.input_space_id,
            "target_space_id": self._hparams.problem_hparams.target_space_id
        }
        ### 加入 decode length  变长的
        self.input_extra_length_ph = tf.placeholder(dtype=tf.int32)
        #self._features['decode_length'] = [self.input_extra_length_ph]
        #### 采样 c(s)
        ###
        self.cache_ph = tf.placeholder(dtype=tf.int32)
        #self._features['cache_raw']=tf.reshape(self.cache_ph,[1,2,1])

        ## 去掉 整数的
        del self._features["problem_choice"]
        del self._features["input_space_id"]
        del self._features["target_space_id"]
        ####
        mode = tf.estimator.ModeKeys.PREDICT
        # estimator_spec = model_builder.model_fn(self._model_name, features, mode, self._hparams,
        #                                         problem_names=[self._problem], decode_hparams=self._hparams_dc)

        ######
        from tensor2tensor.models import transformer_vae
        model_i = transformer_vae.TransformerAE(
            hparams=self._hparams,
            mode=mode,
            decode_hparams=self._hparams_decode)
        # Transformer_(hparams=self._hparams,
        #                             mode=mode, decode_hparams=self._hparams_decode)
        #                             #problem_hparams=p_hparams,

        # self._beam_result = model_i._fast_decode(self._features, decode_length=5, beam_size=10, top_beams=10,
        #                                          alpha=0.6) #fail
        # self._beam_result = model_i._beam_decode(self._features,
        #                                          decode_length=5,
        #                                          beam_size=self._beam_size,
        #                                          top_beams=self._beam_size,
        #                                          alpha=0.6)

        self.result_dict = model_i.infer(self._features)

        print ''

        #### add target,丢了一些KEY 不能单独拿出来MODEL_FN
        # from tensor2tensor.layers import common_layers
        # features=self._features
        # batch_size = common_layers.shape_list(features["inputs"])[0]
        # length = common_layers.shape_list(features["inputs"])[1]
        # target_length = tf.to_int32(2.0 * tf.to_float(length))
        # initial_output = tf.zeros((batch_size, target_length, 1, 1),
        #                           dtype=tf.int64)
        # features["targets"] = initial_output
        # ### input
        # if "inputs" in features and len(features["inputs"].shape) < 4:
        #     inputs_old = features["inputs"]
        #     features["inputs"] = tf.expand_dims(features["inputs"], 2)
        # #### model_fn
        # self.result_dict=model_i.model_fn(features)

        print ''
        """
        ######
        predictions_dict = self.estimator._call_model_fn(self._features,None,mode,t2t_trainer.create_run_config(self._hparams))
        self._predictions_dict=predictions_dict.predictions
        """
        #self._predictions = self._predictions_dict["outputs"]
        # self._scores=predictions_dict['scores'] not return when greedy search
        tf.logging.info("Start to init tf session")
        if self._isGpu:
            print('Using GPU in Decoder')
            gpu_options = tf.GPUOptions(
                per_process_gpu_memory_fraction=self._fraction)
            self._sess = tf.Session(
                config=tf.ConfigProto(allow_soft_placement=True,
                                      log_device_placement=False,
                                      gpu_options=gpu_options))
        else:
            print('Using CPU in Decoder')
            gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0)
            config = tf.ConfigProto(gpu_options=gpu_options)
            config.allow_soft_placement = True
            config.log_device_placement = False
            self._sess = tf.Session(config=config)
        with self._sess.as_default():
            ckpt = saver_mod.get_checkpoint_state(self._model_dir)
            saver = tf.train.Saver()
            tf.logging.info("Start to restore the parameters from %s",
                            ckpt.model_checkpoint_path)
            saver.restore(self._sess, ckpt.model_checkpoint_path)
        tf.logging.info("Finish intialize environment")
    def _init_env(self):
        FLAGS.use_tpu = False
        #tf.logging.set_verbosity(tf.logging.DEBUG)
        tf.logging.info("Import usr dir from %s", self._usr_dir)
        if self._usr_dir != None:
            #usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
            usr_dir.import_usr_dir(self._usr_dir)
        tf.logging.info("Start to create hparams,for %s of %s", self._problem, self._hparams_set)

        self._hparams = create_hparams()

        self._hparams_decode = create_decode_hparams(extra_length=self._extra_length,
                                                     batch_size=self._batch_size,
                                                     beam_size=self._beam_size,
                                                     alpha=self._alpha,
                                                     return_beams=self._return_beams,
                                                     write_beam_scores=self._write_beam_scores,
                                                     force_decode_length=self._force_decode_length)



        self.estimator = trainer_lib.create_estimator(
            FLAGS.model,
            self._hparams,
            t2t_trainer.create_run_config(self._hparams),
            decode_hparams=self._hparams_decode,
            use_tpu=False)

        tf.logging.info("Finish intialize environment")

        ####### problem type :输出分类 还是序列 还是语言模型
        #self.problem_type = self._hparams.problem_hparams[0].target_modality[0] #class? symble
        self.problem_type = self._hparams.problem_hparams.target_modality[0]
        #self._whether_has_inputs = self._hparams.problem[0].has_inputs
        self._whether_has_inputs = self._hparams.problem.has_inputs
        self._beam_size=1 if self._customer_problem_type=='classification' else self._beam_size



        ### make input placeholder
        #self._inputs_ph = tf.placeholder(dtype=tf.int32)  # shape not specified,any shape

        # x=tf.placeholder(dtype=tf.int32)
        # x.set_shape([None, None]) # ? -> (?,?)
        # x = tf.expand_dims(x, axis=[2])# -> (?,?,1)
        # x = tf.to_int32(x)
        #self._inputs_ph=x

        #batch_inputs = tf.reshape(self._inputs_ph, [self._batch_size, -1, 1, 1])
        #batch_inputs=x

        # batch_inputs = tf.reshape(self._inputs_ph, [-1, -1, 1, 1])

        batch_inputs,self._targets_ph,self.input_extra_length_ph=get_ph(x_dim_3=True)

        #targets_ph = tf.placeholder(dtype=tf.int32)
        #batch_targets = tf.reshape(targets_ph, [1, -1, 1, 1])
        self._features = {"inputs": batch_inputs,
                    "problem_choice": 0,  # We run on the first problem here.
                    "input_space_id": self._hparams.problem_hparams.input_space_id,
                    "target_space_id": self._hparams.problem_hparams.target_space_id}
        ### 加入 decode length  变长的
        #self.input_extra_length_ph = tf.placeholder(dtype=tf.int32,shape=[])
        self._features['decode_length'] = self.input_extra_length_ph # total_decode=input_len+extra_len|  extra of chunkProblem =0
        # real_decode_length=len(input)+extra_length
        ##
        #self._features['decode_length_decide_end'] = True

        #### 如果是relative 参数
        if self._hparams_set=="transformer_relative":
            del self._features['problem_choice']
            del self._features['input_space_id']
            del self._features['target_space_id']

        if self._customer_problem_type=='languageModel_pp':
            del self._features['problem_choice']
            del self._features['input_space_id']
            del self._features['target_space_id']
        if self._model_name in ['slice_net','transformer_encoder']:
            del self._features['problem_choice']
            del self._features['input_space_id']
            del self._features['target_space_id']
        if self._model_name=='transformer' and self._customer_problem_type=='classification':
            del self._features['problem_choice']
            del self._features['input_space_id']
            del self._features['target_space_id']




        ###### target if transformer_scorer
        if self._customer_problem_type=='classification':
            self._targets_ph = tf.placeholder(tf.int32, shape=(None, None, None, None), name='targets')
            self._features['targets'] = self._targets_ph  # batch targets

        if self._customer_problem_type=='languageModel_pp':
            self._targets_ph = tf.placeholder(tf.int32, shape=(None, None, None, None), name='targets')
            self._features['targets']=  self._targets_ph


        #### mode
        mode = tf.estimator.ModeKeys.PREDICT
        if self._customer_problem_type == 'languageModel_pp':
            mode = tf.estimator.ModeKeys.EVAL
        elif self._customer_problem_type=='classification' and 'score' not in self._model_name:
            mode = tf.estimator.ModeKeys.EVAL
        # estimator_spec = model_builder.model_fn(self._model_name, features, mode, self._hparams,
        #                                         problem_names=[self._problem], decode_hparams=self._hparams_dc)
        predictions_dict = self.estimator._call_model_fn(self._features,None,mode,t2t_trainer.create_run_config(self._hparams))
        self._predictions_dict=predictions_dict.predictions
        # score -> score_yr
        if self._customer_problem_type=='classification' and 'score' in self._model_name:
            self._score=predictions_dict.predictions.get('scores')
            if self._score!=None: #[batch,beam] [batch,]
                self._predictions_dict['scores_class']=tf.exp(common_layers.log_prob_from_logits(self._score))
        elif self._customer_problem_type=='classification' and 'score' not in self._model_name:
            self._score = predictions_dict.predictions.get('predictions')
            if self._score!=None: #[batch,beam] [batch,]
                self._predictions_dict['scores_class']=tf.exp(common_layers.log_prob_from_logits(self._score))
        #self._predictions = self._predictions_dict["outputs"]
        # self._scores=predictions_dict['scores'] not return when greedy search
        tf.logging.info("Start to init tf session")
        if self._isGpu:
            print('Using GPU in Decoder')
            gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=self._fraction)
            self._sess = tf.Session(
                config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False, gpu_options=gpu_options))
        else:
            print('Using CPU in Decoder')
            gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0)
            config = tf.ConfigProto(gpu_options=gpu_options)
            config.allow_soft_placement = True
            config.log_device_placement = False
            self._sess = tf.Session(config=config)
        with self._sess.as_default():
            ckpt = saver_mod.get_checkpoint_state(self._model_dir)
            saver = tf.train.Saver(allow_empty=True)
            tf.logging.info("Start to restore the parameters from %s", ckpt.model_checkpoint_path)
            saver.restore(self._sess, ckpt.model_checkpoint_path)
        tf.logging.info("Finish intialize environment")
    def _init_env(self):
        FLAGS.use_tpu = False
        tf.logging.set_verbosity(tf.logging.DEBUG)
        tf.logging.info("Import usr dir from %s", self._usr_dir)
        if self._usr_dir != None:
            usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
        tf.logging.info("Start to create hparams,for %s of %s", self._problem,
                        self._hparams_set)

        self._hparams = create_hparams()
        self._hparams_decode = create_decode_hparams(
            extra_length=self._extra_length,
            batch_size=self._batch_size,
            beam_size=self._beam_size,
            alpha=self._alpha,
            return_beams=self._return_beams,
            write_beam_scores=self._write_beam_scores)

        self.estimator = trainer_lib.create_estimator(
            FLAGS.model,
            self._hparams,
            t2t_trainer.create_run_config(self._hparams),
            decode_hparams=self._hparams_decode,
            use_tpu=False)

        tf.logging.info("Finish intialize environment")
        ####### problem type :输出分类 还是序列 还是语言模型

        self.problem_type = self._hparams.problems[0].target_modality[
            0]  # class? symble
        self._whether_has_inputs = self._hparams.problem_instances[
            0].has_inputs
        self._beam_size = 1 if self.problem_type == 'class_label' else self._beam_size
        #######

        ### make input placeholder
        self._inputs_ph = tf.placeholder(
            dtype=tf.int32)  # shape not specified,any shape

        x = tf.placeholder(dtype=tf.int32)
        x.set_shape([None, None])  # ? -> (?,?)
        x = tf.expand_dims(x, axis=[2])  # -> (?,?,1)
        # EVAL MODEL
        x = tf.expand_dims(x, axis=[3])  # -> (?,?,1,1)
        x = tf.to_int32(x)
        self._inputs_ph = x

        #batch_inputs = tf.reshape(self._inputs_ph, [self._batch_size, -1, 1, 1])
        batch_inputs = x  #[?,?,1,1]

        # batch_inputs = tf.reshape(self._inputs_ph, [-1, -1, 1, 1])

        #targets_ph = tf.placeholder(dtype=tf.int32)
        #batch_targets = tf.reshape(targets_ph, [1, -1, 1, 1])
        self._features = {
            "inputs": batch_inputs,
            "problem_choice": 0,  # We run on the first problem here.
            "input_space_id": self._hparams.problems[0].input_space_id,
            "target_space_id": self._hparams.problems[0].target_space_id
        }
        ### 加入 decode length  变长的
        #self.input_extra_length_ph = tf.placeholder(dtype=tf.int32)
        #self._features['decode_length'] = self.input_extra_length_ph

        #### EVAL MODE target
        self._targets_ph = tf.placeholder(tf.int32,
                                          shape=(1, None, 1, 1),
                                          name='targets')
        self._features['targets'] = self._targets_ph  #batch targets
        del self._features['problem_choice']
        del self._features['input_space_id']
        del self._features['target_space_id']

        ####
        mode = tf.estimator.ModeKeys.EVAL

        predictions_dict = self.estimator._call_model_fn(
            self._features, None, mode,
            t2t_trainer.create_run_config(self._hparams))
        self._predictions_dict = predictions_dict.predictions
        #self._predictions = self._predictions_dict["outputs"]
        # self._scores=predictions_dict['scores'] not return when greedy search
        tf.logging.info("Start to init tf session")
        if self._isGpu:
            print('Using GPU in Decoder')
            gpu_options = tf.GPUOptions(
                per_process_gpu_memory_fraction=self._fraction)
            self._sess = tf.Session(
                config=tf.ConfigProto(allow_soft_placement=True,
                                      log_device_placement=False,
                                      gpu_options=gpu_options))
        else:
            print('Using CPU in Decoder')
            gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0)
            config = tf.ConfigProto(gpu_options=gpu_options)
            config.allow_soft_placement = True
            config.log_device_placement = False
            self._sess = tf.Session(config=config)
        with self._sess.as_default():
            ckpt = saver_mod.get_checkpoint_state(self._model_dir)
            saver = tf.train.Saver()
            tf.logging.info("Start to restore the parameters from %s",
                            ckpt.model_checkpoint_path)
            saver.restore(self._sess, ckpt.model_checkpoint_path)
        tf.logging.info("Finish intialize environment")