def infer_sequence(self, lr_data, device): """ Parameters: :param lr_data: torch.FloatTensor in shape tchw :param device: torch.device :return hr_seq: uint8 np.ndarray in shape tchw """ # setup params tot_frm, c, h, w = lr_data.size() s = self.scale # forward hr_seq = [] lr_prev = torch.zeros(1, c, h, w, dtype=torch.float32).to(device) hr_prev = torch.zeros(1, self.out_nc, s * h, s * w, dtype=torch.float32).to(device) for i in range(tot_frm): with torch.no_grad(): self.eval() lr_curr = lr_data[i:i + 1, ...].to(device) hr_curr = self.forward(lr_curr, lr_prev, hr_prev) lr_prev, hr_prev = lr_curr, hr_curr hr_frm = hr_curr.squeeze(0).cpu().numpy() # chw|rgb|uint8 hr_seq.append(float32_to_uint8(hr_frm)) return np.stack(hr_seq).transpose(0, 2, 3, 1) # thwc
def infer_sequence(self, lr_data, device): """ Parameters: :param lr_data: torch.FloatTensor in shape tchw :param device: torch.device :return hr_seq: uint8 np.ndarray in shape tchw """ # setup params tot_frm, c, h, w = lr_data.size() p = self.depth // 2 print(lr_data.size()) # forward hr_seq = [] for i in range(p, tot_frm-p): with torch.no_grad(): self.eval() lr_seq = lr_data[i-p: i+p+1, ...].to(device) hr_curr = self.forward(lr_seq) hr_frm = hr_curr.squeeze(0).cpu().numpy() # chw|rgb|uint8 hr_seq.append(float32_to_uint8(hr_frm)) return np.stack(hr_seq).transpose(0, 2, 3, 1) # thwc
def infer(self, lr_data): lr_data = data_utils.canonicalize( lr_data) # to torch.FloatTensor thwc print(lr_data.size()) _, h, w, _ = lr_data.size() lr_yuv = data_utils.rgb2yCbCr(lr_data) lr_yuv = lr_yuv.permute(0, 3, 1, 2) # thwc lr_y = lr_yuv[:, 0:1, :, :] lr_u = lr_yuv[:, 1:2, :, :] lr_v = lr_yuv[:, 2:3, :, :] # dual direct temporal padding lr_y_seq, n_pad_front = self.pad_sequence(lr_y) # infer hr_y_seq = self.net_G.infer_sequence(lr_y_seq, self.device) hr_u_seq = tvs.resize(lr_u, [self.scale * h, self.scale * w], interpolation=3) # bilinear:2(default) bicubic:3 hr_v_seq = tvs.resize(lr_v, [self.scale * h, self.scale * w], interpolation=3) # bilinear:2(default) bicubic:3 hr_yuv = torch.cat((hr_y_seq, hr_u_seq, hr_v_seq), dim=1) hr_yuv = hr_yuv.permute(0, 2, 3, 1) # tchw hr_rgb = data_utils.yCbCr2rgb(hr_yuv).numpy() hr_seq = data_utils.float32_to_uint8(hr_rgb) # thwc|rgb|uint8 return hr_seq
def infer_sequence_generator(self, seq_gen, device): """ Parameters: :param lr_data: torch.FloatTensor in shape tchw :param device: torch.device :yield hr_frame: uint8 np.ndarray in shape 1chw """ first_frame = True s = self.scale # hr_seq = [] # gt_seq = [] # lr_seq = [] # frm_names = [] # forward for item in seq_gen: with torch.no_grad(): self.eval() lr_curr = item['lr'] gt_curr = item['gt'] frm_idx = item['frm_idx'] if first_frame: _, c, h, w = lr_curr.size() first_frame = False lr_prev = torch.zeros(1, c, h, w, dtype=torch.float32).to(device) hr_prev = torch.zeros(1, self.out_nc, s * h, s * w, dtype=torch.float32).to(device) lr_curr = lr_curr.to(device) hr_curr = self.forward(lr_curr, lr_prev, hr_prev) lr_prev, hr_prev = lr_curr, hr_curr hr_frm = hr_curr.squeeze(0).cpu().numpy() # chw|rgb|uint8 gt_curr = gt_curr.squeeze(0).cpu().numpy() lr_frm = lr_curr.squeeze(0).cpu().numpy()[:3] yield { 'frm_idx': frm_idx, 'lr': float32_to_uint8(lr_frm).transpose(1, 2, 0), 'gt': float32_to_uint8(gt_curr).transpose(1, 2, 0), 'hr': float32_to_uint8(hr_frm).transpose(1, 2, 0) }
def upscale_sequence(data, gt_h, gt_w, batch_size=10): data = data.permute(0, 3, 1, 2) t, c, h, w = data.size() result = [] for idx_start in range(0, t, batch_size): idx_end = min(idx_start + batch_size, t) data_item = data[idx_start:idx_end] data_item = F.interpolate(data_item, size=(gt_h, gt_w), mode='bilinear', align_corners=False) result.append(data_item.cpu().numpy()) result = np.concatenate(result) result = result.transpose(0, 2, 3, 1) result = float32_to_uint8(result) return result
def data_processing(model, data, test_loader, input_data_type): if input_data_type == 'BD': lr_data = test_loader.dataset.apply_BD(data['gt'])['lr'][0] else: lr_data = data['lr'][0] seq_idx = data['seq_idx'][0] frm_idx = [frm_idx[0] for frm_idx in data['frm_idx']] output_seq = model.infer(lr_data) # thwc|rgb|uint8 _, h, w, _ = output_seq.shape if input_data_type != 'Style': input_seq = upscale_sequence(lr_data, h, w) # thwc|rgb|uint8 else: input_seq = lr_data.detach().cpu().numpy() input_seq = data_utils.float32_to_uint8(input_seq) return input_seq, output_seq, seq_idx, frm_idx
def infer_sequence(self, lr_data, device): """ Parameters: :param lr_data: torch.FloatTensor in shape tchw :param device: torch.device :return hr_seq: uint8 np.ndarray in shape tchw """ # setup params tot_frm, c, h, w = lr_data.size() s = self.scale # forward hr_seq = [] lr_prev = torch.zeros(1, c, h, w, dtype=torch.float32).to(device) hr_prev = torch.zeros(1, c, s * h, s * w, dtype=torch.float32).to(device) for i in range(tot_frm): with torch.no_grad(): self.eval() lr_curr = lr_data[i:i + 1, ...].to(device) hr_curr, hr_flow, hr_warp = self.forward( lr_curr, lr_prev, hr_prev) lr_prev, hr_prev = lr_curr, hr_curr hr_warp = hr_warp.squeeze(0).cpu().numpy() # chw|rgb|uint8 hr_frm = hr_warp.transpose(1, 2, 0) # hwc flow_frm = hr_flow.squeeze(0).cpu().numpy() # chw|rgb|uint8 flow_uv = flow_frm.transpose(1, 2, 0) # hwc flow_color = flow_vis.flow_to_color(flow_uv, convert_to_bgr=False) hr_seq.append(float32_to_uint8(hr_frm)) # hr_seq.append(float32_to_uint8(flow_color)) return np.stack(hr_seq)