Пример #1
0
    def batch_gen(self, path, batch_size, crop_size):
        content_path, mask_path = get_file_paths(path)

        while True:
            index = random.sample(range(1, len(content_path)), batch_size)
            try:
                offset_h = random.randint(0, (2448 - crop_size[0]))
                offset_w = random.randint(0, (2448 - crop_size[1]))
                offset = (offset_h, offset_w)

                contents = [
                    vgg_sub_mean(
                        random_crop(get_image(content_path[i]), offset,
                                    crop_size)) for i in index
                ]
                masks = [
                    mask_preprocess(
                        random_crop(get_image(mask_path[i]), offset,
                                    crop_size)) for i in index
                ]

                contents = np.asarray(contents, dtype=np.float32)
                masks = np.asarray(masks, dtype=np.uint8)

            except Exception as err:
                print("\nError: {}".format(err))
                continue

            yield contents, masks
Пример #2
0
    def eval(self):
        # print('loading model from {}...'.format(self.model.model_path))
        # self.model.load(self.model.model_path)
        # print('Model is loaded!!!')
        self.model.build_model()

        path = 'dataset/train/998002_sat.jpg'
        img = np.expand_dims(vgg_sub_mean(get_image(path)), axis=0)
        reconst_mask = self.model.decode(img)
Пример #3
0
    def eval(self):
        # print('loading model from {}...'.format(self.model.model_path))        
        # self.model.load(self.model.model_path)
        # print('Model is loaded!!!')
        self.model.build_model()
        content_path, _ = get_file_paths(self.test_path)
        content_path.sort()

        for path in content_path:
            fileID = path.split('/')[-1].split('_')[0]
            output_name = '{}_mask.png'.format(fileID)
            output_name = os.path.join(self.output_path, output_name)
            print(output_name)
            img = np.expand_dims(vgg_sub_mean(get_image(path)), axis=0)
            reconst_mask = self.model.decode(img)
            reconst_mask = mask_postprocess(reconst_mask[0])
            reconst_mask = image_resize(reconst_mask, size=(612, 612))
            skimage.io.imsave(output_name, reconst_mask)
Пример #4
0
def batch_gen(dir, batch_size):
    content_path, mask_path = get_file_paths(dir)
    content_path.sort()
    mask_path.sort()

    while True:
        index = random.sample(range(1, len(content_path)), batch_size)
        try:
            contents = [
                vgg_sub_mean(get_image(content_path[i])) for i in index
            ]
            masks = [mask_preprocess(get_image(mask_path[i])) for i in index]

            contents = np.asarray(contents, dtype=np.float32)
            masks = np.asarray(masks, dtype=np.float32)

        except Exception as err:
            print("\nError: {}".format(err))
            continue

        yield contents, masks
Пример #5
0
    def eval(self, output_mode):
        # print('loading model from {}...'.format(self.model.model_path))
        # self.model.load(self.model.model_path)
        # print('Model is loaded!!!')
        def merge(group, size=(612, 612), num_per_side=13):
            np_mask = np.zeros((2448, 2448, 7))
            # img = Image.new('RGB', (2448, 2448))
            for idx, chunk in enumerate(group):
                offset_x = idx // num_per_side * size[0] // 4
                offset_y = idx % num_per_side * size[1] // 4
                np_mask[offset_x:offset_x + size[0], offset_y:offset_y +
                        size[1], :] = np_mask[offset_x:offset_x + size[0],
                                              offset_y:offset_y +
                                              size[1], :] + chunk[:, :, :]
            img = mask_postprocess(np_mask)
            img = Image.fromarray(img)
            return img

        self.model.build_model()
        content_path, _ = get_file_paths(self.test_path)
        content_path.sort()

        group = []

        for cnt, path in enumerate(content_path):
            # assert len(content_path) % 169 == 0

            img = np.expand_dims(vgg_sub_mean(get_image(path)), axis=0)
            reconst_mask = self.model.decode(img)

            if output_mode == 'img':
                fileID = path.split('/')[-1].split('_')[0]
                output_name = '{}_mask.png'.format(fileID)
                output_name = os.path.join(self.output_path, output_name)
                print("({}/{}) {}".format(cnt, len(content_path), output_name))

                reconst_mask = mask_postprocess(reconst_mask[0])
                reconst_mask = image_resize(reconst_mask, size=(612, 612))
                skimage.io.imsave(output_name, reconst_mask)
            elif output_mode == 'npz':
                if cnt % 169 == 0 and cnt != 0:
                    print('Saving {}'.format(output_name))
                    img = merge(group)
                    img.save(output_name)
                    group = []

                fileID = path.split('/')[-1].split('-')[0]
                output_name = '{}_mask.png'.format(fileID)
                output_name = os.path.join(self.output_path, output_name)
                print("({}/{}) {}".format(cnt + 1, len(content_path),
                                          output_name))

                scale = 612. / 512.
                reconst_mask = scipy.ndimage.interpolation.zoom(
                    reconst_mask[0], zoom=(scale, scale, 1), mode='reflect')

                group.append(reconst_mask)

        # Saving the last figure
        if output_mode == 'npz':
            print('Saving {}'.format(output_name))
            img = merge(group)
            img.save(output_name)
Пример #6
0
if __name__ == '__main__':
    test_path = args.input_path
    output_path = args.output_path
    model_path = args.model_path
    batch_size = 4

    if args.net == 'fcn32':
        model = VGG_FCN32(mode='test', model_path=model_path)
    elif args.net == 'fcn8':
        model = VGG_FCN8(mode='test', model_path=model_path)
    

    content_path = glob.glob("{}/*.jpg".format(test_path))
    content_path.sort()

    for i, path in enumerate(content_path):
        output_name = "{}{:04d}_mask.png".format(output_path, i)

        img = vgg_sub_mean(get_image(path))
        img = np.expand_dims(img, axis=0)
        reconstructed_mask = model.decode(img)
        reconstructed_mask = mask_postprocess(reconstructed_mask[0])
        skimage.io.imsave(output_name, reconstructed_mask)