コード例 #1
0
    def forward(self, frames, F_kprime_to_k, F_n_to_k_s, F_k_to_n_s):
        # frames: base frame also included
        h0 = int(list(frames[0].size())[2])
        w0 = int(list(frames[0].size())[3])
        h6 = int(list(frames[-1].size())[2])
        w6 = int(list(frames[-1].size())[3])
        if h0 != h6 or w0 != w6:
            sys.exit('Frame sizes do not match')

        GAUSSIAN_FILTER_KSIZE = len(frames)
        gaussian_filter = cv2.getGaussianKernel(GAUSSIAN_FILTER_KSIZE, -1)

        if self.args.inference_with_frame_selection > 0:
            import numpy as np
            var_lap_values = []
            flow_values = []
            clear_frame_indicator = []

            def variance_of_laplacian(image):
                gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
                # compute the Laplacian of the image and then return the focus
                # measure, which is simply the variance of the Laplacian
                return cv2.Laplacian(gray, cv2.CV_32F).var()

            for iii, tmp_img in enumerate(frames):
                var_lap_values.append(
                    variance_of_laplacian(
                        np.transpose(tmp_img.cpu().numpy()[0], (1, 2, 0))))
                flow_values.append(
                    np.sum((F_n_to_k_s[iii].cpu().numpy()[0])**2))
            mid_value = sorted(var_lap_values)[len(frames) // 2]
            mid_flow_value = sorted(flow_values)[len(frames) // 2]
            for iii, value in enumerate(var_lap_values):
                if value < mid_value and flow_values[iii] < mid_flow_value:
                    clear_frame_indicator.append(False)
                else:
                    clear_frame_indicator.append(True)
            clear_frames = []
            clear_gaussian_filter = []
            clear_F_n_to_k_s = []
            clear_F_k_to_n_s = []
            for idx, indicator in enumerate(clear_frame_indicator):
                if indicator:
                    clear_frames.append(frames[idx])
                    clear_gaussian_filter.append(gaussian_filter[idx, 0])
                    clear_F_n_to_k_s.append(F_n_to_k_s[idx])
                    clear_F_k_to_n_s.append(F_k_to_n_s[idx])
            frames = clear_frames
            F_n_to_k_s = clear_F_n_to_k_s
            F_k_to_n_s = clear_F_k_to_n_s
            gaussian_filter = np.copy(np.stack(clear_gaussian_filter, 0))
            gaussian_filter = np.expand_dims(gaussian_filter, -1)
            gaussian_filter = gaussian_filter / np.sum(gaussian_filter)

        h_padded = False
        w_padded = False
        minimum_size = 4
        if h0 % minimum_size != 0:
            pad_h = minimum_size - (h0 % minimum_size)
            F_kprime_to_k = F.pad(F_kprime_to_k, (0, 0, 0, pad_h),
                                  mode='replicate')
            for i in range(len(frames)):
                frames[i] = F.pad(frames[i], (0, 0, 0, pad_h),
                                  mode='replicate')
                F_n_to_k_s[i] = F.pad(F_n_to_k_s[i], (0, 0, 0, pad_h),
                                      mode='replicate')
            h_padded = True

        if w0 % minimum_size != 0:
            pad_w = minimum_size - (w0 % minimum_size)
            F_kprime_to_k = F.pad(F_kprime_to_k, (0, pad_w, 0, 0),
                                  mode='replicate')
            for i in range(len(frames)):
                frames[i] = F.pad(frames[i], (0, pad_w, 0, 0),
                                  mode='replicate')
                F_n_to_k_s[i] = F.pad(F_n_to_k_s[i], (0, pad_w, 0, 0),
                                      mode='replicate')
            w_padded = True

        if self.args.noDL_CNNAggregation > 0:
            """no learnable params"""
            """except final aggregation function"""
            W = 256
            H = 256
            tenOnes = torch.ones_like(frames[0])[:, 0:1, :, :]
            tenOnes = torch.nn.ZeroPad2d((W, W, H, H))(tenOnes).detach()
            F_kprime_to_k_pad = torch.nn.ZeroPad2d((W, W, H, H))(F_kprime_to_k)
            tenWarpedFeat = []
            tenWarpedMask = []
            tenWarpedFlow = []
            for idx, feat in enumerate(frames):
                """padding for forward warping"""
                ref_frame_flow = torch.nn.ReplicationPad2d(
                    (W, W, H, H))(F_n_to_k_s[idx])
                tenRef = torch.nn.ReplicationPad2d((W, W, H, H))(feat)
                """first forward warping"""
                tenWarpedFirst = softsplat.FunctionSoftsplat(
                    tenInput=tenRef,
                    tenFlow=ref_frame_flow,
                    tenMetric=None,
                    strType='average')
                tenMaskFirst = softsplat.FunctionSoftsplat(
                    tenInput=tenOnes,
                    tenFlow=ref_frame_flow,
                    tenMetric=None,
                    strType='average')
                tenFlowFirst = softsplat.FunctionSoftsplat(
                    tenInput=ref_frame_flow,
                    tenFlow=ref_frame_flow,
                    tenMetric=None,
                    strType='average')
                """second backward warping"""
                tenWarpedSecond = backwarp(tenInput=tenWarpedFirst,
                                           tenFlow=F_kprime_to_k_pad)
                tenMaskSecond = backwarp(tenInput=tenMaskFirst,
                                         tenFlow=F_kprime_to_k_pad)
                tenFlowSecond = backwarp(tenInput=tenFlowFirst,
                                         tenFlow=F_kprime_to_k_pad)
                """back to original resolution"""
                tenWarped = tenWarpedSecond[:, :, H:-H, W:-W]
                tenMask = tenMaskSecond[:, :, H:-H, W:-W]
                tenFlow = tenFlowSecond[:, :, H:-H, W:-W]
                tenWarpedFeat.append(tenWarped)
                tenWarpedMask.append(tenMask)
                tenWarpedFlow.append(tenFlow)
            color_tensor = []
            for i in range(len(tenWarpedFeat)):
                color_tensor.append(tenWarpedFeat[i])
            color_tensor = torch.stack(color_tensor, 0)
            weight_tensor = []
            for i in range(len(tenWarpedFeat)):
                weight_tensor.append(
                    self.aggregationWeightingNetwork(
                        torch.cat([
                            tenWarpedFeat[i], tenWarpedMask[i],
                            tenWarpedFlow[i], tenWarpedFeat[len(frames) // 2],
                            tenWarpedMask[len(frames) // 2],
                            torch.abs(tenWarpedFeat[i] -
                                      tenWarpedFeat[len(frames) // 2])
                        ], 1)))
            weight_tensor = torch.stack(weight_tensor, 0)

            if self.args.gumbel > 0:
                weight_tensor = gumbel_softmax(weight_tensor, hard=True, dim=0)
            else:
                weight_tensor = torch.softmax(weight_tensor, 0)
            total_mask = torch.sum(torch.stack(tenWarpedMask, 0), 0)
            global_average_pooled_feature = torch.sum(color_tensor *
                                                      weight_tensor,
                                                      dim=0)
            global_average_pooled_feature = global_average_pooled_feature * torch.where(
                total_mask > 0, torch.ones_like(total_mask),
                torch.zeros_like(total_mask))
            if h_padded:
                global_average_pooled_feature = global_average_pooled_feature[:, :,
                                                                              0:
                                                                              h0, :]
            if w_padded:
                global_average_pooled_feature = global_average_pooled_feature[:, :, :,
                                                                              0:
                                                                              w0]
            return global_average_pooled_feature

        features = []
        for i in range(len(frames)):
            features.append(self.spatialFeatureNetwork(frames[i]))
        # import numpy as np
        # nnn = feature0.cpu().numpy()
        # print(np.mean(nnn))
        # print(np.std(nnn))
        # print(np.min(nnn))
        # print(np.max(nnn))

        # forward warping + backward warping
        # W = list(features[0].size())[3]//2
        # H = list(features[0].size())[2]//2
        W = 256
        H = 256
        tenOnes = torch.ones_like(features[0])[:, 0:1, :, :]
        tenOnes = torch.nn.ZeroPad2d((W, W, H, H))(tenOnes).detach()
        if self.args.FOV_expansion > 0:
            F_kprime_to_k_pad = torch.nn.ReplicationPad2d(
                (W, W, H, H))(F_kprime_to_k)
        else:
            F_kprime_to_k_pad = torch.nn.ZeroPad2d((W, W, H, H))(F_kprime_to_k)
        tenWarpedFeat = []
        tenWarpedMask = []
        for idx, feat in enumerate(features):
            if self.args.all_backward > 0:
                ref_frame_flow = torch.nn.ReplicationPad2d(
                    (W, W, H, H))(F_k_to_n_s[idx])
                tenRef = torch.nn.ReplicationPad2d((W, W, H, H))(feat)
                tenWarpedFirst = backwarp(tenInput=tenRef,
                                          tenFlow=ref_frame_flow)
                tenMaskFirst = backwarp(tenInput=tenOnes,
                                        tenFlow=ref_frame_flow)
            else:
                """padding for forward warping"""
                ref_frame_flow = torch.nn.ReplicationPad2d(
                    (W, W, H, H))(F_n_to_k_s[idx])
                tenRef = torch.nn.ReplicationPad2d((W, W, H, H))(feat)
                """first forward warping"""
                tenWarpedFirst = softsplat.FunctionSoftsplat(
                    tenInput=tenRef,
                    tenFlow=ref_frame_flow,
                    tenMetric=None,
                    strType='average')
                tenMaskFirst = softsplat.FunctionSoftsplat(
                    tenInput=tenOnes,
                    tenFlow=ref_frame_flow,
                    tenMetric=None,
                    strType='average')
            """second backward warping"""
            if self.args.bundle_forward_flow > 0:
                tenWarpedSecond = softsplat.FunctionSoftsplat(
                    tenInput=tenWarpedFirst,
                    tenFlow=F_kprime_to_k_pad,
                    tenMetric=None,
                    strType='average')
                tenMaskSecond = softsplat.FunctionSoftsplat(
                    tenInput=tenMaskFirst,
                    tenFlow=F_kprime_to_k_pad,
                    tenMetric=None,
                    strType='average')
            else:
                tenWarpedSecond = backwarp(tenInput=tenWarpedFirst,
                                           tenFlow=F_kprime_to_k_pad)
                tenMaskSecond = backwarp(tenInput=tenMaskFirst,
                                         tenFlow=F_kprime_to_k_pad)
            """back to original resolution"""
            if self.args.FOV_expansion <= 0:
                tenWarped = tenWarpedSecond[:, :, H:-H, W:-W]
                tenMask = tenMaskSecond[:, :, H:-H, W:-W]
            else:
                tenWarped = tenWarpedSecond
                tenMask = tenMaskSecond
            tenWarpedFeat.append(tenWarped)
            tenWarpedMask.append(tenMask)

        # tenMetrics = []
        # tenSoftmaxs = []
        # for i in range(len(frames)):
        #     tenMetrics.append(torch.nn.functional.l1_loss(input=frames[i], target=backwarp(tenInput=target_frame, tenFlow=flows[i]), reduction='none').mean(1, True))
        #     tenSoftmaxs.append(softsplat.FunctionSoftsplat(tenInput=features[i], tenFlow=flows[i], tenMetric=self.beta * tenMetrics[i], strType=self.splatting_type))

        # # backward warping
        # tenSoftmaxs = []
        # for i in range(len(frames)):
        #     # tenSoftmaxs.append(backwarp(tenInput=features[i], tenFlow=flows[i]))
        #     tenSoftmaxs.append(tenWarped)

        # # TODO: mask of backward warping
        # if self.args.pooling_with_mask > 0 or self.args.decoder_with_mask > 0 or self.args.softargmax_with_mask > 0:
        #     def length_sq(x):
        #         return torch.sum(x**2, dim=1, keepdim=True)
        #     tenMasks = []
        #     for i in range(len(frames)):
        #         mag_sq = length_sq(forward_flows[i]) + length_sq(flows[i])
        #         flow_fw_warped = backwarp(forward_flows[i], flows[i])
        #         flow_diff_bw = flows[i] + flow_fw_warped
        #         occ_thresh = 0.01 * mag_sq + 0.5
        #         fb_occ_bw = (length_sq(flow_diff_bw) > occ_thresh).float()
        #         if self.args.mask_with_proxy_mask > 0:
        #             tenMasks.append(((1.0 - fb_occ_bw)*torch.mean(proxy_mask, dim=1, keepdim=True)).detach())
        #         else:
        #             tenMasks.append((1.0-fb_occ_bw).detach())
        #         # imwrite(fb_occ_bw, 'mask'+str(i)+'.png', range=(0, 1))
        #         # imwrite(frames[i], 'frame'+str(i)+'.png', range=(0, 1))

        # # mask of forward warping
        # if self.args.pooling_with_mask > 0 or self.args.decoder_with_mask > 0 or self.args.softargmax_with_mask > 0:
        #     tenOnes = torch.ones([self.args.batch_size, 1, int(list(frames[0].size())[2]), int(list(frames[0].size())[3])]).cuda().detach()
        #     tenMasks = []
        #     for i in range(len(frames)):
        #         tenMasks.append(softsplat.FunctionSoftsplat(tenInput=tenOnes, tenFlow=flows[i], tenMetric=self.beta * tenMetrics[i], strType=self.splatting_type))
        """Pooling"""
        if self.args.pooling_with_mask <= 0:
            global_average_pooled_feature = torch.stack(tenWarpedFeat,
                                                        -1).mean(-1)
        else:
            if self.args.pooling_with_center_bias <= 0:
                tmp_list = []
                for i in range(len(frames)):
                    tmp_list.append(tenWarpedFeat[i] * tenWarpedMask[i])
                global_summed_feature = torch.stack(tmp_list, -1).sum(-1)
                global_summed_weights = torch.stack(tenWarpedMask, -1).sum(-1)
                global_average_pooled_feature = global_summed_feature / torch.clamp(
                    global_summed_weights, min=1e-6)
            else:
                if self.args.pooling_type == 'gaussian':
                    # gaussian weights for center bias pooling
                    # GAUSSIAN_FILTER_KSIZE = len(frames)
                    # gaussian_filter = cv2.getGaussianKernel(GAUSSIAN_FILTER_KSIZE, -1)
                    color_tensor = []
                    weight_tensor = []
                    for i in range(len(frames)):
                        color_tensor.append(tenWarpedFeat[i])
                        weight_tensor.append(tenWarpedMask[i] *
                                             gaussian_filter[i, 0])
                    color_tensor = torch.stack(color_tensor, 0)
                    weight_tensor = torch.stack(weight_tensor, 0)
                    output_mask = torch.sum(weight_tensor, dim=0)
                    global_average_pooled_feature = torch.sum(
                        color_tensor * weight_tensor, dim=0) / torch.clamp(
                            torch.sum(weight_tensor, dim=0), min=1e-6)
                elif self.args.pooling_type == 'max':

                    def score_max(x, dim, score):
                        _tmp = [1] * len(x.size())
                        _tmp[dim] = x.size(dim)
                        _tmp[2] = x.size(2)  # channel dim
                        return torch.gather(
                            x, dim,
                            score.max(dim)[1].unsqueeze(dim).repeat(
                                tuple(_tmp))).select(dim, 0)

                    # GAUSSIAN_FILTER_KSIZE = len(frames)
                    # gaussian_filter = cv2.getGaussianKernel(GAUSSIAN_FILTER_KSIZE, -1)
                    color_tensor = []
                    weight_tensor = []
                    for i in range(len(frames)):
                        color_tensor.append(tenWarpedFeat[i])
                        weight_tensor.append(tenWarpedMask[i] *
                                             gaussian_filter[i, 0])
                    color_tensor = torch.stack(color_tensor, 0)
                    weight_tensor = torch.stack(weight_tensor, 0)
                    output_mask = torch.sum(weight_tensor, dim=0)
                    global_average_pooled_feature = score_max(
                        color_tensor, 0, weight_tensor)
                elif self.args.pooling_type == 'CNN':
                    # TODO: deep blend
                    color_tensor = []
                    for i in range(len(frames)):
                        color_tensor.append(tenWarpedFeat[i])
                    color_tensor = torch.stack(color_tensor, 0)
                    weight_tensor = []
                    if self.args.inference_with_frame_selection > 0:
                        middle_idx = np.argmax(np.array(gaussian_filter))
                    else:
                        middle_idx = len(frames) // 2
                    for i in range(len(frames)):
                        weight_tensor.append(
                            self.aggregationWeightingNetwork(
                                torch.cat([
                                    tenWarpedFeat[i], tenWarpedMask[i],
                                    tenWarpedFeat[middle_idx],
                                    tenWarpedMask[middle_idx]
                                ], 1)))
                    weight_tensor = torch.stack(weight_tensor, 0)
                    if self.args.gumbel > 0:
                        weight_tensor = gumbel_softmax(weight_tensor,
                                                       hard=True,
                                                       dim=0)
                    else:
                        weight_tensor = torch.softmax(weight_tensor, 0)
                    global_average_pooled_feature = torch.sum(color_tensor *
                                                              weight_tensor,
                                                              dim=0)
                elif self.args.pooling_type == 'CNN_flowError':
                    # consistency
                    def length_sq(x):
                        return torch.sum(x**2, dim=1, keepdim=True)

                    flow_errors = []
                    for i in range(len(frames)):
                        flow_fw_warped = backwarp(F_n_to_k_s[i], F_k_to_n_s[i])
                        flow_diff_bw = F_k_to_n_s[i] + flow_fw_warped
                        if self.args.bundle_forward_flow > 0:
                            flow_errors.append(
                                softsplat.FunctionSoftsplat(
                                    tenInput=length_sq(flow_diff_bw),
                                    tenFlow=F_kprime_to_k,
                                    tenMetric=None,
                                    strType='average'))
                        else:
                            flow_errors.append(
                                backwarp(tenInput=length_sq(flow_diff_bw),
                                         tenFlow=F_kprime_to_k))
                        # flow_errors.append(length_sq(flow_diff_bw))
                    # TODO: deep blend
                    color_tensor = []
                    for i in range(len(frames)):
                        color_tensor.append(tenWarpedFeat[i])
                    color_tensor = torch.stack(color_tensor, 0)
                    weight_tensor = []
                    if self.args.inference_with_frame_selection > 0:
                        middle_idx = np.argmax(np.array(gaussian_filter))
                    else:
                        middle_idx = len(frames) // 2
                    for i in range(len(frames)):
                        if self.args.FOV_expansion > 0:
                            flow_errors[i] = torch.nn.ReplicationPad2d(
                                (W, W, H, H))(flow_errors[i])

                        weight_tensor.append(
                            self.aggregationWeightingNetwork(
                                torch.cat([
                                    tenWarpedFeat[i], tenWarpedMask[i],
                                    flow_errors[i], tenWarpedFeat[middle_idx],
                                    tenWarpedMask[middle_idx]
                                ], 1)))
                    weight_tensor = torch.stack(weight_tensor, 0)
                    if self.args.gumbel > 0:
                        weight_tensor = gumbel_softmax(weight_tensor,
                                                       hard=True,
                                                       dim=0)
                    else:
                        weight_tensor = torch.softmax(weight_tensor, 0)
                    global_average_pooled_feature = torch.sum(color_tensor *
                                                              weight_tensor,
                                                              dim=0)
        """decoder"""
        if self.args.no_pooling > 0 and self.args.single_decoder <= 0:
            I_preds = []
            Cs = []
            for i in range(len(frames)):
                I_pred, C = self.refinementNetwork(
                    torch.cat([tenWarpedFeat[i], tenWarpedMask[i]], 1))
                I_preds.append(I_pred)
                Cs.append(C)
        elif self.args.no_pooling <= 0 and self.args.single_decoder > 0:
            I_pred, C = self.refinementNetwork(global_average_pooled_feature)
            if h_padded:
                I_pred = I_pred[:, :, 0:h0, :]
            if w_padded:
                I_pred = I_pred[:, :, :, 0:w0]
            return I_pred
        else:
            if self.args.decoder_with_mask <= 0:
                I_preds = []
                Cs = []
                for i in range(len(frames)):
                    I_pred, C = self.refinementNetwork(
                        torch.cat(
                            [tenWarpedFeat[i], global_average_pooled_feature],
                            1))
                    I_preds.append(I_pred)
                    Cs.append(C)
            else:
                I_preds = []
                Cs = []
                for i in range(len(frames)):
                    I_pred, C = self.refinementNetwork(
                        torch.cat([
                            tenWarpedFeat[i], global_average_pooled_feature,
                            tenWarpedMask[i]
                        ], 1))
                    I_preds.append(I_pred)
                    Cs.append(C)

        # TODO: residual detail transfer
        # if self.args.residual_detail_transfer > 0:
        #     # print('with residual detail transfer')
        #     tenOnes = torch.ones([self.args.batch_size, 1, int(list(frames[0].size())[2]), int(list(frames[0].size())[3])]).cuda().detach()
        #     for i in range(len(frames)):
        #         reconstructed_frame, _ = self.refinementNetwork(torch.cat([features[i], features[i], tenOnes], 1))
        #         delta_frame = frames[i] - reconstructed_frame
        #         warped_delta_frame = backwarp(delta_frame, flows[i])
        #         # warped_delta_frame = softsplat.FunctionSoftsplat(tenInput=delta_frame, tenFlow=flows[i], tenMetric=self.beta * tenMetrics[i], strType=self.splatting_type)
        #         if self.args.residual_detail_transfer_with_mask > 0:
        #             I_preds[i] = I_preds[i] + warped_delta_frame * tenMasks[i]
        #         else:
        #             I_preds[i] = I_preds[i] + warped_delta_frame

        # center residual detail transfer (wierd, worse PNSR due to shared weights decoder)
        if self.args.center_residual_detail_transfer > 0:
            tenOnes = torch.ones([
                self.args.batch_size, 1,
                int(list(frames[0].size())[2]),
                int(list(frames[0].size())[3])
            ]).cuda().detach()
            center_idx = len(frames) // 2
            reconstructed_frame, _ = self.refinementNetwork(
                torch.cat(
                    [features[center_idx], features[center_idx], tenOnes], 1))
            delta_frame = frames[center_idx] - reconstructed_frame
            warped_delta_frame = backwarp(delta_frame, F_kprime_to_k)
            # center residual detail transfer with mask
            I_preds[center_idx] = I_preds[
                center_idx] + warped_delta_frame * tenWarpedMask[center_idx]

        # all frames residual detail transfer
        else:
            if self.args.residual_detail_transfer > 0:
                tenOnes = torch.ones([
                    self.args.batch_size, 1,
                    int(list(frames[0].size())[2]),
                    int(list(frames[0].size())[3])
                ]).cuda().detach()
                for i in range(len(frames)):
                    reconstructed_frame, _ = self.refinementNetwork(
                        torch.cat([features[i], features[i], tenOnes], 1))
                    delta_frame = frames[i] - reconstructed_frame
                    # warped_delta_frame = backwarp(delta_frame, F_kprime_to_k)
                    if self.args.all_backward > 0:
                        ref_frame_flow = torch.nn.ReplicationPad2d(
                            (W, W, H, H))(F_k_to_n_s[idx])
                        tenRef = torch.nn.ReplicationPad2d(
                            (W, W, H, H))(delta_frame)
                        tenWarpedFirst = backwarp(tenInput=tenRef,
                                                  tenFlow=ref_frame_flow)
                    else:
                        """padding for forward warping"""
                        ref_frame_flow = torch.nn.ReplicationPad2d(
                            (W, W, H, H))(F_n_to_k_s[i])
                        tenRef = torch.nn.ReplicationPad2d(
                            (W, W, H, H))(delta_frame)
                        """first forward warping"""
                        tenWarpedFirst = softsplat.FunctionSoftsplat(
                            tenInput=tenRef,
                            tenFlow=ref_frame_flow,
                            tenMetric=None,
                            strType='average')
                    """second backward warping"""
                    if self.args.bundle_forward_flow > 0:
                        tenWarpedSecond = softsplat.FunctionSoftsplat(
                            tenInput=tenWarpedFirst,
                            tenFlow=F_kprime_to_k_pad,
                            tenMetric=None,
                            strType='average')
                    else:
                        tenWarpedSecond = backwarp(tenInput=tenWarpedFirst,
                                                   tenFlow=F_kprime_to_k_pad)
                    """back to original resolution"""
                    if self.args.FOV_expansion <= 0:
                        tenWarpedResidual = tenWarpedSecond[:, :, H:-H, W:-W]
                    else:
                        tenWarpedResidual = tenWarpedSecond
                    # center residual detail transfer with mask
                    # if self.args.masked_residual_detail_transfer > 0 and i != len(frames)//2:
                    #     I_preds[i] = I_preds[i] + tenWarpedResidual * torch.clamp((tenWarpedMask[i] - tenWarpedMask[len(frames)//2]), min=0.0)
                    # else:
                    I_preds[
                        i] = I_preds[i] + tenWarpedResidual * tenWarpedMask[i]

        # blending
        # if self.training:
        if not self.args.softargmax_with_mask:
            softmax_conf = torch.softmax(torch.stack(Cs, -1), -1)
        else:
            tmp_list = []
            for i in range(len(frames)):
                tmp_list.append(Cs[i] * tenWarpedMask[i])
            softmax_conf = torch.softmax(torch.stack(tmp_list, -1), -1)

        I_pred = (I_preds[0] * softmax_conf[..., 0])
        for i in range(1, len(frames)):
            I_pred += (I_preds[i] * softmax_conf[..., i])

        if self.args.seamless > 0:
            import numpy as np
            # print(gaussian_filter)
            # print(np.argmax(np.array(gaussian_filter)))
            # print(len(gaussian_filter))
            # print(len(frames))
            # import pdb
            # pdb.set_trace()
            if self.args.inference_with_frame_selection > 0:
                middle_idx = np.argmax(np.array(gaussian_filter))
            else:
                middle_idx = len(frames) // 2

            ref_frame_flow = torch.nn.ReplicationPad2d(
                (W, W, H, H))(F_n_to_k_s[middle_idx])
            tenRef = torch.nn.ReplicationPad2d(
                (W, W, H, H))(frames[middle_idx])
            """first forward warping"""
            tenWarpedFirst = softsplat.FunctionSoftsplat(
                tenInput=tenRef,
                tenFlow=ref_frame_flow,
                tenMetric=None,
                strType='average')
            """second backward warping"""
            if self.args.bundle_forward_flow > 0:
                tenWarpedSecond = softsplat.FunctionSoftsplat(
                    tenInput=tenWarpedFirst,
                    tenFlow=F_kprime_to_k_pad,
                    tenMetric=None,
                    strType='average')
            else:
                tenWarpedSecond = backwarp(tenInput=tenWarpedFirst,
                                           tenFlow=F_kprime_to_k_pad)
            """back to original resolution"""
            if self.args.FOV_expansion <= 0:
                tenWarpedKey = tenWarpedSecond[:, :, H:-H, W:-W]
            else:
                tenWarpedKey = tenWarpedSecond

            # """warped key frame"""
            # tenRef = torch.nn.ReplicationPad2d((W, W, H, H))(frames[len(frames)//2])
            # tenWarpedSecond = backwarp(tenInput=tenRef, tenFlow=F_kprime_to_k_pad)
            # """back to original resolution"""
            # if self.args.FOV_expansion <= 0:
            #     tenWarpedKey = tenWarpedSecond[:, :, H:-H, W:-W]
            # else:
            #     tenWarpedKey = tenWarpedSecond
            npWarpedKey = np.transpose(
                np.clip(tenWarpedKey.cpu().numpy(), 0.0, 1.0)[0, ::-1],
                (1, 2, 0))
            np_I_pred = np.transpose(
                np.clip(I_pred.cpu().numpy(), 0.0, 1.0)[0, ::-1], (1, 2, 0))
            np_mask_A = np.clip(tenWarpedMask[middle_idx].cpu().numpy(), 0.0,
                                1.0)[0, 0, :, :]

            np_mask_B = np.clip(1.0 - tenWarpedMask[middle_idx].cpu().numpy(),
                                0.0, 1.0)[0, 0, :, :]

            # """np_mask[0] = 0
            # np_mask[-1] = 0
            # np_mask[:, 0] = 0
            # np_mask[:, -1] = 0"""
            #
            # """erosion_size = 10
            # element = cv2.getStructuringElement(cv2.MORPH_RECT, (2 * erosion_size + 1, 2 * erosion_size + 1),
            #                            (erosion_size, erosion_size))
            #
            # np_mask_dst = cv2.erode(np_mask, element)
            # cv2.imwrite('np_mask_dst.png', np_mask_dst)"""
            #
            # cv2.imwrite('np_mask.png', np_mask)
            # cv2.imwrite('np_I_pred.png', np_I_pred)
            # cv2.imwrite('npWarpedKey.png', npWarpedKey)
            #
            # normal_clone = cv2.seamlessClone(np_I_pred, npWarpedKey, np_mask, (np_mask.shape[1] // 2, np_mask.shape[0] // 2), cv2.NORMAL_CLONE)
            # I_pred = torch.from_numpy(np.expand_dims(np.transpose(normal_clone[:, :, ::-1], (2, 0, 1)), 0).astype(np.float32) / 255.0).cuda()

            def GaussianPyramid(img, leveln):
                GP = [img]
                for i in range(leveln - 1):
                    GP.append(cv2.pyrDown(GP[i]))
                return GP

            def LaplacianPyramid(img, leveln):
                LP = []
                for i in range(leveln - 1):
                    next_img = cv2.pyrDown(img)
                    LP.append(img - cv2.pyrUp(next_img, img.shape[1::-1]))
                    img = next_img
                LP.append(img)
                return LP

            def blend_pyramid(LPA, LPB, MA, MB):
                blended = []
                for i, M in enumerate(MA):
                    if len(list(MA[i].shape)) < 3 and len(list(
                            MB[i].shape)) < 3:
                        blended.append(LPA[i] * np.expand_dims(MA[i], -1) +
                                       LPB[i] * np.expand_dims(MB[i], -1))
                    else:
                        blended.append(LPA[i] * MA[i] + LPB[i] * MB[i])
                return blended

            def reconstruct(LS):
                img = LS[-1]
                for lev_img in LS[-2::-1]:
                    img = cv2.pyrUp(img, lev_img.shape[1::-1])
                    img += lev_img
                return img

            minimum_multi_band = 32
            pad_h_multi_band = 0
            pad_b_multi_band = 0
            if np_mask_A.shape[0] % minimum_multi_band != 0:
                pad_h_multi_band = (
                    minimum_multi_band -
                    (np_mask_A.shape[0] % minimum_multi_band)) // 2
                pad_b_multi_band = minimum_multi_band - (
                    np_mask_A.shape[0] % minimum_multi_band) - pad_h_multi_band
                print(pad_h_multi_band)
                print(pad_b_multi_band)
                np_mask_A = cv2.copyMakeBorder(np_mask_A, pad_h_multi_band,
                                               pad_b_multi_band, 0, 0,
                                               cv2.BORDER_REPLICATE)
                np_mask_B = cv2.copyMakeBorder(np_mask_B, pad_h_multi_band,
                                               pad_b_multi_band, 0, 0,
                                               cv2.BORDER_REPLICATE)
                npWarpedKey = cv2.copyMakeBorder(npWarpedKey, pad_h_multi_band,
                                                 pad_b_multi_band, 0, 0,
                                                 cv2.BORDER_REPLICATE)
                np_I_pred = cv2.copyMakeBorder(np_I_pred, pad_h_multi_band,
                                               pad_b_multi_band, 0, 0,
                                               cv2.BORDER_REPLICATE)

            pad_l_multi_band = 0
            pad_r_multi_band = 0
            if np_mask_A.shape[1] % minimum_multi_band != 0:
                pad_l_multi_band = (
                    minimum_multi_band -
                    (np_mask_A.shape[1] % minimum_multi_band)) // 2
                pad_r_multi_band = minimum_multi_band - (
                    np_mask_A.shape[1] % minimum_multi_band) - pad_l_multi_band
                np_mask_A = cv2.copyMakeBorder(np_mask_A, 0, 0,
                                               pad_l_multi_band,
                                               pad_r_multi_band,
                                               cv2.BORDER_REPLICATE)
                np_mask_B = cv2.copyMakeBorder(np_mask_B, 0, 0,
                                               pad_l_multi_band,
                                               pad_r_multi_band,
                                               cv2.BORDER_REPLICATE)
                npWarpedKey = cv2.copyMakeBorder(npWarpedKey, 0, 0,
                                                 pad_l_multi_band,
                                                 pad_r_multi_band,
                                                 cv2.BORDER_REPLICATE)
                np_I_pred = cv2.copyMakeBorder(np_I_pred, 0, 0,
                                               pad_l_multi_band,
                                               pad_r_multi_band,
                                               cv2.BORDER_REPLICATE)

            MA = GaussianPyramid(np_mask_A, 5)
            MB = GaussianPyramid(np_mask_B, 5)
            LPA = LaplacianPyramid(npWarpedKey, 5)
            LPB = LaplacianPyramid(np_I_pred, 5)
            # Blend two Laplacian pyramidspass
            blended = blend_pyramid(LPA, LPB, MA, MB)
            # Reconstruction process
            frame_A = reconstruct(blended)
            frame_A[frame_A > 1.0] = 1.0
            frame_A[frame_A < 0.0] = 0.0
            # print(frame_A.shape)

            if pad_h_multi_band != 0 or pad_b_multi_band != 0:
                frame_A = frame_A[pad_h_multi_band:-pad_b_multi_band]
            if pad_l_multi_band != 0 or pad_r_multi_band != 0:
                frame_A = frame_A[:, pad_l_multi_band:-pad_r_multi_band]

            # print(frame_A.shape)

            I_pred = torch.from_numpy(
                np.expand_dims(np.transpose(frame_A[:, :, ::-1], (2, 0, 1)),
                               0).astype(np.float32)).cuda()
            """cv2.imwrite('_np_mask_A.png', np.round(np_mask_A*255).astype(np.uint8))
            cv2.imwrite('_np_mask_B.png', np.round(np_mask_B*255).astype(np.uint8))
            cv2.imwrite('_frame_A.png', np.round(frame_A*255).astype(np.uint8))
            cv2.imwrite('_npWarpedKey.png', np.round(npWarpedKey*255).astype(np.uint8))
            cv2.imwrite('_np_I_pred.png', np.round(np_I_pred*255).astype(np.uint8))"""

        if h_padded:
            I_pred = I_pred[:, :, 0:h0, :]
        if w_padded:
            I_pred = I_pred[:, :, :, 0:w0]
        return I_pred
コード例 #2
0
        cv2.imread(filename='./images/second.png', flags=-1).transpose(
            2, 0, 1)[None, :, :, :].astype(numpy.float32) *
        (1.0 / 255.0))).cuda()
tenFlow = torch.FloatTensor(
    numpy.ascontiguousarray(
        read_flo('./images/flow.flo').transpose(2, 0,
                                                1)[None, :, :, :])).cuda()

tenMetric = torch.nn.functional.l1_loss(input=tenFirst,
                                        target=backwarp(tenInput=tenSecond,
                                                        tenFlow=tenFlow),
                                        reduction='none').mean(1, True)

for intTime, fltTime in enumerate(numpy.linspace(0.0, 1.0, 11).tolist()):
    tenSummation = softsplat.FunctionSoftsplat(tenInput=tenFirst,
                                               tenFlow=tenFlow * fltTime,
                                               tenMetric=None,
                                               strType='summation')
    tenAverage = softsplat.FunctionSoftsplat(tenInput=tenFirst,
                                             tenFlow=tenFlow * fltTime,
                                             tenMetric=None,
                                             strType='average')
    tenLinear = softsplat.FunctionSoftsplat(
        tenInput=tenFirst,
        tenFlow=tenFlow * fltTime,
        tenMetric=(0.3 - tenMetric).clamp(0.0000001, 1.0),
        strType='linear'
    )  # finding a good linearly metric is difficult, and it is not invariant to translations
    tenSoftmax = softsplat.FunctionSoftsplat(
        tenInput=tenFirst,
        tenFlow=tenFlow * fltTime,
        tenMetric=-20.0 * tenMetric,
コード例 #3
0
def splat_rgb_img(ret, ratio, R_w2t, t_w2t, j, H, W, focal, fwd_flow):
    import softsplat

    raw_rgba_s = torch.cat([ret['raw_rgb'], ret['raw_alpha'].unsqueeze(-1)],
                           dim=-1)
    raw_rgba = raw_rgba_s[:, :,
                          j, :].permute(2, 0,
                                        1).unsqueeze(0).contiguous().cuda()
    pts_ref = ret['pts_ref'][:, :, j, :3]

    pts_ref_e_G = NDC2Euclidean(pts_ref, H, W, focal)

    if fwd_flow:
        pts_post = pts_ref + ret['raw_sf_ref2post'][:, :, j, :]
    else:
        pts_post = pts_ref + ret['raw_sf_ref2prev'][:, :, j, :]

    pts_post_e_G = NDC2Euclidean(pts_post, H, W, focal)
    pts_mid_e_G = (pts_post_e_G - pts_ref_e_G) * ratio + pts_ref_e_G

    pts_mid_e_local = se3_transform_points(pts_mid_e_G,
                                           R_w2t.unsqueeze(0).unsqueeze(0),
                                           t_w2t.unsqueeze(0).unsqueeze(0))

    pts_2d_mid = perspective_projection(pts_mid_e_local, H, W, focal)

    xx, yy = torch.meshgrid(torch.linspace(0, W - 1, W),
                            torch.linspace(0, H - 1, H))
    xx = xx.t()
    yy = yy.t()
    pts_2d_original = torch.stack([xx, yy], -1)

    flow_2d = pts_2d_mid - pts_2d_original

    flow_2d = flow_2d.permute(2, 0, 1).unsqueeze(0).contiguous().cuda()

    splat_raw_rgba_dy = softsplat.FunctionSoftsplat(tenInput=raw_rgba,
                                                    tenFlow=flow_2d,
                                                    tenMetric=None,
                                                    strType='average')

    # splatting for static nerf
    pts_rig_e_local = se3_transform_points(pts_ref_e_G,
                                           R_w2t.unsqueeze(0).unsqueeze(0),
                                           t_w2t.unsqueeze(0).unsqueeze(0))

    pts_2d_rig = perspective_projection(pts_rig_e_local, H, W, focal)

    flow_2d_rig = pts_2d_rig - pts_2d_original

    flow_2d_rig = flow_2d_rig.permute(2, 0, 1).unsqueeze(0).contiguous().cuda()
    raw_rgba_rig = torch.cat(
        [ret['raw_rgb_rigid'], ret['raw_alpha_rigid'].unsqueeze(-1)], dim=-1)
    raw_rgba_rig = raw_rgba_rig[:, :, j, :].permute(
        2, 0, 1).unsqueeze(0).contiguous().cuda()

    splat_raw_rgba_rig = softsplat.FunctionSoftsplat(tenInput=raw_rgba_rig,
                                                     tenFlow=flow_2d_rig,
                                                     tenMetric=None,
                                                     strType='average')

    splat_alpha_dy = splat_raw_rgba_dy[0, 3:4, :, :]
    splat_rgb_dy = splat_raw_rgba_dy[0, 0:3, :, :]

    splat_alpha_rig = splat_raw_rgba_rig[0, 3:4, :, :]
    splat_rgb_rig = splat_raw_rgba_rig[0, 0:3, :, :]

    return splat_alpha_dy, splat_rgb_dy, splat_alpha_rig, splat_rgb_rig
コード例 #4
0
def img2tex_forwardwarp(img_tensor,
                        iuv_tensor,
                        class_num=25,
                        tex_patch_size=256,
                        img_space_size=512,
                        inpaint_func=None,
                        warp_mask=True,
                        use_prob=False):
    # pdb.set_trace()
    b, c, h, w = img_tensor.shape
    assert iuv_tensor.shape[
        1] == class_num + 2, "input I channel should be onehot encoded"
    assert iuv_tensor[:, -2:].min() >= 0 and iuv_tensor[:, -2:].max(
    ) <= 1, "input UV channels should be normalized into [0,1]"
    TextureIm = torch.zeros([b, 24, c, tex_patch_size, tex_patch_size],
                            dtype=torch.float64).to(iuv_tensor.device)
    if warp_mask:
        TextureMask = torch.zeros([b, 24, 1, tex_patch_size, tex_patch_size],
                                  dtype=torch.float64).to(iuv_tensor.device)
    for PartInd in range(
            1, class_num):  ## Set to xrange(1,23) to ignore the face part.
        prob = iuv_tensor[:, PartInd:PartInd + 1]
        if use_prob:
            ## create flow
            part_mask = (prob > 1. / class_num).float()
            # pdb.set_trace()
            uv = part_mask * iuv_tensor[:, -2:]
            uv = uv * (tex_patch_size - 1)
            if isinstance(img_space_size, tuple):
                coords = make_meshgrid(b,
                                       img_space_size[0],
                                       img_space_size[1],
                                       norm=False).to(iuv_tensor.device)
            else:
                coords = make_meshgrid(b,
                                       img_space_size,
                                       img_space_size,
                                       norm=False).to(iuv_tensor.device)
            flow = uv - coords
            # pdb.set_trace()
            ## create img
            img_masked = img_tensor * prob
            ## forward warp
            tex_patch = softsplat.FunctionSoftsplat(
                tenInput=img_masked.float(),
                tenFlow=flow.float(),
                tenMetric=None,
                strType='average')
            # pdb.set_trace()
            # imageio.imwrite("../tmp/tmp_tex_patch{}.png".format(PartInd), tex_patch[0].permute([1,2,0]).detach().cpu().numpy())
            tex_patch = tex_patch[:, :, :tex_patch_size, :tex_patch_size]
            if inpaint_func is not None:
                tex_patch = inpaint_func(tex_patch)
            # pdb.set_trace()
            TextureIm[:, PartInd - 1] += tex_patch
            # pdb.set_trace()
            if warp_mask:
                prob = softsplat.FunctionSoftsplat(tenInput=prob,
                                                   tenFlow=flow.float(),
                                                   tenMetric=None,
                                                   strType='average')
                prob = prob[:, :, :tex_patch_size, :tex_patch_size]
                if inpaint_func is not None:
                    prob = inpaint_func(prob)
                TextureMask[:, PartInd - 1] += prob  #(prob>0).float()
        else:
            # assert prob
            ## create flow
            part_mask = (prob > 0).float()
            uv = part_mask * iuv_tensor[:, -2:]
            uv = uv * (tex_patch_size - 1)
            if isinstance(img_space_size, tuple):
                coords = make_meshgrid(b,
                                       img_space_size[0],
                                       img_space_size[1],
                                       norm=False).to(iuv_tensor.device)
            else:
                coords = make_meshgrid(b,
                                       img_space_size,
                                       img_space_size,
                                       norm=False).to(iuv_tensor.device)
            flow = uv - coords
            # pdb.set_trace()
            ## create img
            img_masked = img_tensor * prob
            ## forward warp
            tex_patch = softsplat.FunctionSoftsplat(
                tenInput=img_masked.float(),
                tenFlow=flow.float(),
                tenMetric=None,
                strType='average')
            # pdb.set_trace()
            # imageio.imwrite("../tmp/tmp_tex_patch{}.png".format(PartInd), tex_patch[0].permute([1,2,0]).detach().cpu().numpy())
            tex_patch = tex_patch[:, :, :tex_patch_size, :tex_patch_size]
            if inpaint_func is not None:
                tex_patch = inpaint_func(tex_patch)
            # pdb.set_trace()
            TextureIm[:, PartInd - 1] += tex_patch
            # pdb.set_trace()
            if warp_mask:
                prob = softsplat.FunctionSoftsplat(tenInput=prob.float(),
                                                   tenFlow=flow.float(),
                                                   tenMetric=None,
                                                   strType='average')
                prob = prob[:, :, :tex_patch_size, :tex_patch_size]
                if inpaint_func is not None:
                    prob = inpaint_func(prob)
                TextureMask[:, PartInd - 1] += prob  #(prob>0).float()

        ## vis
        # flow_img = flow_viz.flow_to_image(flow[0].permute([1,2,0]).detach().cpu().numpy())
        # imageio.imwrite('flow_img.png', flow_img)
        # imageio.imwrite('u_masked.png', uv[0,0].detach().cpu().numpy())
        # imageio.imwrite('part_mask.png', part_mask[0,0].detach().cpu().numpy())
        # imageio.imwrite('img_tensor.png', img_tensor[0].permute([1,2,0]).detach().cpu().numpy())
        # imageio.imwrite('img_masked.png', img_masked[0].permute([1,2,0]).detach().cpu().numpy())
    if warp_mask:
        return TextureIm.float(), TextureMask.float()
    else:
        return TextureIm.float()
コード例 #5
0
            HHH = 256
            tenOnes = torch.ones_like(input_frames[0])[:, 0:1, :, :]
            tenOnes = torch.nn.ZeroPad2d(
                (WWW, WWW, HHH, HHH))(tenOnes).detach()
            F_kprime_to_k_pad = torch.nn.ReplicationPad2d(
                (WWW, WWW, HHH, HHH))(F_kprime_to_k)
            tenWarpedFeat = []
            tenWarpedMask = []
            for iii, feat in enumerate(input_frames):
                """padding for forward warping"""
                ref_frame_flow = torch.nn.ReplicationPad2d(
                    (WWW, WWW, HHH, HHH))(forward_flows[iii])
                """first forward warping"""
                tenMaskFirst = softsplat.FunctionSoftsplat(
                    tenInput=tenOnes,
                    tenFlow=ref_frame_flow,
                    tenMetric=None,
                    strType='average')
                """second backward warping"""
                tenMaskSecond = backwarp(tenInput=tenMaskFirst,
                                         tenFlow=F_kprime_to_k_pad)
                """back to original resolution"""
                tenMask = tenMaskSecond
                tenWarpedMask.append(tenMask)
            weight_tensor = torch.stack(tenWarpedMask, 0)
            output_mask = torch.sum(weight_tensor, dim=0)
            output_mask = torch.clamp(output_mask, max=1.0)
            # imwrite(output_mask, str(idx-GAUSSIAN_FILTER_KSIZE//2).zfill(5)+'_mask.png', range=(0, 1))

            large_mask_chain.append(output_mask.detach().cpu())
コード例 #6
0
                pic_B.transpose(2, 0, 1)[None, :, :, :].astype(numpy.float32) *
                (1.0 / 255.0))).cuda()
        tenFlow = torch.FloatTensor(
            numpy.ascontiguousarray(
                read_flo(arg_Flow).transpose(2, 0, 1)[None, :, :, :])).cuda()

        tenMetric = torch.nn.functional.l1_loss(
            input=tenFirst,
            target=backwarp(tenInput=tenSecond, tenFlow=tenFlow),
            reduction='none').mean(1, True)

        img_array.append(pic_A)
        for t in time:
            tenSoftmax = softsplat.FunctionSoftsplat(tenInput=tenFirst,
                                                     tenFlow=tenFlow * t,
                                                     tenMetric=-20.0 *
                                                     tenMetric,
                                                     strType='softmax')
            img = tenSoftmax[0, :, :, :].cpu().numpy().transpose(1, 2, 0)
            img = (img * 255).astype(numpy.uint8)
            img_array.append(img)
    # end

    print('-- processing %d / %d' % (filecounts, filecounts))
    print('Creating video...')

    out = cv2.VideoWriter('project.mp4', cv2.VideoWriter_fourcc(*'mp4v'),
                          arg_FPS * FPS, (arg_width, arg_height))

    print('Total images : ' + str(len(img_array)))
    for i in range(len(img_array)):