예제 #1
0
def testFlip():
    base_dir = Path.db_root_dir('rssrai')
    image_dir = os.path.join(base_dir, 'test')
    save_dir = save_path(os.path.join(base_dir, 'test_output'))

    _img_path_list = glob(os.path.join(image_dir, '*.tif'))
    img_name_list = [name.split('/')[-1] for name in _img_path_list]
    pprint(img_name_list)
    pprint(image_dir)

    rssraiImage = RssraiTestOneImage(img_name_list[0], image_dir, save_dir, 10,
                                     4)

    type = "origin"
    # type = "vertical"
    # type = "horizontal"
    print(rssraiImage.images[type].shape)
    for dateSet in rssraiImage.get_slide_dataSet(rssraiImage.images[type]):
        print(dateSet)
        for i in dateSet:
            shape = i['image'].shape
            print(shape)
            # i['image'] = torch.zeros(shape[0], rssraiImage.num_classes, shape[2], shape[3]).cuda()
            # rssraiImage.fill_image(type, i)

    rssraiImage.saveResultRGBImage()
예제 #2
0
def test_one_merge_image():
    base_path = Path.db_root_dir("rssrai")

    # 图片
    image_path = os.path.join(base_path, "split_test_520")
    save_image_path = os.path.join(base_path, "merge_test", "img")
    if not os.path.exists(save_image_path):
        os.makedirs(save_image_path)
    merge_image(image_path, "GF2_PMS1__20150902_L1A0001015646-MSS1.tif",
                save_image_path, "CMYK")
예제 #3
0
def testOneImage():
    base_path = Path.db_root_dir("rssrai")
    path = '/home/arron/Documents/grey/Project_Rssrai/rssrai/train/img'
    name = 'GF2_PMS1__20150212_L1A0000647768-MSS1.tif'

    file_image = Image.open(os.path.join(path, name))

    np_image = np.array(file_image)[:, :, 1:]

    image = Image.fromarray(np_image.astype('uint8')).convert("RGB")
    image.save(os.path.join(base_path, name))
예제 #4
0
def test_spilt_test_image():
    base_path = Path.db_root_dir("rssrai")

    # 图片
    image_path = os.path.join(base_path, "test")
    image_list = glob(os.path.join(image_path, "*.tif"))
    save_image_path = os.path.join(base_path, "split_test_520", "img")
    spilt_all_images(image_list,
                     save_image_path,
                     mode="CMYK",
                     output_image_h_w=(520, 520))
예제 #5
0
def test_merge_images():
    import pandas as pd
    base_path = Path.db_root_dir("rssrai")
    df = pd.read_csv("test_name_list.csv")
    name_list = df['name'].tolist()

    # 图片
    image_path = os.path.join(base_path, "temp_test", "img")
    save_image_path = os.path.join(base_path, "merge_test", "img")
    if not os.path.exists(save_image_path):
        os.makedirs(save_image_path)
    for name in tqdm(name_list):
        merge_image(image_path, name, save_image_path, "CMYK")
예제 #6
0
def spilt_image(split, output_image_h_w):
    base_path = Path.db_root_dir("rssrai")
    import pandas as pd
    df = pd.read_csv(os.path.join(base_path, f"{split}_set.csv"))
    label_name_list = df["文件名"].values.tolist()
    print(label_name_list)

    base_path = Path.db_root_dir("rssrai")

    save_label_path = os.path.join(base_path,
                                   f"split_{split}_{output_image_h_w[0]}",
                                   "label")
    make_sure_path_exists(save_label_path)
    label_path = os.path.join(base_path, "split_train", "label")

    save_image_path = os.path.join(base_path,
                                   f"split_{split}_{output_image_h_w[0]}",
                                   "img")
    make_sure_path_exists(save_image_path)
    image_path = os.path.join(base_path, "split_train", "img")

    # label_name_list = [path_name.split("/")[-1] for path_name in glob(os.path.join(label_path, "*"))]

    print(len(label_name_list))

    for label_name in tqdm(label_name_list):
        image_name = label_name.replace("_label", "")
        # print(image_name)
        # print(label_name)
        split_image(label_path,
                    label_name,
                    save_label_path,
                    mode="RGB",
                    output_image_h_w=output_image_h_w)
        split_image(image_path,
                    image_name,
                    save_image_path,
                    mode="CMYK",
                    output_image_h_w=output_image_h_w)
