class MDRGWordPolicy(SysPolicy): def __init__(self, archive_file=DEFAULT_ARCHIVE_FILE, cuda_device=DEFAULT_CUDA_DEVICE, model_file=None): if not os.path.isfile(archive_file): if not model_file: raise Exception("No model for MDRG is specified!") archive_file = cached_path(model_file) temp_path = tempfile.mkdtemp() zip_ref = zipfile.ZipFile(archive_file, 'r') zip_ref.extractall(temp_path) zip_ref.close() self.dic = pickle.load( open(os.path.join(temp_path, 'mdrg/svdic.pkl'), 'rb')) # Load dictionaries with open(os.path.join(temp_path, 'mdrg/input_lang.index2word.json')) as f: input_lang_index2word = json.load(f) with open(os.path.join(temp_path, 'mdrg/input_lang.word2index.json')) as f: input_lang_word2index = json.load(f) with open(os.path.join(temp_path, 'mdrg/output_lang.index2word.json')) as f: output_lang_index2word = json.load(f) with open(os.path.join(temp_path, 'mdrg/output_lang.word2index.json')) as f: output_lang_word2index = json.load(f) self.response_model = Model(args, input_lang_index2word, output_lang_index2word, input_lang_word2index, output_lang_word2index) self.response_model.loadModel(os.path.join(temp_path, 'mdrg/mdrg')) shutil.rmtree(temp_path) self.prev_state = init_state() self.prev_active_domain = None def predict(self, state): try: response, active_domain = predict(self.response_model, self.prev_state, self.prev_active_domain, state, self.dic) except Exception as e: print('Response generation error', e) response = 'What did you say?' active_domain = None self.prev_state = deepcopy(state) self.prev_active_domain = active_domain return response
def loadModel(num): # Load dictionaries with open(os.path.join(DATA_PATH, 'input_lang.index2word.json')) as f: input_lang_index2word = json.load(f) with open(os.path.join(DATA_PATH, 'input_lang.word2index.json')) as f: input_lang_word2index = json.load(f) with open(os.path.join(DATA_PATH, 'output_lang.index2word.json')) as f: output_lang_index2word = json.load(f) with open(os.path.join(DATA_PATH, 'output_lang.word2index.json')) as f: output_lang_word2index = json.load(f) # Reload existing checkpoint model = Model(args, input_lang_index2word, output_lang_index2word, input_lang_word2index, output_lang_word2index) model.loadModel(iter=num) return model
def __init__(self, archive_file=DEFAULT_ARCHIVE_FILE, cuda_device=DEFAULT_CUDA_DEVICE, model_file=None): if not os.path.isfile(archive_file): if not model_file: raise Exception("No model for MDRG is specified!") archive_file = cached_path(model_file) temp_path = tempfile.mkdtemp() zip_ref = zipfile.ZipFile(archive_file, 'r') zip_ref.extractall(temp_path) zip_ref.close() self.dic = pickle.load( open(os.path.join(temp_path, 'mdrg/svdic.pkl'), 'rb')) # Load dictionaries with open(os.path.join(temp_path, 'mdrg/input_lang.index2word.json')) as f: input_lang_index2word = json.load(f) with open(os.path.join(temp_path, 'mdrg/input_lang.word2index.json')) as f: input_lang_word2index = json.load(f) with open(os.path.join(temp_path, 'mdrg/output_lang.index2word.json')) as f: output_lang_index2word = json.load(f) with open(os.path.join(temp_path, 'mdrg/output_lang.word2index.json')) as f: output_lang_word2index = json.load(f) self.response_model = Model(args, input_lang_index2word, output_lang_index2word, input_lang_word2index, output_lang_word2index) self.response_model.loadModel(os.path.join(temp_path, 'mdrg/mdrg')) shutil.rmtree(temp_path) self.prev_state = init_state() self.prev_active_domain = None
def loadModelAndData(num): # Load dictionaries with open('data/input_lang.index2word.json') as f: input_lang_index2word = json.load(f) with open('data/input_lang.word2index.json') as f: input_lang_word2index = json.load(f) with open('data/output_lang.index2word.json') as f: output_lang_index2word = json.load(f) with open('data/output_lang.word2index.json') as f: output_lang_word2index = json.load(f) # Reload existing checkpoint model = Model(args, input_lang_index2word, output_lang_index2word, input_lang_word2index, output_lang_word2index) if args.load_param: model.loadModel(iter=num) # Load data if os.path.exists(args.decode_output): shutil.rmtree(args.decode_output) os.makedirs(args.decode_output) else: os.makedirs(args.decode_output) if os.path.exists(args.valid_output): shutil.rmtree(args.valid_output) os.makedirs(args.valid_output) else: os.makedirs(args.valid_output) # Load validation file list: # with open('data/val_dials.json') as outfile: with open('data/x_dials.json') as outfile: val_dials = json.load(outfile) # Load test file list: # with open('data/test_dials.json') as outfile: with open('data/x_dials.json') as outfile: test_dials = json.load(outfile) return model, val_dials, test_dials
# load data and dictionaries with open('data/input_lang.index2word.json') as f: input_lang_index2word = json.load(f) with open('data/input_lang.word2index.json') as f: input_lang_word2index = json.load(f) with open('data/output_lang.index2word.json') as f: output_lang_index2word = json.load(f) with open('data/output_lang.word2index.json') as f: output_lang_word2index = json.load(f) return input_lang_index2word, output_lang_index2word, input_lang_word2index, output_lang_word2index if __name__ == '__main__': input_lang_index2word, output_lang_index2word, input_lang_word2index, output_lang_word2index = loadDictionaries( ) # Load training file list: with open('data/train_dials.json') as outfile: train_dials = json.load(outfile) # Load validation file list: with open('data/val_dials.json') as outfile: val_dials = json.load(outfile) model = Model(args, input_lang_index2word, output_lang_index2word, input_lang_word2index, output_lang_word2index) if args.load_param: model.loadModel(args.epoch_load) trainIters(model, n_epochs=args.max_epochs, args=args)