예제 #1
0
def _load_negative_segment(root_path: str) -> Segment:
    segment = Segment("negative")
    for negative_image_path in glob(os.path.join(root_path, "negatives", "negativePics", "*.png")):
        data = Data(negative_image_path)
        data.label.box2d = []
        segment.append(data)
    return segment
예제 #2
0
def _get_segment(path: str, segment_name: str) -> Segment:
    segment = Segment(segment_name)
    image_paths = glob(os.path.join(path, segment_name, "*.png"))

    for image_path in image_paths:
        segment.append(Data(image_path))
    return segment
예제 #3
0
    def test_upload_segment_with_label(self, accesskey, url, tmp_path):
        gas_client = GAS(access_key=accesskey, url=url)
        dataset_name = get_dataset_name()
        dataset_client = gas_client.create_dataset(dataset_name)
        dataset_client.create_draft("draft-1")
        dataset_client.upload_catalog(Catalog.loads(CATALOG))

        segment = Segment("segment1")
        path = tmp_path / "sub"
        path.mkdir()
        for i in range(10):
            local_path = path / f"hello{i}.txt"
            local_path.write_text("CONTENT")
            data = Data(local_path=str(local_path))
            data.label = Label.loads(LABEL)
            segment.append(data)

        dataset_client.upload_segment(segment)
        segment1 = Segment(name="segment1", client=dataset_client)
        assert len(segment1) == 10
        assert segment1[0].path == "hello0.txt"
        assert segment1[0].path == segment[0].target_remote_path
        assert segment1[0].label
        # todo: match the input and output label

        gas_client.delete_dataset(dataset_name)
예제 #4
0
    def test_upload_segment_without_file(self, accesskey, url):
        gas_client = GAS(access_key=accesskey, url=url)
        dataset_name = get_dataset_name()
        dataset_client = gas_client.create_dataset(dataset_name)
        dataset_client.create_draft("test")

        segment = Segment("segment1")

        dataset_client.upload_segment(segment)
        segment1 = Segment(name="segment1", client=dataset_client)
        assert len(segment1) == 0

        gas_client.delete_dataset(dataset_name)
    def test_getitem(self):
        dataset = DatasetBase("test_name")
        train_segment = Segment("train")
        test_segment = Segment("test")
        dataset.add_segment(train_segment)
        dataset.add_segment(test_segment)
        assert test_segment is dataset[0] is dataset["test"]
        assert train_segment is dataset[1] is dataset["train"]

        with pytest.raises(IndexError):
            dataset[2]
        with pytest.raises(KeyError):
            dataset["unknown"]
 def test_upload_segment(self, mocker):
     self.dataset_client._status.checkout(draft_number=1)
     segment_test = Segment(name="test1")
     for i in range(5):
         segment_test.append(Data(f"data{i}.png"))
     segment_client = SegmentClient(name="test1", data_client=self.dataset_client)
     upload_segment = mocker.patch(
         f"{dataset.__name__}.DatasetClient._upload_segment", return_value=segment_client
     )
     assert self.dataset_client.upload_segment(segment_test).name == "test1"
     args, keywords = upload_segment.call_args
     assert args[0] == segment_test
     assert keywords["jobs"] == 1
     assert not keywords["skip_uploaded_files"]
     upload_segment.assert_called_once()
예제 #7
0
    def test_copy_data(self, accesskey, url, tmp_path):
        gas_client = GAS(access_key=accesskey, url=url)
        dataset_name = get_dataset_name()
        gas_client.create_dataset(dataset_name)
        dataset = Dataset(name=dataset_name)
        segment = dataset.create_segment("Segment1")
        dataset._catalog = Catalog.loads(CATALOG)
        path = tmp_path / "sub"
        path.mkdir()
        for i in range(10):
            local_path = path / f"hello{i}.txt"
            local_path.write_text("CONTENT")
            data = Data(local_path=str(local_path))
            data.label = Label.loads(LABEL)
            segment.append(data)

        dataset_client = gas_client.upload_dataset(dataset)
        segment_client = dataset_client.get_segment("Segment1")
        segment_client.copy_data("hello0.txt", "goodbye0.txt")
        segment_client.copy_data("hello1.txt", "hello10.txt")

        with pytest.raises(InvalidParamsError):
            segment_client.copy_data("hello2.txt", "see_you.txt", strategy="push")

        segment2 = Segment("Segment1", client=dataset_client)
        assert segment2[0].path == "goodbye0.txt"
        assert segment2[3].path == "hello10.txt"
        assert segment2[1].label

        gas_client.delete_dataset(dataset_name)
