def __init__(self, args):
        """args type is EasyDict class
        """
        labels = list('ABCDEFGHIJKLMNOPQRS')
        self.label2clsval = {l:i for i,l in enumerate(labels)}

        self.args = args
        self.gray = args.converse_gray
        self.image_normalizer = ImageNormalizer()
        self.pairs = self.read_paths()
        self.counter = 0
        self.image_size_in_batch = [None, None]  # height, width
Beispiel #2
0
 def __init__(self, args):
     self.args = args
     self.image_normalizer = ImageNormalizer()
     self.pairs = self.read_paths()
     self.counter = 0
     self.image_size_in_batch = [None, None]  # height, width
Beispiel #3
0
 def create_normalized_dataset(self,
                               target_width,
                               target_height,
                               root_path,
                               max_db_gb=2,
                               image_limit=None):
     if not os.path.exists(root_path):
         os.makedirs(root_path)
     normalizer = ImageNormalizer(target_width, target_height)
     mono_table_name = "mono"
     colour_table_name = "colour"
     index_table_name = "index"
     meatadata = {
         mono_table_name: {
             "width": target_width,
             "height": target_height,
             "channels": 1,
             "dtype": "float32"
         },
         colour_table_name: {
             "width": target_width,
             "height": target_height,
             "channels": 3,
             "dtype": "float32"
         },
         index_table_name: {
             "encoding": "ascii",
             "dtype": "string"
         }
     }
     with open(os.path.join(root_path, "metadata.json"),
               "w") as metadata_file:
         json.dump(meatadata, metadata_file)
     env = lmdb.open(root_path,
                     map_size=(max_db_gb * one_gb),
                     create=True,
                     max_dbs=4)
     mono_db = env.open_db(mono_table_name.encode(encoding="ascii"))
     colour_db = env.open_db(colour_table_name.encode(encoding="ascii"))
     index_db = env.open_db(index_table_name.encode(encoding="ascii"))
     i = 0
     blob_service = self._get_or_create_blob_service()
     blobs_to_be_normalized = blob_service.list_blob_names(
         self.original_image_blob_container, num_results=image_limit)
     source_count = len(list(blobs_to_be_normalized))
     source_index = 0
     for original_blob_name in blobs_to_be_normalized:
         source_index += 1
         print("Normalizing {} of {} : {}".format(source_index,
                                                  source_count,
                                                  original_blob_name))
         self.blob_service.get_blob_to_path(
             self.original_image_blob_container, original_blob_name,
             original_blob_name)
         im = Image.open(original_blob_name)
         image_id = original_blob_name.split(".")[0]
         normalized = normalizer.normalize_image(im)
         for normal_image in normalized:
             i_key = i.to_bytes(4, byteorder="big")
             i += 1
             new_image_id = "{0}.{1}.png".format(image_id,
                                                 str(uuid.uuid4()))
             mono_image_array = ToTensor()(
                 normal_image["mono"]).flatten().numpy()
             colour_image_array = ToTensor()(
                 normal_image["colour"]).flatten().numpy()
             # insert in db
             with env.begin(db=index_db, write=True, buffers=True) as txn:
                 txn.put(i_key, new_image_id.encode(encoding="ascii"))
             with env.begin(db=mono_db, write=True, buffers=True) as txn:
                 txn.put(i_key, mono_image_array.data)
             with env.begin(db=colour_db, write=True, buffers=True) as txn:
                 txn.put(i_key, colour_image_array.data)
         print("{} normalized images for {}".format(len(normalized),
                                                    original_blob_name))
         os.remove(original_blob_name)
     env.close()
     # open and close it again as readonly to compress the file
     env2 = lmdb.open(root_path, max_dbs=1)
     env2.close()
