示例#1
0
def testFlip():
    base_dir = Path.db_root_dir('rssrai')
    image_dir = os.path.join(base_dir, 'test')
    save_dir = save_path(os.path.join(base_dir, 'test_output'))

    _img_path_list = glob(os.path.join(image_dir, '*.tif'))
    img_name_list = [name.split('/')[-1] for name in _img_path_list]
    pprint(img_name_list)
    pprint(image_dir)

    rssraiImage = RssraiTestOneImage(img_name_list[0], image_dir, save_dir, 10,
                                     4)

    type = "origin"
    # type = "vertical"
    # type = "horizontal"
    print(rssraiImage.images[type].shape)
    for dateSet in rssraiImage.get_slide_dataSet(rssraiImage.images[type]):
        print(dateSet)
        for i in dateSet:
            shape = i['image'].shape
            print(shape)
            # i['image'] = torch.zeros(shape[0], rssraiImage.num_classes, shape[2], shape[3]).cuda()
            # rssraiImage.fill_image(type, i)

    rssraiImage.saveResultRGBImage()
示例#2
0
def test_spilt_valid_image():
    # import pandas as pd
    # df = pd.read_csv("valid_set.csv")
    # name_list = df["文件名"].values.tolist()
    # print(name_list)

    base_path = Path.db_root_dir("rssrai")

    save_label_path = os.path.join(base_path, "split_valid_256", "label")
    if not os.path.exists(save_label_path):
        os.makedirs(save_label_path)
    label_path = os.path.join(base_path, "split_valid", "label")

    save_image_path = os.path.join(base_path, "split_valid_256", "img")
    if not os.path.exists(save_image_path):
        os.makedirs(save_image_path)
    image_path = os.path.join(base_path, "split_valid", "img")

    label_name_list = [
        path_name.split("/")[-1]
        for path_name in glob(os.path.join(label_path, "*"))
    ]

    print(len(label_name_list))

    for label_name in tqdm(label_name_list):
        image_name = label_name.replace("_label", "")
        # print(image_name)
        # print(label_name)
        split_image(label_path, label_name, save_label_path, mode="RGB")
        split_image(image_path, image_name, save_image_path, mode="CMYK")
示例#3
0
def test_one_merge_image():
    base_path = Path.db_root_dir("rssrai")

    # 图片
    image_path = os.path.join(base_path, "split_test_256")
    save_image_path = os.path.join(base_path, "merge_test", "img")
    if not os.path.exists(save_image_path):
        os.makedirs(save_image_path)
    merge_image(image_path, "GF2_PMS1__20150902_L1A0001015646-MSS1.tif",
                save_image_path, "CMYK")
示例#4
0
def testOneImage():
    base_path = Path.db_root_dir("rssrai")
    path = '/home/arron/Documents/grey/Project_Rssrai/rssrai/train/img'
    name = 'GF2_PMS1__20150212_L1A0000647768-MSS1.tif'

    file_image = Image.open(os.path.join(path, name))

    np_image = np.array(file_image)[:, :, 1:]

    image = Image.fromarray(np_image.astype('uint8')).convert("RGB")
    image.save(os.path.join(base_path, name))
示例#5
0
def test_spilt_test_image():
    base_path = Path.db_root_dir("rssrai")

    # 图片
    image_path = os.path.join(base_path, "test")
    image_list = glob(os.path.join(image_path, "*.tif"))
    save_image_path = os.path.join(base_path, "split_test_256", "img")
    spilt_all_images(image_list,
                     save_image_path,
                     mode="CMYK",
                     output_image_h_w=(256, 256))
示例#6
0
def test_merge_images():
    import pandas as pd
    base_path = Path.db_root_dir("rssrai")
    df = pd.read_csv("test_name_list.csv")
    name_list = df['name'].tolist()

    # 图片
    image_path = os.path.join(base_path, "temp_test", "img")
    save_image_path = os.path.join(base_path, "merge_test", "img")
    if not os.path.exists(save_image_path):
        os.makedirs(save_image_path)
    for name in tqdm(name_list):
        merge_image(image_path, name, save_image_path, "CMYK")
示例#7
0
def test_one_spilt_test_image():
    base_path = Path.db_root_dir("rssrai")
    image_path = os.path.join(base_path, "test")
    name = 'GF2_PMS1__20150902_L1A0001015646-MSS1.tif'
    save_image_path = os.path.join(base_path, "split_test_256")
    if not os.path.exists(save_image_path):
        os.makedirs(save_image_path)

    # 图片
    split_image(image_path,
                name,
                save_image_path,
                mode="CMYK",
                output_image_h_w=(256, 256))