예제 #8
0
    def test_copy_segment_skip(self, accesskey, url, tmp_path):
        gas_client = GAS(access_key=accesskey, url=url)
        dataset_name = get_dataset_name()
        gas_client.create_dataset(dataset_name)
        dataset = Dataset(name=dataset_name)
        segment1 = dataset.create_segment("Segment1")
        dataset._catalog = Catalog.loads(CATALOG)
        path = tmp_path / "sub"
        path.mkdir()
        for i in range(10):
            local_path = path / f"hello{i}.txt"
            local_path.write_text("CONTENT")
            data = Data(local_path=str(local_path))
            data.label = Label.loads(LABEL)
            segment1.append(data)

        segment2 = dataset.create_segment("Segment2")
        for i in range(10, 20):
            local_path = path / f"hello{i}.txt"
            local_path.write_text("CONTENT")
            data = Data(local_path=str(local_path))
            data.label = Label.loads(LABEL)
            segment2.append(data)

        dataset_client = gas_client.upload_dataset(dataset)
        dataset_client.copy_segment("Segment1", "Segment2", strategy="skip")

        segment_copied = Segment("Segment2", client=dataset_client)
        assert segment_copied[0].path == "hello10.txt"
        assert segment_copied[0].path == segment2[0].target_remote_path
        assert segment_copied[0].label

        gas_client.delete_dataset(dataset_name)
예제 #9
0
    def test_move_data_skip(self, accesskey, url, tmp_path):
        gas_client = GAS(access_key=accesskey, url=url)
        dataset_name = get_dataset_name()
        gas_client.create_dataset(dataset_name)
        dataset = Dataset(name=dataset_name)
        segment = dataset.create_segment("Segment1")
        dataset._catalog = Catalog.loads(CATALOG)
        path = tmp_path / "sub"
        path.mkdir()
        for i in range(10):
            local_path = path / f"hello{i}.txt"
            local_path.write_text(f"CONTENT_{i}")
            data = Data(local_path=str(local_path))
            data.label = Label.loads(LABEL)
            segment.append(data)

        dataset_client = gas_client.upload_dataset(dataset)
        segment_client = dataset_client.get_segment("Segment1")

        segment_client.move_data("hello0.txt", "hello1.txt", strategy="skip")

        segment_moved = Segment("Segment1", client=dataset_client)
        assert segment_moved[0].path == "hello1.txt"
        assert segment_moved[0].open().read() == b"CONTENT_1"

        gas_client.delete_dataset(dataset_name)
예제 #10
0
    def test_upload_dataset_to_given_draft(self, accesskey, url, tmp_path):
        gas_client = GAS(access_key=accesskey, url=url)
        dataset_name = get_random_dataset_name()
        dataset_client_1 = gas_client.create_dataset(dataset_name)
        draft_number = dataset_client_1.create_draft("test")

        dataset = Dataset(name=dataset_name)
        segment = dataset.create_segment("Segment1")

        path = tmp_path / "sub"
        path.mkdir()
        for i in range(10):
            local_path = path / f"hello{i}.txt"
            local_path.write_text("CONTENT")
            segment.append(Data(local_path=str(local_path)))

        dataset_client_2 = gas_client.upload_dataset(dataset,
                                                     draft_number=draft_number)
        segment1 = Segment("Segment1", client=dataset_client_2)
        assert len(segment1) == 10
        assert segment1[0].path == "hello0.txt"
        assert not segment1[0].label

        with pytest.raises(GASResponseError):
            gas_client.upload_dataset(dataset, draft_number=draft_number + 1)

        gas_client.delete_dataset(dataset_name)
