コード例 #1
0
    def __init__(self, site, **kwargs):
        if site == 'ISBI' or site == 'ISBI_15':
            self.num_classes = 3
        else:
            self.num_classes = 2

        self.palette = pallete.get_voc_pallete(self.num_classes)
        super(MRI_dataset, self).__init__(site, **kwargs)
コード例 #2
0
def main():
    args = parse_arguments()

    # CONFIG
    assert args.config
    config = json.load(open(args.config))
    scales = [0.5, 0.75, 1.0, 1.25, 1.5]

    # DATA
    testdataset = testDataset(args.images)
    loader = DataLoader(testdataset,
                        batch_size=1,
                        shuffle=False,
                        num_workers=1)
    num_classes = config['num_classes']
    palette = get_voc_pallete(num_classes)

    # MODEL
    config['model']['supervised'] = True
    config['model']['semi'] = False
    model = models.CCT(num_classes=num_classes,
                       conf=config['model'],
                       testing=True)
    checkpoint = torch.load(args.model)
    model = torch.nn.DataParallel(model)
    try:
        model.load_state_dict(checkpoint['state_dict'], strict=True)
    except Exception as e:
        print(f'Some modules are missing: {e}')
        model.load_state_dict(checkpoint['state_dict'], strict=False)
    model.eval()
    model.cuda()

    #if args.save and not os.path.exists('outputs'):
    #    os.makedirs('outputs')
    if not os.path.exists(args.save):
        os.makedirs(args.save)

    # LOOP OVER THE DATA
    tbar = tqdm(loader, ncols=100)
    total_inter, total_union, total_correct, total_label = 0, 0, 0, 0
    labels, predictions = [], []

    for index, data in enumerate(tbar):
        image, image_id = data
        image = image.cuda()

        # PREDICT
        with torch.no_grad():
            output = multi_scale_predict(model, image, scales, num_classes)
        prediction = np.asarray(np.argmax(output, axis=0), dtype=np.uint8)

        # SAVE RESULTS
        prediction_im = colorize_mask(prediction, palette)
        prediction_im.save(args.save + '/' + image_id[0] + '.png')
コード例 #3
0
    def __init__(self, **kwargs):
        self.num_classes = 19

        self.datalist = kwargs.pop("datalist")
        self.stride = kwargs.pop('stride')
        self.iou_bound = kwargs.pop('iou_bound')

        self.palette = pallete.get_voc_pallete(self.num_classes)
        super(PairCityDataset, self).__init__(**kwargs)

        self.train_transform = transforms.Compose([
            transforms.ToPILImage(),
            RandomGaussianBlur(),
            transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.ToTensor(),
            self.normalize,
        ])
コード例 #4
0
    def __init__(self, **kwargs):
        # self.num_classes = 3
        view_idx = kwargs['view_idx']
        number_views = kwargs['number_views']

        self.num_classes = 2 + 10
        # self.palette = pallete.get_voc_palette(self.num_classes)
        self.palette = pallete.get_voc_pallete(self.num_classes)
        self.number_views = number_views
        self.view_idx = view_idx
        if not isinstance(view_idx, int):
            raise ValueError('view_idx: {}'.format(view_idx))
        self.view_key_img = "frames views " + str(self.view_idx)
        self.view_key_seg = "seg " + str(self.view_idx)
        assert isinstance(view_idx, int) and isinstance(number_views, int)
        super(MuiltivwDataset, self).__init__(**kwargs)
        print('data dir {}, view idx {}, num views'.format(
            self.root, view_idx, number_views))
コード例 #5
0
    def __init__(self, **kwargs):
        self.num_classes = 19

        self.datalist = kwargs.pop("datalist")
        self.palette = pallete.get_voc_pallete(self.num_classes)
        super(CityDataset, self).__init__(**kwargs)
コード例 #6
0
ファイル: voc.py プロジェクト: Pele324/TissueSeg
    def __init__(self, **kwargs):
        self.num_classes = 21

        self.palette = pallete.get_voc_pallete(self.num_classes)
        super(VOCDataset, self).__init__(**kwargs)
コード例 #7
0
ファイル: custom.py プロジェクト: saramsv/TCT
    def __init__(self, **kwargs):
        self.num_classes = 7  #Sara: 7 body parts and background #21

        self.palette = pallete.get_voc_pallete(self.num_classes)
        super(CUS_Dataset, self).__init__(**kwargs)