def test_build_train_datasets_valid(data_opts_17): datasets = build_training_datasets( fieldset=build_fieldset(), **vars(data_opts_17) ) assert len(datasets) == 2 for dataset in datasets: assert type(dataset) == QEDataset
def test_build_train_datasets_no_valid(data_opts_no_validation, atol): datasets = build_training_datasets( fieldset=build_fieldset(), **vars(data_opts_no_validation) ) assert len(datasets) == 2 for dataset in datasets: assert type(dataset) == QEDataset train_size, dev_size = len(datasets[0]), len(datasets[1]) np.testing.assert_allclose( train_size / (train_size + dev_size), data_opts_no_validation.split, atol=atol, )
def check_qe_dataset(options): train_dataset, dev_dataset = build_training_datasets( fieldset=build_fieldset(), **vars(options) ) train_iter = build_bucket_iterator( train_dataset, batch_size=8, is_train=True, device=None ) dev_iter = build_bucket_iterator( dev_dataset, batch_size=8, is_train=False, device=None ) for batch_train, batch_dev in zip(train_iter, dev_iter): train_source = getattr(batch_train, constants.SOURCE) if isinstance(train_source, tuple): train_source, lenghts = train_source train_source.t() train_prev_len = train_source.shape[1] # buckets should be sorted in decreasing length order # so we can use pack/padded sequences for train_sample in train_source: train_mask = train_sample != constants.PAD_ID train_cur_len = train_mask.int().sum().item() assert train_cur_len <= train_prev_len train_prev_len = train_cur_len source_field = train_dataset.fields[constants.SOURCE] target_field = train_dataset.fields[constants.TARGET] target_tags_field = train_dataset.fields[constants.TARGET_TAGS] # check if each token is in the vocab for train_sample, dev_sample in zip(train_dataset, dev_dataset): for word in getattr(train_sample, constants.SOURCE): assert word in source_field.vocab.stoi for word in getattr(train_sample, constants.TARGET): assert word in target_field.vocab.stoi for tag in getattr(train_sample, constants.TARGET_TAGS): assert tag in target_tags_field.vocab.stoi for word in getattr(dev_sample, constants.SOURCE): assert word in source_field.vocab.stoi for word in getattr(dev_sample, constants.TARGET): assert word in target_field.vocab.stoi for tag in getattr(dev_sample, constants.TARGET_TAGS): assert tag in target_tags_field.vocab.stoi
def fieldset(*args, **kwargs): return build_fieldset(*args, **kwargs)
def test_build_test_dataset(data_opts_17): dataset = build_test_dataset( fieldset=build_fieldset(), **vars(data_opts_17) ) assert type(dataset) == QEDataset