예제 #1
0
 def _info(self):
     return datasets.DatasetInfo(features=datasets.Features({
         "channels":
         datasets.Array4D(shape=(4, 240, 240, 155), dtype='float32'),
         "segmentation":
         datasets.Array3D(shape=(240, 240, 155), dtype='float32')
     }), )
예제 #2
0
 def _info(self):
     return datasets.DatasetInfo(
         description=_DESCRIPTION,
         features=datasets.Features(
             {
                 "img": datasets.Array3D(shape=(32, 32, 3), dtype="uint8"),
                 "label": datasets.features.ClassLabel(
                     names=[
                         "airplane",
                         "automobile",
                         "bird",
                         "cat",
                         "deer",
                         "dog",
                         "frog",
                         "horse",
                         "ship",
                         "truck",
                     ]
                 ),
             }
         ),
         supervised_keys=("img", "label"),
         homepage="https://www.cs.toronto.edu/~kriz/cifar.html",
         citation=_CITATION,
     )
예제 #3
0
    def __init__(self, config):
        self.config = config
        self.image_column_name = config.image_column_name
        self.label_column_name = config.label_column_name
        self.channels_first_input = config.channels_first_input

        transformations = []
        for transform in config.transform_args:
            param_dict = (dict(transform["params"])
                          if transform["params"] is not None else {})
            transformations.append(
                configmapper.get_object("transforms",
                                        transform["type"])(**param_dict))
        self.transform = (transforms.Compose(transformations)
                          if transformations != [] else None)

        self.raw_dataset = load_dataset(**config.load_dataset_args)
        if config.remove_columns is not None:
            self.raw_dataset = self.raw_dataset.remove_columns(
                config.remove_columns)
        self.raw_dataset.set_format(
            "torch", columns=self.raw_dataset["train"].column_names)

        features = datasets.Features({
            self.image_column_name:
            datasets.Array3D(
                shape=tuple(self.config.features.image_output_shape),
                dtype="float32",
            ),
            self.label_column_name:
            datasets.features.ClassLabel(
                names=list(self.config.features.label_names)),
        })

        self.train_dataset = self.raw_dataset.map(
            self.prepare_features,
            features=features,
            batched=True,
            batch_size=64,
        )

        if self.image_column_name != "image":
            self.train_dataset = self.train_dataset.rename_column(
                self.image_column_name, "image")
        if self.label_column_name != "label":
            self.train_dataset = self.train_dataset.rename_column(
                self.label_column_name, "label")

        self.train_dataset.set_format("torch", columns=["image", "label"])
예제 #4
0
 def _info(self):
     return datasets.DatasetInfo(
         description=_DESCRIPTION,
         features=datasets.Features({
             "id":
             datasets.Value("string"),
             "tokens":
             datasets.Sequence(datasets.Value("string")),
             "bboxes":
             datasets.Sequence(datasets.Sequence(datasets.Value("int64"))),
             "ner_tags":
             datasets.Sequence(
                 datasets.features.ClassLabel(names=[
                     "O", "B-HEADER", "I-HEADER", "B-QUESTION",
                     "I-QUESTION", "B-ANSWER", "I-ANSWER"
                 ])),
             "image":
             datasets.Array3D(shape=(3, 224, 224), dtype="uint8"),
         }),
         supervised_keys=None,
         homepage="https://guillaumejaume.github.io/FUNSD/",
         citation=_CITATION,
     )
예제 #5
0
파일: xfun.py 프로젝트: microsoft/unilm
 def _info(self):
     return datasets.DatasetInfo(
         features=datasets.Features({
             "id":
             datasets.Value("string"),
             "input_ids":
             datasets.Sequence(datasets.Value("int64")),
             "bbox":
             datasets.Sequence(datasets.Sequence(datasets.Value("int64"))),
             "labels":
             datasets.Sequence(
                 datasets.ClassLabel(names=[
                     "O", "B-QUESTION", "B-ANSWER", "B-HEADER", "I-ANSWER",
                     "I-QUESTION", "I-HEADER"
                 ])),
             "image":
             datasets.Array3D(shape=(3, 224, 224), dtype="uint8"),
             "entities":
             datasets.Sequence({
                 "start":
                 datasets.Value("int64"),
                 "end":
                 datasets.Value("int64"),
                 "label":
                 datasets.ClassLabel(
                     names=["HEADER", "QUESTION", "ANSWER"]),
             }),
             "relations":
             datasets.Sequence({
                 "head": datasets.Value("int64"),
                 "tail": datasets.Value("int64"),
                 "start_index": datasets.Value("int64"),
                 "end_index": datasets.Value("int64"),
             }),
         }),
         supervised_keys=None,
     )
