def testFlip(): base_dir = Path.db_root_dir('rssrai') image_dir = os.path.join(base_dir, 'test') save_dir = save_path(os.path.join(base_dir, 'test_output')) _img_path_list = glob(os.path.join(image_dir, '*.tif')) img_name_list = [name.split('/')[-1] for name in _img_path_list] pprint(img_name_list) pprint(image_dir) rssraiImage = RssraiTestOneImage(img_name_list[0], image_dir, save_dir, 10, 4) type = "origin" # type = "vertical" # type = "horizontal" print(rssraiImage.images[type].shape) for dateSet in rssraiImage.get_slide_dataSet(rssraiImage.images[type]): print(dateSet) for i in dateSet: shape = i['image'].shape print(shape) # i['image'] = torch.zeros(shape[0], rssraiImage.num_classes, shape[2], shape[3]).cuda() # rssraiImage.fill_image(type, i) rssraiImage.saveResultRGBImage()
def test_spilt_valid_image(): # import pandas as pd # df = pd.read_csv("valid_set.csv") # name_list = df["文件名"].values.tolist() # print(name_list) base_path = Path.db_root_dir("rssrai") save_label_path = os.path.join(base_path, "split_valid_256", "label") if not os.path.exists(save_label_path): os.makedirs(save_label_path) label_path = os.path.join(base_path, "split_valid", "label") save_image_path = os.path.join(base_path, "split_valid_256", "img") if not os.path.exists(save_image_path): os.makedirs(save_image_path) image_path = os.path.join(base_path, "split_valid", "img") label_name_list = [ path_name.split("/")[-1] for path_name in glob(os.path.join(label_path, "*")) ] print(len(label_name_list)) for label_name in tqdm(label_name_list): image_name = label_name.replace("_label", "") # print(image_name) # print(label_name) split_image(label_path, label_name, save_label_path, mode="RGB") split_image(image_path, image_name, save_image_path, mode="CMYK")
def test_one_merge_image(): base_path = Path.db_root_dir("rssrai") # 图片 image_path = os.path.join(base_path, "split_test_256") save_image_path = os.path.join(base_path, "merge_test", "img") if not os.path.exists(save_image_path): os.makedirs(save_image_path) merge_image(image_path, "GF2_PMS1__20150902_L1A0001015646-MSS1.tif", save_image_path, "CMYK")
def testOneImage(): base_path = Path.db_root_dir("rssrai") path = '/home/arron/Documents/grey/Project_Rssrai/rssrai/train/img' name = 'GF2_PMS1__20150212_L1A0000647768-MSS1.tif' file_image = Image.open(os.path.join(path, name)) np_image = np.array(file_image)[:, :, 1:] image = Image.fromarray(np_image.astype('uint8')).convert("RGB") image.save(os.path.join(base_path, name))
def test_spilt_test_image(): base_path = Path.db_root_dir("rssrai") # 图片 image_path = os.path.join(base_path, "test") image_list = glob(os.path.join(image_path, "*.tif")) save_image_path = os.path.join(base_path, "split_test_256", "img") spilt_all_images(image_list, save_image_path, mode="CMYK", output_image_h_w=(256, 256))
def test_merge_images(): import pandas as pd base_path = Path.db_root_dir("rssrai") df = pd.read_csv("test_name_list.csv") name_list = df['name'].tolist() # 图片 image_path = os.path.join(base_path, "temp_test", "img") save_image_path = os.path.join(base_path, "merge_test", "img") if not os.path.exists(save_image_path): os.makedirs(save_image_path) for name in tqdm(name_list): merge_image(image_path, name, save_image_path, "CMYK")
def test_one_spilt_test_image(): base_path = Path.db_root_dir("rssrai") image_path = os.path.join(base_path, "test") name = 'GF2_PMS1__20150902_L1A0001015646-MSS1.tif' save_image_path = os.path.join(base_path, "split_test_256") if not os.path.exists(save_image_path): os.makedirs(save_image_path) # 图片 split_image(image_path, name, save_image_path, mode="CMYK", output_image_h_w=(256, 256))
def testData(): plt.rcParams['savefig.dpi'] = 500 # 图片像素 plt.rcParams['figure.dpi'] = 500 # 分辨率 test_path = os.path.join(Path().db_root_dir("rssrai"), "测试输出") if not os.path.exists(test_path): os.makedirs(test_path) rssrai = Rssrai(type="train") for i in rssrai: pprint(i["image"].shape) pprint(i["label"].shape) break data_loader = DataLoader(rssrai, batch_size=4, shuffle=True, num_workers=4) for ii, sample in enumerate(data_loader): print(sample['image'].shape) sample['image'] = sample['image'][:, 1:, :, :] for jj in range(sample["image"].size()[0]): img = sample['image'].numpy() gt = sample['label'].numpy() img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) tmp = gt[jj] segmap = decode_segmap(tmp) img_tmp *= rssrai.std[1:] img_tmp += rssrai.mean[1:] img_tmp *= 255.0 img_tmp = img_tmp.astype(np.uint8) plt.figure() plt.title('display') plt.subplot(121) plt.imshow(img_tmp) plt.subplot(122) plt.imshow(segmap) # with open( f"{test_path}/rssrai-{ii}-{jj}.txt", "w" ) as f: # f.write( str( img_tmp ) ) # f.write( str( tmp ) ) # f.write( str( segmap ) ) plt.savefig(f"{test_path}/rssrai-{ii}-{jj}.jpg") plt.close('all') if ii == 3: break plt.show(block=True)
def __init__(self, type, base_size=(513, 513), crop_size=(256, 256), base_dir=Path.db_root_dir('pascal')): """ :param base_dir: path to VOC dataset directory :param type: train/val :param transform: transform to apply """ assert type in ['train', 'val'] super().__init__() self._base_dir = base_dir self._image_dir = os.path.join(self._base_dir, 'JPEGImages') self._cat_dir = os.path.join(self._base_dir, 'SegmentationClass') self.split = type self.base_size = base_size self.crop_size = crop_size self.mean = (0.485, 0.456, 0.406) self.std = (0.229, 0.224, 0.225) self.im_ids = [] self.images = [] self.categories = [] _splits_dir = os.path.join(self._base_dir, 'ImageSets', 'Segmentation') with open(os.path.join(os.path.join(_splits_dir, self.split + '.txt')), "r") as f: lines = f.read().splitlines() for ii, line in enumerate(lines): _image = os.path.join(self._image_dir, line + ".jpg") _cat = os.path.join(self._cat_dir, line + ".png") assert os.path.isfile(_image) assert os.path.isfile(_cat) self.im_ids.append(line) self.images.append(_image) self.categories.append(_cat) assert (len(self.images) == len(self.categories)) # Display stats print('Number of images in {}: {:d}'.format(type, len(self.images)))
def test_spilt_train_image(): base_path = Path.db_root_dir("rssrai") # 图片 image_path = os.path.join(base_path, "train", "img") image_list = glob(os.path.join(image_path, "*.tif")) save_image_path = os.path.join(base_path, "split_train", "img") spilt_all_images(image_list, save_image_path, mode="CMYK", output_image_h_w=(680, 720)) # 标签 label_path = os.path.join(base_path, "train", "label") label_list = glob(os.path.join(label_path, "*.tif")) save_label_path = os.path.join(base_path, "split_train", "label") spilt_all_images(label_list, save_label_path, mode="RGB", output_image_h_w=(680, 720))
def testGetValid(): base_path = Path.db_root_dir("rssrai") label_path = os.path.join(base_path, "split_train", "label") img_path = os.path.join(base_path, "split_train", "img") valid_label_path = os.path.join(base_path, "split_valid", "label") valid_img_path = os.path.join(base_path, "split_valid", "img") shutil.rmtree(valid_label_path) os.makedirs(valid_label_path) shutil.rmtree(valid_img_path) os.makedirs(valid_img_path) label_name_list = [ path_name.split("/")[-1] for path_name in glob(os.path.join(label_path, "*")) ] random.shuffle(label_name_list) print(len(label_name_list)) valid_label_name_list = label_name_list[850:] pprint(len(valid_label_name_list)) for label_name in valid_label_name_list: img_name = label_name.replace("_label", "") print(valid_label_path, label_name) print(valid_img_path, img_name) shutil.move(os.path.join(label_path, label_name), os.path.join(valid_label_path, label_name)) shutil.move(os.path.join(img_path, img_name), os.path.join(valid_img_path, img_name))
def testSplitLabel(): print(color_name_map) name_list = ["文件名"] for _, v in color_name_map.items(): name_list.append(v) print(name_list) all_statistic_list = [] base_path = Path.db_root_dir("rssrai") image_path = os.path.join(base_path, "split", "label") from glob import glob image_list = glob(os.path.join(image_path, "*.tif")) # 多进程 pool = Pool(16) for image in tqdm(image_list): list = image.split("/") path = "/".join(list[:-1]) name = list[-1] result = pool.apply_async(statistic_label, args=(path, name)) all_statistic_list.append(result.get()) df = pd.DataFrame(all_statistic_list, columns=name_list) print(df) df.to_csv(os.path.join(base_path, "split_label.csv"))
del rssraiImage del self.model gc.collect() if __name__ == "__main__": args = Options().parse() args.dataset = 'rssrai' args.model = 'FCN' args.backbone = 'resnet50' args.check_point_id = 1 args.batch_size = 500 print(args) now = f"{time.localtime(time.time()).tm_year}-{time.localtime(time.time()).tm_mon}-{time.localtime(time.time()).tm_mday}-" now += f"{time.localtime(time.time()).tm_hour}-{time.localtime(time.time()).tm_min}-{time.localtime(time.time()).tm_sec}" tester = Tester() base_dir = Path.db_root_dir('rssrai') image_dir = os.path.join(base_dir, 'train', 'img') save_dir = save_path( os.path.join(base_dir, f'test_output_model={args.model}_time={now}')) _img_path_list = glob(os.path.join(image_dir, '*.tif')) img_name_list = [name.split('/')[-1] for name in _img_path_list] pprint(img_name_list) for index, name in enumerate(img_name_list): print(f"{index}:{name}") tester.test(name, image_dir, save_dir)
def __init__(self, type='train', base_size=(512, 512), crop_size=(256, 256), base_dir=Path.db_root_dir('rssrai')): assert type in ['train', 'valid', 'test'] super().__init__() self._base_dir = base_dir self.type = type self.in_c = 4 self.mean = mean self.std = std self.crop_size = crop_size self.base_size = base_size self.im_ids = [] self.images = [] self.categories = [] # 加载数据 if self.type == 'train': # train_csv = os.path.join(self._base_dir, 'train_set.csv') # self._label_name_list = pd.read_csv(train_csv)["文件名"].values.tolist() self._label_path_list = glob( os.path.join(self._base_dir, 'split_train', 'label', '*.tif')) # print(self._label_path_list) self._label_name_list = [ name.split('/')[-1] for name in self._label_path_list ] # print(self._label_name_list) self._image_dir = os.path.join(self._base_dir, 'split_train', 'img') self._label_dir = os.path.join(self._base_dir, 'split_train', 'label') self.len = 20000 if self.type == 'valid': self._label_path_list = glob( os.path.join(self._base_dir, 'split_valid_256', 'label', '*.tif')) self._label_name_list = [ name.split('/')[-1] for name in self._label_path_list ] self._image_dir = os.path.join(self._base_dir, 'split_valid_256', 'img') self._label_dir = os.path.join(self._base_dir, 'split_valid_256', 'label') # self._label_name_list = pd.read_csv( valid_csv )["文件名"].values.tolist() self.len = len(self._label_name_list) if self.type == 'test': self._img_path_list = glob( os.path.join(self._base_dir, 'split_test_256', 'img', '*.tif')) self._img_name_list = [ name.split('/')[-1] for name in self._img_path_list ] self._image_dir = os.path.join(self._base_dir, 'split_test_256', 'img') self.len = len(self._img_path_list)