예제 #11
0
    def test_copy_data_between_datasets(self, accesskey, url, tmp_path):
        gas_client = GAS(access_key=accesskey, url=url)
        dataset_name_1 = get_dataset_name()
        gas_client.create_dataset(dataset_name_1)
        dataset_1 = Dataset(name=dataset_name_1)
        segment_1 = dataset_1.create_segment("Segment1")
        dataset_1._catalog = Catalog.loads(CATALOG)
        path = tmp_path / "sub"
        path.mkdir()
        for i in range(10):
            local_path = path / f"hello{i}.txt"
            local_path.write_text("CONTENT")
            data = Data(local_path=str(local_path))
            data.label = Label.loads(LABEL)
            segment_1.append(data)
        dataset_client_1 = gas_client.upload_dataset(dataset_1)
        dataset_client_1.commit("upload data")
        segment_client_1 = dataset_client_1.get_segment("Segment1")

        dataset_name_2 = dataset_name_1 + "_2"
        dataset_client_2 = gas_client.create_dataset(dataset_name_2)
        dataset_client_2.create_draft("draft_2")
        dataset_client_2.create_segment("Segment1")
        segment_client_2 = dataset_client_2.get_segment("Segment1")

        segment_client_2.copy_data("hello0.txt", "hello0.txt", source_client=segment_client_1)

        segment2 = Segment("Segment1", client=dataset_client_2)
        assert segment2[0].path == "hello0.txt"
        assert segment2[0].label

        gas_client.delete_dataset(dataset_name_1)
        gas_client.delete_dataset(dataset_name_2)
예제 #12
0
    def test_upload_dataset_only_with_file(self, accesskey, url, tmp_path):
        gas_client = GAS(access_key=accesskey, url=url)
        dataset_name = get_dataset_name()
        gas_client.create_dataset(dataset_name)

        dataset = Dataset(name=dataset_name)
        dataset.notes.is_continuous = True
        segment = dataset.create_segment("Segment1")

        path = tmp_path / "sub"
        path.mkdir()
        for i in range(10):
            local_path = path / f"hello{i}.txt"
            local_path.write_text("CONTENT")
            segment.append(Data(local_path=str(local_path)))

        dataset_client = gas_client.upload_dataset(dataset)
        assert dataset_client.status.branch_name == DEFAULT_BRANCH
        assert dataset_client.status.draft_number
        assert not dataset_client.status.commit_id

        assert dataset_client.get_notes().is_continuous is True
        assert not dataset_client.get_catalog()
        segment1 = Segment("Segment1", client=dataset_client)
        assert len(segment1) == 10
        assert segment1[0].path == "hello0.txt"
        assert not segment1[0].label

        gas_client.delete_dataset(dataset_name)
예제 #13
0
    def test_move_segment(self, accesskey, url, tmp_path):
        gas_client = GAS(access_key=accesskey, url=url)
        dataset_name = get_dataset_name()
        gas_client.create_dataset(dataset_name)
        dataset = Dataset(name=dataset_name)
        segment = dataset.create_segment("Segment1")
        dataset._catalog = Catalog.loads(CATALOG)
        path = tmp_path / "sub"
        path.mkdir()
        for i in range(10):
            local_path = path / f"hello{i}.txt"
            local_path.write_text("CONTENT")
            data = Data(local_path=str(local_path))
            data.label = Label.loads(LABEL)
            segment.append(data)

        dataset_client = gas_client.upload_dataset(dataset)
        segment_client = dataset_client.move_segment("Segment1", "Segment2")
        assert segment_client.name == "Segment2"

        with pytest.raises(InvalidParamsError):
            dataset_client.move_segment("Segment1", "Segment3", strategy="push")

        segment2 = Segment("Segment2", client=dataset_client)
        assert segment2[0].path == "hello0.txt"
        assert segment2[0].path == segment[0].target_remote_path
        assert segment2[0].label

        gas_client.delete_dataset(dataset_name)
    def test_create_and_upload_dataset_with_config(self, accesskey, url, tmp_path):
        gas_client = GAS(access_key=accesskey, url=url)
        dataset_name = get_dataset_name()
        try:
            gas_client.get_auth_storage_config(name=_LOCAL_CONFIG_NAME)
        except ResourceNotExistError:
            pytest.skip(f"skip this case because there's no {_LOCAL_CONFIG_NAME} config")

        gas_client.create_dataset(dataset_name, config_name=_LOCAL_CONFIG_NAME)
        dataset = Dataset(name=dataset_name)
        segment = dataset.create_segment("Segment1")
        # When uploading label, upload catalog first.
        dataset._catalog = Catalog.loads(CATALOG)

        path = tmp_path / "sub"
        path.mkdir()
        for i in range(5):
            local_path = path / f"hello{i}.txt"
            local_path.write_text("CONTENT")
            data = Data(local_path=str(local_path))
            data.label = Label.loads(LABEL)
            segment.append(data)

        dataset_client = gas_client.upload_dataset(dataset)
        assert dataset_client.get_catalog()
        segment1 = Segment("Segment1", client=dataset_client)
        assert len(segment1) == 5
        for i in range(5):
            assert segment1[i].path == f"hello{i}.txt"
            assert segment1[i].label

        gas_client.delete_dataset(dataset_name)