示例#8
0
def testData():
    plt.rcParams['savefig.dpi'] = 500  # 图片像素
    plt.rcParams['figure.dpi'] = 500  # 分辨率

    test_path = os.path.join(Path().db_root_dir("rssrai"), "测试输出")
    if not os.path.exists(test_path):
        os.makedirs(test_path)

    rssrai = Rssrai(type="train")
    for i in rssrai:
        pprint(i["image"].shape)
        pprint(i["label"].shape)
        break
    data_loader = DataLoader(rssrai, batch_size=4, shuffle=True, num_workers=4)

    for ii, sample in enumerate(data_loader):
        print(sample['image'].shape)
        sample['image'] = sample['image'][:, 1:, :, :]
        for jj in range(sample["image"].size()[0]):
            img = sample['image'].numpy()
            gt = sample['label'].numpy()
            img_tmp = np.transpose(img[jj], axes=[1, 2, 0])
            tmp = gt[jj]
            segmap = decode_segmap(tmp)
            img_tmp *= rssrai.std[1:]
            img_tmp += rssrai.mean[1:]
            img_tmp *= 255.0
            img_tmp = img_tmp.astype(np.uint8)
            plt.figure()
            plt.title('display')
            plt.subplot(121)
            plt.imshow(img_tmp)
            plt.subplot(122)
            plt.imshow(segmap)
            # with open( f"{test_path}/rssrai-{ii}-{jj}.txt", "w" ) as f:
            #     f.write( str( img_tmp ) )
            #     f.write( str( tmp ) )
            #     f.write( str( segmap ) )
            plt.savefig(f"{test_path}/rssrai-{ii}-{jj}.jpg")
            plt.close('all')

        if ii == 3:
            break

    plt.show(block=True)
示例#9
0
    def __init__(self,
                 type,
                 base_size=(513, 513),
                 crop_size=(256, 256),
                 base_dir=Path.db_root_dir('pascal')):
        """
        :param base_dir: path to VOC dataset directory
        :param type: train/val
        :param transform: transform to apply
        """
        assert type in ['train', 'val']
        super().__init__()
        self._base_dir = base_dir
        self._image_dir = os.path.join(self._base_dir, 'JPEGImages')
        self._cat_dir = os.path.join(self._base_dir, 'SegmentationClass')

        self.split = type
        self.base_size = base_size
        self.crop_size = crop_size
        self.mean = (0.485, 0.456, 0.406)
        self.std = (0.229, 0.224, 0.225)

        self.im_ids = []
        self.images = []
        self.categories = []

        _splits_dir = os.path.join(self._base_dir, 'ImageSets', 'Segmentation')
        with open(os.path.join(os.path.join(_splits_dir, self.split + '.txt')),
                  "r") as f:
            lines = f.read().splitlines()
        for ii, line in enumerate(lines):
            _image = os.path.join(self._image_dir, line + ".jpg")
            _cat = os.path.join(self._cat_dir, line + ".png")
            assert os.path.isfile(_image)
            assert os.path.isfile(_cat)
            self.im_ids.append(line)
            self.images.append(_image)
            self.categories.append(_cat)

        assert (len(self.images) == len(self.categories))

        # Display stats
        print('Number of images in {}: {:d}'.format(type, len(self.images)))
示例#10
0
def test_spilt_train_image():
    base_path = Path.db_root_dir("rssrai")

    # 图片
    image_path = os.path.join(base_path, "train", "img")
    image_list = glob(os.path.join(image_path, "*.tif"))
    save_image_path = os.path.join(base_path, "split_train", "img")
    spilt_all_images(image_list,
                     save_image_path,
                     mode="CMYK",
                     output_image_h_w=(680, 720))

    # 标签
    label_path = os.path.join(base_path, "train", "label")
    label_list = glob(os.path.join(label_path, "*.tif"))
    save_label_path = os.path.join(base_path, "split_train", "label")
    spilt_all_images(label_list,
                     save_label_path,
                     mode="RGB",
                     output_image_h_w=(680, 720))
示例#11
0
def testGetValid():
    base_path = Path.db_root_dir("rssrai")

    label_path = os.path.join(base_path, "split_train", "label")

    img_path = os.path.join(base_path, "split_train", "img")

    valid_label_path = os.path.join(base_path, "split_valid", "label")

    valid_img_path = os.path.join(base_path, "split_valid", "img")

    shutil.rmtree(valid_label_path)
    os.makedirs(valid_label_path)

    shutil.rmtree(valid_img_path)
    os.makedirs(valid_img_path)

    label_name_list = [
        path_name.split("/")[-1]
        for path_name in glob(os.path.join(label_path, "*"))
    ]

    random.shuffle(label_name_list)

    print(len(label_name_list))

    valid_label_name_list = label_name_list[850:]

    pprint(len(valid_label_name_list))

    for label_name in valid_label_name_list:
        img_name = label_name.replace("_label", "")
        print(valid_label_path, label_name)
        print(valid_img_path, img_name)
        shutil.move(os.path.join(label_path, label_name),
                    os.path.join(valid_label_path, label_name))
        shutil.move(os.path.join(img_path, img_name),
                    os.path.join(valid_img_path, img_name))