예제 #7
0
def test_one_spilt_test_image():
    base_path = Path.db_root_dir("rssrai")
    image_path = os.path.join(base_path, "test")
    name = 'GF2_PMS1__20150902_L1A0001015646-MSS1.tif'
    save_image_path = os.path.join(base_path, "split_test_520")
    if not os.path.exists(save_image_path):
        os.makedirs(save_image_path)

    # 图片
    split_image(image_path,
                name,
                save_image_path,
                mode="CMYK",
                output_image_h_w=(520, 520))
예제 #8
0
def test_spilt_train_image():
    base_path = Path.db_root_dir("rssrai")

    # 图片
    image_path = os.path.join(base_path, "train", "img")
    image_list = glob(os.path.join(image_path, "*.tif"))
    save_image_path = os.path.join(base_path, "split_train", "img")
    spilt_all_images(image_list,
                     save_image_path,
                     mode="CMYK",
                     output_image_h_w=(680, 720))

    # 标签
    label_path = os.path.join(base_path, "train", "label")
    label_list = glob(os.path.join(label_path, "*.tif"))
    save_label_path = os.path.join(base_path, "split_train", "label")
    spilt_all_images(label_list,
                     save_label_path,
                     mode="RGB",
                     output_image_h_w=(680, 720))
예제 #9
0
def testGetValid():
    base_path = Path.db_root_dir("rssrai")

    label_path = os.path.join(base_path, "split_train", "label")

    img_path = os.path.join(base_path, "split_train", "img")

    valid_label_path = os.path.join(base_path, "split_valid", "label")

    valid_img_path = os.path.join(base_path, "split_valid", "img")

    shutil.rmtree(valid_label_path)
    os.makedirs(valid_label_path)

    shutil.rmtree(valid_img_path)
    os.makedirs(valid_img_path)

    label_name_list = [
        path_name.split("/")[-1]
        for path_name in glob(os.path.join(label_path, "*"))
    ]

    random.shuffle(label_name_list)

    print(len(label_name_list))

    valid_label_name_list = label_name_list[850:]

    pprint(len(valid_label_name_list))

    for label_name in valid_label_name_list:
        img_name = label_name.replace("_label", "")
        print(valid_label_path, label_name)
        print(valid_img_path, img_name)
        shutil.move(os.path.join(label_path, label_name),
                    os.path.join(valid_label_path, label_name))
        shutil.move(os.path.join(img_path, img_name),
                    os.path.join(valid_img_path, img_name))
def testSplitLabel():
    print(color_name_map)
    name_list = ["文件名"]
    for _, v in color_name_map.items():
        name_list.append(v)
    print(name_list)
    all_statistic_list = []
    base_path = Path.db_root_dir("rssrai")
    image_path = os.path.join(base_path, "split", "label")
    from glob import glob
    image_list = glob(os.path.join(image_path, "*.tif"))
    # 多进程
    pool = Pool(16)
    for image in tqdm(image_list):
        list = image.split("/")
        path = "/".join(list[:-1])
        name = list[-1]

        result = pool.apply_async(statistic_label, args=(path, name))
        all_statistic_list.append(result.get())

    df = pd.DataFrame(all_statistic_list, columns=name_list)
    print(df)
    df.to_csv(os.path.join(base_path, "split_label.csv"))
import os
from glob import glob

import numpy as np
import torch

from experiments.datasets.path import Path

crop_size = 256
base_dir = Path.db_root_dir('rssrai')


def load_numpy(name_list, index):
    i = index
    d = None
    while d is None:
        try:
            sample = np.load(name_list[i])
            d = {
                'image': torch.from_numpy(sample['image']),
                "label": torch.from_numpy(sample['label']).long()
            }
        except:
            print(f"{name_list[i]} is bad! auto remove it.")
            # os.remove(path_list[i])
            # path_list[i] = path_list[0]
    return d


path_list = glob(
    os.path.join(base_dir, f'test_train_numpy_{crop_size}', '*.npz'))