def generateData(batch_size, data=[]):
    # print 'generateData...'
    while True:
        train_data = []
        train_label = []
        batch = 0
        for i in (range(len(data))):
            url = data[i]
            batch += 1
            # img = load_img(train_data_path + 'src/' + url)

            _, img = load_img_normalization(im_bands,
                                            (train_data_path + 'src/' + url),
                                            data_type=im_type)

            # Adapt dim_ordering automatically
            img = img_to_array(img)
            train_data.append(img)
            # label = load_img(train_data_path + 'label/' + url, grayscale=True)
            _, label = load_img_normalization(
                1, (train_data_path + 'label/' + url))
            label = img_to_array(label)
            train_label.append(label)
            if batch % batch_size == 0:
                # print 'get enough bacth!\n'
                train_data = np.array(train_data)
                train_label = np.array(train_label)
                # train_label = to_categorical(train_label, num_classes=n_label)  # one_hot coding
                train_label = train_label.reshape(
                    (batch_size, img_w * img_h, n_label))
                yield (train_data, train_label)
                train_data = []
                train_label = []
                batch = 0
Esempio n. 2
0
def generateValidData(batch_size, data=[]):
    # print 'generateValidData...'
    while True:
        valid_data = []
        valid_label = []
        batch = 0
        for i in (range(len(data))):
            url = data[i]
            batch += 1
            _, img = load_img_normalization(
                config.im_bands, (config.train_data_path + '/src/' + url),
                data_type=im_type)
            # Adapt dim_ordering automatically
            img = img_to_array(img)
            valid_data.append(img)
            _, label = load_img_normalization(
                1, (config.train_data_path + '/label/' + url))
            label = img_to_array(label)
            valid_label.append(label)
            if batch % batch_size == 0:
                valid_data = np.array(valid_data)
                valid_label = np.array(valid_label)
                if config.nb_classes > 2:
                    valid_label = to_categorical(valid_label,
                                                 num_classes=config.nb_classes)
                yield (valid_data, valid_label)
                valid_data = []
                valid_label = []
                batch = 0
    if not os.path.isdir(train_data_path):
        print("train data does not exist in the path:\n {}".format(
            train_data_path))

    if FLAG_USING_NETWORK == 0:
        model = binary_unet_jaccard_4orMore(input_bands, n_label)
    elif FLAG_USING_NETWORK == 1:
        model = binary_fcnnet_jaccard(n_label)
    elif FLAG_USING_NETWORK == 2:
        model = binary_segnet_jaccard(n_label)

    print("Train by : {}".format(dict_network[FLAG_USING_NETWORK]))
    train(model, model_save_path)

    if FLAG_MAKE_TEST:
        print("test ....................predict by trained model .....\n")
        test_img_path = '../../data/test/sample1.png'
        import sys

        if not os.path.isfile(test_img_path):
            print("no file: {}".format(test_img_path))
            sys.exit(-1)

        ret, input_img = load_img_normalization(test_img_path)
        # model_save_path ='../../data/models/unet_buildings_onehot.h5'

        new_model = load_model(model_save_path)

        test_predict(input_img, new_model)
