Exemplo n.º 1
0
    def setup(self, flags):
        # Model output directory
        out_dir = flags.out_dir
        if out_dir and not tf.gfile.Exists(out_dir):
            tf.gfile.MakeDirs(out_dir)

        # Load hparams.
        default_hparams = create_hparams(flags)
        loaded_hparams = False
        if flags.ckpt:  # Try to load hparams from the same directory as ckpt
            ckpt_dir = os.path.dirname(flags.ckpt)
            ckpt_hparams_file = os.path.join(ckpt_dir, "hparams")
            if tf.gfile.Exists(ckpt_hparams_file) or flags.hparams_path:
                # Note: for some reason this will create an empty "best_bleu" directory and copy vocab files
                hparams = create_or_load_hparams(ckpt_dir,
                                                 default_hparams,
                                                 flags.hparams_path,
                                                 save_hparams=False)
                loaded_hparams = True

        assert loaded_hparams

        # GPU device
        config_proto = utils.get_config_proto(
            allow_soft_placement=True,
            num_intra_threads=hparams.num_intra_threads,
            num_inter_threads=hparams.num_inter_threads)
        utils.print_out("# Devices visible to TensorFlow: %s" %
                        repr(tf.Session(config=config_proto).list_devices()))

        # Inference indices (inference_indices is broken, but without setting it to None we'll crash)
        hparams.inference_indices = None

        # Create the graph
        model_creator = get_model_creator(hparams)
        infer_model = model_helper.create_infer_model(model_creator,
                                                      hparams,
                                                      scope=None)
        sess, loaded_infer_model = start_sess_and_load_model(
            infer_model, flags.ckpt, hparams)

        # Parameters needed by TF GNMT
        self.hparams = hparams

        self.infer_model = infer_model
        self.sess = sess
        self.loaded_infer_model = loaded_infer_model
Exemplo n.º 2
0
    def _createTestInferenceCheckpint(self, hparams, name):
        #Prepare
        hparams.vocab_prefix = ('nmt/testdata/test_infer_vocab')
        hparams.src_vocab_file = hparams.vocab_prefix + '.' + hparams.src
        hparams.tgt_vocab_file = hparams.vocab_prefix + '.' + hparams.tgt
        out_dir = os.path.join(tf.test.get_temp_dir(), name)
        os.makedirs(out_dir)
        hparams.out_dir = out_dir

        #create check point
        model_creator = inference.get_model_creator(hparams)
        infer_model = model_helper.create_infer_model(model_creator, hparams)
        with self.test_session(graph=infer_model.graph) as sess:
            loaded_model, global_step = model_helper.create_or_load_model(
                infer_model.model, out_dir, sess, 'infer_name')
            ckpt_path = loaded_model.saver.save(sess,
                                                os.path.join(
                                                    out_dir,
                                                    'translation.ckpt'),
                                                global_step=global_step)
        return ckpt_path
Exemplo n.º 3
0
print(ckpt_path2)

hparams1 = create_or_load_hparams(out_dir1,
                                  default_hparams1,
                                  None,
                                  save_hparams=0)

hparams2 = create_or_load_hparams(out_dir2,
                                  default_hparams2,
                                  None,
                                  save_hparams=0)

hparams1.inference_indices = None
hparams2.inference_indices = None

model_creator1 = get_model_creator(hparams1)
model_creator2 = get_model_creator(hparams2)

infer_model1 = create_infer_model(model_creator1, hparams1, None)
infer_model2 = create_infer_model(model_creator2, hparams2, None)

sess1, loaded_infer_model1 = start_sess_and_load_model(infer_model1,
                                                       ckpt_path1)
sess2, loaded_infer_model2 = start_sess_and_load_model(infer_model2,
                                                       ckpt_path2)

jieba.load_userdict("字典.txt")


def is_contains_chinese(strs):
    for _char in strs:
