def build_test_dataset(fieldset, load_vocab=None, **kwargs): """Build a test QE dataset. Args: fieldset (Fieldset): specific set of fields to be used (depends on the model to be used.) load_vocab: A path to a saved vocabulary. Returns: A Dataset object. """ test_dataset = build_dataset(fieldset, prefix=Fieldset.TEST, **kwargs) fields_vocab_options = fieldset.fields_vocab_options(**kwargs) if load_vocab: vocab_path = Path(load_vocab) load_vocabularies_to_datasets(vocab_path, test_dataset) else: build_vocabulary(fields_vocab_options, test_dataset) return test_dataset
def build_training_datasets( fieldset, split=0.0, valid_source=None, valid_target=None, load_vocab=None, **kwargs, ): """Build a training and validation QE datasets. Required Args: fieldset (Fieldset): specific set of fields to be used (depends on the model to be used). train_source: Train Source train_target: Train Target (MT) Optional Args (depends on the model): train_pe: Train Post-edited train_target_tags: Train Target Tags train_source_tags: Train Source Tags train_sentence_scores: Train HTER scores valid_source: Valid Source valid_target: Valid Target (MT) valid_pe: Valid Post-edited valid_target_tags: Valid Target Tags valid_source_tags: Valid Source Tags valid_sentence_scores: Valid HTER scores split (float): If no validation sets are provided, randomly sample 1 - split of training examples as validation set. target_vocab_size: Maximum Size of target vocabulary source_vocab_size: Maximum Size of source vocabulary target_max_length: Maximum length for target field target_min_length: Minimum length for target field source_max_length: Maximum length for source field source_min_length: Minimum length for source field target_vocab_min_freq: Minimum word frequency target field source_vocab_min_freq: Minimum word frequency source field load_vocab: Path to existing vocab file Returns: A training and a validation Dataset. """ # TODO: improve handling these length options (defaults are set multiple # times). filter_pred = partial( filter_len, source_min_length=kwargs.get('source_min_length', 1), source_max_length=kwargs.get('source_max_length', float('inf')), target_min_length=kwargs.get('target_min_length', 1), target_max_length=kwargs.get('target_max_length', float('inf')), ) train_dataset = build_dataset(fieldset, prefix=Fieldset.TRAIN, filter_pred=filter_pred, **kwargs) if valid_source and valid_target: valid_dataset = build_dataset( fieldset, prefix=Fieldset.VALID, filter_pred=filter_pred, valid_source=valid_source, valid_target=valid_target, **kwargs, ) elif split: if not 0.0 < split < 1.0: raise Exception('Invalid data split value: {}; it must be in the ' '(0, 1) interval.'.format(split)) train_dataset, valid_dataset = train_dataset.split(split) else: raise Exception('Validation data not provided.') if load_vocab: vocab_path = Path(load_vocab) load_vocabularies_to_datasets(vocab_path, train_dataset, valid_dataset) # Even if vocab is loaded, we need to build the vocabulary # in case fields are missing datasets_for_vocab = [train_dataset] if kwargs.get('extend_source_vocab') or kwargs.get('extend_target_vocab'): vocabs_fieldset = extend_vocabs_fieldset.build_fieldset(fieldset) extend_vocabs_ds = build_dataset(vocabs_fieldset, **kwargs) datasets_for_vocab.append(extend_vocabs_ds) fields_vocab_options = fieldset.fields_vocab_options(**kwargs) build_vocabulary(fields_vocab_options, *datasets_for_vocab) return train_dataset, valid_dataset