def slot_ok(self):
        self.setWindowModality(Qt.ApplicationModal)
        input_dir = self.lineEdit_input.text()
        if not os.path.isdir(input_dir):
            QMessageBox.warning(self, "Prompt",
                                self.tr("Please check input directory!"))
            sys.exit(-1)
        output_dir = self.lineEdit_output.text()
        if not os.path.isdir(output_dir):
            QMessageBox.warning(self, "Prompt",
                                self.tr("Output directory is not existed!"))
            os.mkdir(output_dir)

        try:
            files, nb = get_file(input_dir)
            if nb == 0:
                QMessageBox.warning(self, "Prompt", self.tr("No image found!"))
                sys.exit(-2)

            for file in tqdm(files):
                abs_filename = os.path.split(file)[1]
                abs_filename = abs_filename.split('.')[0]
                shp_file = ''.join([output_dir, '/', abs_filename, '.shp'])
                polygonize(file, shp_file)
        except:
            QMessageBox.warning(self, "Prompt", self.tr("Failed!"))
        else:
            QMessageBox.information(self, "Prompt", self.tr("successfully!"))

        self.setWindowModality(Qt.NonModal)
Example #2
0
def batchbinarize_masks(inputdict):
    threshold = inputdict['threshold']
    inputdir = inputdict['inputdir']
    outputdir = inputdict['outputdir']
    if not os.path.isdir(inputdir):
        print("Warning: ")
        return -1

    files, num = get_file(inputdir)

    for file in tqdm(files):
        img = cv2.imread(file, cv2.IMREAD_GRAYSCALE)

        result = np.zeros(img.shape, np.uint8)
        ind_foreground = np.where(img > threshold)
        result[ind_foreground] = 1
        # plt.imshow(result)
        # plt.show()

        absname = os.path.split(file)[1]
        # absname = absname.split('.')[0]
        # absname = 'shuidao.png'
        # absname = ''.join([absname, '.png'])

        mask_saving_path = os.path.join(outputdir, absname)

        cv2.imwrite(mask_saving_path, result)

    return 0
def creat_dataset_multiclass(labelfile,
                             srcfile,
                             base_path,
                             image_num=5000,
                             mode='original'):

    print('\ncreating dataset...')
    target_dir = os.path.join(base_path, 'label')
    if not os.path.isdir(target_dir):
        print("samples save path does not exist: {}".format(target_dir))
        sys.exit(-3)

    _, baseNO = get_file(target_dir)

    g_count = baseNO + 1

    count = 0

    src_img = cv2.imread(src_file)

    label_img = cv2.imread(label_file, cv2.IMREAD_GRAYSCALE)
    """Check image size and invalid labels"""
    check_src_label_size(src_img, label_img)
    # check_invalid_labels(label_img, valid_labels)

    X_height, X_width, _ = src_img.shape

    while count < image_num:
        random_width = random.randint(0, X_width - img_w - 1)
        random_height = random.randint(0, X_height - img_h - 1)
        src_roi = src_img[random_height:random_height + img_h,
                          random_width:random_width + img_w, :]
        label_roi = label_img[random_height:random_height + img_h,
                              random_width:random_width + img_w]
        """ignore nodata area"""
        FLAG_HAS_NODATA = False
        tmp = np.unique(label_roi)
        for tt in tmp:
            if tt not in valid_labels:
                FLAG_HAS_NODATA = True
                continue

        if FLAG_HAS_NODATA == True:
            continue
        """ignore whole background area"""
        if len(np.unique(label_roi)) < 2:
            if 0 in np.unique(label_roi):
                continue

        if mode == 'augment':
            src_roi, label_roi = data_augment(src_roi, label_roi)

        visualize = label_roi * 50

        cv2.imwrite((base_path + '/visualize/%d.png' % g_count), visualize)
        cv2.imwrite((base_path + '/src/%d.png' % g_count), src_roi)
        cv2.imwrite((base_path + '/label/%d.png' % g_count), label_roi)
        count += 1
        g_count += 1
Example #4
0
def save_hist_to_csv(in_dir, csv_file, bands, scale):
    input_files, _ = get_file(in_dir)

    Hist = get_hist(input_files, bands, scale)

    # Data = {'band_1':Hist[:,0], 'band_2':Hist[:,1],'band_3':Hist[:,2],'band_4':Hist[:,3]}

    df = pd.DataFrame(Hist)
    df.to_csv(csv_file)
Example #5
0
def creat_dataset(image_num=50000,
                  mode='original',
                  in_path=input_path,
                  out_path=output_path):

    print('\ncreating dataset...')

    src_files, tt = get_file(os.path.join(in_path, 'src/'))
    assert (tt != 0)

    image_each = image_num / len(src_files)

    g_count = 0
    for scr_file in tqdm(src_files):
        count = 0
        src_img = cv2.imread(scr_file)
        label_file = os.path.join(in_path,
                                  'label/') + os.path.split(scr_file)[1]

        if not os.path.isfile(label_file):
            print("Have no file:".format(label_file))
            sys.exit(-1)

        label_img = cv2.imread(label_file, cv2.IMREAD_GRAYSCALE)

        check_src_label_size(src_img, label_img)

        X_height, X_width, _ = src_img.shape

        while count < image_each:
            random_width = random.randint(0, X_width - img_w - 1)
            random_height = random.randint(0, X_height - img_h - 1)
            src_roi = src_img[random_height:random_height + img_h,
                              random_width:random_width + img_w, :]
            label_roi = label_img[random_height:random_height + img_h,
                                  random_width:random_width + img_w]
            """check some invalid labels or NoData values"""
            check_invalid_labels(label_roi)
            """Cut down the pure background image with 80% probability"""
            if len(np.unique(label_roi)) < 2:
                if np.unique(label_roi)[0] == 0:
                    if np.random.random() < 0.8:
                        continue

            if mode == 'augment':
                src_roi, label_roi = data_augment(src_roi, label_roi)

            visualize = label_roi * 50

            cv2.imwrite((out_path + 'visualize/%d.png' % g_count), visualize)
            cv2.imwrite((out_path + 'src/%d.png' % g_count), src_roi)
            cv2.imwrite((out_path + 'label/%d.png' % g_count), label_roi)
            count += 1
            g_count += 1
Example #6
0
    def select_invalid_values(self, filepath):
        files, num = get_file(filepath)
        assert (num != 0)

        for label_file in tqdm(files):
            # label_file = input_label_path + os.path.split(src_file)[1]
            #
            # ret,src_img = load_img(src_file)
            # assert(ret==0)

            label_img = load_img_by_gdal(label_file, grayscale=True)
            label_img = np.array(label_img)

            local_labels = np.unique(label_img)
            invalid_labels = []

            self.HAS_INVALID_VALUE = False

            for tmp in local_labels:
                if tmp not in self.valid_values:
                    invalid_labels.append(tmp)
                    print("\nWarning: some label is not valid value")
                    print("\nFile: {}".format(label_file))
                    self.HAS_INVALID_VALUE = True

            if self.HAS_INVALID_VALUE == True:
                new_label_img = self.make_invalid_to_zeros(
                    label_img, invalid_labels)
                new_label_file = os.path.split(
                    label_file)[0] + '/new_' + os.path.split(label_file)[1]
                cv2.imwrite(new_label_file, new_label_img)
                self.HAS_INVALID_VALUE = False
                label_img = new_label_img

            plt.imshow(label_img, cmap='gray')
            plt.show()

        print("Check completely!\n")
Example #7
0
    return tp_img



    #
    # for i in range(height):
    #     for j in range(width):
    #         tmp = img[i,j]
    #         if not tmp in true_values:
    #             print("img[{},{}]: {}".format(i,j,tmp))
    #             img[i,j]=0
    # return img


if __name__ == '__main__':
    files,num = get_file(input_label_path)
    assert (num!=0)

    # valid_labels = []
    # if FLAG_USING_UNET:
    #     valid_labels = unet_labels
    # else:
    #     valid_labels = segnet_labels

    for label_file in tqdm(files):
        # label_file = input_label_path + os.path.split(src_file)[1]
        #
        # ret,src_img = load_img(src_file)
        # assert(ret==0)

        ret,label_img = load_img_by_cv2(label_file, grayscale=True)
