def testVocabularyFromFile(self): # write vocabs to file and create new ones from those files self.word_vocab.to_file(self.temp_file_word) self.char_vocab.to_file(self.temp_file_char) word_vocab2 = Vocabulary.from_file(self.temp_file_word) char_vocab2 = Vocabulary.from_file(self.temp_file_char) self.assertEqual(self.word_vocab.itos, word_vocab2.itos) self.assertEqual(self.char_vocab.itos, char_vocab2.itos) os.remove(self.temp_file_char) os.remove(self.temp_file_word)
def translate(cfg_file, ckpt: str, output_path: str = None) -> None: """ Interactive translation function. Loads model from checkpoint and translates either the stdin input or asks for input to translate interactively. The input has to be pre-processed according to the data that the model was trained on, i.e. tokenized or split into subwords. Translations are printed to stdout. :param cfg_file: path to configuration file :param ckpt: path to checkpoint to load """ def _load_line_as_data(line): """ Create a dataset from one line via a temporary file. """ # write src input to temporary file tmp_name = "tmp" tmp_suffix = ".src" tmp_filename = tmp_name + tmp_suffix with open(tmp_filename, "w") as tmp_file: tmp_file.write("{}\n".format(line)) test_data = MonoDataset(path=tmp_name, ext=tmp_suffix, field=src_field) # remove temporary file if os.path.exists(tmp_filename): os.remove(tmp_filename) return test_data def _translate_data(test_data): """ Translates given dataset, using parameters from outer scope. """ hypotheses = validate_on_data(model, data=test_data, batch_size=batch_size, trg_level=level, max_output_length=max_output_length, eval_metrics=[], use_cuda=use_cuda, loss_function=None, beam_size=beam_size, beam_alpha=beam_alpha)[2] return hypotheses cfg = load_config(cfg_file) # when checkpoint is not specified, take oldest from model dir if ckpt is None: model_dir = cfg["training"]["model_dir"] ckpt = get_latest_checkpoint(model_dir) batch_size = cfg["training"].get("batch_size", 1) use_cuda = cfg["training"].get("use_cuda", False) level = cfg["data"]["level"] max_output_length = cfg["training"].get("max_output_length", None) # read vocabs src_vocab_file = cfg["data"].get( "src_vocab", cfg["training"]["model_dir"] + "/src_vocab.txt") trg_vocab_file = cfg["data"].get( "trg_vocab", cfg["training"]["model_dir"] + "/trg_vocab.txt") src_vocab = Vocabulary.from_file(src_vocab_file) trg_vocab = Vocabulary.from_file(trg_vocab_file) data_cfg = cfg["data"] level = data_cfg["level"] lowercase = data_cfg["lowercase"] tok_fun = list if level == "char" else str.split src_field = Field(init_token=None, eos_token=EOS_TOKEN, pad_token=PAD_TOKEN, tokenize=tok_fun, batch_first=True, lower=lowercase, unk_token=UNK_TOKEN, include_lengths=True) src_field.vocab = src_vocab # load model state from disk model_checkpoint = load_checkpoint(ckpt, use_cuda=use_cuda) # build model and load parameters into it model = build_model(cfg["model"], vocabs={ "src": src_vocab, "trg": trg_vocab }) model.load_state_dict(model_checkpoint["model_state"]) if use_cuda: model.cuda() # whether to use beam search for decoding, 0: greedy decoding if "testing" in cfg.keys(): beam_size = cfg["testing"].get("beam_size", 0) beam_alpha = cfg["testing"].get("alpha", 0) else: beam_size = 0 beam_alpha = 0 if beam_alpha < 0: raise ConfigurationError("alpha for length penalty should be >= 0") if not sys.stdin.isatty(): # file given test_data = MonoDataset(path=sys.stdin, ext="", field=src_field) hypotheses = _translate_data(test_data) if output_path is not None: output_path_set = "{}".format(output_path) with open(output_path_set, mode="w", encoding="utf-8") as out_file: for hyp in hypotheses: out_file.write(hyp + "\n") print("Translations saved to: {}".format(output_path_set)) else: for hyp in hypotheses: print(hyp) else: # enter interactive mode batch_size = 1 while True: try: src_input = input("\nPlease enter a source sentence " "(pre-processed): \n") if not src_input.strip(): break # every line has to be made into dataset test_data = _load_line_as_data(line=src_input) hypotheses = _translate_data(test_data) print("JoeyNMT: {}".format(hypotheses[0])) except (KeyboardInterrupt, EOFError): print("\nBye.") break
def load_data(data_cfg: dict): """ Load train, dev and optionally test data as specified in configuration. Vocabularies are created from the training set with a limit of `voc_limit` tokens and a minimum token frequency of `voc_min_freq` (specified in the configuration dictionary). The training data is filtered to include sentences up to `max_sent_length` on source and target side. :param data_cfg: configuration dictionary for data ("data" part of configuation file) :return: - train_data: training dataset - dev_data: development dataset - test_data: testdata set if given, otherwise None - vocabs: dictionary from src and trg (and possibly other fields) to their corresponding vocab objects """ data_format = data_cfg.get("format", "bitext") formats = {"bitext", "tsv"} assert data_format in formats train_path = data_cfg["train"] dev_path = data_cfg["dev"] test_path = data_cfg.get("test", None) lowercase = data_cfg["lowercase"] max_sent_length = data_cfg["max_sent_length"] default_level = data_cfg.get("level", "word") voc_limit = data_cfg.get("voc_limit", sys.maxsize) voc_min_freq = data_cfg.get("voc_min_freq", 1) tokenizers = { "word": str.split, "bpe": str.split, "char": list, "char-decompose": decompose_tokenize, "tag": partial(str.split, sep=";") } # is field_names better? # column_labels seems like an ok name for the all_fields = data_cfg.get("columns", ["src", "trg"]) label_fields = data_cfg.get("label_fields", []) assert all(label in all_fields for label in label_fields) sequential_fields = [ field for field in all_fields if field not in label_fields ] src_fields = [f_name for f_name in all_fields if f_name != "trg"] trg_fields = ["trg"] suffixes = { f_name: data_cfg.get(f_name, "") for f_name in sequential_fields } seq_field_cls = partial(Field, eos_token=EOS_TOKEN, pad_token=PAD_TOKEN, unk_token=UNK_TOKEN, batch_first=True, lower=lowercase, include_lengths=True) fields = dict() # what are the source fields? what are the target fields? for f_name in sequential_fields: bos = BOS_TOKEN if f_name == "trg" else None current_level = data_cfg.get(f_name + "_level", default_level) assert current_level in tokenizers, "Invalid tokenization level" tok_fun = tokenizers[current_level] fields[f_name] = seq_field_cls(init_token=bos, tokenize=tok_fun) for f_name in label_fields: fields[f_name] = Field(sequential=False) filter_ex = partial(filter_example, max_sent_length=max_sent_length, keys=tuple(fields.keys())) if data_format == "bitext": dataset_cls = partial(TranslationDataset, exts=("." + suffixes["src"], "." + suffixes["trg"]), fields=(fields["src"], fields["trg"])) else: dataset_cls = partial(TSVDataset, fields=fields, columns=all_fields, label_columns=label_fields) if test_path is not None: trg_suffix = suffixes["trg"] if data_format != "bitext" or isfile(test_path + "." + trg_suffix): test_dataset_cls = dataset_cls else: test_dataset_cls = partial(MonoDataset, ext="." + suffixes["src"], field=fields["src"]) else: test_dataset_cls = None train_data = dataset_cls(path=train_path, filter_pred=filter_ex) vocabs = dict() # here's the thing: you want to have a vocab for each f_name, but not # necessarily a *distinct* vocab share_src_vocabs = data_cfg.get("share_src_vocabs", False) if share_src_vocabs: field_groups = [src_fields, trg_fields] else: field_groups = [[f] for f in sequential_fields] for f_group in field_groups: if len(f_group) == 1: f_name = f_group[0] max_size = data_cfg.get("{}_voc_limit".format(f_name), voc_limit) min_freq = data_cfg.get("{}_voc_min_freq".format(f_name), voc_min_freq) vocab_file = data_cfg.get("{}_vocab".format(f_name), None) else: # multiple fields sharing a vocabulary max_size = voc_limit min_freq = voc_min_freq vocab_file = None if vocab_file is not None: f_vocab = Vocabulary.from_file(vocab_file) else: f_vocab = Vocabulary.from_dataset(train_data, f_group, max_size, min_freq) for f_name in f_group: vocabs[f_name] = f_vocab label_field_groups = [[lf] for lf in label_fields] for f_group in label_field_groups: if len(f_group) == 1: f_name = f_group[0] max_size = data_cfg.get("{}_voc_limit".format(f_name), voc_limit) min_freq = data_cfg.get("{}_voc_min_freq".format(f_name), voc_min_freq) vocab_file = data_cfg.get("{}_vocab".format(f_name), None) else: # multiple fields sharing a vocabulary max_size = voc_limit min_freq = voc_min_freq vocab_file = None if vocab_file is not None: f_vocab = Vocabulary.from_file(vocab_file) else: f_vocab = Vocabulary.from_dataset(train_data, f_group, max_size, min_freq, sequential=False) for f_name in f_group: vocabs[f_name] = f_vocab ''' for vocab_name, vocab in vocabs.items(): print(vocab_name) print(vocab.itos) print() ''' dev_data = dataset_cls(path=dev_path) if test_path is not None: trg_suffix = suffixes["trg"] if data_format != "bitext" or isfile(test_path + "." + trg_suffix): test_dataset_cls = dataset_cls else: test_dataset_cls = partial(MonoDataset, ext="." + suffixes["src"], field=fields["src"]) test_data = test_dataset_cls(path=test_path) else: test_data = None for field_name in fields: fields[field_name].vocab = vocabs[field_name] ret = { "train_data": train_data, "dev_data": dev_data, "test_data": test_data, "vocabs": vocabs } return ret