Пример #1
0
    def test_step(self, batchdata, **kwargs):
        """test step.

        Args:
            batchdata: list for train_batch, numpy.ndarray or variable, length up to Collect class.

        Returns:
            list: outputs (already gathered from all threads)
        """
        epoch = kwargs.get('epoch', 0)
        images = batchdata[0]  # [B,N,C,H,W]
        images = ensemble_forward(images, Type=epoch)  # for ensemble

        H, W = images.shape[-2], images.shape[-1]
        scale = getattr(self.generator, 'upscale_factor', 4)
        padding_multi = self.eval_cfg.get('padding_multi', 1)
        # padding for H and W
        images = img_multi_padding(images,
                                   padding_multi=padding_multi,
                                   pad_value=-0.5)  # [B,N,C,H,W]
        output = test_generator_batch(images,
                                      get_mid_bicubic(images),
                                      netG=self.generator)  # HR [B,C,4H,4W]
        output = img_de_multi_padding(output,
                                      origin_H=H * scale,
                                      origin_W=W * scale)

        # back ensemble for G
        G = ensemble_back(output, Type=epoch)

        save_image_flag = kwargs.get('save_image')
        if save_image_flag:
            save_path = kwargs.get('save_path', None)
            start_id = kwargs.get('sample_id', None)
            if save_path is None or start_id is None:
                raise RuntimeError(
                    "if save image in test_step, please set 'save_path' and 'sample_id' parameters"
                )
            for idx in range(G.shape[0]):
                if epoch == 0:
                    imwrite(tensor2img(G[idx], min_max=(-0.5, 0.5)),
                            file_path=os.path.join(
                                save_path,
                                "idx_{}.png".format(start_id + idx)))
                else:
                    imwrite(tensor2img(G[idx], min_max=(-0.5, 0.5)),
                            file_path=os.path.join(
                                save_path, "idx_{}_epoch_{}.png".format(
                                    start_id + idx, epoch)))

        return [
            output,
        ]
Пример #2
0
    def test_step(self, batchdata, **kwargs):
        """
            possible kwargs:
                save_image
                save_path
                ensemble
        """
        lq = batchdata['lq']  #  [B,3,h,w]
        gt = batchdata.get('gt', None)  # if not None: [B,3,4*h,4*w]
        assert len(batchdata['lq_path']) == 1  # 每个sample所带的lq_path列表长度仅为1, 即自己
        lq_paths = batchdata['lq_path'][0]  # length 为batch长度
        now_start_id, clip = self.get_img_id(lq_paths[0])
        now_end_id, _ = self.get_img_id(lq_paths[-1])
        assert clip == _

        if now_start_id == 0:
            print("first frame: {}".format(lq_paths[0]))
            self.LR_list = []
            self.HR_list = []

        # pad lq
        B, _, origin_H, origin_W = lq.shape
        lq = img_multi_padding(lq,
                               padding_multi=self.eval_cfg.multi_pad,
                               pad_method="edge")  #  edge  constant
        self.LR_list.append(lq)  # [1,3,h,w]

        if gt is not None:
            for i in range(B):
                self.HR_list.append(gt[i:i + 1, ...])

        if now_end_id == 99:
            print("start to forward all frames....")
            if self.eval_cfg.gap == 1:
                # do ensemble (8 times)
                ensemble_res = []
                self.LR_list = np.concatenate(self.LR_list,
                                              axis=0)  # [100, 3,h,w]
                for item in tqdm(range(8)):  # do not have flip
                    inp = mge.tensor(ensemble_forward(self.LR_list, Type=item),
                                     dtype="float32")
                    oup = test_generator_batch(F.expand_dims(inp, axis=0),
                                               netG=self.generator)
                    ensemble_res.append(ensemble_back(oup.numpy(), Type=item))
                self.HR_G = sum(ensemble_res) / len(
                    ensemble_res)  # ensemble_res 结果取平均
            elif self.eval_cfg.gap == 2:
                raise NotImplementedError("not implement gap != 1 now")
                # self.HR_G_1 = test_generator_batch(F.stack(self.LR_list[::2], axis=1), netG=self.generator)
                # self.HR_G_2 = test_generator_batch(F.stack(self.LR_list[1::2], axis=1), netG=self.generator) # [B,T,C,H,W]
                # # 交叉组成HR_G
                # res = []
                # _,T1,_,_,_ = self.HR_G_1.shape
                # _,T2,_,_,_ = self.HR_G_2.shape
                # assert T1 == T2
                # for i in range(T1):
                #     res.append(self.HR_G_1[:, i, ...])
                #     res.append(self.HR_G_2[:, i, ...])
                # self.HR_G = F.stack(res, axis=1) # [B,T,C,H,W]
            else:
                raise NotImplementedError("do not support eval&test gap value")

            scale = self.generator.upscale_factor
            # get numpy
            self.HR_G = img_de_multi_padding(
                self.HR_G,
                origin_H=origin_H * scale,
                origin_W=origin_W * scale)  # depad for HR_G   [B,T,C,H,W]

            if kwargs.get('save_image', False):
                print("saving images to disk ...")
                save_path = kwargs.get('save_path')
                B, T, _, _, _ = self.HR_G.shape
                assert B == 1
                assert T == 100
                for i in range(T):
                    img = tensor2img(self.HR_G[0, i, ...], min_max=(0, 1))
                    if (i + 1) % 10 == 0:
                        imwrite(img,
                                file_path=os.path.join(
                                    save_path, "partframes",
                                    f"{clip}_{str(i).zfill(8)}.png"))
                    imwrite(img,
                            file_path=os.path.join(
                                save_path, "allframes",
                                f"{clip}_{str(i).zfill(8)}.png"))

        return now_end_id == 99