示例#12
0
def testSplitLabel():
    print(color_name_map)
    name_list = ["文件名"]
    for _, v in color_name_map.items():
        name_list.append(v)
    print(name_list)
    all_statistic_list = []
    base_path = Path.db_root_dir("rssrai")
    image_path = os.path.join(base_path, "split", "label")
    from glob import glob
    image_list = glob(os.path.join(image_path, "*.tif"))
    # 多进程
    pool = Pool(16)
    for image in tqdm(image_list):
        list = image.split("/")
        path = "/".join(list[:-1])
        name = list[-1]

        result = pool.apply_async(statistic_label, args=(path, name))
        all_statistic_list.append(result.get())

    df = pd.DataFrame(all_statistic_list, columns=name_list)
    print(df)
    df.to_csv(os.path.join(base_path, "split_label.csv"))
示例#13
0
        del rssraiImage
        del self.model
        gc.collect()


if __name__ == "__main__":
    args = Options().parse()

    args.dataset = 'rssrai'
    args.model = 'FCN'
    args.backbone = 'resnet50'
    args.check_point_id = 1
    args.batch_size = 500

    print(args)
    now = f"{time.localtime(time.time()).tm_year}-{time.localtime(time.time()).tm_mon}-{time.localtime(time.time()).tm_mday}-"
    now += f"{time.localtime(time.time()).tm_hour}-{time.localtime(time.time()).tm_min}-{time.localtime(time.time()).tm_sec}"
    tester = Tester()

    base_dir = Path.db_root_dir('rssrai')
    image_dir = os.path.join(base_dir, 'train', 'img')
    save_dir = save_path(
        os.path.join(base_dir, f'test_output_model={args.model}_time={now}'))

    _img_path_list = glob(os.path.join(image_dir, '*.tif'))
    img_name_list = [name.split('/')[-1] for name in _img_path_list]
    pprint(img_name_list)
    for index, name in enumerate(img_name_list):
        print(f"{index}:{name}")
        tester.test(name, image_dir, save_dir)
示例#14
0
    def __init__(self,
                 type='train',
                 base_size=(512, 512),
                 crop_size=(256, 256),
                 base_dir=Path.db_root_dir('rssrai')):

        assert type in ['train', 'valid', 'test']
        super().__init__()
        self._base_dir = base_dir
        self.type = type
        self.in_c = 4
        self.mean = mean
        self.std = std
        self.crop_size = crop_size
        self.base_size = base_size
        self.im_ids = []
        self.images = []
        self.categories = []

        # 加载数据
        if self.type == 'train':
            # train_csv = os.path.join(self._base_dir, 'train_set.csv')
            # self._label_name_list = pd.read_csv(train_csv)["文件名"].values.tolist()
            self._label_path_list = glob(
                os.path.join(self._base_dir, 'split_train', 'label', '*.tif'))
            # print(self._label_path_list)
            self._label_name_list = [
                name.split('/')[-1] for name in self._label_path_list
            ]
            # print(self._label_name_list)
            self._image_dir = os.path.join(self._base_dir, 'split_train',
                                           'img')
            self._label_dir = os.path.join(self._base_dir, 'split_train',
                                           'label')

            self.len = 20000

        if self.type == 'valid':
            self._label_path_list = glob(
                os.path.join(self._base_dir, 'split_valid_256', 'label',
                             '*.tif'))
            self._label_name_list = [
                name.split('/')[-1] for name in self._label_path_list
            ]
            self._image_dir = os.path.join(self._base_dir, 'split_valid_256',
                                           'img')
            self._label_dir = os.path.join(self._base_dir, 'split_valid_256',
                                           'label')
            # self._label_name_list = pd.read_csv( valid_csv )["文件名"].values.tolist()

            self.len = len(self._label_name_list)

        if self.type == 'test':
            self._img_path_list = glob(
                os.path.join(self._base_dir, 'split_test_256', 'img', '*.tif'))
            self._img_name_list = [
                name.split('/')[-1] for name in self._img_path_list
            ]
            self._image_dir = os.path.join(self._base_dir, 'split_test_256',
                                           'img')
            self.len = len(self._img_path_list)