예제 #1
0
def model_loading():
    _, model_dir, hparams, checkpoint, _, _ = general_utils.initialize_session(
        "chat")
    #Load the vocabulary
    print()
    print("Loading vocabulary...")
    if hparams.model_hparams.share_embedding:
        shared_vocab_filepath = path.join(model_dir,
                                          Vocabulary.SHARED_VOCAB_FILENAME)
        input_vocabulary = Vocabulary.load(shared_vocab_filepath)
        output_vocabulary = input_vocabulary
    else:
        input_vocab_filepath = path.join(model_dir,
                                         Vocabulary.INPUT_VOCAB_FILENAME)
        input_vocabulary = Vocabulary.load(input_vocab_filepath)
        output_vocab_filepath = path.join(model_dir,
                                          Vocabulary.OUTPUT_VOCAB_FILENAME)
        output_vocabulary = Vocabulary.load(output_vocab_filepath)

    # Setting up the chat
    chatlog_filepath = path.join(
        model_dir, "chat_logs", "chatlog_{0}.txt".format(
            datetime.datetime.now().strftime("%Y%m%d_%H%M%S")))
    chat_settings = ChatSettings(hparams.model_hparams,
                                 hparams.inference_hparams)

    ############# Loading Model #############

    reload_model = False
    print()
    print("Initializing model..."
          if not reload_model else "Re-initializing model...")
    print()
    model = ChatbotModel(mode="infer",
                         model_hparams=chat_settings.model_hparams,
                         input_vocabulary=input_vocabulary,
                         output_vocabulary=output_vocabulary,
                         model_dir=model_dir)

    #Load the weights
    print()
    print("Loading model weights...")
    print()
    model.load(checkpoint)

    #Show the commands
    if not reload_model:
        #Uncomment the following line if you want to print commands.
        #chat_command_handler.print_commands()
        print('Model Reload!')
    return model, chatlog_filepath, chat_settings
예제 #2
0
"""
Script for chatting with a trained chatbot model
"""
import datetime
from os import path

import general_utils
import chat_command_handler
from chat_settings import ChatSettings
from chatbot_model import ChatbotModel
from vocabulary import Vocabulary

#Read the hyperparameters and configure paths
_, model_dir, hparams, checkpoint = general_utils.initialize_session("chat")

#Load the vocabulary
print()
print("Loading vocabulary...")
if hparams.model_hparams.share_embedding:
    shared_vocab_filepath = path.join(model_dir,
                                      Vocabulary.SHARED_VOCAB_FILENAME)
    input_vocabulary = Vocabulary.load(shared_vocab_filepath)
    output_vocabulary = input_vocabulary
else:
    input_vocab_filepath = path.join(model_dir,
                                     Vocabulary.INPUT_VOCAB_FILENAME)
    input_vocabulary = Vocabulary.load(input_vocab_filepath)
    output_vocab_filepath = path.join(model_dir,
                                      Vocabulary.OUTPUT_VOCAB_FILENAME)
    output_vocabulary = Vocabulary.load(output_vocab_filepath)
예제 #3
0
"""
import time
import math
from os import path
from shutil import copytree  # Recursively copy an entire directory tree rooted at src.

import general_utils
import train_console_helper
from dataset_readers import dataset_reader_factory
from vocabulary_importers import vocabulary_importer_factory
from vocabulary import Vocabulary
from chatbot_model import ChatbotModel
from training_stats import TrainingStats

#Read the hyperparameters and paths
dataset_dir, model_dir, hparams, resume_checkpoint, encoder_embeddings_dir, decoder_embeddings_dir = general_utils.initialize_session(
    "train")
training_stats_filepath = path.join(model_dir, "training_stats.json")

#Read the chatbot dataset and generate / import the vocabulary
dataset_reader = dataset_reader_factory.get_dataset_reader(dataset_dir)

print()
print("Reading dataset '{0}'...".format(dataset_reader.dataset_name))
dataset, dataset_read_stats = dataset_reader.read_dataset(
    dataset_dir=dataset_dir,
    model_dir=model_dir,
    training_hparams=hparams.training_hparams,
    share_vocab=hparams.model_hparams.share_embedding,
    encoder_embeddings_dir=encoder_embeddings_dir,
    decoder_embeddings_dir=decoder_embeddings_dir)
if encoder_embeddings_dir is not None:
"""
Script for training the chatbot model
"""
import time
import math
from os import path

import general_utils
import dataset_reader_factory
from chatbot_model import ChatbotModel
from training_stats import TrainingStats

#Read the hyperparameters and paths
dataset_dir, model_dir, hparams, resume_checkpoint = general_utils.initialize_session(
    "train")
training_stats_filepath = path.join(model_dir, "training_stats.json")

#Read the chatbot dataset
dataset_reader = dataset_reader_factory.get_dataset_reader(dataset_dir)

print()
print("Reading dataset '{0}'...".format(dataset_reader.dataset_name))
dataset = dataset_reader.read_dataset(
    dataset_dir=dataset_dir,
    model_dir=model_dir,
    training_hparams=hparams.training_hparams,
    share_vocab=hparams.model_hparams.share_embedding)

#Split the chatbot dataset into training & validation datasets
print(
    "Splitting {0} samples into training & validation sets ({1}% used for validation)..."