def setUp(self): self.file = "test/data/toy/train.de" sent = "Die Wahrheit ist, dass die Titanic – obwohl sie alle " \ "Kinokassenrekorde bricht – nicht gerade die aufregendste " \ "Geschichte vom Meer ist." self.word_list = sent.split() # only unique tokens self.char_list = list(sent) self.temp_file_char = "tmp.src.char" self.temp_file_word = "tmp.src.word" self.word_vocab = Vocabulary(tokens=sorted((set(self.word_list)))) self.char_vocab = Vocabulary(tokens=sorted(set(self.char_list)))
def testVocabularyFromFile(self): # write vocabs to file and create new ones from those files self.word_vocab.to_file(self.temp_file_word) self.char_vocab.to_file(self.temp_file_char) word_vocab2 = Vocabulary.from_file(self.temp_file_word) char_vocab2 = Vocabulary.from_file(self.temp_file_char) self.assertEqual(self.word_vocab.itos, word_vocab2.itos) self.assertEqual(self.char_vocab.itos, char_vocab2.itos) os.remove(self.temp_file_char) os.remove(self.temp_file_word)
class TestVocabulary(unittest.TestCase): def setUp(self): self.file = "test/data/toy/train.de" sent = "Die Wahrheit ist, dass die Titanic – obwohl sie alle " \ "Kinokassenrekorde bricht – nicht gerade die aufregendste " \ "Geschichte vom Meer ist. GROẞ" # ẞ (in uppercase) requires Unicode self.word_list = sent.split() # only unique tokens self.char_list = list(sent) self.temp_file_char = "tmp.src.char" self.temp_file_word = "tmp.src.word" self.word_vocab = Vocabulary(tokens=sorted(list(set(self.word_list)))) self.char_vocab = Vocabulary(tokens=sorted(list(set(self.char_list)))) def testVocabularyFromList(self): self.assertEqual( len(self.word_vocab) - len(self.word_vocab.specials), len(set(self.word_list))) self.assertEqual( len(self.char_vocab) - len(self.char_vocab.specials), len(set(self.char_list))) expected_char_itos = [ '<unk>', '<pad>', '<s>', '</s>', ' ', ',', '.', 'D', 'G', 'K', 'M', 'O', 'R', 'T', 'W', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'k', 'l', 'm', 'n', 'o', 'r', 's', 't', 'u', 'v', 'w', 'ẞ', '–' ] self.assertEqual(self.char_vocab.itos, expected_char_itos) expected_word_itos = [ '<unk>', '<pad>', '<s>', '</s>', 'Die', 'GROẞ', 'Geschichte', 'Kinokassenrekorde', 'Meer', 'Titanic', 'Wahrheit', 'alle', 'aufregendste', 'bricht', 'dass', 'die', 'gerade', 'ist,', 'ist.', 'nicht', 'obwohl', 'sie', 'vom', '–' ] self.assertEqual(self.word_vocab.itos, expected_word_itos) def testVocabularyFromFile(self): # write vocabs to file and create new ones from those files self.word_vocab.to_file(self.temp_file_word) self.char_vocab.to_file(self.temp_file_char) word_vocab2 = Vocabulary(file=self.temp_file_word) char_vocab2 = Vocabulary(file=self.temp_file_char) self.assertEqual(self.word_vocab.itos, word_vocab2.itos) self.assertEqual(self.char_vocab.itos, char_vocab2.itos) os.remove(self.temp_file_char) os.remove(self.temp_file_word) def testIsUnk(self): self.assertTrue(self.word_vocab.is_unk("BLA")) self.assertFalse(self.word_vocab.is_unk("Die")) self.assertFalse(self.word_vocab.is_unk("GROẞ")) self.assertTrue(self.char_vocab.is_unk("x")) self.assertFalse(self.char_vocab.is_unk("d")) self.assertFalse(self.char_vocab.is_unk("ẞ"))
def setUp(self): self.seed = 42 vocab_size = 30 tokens = ["tok{:02d}".format(i) for i in range(vocab_size)] self.vocab = Vocabulary(tokens=tokens) self.cfg = { "model": { "tied_embeddings": False, "tied_softmax": False, "encoder": { "type": "recurrent", "hidden_size": 64, "embeddings": { "embedding_dim": 32 }, "num_layers": 1, }, "decoder": { "type": "recurrent", "hidden_size": 64, "embeddings": { "embedding_dim": 32 }, "num_layers": 1, }, } }
def build_vocab(field, max_size, min_freq, data, vocab_file=None): """ Builds vocabulary for a torchtext `field` :param field: :param max_size: :param min_freq: :param data: :param vocab_file: :return: """ # special symbols specials = [UNK_TOKEN, PAD_TOKEN, BOS_TOKEN, EOS_TOKEN] if vocab_file is not None: # load it from file vocab = Vocabulary(file=vocab_file) vocab.add_tokens(specials) else: # create newly def filter_min(counter, min_freq): """ Filter counter by min frequency """ filtered_counter = Counter({t: c for t, c in counter.items() if c >= min_freq}) return filtered_counter def sort_and_cut(counter, limit): """ Cut counter to most frequent, sorted numerically and alphabetically""" # sort by frequency, then alphabetically tokens_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0]) tokens_and_frequencies.sort(key=lambda tup: tup[1], reverse=True) vocab_tokens = [i[0] for i in tokens_and_frequencies[:limit]] return vocab_tokens tokens = [] for i in data.examples: if field == "src": tokens.extend(i.src) elif field == "trg": tokens.extend(i.trg) counter = Counter(tokens) if min_freq > -1: counter = filter_min(counter, min_freq) vocab_tokens = specials + sort_and_cut(counter, max_size) assert vocab_tokens[DEFAULT_UNK_ID()] == UNK_TOKEN assert len(vocab_tokens) <= max_size + len(specials) vocab = Vocabulary(tokens=vocab_tokens) # check for all except for UNK token whether they are OOVs for s in specials[1:]: assert not vocab.is_unk(s) return vocab
def setUp(self): self.emb_size = 12 self.num_layers = 3 self.hidden_size = 12 self.ff_size = 24 self.num_heads = 4 self.dropout = 0. self.encoder_hidden_size = 3 self.vocab = Vocabulary(tokens=['word']) self.vocab_size = len(self.vocab) # = 5 seed = 42 torch.manual_seed(seed) #self.bos_index = 2 self.pad_index = 1
def translate(cfg_file: str, ckpt: str, output_path: str = None, batch_class: Batch = Batch, n_best: int = 1) -> None: """ Interactive translation function. Loads model from checkpoint and translates either the stdin input or asks for input to translate interactively. The input has to be pre-processed according to the data that the model was trained on, i.e. tokenized or split into subwords. Translations are printed to stdout. :param cfg_file: path to configuration file :param ckpt: path to checkpoint to load :param output_path: path to output file :param batch_class: class type of batch :param n_best: amount of candidates to display """ def _load_line_as_data(line): """ Create a dataset from one line via a temporary file. """ # write src input to temporary file tmp_name = "tmp" tmp_suffix = ".src" tmp_filename = tmp_name + tmp_suffix with open(tmp_filename, "w") as tmp_file: tmp_file.write("{}\n".format(line)) test_data = MonoDataset(path=tmp_name, ext=tmp_suffix, field=src_field) # remove temporary file if os.path.exists(tmp_filename): os.remove(tmp_filename) return test_data def _translate_data(test_data): """ Translates given dataset, using parameters from outer scope. """ # pylint: disable=unused-variable score, loss, ppl, sources, sources_raw, references, hypotheses, \ hypotheses_raw, attention_scores = validate_on_data( model, data=test_data, batch_size=batch_size, batch_class=batch_class, batch_type=batch_type, level=level, max_output_length=max_output_length, eval_metric="", use_cuda=use_cuda, compute_loss=False, beam_size=beam_size, beam_alpha=beam_alpha, postprocess=postprocess, bpe_type=bpe_type, sacrebleu=sacrebleu, n_gpu=n_gpu, n_best=n_best) return hypotheses cfg = load_config(cfg_file) model_dir = cfg["training"]["model_dir"] _ = make_logger(model_dir, mode="translate") # version string returned # when checkpoint is not specified, take oldest from model dir if ckpt is None: ckpt = get_latest_checkpoint(model_dir) # read vocabs src_vocab_file = cfg["data"].get("src_vocab", model_dir + "/src_vocab.txt") trg_vocab_file = cfg["data"].get("trg_vocab", model_dir + "/trg_vocab.txt") src_vocab = Vocabulary(file=src_vocab_file) trg_vocab = Vocabulary(file=trg_vocab_file) data_cfg = cfg["data"] level = data_cfg["level"] lowercase = data_cfg["lowercase"] tok_fun = lambda s: list(s) if level == "char" else s.split() src_field = Field(init_token=None, eos_token=EOS_TOKEN, pad_token=PAD_TOKEN, tokenize=tok_fun, batch_first=True, lower=lowercase, unk_token=UNK_TOKEN, include_lengths=True) src_field.vocab = src_vocab # parse test args batch_size, batch_type, use_cuda, device, n_gpu, level, _, \ max_output_length, beam_size, beam_alpha, postprocess, \ bpe_type, sacrebleu, _, _ = parse_test_args(cfg, mode="translate") # load model state from disk model_checkpoint = load_checkpoint(ckpt, use_cuda=use_cuda) # build model and load parameters into it model = build_model(cfg["model"], src_vocab=src_vocab, trg_vocab=trg_vocab) model.load_state_dict(model_checkpoint["model_state"]) if use_cuda: model.to(device) if not sys.stdin.isatty(): # input file given test_data = MonoDataset(path=sys.stdin, ext="", field=src_field) all_hypotheses = _translate_data(test_data) if output_path is not None: # write to outputfile if given def write_to_file(output_path_set, hypotheses): with open(output_path_set, mode="w", encoding="utf-8") \ as out_file: for hyp in hypotheses: out_file.write(hyp + "\n") logger.info("Translations saved to: %s.", output_path_set) if n_best > 1: for n in range(n_best): file_name, file_extension = os.path.splitext(output_path) write_to_file( "{}-{}{}".format( file_name, n, file_extension if file_extension else ""), [ all_hypotheses[i] for i in range(n, len(all_hypotheses), n_best) ]) else: write_to_file("{}".format(output_path), all_hypotheses) else: # print to stdout for hyp in all_hypotheses: print(hyp) else: # enter interactive mode batch_size = 1 batch_type = "sentence" while True: try: src_input = input("\nPlease enter a source sentence " "(pre-processed): \n") if not src_input.strip(): break # every line has to be made into dataset test_data = _load_line_as_data(line=src_input) hypotheses = _translate_data(test_data) print("JoeyNMT: Hypotheses ranked by score") for i, hyp in enumerate(hypotheses): print("JoeyNMT #{}: {}".format(i + 1, hyp)) except (KeyboardInterrupt, EOFError): print("\nBye.") break
def translate(cfg_file, ckpt: str, output_path: str = None) -> None: """ Interactive translation function. Loads model from checkpoint and translates either the stdin input or asks for input to translate interactively. The input has to be pre-processed according to the data that the model was trained on, i.e. tokenized or split into subwords. Translations are printed to stdout. :param cfg_file: path to configuration file :param ckpt: path to checkpoint to load :param output_path: path to output file """ def _load_line_as_data(line): """ Create a dataset from one line via a temporary file. """ # write src input to temporary file tmp_name = "tmp" tmp_suffix = ".src" tmp_filename = tmp_name + tmp_suffix with open(tmp_filename, "w") as tmp_file: tmp_file.write("{}\n".format(line)) test_data = MonoDataset(path=tmp_name, ext=tmp_suffix, field=src_field) # remove temporary file if os.path.exists(tmp_filename): os.remove(tmp_filename) return test_data logger = make_logger() def _translate_data(test_data): """ Translates given dataset, using parameters from outer scope. """ # pylint: disable=unused-variable score, loss, ppl, sources, sources_raw, references, hypotheses, \ hypotheses_raw, attention_scores = validate_on_data( model, data=test_data, batch_size=batch_size, batch_type=batch_type, level=level, max_output_length=max_output_length, eval_metric="", use_cuda=use_cuda, loss_function=None, beam_size=beam_size, beam_alpha=beam_alpha, logger=logger, postprocess=postprocess) return hypotheses cfg = load_config(cfg_file) # when checkpoint is not specified, take oldest from model dir if ckpt is None: model_dir = cfg["training"]["model_dir"] ckpt = get_latest_checkpoint(model_dir) batch_size = cfg["training"].get("eval_batch_size", cfg["training"].get("batch_size", 1)) batch_type = cfg["training"].get( "eval_batch_type", cfg["training"].get("batch_type", "sentence")) use_cuda = cfg["training"].get("use_cuda", False) level = cfg["data"]["level"] max_output_length = cfg["training"].get("max_output_length", None) # read vocabs src_vocab_file = cfg["data"].get( "src_vocab", cfg["training"]["model_dir"] + "/src_vocab.txt") trg_vocab_file = cfg["data"].get( "trg_vocab", cfg["training"]["model_dir"] + "/trg_vocab.txt") src_vocab = Vocabulary(file=src_vocab_file) trg_vocab = Vocabulary(file=trg_vocab_file) data_cfg = cfg["data"] level = data_cfg["level"] lowercase = data_cfg["lowercase"] tok_fun = lambda s: list(s) if level == "char" else s.split() src_field = Field(init_token=None, eos_token=EOS_TOKEN, pad_token=PAD_TOKEN, tokenize=tok_fun, batch_first=True, lower=lowercase, unk_token=UNK_TOKEN, include_lengths=True) src_field.vocab = src_vocab # load model state from disk model_checkpoint = load_checkpoint(ckpt, use_cuda=use_cuda) # build model and load parameters into it model = build_model(cfg["model"], src_vocab=src_vocab, trg_vocab=trg_vocab) model.load_state_dict(model_checkpoint["model_state"]) if use_cuda: model.cuda() # whether to use beam search for decoding, <2: greedy decoding if "testing" in cfg.keys(): beam_size = cfg["testing"].get("beam_size", 1) beam_alpha = cfg["testing"].get("alpha", -1) postprocess = cfg["testing"].get("postprocess", True) else: beam_size = 1 beam_alpha = -1 postprocess = True if not sys.stdin.isatty(): # input file given test_data = MonoDataset(path=sys.stdin, ext="", field=src_field) hypotheses = _translate_data(test_data) if output_path is not None: # write to outputfile if given output_path_set = "{}".format(output_path) with open(output_path_set, mode="w", encoding="utf-8") as out_file: for hyp in hypotheses: out_file.write(hyp + "\n") logger.info("Translations saved to: %s.", output_path_set) else: # print to stdout for hyp in hypotheses: print(hyp) else: # enter interactive mode batch_size = 1 batch_type = "sentence" while True: try: src_input = input("\nPlease enter a source sentence " "(pre-processed): \n") if not src_input.strip(): break # every line has to be made into dataset test_data = _load_line_as_data(line=src_input) hypotheses = _translate_data(test_data) print("JoeyNMT: {}".format(hypotheses[0])) except (KeyboardInterrupt, EOFError): print("\nBye.") break
def translate(cfg_file, ckpt: str, output_path: str = None) -> None: """ Interactive translation function. Loads model from checkpoint and translates either the stdin input or asks for input to translate interactively. The input has to be pre-processed according to the data that the model was trained on, i.e. tokenized or split into subwords. Translations are printed to stdout. :param cfg_file: path to configuration file :param ckpt: path to checkpoint to load """ def _load_line_as_data(line): """ Create a dataset from one line via a temporary file. """ # write src input to temporary file tmp_name = "tmp" tmp_suffix = ".src" tmp_filename = tmp_name+tmp_suffix with open(tmp_filename, "w") as tmp_file: tmp_file.write("{}\n".format(line)) test_data = MonoDataset(path=tmp_name, ext=tmp_suffix, field=src_field) # remove temporary file if os.path.exists(tmp_filename): os.remove(tmp_filename) return test_data def _translate_data(test_data): """ Translates given dataset, using parameters from outer scope. """ # pylint: disable=unused-variable _, _, _, _, hypotheses, _, _, _, _ = validate_on_data( model, data=test_data, batch_size=batch_size, level=level, max_output_length=max_output_length, eval_metrics=[], use_cuda=use_cuda, loss_function=None, beam_size=beam_size, beam_alpha=beam_alpha) return hypotheses cfg = load_config(cfg_file) # when checkpoint is not specified, take oldest from model dir if ckpt is None: model_dir = cfg["training"]["model_dir"] ckpt = get_latest_checkpoint(model_dir) data_cfg = cfg["data"] batch_size = cfg["training"].get("batch_size", 1) use_cuda = cfg["training"].get("use_cuda", False) max_output_length = cfg["training"].get("max_output_length", None) # read vocabs # This will need to change: currently translate does not support inflection src_vocab_file = data_cfg.get( "src_vocab", cfg["training"]["model_dir"] + "/src_vocab.txt") trg_vocab_file = data_cfg.get( "trg_vocab", cfg["training"]["model_dir"] + "/trg_vocab.txt") src_vocab = Vocabulary(file=src_vocab_file) trg_vocab = Vocabulary(file=trg_vocab_file) vocabs = {"src": src_vocab, "trg": trg_vocab} level = data_cfg["level"] lowercase = data_cfg["lowercase"] tok_fun = list if level == "char" else str.split src_field = Field(init_token=None, eos_token=EOS_TOKEN, pad_token=PAD_TOKEN, tokenize=tok_fun, batch_first=True, lower=lowercase, unk_token=UNK_TOKEN, include_lengths=True) src_field.vocab = src_vocab # load model state from disk model_checkpoint = load_checkpoint(ckpt, use_cuda=use_cuda) # build model and load parameters into it model = build_model(cfg["model"], vocabs=vocabs) model.load_state_dict(model_checkpoint["model_state"]) if use_cuda: model.cuda() # whether to use beam search for decoding, 0: greedy decoding if "testing" in cfg.keys(): beam_size = cfg["testing"].get("beam_size", 0) beam_alpha = cfg["testing"].get("alpha", 0) else: beam_size = 0 beam_alpha = 0 if beam_alpha < 0: raise ConfigurationError("alpha for length penalty should be >= 0") if not sys.stdin.isatty(): # file given test_data = MonoDataset(path=sys.stdin, ext="", field=src_field) hypotheses = _translate_data(test_data) if output_path is not None: output_path_set = "{}".format(output_path) with open(output_path_set, mode="w", encoding="utf-8") as out_file: for hyp in hypotheses: out_file.write(hyp + "\n") print("Translations saved to: {}".format(output_path_set)) else: for hyp in hypotheses: print(hyp) else: # enter interactive mode batch_size = 1 while True: try: src_input = input("\nPlease enter a source sentence " "(pre-processed): \n") if not src_input.strip(): break # every line has to be made into dataset test_data = _load_line_as_data(line=src_input) hypotheses = _translate_data(test_data) print("JoeyNMT: {}".format(hypotheses[0])) except (KeyboardInterrupt, EOFError): print("\nBye.") break
def __init__(self, data_path: str, split_file_name: str, ann_file_name: str, transform=None, fix_length: int = None, max_vocab_size: int = None, all_lower: bool = False): """ Flickr Dataset class to use with dataloader :param data_path: Dataset directory :param split_file_name: File listing all used images in split - vary this parameter for train/test split :param ann_file_name: File containing annotation tokens :param transform: Torchvision transforms object to be applied on the images :param fix_length: Pads caption fix_length if provided, otherwise pads to the length of the longest example in the batch :param max_vocab_size: The maximum size of the vocabulary, or None for no maximum :param all_lower: Set this to convert all tokens to lower case """ self.root = os.path.expanduser(data_path) self.ann_file = os.path.expanduser(ann_file_name) self.transform = transform self.max_vocab_size = max_vocab_size self.idx2image = [] self.idx2caption = [] self.idx2caption_no_padding = [] self.image_name2idxs = defaultdict(list) self.lengths = dict() # Get image file names for chosen TRAIN/DEV/TEST data valid_image_file_names = set( [line.rstrip() for line in open(split_file_name, 'r')]) annotations = [line.rstrip() for line in open(ann_file_name, 'r')] valid_counter = 0 # Loop through all annotations, as they are not separated per fraction (train/dev/test). for annotation in annotations: image_file_name, caption = annotation.split('\t') # Only choose the captions for images, which are part of the current fraction defined. if image_file_name[:-2] in valid_image_file_names: # In case this option is enabled, convert all tokens in lower letters. if all_lower: caption = caption.lower().split() self.idx2caption.append(caption) else: caption = caption.split() self.idx2caption.append(caption) # Store each caption id corresponding a caption length in a dictionary # ...this can be used to sample batches of equal size if len(caption) not in self.lengths: self.lengths[len(caption)] = [valid_counter] else: self.lengths[len(caption)].append(valid_counter) self.idx2caption_no_padding.append(caption) self.idx2image.append(image_file_name[:-2]) self.image_name2idxs[image_file_name[:-2]].append( len(self.idx2image) - 1) valid_counter += 1 self.corpus = data.Field(init_token=BOS_TOKEN, eos_token=EOS_TOKEN, pad_token=PAD_TOKEN, unk_token=UNK_TOKEN, fix_length=fix_length) self.max_length = max(list(self.lengths.keys())) # Pad captions self.idx2caption = self.corpus.pad(self.idx2caption) # Select the most-frequently used tokens (top max_vocab_size) and build vocabulary object counter = Counter(list(itertools.chain(*self.idx2caption))) vocab_tokens = sort_and_cut(counter, self.max_vocab_size) self.corpus.vocab = Vocabulary(tokens=vocab_tokens)
def load_data(data_cfg: dict): """ Load train, dev and optionally test data as specified in configuration. Vocabularies are created from the training set with a limit of `voc_limit` tokens and a minimum token frequency of `voc_min_freq` (specified in the configuration dictionary). The training data is filtered to include sentences up to `max_sent_length` on source and target side. :param data_cfg: configuration dictionary for data ("data" part of configuation file) :return: - train_data: training dataset - dev_data: development dataset - test_data: testdata set if given, otherwise None - vocabs: dictionary from src and trg (and possibly other fields) to their corresponding vocab objects """ data_format = data_cfg.get("format", "bitext") formats = {"bitext", "tsv"} assert data_format in formats train_path = data_cfg["train"] dev_path = data_cfg["dev"] test_path = data_cfg.get("test", None) lowercase = data_cfg["lowercase"] max_sent_length = data_cfg["max_sent_length"] default_level = data_cfg.get("level", "word") voc_limit = data_cfg.get("voc_limit", sys.maxsize) voc_min_freq = data_cfg.get("voc_min_freq", 1) tokenizers = { "word": str.split, "bpe": str.split, "char": list, "char-decompose": decompose_tokenize, "tag": partial(str.split, sep=";") } # is field_names better? # column_labels seems like an ok name for the all_fields = data_cfg.get("columns", ["src", "trg"]) label_fields = data_cfg.get("label_fields", []) assert all(label in all_fields for label in label_fields) sequential_fields = [ field for field in all_fields if field not in label_fields ] src_fields = [f_name for f_name in all_fields if f_name != "trg"] trg_fields = ["trg"] suffixes = { f_name: data_cfg.get(f_name, "") for f_name in sequential_fields } seq_field_cls = partial(Field, eos_token=EOS_TOKEN, pad_token=PAD_TOKEN, unk_token=UNK_TOKEN, batch_first=True, lower=lowercase, include_lengths=True) fields = dict() # what are the source fields? what are the target fields? for f_name in sequential_fields: bos = BOS_TOKEN if f_name == "trg" else None current_level = data_cfg.get(f_name + "_level", default_level) assert current_level in tokenizers, "Invalid tokenization level" tok_fun = tokenizers[current_level] fields[f_name] = seq_field_cls(init_token=bos, tokenize=tok_fun) for f_name in label_fields: fields[f_name] = Field(sequential=False) filter_ex = partial(filter_example, max_sent_length=max_sent_length, keys=tuple(fields.keys())) if data_format == "bitext": dataset_cls = partial(TranslationDataset, exts=("." + suffixes["src"], "." + suffixes["trg"]), fields=(fields["src"], fields["trg"])) else: dataset_cls = partial(TSVDataset, fields=fields, columns=all_fields, label_columns=label_fields) if test_path is not None: trg_suffix = suffixes["trg"] if data_format != "bitext" or isfile(test_path + "." + trg_suffix): test_dataset_cls = dataset_cls else: test_dataset_cls = partial(MonoDataset, ext="." + suffixes["src"], field=fields["src"]) else: test_dataset_cls = None train_data = dataset_cls(path=train_path, filter_pred=filter_ex) vocabs = dict() # here's the thing: you want to have a vocab for each f_name, but not # necessarily a *distinct* vocab share_src_vocabs = data_cfg.get("share_src_vocabs", False) if share_src_vocabs: field_groups = [src_fields, trg_fields] else: field_groups = [[f] for f in sequential_fields] for f_group in field_groups: if len(f_group) == 1: f_name = f_group[0] max_size = data_cfg.get("{}_voc_limit".format(f_name), voc_limit) min_freq = data_cfg.get("{}_voc_min_freq".format(f_name), voc_min_freq) vocab_file = data_cfg.get("{}_vocab".format(f_name), None) else: # multiple fields sharing a vocabulary max_size = voc_limit min_freq = voc_min_freq vocab_file = None if vocab_file is not None: f_vocab = Vocabulary.from_file(vocab_file) else: f_vocab = Vocabulary.from_dataset(train_data, f_group, max_size, min_freq) for f_name in f_group: vocabs[f_name] = f_vocab label_field_groups = [[lf] for lf in label_fields] for f_group in label_field_groups: if len(f_group) == 1: f_name = f_group[0] max_size = data_cfg.get("{}_voc_limit".format(f_name), voc_limit) min_freq = data_cfg.get("{}_voc_min_freq".format(f_name), voc_min_freq) vocab_file = data_cfg.get("{}_vocab".format(f_name), None) else: # multiple fields sharing a vocabulary max_size = voc_limit min_freq = voc_min_freq vocab_file = None if vocab_file is not None: f_vocab = Vocabulary.from_file(vocab_file) else: f_vocab = Vocabulary.from_dataset(train_data, f_group, max_size, min_freq, sequential=False) for f_name in f_group: vocabs[f_name] = f_vocab ''' for vocab_name, vocab in vocabs.items(): print(vocab_name) print(vocab.itos) print() ''' dev_data = dataset_cls(path=dev_path) if test_path is not None: trg_suffix = suffixes["trg"] if data_format != "bitext" or isfile(test_path + "." + trg_suffix): test_dataset_cls = dataset_cls else: test_dataset_cls = partial(MonoDataset, ext="." + suffixes["src"], field=fields["src"]) test_data = test_dataset_cls(path=test_path) else: test_data = None for field_name in fields: fields[field_name].vocab = vocabs[field_name] ret = { "train_data": train_data, "dev_data": dev_data, "test_data": test_data, "vocabs": vocabs } return ret