예제 #15
0
    def test_import_cloud_files(self, accesskey, url, config_name):

        gas_client = GAS(access_key=accesskey, url=url)
        try:
            cloud_client = gas_client.get_cloud_client(config_name)
        except ResourceNotExistError:
            pytest.skip(
                f"skip this case because there's no {config_name} config")

        auth_data = cloud_client.list_auth_data("tests")
        dataset_name = get_dataset_name()
        dataset_client = gas_client.create_dataset(dataset_name,
                                                   config_name=config_name)

        dataset = Dataset(name=dataset_name)
        segment = dataset.create_segment("Segment1")
        for data in auth_data:
            segment.append(data)

        dataset_client = gas_client.upload_dataset(dataset, jobs=5)
        dataset_client.commit("import data")

        segment1 = Segment("Segment1", client=dataset_client)
        assert len(segment1) == len(segment)
        assert segment1[0].path == segment[0].path.split("/")[-1]
        assert not segment1[0].label

        assert len(auth_data) == len(segment)

        gas_client.delete_dataset(dataset_name)
    def test_delitem(self):
        dataset = DatasetBase("test_name")
        segments = [Segment(str(i)) for i in range(5)]
        for segment in segments:
            dataset.add_segment(segment)

        del segments[1:3]
        del dataset[1:3]
        assert len(dataset) == len(segments)
        for dataset_segment, segment in zip(dataset, segments):
            assert dataset_segment is segment

        del segments[1]
        del dataset[1]
        assert len(dataset) == len(segments)
        for dataset_segment, segment in zip(dataset, segments):
            assert dataset_segment is segment

        del segments[segments.index(dataset["4"])]
        del dataset["4"]
        assert len(dataset) == len(segments)
        for dataset_segment, segment in zip(dataset, segments):
            assert dataset_segment is segment

        del dataset[100:200]
        assert len(dataset) == len(segments)
        for dataset_segment, segment in zip(dataset, segments):
            assert dataset_segment is segment

        with pytest.raises(IndexError):
            del dataset[100]

        with pytest.raises(KeyError):
            del dataset["100"]
예제 #17
0
    def test_upload_dataset_with_label(self, accesskey, url, tmp_path):
        gas_client = GAS(access_key=accesskey, url=url)
        dataset_name = get_dataset_name()
        gas_client.create_dataset(dataset_name)

        dataset = Dataset(name=dataset_name)
        segment = dataset.create_segment("Segment1")
        # When uploading label, upload catalog first.
        dataset._catalog = Catalog.loads(CATALOG)

        path = tmp_path / "sub"
        path.mkdir()
        for i in range(10):
            local_path = path / f"hello{i}.txt"
            local_path.write_text("CONTENT")
            data = Data(local_path=str(local_path))
            data.label = Label.loads(LABEL)
            segment.append(data)

        dataset_client = gas_client.upload_dataset(dataset)
        assert dataset_client.get_catalog()
        segment1 = Segment("Segment1", client=dataset_client)
        assert len(segment1) == 10
        assert segment1[0].path == "hello0.txt"
        assert segment1[0].label

        gas_client.delete_dataset(dataset_name)
