示例#1
0
    def __init__(self,
                 data_dir='',
                 model_dir='',
                 output_dir='',
                 batch_size=128,
                 max_seq_len=32,
                 num_train_epochs=10,
                 learning_rate=0.00005,
                 gpu_memory_fraction=0.9):
        self.model_dir = model_dir
        self.vocab_file = os.path.join(model_dir, 'vocab.txt')
        self.config_name = os.path.join(model_dir, 'bert_config.json')
        self.ckpt_name = os.path.join(model_dir, 'bert_model.ckpt')
        self.output_dir = output_dir
        self.mode = None
        self.max_seq_len = max_seq_len
        self.tokenizer = tokenization.FullTokenizer(vocab_file=self.vocab_file,
                                                    do_lower_case=True)
        self.batch_size = batch_size
        self.estimator = None
        self.processor = TrainProcessor()
        tf.logging.set_verbosity(tf.logging.INFO)
        self.learning_rate = learning_rate

        self.data_dir = data_dir
        self.num_train_epochs = num_train_epochs
        self.gpu_memory_fraction = gpu_memory_fraction
示例#2
0
 def __init__(self,
              max_seq_len=32,
              batch_size=32,
              layer_indexes=[-2],
              model_dir='',
              output_dir=''):
     """
     init BertVector
     :param batch_size:     Depending on your memory default is 32
     """
     self.max_seq_length = max_seq_len
     self.layer_indexes = layer_indexes
     self.gpu_memory_fraction = 1
     self.model_dir = model_dir
     vocab_file = os.path.join(model_dir, 'vocab.txt')
     config_name = os.path.join(model_dir, 'bert_config.json')
     ckpt_name = os.path.join(model_dir, 'bert_model.ckpt')
     self.graph_path = optimize_graph(layer_indexes=layer_indexes,
                                      config_name=config_name,
                                      ckpt_name=ckpt_name,
                                      max_seq_len=max_seq_len,
                                      output_dir=output_dir)
     self.tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file,
                                                 do_lower_case=True)
     self.batch_size = batch_size
     self.estimator = self.get_estimator()
     self.input_queue = Queue(maxsize=1)
     self.output_queue = Queue(maxsize=1)
     self.predict_thread = Thread(target=self.predict_from_queue,
                                  daemon=True)
     self.predict_thread.start()
     self.sentence_len = 0