예제 #6
0
 def _info(self):
     return datasets.DatasetInfo(
         description=_DESCRIPTION,
         features=datasets.Features(
             {
                 "img": datasets.Array3D(shape=(32, 32, 3), dtype="uint8"),
                 "fine_label": datasets.features.ClassLabel(
                     names=[
                         "apple",
                         "aquarium_fish",
                         "baby",
                         "bear",
                         "beaver",
                         "bed",
                         "bee",
                         "beetle",
                         "bicycle",
                         "bottle",
                         "bowl",
                         "boy",
                         "bridge",
                         "bus",
                         "butterfly",
                         "camel",
                         "can",
                         "castle",
                         "caterpillar",
                         "cattle",
                         "chair",
                         "chimpanzee",
                         "clock",
                         "cloud",
                         "cockroach",
                         "couch",
                         "cra",
                         "crocodile",
                         "cup",
                         "dinosaur",
                         "dolphin",
                         "elephant",
                         "flatfish",
                         "forest",
                         "fox",
                         "girl",
                         "hamster",
                         "house",
                         "kangaroo",
                         "keyboard",
                         "lamp",
                         "lawn_mower",
                         "leopard",
                         "lion",
                         "lizard",
                         "lobster",
                         "man",
                         "maple_tree",
                         "motorcycle",
                         "mountain",
                         "mouse",
                         "mushroom",
                         "oak_tree",
                         "orange",
                         "orchid",
                         "otter",
                         "palm_tree",
                         "pear",
                         "pickup_truck",
                         "pine_tree",
                         "plain",
                         "plate",
                         "poppy",
                         "porcupine",
                         "possum",
                         "rabbit",
                         "raccoon",
                         "ray",
                         "road",
                         "rocket",
                         "rose",
                         "sea",
                         "seal",
                         "shark",
                         "shrew",
                         "skunk",
                         "skyscraper",
                         "snail",
                         "snake",
                         "spider",
                         "squirrel",
                         "streetcar",
                         "sunflower",
                         "sweet_pepper",
                         "table",
                         "tank",
                         "telephone",
                         "television",
                         "tiger",
                         "tractor",
                         "train",
                         "trout",
                         "tulip",
                         "turtle",
                         "wardrobe",
                         "whale",
                         "willow_tree",
                         "wolf",
                         "woman",
                         "worm",
                     ]
                 ),
                 "coarse_label": datasets.features.ClassLabel(
                     names=[
                         "aquatic_mammals",
                         "fish",
                         "flowers",
                         "food_containers",
                         "fruit_and_vegetables",
                         "household_electrical_devices",
                         "household_furniture",
                         "insects",
                         "large_carnivores",
                         "large_man-made_outdoor_things",
                         "large_natural_outdoor_scenes",
                         "large_omnivores_and_herbivores",
                         "medium_mammals",
                         "non-insect_invertebrates",
                         "people",
                         "reptiles",
                         "small_mammals",
                         "trees",
                         "vehicles_1",
                         "vehicles_2",
                     ]
                 ),
             }
         ),
         supervised_keys=None,  # Probably needs to be fixed.
         homepage="https://www.cs.toronto.edu/~kriz/cifar.html",
         citation=_CITATION,
     )
예제 #7
0
    np.testing.assert_equal(arr[0], dummy_array)
    np.testing.assert_equal(arr[2], dummy_array)
    assert np.isnan(arr[1])  # a single np.nan value - np.all not needed


@pytest.mark.parametrize(
    "data, feature, expected",
    [
        (np.zeros((2, 2)), None, [[0.0, 0.0], [0.0, 0.0]]),
        (np.zeros((2, 3)), datasets.Array2D(shape=(2, 3), dtype="float32"),
         [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]),
        ([np.zeros(2)], datasets.Array2D(shape=(1, 2),
                                         dtype="float32"), [[0.0, 0.0]]),
        (
            [np.zeros((2, 3))],
            datasets.Array3D(shape=(1, 2, 3), dtype="float32"),
            [[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]],
        ),
    ],
)
def test_array_xd_with_np(data, feature, expected):
    ds = datasets.Dataset.from_dict(
        {"col": [data]},
        features=datasets.Features({"col": feature}) if feature else None)
    assert ds[0]["col"] == expected


@pytest.mark.parametrize("with_none", [False, True])
def test_dataset_map(with_none):
    ds = datasets.Dataset.from_dict({"path": ["path1", "path2"]})