예제 #18
0
    def test__upload_segment(self, mocker):
        segment_test = Segment(name="test1")
        for i in range(5):
            segment_test.append(Data(f"data{i}.png"))
        segment_client = SegmentClient(name="test1",
                                       data_client=self.dataset_client)
        get_or_create_segment = mocker.patch(
            f"{dataset.__name__}.DatasetClient.get_or_create_segment",
            return_value=segment_client)
        list_data_paths = mocker.patch(
            f"{segment.__name__}.SegmentClient.list_data_paths",
            return_value=["data1.png", "data2.png"],
        )
        multithread_upload = mocker.patch(
            f"{dataset.__name__}.multithread_upload")

        with Tqdm(5, disable=False) as pbar:
            self.dataset_client._upload_segment(segment_test,
                                                skip_uploaded_files=True,
                                                pbar=pbar)
            get_or_create_segment.assert_called_once_with(segment_test.name)
            list_data_paths.assert_called_once_with()
            args, keywords = multithread_upload.call_args
            assert args[0] == segment_client._upload_or_import_data
            assert [item.path for item in args[1]
                    ] == ["data0.png", "data3.png", "data4.png"]
            assert keywords[
                "callback"] == segment_client._synchronize_upload_info
            assert keywords["jobs"] == 1
            assert keywords["pbar"] == pbar
            multithread_upload.assert_called_once()
        with Tqdm(5, disable=False) as pbar:
            self.dataset_client._upload_segment(segment_test,
                                                skip_uploaded_files=False,
                                                pbar=pbar)
            get_or_create_segment.assert_called_with(segment_test.name)
            list_data_paths.assert_called_with()
            args, keywords = multithread_upload.call_args
            assert args[0] == segment_client._upload_or_import_data
            assert [item.path
                    for item in args[1]] == [f"data{i}.png" for i in range(5)]
            assert keywords[
                "callback"] == segment_client._synchronize_upload_info
            assert keywords["jobs"] == 1
            assert keywords["pbar"] == pbar
            multithread_upload.assert_called()
예제 #19
0
def _load_positive_segment(segment_name: str, segment_path: str) -> Segment:
    if segment_name.startswith("vid"):
        # Pad zero for segment name to change "vid0" to "vid00"
        segment_name = f"{segment_name[:3]}{int(segment_name[3:]):02}"
    segment = Segment(segment_name)
    annotation_file = glob(
        os.path.join(segment_path, "frameAnnotations-*",
                     "frameAnnotations.csv"))[0]
    image_folder = os.path.dirname(annotation_file)
    pre_filename = ""
    with open(annotation_file, "r", encoding="utf-8") as fp:
        for annotation in csv.DictReader(fp, delimiter=";"):
            filename = annotation["Filename"]

            if filename != pre_filename:
                data = Data(os.path.join(image_folder, filename))
                data.label.box2d = []
                segment.append(data)
                pre_filename = filename

            occluded, on_another_road = annotation[
                "Occluded,On another road"].split(",", 1)
            data.label.box2d.append(
                LabeledBox2D(
                    int(annotation["Upper left corner X"]),
                    int(annotation["Upper left corner Y"]),
                    int(annotation["Lower right corner X"]),
                    int(annotation["Lower right corner Y"]),
                    category=annotation["Annotation tag"],
                    attributes={
                        "Occluded":
                        bool(int(occluded)),
                        "On another road":
                        bool(int(on_another_road)),
                        "Origin file":
                        annotation["Origin file"],
                        "Origin frame number":
                        int(annotation["Origin frame number"]),
                        "Origin track":
                        annotation["Origin track"],
                        "Origin track frame number":
                        int(annotation["Origin track frame number"]),
                    },
                ))
    return segment
    def test_cache_dataset(self, accesskey, url, tmp_path):
        gas_client = GAS(access_key=accesskey, url=url)
        dataset_name = get_dataset_name()
        gas_client.create_dataset(dataset_name)

        dataset = Dataset(name=dataset_name)
        segment = dataset.create_segment("Segment1")
        # When uploading label, upload catalog first.
        dataset._catalog = Catalog.loads(_CATALOG)

        path = tmp_path / "sub"
        semantic_path = tmp_path / "semantic_mask"
        instance_path = tmp_path / "instance_mask"
        path.mkdir()
        semantic_path.mkdir()
        instance_path.mkdir()
        for i in range(_SEGMENT_LENGTH):
            local_path = path / f"hello{i}.txt"
            local_path.write_text("CONTENT")
            data = Data(local_path=str(local_path))
            data.label = Label.loads(_LABEL)

            semantic_mask = semantic_path / f"semantic_mask{i}.png"
            semantic_mask.write_text("SEMANTIC_MASK")
            data.label.semantic_mask = SemanticMask(str(semantic_mask))

            instance_mask = instance_path / f"instance_mask{i}.png"
            instance_mask.write_text("INSTANCE_MASK")
            data.label.instance_mask = InstanceMask(str(instance_mask))
            segment.append(data)

        dataset_client = gas_client.upload_dataset(dataset)
        dataset_client.commit("commit-1")
        cache_path = tmp_path / "cache_test"
        dataset_client.enable_cache(str(cache_path))
        segment1 = Segment("Segment1", client=dataset_client)
        for data in segment1:
            data.open()
            data.label.semantic_mask.open()
            data.label.instance_mask.open()

        segment_cache_path = (cache_path / dataset_client.dataset_id /
                              dataset_client.status.commit_id / "Segment1")
        semantic_mask_cache_path = segment_cache_path / "semantic_mask"
        instance_mask_cache_path = segment_cache_path / "instance_mask"

        for cache_dir, extension in (
            (segment_cache_path, "txt"),
            (semantic_mask_cache_path, "png"),
            (instance_mask_cache_path, "png"),
        ):
            assert set(cache_dir.glob(f"*.{extension}")) == set(
                cache_dir / f"hello{i}.{extension}"
                for i in range(_SEGMENT_LENGTH))

        gas_client.delete_dataset(dataset_name)
    def test_contains(self):
        dataset = DatasetBase("test_name")
        keys = ("test", "train")
        for key in keys:
            dataset.add_segment(Segment(key))

        for key in keys:
            assert key in dataset

        assert "val" not in dataset
        assert 100 not in dataset