Exemplo n.º 4
0
    def setup(self, flags):
        # Model output directory
        out_dir = flags.out_dir
        if out_dir and not tf.gfile.Exists(out_dir):
          tf.gfile.MakeDirs(out_dir)

        # Load hparams.
        default_hparams = create_hparams(flags)
        loaded_hparams = False
        if flags.ckpt:  # Try to load hparams from the same directory as ckpt
          ckpt_dir = os.path.dirname(flags.ckpt)
          ckpt_hparams_file = os.path.join(ckpt_dir, "hparams")
          if tf.gfile.Exists(ckpt_hparams_file) or flags.hparams_path:
                # Note: for some reason this will create an empty "best_bleu" directory and copy vocab files
                hparams = create_or_load_hparams(ckpt_dir, default_hparams, flags.hparams_path, save_hparams=False)
                loaded_hparams = True

        assert loaded_hparams

        # GPU device
        config_proto = utils.get_config_proto(
            allow_soft_placement=True,
            num_intra_threads=hparams.num_intra_threads,
            num_inter_threads=hparams.num_inter_threads)
        utils.print_out(
            "# Devices visible to TensorFlow: %s"
            % repr(tf.Session(config=config_proto).list_devices()))


        # Inference indices (inference_indices is broken, but without setting it to None we'll crash)
        hparams.inference_indices = None

        # Create the graph
        model_creator = get_model_creator(hparams)
        infer_model = model_helper.create_infer_model(model_creator, hparams, scope=None)
        sess, loaded_infer_model = start_sess_and_load_model(infer_model, flags.ckpt,
                                                       hparams)

        # FIXME (bryce): Set to False to disable inference from frozen graph and run fast again
        if True:
          frozen_graph = None
          with infer_model.graph.as_default():
            output_node_names = ['hash_table_Lookup_1/LookupTableFindV2']
            other_node_names  = ['MakeIterator', 'IteratorToStringHandle', 'init_all_tables', 'NoOp', 'dynamic_seq2seq/decoder/NoOp']
            frozen_graph = tf.graph_util.convert_variables_to_constants(sess,
                                                                        tf.get_default_graph().as_graph_def(),
                                                                        output_node_names=output_node_names + other_node_names)

            # FIXME (bryce): Uncomment this block to enable tensorRT convert
            from tensorflow.python.compiler.tensorrt import trt_convert as trt
            converter = trt.TrtGraphConverter(input_graph_def=frozen_graph, nodes_blacklist=(output_node_names),
                                              is_dynamic_op=True, max_batch_size=hparams.infer_batch_size,
                                              max_beam_size=hparams.beam_width, max_src_seq_len=hparams.src_max_len)
            frozen_graph = converter.convert()

          with tf.Graph().as_default():
            tf.graph_util.import_graph_def(frozen_graph, name="")
            sess = tf.Session(graph=tf.get_default_graph(),
                   config=utils.get_config_proto(
                   num_intra_threads=hparams.num_intra_threads,
                   num_inter_threads=hparams.num_inter_threads)
                   )
            iterator = iterator_utils.BatchedInput(
              initializer=tf.get_default_graph().get_operation_by_name(infer_model.iterator.initializer.name),
              source=tf.get_default_graph().get_tensor_by_name(infer_model.iterator.source.name),
              target_input=None,
              target_output=None,
              source_sequence_length=tf.get_default_graph().get_tensor_by_name(infer_model.iterator.source_sequence_length.name),
              target_sequence_length=None)
            infer_model = model_helper.InferModel(
                  graph=tf.get_default_graph(),
                  model=infer_model.model,
                  src_placeholder=tf.get_default_graph().get_tensor_by_name(infer_model.src_placeholder.name),
                  batch_size_placeholder=tf.get_default_graph().get_tensor_by_name(infer_model.batch_size_placeholder.name),
                  iterator=iterator)

        # Parameters needed by TF GNMT
        self.hparams = hparams

        self.infer_model = infer_model
        self.sess = sess
        self.loaded_infer_model = loaded_infer_model