Пример #3
0
    def test_step(self, batchdata, **kwargs):
        """test step.
           need to know whether the first frame for video, and every step restore some hidden state.
        Args:
            batchdata: list for train_batch, numpy.ndarray, length up to Collect class.

        Returns:
            list: outputs
        """
        epoch = kwargs.get('epoch', 0)
        image = batchdata[0]  # [B,T,C,H,W]
        image = ensemble_forward(image, Type=epoch)  # for ensemble

        H, W = image.shape[-2], image.shape[-1]
        scale = getattr(self.generator, 'upscale_factor', 4)
        padding_multi = self.eval_cfg.get('padding_multi', 1)
        # padding for H and W
        image = img_multi_padding(image,
                                  padding_multi=padding_multi,
                                  pad_value=-0.5)  # [B,T,C,H,W]

        assert image.shape[0] == 1  # only support batchsize 1
        assert len(batchdata[1].shape) == 1  # first frame flag
        if batchdata[1][0] > 0.5:  # first frame
            print("first frame")
            self.now_test_num = 1
            B, _, _, now_H, now_W = image.shape
            print("use now_H : {} and now_W: {}".format(now_H, now_W))
            self.pre_SD = np.zeros((B, hidden_channels, now_H, now_W),
                                   dtype=np.float32)

        outputs = test_generator_batch(image, self.pre_SD, netG=self.generator)
        outputs = list(outputs)
        outputs[0] = img_de_multi_padding(outputs[0],
                                          origin_H=H * scale,
                                          origin_W=W * scale)

        for i in range(len(outputs)):
            outputs[i] = outputs[i].numpy()

        # update hidden state
        G, self.pre_SD = outputs

        # back ensemble for G
        G = ensemble_back(G, Type=epoch)

        save_image_flag = kwargs.get('save_image')
        if save_image_flag:
            save_path = kwargs.get('save_path', None)
            start_id = kwargs.get('sample_id', None)
            if save_path is None or start_id is None:
                raise RuntimeError(
                    "if save image in test_step, please set 'save_path' and 'sample_id' parameters"
                )
            for idx in range(G.shape[0]):
                if epoch == 0:
                    imwrite(tensor2img(G[idx], min_max=(-0.5, 0.5)),
                            file_path=os.path.join(
                                save_path,
                                "idx_{}.png".format(start_id + idx)))
                else:
                    imwrite(tensor2img(G[idx], min_max=(-0.5, 0.5)),
                            file_path=os.path.join(
                                save_path, "idx_{}_epoch_{}.png".format(
                                    start_id + idx, epoch)))

        print("now test num: {}".format(self.now_test_num))
        self.now_test_num += 1
        return outputs