Exemple #1
0
    def apart(test_size, dataset, classify):
        """
        test_size:从每个class值中取多少张作为测试集
        """
        assert 0 < test_size < 1, "比例啦"

        print("本次运行{}分类".format('使用' if classify else '只分割不'))
        # 读取文件名
        file_names = {}
        classes = []
        dir = Path.db_root_dir(dataset)
        for root, dirs, files in os.walk(os.path.join(dir, 'masks')):
            # 如果是/mask的一级子文件夹
            if os.path.split(root)[0] == os.path.join(dir, 'masks'):
                if dataset == 'all':
                    if classify:
                        # 分类True bad和False bad
                        _class = 'false' if 'fix' in os.path.split(
                            root)[1] else 'true'
                    else:
                        # 只有一类
                        _class = 'image'
                    if _class not in classes:
                        classes.append(_class)
                elif dataset == 'mydataset':
                    if classify:
                        _class = os.path.split(root)[1]  # 子文件夹名
                    else:
                        _class = 'image'
                    if _class not in classes:
                        classes.append(_class)
                else:
                    raise NotImplementedError

                for file in files:
                    if _class not in file_names.keys():
                        file_names[_class] = []
                    file_names[_class].append(
                        (os.path.splitext(file)[0], os.path.basename(root)))

        def rand_select(data, size):
            test = []
            index = [i for i in range(len(data))]
            idx = random.sample(index, size)
            for i in idx:
                test.append(data[i])
            train = [data[i] for i in index if i not in idx]
            return train, test

        train_data, test_data = {}, {}
        for _class in classes:
            train_data[_class], test_data[_class] = rand_select(
                file_names[_class], int(len(file_names[_class]) * test_size))
            print("{}:训练{}张,测试{}张".format(_class, len(train_data[_class]),
                                          len(test_data[_class])))

        return train_data, test_data
Exemple #2
0
 def __init__(self):
     self.root = Path.db_root_dir('penn')
     self.transform = transforms.Compose(
         [tr.FixedResize(550), tr.ToTensor()])
     self.PNGImagesRoot = "PNGImages"
     self.PedMasksRoot = "PedMasks"
     # 加载image和mask,sort是为了保证他们一一对应
     self.imgs = list(
         sorted(os.listdir(os.path.join(self.root, self.PNGImagesRoot))))
     self.masks = list(
         sorted(os.listdir(os.path.join(self.root, self.PedMasksRoot))))
Exemple #3
0
 def __init__(self, file_names, dataset):
     self.mean = (0.5071, 0.4867, 0.4408)
     self.stdv = (0.2675, 0.2565, 0.2761)
     self.dir = Path.db_root_dir(dataset)
     self.dataset = dataset
     files = []
     self.classes = list(file_names.keys())
     for _class in self.classes:
         # file_names [(filename0, dir0), (filename1, dir1), (filename2, dir2)...]
         for file_dir in file_names[_class]:
             files.append([file_dir, _class])
     self.files = sorted(files)
Exemple #4
0
    def __init__(
            self,
            args,
            base_dir=Path.db_root_dir('pascal'),
            split='train',
    ):
        """
        :param base_dir: path to VOC dataset directory
        :param split: train/val
        :param transform: transform to apply
        """
        super().__init__()
        self._base_dir = base_dir
        self._image_dir = os.path.join(self._base_dir, 'JPEGImages')
        self._cat_dir = os.path.join(self._base_dir, 'SegmentationClass')

        if isinstance(split, str):
            self.split = [split]
        else:
            split.sort()
            self.split = split

        self.args = args

        _splits_dir = os.path.join(self._base_dir, 'ImageSets', 'Segmentation')

        self.im_ids = []
        self.images = []
        self.categories = []

        for splt in self.split:
            with open(os.path.join(os.path.join(_splits_dir, splt + '.txt')),
                      "r") as f:
                lines = f.read().splitlines()

            for ii, line in enumerate(lines):
                _image = os.path.join(self._image_dir, line + ".jpg")
                _cat = os.path.join(self._cat_dir, line + ".png")
                assert os.path.isfile(_image)
                assert os.path.isfile(_cat)
                self.im_ids.append(line)
                self.images.append(_image)
                self.categories.append(_cat)

        assert (len(self.images) == len(self.categories))

        # Display stats
        print('Number of images in {}: {:d}'.format(split, len(self.images)))
Exemple #5
0
    def __init__(
            self,
            args,
            base_dir=Path.db_root_dir('sbd'),
            split='train',
    ):
        """
        :param base_dir: path to VOC dataset directory
        :param split: train/val
        :param transform: transform to apply
        """
        super().__init__()
        self._base_dir = base_dir
        self._dataset_dir = os.path.join(self._base_dir, 'dataset')
        self._image_dir = os.path.join(self._dataset_dir, 'img')
        self._cat_dir = os.path.join(self._dataset_dir, 'cls')

        if isinstance(split, str):
            self.split = [split]
        else:
            split.sort()
            self.split = split

        self.args = args

        # Get list of all images from the split and check that the files exist
        self.im_ids = []
        self.images = []
        self.categories = []
        for splt in self.split:
            with open(os.path.join(self._dataset_dir, splt + '.txt'),
                      "r") as f:
                lines = f.read().splitlines()

            for line in lines:
                _image = os.path.join(self._image_dir, line + ".jpg")
                _categ = os.path.join(self._cat_dir, line + ".mat")
                assert os.path.isfile(_image)
                assert os.path.isfile(_categ)
                self.im_ids.append(line)
                self.images.append(_image)
                self.categories.append(_categ)

        assert (len(self.images) == len(self.categories))

        # Display stats
        print('Number of images: {:d}'.format(len(self.images)))
