コード例 #1
0
ファイル: core.py プロジェクト: yaoxinbin/detecto
    def _get_raw_predictions(self, images):
        self._model.eval()

        with torch.no_grad():
            # Convert image into a list of length 1 if not already a list
            if not _is_iterable(images):
                images = [images]

            # Convert to tensor and normalize if not already
            if not isinstance(images[0], torch.Tensor):
                # This is a temporary workaround to the bad accuracy
                # when normalizing on default weights. Will need to
                # investigate further TODO
                if self._disable_normalize:
                    defaults = transforms.Compose([transforms.ToTensor()])
                else:
                    defaults = default_transforms()
                images = [defaults(img) for img in images]

            # Send images to the specified device
            images = [img.to(self._device) for img in images]

            preds = self._model(images)
            # Send predictions to CPU if not already
            preds = [{k: v.to(torch.device('cpu')) for k, v in p.items()} for p in preds]
            return preds
コード例 #2
0
    def _get_raw_predictions(self, images):
        self._model.eval()

        with torch.no_grad():
            # Convert image into a list of length 1 if not already a list
            if not _is_iterable(images):
                images = [images]

            # Convert to tensor and normalize if not already
            if not isinstance(images[0], torch.Tensor):
                defaults = default_transforms()
                images = [defaults(img) for img in images]

            # Send images to the specified device
            images = [img.to(self._device) for img in images]

            preds = self._model(images)
            # Send predictions to CPU if not already
            preds = [{k: v.to(torch.device('cpu'))
                      for k, v in p.items()} for p in preds]
            return preds
コード例 #3
0
ファイル: core.py プロジェクト: yaoxinbin/detecto
    def __init__(self, label_data, image_folder=None, transform=None):
        """Takes in the path to the label data and images and creates
        an indexable dataset over all of the data. Applies optional
        transforms over the data. Extends PyTorch's `Dataset
        <https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset>`_.

        :param label_data: Can either contain the path to a folder storing
            the XML label files or a CSV file containing the label data.
            If a CSV file, the file should have the following columns in
            order: ``filename``, ``width``, ``height``, ``class``, ``xmin``,
            ``ymin``, ``xmax``, and ``ymax``. See
            :func:`detecto.utils.xml_to_csv` to generate CSV files in this
            format from XML label files.
        :type label_data: str
        :param image_folder: (Optional) The path to the folder containing the
            images. If not specified, it is assumed that the images and XML
            files are in the same directory as given by `label_data`. Defaults
            to None.
        :type image_folder: str
        :param transform: (Optional) A torchvision `transforms.Compose
            <https://pytorch.org/docs/stable/torchvision/transforms.html#torchvision.transforms.Compose>`__
            object containing transformations to apply on all elements in
            the dataset. See `PyTorch docs
            <https://pytorch.org/docs/stable/torchvision/transforms.html>`_
            for a list of possible transforms. When using transforms.Resize
            and transforms.RandomHorizontalFlip, all box coordinates are
            automatically adjusted to match the modified image. If None,
            defaults to the transforms returned by
            :func:`detecto.utils.default_transforms`.
        :type transform: torchvision.transforms.Compose or None

        **Indexing**:

        A Dataset object can be indexed like any other Python iterable.
        Doing so returns a tuple of length 2. The first element is the
        image and the second element is a dict containing a 'boxes' and
        'labels' key. ``dict['boxes']`` is a torch.Tensor of size
        ``(1, 4)`` containing ``xmin``, ``ymin``, ``xmax``, and ``ymax``
        of the box and ``dict['labels']`` is the string label of the
        detected object.

        **Example**::

            >>> from detecto.core import Dataset

            >>> # Create dataset from separate XML and image folders
            >>> dataset = Dataset('xml_labels/', 'images/')
            >>> # Create dataset from a combined XML and image folder
            >>> dataset1 = Dataset('images_and_labels/')
            >>> # Create dataset from a CSV file and image folder
            >>> dataset2 = Dataset('labels.csv', 'images/')

            >>> print(len(dataset))
            >>> image, target = dataset[0]
            >>> print(image.shape)
            >>> print(target)
            4
            torch.Size([3, 720, 1280])
            {'boxes': tensor([[564, 43, 736, 349]]), 'labels': 'balloon'}
        """

        # CSV file contains: filename, width, height, class, xmin, ymin, xmax, ymax
        if os.path.isfile(label_data):
            self._csv = pd.read_csv(label_data)
        else:
            self._csv = xml_to_csv(label_data)

        # If image folder not given, set it to labels folder
        if image_folder is None:
            self._root_dir = label_data
        else:
            self._root_dir = image_folder

        if transform is None:
            self.transform = default_transforms()
        else:
            self.transform = transform