Example #8
0
def predict_multiclass_for_batch_image(input_dict={}):
    gup_id = input_dict['GPUID']
    os.environ["CUDA_VISIBLE_DEVICES"] = gup_id
    window_size = input_dict['windsize']
    im_bands = input_dict['im_bands']
    im_type = input_dict['dtype']
    FLAG_APPROACH_PREDICT = input_dict['strategy']
    input_path = input_dict['image_dir']
    model_file = input_dict['model_file']
    output_path = input_dict['mask_dir']

    out_bands = input_dict['target_num']

    all_files, num = get_file(input_path)
    if num == 0:
        print("There is no file in path:{}".format(input_path))
        sys.exit(-1)

    for img_file in all_files:
        print("[INFO] opening image...".format(img_file))
        input_img = load_img_by_gdal(img_file)
        if im_type == UINT8:
            input_img = input_img / 255.0
        elif im_type == UINT10:
            input_img = input_img / 1024.0
        elif im_type == UINT16:
            input_img = input_img / 65535.0

        input_img = np.clip(input_img, 0.0, 1.0)
        input_img = input_img.astype(np.float16)

        model = load_model(model_file)
        print(model.summary())
        layer_dict = dict([(layer.name, layer) for layer in model.layers])
        layer_name = 'softmax'  #sigmoid, softmax
        nb_classes = layer_dict[layer_name].output.shape[-1]
        if out_bands != nb_classes - 1:
            out_bands = nb_classes - 1

        abs_filename = os.path.split(img_file)[1]
        abs_filename = abs_filename.split(".")[0]

        if FLAG_APPROACH_PREDICT == 0:
            print("[INFO] predict image by orignal approach\n")
            result = orignal_predict_onehot(input_img, im_bands, model,
                                            window_size)
            output_file = ''.join([output_path, '/', abs_filename, '.png'])
            print("result save as to: {}".format(output_file))
            cv2.imwrite(output_file, result * 128)

        elif FLAG_APPROACH_PREDICT == 1:
            print("[INFO] predict image by smooth approach\n")
            result = predict_img_with_smooth_windowing_multiclassbands(
                input_img,
                model,
                window_size=window_size,
                subdivisions=2,
                real_classes=out_bands,  # output channels = 是真的类别,总类别-背景
                pred_func=smooth_predict_for_multiclass,
                PLOT_PROGRESS=False)

            H, W, C = np.array(input_img).shape
            output_file = ''.join(
                [output_path, '/', abs_filename, '_smooth_pred.png'])
            output_mask = np.zeros((H, W), np.uint8)
            for i in range(out_bands):
                indx = np.where(result[:, :, i] >= 127)
                output_mask[indx] = i + 1
            print(np.unique(result))
            cv2.imwrite(output_file, output_mask)
            print("Saved to:{}".format(output_file))

            # for b in range(out_bands):
            #     output_file = ''.join([output_path, '/', abs_filename, '_', dict_target[b],'smooth.png'])
            #     cv2.imwrite(output_file, result[:,:,b])
            #     print("Saved to: {}".format(output_file))

        gc.collect()

    return 0
Example #9
0
    FLAG_APPROACH_PREDICT = 1
else:
    pass

date_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
output_dir = ''.join([config.mask_dir, '/', date_time])
os.mkdir(output_dir)

