示例#1
0
class MDRGWordPolicy(SysPolicy):
    def __init__(self,
                 archive_file=DEFAULT_ARCHIVE_FILE,
                 cuda_device=DEFAULT_CUDA_DEVICE,
                 model_file=None):

        if not os.path.isfile(archive_file):
            if not model_file:
                raise Exception("No model for MDRG is specified!")
            archive_file = cached_path(model_file)

        temp_path = tempfile.mkdtemp()
        zip_ref = zipfile.ZipFile(archive_file, 'r')
        zip_ref.extractall(temp_path)
        zip_ref.close()

        self.dic = pickle.load(
            open(os.path.join(temp_path, 'mdrg/svdic.pkl'), 'rb'))
        # Load dictionaries
        with open(os.path.join(temp_path,
                               'mdrg/input_lang.index2word.json')) as f:
            input_lang_index2word = json.load(f)
        with open(os.path.join(temp_path,
                               'mdrg/input_lang.word2index.json')) as f:
            input_lang_word2index = json.load(f)
        with open(os.path.join(temp_path,
                               'mdrg/output_lang.index2word.json')) as f:
            output_lang_index2word = json.load(f)
        with open(os.path.join(temp_path,
                               'mdrg/output_lang.word2index.json')) as f:
            output_lang_word2index = json.load(f)
        self.response_model = Model(args, input_lang_index2word,
                                    output_lang_index2word,
                                    input_lang_word2index,
                                    output_lang_word2index)
        self.response_model.loadModel(os.path.join(temp_path, 'mdrg/mdrg'))

        shutil.rmtree(temp_path)

        self.prev_state = init_state()
        self.prev_active_domain = None

    def predict(self, state):
        try:
            response, active_domain = predict(self.response_model,
                                              self.prev_state,
                                              self.prev_active_domain, state,
                                              self.dic)
        except Exception as e:
            print('Response generation error', e)
            response = 'What did you say?'
            active_domain = None
        self.prev_state = deepcopy(state)
        self.prev_active_domain = active_domain
        return response
示例#2
0
def loadModel(num):
    # Load dictionaries
    with open(os.path.join(DATA_PATH, 'input_lang.index2word.json')) as f:
        input_lang_index2word = json.load(f)
    with open(os.path.join(DATA_PATH, 'input_lang.word2index.json')) as f:
        input_lang_word2index = json.load(f)
    with open(os.path.join(DATA_PATH, 'output_lang.index2word.json')) as f:
        output_lang_index2word = json.load(f)
    with open(os.path.join(DATA_PATH, 'output_lang.word2index.json')) as f:
        output_lang_word2index = json.load(f)

    # Reload existing checkpoint
    model = Model(args, input_lang_index2word, output_lang_index2word,
                  input_lang_word2index, output_lang_word2index)
    model.loadModel(iter=num)

    return model
示例#3
0
    def __init__(self,
                 archive_file=DEFAULT_ARCHIVE_FILE,
                 cuda_device=DEFAULT_CUDA_DEVICE,
                 model_file=None):

        if not os.path.isfile(archive_file):
            if not model_file:
                raise Exception("No model for MDRG is specified!")
            archive_file = cached_path(model_file)

        temp_path = tempfile.mkdtemp()
        zip_ref = zipfile.ZipFile(archive_file, 'r')
        zip_ref.extractall(temp_path)
        zip_ref.close()

        self.dic = pickle.load(
            open(os.path.join(temp_path, 'mdrg/svdic.pkl'), 'rb'))
        # Load dictionaries
        with open(os.path.join(temp_path,
                               'mdrg/input_lang.index2word.json')) as f:
            input_lang_index2word = json.load(f)
        with open(os.path.join(temp_path,
                               'mdrg/input_lang.word2index.json')) as f:
            input_lang_word2index = json.load(f)
        with open(os.path.join(temp_path,
                               'mdrg/output_lang.index2word.json')) as f:
            output_lang_index2word = json.load(f)
        with open(os.path.join(temp_path,
                               'mdrg/output_lang.word2index.json')) as f:
            output_lang_word2index = json.load(f)
        self.response_model = Model(args, input_lang_index2word,
                                    output_lang_index2word,
                                    input_lang_word2index,
                                    output_lang_word2index)
        self.response_model.loadModel(os.path.join(temp_path, 'mdrg/mdrg'))

        shutil.rmtree(temp_path)

        self.prev_state = init_state()
        self.prev_active_domain = None
示例#4
0
def loadModelAndData(num):
    # Load dictionaries
    with open('data/input_lang.index2word.json') as f:
        input_lang_index2word = json.load(f)
    with open('data/input_lang.word2index.json') as f:
        input_lang_word2index = json.load(f)
    with open('data/output_lang.index2word.json') as f:
        output_lang_index2word = json.load(f)
    with open('data/output_lang.word2index.json') as f:
        output_lang_word2index = json.load(f)

    # Reload existing checkpoint
    model = Model(args, input_lang_index2word, output_lang_index2word,
                  input_lang_word2index, output_lang_word2index)
    if args.load_param:
        model.loadModel(iter=num)

    # Load data
    if os.path.exists(args.decode_output):
        shutil.rmtree(args.decode_output)
        os.makedirs(args.decode_output)
    else:
        os.makedirs(args.decode_output)

    if os.path.exists(args.valid_output):
        shutil.rmtree(args.valid_output)
        os.makedirs(args.valid_output)
    else:
        os.makedirs(args.valid_output)

    # Load validation file list:
    # with open('data/val_dials.json') as outfile:
    with open('data/x_dials.json') as outfile:
        val_dials = json.load(outfile)

    # Load test file list:
    # with open('data/test_dials.json') as outfile:
    with open('data/x_dials.json') as outfile:
        test_dials = json.load(outfile)
    return model, val_dials, test_dials
示例#5
0
    # load data and dictionaries
    with open('data/input_lang.index2word.json') as f:
        input_lang_index2word = json.load(f)
    with open('data/input_lang.word2index.json') as f:
        input_lang_word2index = json.load(f)
    with open('data/output_lang.index2word.json') as f:
        output_lang_index2word = json.load(f)
    with open('data/output_lang.word2index.json') as f:
        output_lang_word2index = json.load(f)

    return input_lang_index2word, output_lang_index2word, input_lang_word2index, output_lang_word2index


if __name__ == '__main__':
    input_lang_index2word, output_lang_index2word, input_lang_word2index, output_lang_word2index = loadDictionaries(
    )
    # Load training file list:
    with open('data/train_dials.json') as outfile:
        train_dials = json.load(outfile)

    # Load validation file list:
    with open('data/val_dials.json') as outfile:
        val_dials = json.load(outfile)

    model = Model(args, input_lang_index2word, output_lang_index2word,
                  input_lang_word2index, output_lang_word2index)
    if args.load_param:
        model.loadModel(args.epoch_load)

    trainIters(model, n_epochs=args.max_epochs, args=args)