コード例 #1
0
 def test_training_validation_split_working(self):
     kwargs = self.base_kwargs.copy()
     kwargs['validation_percent'] = 0.4
     fs3 = FileLoader(**kwargs)
     assert len(fs3) == 3
     assert len(fs3.get_training_datasets()) == 2
     assert len(fs3.get_validation_datasets()) == 1
コード例 #2
0
    def test_training_validation_loader_retrieval(self):

        kwargs8 = self.base_kwargs.copy()
        kwargs8['validation_percent'] = 0.4
        fs8 = FileLoader(**kwargs8)
        training_loader, validation_loader = fs8.get_dataset_loader()

        assert len(training_loader) == 2
        assert len(validation_loader) == 1
コード例 #3
0
    def test_support_for_custom_attributes(self):
        kwargs6 = self.base_kwargs.copy()
        kwargs6['custom_attributes'] = {
            'foo': lambda **kw: int(kw['THIRD']) * 10
        }
        kwargs6['id_format'] = kwargs6['id_format']+"-"+"{foo}"

        fs6 = FileLoader(**kwargs6)
        assert 'C-XYZ-10' in fs6.get_full_datasets()
コード例 #4
0
 def test_dynamic_types(self):
     kwargs = self.base_kwargs.copy()
     kwargs['input_source'] = self.tempf.name
     kwargs['dynamic_types'] = {}
     kwargs['dynamic_types']['FOO'] = CloneComponentGenerator(
         base_component='1')
     fs1 = FileLoader(**kwargs)
     datasets = fs1.get_training_datasets()
     assert len(fs1) == 3
     for dataset_id, dataset in datasets:
         assert 'FOO' in dataset
         handler = dataset['FOO']
         assert callable(handler)
コード例 #5
0
 def test_filter(self):
     kwargs4 = self.base_kwargs.copy()
     kwargs4[
         'filter'] = lambda dataset_id, match_components, dataset: not dataset_id.startswith(
             "C-")
     fs4 = FileLoader(**kwargs4)
     assert (len(fs4)) == 2
コード例 #6
0
 def test_pattern_partially_matching_input(self):
     kwargs7 = self.base_kwargs.copy()
     kwargs7['data_pattern'] = r'(?P<name>[0-9A-Ba-z]+)_(?P<SECOND>[A-Za-z0-9]+)_(?P<THIRD>[A-Za-z0-9]+)\.txt$'
     fs7 = FileLoader(**kwargs7)
     assert len(fs7) == 2
コード例 #7
0
 def test_missing_pattern_raises_error(self):
     kwargs5 = self.base_kwargs.copy()
     del kwargs5['data_pattern']
     with pytest.raises(ValueError) as excinfo:
         FileLoader(**kwargs5)
コード例 #8
0
 def test_detected_from_input_descriptor(self):
     self.tempf.seek(0)
     kwargs = self.base_kwargs.copy()
     kwargs['input_source'] = self.tempf
     fs2 = FileLoader(**kwargs)
     assert len(fs2) == 3
コード例 #9
0
 def test_detected_from_input_file(self):
     kwargs = self.base_kwargs.copy()
     kwargs['input_source'] = self.tempf.name
     fs1 = FileLoader(**kwargs)
     assert len(fs1) == 3