Exemple #6
0
 def __init__(self,
              args,
              base_dir=Path.db_root_dir('coco'),
              split='train',
              year='2017'):
     super().__init__()
     ann_file = os.path.join(
         base_dir, 'annotations/instances_{}{}.json'.format(split, year))
     ids_file = os.path.join(
         base_dir, 'annotations/{}_ids_{}.pth'.format(split, year))
     self.img_dir = os.path.join(base_dir,
                                 'images/{}{}'.format(split, year))
     self.split = split
     self.coco = COCO(ann_file)
     self.coco_mask = mask
     if os.path.exists(ids_file):
         self.ids = torch.load(ids_file)
     else:
         ids = list(self.coco.imgs.keys())
         self.ids = self._preprocess(ids, ids_file)
     self.args = args
Exemple #7
0
    def __init__(self,
                 args,
                 root=Path.db_root_dir('cityscapes'),
                 split="train"):

        self.root = root
        self.split = split
        self.args = args
        self.files = {}

        self.images_base = os.path.join(self.root, 'leftImg8bit', self.split)
        self.annotations_base = os.path.join(self.root, 'gtFine_trainvaltest',
                                             'gtFine', self.split)

        self.files[split] = self.recursive_glob(rootdir=self.images_base,
                                                suffix='.png')

        self.void_classes = [
            0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1
        ]
        self.valid_classes = [
            7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31,
            32, 33
        ]
        self.class_names = ['unlabelled', 'road', 'sidewalk', 'building', 'wall', 'fence', \
                            'pole', 'traffic_light', 'traffic_sign', 'vegetation', 'terrain', \
                            'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', \
                            'motorcycle', 'bicycle']

        self.ignore_index = 255
        self.class_map = dict(zip(self.valid_classes, range(self.NUM_CLASSES)))

        if not self.files[split]:
            raise Exception("No files for split=[%s] found in %s" %
                            (split, self.images_base))

        print("Found %d %s images" % (len(self.files[split]), split))
Exemple #8
0
def calculate_weigths_labels(dataset, dataloader, num_classes):
    # Create an instance from the data loader
    z = np.zeros((num_classes,))
    # Initialize tqdm
    tqdm_batch = tqdm(dataloader)
    print('Calculating classes weights')
    for sample in tqdm_batch:
        y = sample['label']
        y = y.detach().cpu().numpy()
        mask = (y >= 0) & (y < num_classes)
        labels = y[mask].astype(np.uint8)
        count_l = np.bincount(labels, minlength=num_classes)
        z += count_l
    tqdm_batch.close()
    total_frequency = np.sum(z)
    class_weights = []
    for frequency in z:
        class_weight = 1 / (np.log(1.02 + (frequency / total_frequency)))
        class_weights.append(class_weight)
    ret = np.array(class_weights)
    classes_weights_path = os.path.join(Path.db_root_dir(dataset), dataset + '_classes_weights.npy')
    np.save(classes_weights_path, ret)

    return ret
Exemple #9
0
    def __init__(self, args, train=False):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        self.save_dir = os.path.join('../save', 'checkpoint.pth.tar')
        self.cuda = args.cuda
        self.sync_train = args.sync_train
        print("同时训练两个分支" if self.sync_train else "交替训练两个分支")

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        if train:
            self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)
        else:
            self.nclass = 3

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)
        # 梯度全开
        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]
        # 关闭2的梯度
        model.set_requires_grad([2], False)
        train_params1 = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                         {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]
        # 打开2的梯度,关闭1的梯度
        model.set_requires_grad([1, 2], True)
        model.set_requires_grad([1], False)
        train_params2 = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                         {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]
        model.set_requires_grad([1, 2], True)

        # Define Optimizer
        self.optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                         weight_decay=args.weight_decay, nesterov=args.nesterov)
        self.optimizer1 = torch.optim.SGD(train_params1, momentum=args.momentum,
                                          weight_decay=args.weight_decay, nesterov=args.nesterov)
        self.optimizer2 = torch.optim.SGD(train_params2, momentum=args.momentum,
                                          weight_decay=args.weight_decay, nesterov=args.nesterov)
        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float16))
        else:
            weight = None
        self.criterion1 = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.criterion2 = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model = model

        # Define Evaluator
        self.evaluator1 = Evaluator(self.nclass)
        self.evaluator2 = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                      args.epochs, len(self.train_loader) if train else 4)
        self.scheduler1 = LR_Scheduler(args.lr_scheduler, args.lr,
                                       args.epochs, len(self.train_loader) if train else 4)
        self.scheduler2 = LR_Scheduler(args.lr_scheduler, args.lr,
                                       args.epochs, len(self.train_loader) if train else 4)

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0