예제 #22
0
    def _generate_segments(self,
                           offset: int = 0,
                           limit: int = 128) -> Generator[Segment, None, int]:
        response = self._list_segments(offset, limit)

        for item in response["segments"]:
            segment = Segment._from_client(  # pylint: disable=protected-access
                SegmentClient(item["name"], self))
            segment.description = item["description"]
            yield segment

        return response["totalCount"]  # type: ignore[no-any-return]
예제 #23
0
    def test_data_in_draft(self, accesskey, url, tmp_path):
        gas_client = GAS(access_key=accesskey, url=url)
        dataset_name = get_dataset_name()
        dataset_client = gas_client.create_dataset(dataset_name)
        dataset_client.create_draft("draft-1")
        segment = Segment("segment1")
        path = tmp_path / "sub"
        path.mkdir()
        for i in range(10):
            local_path = path / f"hello{i}.txt"
            local_path.write_text("CONTENT")
            data = Data(local_path=str(local_path))
            segment.append(data)

        dataset_client.upload_segment(segment)
        dataset_client.commit("commit-1")
        segment1 = Segment(name="segment1", client=dataset_client)
        assert len(segment1) == 10
        assert segment1[0].get_url()
        assert segment1[0].path == segment[0].target_remote_path

        dataset_client.create_draft("draft-2")
        segment1 = Segment(name="segment1", client=dataset_client)
        assert len(segment1) == 10
        assert segment1[0].get_url()
        assert segment1[0].path == segment[0].target_remote_path

        gas_client.delete_dataset(dataset_name)
예제 #24
0
    def test_sort(self):
        segment = Segment("train")
        segment.append(Data("file1"))
        segment.append(Data("file2"))

        assert segment[0].path == "file1"

        segment.sort(key=lambda data: data.path, reverse=True)
        assert segment[0].path == "file2"