if __name__ == '__main__':
    input_files = []
    if os.path.isfile(config.img_input):
        print("[INFO] input is one file...")
        input_files.append(config.img_input)
    elif os.path.isdir(config.img_input):
        print("[INFO] input is a directory...")
        in_files, _ = get_file(config.img_input)
        for file in in_files:
            input_files.append(file)
    print("{} images will be classified".format(len(input_files)))

    # sys.exit(-1)

    out_bands = config.mask_classes
    model = load_model(config.model_path)
    print(model.summary())
    layer_dict = dict([(layer.name, layer) for layer in model.layers])
    layer_name = config.activation  # sigmoid, softmax
    nb_classes = layer_dict[layer_name].output.shape[-1]
    if "sigmoid" in config.activation:
        if target_class != nb_classes:
            print(
def image_normalize(input_dict):
    input_dir = input_dict["input_dir"]
    output_dir = input_dict["output_dir"]
    nodata = input_dict["NoData"]
    result_bits = input_dict["OutBits"]
    valid_range = input_dict["StretchRange"]
    cut_value = input_dict["CutValue"]

    src_files, tt = get_file(input_dir)
    assert (tt != 0)
    factor = 4.0

    if '8' in result_bits:
        assert (valid_range < 256)
        factor = 6.0
    elif '16' in result_bits:
        assert (valid_range < 65536)
        factor = 4.0
    else:
        pass

    for file in tqdm(src_files):

        absname = os.path.split(file)[1]
        absname = absname.split('.')[0]
        absname = ''.join([absname, '.tif'])
        print(absname)
        if not os.path.isfile(file):
            print("input file dose not exist:{}\n".format(file))
            # sys.exit(-1)
            continue

        dataset = gdal.Open(file)
        if dataset == None:
            print("Open file failed: {}".format(file))
            continue

        height = dataset.RasterYSize
        width = dataset.RasterXSize
        im_bands = dataset.RasterCount
        im_type = dataset.GetRasterBand(1).DataType
        img = dataset.ReadAsArray(0, 0, width, height)
        geotransform = dataset.GetGeoTransform()
        del dataset
        # img = np.array(img, np.uint16)
        img = np.array(img, np.float32)
        result = []
        for i in range(im_bands):
            data = np.array(img[i])
            maxium = data.max()
            minm = data.min()
            mean = data.mean()
            std = data.std()
            print("\nOriginal max, min, mean,std:[{},{},{},{}]".format(
                maxium, minm, mean, std))
            data = data.reshape(height * width)
            ind = np.where((data > 0) & (data < nodata))
            ind = np.array(ind)

            a, b = ind.shape
            print("valid value number: {}".format(b))
            # tmp = np.zeros(b, np.uint16)
            tmp = np.zeros(b, np.float32)
            for j in range(b):
                tmp[j] = data[ind[0, j]]
            tmaxium = tmp.max()
            tminm = tmp.min()
            tmean = tmp.mean()
            tstd = tmp.std()
            # print(tmaxium, tminm, tmean, tstd)
            tt = (data - tmean) / tstd  # first Z-score normalization
            tt = (tt + factor) * valid_range / (2 * factor) - cut_value
            tind = np.where(data == 0)

            tt = np.array(tt)
            # tt = tt.astype(np.uint8)
            tt = tt.astype(np.uint16)
            tt[tind] = 0

            smaxium = tt.max()
            sminm = tt.min()
            smean = tt.mean()
            sstd = tt.std()
            # print(smaxium, sminm, smean, sstd)
            print("New max, min, mean,std:[{},{},{},{}]".format(
                smaxium, sminm, smean, sstd))

            out = tt.reshape((height, width))
            result.append(out)

        outputfile = os.path.join(output_dir, absname)
        driver = gdal.GetDriverByName("GTiff")

        if '8' in result_bits:
            outdataset = driver.Create(outputfile, width, height, im_bands,
                                       gdal.GDT_Byte)
            outdataset.SetGeoTransform(geotransform)
        elif '16' in result_bits:
            outdataset = driver.Create(outputfile, width, height, im_bands,
                                       gdal.GDT_UInt16)
            outdataset.SetGeoTransform(geotransform)
        # outdataset = driver.Create(outputfile, width, height, im_bands, gdal.GDT_UInt16)

        for i in range(im_bands):
            outdataset.GetRasterBand(i + 1).WriteArray(result[i])

        del outdataset

    return 0
Example #11
0
def predict_binary_for_batch_image(input_dict={}):
    gup_id = input_dict['GPUID']
    os.environ["CUDA_VISIBLE_DEVICES"] = gup_id
    window_size = input_dict['windsize']
    im_bands = input_dict['im_bands']
    im_type = input_dict['dtype']
    FLAG_APPROACH_PREDICT = input_dict['strategy']
    input_path = input_dict['image_dir']
    model_file = input_dict['model_file']
    output_path = input_dict['mask_dir']

    out_bands = 1
    FLAG_ONEHOT = 0
    if input_dict['onehot']:
        FLAG_ONEHOT = 1

    all_files, num = get_file(input_path)
    if num == 0:
        print("There is no file in path:{}".format(input_path))
        sys.exit(-1)

    for img_file in all_files:
        print("[INFO] opening image...")
        print("FileName:{}".format(img_file))
        input_img = load_img_by_gdal(img_file)
        if im_type == UINT8:
            input_img = input_img / 255.0
        elif im_type == UINT10:
            input_img = input_img / 1024.0
        elif im_type == UINT16:
            input_img = input_img / 65535.0

        input_img = np.clip(input_img, 0.0, 1.0)
        input_img = input_img.astype(np.float16)

        model = load_model(model_file)
        print(model.summary())
        layer_dict = dict([(layer.name, layer) for layer in model.layers])
        layer_name = 'sigmoid'  # sigmoid, softmax
        nb_classes = layer_dict[layer_name].output.shape[-1]

        abs_filename = os.path.split(img_file)[1]
        abs_filename = abs_filename.split(".")[0]

        if FLAG_APPROACH_PREDICT == 0:
            print("[INFO] predict image by orignal approach\n")
            if FLAG_ONEHOT:
                result = orignal_predict_onehot(input_img, im_bands, model,
                                                window_size)
            else:
                result = orignal_predict_notonehot(input_img, im_bands, model,
                                                   window_size)

            output_file = ''.join([output_path, '/', abs_filename, '.png'])
            print("result save as to: {}".format(output_file))
            cv2.imwrite(output_file, result * 128)

        elif FLAG_APPROACH_PREDICT == 1:
            print("[INFO] predict image by smooth approach\n")
            if FLAG_ONEHOT:
                result = predict_img_with_smooth_windowing_multiclassbands(
                    input_img,
                    model,
                    window_size=window_size,
                    subdivisions=2,
                    real_classes=out_bands,  # output channels = 是真的类别,总类别-背景
                    pred_func=smooth_predict_for_binary_onehot,
                    PLOT_PROGRESS=False)
            else:
                result = predict_img_with_smooth_windowing_multiclassbands(
                    input_img,
                    model,
                    window_size=window_size,
                    subdivisions=2,
                    real_classes=out_bands,  # output channels = 是真的类别,总类别-背景
                    pred_func=smooth_predict_for_binary_notonehot,
                    PLOT_PROGRESS=False)

            output_file = ''.join(
                [output_path, '/', abs_filename, 'smooth.png'])
            cv2.imwrite(output_file, result)
            print("Saved to: {}".format(output_file))

        gc.collect()

    return 0
def creat_dataset_binary(labelfile,
                         srcfile,
                         base_path,
                         image_num=5000,
                         mode='original'):
    print('\ncreating dataset...')
    target_dir = os.path.join(base_path, 'roads', 'label')
    if not os.path.isdir(target_dir):
        print("samples save path does not exist: {}".format(target_dir))
        sys.exit(-3)

    _, baseNO = get_file(target_dir)

    src_img = cv2.imread(srcfile)

    label_img = cv2.imread(labelfile, cv2.IMREAD_GRAYSCALE)
    """Check image size and invalid labels"""
    check_src_label_size(src_img, label_img)
    # check_invalid_labels(label_img, valid_labels)

    X_height, X_width, _ = src_img.shape

    print("\n1: produce road labels---------------------")
    index = np.where(label_img == 1)  # 1: roads
    road_label = np.zeros((X_height, X_width), np.uint8)
    road_label[index] = 1

    print(np.unique(road_label))
    g_count = baseNO + 1
    count = 0
    while count < image_num:
        random_width = random.randint(0, X_width - img_w - 1)
        random_height = random.randint(0, X_height - img_h - 1)
        src_roi = src_img[random_height:random_height + img_h,
                          random_width:random_width + img_w, :]
        label_roi = road_label[random_height:random_height + img_h,
                               random_width:random_width + img_w]
        """ignore nodata area"""
        FLAG_HAS_NODATA = False
        tmp = np.unique(label_img[random_height:random_height + img_h,
                                  random_width:random_width + img_w])
        for tt in tmp:
            if tt not in valid_labels:
                FLAG_HAS_NODATA = True
                continue

        if FLAG_HAS_NODATA == True:
            continue
        """ignore whole background area"""
        if len(np.unique(label_roi)) < 2:
            if 0 in np.unique(label_roi):
                continue

        if mode == 'augment':
            src_roi, label_roi = data_augment(src_roi, label_roi)

        visualize = label_roi * 50

        cv2.imwrite((base_path + '/roads/visualize/%d.png' % g_count),
                    visualize)
        cv2.imwrite((base_path + '/roads/src/%d.png' % g_count), src_roi)
        cv2.imwrite((base_path + '/roads/label/%d.png' % g_count), label_roi)
        count += 1
        g_count += 1

    print("\n2: produce buildings labels---------------------")
    index = np.where(label_img == 2)  # 1: buildings
    building_label = np.zeros((X_height, X_width), np.uint8)
    building_label[index] = 1

    target_dir = os.path.join(base_path, 'buildings', 'label')
    if not os.path.isdir(target_dir):
        print("samples save path does not exist: {}".format(target_dir))
        sys.exit(-3)

    _, baseNO = get_file(target_dir)

    g_count = baseNO + 1
    count = 0
    while count < image_num:
        random_width = random.randint(0, X_width - img_w - 1)
        random_height = random.randint(0, X_height - img_h - 1)
        src_roi = src_img[random_height:random_height + img_h,
                          random_width:random_width + img_w, :]
        label_roi = building_label[random_height:random_height + img_h,
                                   random_width:random_width + img_w]
        """ignore nodata area"""
        FLAG_HAS_NODATA = False
        tmp = np.unique(label_img[random_height:random_height + img_h,
                                  random_width:random_width + img_w])
        for tt in tmp:
            if tt not in valid_labels:
                FLAG_HAS_NODATA = True
                continue

        if FLAG_HAS_NODATA == True:
            continue
        """ignore whole background area"""
        if len(np.unique(label_roi)) < 2:
            if 0 in np.unique(label_roi):
                continue

        if mode == 'augment':
            src_roi, label_roi = data_augment(src_roi, label_roi)

        visualize = label_roi * 50

        cv2.imwrite((base_path + '/buildings/visualize/%d.png' % g_count),
                    visualize)
        cv2.imwrite((base_path + '/buildings/src/%d.png' % g_count), src_roi)
        cv2.imwrite((base_path + '/buildings/label/%d.png' % g_count),
                    label_roi)
        count += 1
        g_count += 1
def creat_dataset_binary(in_path, out_path, image_num=50000, mode='original'):
    print('\ncreating dataset...')

    label_files, tt = get_file(os.path.join(in_path, 'label/'))
    assert (tt != 0)

    image_each = image_num / len(label_files)

    g_count = 0
    for label_file in tqdm(label_files):

        src_file = os.path.join(in_path, 'src/') + os.path.split(label_file)[1]
        if not os.path.isfile(src_file):
            print("Have no file:".format(src_file))
            continue
            # sys.exit(-1)

        print("src file:{}".format(os.path.split(src_file)[1]))
        src_img = cv2.imread(src_file)

        label_img = cv2.imread(label_file, cv2.IMREAD_GRAYSCALE)
        """Check image size and invalid labels"""
        check_src_label_size(src_img, label_img)
        # check_invalid_labels(label_img, valid_labels)

        X_height, X_width, _ = src_img.shape

        print("\n1: produce road labels---------------------")
        index = np.where(label_img == 1)  # 1: roads
        road_label = np.zeros((X_height, X_width), np.uint8)
        road_label[index] = 1

        print(np.unique(road_label))
        count = 0
        while count < image_each:
            random_width = random.randint(0, X_width - img_w - 1)
            random_height = random.randint(0, X_height - img_h - 1)
            src_roi = src_img[random_height:random_height + img_h,
                              random_width:random_width + img_w, :]
            label_roi = road_label[random_height:random_height + img_h,
                                   random_width:random_width + img_w]
            """ignore nodata area"""
            FLAG_HAS_NODATA = False
            tmp = np.unique(label_img[random_height:random_height + img_h,
                                      random_width:random_width + img_w])
            for tt in tmp:
                if tt not in valid_labels:
                    FLAG_HAS_NODATA = True
                    continue

            if FLAG_HAS_NODATA == True:
                continue

            # """Cut down the pure background image with 80% probability"""
            # if len(np.unique(label_roi)) < 2:
            #     if np.unique(label_roi)[0] ==0:
            #         if np.random.random()< 0.8:
            #             continue
            """ignore pure background area"""
            if len(np.unique(label_roi)) < 2:
                if 0 in np.unique(label_roi):
                    continue

            if mode == 'augment':
                src_roi, label_roi = data_augment(src_roi, label_roi)

            visualize = label_roi * 50

            cv2.imwrite((out_path + '/roads/visualize/%d.png' % g_count),
                        visualize)
            cv2.imwrite((out_path + '/roads/src/%d.png' % g_count), src_roi)
            cv2.imwrite((out_path + '/roads/label/%d.png' % g_count),
                        label_roi)
            count += 1
            g_count += 1

        print("\n2: produce buildings labels---------------------")
        index = np.where(label_img == 2)  # 1: buildings
        building_label = np.zeros((X_height, X_width), np.uint8)
        building_label[index] = 1

        count = 0
        while count < image_each:
            random_width = random.randint(0, X_width - img_w - 1)
            random_height = random.randint(0, X_height - img_h - 1)
            src_roi = src_img[random_height:random_height + img_h,
                              random_width:random_width + img_w, :]
            label_roi = building_label[random_height:random_height + img_h,
                                       random_width:random_width + img_w]
            """ignore nodata area"""
            FLAG_HAS_NODATA = False
            tmp = np.unique(label_img[random_height:random_height + img_h,
                                      random_width:random_width + img_w])
            for tt in tmp:
                if tt not in valid_labels:
                    FLAG_HAS_NODATA = True
                    continue

            if FLAG_HAS_NODATA == True:
                continue
            """Cut down the pure background image with 80% probability"""
            # if len(np.unique(label_roi)) < 2:
            #     if np.unique(label_roi)[0] == 0:
            #         if np.random.random() < 0.8:
            #             continue
            """ignore pure background area"""
            if len(np.unique(label_roi)) < 2:
                if 0 in np.unique(label_roi):
                    continue

            if mode == 'augment':
                src_roi, label_roi = data_augment(src_roi, label_roi)

            visualize = label_roi * 50

            cv2.imwrite((out_path + '/buildings/visualize/%d.png' % g_count),
                        visualize)
            cv2.imwrite((out_path + '/buildings/src/%d.png' % g_count),
                        src_roi)
            cv2.imwrite((out_path + '/buildings/label/%d.png' % g_count),
                        label_roi)
            count += 1
            g_count += 1
Example #14
0
def produce_training_samples_multiclass(in_path,
                                        out_path,
                                        image_num=50000,
                                        mode='original'):
    print('\ncreating dataset...')

    label_files, tt = get_file(os.path.join(in_path, 'label/'))
    assert (tt != 0)

    image_each = image_num / len(label_files)

    g_count = 0
    for label_file in tqdm(label_files):

        src_file = os.path.join(in_path, 'src/') + os.path.split(label_file)[1]
        if not os.path.isfile(src_file):
            print("Have no file:".format(src_file))
            continue
            # sys.exit(-1)

        print("src file:{}".format(os.path.split(src_file)[1]))

        label_img = cv2.imread(label_file, cv2.IMREAD_GRAYSCALE)

        dataset = gdal.Open(src_file)
        if dataset == None:
            print("open failed!\n")
            continue

        X_height = dataset.RasterYSize
        X_width = dataset.RasterXSize
        im_bands = dataset.RasterCount
        data_type = dataset.GetRasterBand(1).DataType

        # check size of label and src images
        x, y = label_img.shape
        print("Heigh, width of label is :{}, {}".format(x, y))
        print("Heigh, width of src is :{}, {}".format(X_height, X_width))
        if x != X_height or y != X_width:
            print("Warning: src and label have different size!")
            continue

        src_img = dataset.ReadAsArray(0, 0, X_width, X_height)
        src_img = np.array(src_img)

        del dataset

        count = 0
        while count < image_each:
            random_width = random.randint(0, X_width - img_w - 1)
            random_height = random.randint(0, X_height - img_h - 1)
            src_roi = src_img[:, random_height:random_height + img_h,
                              random_width:random_width + img_w]
            label_roi = label_img[random_height:random_height + img_h,
                                  random_width:random_width + img_w]
            # try:
            #     label_roi = label_img[random_height: random_height + img_h, random_width: random_width + img_w]
            # except SelfDefinedExceptions as e_result:
            #     print(e_result)
            """ignore nodata area"""
            FLAG_HAS_NODATA = False
            tmp = np.unique(label_img[random_height:random_height + img_h,
                                      random_width:random_width + img_w])
            for tt in tmp:
                if tt not in valid_labels:
                    FLAG_HAS_NODATA = True
                    continue

            if FLAG_HAS_NODATA == True:
                continue
            """ignore pure background area"""
            if len(np.unique(label_roi)) < 2:
                if 0 in np.unique(label_roi):
                    continue
            # print(np.unique(label_roi))

            if mode == 'augment':
                src_roi, label_roi = data_augment(src_roi, label_roi,
                                                  data_type)

            visualize = label_roi * 50

            cv2.imwrite((out_path + '/visualize/%d.png' % g_count), visualize)
            cv2.imwrite((out_path + '/label/%d.png' % g_count), label_roi)

            src_sample_file = out_path + '/src/%d.png' % g_count
            driver = gdal.GetDriverByName("GTiff")
            outdataset = driver.Create(src_sample_file, img_w, img_h, im_bands,
                                       data_type)
            if outdataset == None:
                print("create dataset failed!\n")
                sys.exit(-2)
            if im_bands == 1:
                outdataset.GetRasterBand(1).WriteArray(src_roi)
            else:
                for i in range(im_bands):
                    outdataset.GetRasterBand(i + 1).WriteArray(src_roi[i])
            del outdataset

            count += 1
            g_count += 1
Example #15
0
HAS_INVALID_VALUE = False


def make_label_valid(img, true_values):
    height, width = img.shape
    for i in range(height):
        for j in range(width):
            tmp = img[i,j]
            if not tmp in true_values:
                print("img[{},{}]: {}".format(i,j,tmp))
                img[i,j]=0
    return img


if __name__ == '__main__':
    src_files,num = get_file(input_src_path)
    assert (num!=0)

    valid_labels = []
    if FLAG_USING_UNET:
        valid_labels = unet_labels
    else:
        valid_labels = segnet_labels

    for src_file in tqdm(src_files):
        label_file = input_label_path + os.path.split(src_file)[1]

        ret,src_img = load_img(src_file)
        assert(ret==0)

        ret,label_img = load_img(label_file, grayscale=True)
def predict_backend(input_dict):

    gpu_id = input_dict['gpu']
    print("gpu_id:{}".format(gpu_id))
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id

    config_file = input_dict['config']
    print("cofig file:{}".format(config_file))
    with open(config_file, 'r') as f:
        cfgl = json.load(f)

    config = Config(**cfgl)
    print(config)

    input = input_dict['input']
    output = input_dict['output']

    im_type = UINT8
    if "10" in config.im_type:
        im_type = UINT10
    elif "16" in config.im_type:
        im_type = UINT16
    else:
        pass

    target_class = config.nb_classes
    if target_class > 1:  # multiclass, target class = total class -1
        if target_class == 2:
            print("Warning: target classes should not be 2, this must be binary classification!")
            target_class = 1
        else:
            target_class -= 1

    FLAG_APPROACH_PREDICT = 0  # 0: original predict, 1: smooth predict
    if "smooth" in config.strategy:
        FLAG_APPROACH_PREDICT = 1
    else:
        pass

    date_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
    output_dir = ''.join([output, '/', date_time])
    os.mkdir(output_dir)

    block_size = config.block_size
    nodata = config.nodata

    input_files = []
    if os.path.isfile(input):
        print("[INFO] input is one file...")
        input_files.append(input)
    elif os.path.isdir(input):
        print("[INFO] input is a directory...")
        in_files, _ = get_file(input)
        for file in in_files:
            input_files.append(file)
    if len(input_files) == 0:
        print("no input images")
        sys.exit(-1)
    print("{} images will be classified".format(len(input_files)))

    csv_file = os.path.join(output_dir, 'readme.csv')
    df = pd.DataFrame(list(config))
    df.to_csv(csv_file)

    out_bands = target_class
    try:
        model = load_model(input_dict['model'], compile=False)
        """load model using customer loss"""
        # lossName = 'cce_jaccard_loss'
        # model = load_model(config.model_path, custom_objects={'closure':self_define_loss(lossName)})
    except ValueError:
        print("Warning: there are several custom objects in model")
        print("For deeplab V3+, load model with parameters of custom_objects\n")
        model = load_model(input_dict['model'], custom_objects={'relu6': relu6, 'BilinearUpsampling': BilinearUpsampling}, compile=False)
    except Exception:
        print("Error: failde to load model!\n")
        sys.exit(-1)
    else:
        print("model is not deeplab V3+!\n")
    # print(model.summary())

    for img_file in tqdm(input_files):
        print("\n[INFO] opening image:{}...".format(img_file))
        abs_filename = os.path.split(img_file)[1]
        # abs_filename = abs_filename.split(".")[0]
        H, W, C, geoinf = load_img_by_gdal_info(img_file)
        if H==0:
            print("Open failed:{}".format(abs_filename))
            continue
        gc.collect()


        nb_blocks = int(H*W/block_size)
        if H*W>nb_blocks*block_size:
            nb_blocks +=1
        block_h = int(block_size/W)
        print("single block size :[{},{}]".format(block_h,W))
        result_mask = np.zeros((H, W), np.uint8)
        for i in tqdm(list(range(nb_blocks))):
            print("[INFO] predict image for {} block".format(i))
            start =block_h*i
            this_h = block_h
            if (i+1)*block_h>H:
                this_h = H-i*block_h
            end = start+this_h
            # b_img = load_img_by_gdal_blocks(img_file,0,start,W,this_h)
            b_img = load_img_by_gdal_blocks(img_file, 0, start, W, this_h+config.window_size)

            if i ==nb_blocks-1:
                tmp_img = np.zeros((this_h+config.window_size, W, C), np.uint16)
                tmp_img[:this_h,:,:] = b_img
            else:
                tmp_img = b_img
                # exp_img = np.zeros((this_h+config.window_size, W, C), np.uint16)
                # exp_img[:, :, :] = b_img[:,:,:]
            # b_img = whole_img[start:end,:,:]
            # plt.imshow(b_img[:,:,1])
            # plt.show()
            # sys.exit(-3)
            """get data in bands of band_list"""
            band_list = config.band_list
            if len(band_list) == 0:
                band_list = range(C)
            if len(band_list) > C or max(band_list) >= C:
                print("input bands should not be bigger than image bands!")
                sys.exit(-2)

            a,b,c = tmp_img.shape
            input_img = np.zeros((a,b,len(band_list)), np.float16)
            for i in range(len(band_list)):
                input_img[:,:,i] = tmp_img[:,:,band_list[i]]

            if im_type == UINT8:
                input_img = input_img / 255.0
            elif im_type == UINT10:
                input_img = input_img / 1024.0
            elif im_type == UINT16:
                input_img = input_img / 65535.0

            input_img = np.clip(input_img, 0.0, 1.0)
            input_img = input_img.astype(np.float16)

            if FLAG_APPROACH_PREDICT == 0:
                print("[INFO] predict image by orignal approach ...")
                a,b,c=input_img.shape
                num_of_bands = min(a,b,c)
                result = core_orignal_predict(input_img, num_of_bands, model, config.window_size, config.img_w, mask_bands=config.nb_classes)
                result_mask[start:end,:]=result[:this_h,:]

            elif FLAG_APPROACH_PREDICT == 1:
                print("[INFO] predict image by smooth approach... ")
                output_mask = np.zeros((this_h+config.window_size, W), np.uint8)
                if out_bands > 1:
                    result = predict_img_with_smooth_windowing(
                        input_img,
                        model,
                        window_size=config.window_size,
                        subdivisions=config.subdivisions,
                        slices= config.slices,
                        real_classes=target_class,  # output channels = 是真的类别,总类别-背景
                        pred_func=core_smooth_predict_multiclass,
                        PLOT_PROGRESS=False
                    )
                    for i in range(target_class):
                        indx = np.where(result[:, :, i] >= 127)
                        output_mask[indx] = i + 1
                    del result
                    gc.collect()

                else:
                    result = predict_img_with_smooth_windowing(
                        input_img,
                        model,
                        window_size=config.window_size,
                        subdivisions=config.subdivisions,
                        slices=config.slices,
                        real_classes=target_class,
                        pred_func=core_smooth_predict_binary,
                        PLOT_PROGRESS=False
                    )
                    indx = np.where(result[:, :, 0] >= 127)
                    output_mask[indx] = 1
                    # del result
                    gc.collect()

                result_mask[start:end, :] = output_mask[:this_h, :]
                # del output_mask
                gc.collect()

            del b_img
            # del tmp_img
            # del input_img

            gc.collect()

        print(np.unique(result_mask))
        # result_mask[nodata_indx]=255
        # output_file = ''.join([output_dir, '/', abs_filename, config.suffix])
        output_file = ''.join([output_dir, '/', abs_filename])
        driver = gdal.GetDriverByName("GTiff")
        outdataset = driver.Create(output_file, W, H, 1, gdal.GDT_Byte)
        outdataset.SetGeoTransform(geoinf)
        if outdataset == None:
            print("create dataset failed!\n")
            sys.exit(-2)
        outdataset.GetRasterBand(1).WriteArray(result_mask)
        del outdataset
        # result_mask[nodata_indx] = 255
        del result_mask
        gc.collect()
        print("Saved to:{}".format(output_file))

        # output vector file from raster file
        if config.tovector:
            shp_file= ''.join([output_dir, '/', abs_filename, '.shp'])
            polygonize(output_file, shp_file)
    return 0
Example #17
0
def convert_all_image_to_16bits():
    # src_files, tt = get_file(input_path,file_type='.tif')
    src_files, tt = get_file(input_path)
    assert (tt != 0)

    for file in tqdm(src_files):

        absname = os.path.split(file)[1]
        absname = absname.split('.')[0]
        # absname = 'shuidao.png'
        absname = ''.join([absname, '.png'])
        print(absname)
        if not os.path.isfile(file):
            print("input file dose not exist:{}\n".format(file))
            # sys.exit(-1)
            continue

        dataset = gdal.Open(file)
        if dataset == None:
            print("Open file failed: {}".format(file))
            continue

        height = dataset.RasterYSize
        width = dataset.RasterXSize
        im_bands = dataset.RasterCount
        im_type = dataset.GetRasterBand(1).DataType
        img = dataset.ReadAsArray(0, 0, width, height)
        del dataset
        # img = np.array(img, np.uint16)
        img = np.array(img, np.float32)
        result = []
        for i in range(im_bands):
            if im_bands == 1:
                data = img
            else:
                data = np.array(img[i])
            maxium = data.max()
            minm = data.min()
            mean = data.mean()
            std = data.std()
            print(maxium, minm, mean, std)
            data = data.reshape(height * width)
            ind = np.where((data > 0) & (data < NoData))
            ind = np.array(ind)

            a, b = ind.shape
            print("valid value number: {}\n".format(b))
            # tmp = np.zeros(b, np.uint16)
            tmp = np.zeros(b, np.float32)
            for j in range(b):
                tmp[j] = data[ind[0, j]]
            tmaxium = tmp.max()
            tminm = tmp.min()
            tmean = tmp.mean()
            tstd = tmp.std()
            print(tmaxium, tminm, tmean, tstd)
            tt = (data - tmean) / tstd  # first Z-score normalization
            tt = (tt + 4) * 1024 / 8.0 - 100
            tind = np.where(data == 0)

            tt = np.array(tt)
            # tt = tt.astype(np.uint8)
            tt = tt.astype(np.uint16)
            tt[tind] = 0

            smaxium = tt.max()
            sminm = tt.min()
            smean = tt.mean()
            sstd = tt.std()
            print(smaxium, sminm, smean, sstd)

            out = tt.reshape((height, width))
            result.append(out)

        outputfile = os.path.join(output_path, absname)
        driver = gdal.GetDriverByName("GTiff")

        outdataset = driver.Create(outputfile, width, height, im_bands,
                                   gdal.GDT_UInt16)

        for i in range(im_bands):
            outdataset.GetRasterBand(i + 1).WriteArray(result[i])

        del outdataset
Example #18
0
        cmd.append('{}'.format(bmin))
        cmd.append('{}'.format(bmax))
        cmd.append('{}'.format(0))
        cmd.append('{}'.format(255))

    cmd.append(inputRaster)
    cmd.append(outputRaster)
    print("Conversin command:", cmd)
    subprocess.call(cmd)


if __name__ == '__main__':
    if not os.path.isdir(inputdir):
        print("Please check input directory:{}".format(inputdir))
        sys.exit(-1)

    if not os.path.isdir(outputdir):
        print('Warning: output directory is not existed')
        os.mkdir(outputdir)

    files, _ = get_file(inputdir)
    for file in files:
        absname = os.path.split(file)[1]
        outputfile = os.path.join(outputdir, absname)
        convert_to_8Bit2(file,
                         outputfile,
                         outputDataType='Byte',
                         stretch_type='rescale',
                         nodata=65535,
                         percentiles=[1, 99.9])
Example #19
0
def creat_dataset_binary(in_path, out_path, image_num=50000, mode='original'):
    print('\ncreating dataset...')

    label_files, tt = get_file(os.path.join(in_path, 'label/'))
    assert (tt != 0)

    image_each = image_num / len(label_files)

    print("\n1: produce road labels---------------------")
    g_count = 0
    for label_file in tqdm(label_files):

        src_file = os.path.join(in_path, 'src/') + os.path.split(label_file)[1]
        if not os.path.isfile(src_file):
            print("Have no file:".format(src_file))
            continue

        print("src file:{}".format(os.path.split(src_file)[1]))

        label_img = cv2.imread(label_file, cv2.IMREAD_GRAYSCALE)

        dataset = gdal.Open(src_file)
        if dataset == None:
            print("open failed!\n")
            continue

        X_height = dataset.RasterYSize
        X_width = dataset.RasterXSize
        im_bands = dataset.RasterCount
        data_type = dataset.GetRasterBand(1).DataType

        src_img = dataset.ReadAsArray(0, 0, X_width, X_height)
        src_img = np.array(src_img)

        del dataset

        index = np.where(label_img == 1)  # 1: roads
        road_label = np.zeros((X_height, X_width), np.uint8)
        road_label[index] = 1

        print(np.unique(road_label))
        count = 0
        while count < image_each:
            random_width = random.randint(0, X_width - img_w - 1)
            random_height = random.randint(0, X_height - img_h - 1)
            src_roi = src_img[:, random_height:random_height + img_h,
                              random_width:random_width + img_w]
            label_roi = road_label[random_height:random_height + img_h,
                                   random_width:random_width + img_w]
            """ignore nodata area"""
            FLAG_HAS_NODATA = False
            tmp = np.unique(label_img[random_height:random_height + img_h,
                                      random_width:random_width + img_w])
            for tt in tmp:
                if tt not in valid_labels:
                    FLAG_HAS_NODATA = True
                    continue

            if FLAG_HAS_NODATA == True:
                continue
            """ignore pure background area"""
            if len(np.unique(label_roi)) < 2:
                if 0 in np.unique(label_roi):
                    continue

            if mode == 'augment':
                src_roi, label_roi = data_augment(src_roi, label_roi,
                                                  data_type)

            visualize = label_roi * 50

            # cv2.imwrite((out_path + '/roads/visualize/%d.png' % g_count), visualize)
            cv2.imwrite((out_path + '/roads/label/%d.png' % g_count),
                        label_roi)

            src_sample_file = out_path + '/roads/src/%d.png' % g_count
            driver = gdal.GetDriverByName("GTiff")
            # driver = gdal.GetDriverByName("PNG")
            # outdataset = driver.Create(src_sample_file, img_w, img_h, im_bands, gdal.GDT_UInt16)
            outdataset = driver.Create(src_sample_file, img_w, img_h, im_bands,
                                       data_type)
            if outdataset == None:
                print("create dataset failed!\n")
                sys.exit(-2)
            if im_bands == 1:
                outdataset.GetRasterBand(1).WriteArray(src_roi)
            else:
                for i in range(im_bands):
                    outdataset.GetRasterBand(i + 1).WriteArray(src_roi[i])
            del outdataset

            count += 1
            g_count += 1

    print("\n2: produce buildings labels---------------------")

    g_count = 0

    for label_file in tqdm(label_files):
        src_file = os.path.join(in_path, 'src/') + os.path.split(label_file)[1]
        if not os.path.isfile(src_file):
            print("Have no file:".format(src_file))
            continue
            # sys.exit(-1)

        print("src file:{}".format(os.path.split(src_file)[1]))

        label_img = cv2.imread(label_file, cv2.IMREAD_GRAYSCALE)

        dataset = gdal.Open(src_file)
        if dataset == None:
            print("open failed!\n")
            continue

        X_height = dataset.RasterYSize
        X_width = dataset.RasterXSize
        im_bands = dataset.RasterCount
        src_img = dataset.ReadAsArray(0, 0, X_width, X_height)
        src_img = np.array(src_img)
        data_type = dataset.GetRasterBand(1).DataType

        del dataset

        index = np.where(label_img == 2)  # 1: buildings
        building_label = np.zeros((X_height, X_width), np.uint8)
        building_label[index] = 1

        count = 0
        while count < image_each:
            random_width = random.randint(0, X_width - img_w - 1)
            random_height = random.randint(0, X_height - img_h - 1)
            src_roi = src_img[:, random_height:random_height + img_h,
                              random_width:random_width + img_w]
            label_roi = building_label[random_height:random_height + img_h,
                                       random_width:random_width + img_w]
            """ignore nodata area"""
            FLAG_HAS_NODATA = False
            tmp = np.unique(label_img[random_height:random_height + img_h,
                                      random_width:random_width + img_w])
            for tt in tmp:
                if tt not in valid_labels:
                    FLAG_HAS_NODATA = True
                    continue

            if FLAG_HAS_NODATA == True:
                continue
            """ignore pure background area"""
            if len(np.unique(label_roi)) < 2:
                if 0 in np.unique(label_roi):
                    continue

            if mode == 'augment':
                src_roi, label_roi = data_augment(src_roi, label_roi,
                                                  data_type)

            visualize = label_roi * 50

            # cv2.imwrite((out_path + '/buildings/visualize/%d.png' % g_count), visualize)
            cv2.imwrite((out_path + '/buildings/label/%d.png' % g_count),
                        label_roi)

            src_sample_file = out_path + '/buildings/src/%d.png' % g_count
            driver = gdal.GetDriverByName("GTiff")
            # outdataset = driver.Create(src_sample_file, img_w, img_h, im_bands, gdal.GDT_UInt16)
            outdataset = driver.Create(src_sample_file, img_w, img_h, im_bands,
                                       data_type)
            if outdataset == None:
                print("create dataset failed!\n")
                sys.exit(-2)
            if im_bands == 1:
                outdataset.GetRasterBand(1).WriteArray(src_roi)
            else:
                for i in range(im_bands):
                    outdataset.GetRasterBand(i + 1).WriteArray(src_roi[i])
            del outdataset

            count += 1
            g_count += 1
Example #20
0
def convert_all_image_to_8bits():
    src_files, tt = get_file(input_path)
    assert (tt != 0)

    for file in tqdm(src_files):

        absname = os.path.split(file)[1]
        print(absname)
        if not os.path.isfile(file):
            print("input file dose not exist:{}\n".format(file))
            continue
            # sys.exit(-1)

        dataset = gdal.Open(file)
        height = dataset.RasterYSize
        width = dataset.RasterXSize
        im_bands = dataset.RasterCount
        img = dataset.ReadAsArray(0, 0, width, height)
        del dataset
        img = np.array(img, np.uint16)
        result = []
        for i in range(im_bands):
            data = np.array(img[i])
            maxium = data.max()
            minm = data.min()
            mean = data.mean()
            std = data.std()
            data = data.reshape(height * width)
            ind = np.where((data > 0) & (data < NoData))
            ind = np.array(ind)
            # ind = ind.sort()
            #         ind = np.sort(ind)
            a, b = ind.shape
            print(b)
            tmp = np.zeros(b, np.uint16)
            for j in range(b):
                tmp[j] = data[ind[0, j]]
            tmaxium = tmp.max()
            tminm = tmp.min()
            tmean = tmp.mean()
            tstd = tmp.std()
            tt = (data - tmean) / tstd  # first Z-score normalization
            # tt = (tt + 4) * 200 / 8.0  # second min-max normalization to 255
            tt = (tt +
                  4) * 255 / 8.0 - 255  # second min-max normalization to 255

            tind = np.where(data == 0)

            tt = np.array(tt)
            tt = tt.astype(np.uint8)
            # tt = tt.astype(np.uint16)
            tt[tind] = 0

            out = tt.reshape((height, width))
            result.append(out)

        # plt.imshow(out)
        # plt.show()
        # cv2.imwrite((output_path + '%d.png' % i),out)

        # absname = os.path.split(file)[1]

        outputfile = os.path.join(output_path, absname)
        driver = gdal.GetDriverByName("GTiff")

        outdataset = driver.Create(outputfile, width, height, im_bands,
                                   gdal.GDT_Byte)
        # outdataset = driver.Create(outputfile, width, height, im_bands, gdal.GDT_UInt16)
        # if im_bands ==1:
        #     outdataset.GetRasterBand(1).WriteArray(result[0])
        # else:
        for i in range(im_bands):
            outdataset.GetRasterBand(i + 1).WriteArray(result[i])

        del outdataset
Example #21
0
def produce_training_samples_binary(in_path,
                                    out_path,
                                    image_num=50000,
                                    mode='original'):
    print('\ncreating dataset...')

    label_files, tt = get_file(os.path.join(in_path, 'label/'))
    assert (tt != 0)

    image_each = image_num / len(label_files)

    print("\n[INFO] produce samples---------------------")
    g_count = 0
    for label_file in tqdm(label_files):

        src_file = os.path.join(in_path, 'src/') + os.path.split(label_file)[1]
        if not os.path.isfile(src_file):
            print("Have no file:".format(src_file))
            continue

        print("src file:{}".format(os.path.split(src_file)[1]))

        # label_img = cv2.imread(label_file, cv2.IMREAD_GRAYSCALE)
        label_img = load_img_by_gdal(label_file, grayscale=True)
        # print("label_img: {}".format(np.unique(label_img)))
        label_img = label_img.astype(np.uint8)
        y, x = label_img.shape
        # print("label_img: {}".format(np.unique(label_img)))

        dataset = gdal.Open(src_file)
        if dataset == None:
            print("open failed!\n")
            continue

        Y_height = dataset.RasterYSize
        X_width = dataset.RasterXSize

        # check size of label and src images
        x, y = label_img.shape
        if (X_width != x and Y_height != y):
            print("label and source image have different size:".format(
                label_file))
            continue

        im_bands = dataset.RasterCount
        data_type = dataset.GetRasterBand(1).DataType

        src_img = dataset.ReadAsArray(0, 0, X_width, Y_height)
        src_img = np.array(src_img)

        del dataset

        index = np.where(label_img == target_label)
        all_label = np.zeros((Y_height, X_width), np.uint8)
        all_label[index] = 1

        print(np.unique(all_label))
        count = 0
        while count < image_each:
            random_width = random.randint(0, X_width - img_w - 1)
            random_height = random.randint(0, Y_height - img_h - 1)
            src_roi = src_img[:, random_height:random_height + img_h,
                              random_width:random_width + img_w]
            label_roi = all_label[random_height:random_height + img_h,
                                  random_width:random_width + img_w]
            """ignore nodata area"""
            FLAG_HAS_NODATA = False
            tmp = np.unique(label_img[random_height:random_height + img_h,
                                      random_width:random_width + img_w])
            for tt in tmp:
                if tt not in valid_labels:
                    FLAG_HAS_NODATA = True
                    continue

            if FLAG_HAS_NODATA == True:
                continue
            """ignore pure background area"""
            if len(np.unique(label_roi)) < 2:
                if 0 in np.unique(label_roi):
                    continue

            if mode == 'augment':
                src_roi, label_roi = data_augment(src_roi, label_roi,
                                                  data_type)

            visualize = label_roi * 50

            cv2.imwrite((out_path + '/visualize/%d.png' % g_count), visualize)
            cv2.imwrite((out_path + '/label/%d.png' % g_count), label_roi)

            src_sample_file = out_path + '/src/%d.png' % g_count
            driver = gdal.GetDriverByName("GTiff")
            # driver = gdal.GetDriverByName("PNG")
            # outdataset = driver.Create(src_sample_file, img_w, img_h, im_bands, gdal.GDT_UInt16)
            outdataset = driver.Create(src_sample_file, img_w, img_h, im_bands,
                                       data_type)
            if outdataset == None:
                print("create dataset failed!\n")
                sys.exit(-2)
            if im_bands == 1:
                outdataset.GetRasterBand(1).WriteArray(src_roi)
            else:
                for i in range(im_bands):
                    outdataset.GetRasterBand(i + 1).WriteArray(src_roi[i])
            del outdataset

            count += 1
            g_count += 1
Example #22
0
import cv2
from tqdm import tqdm

from ulitities.base_functions import load_img, get_file

input_path = '../../data//traindata/segnet/label_old/'
output_path = '../../data//traindata/segnet/label/'
label_class=[0,1,2] # remain only road, buildings

height=256
width=256


if __name__=='__main__':

    files = get_file(input_path)
    for filename in tqdm(files):
        ret, img = load_img(filename, grayscale=True)
        assert(ret == 0)
        for i in range(height):
            for j in range(width):
                if img[i,j]==4:
                    img[i,j] =1
                elif img[i,j]==2:
                    img[i,j]=2
                else:
                    print ("\n img[{},{}]:{}".format(i,j,img[i,j]))
                    img[i,j]=0

        cv2.imwrite(output_path+os.path.split(filename)[1],img)
import cv2
import sys, os
from ulitities.base_functions import load_img, get_file

input_path = '../../data/traindata/unet/buildings/label/'
output_path = '../../data/traindata/unet/buildings/visulize/'

if __name__ == '__main__':

    if not os.path.isdir(input_path):
        print("No input directory:{}".format(input_path))
        sys.exit(-1)
    if not os.path.isdir(output_path):
        print("No output directory:{}".format(output_path))
        os.mkdir(output_path)

    srcfiles, tt = get_file(input_path)
    assert (tt != 0)

    for index, file in enumerate(srcfiles):
        ret, img = load_img(file, grayscale=True)
        assert (ret == 0)

        img = img * 100
        filename = os.path.split(file)[1]
        outfile = os.path.join(output_path, filename)
        print(outfile)

        cv2.imwrite(outfile, img)
    def produce_training_samples_multiclass_selfAdapt(self):
        print('\ncreating dataset...')
        in_path = self.input_dict['input_dir']
        out_path = self.input_dict['output_dir']
        valid_labels = list(range(int(self.input_dict['min']), int(self.input_dict['max'] + 1)))
        target_label = int(self.input_dict['target_label'])

        label_files, tt = get_file(os.path.join(in_path, 'label/'))
        assert (tt != 0)

        image_num_rate = int(self.input_dict['sample_scaleRate'])

        # image_each = image_num / len(label_files)
        img_w = int(self.input_dict['window_size'])
        img_h = int(self.input_dict['window_size'])

        g_count = 0
        for label_file in tqdm(label_files):

            src_file = os.path.join(in_path, 'src/') + os.path.split(label_file)[1]
            if not os.path.isfile(src_file):
                print("Have no file:".format(src_file))
                continue
                # sys.exit(-1)

            print("src file:{}".format(os.path.split(src_file)[1]))

            label_img = cv2.imread(label_file, cv2.IMREAD_GRAYSCALE)
            absname = os.path.split(label_file)[1]
            absname = absname.split('.')[0]

            dataset = gdal.Open(src_file)
            if dataset == None:
                print("open failed!\n")
                continue

            Y_height = dataset.RasterYSize
            X_width = dataset.RasterXSize
            im_bands = dataset.RasterCount
            data_type = dataset.GetRasterBand(1).DataType

            # check size of label and src images
            x, y = label_img.shape
            print("Heigh, width of label is :{}, {}".format(x, y))
            print("Heigh, width of src is :{}, {}".format(Y_height, X_width))
            if x != Y_height or y != X_width:
                print("Warning: src and label have different size!")
                continue

            src_img = dataset.ReadAsArray(0, 0, X_width, Y_height)
            src_img = np.array(src_img)

            del dataset

            # Evaluate samples numbers according to the image_size, window_size and sample_scaleRate
            samples_num_of_current_image = int((X_width * Y_height * image_num_rate) / (img_w * img_h) + 0.5)
            if 'augment' in self.input_dict['mode']:
                samples_num_of_current_image = 6*samples_num_of_current_image
            print("Extract {} samples from {}".format(samples_num_of_current_image, os.path.split(label_file)[1]))

            count = 0
            while count < samples_num_of_current_image:
                random_width = random.randint(0, X_width - img_w - 1)
                random_height = random.randint(0, Y_height - img_h - 1)
                src_roi = src_img[:, random_height: random_height + img_h, random_width: random_width + img_w]
                label_roi = label_img[random_height: random_height + img_h, random_width: random_width + img_w]

                """ignore nodata area"""
                FLAG_HAS_NODATA = False
                tmp = np.unique(label_img[random_height: random_height + img_h, random_width: random_width + img_w])
                for tt in tmp:
                    if tt not in valid_labels:
                        FLAG_HAS_NODATA = True
                        continue

                if FLAG_HAS_NODATA == True:
                    continue

                """ignore pure background area"""
                if len(np.unique(label_roi)) < 2:
                    if 0 in np.unique(label_roi):
                        continue
                # print(np.unique(label_roi))

                if 'augment' in self.input_dict['mode']:
                    src_roi, label_roi = self.data_augment(src_roi, label_roi, data_type)

                visualize = label_roi * 50

                cv2.imwrite((out_path + '/visualize/%d_%s.png' % (g_count,absname)), visualize)
                cv2.imwrite((out_path + '/label/%d_%s.png' % (g_count, absname)), label_roi)

                src_sample_file = out_path + '/src/%d_%s.png' % (g_count,absname)
                driver = gdal.GetDriverByName("GTiff")
                outdataset = driver.Create(src_sample_file, img_w, img_h, im_bands, data_type)
                if outdataset == None:
                    print("create dataset failed!\n")
                    sys.exit(-2)
                if im_bands == 1:
                    outdataset.GetRasterBand(1).WriteArray(src_roi)
                else:
                    for i in range(im_bands):
                        outdataset.GetRasterBand(i + 1).WriteArray(src_roi[i])
                del outdataset

                count += 1
                g_count += 1
Example #25
0
    def produce_training_samples_binary(self):
        print('\ncreating dataset...')
        in_path = self.input_dict['input_dir']
        out_path = self.input_dict['output_dir']
        valid_labels = list(range(int(self.input_dict['min']), int(self.input_dict['max']+1)))
        target_label = int(self.input_dict['target_label'])

        label_files, tt = get_file(os.path.join(in_path, 'label/'))
        assert (tt != 0)

        image_num = int(self.input_dict['sample_num'])

        image_each = image_num / len(label_files)
        img_w = int(self.input_dict['window_size'])
        img_h = int(self.input_dict['window_size'])

        print("\n[INFO] produce samples---------------------")
        g_count = 0
        for label_file in tqdm(label_files):

            src_file = os.path.join(in_path, 'src/') + os.path.split(label_file)[1]
            if not os.path.isfile(src_file):
                print("Have no file:".format(src_file))
                continue

            print("src file:{}".format(os.path.split(src_file)[1]))
            label_img = load_img_by_gdal(label_file, grayscale=True)
            # print("label_img: {}".format(np.unique(label_img)))
            label_img = label_img.astype(np.uint8)
            y, x = label_img.shape
            # print("label_img: {}".format(np.unique(label_img)))


            dataset = gdal.Open(src_file)
            if dataset == None:
                print("open failed!\n")
                continue

            Y_height = dataset.RasterYSize
            X_width = dataset.RasterXSize
            if (X_width != x and Y_height != y):
                print("label and source image have different size:".format(label_file))
                continue

            im_bands = dataset.RasterCount
            data_type = dataset.GetRasterBand(1).DataType

            src_img = dataset.ReadAsArray(0, 0, X_width, Y_height)
            src_img = np.array(src_img)

            del dataset

            index = np.where(label_img == target_label)
            all_label = np.zeros((Y_height, X_width), np.uint8)
            all_label[index] = 1

            print(np.unique(all_label))
            # if no pixel in target value, ignore this label file
            tp = np.unique(all_label)
            # if tp[0]==0:
            #     print("no target value in {}".format(label_file))
            #     continue
            #
            if len(tp) < 2:
                print("Only one value {} in {}".format(tp, label_file))
                if tp[0] == 0:
                    print("no target value in {}".format(label_file))
                    continue

            count = 0
            while count < image_each:
                random_width = random.randint(0, X_width - img_w - 1)
                random_height = random.randint(0, Y_height - img_h - 1)
                src_roi = src_img[:, random_height: random_height + img_h, random_width: random_width + img_w]
                label_roi = all_label[random_height: random_height + img_h, random_width: random_width + img_w]

                """ignore nodata area"""
                FLAG_HAS_NODATA = False
                tmp = np.unique(label_img[random_height: random_height + img_h, random_width: random_width + img_w])
                for tt in tmp:
                    if tt not in valid_labels:
                        FLAG_HAS_NODATA = True
                        continue

                if FLAG_HAS_NODATA == True:
                    continue

                """ignore pure background area"""
                if len(np.unique(label_roi)) < 2:
                    if 0 in np.unique(label_roi):
                        continue

                if 'augment' in self.input_dict['mode']:
                    src_roi, label_roi = self.data_augment(src_roi, label_roi, img_w, img_h, data_type)

                visualize = label_roi * 50

                cv2.imwrite((out_path + '/visualize/%d.png' % g_count), visualize)
                cv2.imwrite((out_path + '/label/%d.png' % g_count), label_roi)

                src_sample_file = out_path + '/src/%d.png' % g_count
                driver = gdal.GetDriverByName("GTiff")
                # driver = gdal.GetDriverByName("PNG")
                # outdataset = driver.Create(src_sample_file, img_w, img_h, im_bands, gdal.GDT_UInt16)
                outdataset = driver.Create(src_sample_file, img_w, img_h, im_bands, data_type)
                if outdataset == None:
                    print("create dataset failed!\n")
                    sys.exit(-2)
                if im_bands == 1:
                    outdataset.GetRasterBand(1).WriteArray(src_roi)
                else:
                    for i in range(im_bands):
                        outdataset.GetRasterBand(i + 1).WriteArray(src_roi[i])
                del outdataset

                count += 1
                g_count += 1
Example #26
0
    else:
        for i in range(bands):
            outdataset.GetRasterBand(i + 1).WriteArray(data[i])
    del outdataset


input_dir = '/media/omnisky/e0331d4a-a3ea-4c31-90ab-41f5b0ee2663/rice/original_label/'
output_dir = '/media/omnisky/e0331d4a-a3ea-4c31-90ab-41f5b0ee2663/rice/label/'

if __name__ == '__main__':

    if not os.path.isdir(input_dir):
        print("input dir is not existed: {}".format(input_dir))
        sys.exit(-1)

    files, numb = get_file(input_dir)

    for file in tqdm(files):
        absname = os.path.split(file)[1]
        absname = absname.split('.')[0]
        tmp_file = ''.join([output_dir, absname, '.png'])
        # img = load_img_by_gdal(file, grayscale=True)
        img = cv2.imread(file, 0)
        # img = np.array(img, np.uint8)
        print("original value: {}".format(np.unique(img)))

        size = img.shape

        result = np.zeros(size, np.uint8)
        ind_targt = np.where(img == 1)
        result[ind_targt] = 1
Example #27
0
    def stretch_all_image_from_dict(self):
        if None == self.in_dict:
            QMessageBox.warning(self, "Warning", self.tr("input dict errors!"))
            sys.exit(-1)
        src_files, tt = get_file(self.in_dict['input_dir'])
        assert (tt != 0)
        NoData = int(self.in_dict['NoData'])
        valid_range = float(self.in_dict['StretchRange'])
        print(valid_range)
        cut_value = float(self.in_dict['CutValue'])
        if '8' in self.in_dict['OutBits']:
            assert(valid_range < 256)
        elif '16' in self.in_dict['OutBits']:
            assert (valid_range < 65536)


        for file in tqdm(src_files):

            absname = os.path.split(file)[1]
            absname = absname.split('.')[0]
            # absname = 'shuidao.png'
            absname = ''.join([absname, '.png'])
            print(absname)
            if not os.path.isfile(file):
                print("input file dose not exist:{}\n".format(file))
                # sys.exit(-1)
                continue

            dataset = gdal.Open(file)
            if dataset == None:
                print("Open file failed: {}".format(file))
                continue

            height = dataset.RasterYSize
            width = dataset.RasterXSize
            im_bands = dataset.RasterCount
            im_type = dataset.GetRasterBand(1).DataType
            img = dataset.ReadAsArray(0, 0, width, height)
            del dataset
            # img = np.array(img, np.uint16)
            img = np.array(img, np.float32)
            result = []
            for i in range(im_bands):
                data = np.array(img[i])
                maxium = data.max()
                minm = data.min()
                mean = data.mean()
                std = data.std()
                print(maxium, minm, mean, std)
                data = data.reshape(height * width)
                ind = np.where((data > 0) & (data < NoData))
                ind = np.array(ind)

                a, b = ind.shape
                print("valid value number: {}\n".format(b))
                # tmp = np.zeros(b, np.uint16)
                tmp = np.zeros(b, np.float32)
                for j in range(b):
                    tmp[j] = data[ind[0, j]]
                tmaxium = tmp.max()
                tminm = tmp.min()
                tmean = tmp.mean()
                tstd = tmp.std()
                print(tmaxium, tminm, tmean, tstd)
                tt = (data - tmean) / tstd  # first Z-score normalization
                tt = (tt + 4) * valid_range / 8.0 - cut_value
                tind = np.where(data == 0)

                tt = np.array(tt)
                # tt = tt.astype(np.uint8)
                tt = tt.astype(np.uint16)
                tt[tind] = 0

                smaxium = tt.max()
                sminm = tt.min()
                smean = tt.mean()
                sstd = tt.std()
                print(smaxium, sminm, smean, sstd)

                out = tt.reshape((height, width))
                result.append(out)

            outputfile = os.path.join(self.in_dict['output_dir'], absname)
            driver = gdal.GetDriverByName("GTiff")

            if '8' in self.in_dict['OutBits']:
                outdataset = driver.Create(outputfile, width, height, im_bands, gdal.GDT_Byte)
            elif '16' in self.in_dict['OutBits']:
                outdataset = driver.Create(outputfile, width, height, im_bands, gdal.GDT_UInt16)
            # outdataset = driver.Create(outputfile, width, height, im_bands, gdal.GDT_UInt16)

            for i in range(im_bands):
                outdataset.GetRasterBand(i + 1).WriteArray(result[i])

            del outdataset
Example #28
0
outputdir = '/home/omnisky/PycharmProjects/data/samples/global/miandian/train_10bits_2000'
patch_size=2000

if __name__=='__main__':
    if not os.path.isdir(inputdir):
        print("Error: input directory is not existed")
        sys.exit(-1)
    if not os.path.isdir(outputdir):
        print("Warning: output directory is not existed")
        os.mkdir(outputdir)
    out_label_dir=outputdir+'/label/'
    out_src_dir = outputdir + '/src/'

    label_list, img_list =[], []

    label_files, _=get_file(inputdir+'/label')
    img_files =[]
    for file in label_files:
        absname = os.path.split(file)[1]
        absname = absname.split('.')[0]
        img_f = find_file(inputdir+'/src',absname)
        img_files.append(img_f)
    # img_files = list()
    # img_files, _=get_file(inputdir+'/src')
    assert(len(label_files)==len(img_files))
    name_list =[]
    for i,file in enumerate(label_files):
        l_img = load_img_by_gdal(file, grayscale=True)
        if len(l_img)==0:
            continue
        label_list.append(l_img)
Example #29
0
            model,
            window_size=window_size,
            subdivisions=2,
            real_classes=target_class,  # output channels = 是真的类别,总类别-背景
            pred_func=smooth_predict_for_binary_notonehot,
            PLOT_PROGRESS=False)

        cv2.imwrite(output_file, result)
        print("Saved to: {}".format(output_file))

    gc.collect()


