示例#1
0
class TestDataset:
    def __init__(self, opt, split='test', use_difficult=True):
        self.opt = opt
        if self.opt.data == 'markers':
            self.db = MarkersDataset(opt.data_dir)
        elif self.opt.data == 'voc':
            self.db = VOCBboxDataset(opt.data_dir)
        else:
            raise Exception('database type not recognised: {}'.format(
                self.opt.data))
        self.label_names = self.db.get_label_names()

    def __getitem__(self, idx):
        ori_img, bbox, label, difficult = self.db.get_example(idx)
        img = preprocess(ori_img)
        return img, ori_img.shape[1:], bbox, label, difficult

    def __len__(self):
        return len(self.db)
示例#2
0
class Dataset:
    def __init__(self, opt):
        self.opt = opt
        if self.opt.data == 'markers':
            self.db = MarkersDataset(opt.data_dir)
        elif self.opt.data == 'voc':
            self.db = VOCBboxDataset(opt.data_dir)
        else:
            raise Exception('database type not recognised: {}'.format(
                self.opt.data))

        self.tsf = Transform(opt.min_size, opt.max_size)
        self.label_names = self.db.get_label_names()

    def __getitem__(self, idx):
        ori_img, bbox, label, difficult = self.db.get_example(idx)

        img, bbox, label, scale = self.tsf((ori_img, bbox, label))
        # TODO: check whose stride is negative to fix this instead copy all
        # some of the strides of a given numpy array are negative.
        return img.copy(), bbox.copy(), label.copy(), scale

    def __len__(self):
        return len(self.db)