def __init__(self, session, corpus_dir, knbase_dir, result_dir, hparams_dir=None): self.session = session hparams = HParams(hparams_dir).hparams if hparams_dir else None # Prepare data and hyper parameters print("# Prepare dataset placeholder and hyper parameters ...") self.tokenized_data = TokenizedData(corpus_dir=corpus_dir, hparams=hparams, knbase_dir=knbase_dir, training=False) self.hparams = self.tokenized_data.hparams self.src_placeholder = tf.placeholder(shape=[None], dtype=tf.string) src_dataset = tf.contrib.data.Dataset.from_tensor_slices( self.src_placeholder) self.infer_batch = self.tokenized_data.get_inference_batch(src_dataset) # Create model print("# Creating inference model ...") self.model = ModelCreator(training=False, tokenized_data=self.tokenized_data, batch_input=self.infer_batch) latest_ckpt = tf.train.latest_checkpoint(result_dir) print("# Restoring model weights ...") self.model.saver.restore(session, latest_ckpt) self.session.run(tf.tables_initializer())
def __init__(self, corpus_dir, hparams=None, knbase_dir=None, training=True, augment_factor=3, buffer_size=8192): """ Args: corpus_dir: Name of the folder storing corpus files for training. hparams: The object containing the loaded hyper parameters. If None, it will be initialized here. knbase_dir: Name of the folder storing data files for the knowledge base. Used for inference only. training: Whether to use this object for training. augment_factor: Times the training data appears. If 1 or less, no augmentation. buffer_size: The buffer size used for mapping process during data processing. """ if hparams is None: self.hparams = HParams(corpus_dir).hparams else: self.hparams = hparams self.src_max_len = self.hparams.src_max_len self.tgt_max_len = self.hparams.tgt_max_len self.training = training self.text_set = None self.id_set = None vocab_file = os.path.join(corpus_dir, VOCAB_FILE) self.vocab_size, _ = check_vocab(vocab_file) self.vocab_table = lookup_ops.index_table_from_file( vocab_file, default_value=self.hparams.unk_id) # print("vocab_size = {}".format(self.vocab_size)) if training: self.case_table = prepare_case_table() self.reverse_vocab_table = None self._load_corpus(corpus_dir, augment_factor) self._convert_to_tokens(buffer_size) self.upper_words = {} self.stories = {} self.jokes = [] else: self.case_table = None self.reverse_vocab_table = \ lookup_ops.index_to_string_table_from_file(vocab_file, default_value=self.hparams.unk_token) assert knbase_dir is not None knbs = KnowledgeBase() knbs.load_knbase(knbase_dir) self.upper_words = knbs.upper_words self.stories = knbs.stories self.jokes = knbs.jokes
def __init__(self, corpus_dir, hparams_dir=None): hparams = HParams(hparams_dir).hparams if hparams_dir else None self.graph = tf.Graph() with self.graph.as_default(): tokenized_data = TokenizedData(corpus_dir=corpus_dir, hparams=hparams) self.hparams = tokenized_data.hparams self.train_batch = tokenized_data.get_training_batch() self.model = ModelCreator(training=True, tokenized_data=tokenized_data, batch_input=self.train_batch)
def __init__(self, corpus_dir, hparams=None, training=True, buffer_size=8192): """ Args: corpus_dir: Name of the folder storing corpus files for training. hparams: The object containing the loaded hyper parameters. If None, it will be initialized here. training: Whether to use this object for training. buffer_size: The buffer size used for mapping process during data processing. """ if hparams is None: self.hparams = HParams(corpus_dir).hparams else: self.hparams = hparams self.src_max_len = self.hparams.src_max_len self.tgt_max_len = self.hparams.tgt_max_len self.training = training self.text_set = None self.id_set = None vocab_file = os.path.join(corpus_dir, VOCAB_FILE) self.vocab_size, _ = check_vocab(vocab_file) self.vocab_table = lookup_ops.index_table_from_file( vocab_file, default_value=self.hparams.unk_id) # print("vocab_size = {}".format(self.vocab_size)) if training: self.case_table = prepare_case_table() self.reverse_vocab_table = None self._load_corpus(corpus_dir) self._convert_to_tokens(buffer_size) else: self.case_table = None self.reverse_vocab_table = \ lookup_ops.index_to_string_table_from_file(vocab_file, default_value=self.hparams.unk_token)