def test_step(self, batchdata, **kwargs): """test step. Args: batchdata: list for train_batch, numpy.ndarray or variable, length up to Collect class. Returns: list: outputs (already gathered from all threads) """ epoch = kwargs.get('epoch', 0) images = batchdata[0] # [B,N,C,H,W] images = ensemble_forward(images, Type=epoch) # for ensemble H, W = images.shape[-2], images.shape[-1] scale = getattr(self.generator, 'upscale_factor', 4) padding_multi = self.eval_cfg.get('padding_multi', 1) # padding for H and W images = img_multi_padding(images, padding_multi=padding_multi, pad_value=-0.5) # [B,N,C,H,W] output = test_generator_batch(images, get_mid_bicubic(images), netG=self.generator) # HR [B,C,4H,4W] output = img_de_multi_padding(output, origin_H=H * scale, origin_W=W * scale) # back ensemble for G G = ensemble_back(output, Type=epoch) save_image_flag = kwargs.get('save_image') if save_image_flag: save_path = kwargs.get('save_path', None) start_id = kwargs.get('sample_id', None) if save_path is None or start_id is None: raise RuntimeError( "if save image in test_step, please set 'save_path' and 'sample_id' parameters" ) for idx in range(G.shape[0]): if epoch == 0: imwrite(tensor2img(G[idx], min_max=(-0.5, 0.5)), file_path=os.path.join( save_path, "idx_{}.png".format(start_id + idx))) else: imwrite(tensor2img(G[idx], min_max=(-0.5, 0.5)), file_path=os.path.join( save_path, "idx_{}_epoch_{}.png".format( start_id + idx, epoch))) return [ output, ]
def test_step(self, batchdata, **kwargs): """ possible kwargs: save_image save_path ensemble """ lq = batchdata['lq'] # [B,3,h,w] gt = batchdata.get('gt', None) # if not None: [B,3,4*h,4*w] assert len(batchdata['lq_path']) == 1 # 每个sample所带的lq_path列表长度仅为1, 即自己 lq_paths = batchdata['lq_path'][0] # length 为batch长度 now_start_id, clip = self.get_img_id(lq_paths[0]) now_end_id, _ = self.get_img_id(lq_paths[-1]) assert clip == _ if now_start_id == 0: print("first frame: {}".format(lq_paths[0])) self.LR_list = [] self.HR_list = [] # pad lq B, _, origin_H, origin_W = lq.shape lq = img_multi_padding(lq, padding_multi=self.eval_cfg.multi_pad, pad_method="edge") # edge constant self.LR_list.append(lq) # [1,3,h,w] if gt is not None: for i in range(B): self.HR_list.append(gt[i:i + 1, ...]) if now_end_id == 99: print("start to forward all frames....") if self.eval_cfg.gap == 1: # do ensemble (8 times) ensemble_res = [] self.LR_list = np.concatenate(self.LR_list, axis=0) # [100, 3,h,w] for item in tqdm(range(8)): # do not have flip inp = mge.tensor(ensemble_forward(self.LR_list, Type=item), dtype="float32") oup = test_generator_batch(F.expand_dims(inp, axis=0), netG=self.generator) ensemble_res.append(ensemble_back(oup.numpy(), Type=item)) self.HR_G = sum(ensemble_res) / len( ensemble_res) # ensemble_res 结果取平均 elif self.eval_cfg.gap == 2: raise NotImplementedError("not implement gap != 1 now") # self.HR_G_1 = test_generator_batch(F.stack(self.LR_list[::2], axis=1), netG=self.generator) # self.HR_G_2 = test_generator_batch(F.stack(self.LR_list[1::2], axis=1), netG=self.generator) # [B,T,C,H,W] # # 交叉组成HR_G # res = [] # _,T1,_,_,_ = self.HR_G_1.shape # _,T2,_,_,_ = self.HR_G_2.shape # assert T1 == T2 # for i in range(T1): # res.append(self.HR_G_1[:, i, ...]) # res.append(self.HR_G_2[:, i, ...]) # self.HR_G = F.stack(res, axis=1) # [B,T,C,H,W] else: raise NotImplementedError("do not support eval&test gap value") scale = self.generator.upscale_factor # get numpy self.HR_G = img_de_multi_padding( self.HR_G, origin_H=origin_H * scale, origin_W=origin_W * scale) # depad for HR_G [B,T,C,H,W] if kwargs.get('save_image', False): print("saving images to disk ...") save_path = kwargs.get('save_path') B, T, _, _, _ = self.HR_G.shape assert B == 1 assert T == 100 for i in range(T): img = tensor2img(self.HR_G[0, i, ...], min_max=(0, 1)) if (i + 1) % 10 == 0: imwrite(img, file_path=os.path.join( save_path, "partframes", f"{clip}_{str(i).zfill(8)}.png")) imwrite(img, file_path=os.path.join( save_path, "allframes", f"{clip}_{str(i).zfill(8)}.png")) return now_end_id == 99
def test_step(self, batchdata, **kwargs): """test step. need to know whether the first frame for video, and every step restore some hidden state. Args: batchdata: list for train_batch, numpy.ndarray, length up to Collect class. Returns: list: outputs """ epoch = kwargs.get('epoch', 0) image = batchdata[0] # [B,T,C,H,W] image = ensemble_forward(image, Type=epoch) # for ensemble H, W = image.shape[-2], image.shape[-1] scale = getattr(self.generator, 'upscale_factor', 4) padding_multi = self.eval_cfg.get('padding_multi', 1) # padding for H and W image = img_multi_padding(image, padding_multi=padding_multi, pad_value=-0.5) # [B,T,C,H,W] assert image.shape[0] == 1 # only support batchsize 1 assert len(batchdata[1].shape) == 1 # first frame flag if batchdata[1][0] > 0.5: # first frame print("first frame") self.now_test_num = 1 B, _, _, now_H, now_W = image.shape print("use now_H : {} and now_W: {}".format(now_H, now_W)) self.pre_SD = np.zeros((B, hidden_channels, now_H, now_W), dtype=np.float32) outputs = test_generator_batch(image, self.pre_SD, netG=self.generator) outputs = list(outputs) outputs[0] = img_de_multi_padding(outputs[0], origin_H=H * scale, origin_W=W * scale) for i in range(len(outputs)): outputs[i] = outputs[i].numpy() # update hidden state G, self.pre_SD = outputs # back ensemble for G G = ensemble_back(G, Type=epoch) save_image_flag = kwargs.get('save_image') if save_image_flag: save_path = kwargs.get('save_path', None) start_id = kwargs.get('sample_id', None) if save_path is None or start_id is None: raise RuntimeError( "if save image in test_step, please set 'save_path' and 'sample_id' parameters" ) for idx in range(G.shape[0]): if epoch == 0: imwrite(tensor2img(G[idx], min_max=(-0.5, 0.5)), file_path=os.path.join( save_path, "idx_{}.png".format(start_id + idx))) else: imwrite(tensor2img(G[idx], min_max=(-0.5, 0.5)), file_path=os.path.join( save_path, "idx_{}_epoch_{}.png".format( start_id + idx, epoch))) print("now test num: {}".format(self.now_test_num)) self.now_test_num += 1 return outputs