Ejemplo n.º 1
0
    def infer_sequence(self, lr_data, device):
        """
            Parameters:
                :param lr_data: torch.FloatTensor in shape tchw
                :param device: torch.device

                :return hr_seq: uint8 np.ndarray in shape tchw
        """

        # setup params
        tot_frm, c, h, w = lr_data.size()
        s = self.scale

        # forward
        hr_seq = []
        lr_prev = torch.zeros(1, c, h, w, dtype=torch.float32).to(device)
        hr_prev = torch.zeros(1,
                              self.out_nc,
                              s * h,
                              s * w,
                              dtype=torch.float32).to(device)
        for i in range(tot_frm):
            with torch.no_grad():
                self.eval()

                lr_curr = lr_data[i:i + 1, ...].to(device)
                hr_curr = self.forward(lr_curr, lr_prev, hr_prev)
                lr_prev, hr_prev = lr_curr, hr_curr
                hr_frm = hr_curr.squeeze(0).cpu().numpy()  # chw|rgb|uint8
            hr_seq.append(float32_to_uint8(hr_frm))

        return np.stack(hr_seq).transpose(0, 2, 3, 1)  # thwc
Ejemplo n.º 2
0
    def infer_sequence(self, lr_data, device):
        """
            Parameters:
                :param lr_data: torch.FloatTensor in shape tchw
                :param device: torch.device

                :return hr_seq: uint8 np.ndarray in shape tchw
        """

        # setup params
        tot_frm, c, h, w = lr_data.size()
        p = self.depth // 2

        print(lr_data.size())

        # forward
        hr_seq = []

        for i in range(p, tot_frm-p):
            with torch.no_grad():
                self.eval()

                lr_seq = lr_data[i-p: i+p+1, ...].to(device)
                hr_curr = self.forward(lr_seq)
                hr_frm = hr_curr.squeeze(0).cpu().numpy()  # chw|rgb|uint8

            hr_seq.append(float32_to_uint8(hr_frm))

        return np.stack(hr_seq).transpose(0, 2, 3, 1)  # thwc
Ejemplo n.º 3
0
    def infer(self, lr_data):

        lr_data = data_utils.canonicalize(
            lr_data)  # to torch.FloatTensor  thwc

        print(lr_data.size())
        _, h, w, _ = lr_data.size()

        lr_yuv = data_utils.rgb2yCbCr(lr_data)
        lr_yuv = lr_yuv.permute(0, 3, 1, 2)  # thwc

        lr_y = lr_yuv[:, 0:1, :, :]
        lr_u = lr_yuv[:, 1:2, :, :]
        lr_v = lr_yuv[:, 2:3, :, :]

        # dual direct temporal padding
        lr_y_seq, n_pad_front = self.pad_sequence(lr_y)

        # infer
        hr_y_seq = self.net_G.infer_sequence(lr_y_seq, self.device)
        hr_u_seq = tvs.resize(lr_u, [self.scale * h, self.scale * w],
                              interpolation=3)  # bilinear:2(default) bicubic:3
        hr_v_seq = tvs.resize(lr_v, [self.scale * h, self.scale * w],
                              interpolation=3)  # bilinear:2(default) bicubic:3

        hr_yuv = torch.cat((hr_y_seq, hr_u_seq, hr_v_seq), dim=1)
        hr_yuv = hr_yuv.permute(0, 2, 3, 1)  # tchw

        hr_rgb = data_utils.yCbCr2rgb(hr_yuv).numpy()
        hr_seq = data_utils.float32_to_uint8(hr_rgb)  # thwc|rgb|uint8

        return hr_seq
