def compute_flow(self, ref, supp): """Compute flow from ref to supp. Note that in this function, the images are already resized to a multiple of 32. Args: ref (Tensor): Reference image with shape of (n, 3, h, w). supp (Tensor): Supporting image with shape of (n, 3, h, w). Returns: Tensor: Estimated optical flow: (n, 2, h, w). """ n, _, h, w = ref.size() # normalize the input images ref = [(ref - self.mean) / self.std] supp = [(supp - self.mean) / self.std] # generate downsampled frames for level in range(5): ref.append( F.avg_pool2d( input=ref[-1], kernel_size=2, stride=2, count_include_pad=False)) supp.append( F.avg_pool2d( input=supp[-1], kernel_size=2, stride=2, count_include_pad=False)) ref = ref[::-1] supp = supp[::-1] # flow computation flow = ref[0].new_zeros(n, 2, h // 32, w // 32) for level in range(len(ref)): if level == 0: flow_up = flow else: flow_up = F.interpolate( input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0 # add the residue to the upsampled flow flow = flow_up + self.basic_module[level]( torch.cat([ ref[level], flow_warp( supp[level], flow_up.permute(0, 2, 3, 1), padding_mode='border'), flow_up ], 1)) return flow
def forward(self, ref, supp): """ Args: ref (Tensor): Reference image with shape of (b, 3, h, w). supp: The supporting image to be warped: (b, 3, h, w). Returns: Tensor: Estimated optical flow: (b, 2, h, w). """ num_batches, _, h, w = ref.size() ref = [ref] supp = [supp] # generate downsampled frames for _ in range(3): ref.insert( 0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False)) supp.insert( 0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False)) # flow computation flow = ref[0].new_zeros(num_batches, 2, h // 16, w // 16) for i in range(4): flow_up = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0 flow = flow_up + self.basic_module[i](torch.cat([ ref[i], flow_warp(supp[i], flow_up.permute(0, 2, 3, 1)), flow_up ], 1)) return flow
def forward(self, lrs): """ Args: lrs: Input lr frames: (b, 7, 3, h, w). Returns: Tensor: SR frame: (b, 3, h, w). """ # In the official implementation, the 0-th frame is the reference frame if self.adapt_official_weights: lrs = lrs[:, [3, 0, 1, 2, 4, 5, 6], :, :, :] num_batches, num_lrs, _, h, w = lrs.size() lrs = self.normalize(lrs.view(-1, 3, h, w)) lrs = lrs.view(num_batches, num_lrs, 3, h, w) lr_ref = lrs[:, self.ref_idx, :, :, :] lr_aligned = [] for i in range(7): # 7 frames if i == self.ref_idx: lr_aligned.append(lr_ref) else: lr_supp = lrs[:, i, :, :, :] flow = self.spynet(lr_ref, lr_supp) lr_aligned.append(flow_warp(lr_supp, flow.permute(0, 2, 3, 1))) # reconstruction hr = torch.stack(lr_aligned, dim=1) hr = hr.view(num_batches, -1, h, w) hr = self.relu(self.conv1(hr)) hr = self.relu(self.conv2(hr)) hr = self.relu(self.conv3(hr)) hr = self.conv4(hr) + lr_ref return self.denormalize(hr)
def forward(self, lrs): """Forward function for IconVSR. Args: lrs (Tensor): Input LR tensor with shape (n, t, c, h, w). Returns: Tensor: Output HR tensor with shape (n, t, c, 4h, 4w). """ n, t, c, h_input, w_input = lrs.size() assert h_input >= 64 and w_input >= 64, ( 'The height and width of inputs should be at least 64, ' f'but got {h_input} and {w_input}.') # check whether the input is an extended sequence self.check_if_mirror_extended(lrs) lrs = self.spatial_padding(lrs) h, w = lrs.size(3), lrs.size(4) # get the keyframe indices for information-refill keyframe_idx = list(range(0, t, self.keyframe_stride)) if keyframe_idx[-1] != t - 1: keyframe_idx.append(t - 1) # the last frame must be a keyframe # compute optical flow and compute features for information-refill flows_forward, flows_backward = self.compute_flow(lrs) feats_refill = self.compute_refill_features(lrs, keyframe_idx) # backward-time propgation outputs = [] feat_prop = lrs.new_zeros(n, self.mid_channels, h, w) for i in range(t - 1, -1, -1): lr_curr = lrs[:, i, :, :, :] if i < t - 1: # no warping for the last timestep flow = flows_backward[:, i, :, :, :] feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1)) if i in keyframe_idx: feat_prop = torch.cat([feat_prop, feats_refill[i]], dim=1) feat_prop = self.backward_fusion(feat_prop) feat_prop = torch.cat([lr_curr, feat_prop], dim=1) feat_prop = self.backward_resblocks(feat_prop) outputs.append(feat_prop) outputs = outputs[::-1] # forward-time propagation and upsampling feat_prop = torch.zeros_like(feat_prop) for i in range(0, t): lr_curr = lrs[:, i, :, :, :] if i > 0: # no warping for the first timestep if flows_forward is not None: flow = flows_forward[:, i - 1, :, :, :] else: flow = flows_backward[:, -i, :, :, :] feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1)) if i in keyframe_idx: # information-refill feat_prop = torch.cat([feat_prop, feats_refill[i]], dim=1) feat_prop = self.forward_fusion(feat_prop) feat_prop = torch.cat([lr_curr, outputs[i], feat_prop], dim=1) feat_prop = self.forward_resblocks(feat_prop) out = self.lrelu(self.upsample1(feat_prop)) out = self.lrelu(self.upsample2(out)) out = self.lrelu(self.conv_hr(out)) out = self.conv_last(out) base = self.img_upsample(lr_curr) out += base outputs[i] = out return torch.stack(outputs, dim=1)[:, :, :, :4 * h_input, :4 * w_input]
def propagate(self, feats, flows, module_name): """Propagate the latent features throughout the sequence. Args: feats dict(list[tensor]): Features from previous branches. Each component is a list of tensors with shape (n, c, h, w). flows (tensor): Optical flows with shape (n, t - 1, 2, h, w). module_name (str): The name of the propgation branches. Can either be 'backward_1', 'forward_1', 'backward_2', 'forward_2'. Return: dict(list[tensor]): A dictionary containing all the propagated features. Each key in the dictionary corresponds to a propagation branch, which is represented by a list of tensors. """ n, t, _, h, w = flows.size() frame_idx = range(0, t + 1) flow_idx = range(-1, t) mapping_idx = list(range(0, len(feats['spatial']))) mapping_idx += mapping_idx[::-1] if 'backward' in module_name: frame_idx = frame_idx[::-1] flow_idx = frame_idx feat_prop = flows.new_zeros(n, self.mid_channels, h, w) for i, idx in enumerate(frame_idx): feat_current = feats['spatial'][mapping_idx[idx]] if self.cpu_cache: feat_current = feat_current.cuda() feat_prop = feat_prop.cuda() # second-order deformable alignment if i > 0 and self.is_with_alignment: flow_n1 = flows[:, flow_idx[i], :, :, :] if self.cpu_cache: flow_n1 = flow_n1.cuda() cond_n1 = flow_warp(feat_prop, flow_n1.permute(0, 2, 3, 1)) # initialize second-order features feat_n2 = torch.zeros_like(feat_prop) flow_n2 = torch.zeros_like(flow_n1) cond_n2 = torch.zeros_like(cond_n1) if i > 1: # second-order features feat_n2 = feats[module_name][-2] if self.cpu_cache: feat_n2 = feat_n2.cuda() flow_n2 = flows[:, flow_idx[i - 1], :, :, :] if self.cpu_cache: flow_n2 = flow_n2.cuda() flow_n2 = flow_n1 + flow_warp(flow_n2, flow_n1.permute(0, 2, 3, 1)) cond_n2 = flow_warp(feat_n2, flow_n2.permute(0, 2, 3, 1)) # flow-guided deformable convolution cond = torch.cat([cond_n1, feat_current, cond_n2], dim=1) feat_prop = torch.cat([feat_prop, feat_n2], dim=1) feat_prop = self.deform_align[module_name](feat_prop, cond, flow_n1, flow_n2) # concatenate and residual blocks feat = [feat_current] + [ feats[k][idx] for k in feats if k not in ['spatial', module_name] ] + [feat_prop] if self.cpu_cache: feat = [f.cuda() for f in feat] feat = torch.cat(feat, dim=1) feat_prop = feat_prop + self.backbone[module_name](feat) feats[module_name].append(feat_prop) if self.cpu_cache: feats[module_name][-1] = feats[module_name][-1].cpu() torch.cuda.empty_cache() if 'backward' in module_name: feats[module_name] = feats[module_name][::-1] return feats
def forward(self, lrs): """Forward function for BasicVSR. Args: lrs (Tensor): Input LR sequence with shape (n, t, c, h, w). Returns: Tensor: Output HR sequence with shape (n, t, c, 4h, 4w). """ n, t, c, h, w = lrs.size() assert h >= 64 and w >= 64, ( 'The height and width of inputs should be at least 64, ' f'but got {h} and {w}.') # check whether the input is an extended sequence self.check_if_mirror_extended(lrs) # compute optical flow flows_forward, flows_backward = self.compute_flow(lrs) # backward-time propgation outputs = [] feat_prop = lrs.new_zeros(n, self.mid_channels, h, w) for i in range(t - 1, -1, -1): if i < t - 1: # no warping required for the last timestep flow = flows_backward[:, i, :, :, :] feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1)) feat_prop = torch.cat([lrs[:, i, :, :, :], feat_prop], dim=1) feat_prop = self.backward_resblocks(feat_prop) outputs.append(feat_prop) outputs = outputs[::-1] # forward-time propagation and upsampling feat_prop = torch.zeros_like(feat_prop) for i in range(0, t): lr_curr = lrs[:, i, :, :, :] if i > 0: # no warping required for the first timestep if flows_forward is not None: flow = flows_forward[:, i - 1, :, :, :] else: flow = flows_backward[:, -i, :, :, :] feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1)) feat_prop = torch.cat([lr_curr, feat_prop], dim=1) feat_prop = self.forward_resblocks(feat_prop) # upsampling given the backward and forward features out = torch.cat([outputs[i], feat_prop], dim=1) out = self.lrelu(self.fusion(out)) out = self.lrelu(self.upsample1(out)) out = self.lrelu(self.upsample2(out)) out = self.lrelu(self.conv_hr(out)) out = self.conv_last(out) base = self.img_upsample(lr_curr) out += base outputs[i] = out return torch.stack(outputs, dim=1)