コード例 #1
0
ファイル: test_file_group.py プロジェクト: hubayirp/kaishi
def test_configure_pipeline():
    test = FileGroup(recursive=True)
    test.configure_pipeline(["FilterDuplicateFiles"])
    pipeline_names = [
        component.__class__.__name__ for component in test.pipeline.components
    ]
    assert "FilterDuplicateFiles" in pipeline_names
コード例 #2
0
def test_by_regex():
    test = FileGroup(recursive=True)
    test.load_dir("tests/data/image", File, recursive=True)
    test.configure_pipeline(["FilterByRegex"])
    test.pipeline.components[0].configure(pattern="sample.jpg")
    original_count = len(test.files)
    test.run_pipeline()
    assert len(test.files) == original_count - 1
コード例 #3
0
def test_subsample():
    test = FileGroup(recursive=True)
    test.load_dir("tests/data/image", File, recursive=True)
    test.configure_pipeline(["FilterSubsample"])
    assert len(test.files) > 2
    test.pipeline.components[0].configure(N=2)
    test.run_pipeline()
    assert len(test.files) == 2
コード例 #4
0
    def __new__(self, source: str = None, recursive: bool = False):
        """Initialize with the default pipeline defined.

        :param source: string identifying source (e.g. directory)
        :type source: str
        :param recursive: flag to set recursion
        :type recursive: bool
        """
        if os.path.exists(source):
            file_dataset = FileGroup(recursive=recursive)
            file_dataset.load_dir(source, File, file_dataset.recursive)
            return file_dataset
        else:
            raise NotImplementedError(
                "Currently only supports a valid path as input")
コード例 #5
0
def test_by_label():
    test = FileGroup(recursive=True)
    test.load_dir("tests/data/image", File, recursive=True)
    test.configure_pipeline(["FilterByLabel"])
    test.pipeline.components[0].configure(label_to_filter="TRAIN")
    test.files[0].add_label("TRAIN")
    original_count = len(test.files)
    test.run_pipeline()
    assert len(test.files) == original_count - 1
コード例 #6
0
ファイル: test_file_group.py プロジェクト: hubayirp/kaishi
def test_file_report():
    test = FileGroup(recursive=True)
    test.load_dir("tests/data/image", File, True)
    print_capture = StringIO()
    sys.stdout = print_capture
    test.file_report()
    sys.stdout = sys.__stdout__
    assert "sample.jpg" in print_capture.getvalue()
コード例 #7
0
def test_validation_and_test():
    test = FileGroup(recursive=True)
    test.load_dir("tests/data", File, recursive=True)
    test.configure_pipeline(["LabelerValidationAndTest"])
    test.pipeline.components[0].configure(val_frac=0.2, test_frac=0.2)
    test.run_pipeline()
    found_counts = [0, 0, 0]
    for fobj in test.files:
        if fobj.has_label("TRAIN"):
            found_counts[0] += 1
        elif fobj.has_label("VALIDATE"):
            found_counts[1] += 1
        elif fobj.has_label("TEST"):
            found_counts[2] += 1

    assert found_counts[0] == len(
        test.files) - 2 * round(len(test.files) * 0.2)
    assert found_counts[1] == round(len(test.files) * 0.2)
    assert found_counts[2] == round(len(test.files) * 0.2)
コード例 #8
0
ファイル: test_file_group.py プロジェクト: hubayirp/kaishi
def test_init_and_load_dir():
    test = FileGroup(recursive=True)
    test.load_dir("tests/data/image", File, True)
    assert len(test.files) > 0
コード例 #9
0
ファイル: test_file_group.py プロジェクト: hubayirp/kaishi
def test_get_pipeline_options():
    test = FileGroup(recursive=True)
    assert len(test.get_pipeline_options()) > 0
コード例 #10
0
def test_duplicates():
    test = FileGroup(recursive=True)
    test.load_dir("tests/data/image", File, recursive=True)
    test.configure_pipeline(["FilterDuplicateFiles"])
    test.run_pipeline()
    assert len(test.filtered["duplicates"]) > 0