def __init__(self, engine_path, is_admin): self.engine_path = engine_path self.is_admin = is_admin self.model_path = os.path.join(engine_path, 'model') self.config_path = os.path.join(engine_path, 'config.yaml') self.gpu_id = None # Load parameters from configuration file config = load_config(self.config_path) if "load_model" in config['training'].keys(): self.ckpt = os.path.realpath( os.path.join(app.config['JOEYNMT_FOLDER'], config['training']["load_model"])) else: self.ckpt = get_latest_checkpoint(self.model_path) self.use_cuda = config["training"].get("use_cuda", False) self.level = config["data"]["level"] self.max_output_length = config["training"].get( "max_output_length", None) self.lowercase = config["data"].get("lowercase", False) self.model_data = config["model"] # load the vocabularies src_vocab_file = os.path.realpath( os.path.join(app.config['JOEYNMT_FOLDER'], config["data"]["src_vocab"])) trg_vocab_file = os.path.realpath( os.path.join(app.config['JOEYNMT_FOLDER'], config["data"]["trg_vocab"])) self.src_vocab = build_vocab(field="src", vocab_file=src_vocab_file, dataset=None, max_size=-1, min_freq=0) self.trg_vocab = build_vocab(field="trg", vocab_file=trg_vocab_file, dataset=None, max_size=-1, min_freq=0) # whether to use beam search for decoding, 0: greedy decoding if "testing" in config.keys(): self.beam_size = config["testing"].get("beam_size", 0) self.beam_alpha = config["testing"].get("alpha", -1) else: self.beam_size = 1 self.beam_alpha = -1 self.logger = logging.getLogger(__name__)
def load_data( data_cfg: dict, get_test: bool = True, trg_vocab: object = None ) -> (object, object, Optional[object], Vocabulary): src_lang = data_cfg["src"] train_path = data_cfg["train"] dev_path = data_cfg["dev"] test_path = data_cfg.get("test", None) max_src_length = data_cfg["max_src_length"] max_trg_length = data_cfg["max_trg_length"] trg_max_size = data_cfg.get("trg_voc_limit", sys.maxsize) trg_min_freq = data_cfg.get("trg_voc_min_freq", 1) trg_vocab_file = data_cfg.get("trg_vocab", None) train_data = None dev_data = None test_data = None if trg_vocab is None: print(f'Getting train data...') train_data = get_dataset(train_path + "." + src_lang, max_src_length, max_trg_length) trg_vocab = build_vocab(field="trg", min_freq=trg_min_freq, max_size=trg_max_size, dataset=train_data, vocab_file=trg_vocab_file) print(f'Getting dev data...') dev_data = get_dataset(dev_path + "." + src_lang, max_src_length, max_trg_length) print(f'Generating train labels') train_data_labels = label_to_int(train_data[1], trg_vocab) train_data = (train_data[0], train_data_labels) print(f'Generating test labels') dev_data_labels = label_to_int(dev_data[1], trg_vocab) dev_data = (dev_data[0], dev_data_labels) if test_path is not None and get_test: print(f'Getting test data...') test_data = get_dataset(test_path + "." + src_lang, max_src_length, max_trg_length) test_data_labels = label_to_int(test_data[1], trg_vocab) test_data = (test_data[0], test_data_labels) return train_data, dev_data, test_data, trg_vocab
def load_data(data_cfg: dict) -> (object, object, Optional[object], Vocabulary): src_lang = data_cfg["src"] trg_lang = data_cfg["trg"] train_path = data_cfg["train"] dev_path = data_cfg["dev"] test_path = data_cfg.get("test", None) level = data_cfg["level"] lowercase = data_cfg["lowercase"] max_src_length = data_cfg["max_src_length"] max_trg_length = data_cfg["max_trg_length"] tok_fun = lambda s: list(s) if level == "char" else s.split() trg_max_size = data_cfg.get("trg_voc_limit", sys.maxsize) trg_min_freq = data_cfg.get("trg_voc_min_freq", 1) trg_vocab_file = data_cfg.get("trg_vocab", None) train_data = get_dataset(train_path+"."+src_lang, max_src_length, max_trg_length) trg_vocab = build_vocab(field="trg", min_freq=trg_min_freq, max_size=trg_max_size, dataset=train_data, vocab_file=trg_vocab_file) dev_data = get_dataset(dev_path+"."+src_lang, max_src_length, max_trg_length) test_data = None if test_path is not None: test_data = get_dataset(test_path+"."+src_lang, max_src_length, max_trg_length) test_data_labels = label_to_int(train_data[1], trg_vocab) test_data = (test_data[0], test_data_labels) train_data_labels = label_to_int(train_data[1], trg_vocab) dev_data_labels = label_to_int(dev_data[1], trg_vocab) print(f'Dev data labels = {str(len(dev_data_labels))}') sys.exit() train_data = (train_data[0], train_data_labels) dev_data = (dev_data[0], dev_data_labels) return train_data, dev_data, test_data, trg_vocab
def reformat_data(data, data_torchaudio, trg_min_freq, trg_max_size, tok_fun, trg_vocab_file=None, trg_vocab=None, lowercase=True): train_iter = data src_field = Noprocessfield(sequential=False, use_vocab=False, dtype=torch.double, include_lengths=True) trg_field = Field(init_token=BOS_TOKEN, eos_token=EOS_TOKEN, pad_token=PAD_TOKEN, tokenize=tok_fun, unk_token=UNK_TOKEN, batch_first=True, lower=lowercase, include_lengths=True) if trg_vocab is None: trg_vocab = build_vocab(min_freq=trg_min_freq, max_size=trg_max_size, dataset=data_torchaudio, trg_field=trg_field, vocab_file=trg_vocab_file) trg_field.vocab = trg_vocab entry_list = [] for i, batch in enumerate(iter(train_iter)): # reactivate training entry_list.append(Entry(batch[0][0].squeeze(), batch[0][1])) train_data = Dataset(entry_list, [('src', src_field), ('trg', trg_field)]) return train_data, trg_vocab, src_field, trg_field
def load_model(model_dir, bpe_src_code=None, tokenize=None): """ Start the bot. This means loading the model according to the config file. :param model_dir: Model directory of trained Joey NMT model. :param bpe_src_code: BPE codes for source side processing (optional). :param tokenize: If True, tokenize inputs with Moses tokenizer. :return: """ conf = {} cfg_file = model_dir+"/config.yaml" logger = logging.getLogger(__name__) conf["logger"] = logger # load the Joey configuration cfg = load_config(cfg_file) # load the checkpoint if "load_model" in cfg['training'].keys(): ckpt = cfg['training']["load_model"] else: ckpt = get_latest_checkpoint(model_dir) if ckpt is None: raise FileNotFoundError("No checkpoint found in directory {}." .format(model_dir)) # prediction parameters from config conf["use_cuda"] = cfg["training"].get("use_cuda", False) conf["level"] = cfg["data"]["level"] conf["max_output_length"] = cfg["training"].get("max_output_length", None) conf["lowercase"] = cfg["data"].get("lowercase", False) # load the vocabularies src_vocab_file = cfg["training"]["model_dir"] + "/src_vocab.txt" trg_vocab_file = cfg["training"]["model_dir"] + "/trg_vocab.txt" conf["src_vocab"] = build_vocab(field="src", vocab_file=src_vocab_file, dataset=None, max_size=-1, min_freq=0) conf["trg_vocab"] = build_vocab(field="trg", vocab_file=trg_vocab_file, dataset=None, max_size=-1, min_freq=0) # whether to use beam search for decoding, 0: greedy decoding if "testing" in cfg.keys(): conf["beam_size"] = cfg["testing"].get("beam_size", 0) conf["beam_alpha"] = cfg["testing"].get("alpha", -1) else: conf["beam_size"] = 1 conf["beam_alpha"] = -1 # pre-processing if tokenize is not None: src_tokenizer = MosesTokenizer(lang=cfg["data"]["src"]) trg_tokenizer = MosesDetokenizer(lang=cfg["data"]["trg"]) # tokenize input tokenizer = lambda x: src_tokenizer.tokenize(x, return_str=True) detokenizer = lambda x: trg_tokenizer.detokenize( x.split(), return_str=True) else: tokenizer = lambda x: x detokenizer = lambda x: x if bpe_src_code is not None and level == "bpe": # load bpe merge file merge_file = open(bpe_src_code, "r") bpe = apply_bpe.BPE(codes=merge_file) segmenter = lambda x: bpe.process_line(x.strip()) elif conf["level"] == "char": # split to chars segmenter = lambda x: list(x.strip()) else: segmenter = lambda x: x.strip() conf["preprocess"] = [tokenizer, segmenter] conf["postprocess"] = [detokenizer] # build model and load parameters into it model_checkpoint = load_checkpoint(ckpt, conf["use_cuda"]) model = build_model(cfg["model"], src_vocab=conf["src_vocab"], trg_vocab=conf["trg_vocab"]) model.load_state_dict(model_checkpoint["model_state"]) if conf["use_cuda"]: model.cuda() conf["model"] = model print("Joey NMT model loaded successfully.") return conf
def load_data( data_cfg: dict ) -> (Dataset, Dataset, Optional[Dataset], Vocabulary, Vocabulary): """ Load train, dev and optionally test data as specified in configuration. Vocabularies are created from the training set with a limit of `voc_limit` tokens and a minimum token frequency of `voc_min_freq` (specified in the configuration dictionary). The training data is filtered to include sentences up to `max_sent_length` on source and target side. :param data_cfg: configuration dictionary for data ("data" part of configuation file) :return: - train_data: training dataset - dev_data: development dataset - test_data: testdata set if given, otherwise None - src_vocab: source vocabulary extracted from training data - trg_vocab: target vocabulary extracted from training data """ # load data from files src_lang = data_cfg["src"] trg_lang = data_cfg["trg"] train_path = data_cfg["train"] dev_path = data_cfg["dev"] test_path = data_cfg.get("test", None) level = data_cfg["level"] lowercase = data_cfg["lowercase"] max_sent_length = data_cfg["max_sent_length"] tok_fun = lambda s: list(s) if level == "char" else s.split() src_field = data.Field(init_token=None, eos_token=EOS_TOKEN, pad_token=PAD_TOKEN, tokenize=tok_fun, batch_first=True, lower=lowercase, unk_token=UNK_TOKEN, include_lengths=True) trg_field = data.Field(init_token=BOS_TOKEN, eos_token=EOS_TOKEN, pad_token=PAD_TOKEN, tokenize=tok_fun, unk_token=UNK_TOKEN, batch_first=True, lower=lowercase, include_lengths=True) train_data = TranslationDataset( path=train_path, exts=("." + src_lang, "." + trg_lang), fields=(src_field, trg_field), filter_pred=lambda x: len(vars(x)['src']) <= max_sent_length and len( vars(x)['trg']) <= max_sent_length) src_max_size = data_cfg.get("src_voc_limit", sys.maxsize) src_min_freq = data_cfg.get("src_voc_min_freq", 1) trg_max_size = data_cfg.get("trg_voc_limit", sys.maxsize) trg_min_freq = data_cfg.get("trg_voc_min_freq", 1) src_vocab_file = data_cfg.get("src_vocab", None) trg_vocab_file = data_cfg.get("trg_vocab", None) src_vocab = build_vocab(field="src", min_freq=src_min_freq, max_size=src_max_size, dataset=train_data, vocab_file=src_vocab_file) trg_vocab = build_vocab(field="trg", min_freq=trg_min_freq, max_size=trg_max_size, dataset=train_data, vocab_file=trg_vocab_file) dev_data = TranslationDataset(path=dev_path, exts=("." + src_lang, "." + trg_lang), fields=(src_field, trg_field)) test_data = None if test_path is not None: # check if target exists if os.path.isfile(test_path + "." + trg_lang): test_data = TranslationDataset(path=test_path, exts=("." + src_lang, "." + trg_lang), fields=(src_field, trg_field)) else: # no target is given -> create dataset from src only test_data = MonoDataset(path=test_path, ext="." + src_lang, field=src_field) src_field.vocab = src_vocab trg_field.vocab = trg_vocab return train_data, dev_data, test_data, src_vocab, trg_vocab
def load_audio_data( cfg: dict ) -> (Dataset, Dataset, Optional[Dataset], Vocabulary, Vocabulary): """ Load train, dev and optionally test data as specified in configuration. Vocabularies are created from the training set with a limit of `voc_limit` tokens and a minimum token frequency of `voc_min_freq` (specified in the configuration dictionary). The training data is filtered to include sentences up to `max_sent_length` on text side and audios up to `max_audio_length`. :param cfg: configuration dictionary for data :return: - train_data: training dataset - dev_data: development dataset - test_data: testdata set if given, otherwise None - src_vocab: copy of trg_vocab - trg_vocab: target vocabulary extracted from training data """ # load data from files data_cfg = cfg["data"] src_lang = data_cfg["src"] trg_lang = data_cfg["trg"] if data_cfg["audio"] == "src": audio_lang = src_lang else: audio_lang = trg_lang train_path = data_cfg["train"] dev_path = data_cfg["dev"] test_path = data_cfg.get("test", None) level = data_cfg["level"] lowercase = data_cfg["lowercase"] max_sent_length = data_cfg.get("max_sent_length", sys.maxsize) max_audio_length = data_cfg.get("max_audio_length", sys.maxsize) number = cfg["model"]["encoder"]["embeddings"]["embedding_dim"] assert number <= 80,\ "The number of used audio features could not be higher than the number of Mel bands. Change the encoder's embedding_dim." check_ratio = data_cfg.get("input_length_ratio", sys.maxsize) audio_features = data_cfg["audio_features_level"] htk = data_cfg["use_htk"] scale = data_cfg.get("scale", None) #pylint: disable=unnecessary-lambda if level == "char": tok_fun = lambda s: list(s) char = True else: # bpe or word, pre-tokenized tok_fun = lambda s: s.split() char = False src_field = data.Field(init_token=None, eos_token=EOS_TOKEN, pad_token=PAD_TOKEN, tokenize=tok_fun, batch_first=True, lower=lowercase, unk_token=UNK_TOKEN, include_lengths=True) trg_field = data.Field(init_token=BOS_TOKEN, eos_token=EOS_TOKEN, pad_token=PAD_TOKEN, tokenize=tok_fun, unk_token=UNK_TOKEN, batch_first=True, lower=lowercase, include_lengths=True) train_data = AudioDataset( path=train_path, text_ext="." + audio_lang, audio_ext=".txt", sfield=src_field, tfield=trg_field, num=number, char_level=char, train=True, check=check_ratio, audio_level=audio_features, htk=htk, scale=scale, filter_pred=lambda x: len(vars(x)['src']) <= max_audio_length and len( vars(x)['trg']) <= max_sent_length) src_max_size = data_cfg.get("src_voc_limit", sys.maxsize) src_min_freq = data_cfg.get("src_voc_min_freq", 1) trg_max_size = data_cfg.get("trg_voc_limit", sys.maxsize) trg_min_freq = data_cfg.get("trg_voc_min_freq", 1) trg_vocab_file = data_cfg.get(audio_lang + "_vocab", None) src_vocab_file = None trg_vocab = build_vocab(field="trg", min_freq=trg_min_freq, max_size=trg_max_size, dataset=train_data, vocab_file=trg_vocab_file) src_vocab = build_vocab(field="src", min_freq=src_min_freq, max_size=src_max_size, dataset=train_data, vocab_file=src_vocab_file) #src_vocab = trg_vocab dev_data = AudioDataset(path=dev_path, text_ext="." + audio_lang, audio_ext=".txt", sfield=src_field, tfield=trg_field, num=number, char_level=char, train=False, check=check_ratio, audio_level=audio_features, htk=htk, scale=scale) test_data = None if test_path is not None: # check if target exists if os.path.isfile(test_path + "." + audio_lang): test_data = AudioDataset(path=test_path, text_ext="." + audio_lang, audio_ext=".txt", sfield=src_field, tfield=trg_field, num=number, char_level=char, train=False, check=check_ratio, audio_level=audio_features, htk=htk, scale=scale) else: # no target is given -> create dataset from src only test_data = MonoAudioDataset(path=test_path, audio_ext=".txt", field=src_field, num=number, char_level=char) trg_field.vocab = trg_vocab src_field.vocab = src_vocab return train_data, dev_data, test_data, src_vocab, trg_vocab
def load_data(data_cfg: dict, datasets: list = None)\ -> (Dataset, Dataset, Optional[Dataset], Vocabulary, Vocabulary): """ Load train, dev and optionally test data as specified in configuration. Vocabularies are created from the training set with a limit of `voc_limit` tokens and a minimum token frequency of `voc_min_freq` (specified in the configuration dictionary). The training data is filtered to include sentences up to `max_sent_length` on source and target side. If you set ``random_train_subset``, a random selection of this size is used from the training set instead of the full training set. :param data_cfg: configuration dictionary for data ("data" part of configuation file) :param datasets: list of dataset names to load :return: - train_data: training dataset - dev_data: development dataset - test_data: testdata set if given, otherwise None - src_vocab: source vocabulary extracted from training data - trg_vocab: target vocabulary extracted from training data """ if datasets is None: datasets = ["train", "dev", "test"] # load data from files src_lang = data_cfg["src"] trg_lang = data_cfg["trg"] train_path = data_cfg.get("train", None) dev_path = data_cfg.get("dev", None) test_path = data_cfg.get("test", None) if train_path is None and dev_path is None and test_path is None: raise ValueError('Please specify at least one data source path.') level = data_cfg["level"] lowercase = data_cfg["lowercase"] max_sent_length = data_cfg["max_sent_length"] tok_fun = lambda s: list(s) if level == "char" else s.split() src_field = data.Field(init_token=None, eos_token=EOS_TOKEN, pad_token=PAD_TOKEN, tokenize=tok_fun, batch_first=True, lower=lowercase, unk_token=UNK_TOKEN, include_lengths=True) trg_field = data.Field(init_token=BOS_TOKEN, eos_token=EOS_TOKEN, pad_token=PAD_TOKEN, tokenize=tok_fun, unk_token=UNK_TOKEN, batch_first=True, lower=lowercase, include_lengths=True) train_data = None if "train" in datasets and train_path is not None: logger.info("Loading training data...") train_data = TranslationDataset( path=train_path, exts=("." + src_lang, "." + trg_lang), fields=(src_field, trg_field), filter_pred=lambda x: len(vars(x)['src']) <= max_sent_length and len(vars(x)['trg']) <= max_sent_length) random_train_subset = data_cfg.get("random_train_subset", -1) if random_train_subset > -1: # select this many training examples randomly and discard the rest keep_ratio = random_train_subset / len(train_data) keep, _ = train_data.split( split_ratio=[keep_ratio, 1 - keep_ratio], random_state=random.getstate()) train_data = keep src_max_size = data_cfg.get("src_voc_limit", sys.maxsize) src_min_freq = data_cfg.get("src_voc_min_freq", 1) trg_max_size = data_cfg.get("trg_voc_limit", sys.maxsize) trg_min_freq = data_cfg.get("trg_voc_min_freq", 1) src_vocab_file = data_cfg.get("src_vocab", None) trg_vocab_file = data_cfg.get("trg_vocab", None) assert (train_data is not None) or (src_vocab_file is not None) assert (train_data is not None) or (trg_vocab_file is not None) logger.info("Building vocabulary...") src_vocab = build_vocab(field="src", min_freq=src_min_freq, max_size=src_max_size, dataset=train_data, vocab_file=src_vocab_file) trg_vocab = build_vocab(field="trg", min_freq=trg_min_freq, max_size=trg_max_size, dataset=train_data, vocab_file=trg_vocab_file) dev_data = None if "dev" in datasets and dev_path is not None: logger.info("Loading dev data...") dev_data = TranslationDataset(path=dev_path, exts=("." + src_lang, "." + trg_lang), fields=(src_field, trg_field)) test_data = None if "test" in datasets and test_path is not None: logger.info("Loading test data...") # check if target exists if os.path.isfile(test_path + "." + trg_lang): test_data = TranslationDataset(path=test_path, exts=("." + src_lang, "." + trg_lang), fields=(src_field, trg_field)) else: # no target is given -> create dataset from src only test_data = MonoDataset(path=test_path, ext="." + src_lang, field=src_field) src_field.vocab = src_vocab trg_field.vocab = trg_vocab logger.info("Data loaded.") return train_data, dev_data, test_data, src_vocab, trg_vocab
def run_bot(model_dir, bpe_src_code=None, tokenize=None): """ Start the bot. This means loading the model according to the config file. :param model_dir: Model directory of trained Joey NMT model. :param bpe_src_code: BPE codes for source side processing (optional). :param tokenize: If True, tokenize inputs with Moses tokenizer. :return: """ cfg_file = model_dir + "/config.yaml" logger = logging.getLogger(__name__) # load the Joey configuration cfg = load_config(cfg_file) # load the checkpoint if "load_model" in cfg['training'].keys(): ckpt = cfg['training']["load_model"] else: ckpt = get_latest_checkpoint(model_dir) if ckpt is None: raise FileNotFoundError( "No checkpoint found in directory {}.".format(model_dir)) # prediction parameters from config use_cuda = cfg["training"].get("use_cuda", False) level = cfg["data"]["level"] max_output_length = cfg["training"].get("max_output_length", None) lowercase = cfg["data"].get("lowercase", False) # load the vocabularies src_vocab_file = cfg["training"]["model_dir"] + "/src_vocab.txt" trg_vocab_file = cfg["training"]["model_dir"] + "/trg_vocab.txt" src_vocab = build_vocab(field="src", vocab_file=src_vocab_file, dataset=None, max_size=-1, min_freq=0) trg_vocab = build_vocab(field="trg", vocab_file=trg_vocab_file, dataset=None, max_size=-1, min_freq=0) # whether to use beam search for decoding, 0: greedy decoding if "testing" in cfg.keys(): beam_size = cfg["testing"].get("beam_size", 0) beam_alpha = cfg["testing"].get("alpha", -1) else: beam_size = 1 beam_alpha = -1 # pre-processing if tokenize is not None: src_tokenizer = MosesTokenizer(lang=cfg["data"]["src"]) trg_tokenizer = MosesDetokenizer(lang=cfg["data"]["trg"]) # tokenize input tokenizer = lambda x: src_tokenizer.tokenize(x, return_str=True) detokenizer = lambda x: trg_tokenizer.detokenize(x.split(), return_str=True) else: tokenizer = lambda x: x detokenizer = lambda x: x if bpe_src_code is not None and level == "bpe": # load bpe merge file merge_file = open(bpe_src_code, "r") bpe = apply_bpe.BPE(codes=merge_file) segmenter = lambda x: bpe.process_line(x.strip()) elif level == "char": # split to chars segmenter = lambda x: list(x.strip()) else: segmenter = lambda x: x.strip() # build model and load parameters into it model_checkpoint = load_checkpoint(ckpt, use_cuda) model = build_model(cfg["model"], src_vocab=src_vocab, trg_vocab=trg_vocab) model.load_state_dict(model_checkpoint["model_state"]) if use_cuda: model.cuda() print("Joey NMT model loaded successfully.") web_client = slack.WebClient(TOKEN, timeout=30) # get bot id bot_id = (web_client.api_call("auth.test")["user_id"].upper()) # find bot channel id all_channels = web_client.api_call("conversations.list")["channels"] for c in all_channels: if c["name"] == BOT_CHANNEL: bot_channel_id = c["id"] slack_events_adapter = SlackEventAdapter(BOT_SIGNIN, endpoint="/slack/events") @slack_events_adapter.on("message") def handle_message(event_data): message = event_data["event"] if message.get("subtype") is None: channel = message["channel"] user = message["user"] text = message["text"].strip() if user != bot_id and message.get("subtype") is None: # translates all messages in its channel and mentions if channel == bot_channel_id or bot_id in text: mention = "<@{}>".format(bot_id) # TODO remove all possible mentions with regex if mention in text: parts = text.split(mention) text = parts[0].strip() + parts[1].strip() message = translate(text, beam_size=beam_size, beam_alpha=beam_alpha, level=level, lowercase=lowercase, max_output_length=max_output_length, model=model, postprocess=[detokenizer], preprocess=[tokenizer, segmenter], src_vocab=src_vocab, trg_vocab=trg_vocab, use_cuda=use_cuda, logger=logger) web_client.chat_postMessage(text=message, token=TOKEN, channel=channel) # Error events @slack_events_adapter.on("error") def error_handler(err): print("ERROR: " + str(err)) slack_events_adapter.start(port=3000)
def load_model(self, src_language, trg_language, domain, bpe_src_code=None, tokenize=None): """ Load model for given trg language. """ # model_dir = "{}-{}".format(self._model_dir_prefix, trg_language) model_dir = f"{self._model_dir_prefix}{src_language}-{trg_language}-{domain}" # Load the checkpoint. ckpt_path = os.path.join(model_dir, 'model.ckpt') # Load the vocabularies. src_vocab_path = os.path.join(model_dir, 'src_vocab.txt') trg_vocab_path = os.path.join(model_dir, 'trg_vocab.txt') # Load the config. config_path = os.path.join(model_dir, 'config_orig.yaml') # Adjust config. config = load_config(config_path) new_config_file = os.path.join(model_dir, 'config.yaml') config = self._update_config(config, src_vocab_path, trg_vocab_path, model_dir, ckpt_path) with open(new_config_file, 'w') as cfile: yaml.dump(config, cfile) # print('Loaded model for {}-{}.'.format(self._src_language, trg_language)) print('Loaded model for {}-{}.'.format(src_language, trg_language)) conf = {} logger = logging.getLogger(__name__) conf["logger"] = logger # load the Joey configuration cfg = load_config(new_config_file) # load the checkpoint if "load_model" in cfg['training'].keys(): ckpt = cfg['training']["load_model"] else: ckpt = get_latest_checkpoint(model_dir) if ckpt is None: raise FileNotFoundError("No checkpoint found in directory {}." .format(model_dir)) # prediction parameters from config conf["use_cuda"] = cfg["training"].get("use_cuda", False) if torch.cuda.is_available() else False conf["level"] = cfg["data"]["level"] conf["max_output_length"] = cfg["training"].get("max_output_length", None) conf["lowercase"] = cfg["data"].get("lowercase", False) # load the vocabularies src_vocab_file = cfg["training"]["model_dir"] + "/src_vocab.txt" trg_vocab_file = cfg["training"]["model_dir"] + "/trg_vocab.txt" conf["src_vocab"] = build_vocab(field="src", vocab_file=src_vocab_file, dataset=None, max_size=-1, min_freq=0) conf["trg_vocab"] = build_vocab(field="trg", vocab_file=trg_vocab_file, dataset=None, max_size=-1, min_freq=0) # whether to use beam search for decoding, 0: greedy decoding if "testing" in cfg.keys(): conf["beam_size"] = cfg["testing"].get("beam_size", 0) conf["beam_alpha"] = cfg["testing"].get("alpha", -1) else: conf["beam_size"] = 1 conf["beam_alpha"] = -1 # pre-processing if tokenize is not None: src_tokenizer = MosesTokenizer(lang=cfg["data"]["src"]) trg_tokenizer = MosesDetokenizer(lang=cfg["data"]["trg"]) # tokenize input tokenizer = lambda x: src_tokenizer.tokenize(x, return_str=True) detokenizer = lambda x: trg_tokenizer.detokenize( x.split(), return_str=True) else: tokenizer = lambda x: x detokenizer = lambda x: x if bpe_src_code is not None and level == "bpe": # load bpe merge file merge_file = open(bpe_src_code, "r") bpe = apply_bpe.BPE(codes=merge_file) segmenter = lambda x: bpe.process_line(x.strip()) elif conf["level"] == "char": # split to chars segmenter = lambda x: list(x.strip()) else: segmenter = lambda x: x.strip() conf["preprocess"] = [tokenizer, segmenter] conf["postprocess"] = [detokenizer] # build model and load parameters into it model_checkpoint = load_checkpoint(ckpt, conf["use_cuda"]) model = build_model(cfg["model"], src_vocab=conf["src_vocab"], trg_vocab=conf["trg_vocab"]) model.load_state_dict(model_checkpoint["model_state"]) # ipdb.set_trace() if conf["use_cuda"]: model.cuda() conf["model"] = model print("Joey NMT model loaded successfully.") return conf
def load_unsupervised_data(data_cfg: dict) \ -> (Dataset, Dataset, Dataset, Dataset, Dataset, Dataset, Optional[Dataset], Optional[Dataset], Vocabulary, Vocabulary, dict): """ Load train, dev and optionally test data as specified in configuration. Expected file extensions for train data are `.(noised|denoised).(src|trg)` Vocabularies are created from the training set with a limit of `voc_limit` tokens and a minimum token frequency of `voc_min_freq` (specified in the configuration dictionary). The training data is filtered to include sentences up to `max_sent_length` on source and target side. All four resulting training corpora have to have the same length. Selecting a random subset of the training data is not supported. :param data_cfg: configuration dictionary for data :return: - src2src: Dataset for src to src denoising task - trg2trg: Dataset for trg to trg denoising task - BTsrc: Monolingual dataset containing denoised src data for BT - BTtrg: Monolingual dataset containing denoised trg data for BT - dev_src2trg: Dataset for src to trg validation - dev_trg2src: Dataset for trg to src validation - test_src2trg: Dataset for testing src to trg translation, optional - test_trg2src: Dataset for testing src to trg translation, optional - src_vocab: Vocabulary of src language - trg_vocab: Vocabulary of trg language - fields: Dictionary containing source and target fields for src and trg language, needed for on-the-fly BT """ src_lang = data_cfg["src"] trg_lang = data_cfg["trg"] noised_ext = data_cfg["noised"] denoised_ext = data_cfg["denoised"] assert noised_ext != denoised_ext train_path = data_cfg["train"] src2trg_dev_path = data_cfg["src2trg_dev"] trg2src_dev_path = data_cfg["trg2src_dev"] src2trg_test_path = data_cfg.get("src2trg_test", None) trg2src_test_path = data_cfg.get("trg2src_test", None) level = data_cfg["level"] lowercase = data_cfg["lowercase"] max_sent_length = data_cfg["max_sent_length"] tok_fun = lambda s: list(s) if level == "char" else s.split() # Make four fields # Src and trg language each get a source and target field # Because field vocabulary needs to be once from src language, and once from trg language # Source fields: # for src language src_src_field = data.Field(init_token=None, eos_token=EOS_TOKEN, pad_token=PAD_TOKEN, tokenize=tok_fun, batch_first=True, lower=lowercase, unk_token=UNK_TOKEN, include_lengths=True) # for trg language trg_src_field = data.Field(init_token=None, eos_token=EOS_TOKEN, pad_token=PAD_TOKEN, tokenize=tok_fun, batch_first=True, lower=lowercase, unk_token=UNK_TOKEN, include_lengths=True) # Target fields: # for src language src_trg_field = data.Field(init_token=BOS_TOKEN, eos_token=EOS_TOKEN, pad_token=PAD_TOKEN, tokenize=tok_fun, unk_token=UNK_TOKEN, batch_first=True, lower=lowercase, include_lengths=True) # for trg language trg_trg_field = data.Field(init_token=BOS_TOKEN, eos_token=EOS_TOKEN, pad_token=PAD_TOKEN, tokenize=tok_fun, unk_token=UNK_TOKEN, batch_first=True, lower=lowercase, include_lengths=True) fields = {'src': {src_lang: src_src_field, trg_lang: trg_src_field}, 'trg': {src_lang: src_trg_field, trg_lang: trg_trg_field}} # datasets for denoising # 'translate' from noised input to denoised output src2src = TranslationDataset(path=train_path, exts=("." + noised_ext + "." + src_lang, "." + denoised_ext + "." + src_lang), fields=(fields['src'][src_lang], fields['trg'][src_lang]), filter_pred= lambda x: len(vars(x)['src']) <= max_sent_length and len(vars(x)['trg']) <= max_sent_length) trg2trg = TranslationDataset(path=train_path, exts=("." + noised_ext + "." + trg_lang, "." + denoised_ext + "." + trg_lang), fields=(fields['src'][trg_lang], fields['trg'][trg_lang]), filter_pred= lambda x: len(vars(x)['src']) <= max_sent_length and len(vars(x)['trg']) <= max_sent_length) # datasets for BT # need denoised sources in order to create back-translations on-the-fly # then use (BT, denoised sources) tuples as training examples # so for now, create monolingual datasets of the denoised sources BTsrc = MonoDataset(path=train_path, ext="." + denoised_ext + "." + src_lang, field=fields['src'][src_lang], filter_pred=lambda x: len(vars(x)['src']) <= max_sent_length) BTtrg = MonoDataset(path=train_path, ext="." + denoised_ext + "." + trg_lang, field=fields['src'][trg_lang], filter_pred=lambda x: len(vars(x)['src']) <= max_sent_length) src_max_size = data_cfg.get("src_voc_limit", sys.maxsize) src_min_freq = data_cfg.get("src_voc_min_freq", 1) trg_max_size = data_cfg.get("trg_voc_limit", sys.maxsize) trg_min_freq = data_cfg.get("trg_voc_min_freq", 1) src_vocab_file = data_cfg.get("src_vocab", None) trg_vocab_file = data_cfg.get("trg_vocab", None) # build vocab based on denoised data (field="trg") src_vocab = build_vocab(field="trg", min_freq=src_min_freq, max_size=src_max_size, dataset=src2src, vocab_file=src_vocab_file) trg_vocab = build_vocab(field="trg", min_freq=trg_min_freq, max_size=trg_max_size, dataset=trg2trg, vocab_file=trg_vocab_file) assert len(src2src) == len(trg2trg) == len(BTsrc) == len(BTtrg), \ "All training sets must have equal length for unsupervised NMT." dev_src2trg = TranslationDataset(path=src2trg_dev_path, exts=("." + src_lang, "." + trg_lang), fields=(fields['src'][src_lang], fields['trg'][trg_lang])) dev_trg2src = TranslationDataset(path=trg2src_dev_path, exts=("." + trg_lang, "." + src_lang), fields=(fields['src'][trg_lang], fields['trg'][src_lang])) def _make_test_set(test_path: str, src_lang: str, trg_lang: str) -> Optional[Dataset]: if test_path is not None: if os.path.isfile(test_path + "." + trg_lang): return TranslationDataset(path=test_path, exts=("." + src_lang, "." + trg_lang), fields=(fields['src'][src_lang], fields['trg'][trg_lang])) else: return MonoDataset(path=test_path, ext="." + src_lang, field=fields['src'][src_lang]) else: return None test_src2trg = _make_test_set(src2trg_test_path, src_lang, trg_lang) test_trg2src = _make_test_set(trg2src_test_path, trg_lang, src_lang) # set vocab of all fields # this is why we need four fields in total src_src_field.vocab = src_vocab trg_src_field.vocab = trg_vocab src_trg_field.vocab = src_vocab trg_trg_field.vocab = trg_vocab return src2src, trg2trg, BTsrc, BTtrg, \ dev_src2trg, dev_trg2src, \ test_src2trg, test_trg2src, \ src_vocab, trg_vocab, \ fields
def load_data( data_cfg: dict ) -> (Dataset, Dataset, Optional[Dataset], Vocabulary, Vocabulary): """ Load train, dev and optionally test data as specified in configuration. Vocabularies are created from the training set with a limit of `voc_limit` tokens and a minimum token frequency of `voc_min_freq` (specified in the configuration dictionary). The training data is filtered to include sentences up to `max_sent_length` on source and target side. If you set ``random_train_subset``, a random selection of this size is used from the training set instead of the full training set. :param data_cfg: configuration dictionary for data ("data" part of configuation file) :return: - train_data: training dataset - dev_data: development dataset - test_data: testdata set if given, otherwise None - src_vocab: source vocabulary extracted from training data - trg_vocab: target vocabulary extracted from training data - train_kb: TranslationDataset from train KB - dev_kb: TranslationDataset from dev KB - test_kb: TranslationDataset from test KB - train_kb_lookup: List of KB association lookup indices from data to KBs - dev_kb_lookup: List of KB association lookup indices from data to KBs - test_kb_lookup: List of KB association lookup indices from data to KBs - train_kb_lengths: List of KB lengths - dev_kb_lengths: List of KB lengths - test_kb_lengths: List of KB lengths """ # load data from files src_lang = data_cfg["src"] trg_lang = data_cfg["trg"] train_path = data_cfg["train"] dev_path = data_cfg["dev"] test_path = data_cfg.get("test", None) level = data_cfg["level"] lowercase = data_cfg["lowercase"] max_sent_length = data_cfg.get("max_sent_length", sys.maxsize * .1) #kb stuff kb_task = bool(data_cfg.get("kb_task", False)) if kb_task: kb_src = data_cfg.get("kb_src", "kbk") kb_trg = data_cfg.get("kb_trg", "kbv") kb_lkp = data_cfg.get("kb_lkp", "lkp") kb_len = data_cfg.get("kb_len", "len") kb_trv = data_cfg.get("kb_truvals", "trv") global_trv = data_cfg.get("global_trv", "") if global_trv: print( f"UserWarning global_trv parameter deprecated, use nothing instead." ) trutrg = data_cfg.get("trutrg", "car") canonization_mode = data_cfg.get("canonization_mode", "canonize") assert canonization_mode in ["canonize", "hash"], canonization_mode # TODO FIXME following is hardcoded; add to configs please pnctprepro = True else: # the rest of the above variables are set at the end of load data for the non KB case pnctprepro = False # default joeyNMT behaviour for sentences tok_fun = list if level == "char" else ( pkt_tokenize if pnctprepro else tokenize) src_field = data.Field(init_token=None, eos_token=EOS_TOKEN, pad_token=PAD_TOKEN, tokenize=tok_fun, batch_first=True, lower=lowercase, unk_token=UNK_TOKEN, include_lengths=True) trg_field = data.Field(init_token=BOS_TOKEN, eos_token=EOS_TOKEN, pad_token=PAD_TOKEN, tokenize=tok_fun, unk_token=UNK_TOKEN, batch_first=True, lower=lowercase, include_lengths=True) if kb_task: # NOTE lowercase MUST be False for datasets with tokens that may include whitespace! # the torchtext lowercase pipeline seems to operate not just on the first char of a token (dataset field level) # but lowercases individual words separated by whitespace WITHIN a specified token # which leads to the vocab not recognizing tokens even though added to the field.vocab # via joeynmt.vocabulary.build_vocab # other way to circumvent may be to lowercase in the same manner before calling # field.process trv_field = data.Field(init_token=BOS_TOKEN, eos_token=EOS_TOKEN, pad_token=PAD_TOKEN, tokenize=lambda entire_line: [entire_line], unk_token=UNK_TOKEN, batch_first=True, lower=False, include_lengths=False) train_data = TranslationDataset( path=train_path, exts=("." + src_lang, "." + trg_lang), fields=(src_field, trg_field), filter_pred=lambda x: len(vars(x)['src']) <= max_sent_length and len( vars(x)['trg']) <= max_sent_length) if kb_task: #load train_kb and metadata # NOTE change trg_lang to trutrg for dev/test # train_data has been loaded with normal extension (canonized files, e.g. train.carno) # dev/test_data will be loaded from non canonized files canon_trg = trg_lang # keep this for loss reporting (load dev/test data from here separately) trg_lang = trutrg train_kb_truvals = MonoDataset(path=train_path, ext=("." + kb_trv), field=("kbtrv", trv_field), filter_pred=lambda x: True) train_kb = TranslationDataset(path=train_path, exts=("." + kb_src, "." + kb_trg), fields=(("kbsrc", src_field), ("kbtrg", trg_field)), filter_pred=lambda x: True) with open(train_path + "." + kb_lkp, "r") as lkp: lookup = lkp.readlines() train_kb_lookup = [int(elem[:-1]) for elem in lookup if elem[:-1]] with open(train_path + "." + kb_len, "r") as lens: lengths = lens.readlines() train_kb_lengths = [int(elem[:-1]) for elem in lengths if elem[:-1]] # now that we have train data, build vocabulary from it. worry about dev and test data further below src_max_size = data_cfg.get("src_voc_limit", sys.maxsize) src_min_freq = data_cfg.get("src_voc_min_freq", 1) trg_max_size = data_cfg.get("trg_voc_limit", sys.maxsize) trg_min_freq = data_cfg.get("trg_voc_min_freq", 1) src_vocab_file = data_cfg.get("src_vocab", None) # NOTE unused trg_vocab_file = data_cfg.get("trg_vocab", None) trg_kb_vocab_file = data_cfg.get("trg_kb_vocab", None) trg_vocab_file = trg_vocab_file if not trg_kb_vocab_file else trg_kb_vocab_file # prefer to use joint trg_kb_vocab_file if specified vocab_building_datasets = train_data if not kb_task else (train_data, train_kb) vocab_building_src_fields = "src" if not kb_task else ("src", "kbsrc") vocab_building_trg_fields = "trg" if not kb_task else ("trg", "kbtrg") src_vocab = build_vocab(fields=vocab_building_src_fields, min_freq=src_min_freq, max_size=src_max_size, dataset=vocab_building_datasets, vocab_file=src_vocab_file) trg_vocab = build_vocab(fields=vocab_building_trg_fields, min_freq=trg_min_freq, max_size=trg_max_size, dataset=vocab_building_datasets, vocab_file=trg_vocab_file) random_train_subset = data_cfg.get("random_train_subset", -1) if random_train_subset > -1: # select this many training examples randomly and discard the rest keep_ratio = random_train_subset / len(train_data) keep, _ = train_data.split(split_ratio=[keep_ratio, 1 - keep_ratio], random_state=random.getstate()) train_data = keep dev_data = TranslationDataset(path=dev_path, exts=("." + src_lang, "." + trg_lang), fields=(src_field, trg_field)) if kb_task: #load dev kb and metadata; load canonized dev data for loss reporting dev_data_canon = TranslationDataset(path=dev_path, exts=("." + src_lang, "." + canon_trg), fields=(src_field, trg_field)) dev_kb = TranslationDataset(path=dev_path, exts=("." + kb_src, "." + kb_trg), fields=(("kbsrc", src_field), ("kbtrg", trg_field)), filter_pred=lambda x: True) dev_kb_truvals = MonoDataset(path=dev_path, ext=("." + kb_trv), field=("kbtrv", trv_field), filter_pred=lambda x: True) with open(dev_path + "." + kb_lkp, "r") as lkp: lookup = lkp.readlines() dev_kb_lookup = [int(elem[:-1]) for elem in lookup if elem[:-1]] with open(dev_path + "." + kb_len, "r") as lens: lengths = lens.readlines() dev_kb_lengths = [int(elem[:-1]) for elem in lengths if elem[:-1]] test_data = None if test_path is not None: # check if target exists if os.path.isfile(test_path + "." + trg_lang): test_data = TranslationDataset(path=test_path, exts=("." + src_lang, "." + trg_lang), fields=(src_field, trg_field)) else: # no target is given -> create dataset from src only test_data = MonoDataset(path=test_path, ext="." + src_lang, field=src_field) if kb_task: #load test kb and metadata test_data_canon = TranslationDataset(path=test_path, exts=("." + src_lang, "." + canon_trg), fields=(src_field, trg_field)) test_kb = TranslationDataset(path=test_path, exts=("." + kb_src, "." + kb_trg), fields=(("kbsrc", src_field), ("kbtrg", trg_field)), filter_pred=lambda x: True) test_kb_truvals = MonoDataset(path=test_path, ext=("." + kb_trv), field=("kbtrv", trv_field), filter_pred=lambda x: True) with open(test_path + "." + kb_lkp, "r") as lkp: lookup = lkp.readlines() test_kb_lookup = [int(elem[:-1]) for elem in lookup if elem[:-1]] with open(test_path + "." + kb_len, "r") as lens: lengths = lens.readlines() test_kb_lengths = [int(elem[:-1]) for elem in lengths if elem[:-1]] # finally actually set the .vocab field attributes src_field.vocab = src_vocab # also sets kb_src_field.vocab if theyre the same (variables point to same object) trg_field.vocab = trg_vocab if kb_task: # NOTE this vocab is hardcodedly built from the concatenation of train+dev+test trv files! # trv_path = train_path[:len(train_path)-train_path[::-1].find("/")]+global_trv # assert os.path.isfile(trv_path) trv_ext = "." + kb_trv trv_train_path = train_path + trv_ext trv_dev_path = dev_path + trv_ext trv_test_path = test_path + trv_ext assert os.path.isfile(trv_train_path) # try to make vocabulary exactly as large as needed # trv_vocab._from_file(trv_path) trv_vocab = deepcopy(trg_vocab) # FIXME only add this for source copying? trv_vocab._from_list(src_vocab.itos) trv_vocab._from_file(trv_train_path) trv_vocab._from_file(trv_dev_path) trv_vocab._from_file(trv_test_path) if canonization_mode == "canonize": # stanford data assert "schedule" in trv_vocab.itos # NOTE really important for model.postprocess: # trv_vocab must begin with trg_vocab # to look up canonical tokens correctly assert trg_vocab.itos == trv_vocab.itos[:len(trg_vocab)] print( f"Added true value lines as tokens to trv_vocab of length={len(trv_vocab)}" ) trv_field.vocab = trv_vocab if kb_task: # make canonization function to create KB from source for batches without one entities_path = "data/kvr/kvret_entities_altered.json" # TODO FIXME add to config entities = load_json(fp=entities_path) efficient_entities = preprocess_entity_dict(entities, lower=lowercase, tok_fun=tok_fun) if canonization_mode == "hash": # initialize with train knowledgebases hash_vocab = build_vocab(max_size=4096, dataset=train_kb, fields=vocab_building_trg_fields, min_freq=1, vocab_file=trv_train_path) hash_vocab._from_file(trv_train_path) hash_vocab._from_file(trv_dev_path) hash_vocab._from_file(trv_test_path) # assert False, hash_vocab.itos # assert False, canonize_sequence(["your", "meeting", "in", "conference", "room", "100", "is", "with", "martha"], efficient_entities) # assert False, # NOTE # assert False, hash_canons(["Sure" , "the", "chinese", "good", "luck", "chinese", "food", "takeaway", "is", "on","the_good_luck_chinese_food_takeaway_address"], hash_vocab.itos) # assert False, # NOTE if canonization_mode == "canonize": class Canonizer: def __init__(self, copy_from_source: bool = False): self.copy_from_source = bool(copy_from_source) def __call__(self, seq): processed, indices, matches = canonize_sequence( seq, efficient_entities) return processed, indices, matches elif canonization_mode == "hash": class Canonizer: def __init__(self, copy_from_source: bool = False): self.copy_from_source = bool(copy_from_source) def __call__(self, seq): processed, indices, matches = hash_canons( seq, hash_vocab.itos) return processed, indices, matches else: raise ValueError( f"canonization mode {canonization_mode} not implemented") if not kb_task: #default values for normal pipeline train_kb, dev_kb, test_kb = None, None, None train_kb_lookup, dev_kb_lookup, test_kb_lookup = [], [], [] train_kb_lengths, dev_kb_lengths, test_kb_lengths = [], [], [] train_kb_truvals, dev_kb_truvals, test_kb_truvals = [], [], [] trv_vocab = None dev_data_canon, test_data_canon = [], [] Canonizer = None # FIXME return dict here lol return train_data, dev_data, test_data,\ src_vocab, trg_vocab,\ train_kb, dev_kb, test_kb,\ train_kb_lookup, dev_kb_lookup, test_kb_lookup,\ train_kb_lengths, dev_kb_lengths, test_kb_lengths,\ train_kb_truvals, dev_kb_truvals, test_kb_truvals,\ trv_vocab, Canonizer, \ dev_data_canon, test_data_canon
def load_data(data_cfg: dict): """ Load train, dev and optionally test data as specified in configuration. Vocabularies are created from the training set with a limit of `voc_limit` tokens and a minimum token frequency of `voc_min_freq` (specified in the configuration dictionary). The training data is filtered to include sentences up to `max_sent_length` on source and target side. :param data_cfg: configuration dictionary for data ("data" part of configuation file) :return: - train_data: training dataset - dev_data: development dataset - test_data: testdata set if given, otherwise None - vocabs: dictionary from src and trg (and possibly other fields) to their corresponding vocab objects """ data_format = data_cfg.get("format", "bitext") formats = {"bitext", "sigmorphon", "sigmorphon-simple", "sigmorphon-g2p"} assert data_format in formats train_path = data_cfg["train"] dev_path = data_cfg["dev"] test_path = data_cfg.get("test", None) lowercase = data_cfg["lowercase"] max_sent_length = data_cfg["max_sent_length"] default_level = data_cfg.get("level", "word") voc_limit = data_cfg.get("voc_limit", sys.maxsize) voc_min_freq = data_cfg.get("voc_min_freq", 1) main_fields = ["src", "trg"] if data_format != "sigmorphon" \ else ["src", "trg", "inflection"] label_fields = [] multilingual = data_cfg.get("multilingual", False) if multilingual: assert data_format in {"sigmorphon", "sigmorphon-g2p"} label_fields.append("language") suffixes = {f_name: data_cfg.get(f_name, "") for f_name in main_fields} sequential_field = partial(Field, eos_token=EOS_TOKEN, pad_token=PAD_TOKEN, unk_token=UNK_TOKEN, batch_first=True, lower=lowercase, include_lengths=True) fields = dict() for f_name in main_fields: init_token = BOS_TOKEN if f_name == "trg" else None if f_name != "inflection": current_level = data_cfg.get(f_name + "_level", default_level) tok_fun = list if current_level == "char" else str.split else: tok_fun = partial(str.split, sep=';') fields[f_name] = sequential_field(init_token=init_token, tokenize=tok_fun) for f_name in label_fields: fields[f_name] = Field(sequential=False) filter_ex = partial(filter_example, max_sent_length=max_sent_length) if data_format == "bitext": dataset_cls = partial(TranslationDataset, exts=("." + suffixes["src"], "." + suffixes["trg"]), fields=(fields["src"], fields["trg"])) else: sigmorphon_fields = {k: [(k, v)] for k, v in fields.items()} if data_format == "sigmorphon": class_name = SigmorphonDataset elif data_format == "sigmorphon-g2p": class_name = SigmorphonG2PDataset else: class_name = SimpleSigmorphonDataset dataset_cls = partial(class_name, fields=sigmorphon_fields) if test_path is not None: trg_suffix = suffixes["trg"] if data_format != "bitext" or isfile(test_path + "." + trg_suffix): test_dataset_cls = dataset_cls else: test_dataset_cls = partial(MonoDataset, ext="." + suffixes["src"], field=fields["src"]) else: test_dataset_cls = None train_data = dataset_cls(path=train_path, filter_pred=filter_ex) vocabs = dict() vocab_counts = dict() # language-specific vocab subsets for f_name in main_fields: vocab_file = data_cfg.get("{}_vocab".format(f_name), None) max_size = data_cfg.get("{}_voc_limit".format(f_name), voc_limit) min_freq = data_cfg.get("{}_voc_min_freq".format(f_name), voc_min_freq) f_vocab, f_vocab_counts = build_vocab(field=f_name, min_freq=min_freq, max_size=max_size, dataset=train_data, vocab_file=vocab_file, multilingual=multilingual) vocabs[f_name] = f_vocab vocab_counts[f_name] = f_vocab_counts for f_name in label_fields: vocab_file = data_cfg.get("{}_vocab".format(f_name), None) max_size = data_cfg.get("{}_voc_limit".format(f_name), voc_limit) min_freq = data_cfg.get("{}_voc_min_freq".format(f_name), voc_min_freq) f_vocab, _ = build_vocab(field=f_name, min_freq=min_freq, max_size=max_size, dataset=train_data, vocab_file=vocab_file, sequential=False) vocabs[f_name] = f_vocab dev_data = dataset_cls(path=dev_path) if test_path is not None: trg_suffix = suffixes["trg"] if data_format != "bitext" or isfile(test_path + "." + trg_suffix): test_dataset_cls = dataset_cls else: test_dataset_cls = partial(MonoDataset, ext="." + suffixes["src"], field=fields["src"]) test_data = test_dataset_cls(path=test_path) else: test_data = None for field_name in fields: fields[field_name].vocab = vocabs[field_name] ret = { "train_data": train_data, "dev_data": dev_data, "test_data": test_data, "vocabs": vocabs } return ret
def train(cfg_file: str) -> None: """ Main training function. After training, also test on test data if given. :param cfg_file: path to configuration yaml file """ cfg = load_config(cfg_file) # make logger model_dir = make_model_dir(cfg["training"]["model_dir"], overwrite=cfg["training"].get( "overwrite", False)) _ = make_logger(model_dir, mode="train") # version string returned # TODO: save version number in model checkpoints # set the random seed set_seed(seed=cfg["training"].get("random_seed", 42)) # load the data train_tasks_list = [] valid_tasks_list = [] src_tasks = cfg["data"].get("src") trg_tasks = cfg["data"].get("trg") for x in range(len(src_tasks)): src_lang = src_tasks[x] trg_lang = trg_tasks[x] train_data, dev_data, _, _, _ = load_data(data_cfg=cfg["data"], src_lang=src_lang, trg_lang=trg_lang) train_tasks_list.append(train_data) valid_tasks_list.append(dev_data) #build vocabulary logger.info("Building vocabulary...") src_max_size = cfg["data"].get("src_voc_limit", sys.maxsize) src_min_freq = cfg["data"].get("src_voc_min_freq", 1) trg_max_size = cfg["data"].get("trg_voc_limit", sys.maxsize) trg_min_freq = cfg["data"].get("trg_voc_min_freq", 1) src_vocab_file = cfg["data"].get("src_vocab", None) trg_vocab_file = cfg["data"].get("trg_vocab", None) src_vocab = build_vocab(field="src", min_freq=src_min_freq, max_size=src_max_size, dataset=train_tasks_list[0], vocab_file=src_vocab_file) trg_vocab = build_vocab(field="trg", min_freq=trg_min_freq, max_size=trg_max_size, dataset=train_tasks_list[0], vocab_file=trg_vocab_file) # build an encoder-decoder model model = build_model(cfg["model"], src_vocab=src_vocab, trg_vocab=src_vocab) # for training management, e.g. early stopping and model selection trainer = TrainManager(model=model, config=cfg) # store copy of original training config in model dir shutil.copy2(cfg_file, model_dir + "/config.yaml") # log all entries of config log_cfg(cfg) # log_data_info(train_data=train_data, # valid_data=dev_data, # test_data=test_data, # src_vocab=src_vocab, # trg_vocab=trg_vocab) logger.info(str(model)) # store the vocabs src_vocab_file = "{}/src_vocab.txt".format(cfg["training"]["model_dir"]) src_vocab.to_file(src_vocab_file) trg_vocab_file = "{}/trg_vocab.txt".format(cfg["training"]["model_dir"]) trg_vocab.to_file(trg_vocab_file) # train the model trainer.maml_train_and_validate(train_tasks=train_tasks_list, valid_tasks=valid_tasks_list) # predict with the best model on validation and test # (if test data is available) ckpt = "{}/{}.ckpt".format(model_dir, trainer.stats.best_ckpt_iter) output_name = "{:08d}.hyps".format(trainer.stats.best_ckpt_iter) output_path = os.path.join(model_dir, output_name)