コード例 #1
0
ファイル: evaluate.py プロジェクト: coderzbx/chainer-pspnet
        n_blocks = [3, 4, 6, 3]
        feat_size = 60
        mid_stride = False
        param_fn = 'weights/pspnet101_ADE20K_473_reference.chainer'
        base_size = 512
        crop_size = 473

    dataset = TransformDataset(dataset, preprocess)
    print('dataset:', len(dataset))

    chainer.config.train = False
    model = PSPNet(n_class, n_blocks, feat_size, mid_stride=mid_stride)
    serializers.load_npz(param_fn, model)
    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu)
        model.to_gpu(args.gpu)

    for i in tqdm(range(args.start_i, args.end_i + 1)):
        img = dataset[i]
        out_fn = os.path.join(args.out_dir,
                              os.path.basename(dataset._dataset.img_fns[i]))
        pred = inference(model, n_class, base_size, crop_size, img,
                         args.scales)
        assert pred.ndim == 2

        if args.model == 'Cityscapes':
            if args.color_out_dir is not None:
                color_out = np.zeros((pred.shape[0], pred.shape[1], 3),
                                     dtype=np.uint8)
            label_out = np.zeros_like(pred)
            for label in cityscapes_labels:
コード例 #2
0
class ModelPSPNet:
    def __init__(self, model):
        chainer.config.stride_rate = float(2.0 / 3.0)
        chainer.config.save_test_image = False
        chainer.config.train = False
        self.scales = None
        self.gpu = 0
        self.model = model
        self.img_fn = 'test.jpg'

        self.n_class = 19
        self.n_blocks = [3, 4, 23, 3]
        self.feat_size = 90
        self.mid_stride = True
        self.param_fn = 'weights/pspnet101_cityscapes_713_reference.chainer'
        self.base_size = 2048
        self.crop_size = 713
        self.labels = cityscapes_label_names
        self.colors = cityscapes_label_colors

        if self.model == 'VOC':
            self.n_class = 21
            self.n_blocks = [3, 4, 23, 3]
            self.feat_size = 60
            self.mid_stride = True
            self.param_fn = 'weights/pspnet101_VOC2012_473_reference.chainer'
            self.base_size = 512
            self.crop_size = 473
            self.labels = voc_semantic_segmentation_label_names
            self.colors = voc_semantic_segmentation_label_colors
        elif self.model == 'Cityscapes':
            self.n_class = 19
            self.n_blocks = [3, 4, 23, 3]
            self.feat_size = 90
            self.mid_stride = True
            self.param_fn = 'weights/pspnet101_cityscapes_713_reference.chainer'
            self.base_size = 2048
            self.crop_size = 713
            self.labels = cityscapes_label_names
            self.colors = cityscapes_label_colors
        elif self.model == 'ADE20K':
            self.n_class = 150
            self.n_blocks = [3, 4, 6, 3]
            self.feat_size = 60
            self.mid_stride = False
            self.param_fn = 'weights/pspnet101_ADE20K_473_reference.chainer'
            self.base_size = 512
            self.crop_size = 473

        self.net = PSPNet(self.n_class,
                          self.n_blocks,
                          self.feat_size,
                          mid_stride=self.mid_stride)
        serializers.load_npz(self.param_fn, self.net)
        if self.gpu >= 0:
            chainer.cuda.get_device_from_id(self.gpu).use()
            self.net.to_gpu(self.gpu)

    @staticmethod
    def load_image(image_data):
        f = StringIO.BytesIO(image_data)
        image = Image.open(f)
        if image.mode not in ('L', 'RGB'):
            image = image.convert('RGB')

        image = np.asarray(image, dtype=np.float32)

        if image.ndim == 2:
            # reshape (H, W) -> (1, H, W)
            return image[np.newaxis]
        else:
            # transpose (H, W, C) -> (C, H, W) and
            return image.transpose(2, 0, 1)

    def do(self, image_data):
        img = preprocess(self.load_image(image_data=image_data))

        # Inference
        pred = inference(self.net, self.n_class, self.base_size,
                         self.crop_size, img, self.scales)

        # Save the result image
        # ax = vis_image(img)
        # _, legend_handles = vis_label(pred, self.labels, self.colors, alpha=1.0, ax=ax)
        # ax.legend(handles=legend_handles, bbox_to_anchor=(1.05, 1), loc=2,
        #           borderaxespad=0.)
        # base = os.path.splitext(os.path.basename(self.img_fn))[0]
        # plot.savefig('predict_{}.png'.format(base), bbox_inches='tight', dpi=400)

        # if self.model == 'Cityscapes':
        #     label_out = np.zeros((img.shape[1], img.shape[2], 3), dtype=np.uint8)
        #     for label in cityscapes_labels:
        #         label_out[np.where(pred == label.trainId)] = label.color
        #
        #     io.imsave(
        #         'predict_{}_color({}).png'.format("ss", self.scales), label_out)
        label_out = np.zeros((img.shape[1], img.shape[2], 3), dtype=np.uint8)
        for label in cityscapes_labels:
            label_out[np.where(pred == label.trainId)] = label.color

        pred_data = StringIO.BytesIO()
        pred_img = Image.fromarray(label_out)
        pred_img.save(pred_data, format="PNG")
        pred_data = pred_data.getvalue()

        print('finish')
        return pred_data