예제 #25
0
    def test_upload_dataset_to_given_branch(self, accesskey, url, tmp_path):
        gas_client = GAS(access_key=accesskey, url=url)
        dataset_name = get_dataset_name()
        dataset_client_1 = gas_client.create_dataset(dataset_name)
        dataset_client_1.create_draft("test")
        dataset_client_1.commit("test1")
        dataset_client_1.create_branch("dev")

        dataset = Dataset(name=dataset_name)
        segment = dataset.create_segment("Segment1")

        path = tmp_path / "sub"
        path.mkdir()
        for i in range(10):
            local_path = path / f"hello{i}.txt"
            local_path.write_text("CONTENT")
            segment.append(Data(local_path=str(local_path)))

        dataset_client_2 = gas_client.upload_dataset(dataset,
                                                     branch_name="dev")
        assert dataset_client_2.status.branch_name == "dev"
        assert dataset_client_2.status.draft_number
        assert not dataset_client_2.status.commit_id

        segment1 = Segment("Segment1", client=dataset_client_2)
        assert len(segment1) == 10
        assert segment1[0].path == "hello0.txt"
        assert not segment1[0].label

        dataset_client_2.commit("test2")
        draft_number = dataset_client_2.create_draft("test2")

        for i in range(10):
            local_path = path / f"goodbye{i}.txt"
            local_path.write_text("CONTENT")
            segment.append(Data(local_path=str(local_path)))

        dataset_client_2 = gas_client.upload_dataset(dataset,
                                                     branch_name="dev")
        assert dataset_client_2.status.branch_name == "dev"
        assert dataset_client_2.status.draft_number == draft_number
        assert not dataset_client_2.status.commit_id

        with pytest.raises(ResourceNotExistError):
            gas_client.upload_dataset(dataset, branch_name="wrong")

        gas_client.delete_dataset(dataset_name)
예제 #26
0
    def test_copy_data_from_commits(self, accesskey, url, tmp_path):
        gas_client = GAS(access_key=accesskey, url=url)
        dataset_name = get_dataset_name()
        gas_client.create_dataset(dataset_name)
        dataset = Dataset(name=dataset_name)
        segment = dataset.create_segment("Segment1")
        dataset._catalog = Catalog.loads(CATALOG)
        path = tmp_path / "sub"
        path.mkdir()
        for i in range(10):
            local_path = path / f"hello{i}.txt"
            local_path.write_text("CONTENT")
            data = Data(local_path=str(local_path))
            data.label = Label.loads(LABEL)
            segment.append(data)

        dataset_client = gas_client.upload_dataset(dataset)
        dataset_client.commit("commit_1")

        for i in range(10, 20):
            local_path = path / f"hello{i}.txt"
            local_path.write_text("CONTENT")
            data = Data(local_path=str(local_path))
            data.label = Label.loads(LABEL)
            segment.append(data)
        dataset_client = gas_client.upload_dataset(dataset)
        dataset_client.commit("commit_2")

        dataset_client_1 = gas_client.get_dataset(dataset_name)
        commit_id = dataset_client_1.list_commits()[-1].commit_id
        dataset_client_1.checkout(revision=commit_id)
        dataset_client.create_draft("draft_3")
        segment_client_1 = dataset_client_1.get_segment("Segment1")
        segment_client_2 = dataset_client.get_segment("Segment1")
        segment_client_2.copy_data("hello0.txt",
                                   "goodbye0.txt",
                                   source_client=segment_client_1)

        segment2 = Segment("Segment1", client=dataset_client)
        assert segment2[0].path == "goodbye0.txt"
        assert segment2[0].path != segment[0].target_remote_path
        assert segment2[0].label
        assert len(segment2) == 21

        gas_client.delete_dataset(dataset_name)
예제 #27
0
    def test_copy_between_datasets_override(self, accesskey, url, tmp_path):
        gas_client = GAS(access_key=accesskey, url=url)
        dataset_name_1 = get_dataset_name()
        gas_client.create_dataset(dataset_name_1)
        dataset_1 = Dataset(name=dataset_name_1)
        segment_1 = dataset_1.create_segment("Segment")
        path = tmp_path / "sub"
        path.mkdir()
        for i in range(10):
            local_path = path / f"hello{i}.txt"
            local_path.write_text("CONTENT")
            data = Data(local_path=str(local_path))
            segment_1.append(data)
        dataset_client_1 = gas_client.upload_dataset(dataset_1)
        dataset_client_1.commit("upload data")

        dataset_name_2 = dataset_name_1 + "_2"
        dataset_client_2 = gas_client.create_dataset(dataset_name_2)
        dataset_2 = Dataset(name=dataset_name_2)
        segment_2 = dataset_2.create_segment("Segment")
        for i in range(10, 15):
            local_path = path / f"hello{i}.txt"
            local_path.write_text("CONTENT")
            data = Data(local_path=str(local_path))
            segment_2.append(data)
        dataset_client_2 = gas_client.upload_dataset(dataset_2)
        dataset_client_2.commit("upload data")
        dataset_client_2.create_draft("draft 2")
        dataset_client_2.copy_segment(
            "Segment", source_client=dataset_client_1, strategy="override"
        )
        dataset_client_2.commit("copy segmnet")

        segment = Segment("Segment", client=dataset_client_2)
        assert len(segment) == 10
        assert segment[0].path == "hello0.txt"

        gas_client.delete_dataset(dataset_name_1)
        gas_client.delete_dataset(dataset_name_2)