Beispiel #4
0
class DatasetPreProcessor(chainer.dataset.DatasetMixin):
    def __init__(self, args):
        self.args = args
        self.image_normalizer = ImageNormalizer()
        self.pairs = self.read_paths()
        self.counter = 0
        self.image_size_in_batch = [None, None]  # height, width

    def __len__(self):
        return len(self.pairs)

    def read_paths(self):
        path_label_pairs = []
        for image_path, label in self.__path_label_pair_generator():
            path_label_pairs.append((image_path, label))
        return path_label_pairs

    def __path_label_pair_generator(self):
        with open(self.args.image_pointer_path, 'r') as f_image:
            for image_file_name in f_image:
                image_file_name = image_file_name.rstrip()
                image_full_path = os.path.join(self.args.image_dir_path,
                                               image_file_name)

                label_file_name = image_file_name.replace('.png', '_L.npz')
                label_full_path = os.path.join(self.args.image_dir_path,
                                               label_file_name)
                if os.path.isfile(image_full_path) and os.path.isfile(
                        label_full_path):
                    yield image_full_path, np.load(label_full_path)['data']
                else:
                    assert False, "error occured at path_label_pair_generator(file is not fined)."

    def __init_batch_counter(self):
        if self.args.train and self.counter == self.args.training_params.batch_size:
            self.counter = 0
            self.image_size_in_batch = [None, None]

    def __set_image_size_in_batch(self, image):
        if self.counter == 1:
            resized_h, resized_w = image.shape[:2]
            self.image_size_in_batch = [resized_h, resized_w]

    def load_data(self, indices):
        xs = []
        ys = []

        for idx in indices:
            batch_inputs = self.get_example(idx)
            xs.append(batch_inputs[0])
            ys.append(batch_inputs[1])
        return np.array(xs, dtype=np.float32), np.array(ys, np.int32),

    def get_example(self, index):
        self.counter += 1
        if self.args.debug_mode:
            if self.counter > 15:
                assert False, 'stop test'

        path, gt = self.pairs[index]
        image = io.imread(path)

        if image is None:
            raise RuntimeError("invalid image: {}".format(path))

        if self.args.debug_mode:
            plt.figure()
            io.imshow(image)
            plt.show()

        h, w, ch = image.shape

        image, ms, ds = self.resize_image(image)
        gt = gt[ds[0]:(gt.shape[0] - ms[0] + ds[0]),
                ds[1]:(gt.shape[1] - ms[1] + ds[1])]

        image, ms, ds = self.resize_image(image, 0.25)
        gt = gt[::4, ::4]
        gt = gt[ds[0]:(gt.shape[0] - ms[0] + ds[0]),
                ds[1]:(gt.shape[1] - ms[1] + ds[1])]

        if self.args.debug_mode:
            plt.figure()
            io.imshow(image)
            plt.show()

            color_data = pd.read_csv(self.args.label_path)
            restoration([gt], color_data, './', index, 32)

        # augmentation
        if self.args.aug_params.do_augment:
            image, gt = self.augment_image(image, gt)

        # バッチごとにデータサイズを統一する
        self.__set_image_size_in_batch(image)

        # 画像の正規化
        image = self.image_normalizer.GCN(image)

        if self.args.debug_mode:
            show_image = image.astype(np.uint8)
            plt.figure()
            io.imshow(show_image)
            plt.show()

        # Chainerの入力に合わせてメモリオーダーを変更
        image = image.transpose(2, 0, 1)
        # バッチカウンターの初期化
        self.__init_batch_counter()

        batch_inputs = image.astype(np.float32), np.array(gt, dtype=np.int32)
        return batch_inputs

    def augment_image(self, image, gt):
        orig_h, orig_w, _ = image.shape
        if self.args.aug_params.params.do_scale and self.counter == 1:
            image, ms, ds = self.scaling(image)
            h, w, ch = image.shape
            inv_scale = orig_h // h, orig_w // w
            gt = gt[::inv_scale[0], ::inv_scale[1]]

        if self.args.aug_params.params.do_flip:
            image, gt = self.flip(image, gt)

        if self.args.aug_params.params.change_britghtness:
            image = self.random_brightness(image)

        if self.args.aug_params.params.change_contrast:
            image = self.random_contrast(image)

        return image, gt

    def resize_image(self, image, scale=None):
        xh, xw = image.shape[:2]

        if scale is None:
            # スケールの定義
            h_scale = (xh // chainer.config.user_multiple
                       ) * chainer.config.user_multiple / xh
            w_scale = (xw // chainer.config.user_multiple
                       ) * chainer.config.user_multiple / xw
            scale = h_scale, w_scale
        elif isinstance(scale, numbers.Number):
            scale = scale, scale
        elif isinstance(scale, tuple) and len(scale) > 2:
            raise InvalidArgumentError

        new_sz = (int(xh * scale[0]), int(xw * scale[1]))
        image = transform.resize(image, new_sz, mode='constant')

        xh, xw = image.shape[:2]
        m0, m1 = xh % chainer.config.user_multiple, xw % chainer.config.user_multiple
        d0, d1 = np.random.randint(m0 + 1), np.random.randint(m1 + 1)

        image = image[d0:(image.shape[0] - m0 + d0),
                      d1:(image.shape[1] - m1 + d1)]

        if len(image.shape) == 2:
            return image.reshape((image.shape[0], image.shape[1], 1))
        else:
            return image, (m0, m1), (d0, d1)

    def flip(self, image, gt):
        do_flip_xy = np.random.randint(0, 2)
        do_flip_x = np.random.randint(0, 2)
        do_flip_y = np.random.randint(0, 2)

        if do_flip_xy:  # X,Y軸の反転
            image = image[::-1, ::-1, :]
            gt = gt[::-1, ::-1]
        elif do_flip_x:  # X軸の反転
            image = image[::-1, :, :]
            gt = gt[::-1, :]
        elif do_flip_y:  # Y軸の反転
            image = image[:, ::-1, :]
            gt = gt[:, ::-1]
        return image, gt,

    def scaling(self, image):
        do_scale = np.random.randint(0, 2)
        if do_scale:
            scale = self.args.aug_params.params.scale[ \
                np.random.randint(0,len(self.args.aug_params.params.scale))]
            return self.resize_image(image, scale)
        else:
            return image, None, None

    def random_brightness(self, image, max_delta=63, seed=None):
        brightness_flag = np.random.randint(0, 2)
        if brightness_flag:
            delta = np.random.uniform(-max_delta, max_delta)
            return image + delta
        else:
            return image

    def random_contrast(self, image, lower=0.2, upper=1.8, seed=None):
        contrast_flag = np.random.randint(0, 2)
        if contrast_flag:
            factor = np.random.uniform(-lower, upper)
            im_mean = image.mean(axis=2)
            return ((image.transpose(2, 0, 1) - im_mean) * factor +
                    im_mean).transpose(1, 2, 0).astype(np.uint8)
        else:
            return image