Ejemplo n.º 4
0
    def infer_sequence_generator(self, seq_gen, device):
        """
            Parameters:
                :param lr_data: torch.FloatTensor in shape tchw
                :param device: torch.device

                :yield hr_frame: uint8 np.ndarray in shape 1chw
        """
        first_frame = True
        s = self.scale
        # hr_seq = []
        # gt_seq = []
        # lr_seq = []
        # frm_names = []
        # forward

        for item in seq_gen:
            with torch.no_grad():
                self.eval()
                lr_curr = item['lr']
                gt_curr = item['gt']
                frm_idx = item['frm_idx']
                if first_frame:
                    _, c, h, w = lr_curr.size()
                    first_frame = False
                    lr_prev = torch.zeros(1, c, h, w,
                                          dtype=torch.float32).to(device)
                    hr_prev = torch.zeros(1,
                                          self.out_nc,
                                          s * h,
                                          s * w,
                                          dtype=torch.float32).to(device)

                lr_curr = lr_curr.to(device)
                hr_curr = self.forward(lr_curr, lr_prev, hr_prev)
                lr_prev, hr_prev = lr_curr, hr_curr

                hr_frm = hr_curr.squeeze(0).cpu().numpy()  # chw|rgb|uint8
                gt_curr = gt_curr.squeeze(0).cpu().numpy()
                lr_frm = lr_curr.squeeze(0).cpu().numpy()[:3]

            yield {
                'frm_idx': frm_idx,
                'lr': float32_to_uint8(lr_frm).transpose(1, 2, 0),
                'gt': float32_to_uint8(gt_curr).transpose(1, 2, 0),
                'hr': float32_to_uint8(hr_frm).transpose(1, 2, 0)
            }
Ejemplo n.º 5
0
def upscale_sequence(data, gt_h, gt_w, batch_size=10):
    data = data.permute(0, 3, 1, 2)
    t, c, h, w = data.size()
    result = []
    for idx_start in range(0, t, batch_size):
        idx_end = min(idx_start + batch_size, t)
        data_item = data[idx_start:idx_end]
        data_item = F.interpolate(data_item,
                                  size=(gt_h, gt_w),
                                  mode='bilinear',
                                  align_corners=False)
        result.append(data_item.cpu().numpy())
    result = np.concatenate(result)
    result = result.transpose(0, 2, 3, 1)
    result = float32_to_uint8(result)
    return result
Ejemplo n.º 6
0
def data_processing(model, data, test_loader, input_data_type):
    if input_data_type == 'BD':
        lr_data = test_loader.dataset.apply_BD(data['gt'])['lr'][0]
    else:
        lr_data = data['lr'][0]

    seq_idx = data['seq_idx'][0]
    frm_idx = [frm_idx[0] for frm_idx in data['frm_idx']]
    output_seq = model.infer(lr_data)  # thwc|rgb|uint8
    _, h, w, _ = output_seq.shape
    
    if input_data_type != 'Style':
        input_seq = upscale_sequence(lr_data, h, w) # thwc|rgb|uint8
    else:
        input_seq = lr_data.detach().cpu().numpy()
        input_seq = data_utils.float32_to_uint8(input_seq)
    return input_seq, output_seq, seq_idx, frm_idx
Ejemplo n.º 7
0
    def infer_sequence(self, lr_data, device):
        """
            Parameters:
                :param lr_data: torch.FloatTensor in shape tchw
                :param device: torch.device

                :return hr_seq: uint8 np.ndarray in shape tchw
        """

        # setup params
        tot_frm, c, h, w = lr_data.size()
        s = self.scale

        # forward
        hr_seq = []
        lr_prev = torch.zeros(1, c, h, w, dtype=torch.float32).to(device)
        hr_prev = torch.zeros(1, c, s * h, s * w,
                              dtype=torch.float32).to(device)

        for i in range(tot_frm):
            with torch.no_grad():
                self.eval()

                lr_curr = lr_data[i:i + 1, ...].to(device)
                hr_curr, hr_flow, hr_warp = self.forward(
                    lr_curr, lr_prev, hr_prev)
                lr_prev, hr_prev = lr_curr, hr_curr

                hr_warp = hr_warp.squeeze(0).cpu().numpy()  # chw|rgb|uint8
                hr_frm = hr_warp.transpose(1, 2, 0)  # hwc
                flow_frm = hr_flow.squeeze(0).cpu().numpy()  # chw|rgb|uint8
                flow_uv = flow_frm.transpose(1, 2, 0)  # hwc
                flow_color = flow_vis.flow_to_color(flow_uv,
                                                    convert_to_bgr=False)

            hr_seq.append(float32_to_uint8(hr_frm))
            # hr_seq.append(float32_to_uint8(flow_color))

        return np.stack(hr_seq)