Esempio n. 4
0
def val_data_generator(config, sample_url):
    # print 'generate validating Data...'
    norm_value = 255.0
    w = config.img_w
    h = config.img_h
    label_list, img_list = [], []
    for pic in sample_url:
        _, t_img = load_img_normalization(
            1, config.train_data_path + '/label/' + pic)
        tp = np.unique(t_img)
        if len(tp) < 2:
            print("Only one value {} in {}".format(
                tp, config.train_data_path + '/label/' + pic))
            if tp[0] == 0:
                print("no target value in {}".format(config.train_data_path +
                                                     '/label/' + pic))
                continue
        ret, s_img = load_img_bybandlist(
            (config.train_data_path + '/src/' + pic),
            bandlist=config.band_list)
        if ret != 0:
            continue
        s_img = img_to_array(s_img)
        img_list.append(s_img)
        label_list.append(t_img)

    assert len(label_list) == len(img_list)

    train_data = []
    train_label = []
    batch = 0
    while True:
        if batch == 0:
            train_data = []
            train_label = []
        # batch = 0
        for i in (range(len(img_list))):
            img = img_list[i]
            label = label_list[i]
            img = np.asarray(img).astype("float") / norm_value
            label = np.asarray(label)
            assert img.shape[0:2] == label.shape[0:2]

            for i in range(img.shape[0] // h):
                for j in range(img.shape[1] // w):
                    x = img[i * h:(i + 1) * h, (j * w):(j + 1) * w, :]
                    y = label[i * h:(i + 1) * h, (j * w):(j + 1) * w]

                    if config.label_nodata in np.unique(y):
                        continue
                    """ignore pure background area"""
                    if len(np.unique(y)) < 2:
                        if (0 in np.unique(y)) and (np.random.random() < 0.75):
                            continue
                    x = img_to_array(x)
                    y = img_to_array(y)
                    train_data.append(x)
                    train_label.append(y)

                    batch += 1
                    if batch % config.batch_size == 0:
                        train_data = np.array(train_data)
                        train_label = np.array(train_label)
                        if config.nb_classes > 2:
                            train_label = to_categorical(
                                train_label, num_classes=config.nb_classes)
                        yield (train_data, train_label)
                        train_data = []
                        train_label = []
                        batch = 0
Esempio n. 5
0
def train_data_generator(config, sample_url):
    # print 'generateData...'
    norm_value = 255.0
    bits_num = 8
    if '10' in config.im_type:
        norm_value = 1024.0
        bits_num = 16
    elif '16' in config.im_type:
        norm_value = 65535.0
        bits_num = 16
    else:
        pass
    label_list, img_list = [], []
    for pic in sample_url:
        _, t_img = load_img_normalization(
            1, config.train_data_path + '/label/' + pic)
        tp = np.unique(t_img)
        if len(tp) < 2:
            print("Only one value {} in {}".format(
                tp, config.train_data_path + '/label/' + pic))
            if tp[0] == 0:
                print("no target value in {}".format(config.train_data_path +
                                                     '/label/' + pic))
                continue

        ret, s_img = load_img_bybandlist(
            (config.train_data_path + '/src/' + pic),
            bandlist=config.band_list)
        if ret != 0:
            continue

        s_img = img_to_array(s_img)
        s_img = np.asarray(s_img, np.uint16)
        # plt.imshow(s_img[:,:,0])
        # plt.show()
        img_list.append(s_img)
        label_list.append(t_img)
    assert len(label_list) == len(img_list)

    train_data = []
    train_label = []
    batch = 0
    while True:
        if batch == 0:  # 防止训练集图像数量少于batch_size
            train_data = []
            train_label = []

        for i in np.random.permutation(np.arange(len(img_list))):
            try:
                src_img = img_list[i]
            except:
                print("can not extract data from img")
            try:
                label_img = label_list[i]
            except:
                print("can not extract data from label")
            random_size = random.randrange(config.img_w, config.img_w * 2 + 1,
                                           config.img_w)
            # random_size = config.img_w
            img, label = random_crop(src_img, label_img, random_size,
                                     random_size)

            if config.label_nodata in np.unique(label):
                continue
            """ignore pure background area"""
            if len(np.unique(label)) < 2:
                if (0 in np.unique(label)) and (np.random.random() < 0.75):
                    continue

            if img.shape[1] != config.img_w or img.shape[0] != config.img_h:
                # print("resize samples")
                img = resample_data(img,
                                    config.img_h,
                                    config.img_w,
                                    mode=Image.BILINEAR,
                                    bits=bits_num)
                label = resample_data(label,
                                      config.img_h,
                                      config.img_w,
                                      mode=Image.NEAREST)

            if config.augment:
                img, label = data_augment(img, label, config.img_w,
                                          config.img_h)

            img = np.asarray(img).astype(np.float32) / norm_value
            img = np.clip(img, 0.0, 1.0)

            batch += 1
            img = img_to_array(img)
            label = img_to_array(label)
            train_data.append(img)
            train_label.append(label)
            if batch % config.batch_size == 0:
                train_data = np.array(train_data)
                train_label = np.array(train_label)
                # print("img shape:{}".format(train_data.shape))
                # print("label shap:{}".format(train_label.shape))
                if config.nb_classes > 2:
                    train_label = to_categorical(train_label,
                                                 num_classes=config.nb_classes)
                yield (train_data, train_label)
                train_data = []
                train_label = []
                batch = 0