예제 #28
0
    def test_move_segment_override(self, accesskey, url, tmp_path):
        gas_client = GAS(access_key=accesskey, url=url)
        dataset_name = get_dataset_name()
        gas_client.create_dataset(dataset_name)
        dataset = Dataset(name=dataset_name)
        segment1 = dataset.create_segment("Segment1")
        dataset._catalog = Catalog.loads(CATALOG)
        path = tmp_path / "sub"
        path.mkdir()
        for i in range(10):
            local_path = path / f"hello{i}.txt"
            local_path.write_text("CONTENT_1")
            data = Data(local_path=str(local_path))
            data.label = Label.loads(LABEL)
            segment1.append(data)

        segment2 = dataset.create_segment("Segment2")
        for i in range(10, 20):
            local_path = path / f"hello{i}.txt"
            local_path.write_text("CONTENT_2")
            data = Data(local_path=str(local_path))
            data.label = Label.loads(LABEL)
            segment2.append(data)

        dataset_client = gas_client.upload_dataset(dataset)
        dataset_client.move_segment("Segment1",
                                    "Segment2",
                                    strategy="override")

        with pytest.raises(ResourceNotExistError):
            dataset_client.get_segment("Segment1")

        segment_moved = Segment("Segment2", client=dataset_client)
        assert segment_moved[0].path == "hello0.txt"
        assert segment_moved[0].path == segment1[0].target_remote_path
        assert segment_moved[0].open().read() == b"CONTENT_1"
        assert segment_moved[0].label

        gas_client.delete_dataset(dataset_name)
예제 #29
0
    def test_upload_dataset_only_with_file(self, accesskey, url, tmp_path):
        gas_client = GAS(access_key=accesskey, url=url)
        dataset_name = get_random_dataset_name()
        gas_client.create_dataset(dataset_name)

        dataset = Dataset(name=dataset_name)
        segment = dataset.create_segment("Segment1")

        path = tmp_path / "sub"
        path.mkdir()
        for i in range(10):
            local_path = path / f"hello{i}.txt"
            local_path.write_text("CONTENT")
            segment.append(Data(local_path=str(local_path)))

        dataset_client = gas_client.upload_dataset(dataset)
        assert not dataset_client.get_catalog()
        segment1 = Segment("Segment1", client=dataset_client)
        assert len(segment1) == 10
        assert segment1[0].path == "hello0.txt"
        assert not segment1[0].label

        gas_client.delete_dataset(dataset_name)
예제 #30
0
def _get_segment(
    segment_name: str,
    local_abspaths: Iterable[str],
    remote_path: str,
    is_recursive: bool,
) -> Segment:
    """Get the pair of local_path and remote_path.

    Arguments:
        segment_name: The name of the segment these data belong to.
        local_abspaths: A list of local abstract paths, could be folder or file.
        remote_path: The remote object path, not necessarily end with '/'.
        is_recursive: Whether copy directories recursively.

    Returns:
        A segment contains mapping data.

    """
    segment = Segment(segment_name)
    for local_abspath in local_abspaths:
        if not os.path.isdir(local_abspath):
            data = Data(
                local_abspath,
                target_remote_path=str(PurePosixPath(remote_path, os.path.basename(local_abspath))),
            )
            segment.append(data)
            continue

        if not is_recursive:
            error("Local paths include directories, please use -r option")

        local_abspath = os.path.normpath(local_abspath)
        folder_name = os.path.basename(local_abspath)
        for root, _, filenames in os.walk(local_abspath):
            relpath = os.path.relpath(root, local_abspath) if root != local_abspath else ""
            for filename in filenames:
                data = Data(
                    os.path.join(root, filename),
                    target_remote_path=str(
                        PurePosixPath(Path(remote_path, folder_name, relpath, filename))
                    ),
                )
                segment.append(data)
    return segment