if __name__ == '__main__':

    all_files, num = get_file(input_path)
    if num == 0:
        print("There is no file in path:{}".format(input_path))
        sys.exit(-1)
    """checke model file"""
    print("model file: {}".format(model_file))
    if not os.path.isfile(model_file):
        print("model does not exist:{}".format(model_file))
        sys.exit(-2)

    for in_file in all_files:
        abs_filename = os.path.split(in_file)[1]
        abs_filename = abs_filename.split(".")[0]
        print(abs_filename)
        out_file = ''.join([
            output_path, '/mask_binary_', abs_filename, '_',
date_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
output_dir = ''.join([config.mask_dir, '/', date_time])
os.mkdir(output_dir)

block_size = config.block_size
nodata = config.nodata

if __name__ == '__main__':
    input_files = []
    if os.path.isfile(config.img_input):
        print("[INFO] input is one file...")
        input_files.append(config.img_input)
    elif os.path.isdir(config.img_input):
        print("[INFO] input is a directory...")
        in_files, _ = get_file(config.img_input, config.suffix)
        for file in in_files:
            input_files.append(file)
    print("{} images will be classified".format(len(input_files)))

    # sys.exit(-1)
    csv_file = os.path.join(output_dir, 'readme.csv')
    df = pd.DataFrame(list(config))
    df.to_csv(csv_file)

    out_bands = config.mask_classes
    model = load_model(config.model_path)
    # print(model.summary())

    for img_file in tqdm(input_files):
        print("\n[INFO] opening image:{}...".format(img_file))