Ejemplo n.º 1
0
    def one_scale(cam_flow_fwd, cam_flow_bwd, flow_fwd, flow_bwd, tgt_img, ref_img_fwd, ref_img_bwd, ws):
        b, _, h, w = cam_flow_fwd.size()
        tgt_img_scaled = nn.functional.adaptive_avg_pool2d(tgt_img, (h, w))
        ref_img_scaled_fwd = nn.functional.adaptive_avg_pool2d(ref_img_fwd, (h, w))
        ref_img_scaled_bwd = nn.functional.adaptive_avg_pool2d(ref_img_bwd, (h, w))

        cam_warped_im_fwd = flow_warp(ref_img_scaled_fwd, cam_flow_fwd)
        cam_warped_im_bwd = flow_warp(ref_img_scaled_bwd, cam_flow_bwd)

        flow_warped_im_fwd = flow_warp(ref_img_scaled_fwd, flow_fwd)
        flow_warped_im_bwd = flow_warp(ref_img_scaled_bwd, flow_bwd)

        valid_pixels_cam_fwd = 1 - (cam_warped_im_fwd == 0).prod(1, keepdim=True).type_as(cam_warped_im_fwd)
        valid_pixels_cam_bwd = 1 - (cam_warped_im_bwd == 0).prod(1, keepdim=True).type_as(cam_warped_im_bwd)
        valid_pixels_cam = logical_or(valid_pixels_cam_fwd, valid_pixels_cam_bwd)  # if one of them is valid, then valid

        valid_pixels_flow_fwd = 1 - (flow_warped_im_fwd == 0).prod(1, keepdim=True).type_as(flow_warped_im_fwd)
        valid_pixels_flow_bwd = 1 - (flow_warped_im_bwd == 0).prod(1, keepdim=True).type_as(flow_warped_im_bwd)
        valid_pixels_flow = logical_or(valid_pixels_flow_fwd, valid_pixels_flow_bwd)  # if one of them is valid, then valid

        cam_err_fwd = ((1-wssim)*robust_l1_per_pix(tgt_img_scaled - cam_warped_im_fwd).mean(1,keepdim=True) \
                    + wssim*(1 - ssim(tgt_img_scaled, cam_warped_im_fwd)).mean(1, keepdim=True))
        cam_err_bwd = ((1-wssim)*robust_l1_per_pix(tgt_img_scaled - cam_warped_im_bwd).mean(1,keepdim=True) \
                    + wssim*(1 - ssim(tgt_img_scaled, cam_warped_im_bwd)).mean(1, keepdim=True))
        cam_err = torch.min(cam_err_fwd, cam_err_bwd) * valid_pixels_cam

        flow_err = (1-wssim)*robust_l1_per_pix(tgt_img_scaled - flow_warped_im_fwd).mean(1, keepdim=True) \
                    + wssim*(1 - ssim(tgt_img_scaled, flow_warped_im_fwd)).mean(1, keepdim=True)
        # flow_err_bwd = (1-wssim)*robust_l1_per_pix(tgt_img_scaled - flow_warped_im_bwd).mean(1, keepdim=True) \
        #             + wssim*(1 - ssim(tgt_img_scaled, flow_warped_im_bwd)).mean(1, keepdim=True)
        # flow_err = torch.min(flow_err_fwd, flow_err_bwd)

        exp_target = (wrig*cam_err <= (flow_err+epsilon)).type_as(cam_err)

        return exp_target
Ejemplo n.º 2
0
    def one_scale(explainability_mask, occ_masks, flows):
        assert(explainability_mask is None or flows[0].size()[2:] == explainability_mask.size()[2:])
        assert(len(flows) == len(ref_imgs))

        reconstruction_loss = 0
        b, _, h, w = flows[0].size()
        downscale = tgt_img.size(2)/h

        tgt_img_scaled = nn.functional.adaptive_avg_pool2d(tgt_img, (h, w))
        ref_imgs_scaled = [nn.functional.adaptive_avg_pool2d(ref_img, (h, w)) for ref_img in ref_imgs]

        weight = 1.

        for i, ref_img in enumerate(ref_imgs_scaled):
            current_flow = flows[i]

            ref_img_warped = flow_warp(ref_img, current_flow)#fomulate 48 w_c
            valid_pixels = 1 - (ref_img_warped == 0).prod(1, keepdim=True).type_as(ref_img_warped)
            diff = (tgt_img_scaled - ref_img_warped) * valid_pixels
            ssim_loss = 1 - ssim(tgt_img_scaled, ref_img_warped) * valid_pixels
            oob_normalization_const = valid_pixels.nelement()/valid_pixels.sum()

            if explainability_mask is not None:
                diff = diff * explainability_mask[:,i:i+1].expand_as(diff)
                ssim_loss = ssim_loss * explainability_mask[:,i:i+1].expand_as(ssim_loss)

            if occ_masks is not None:
                diff = diff *(1-occ_masks[:,i:i+1]).expand_as(diff)
                ssim_loss = ssim_loss*(1-occ_masks[:,i:i+1]).expand_as(ssim_loss)

            reconstruction_loss += (1- wssim)*weight*oob_normalization_const*(robust_l1(diff, q=qch) + wssim*ssim_loss.mean()) + lambda_oob*robust_l1(1 - valid_pixels, q=qch)
            #weight /= 2.83
            assert((reconstruction_loss == reconstruction_loss).item() == 1)

        return reconstruction_loss
Ejemplo n.º 3
0
    def one_scale(flows):
        #assert(explainability_mask is None or flows[0].size()[2:] == explainability_mask.size()[2:])
        assert (len(flows) == len(ref_imgs))

        reconstruction_loss = 0
        b, _, h, w = flows[0].size()

        tgt_img_scaled = nn.functional.adaptive_avg_pool2d(tgt_img, (h, w))
        ref_imgs_scaled = [
            nn.functional.adaptive_avg_pool2d(ref_img, (h, w))
            for ref_img in ref_imgs
        ]

        loss = 0.0
        for i, ref_img in enumerate(ref_imgs_scaled):
            current_flow = flows[i]
            ref_img_warped = flow_warp(ref_img, current_flow)
            valid_pixels = 1 - (ref_img_warped == 0).prod(
                1, keepdim=True).type_as(ref_img_warped)
            diff = (tgt_img_scaled - ref_img_warped)
            if wssim:
                ssim_loss = 1 - ssim(tgt_img_scaled, ref_img_warped)
                reconstruction_loss = (1 - wssim) * robust_l1_per_pix(
                    diff.mean(1, True),
                    q=qch) * valid_pixels + wssim * ssim_loss.mean(1, True)
            else:
                reconstruction_loss = robust_l1_per_pix(diff.mean(1, True),
                                                        q=qch) * valid_pixels
            loss += reconstruction_loss.sum() / valid_pixels.sum()

        return loss
Ejemplo n.º 4
0
    def one_scale(flows):
        #assert(explainability_mask is None or flows[0].size()[2:] == explainability_mask.size()[2:])
        assert (len(flows) == len(ref_imgs))

        reconstruction_loss = 0
        _, _, h, w = flows[0].size()

        tgt_img_scaled = nn.functional.adaptive_avg_pool2d(tgt_img, (h, w))
        ref_imgs_scaled = [
            nn.functional.adaptive_avg_pool2d(ref_img, (h, w))
            for ref_img in ref_imgs
        ]

        reconstruction_loss_all = 0.0
        for i, ref_img in enumerate(ref_imgs_scaled):
            current_flow = flows[i]
            ref_img_warped = flow_warp(ref_img, current_flow)
            valid_pixels = 1 - (ref_img_warped == 0).prod(
                1, keepdim=True).type_as(ref_img_warped)
            reconstruction_loss = gradient_photometric_loss(
                tgt_img_scaled, ref_img_warped,
                qch) * valid_pixels[:, :, :-1, :-1]
            # reconstruction_loss = gradient_photometric_all_direction_loss(tgt_img_scaled, ref_img_warped, qch)*valid_pixels[:,:,1:-1,1:-1]

            reconstruction_loss_all += reconstruction_loss.sum(
            ) / valid_pixels[:, :, :-1, :-1].sum()

        return reconstruction_loss_all
Ejemplo n.º 5
0
    def one_scale(depth, flow_fwd, flow_bwd):

        b, _, h, w = depth.size()
        downscale = tgt_img.size(2) / h
        intrinsics_scaled = torch.cat(
            (intrinsics[:, 0:2] / downscale, intrinsics[:, 2:]), dim=1)
        intrinsics_scaled_inv = torch.cat(
            (intrinsics_inv[:, :, 0:2] * downscale, intrinsics_inv[:, :, 2:]),
            dim=2)
        ref_img_scaled_fwd = nn.functional.adaptive_avg_pool2d(
            ref_imgs[1], (h, w))
        ref_img_scaled_bwd = nn.functional.adaptive_avg_pool2d(
            ref_imgs[0], (h, w))

        depth_warped_im_fwd = inverse_warp(ref_img_scaled_fwd, depth[:, 0],
                                           poses[1], intrinsics_scaled,
                                           intrinsics_scaled_inv,
                                           rotation_mode, padding_mode)
        depth_warped_im_bwd = inverse_warp(ref_img_scaled_bwd, depth[:, 0],
                                           poses[0], intrinsics_scaled,
                                           intrinsics_scaled_inv,
                                           rotation_mode, padding_mode)
        valid_pixels_depth_fwd = 1 - (depth_warped_im_fwd == 0).prod(
            1, keepdim=True).type_as(depth_warped_im_fwd)
        valid_pixels_depth_bwd = 1 - (depth_warped_im_bwd == 0).prod(
            1, keepdim=True).type_as(depth_warped_im_bwd)
        valid_pixels_depth = logical_and(
            valid_pixels_depth_fwd,
            valid_pixels_depth_bwd)  # if one of them is valid, then valid

        flow_warped_im_fwd = flow_warp(ref_img_scaled_fwd, flow_fwd)
        flow_warped_im_bwd = flow_warp(ref_img_scaled_bwd, flow_bwd)

        valid_pixels_flow_fwd = 1 - (flow_warped_im_fwd == 0).prod(
            1, keepdim=True).type_as(flow_warped_im_fwd)
        valid_pixels_flow_bwd = 1 - (flow_warped_im_bwd == 0).prod(
            1, keepdim=True).type_as(flow_warped_im_bwd)
        valid_pixels_flow = logical_and(
            valid_pixels_flow_fwd,
            valid_pixels_flow_bwd)  # if one of them is valid, then valid

        valid_pixel = logical_or(valid_pixels_depth, valid_pixels_flow)

        return valid_pixel
Ejemplo n.º 6
0
    def one_scale(flows):
        #assert(explainability_mask is None or flows[0].size()[2:] == explainability_mask.size()[2:])
        assert (len(flows) == len(ref_imgs))

        reconstruction_loss = 0
        b, _, h, w = flows[0].size()

        tgt_img_scaled = nn.functional.adaptive_avg_pool2d(tgt_img, (h, w))
        ref_imgs_scaled = [
            nn.functional.adaptive_avg_pool2d(ref_img, (h, w))
            for ref_img in ref_imgs
        ]

        reconstruction_loss_all = []
        for i, ref_img in enumerate(ref_imgs_scaled):
            current_flow = flows[i]
            ref_img_warped = flow_warp(ref_img, current_flow)
            # valid_pixels = 1 - (ref_img_warped == 0).prod(1, keepdim=True).type_as(ref_img_warped)
            diff = (tgt_img_scaled - ref_img_warped)
            # ssim_loss = 1 - ssim(tgt_img_scaled, ref_img_warped)
            # if wssim:
            #     reconstruction_loss = (1- wssim)*robust_l1_per_pix(diff.mean(1, True), q=qch) + wssim*ssim_loss.mean(1, True)
            # else:

            reconstruction_loss = robust_l1_per_pix(diff.mean(1, True), q=qch)

            reconstruction_loss_all.append(reconstruction_loss)

        reconstruction_loss = torch.cat(reconstruction_loss_all, 1)
        reconstruction_weight = reconstruction_loss
        # reconstruction_loss_min,_ = reconstruction_loss.min(1,keepdim=True)
        # reconstruction_loss_min = reconstruction_loss_min.repeat(1,2,1,1)
        # loss_weight = reconstruction_loss_min/reconstruction_loss
        # loss_weight = torch.pow(loss_weight,4)

        loss_weight = 1 - torch.nn.functional.softmax(reconstruction_weight, 1)
        loss_weight = Variable(loss_weight.data, requires_grad=False)
        loss = reconstruction_loss * loss_weight
        # loss = torch.mean(loss,3)
        # loss = torch.mean(loss,2)
        # loss = torch.mean(loss,0)
        return loss.sum() / loss_weight.sum()
Ejemplo n.º 7
0
def train(train_loader,
          disp_net,
          pose_net,
          mask_net,
          flow_net,
          optimizer,
          logger=None,
          train_writer=None,
          global_vars_dict=None):
    # 0. 准备
    args = global_vars_dict['args']
    n_iter = global_vars_dict['n_iter']
    device = global_vars_dict['device']

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)
    w1, w2, w3, w4 = args.cam_photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight, args.flow_photo_loss_weight
    w5 = args.consensus_loss_weight

    if args.robust:
        loss_camera = photometric_reconstruction_loss_robust
        loss_flow = photometric_flow_loss_robust
    else:
        loss_camera = photometric_reconstruction_loss
        loss_flow = photometric_flow_loss


#2. switch to train mode
    disp_net.train()
    pose_net.train()
    mask_net.train()
    flow_net.train()

    end = time.time()
    #3. train cycle
    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics = intrinsics.to(device)
        intrinsics_inv = intrinsics_inv.to(device)

        #3.1 compute output and lossfunc input valve---------------------

        #1. disp->depth(none)
        disparities = disp_net(tgt_img)
        if args.spatial_normalize:
            disparities = [spatial_normalize(disp) for disp in disparities]

        depth = [1 / disp for disp in disparities]

        #2. pose(none)
        pose = pose_net(tgt_img, ref_imgs)
        #pose:[4,4,6]

        #3.flow_fwd,flow_bwd 全光流 (depth, pose)
        # 自己改了一点
        if args.flownet == 'Back2Future':  #临近一共三帧做训练/推断
            flow_fwd, flow_bwd, _ = flow_net(tgt_img, ref_imgs[1:3])
        elif args.flownet == 'FlowNetC6':
            flow_fwd = flow_net(tgt_img, ref_imgs[2])
            flow_bwd = flow_net(tgt_img, ref_imgs[1])
        elif args.flownet == 'FlowNetS':
            print(' ')

        # flow_cam 即背景光流
        # flow - flow_s = flow_o
        flow_cam = pose2flow(
            depth[0].squeeze(), pose[:, 2], intrinsics,
            intrinsics_inv)  # pose[:,2] belongs to forward frame
        flows_cam_fwd = [
            pose2flow(depth_.squeeze(1), pose[:, 2], intrinsics,
                      intrinsics_inv) for depth_ in depth
        ]
        flows_cam_bwd = [
            pose2flow(depth_.squeeze(1), pose[:, 1], intrinsics,
                      intrinsics_inv) for depth_ in depth
        ]

        exp_masks_target = consensus_exp_masks(flows_cam_fwd,
                                               flows_cam_bwd,
                                               flow_fwd,
                                               flow_bwd,
                                               tgt_img,
                                               ref_imgs[2],
                                               ref_imgs[1],
                                               wssim=args.wssim,
                                               wrig=args.wrig,
                                               ws=args.smooth_loss_weight)
        rigidity_mask_fwd = [
            (flows_cam_fwd_i - flow_fwd_i).abs()
            for flows_cam_fwd_i, flow_fwd_i in zip(flows_cam_fwd, flow_fwd)
        ]  # .normalize()
        rigidity_mask_bwd = [
            (flows_cam_bwd_i - flow_bwd_i).abs()
            for flows_cam_bwd_i, flow_bwd_i in zip(flows_cam_bwd, flow_bwd)
        ]  # .normalize()
        #v_u

        # 4.explainability_mask(none)
        explainability_mask = mask_net(tgt_img, ref_imgs)  #有效区域?4??
        #list(5):item:tensor:[4,4,128,512]...[4,4,4,16] value:[0.33~0.48~0.63]
        #-------------------------------------------------

        if args.joint_mask_for_depth:
            explainability_mask_for_depth = compute_joint_mask_for_depth(
                explainability_mask, rigidity_mask_bwd, rigidity_mask_fwd,
                args.THRESH)
        else:
            explainability_mask_for_depth = explainability_mask
        #explainability_mask_for_depth list(5) [b,2,h/ , w/]
        if args.no_non_rigid_mask:
            flow_exp_mask = [None for exp_mask in explainability_mask]
            if args.DEBUG:
                print('Using no masks for flow')
        else:
            flow_exp_mask = [
                1 - exp_mask[:, 1:3] for exp_mask in explainability_mask
            ]
            # explaninbility mask 本来是背景mask, 背景对应像素为1
            #取反改成动物mask,并且只要前后两帧
            #list(4) [4,2,256,512]

    #3.2. compute loss重

    # E-r minimizes the photometric loss on static scene
        if w1 > 0:
            loss_1 = loss_camera(tgt_img,
                                 ref_imgs,
                                 intrinsics,
                                 intrinsics_inv,
                                 depth,
                                 explainability_mask_for_depth,
                                 pose,
                                 lambda_oob=args.lambda_oob,
                                 qch=args.qch,
                                 wssim=args.wssim)
        else:
            loss_1 = torch.tensor([0.]).to(device)
        # E_M
        if w2 > 0:
            loss_2 = explainability_loss(
                explainability_mask
            )  #+ 0.2*gaussian_explainability_loss(explainability_mask)
        else:
            loss_2 = 0
        # E_S
        if w3 > 0:
            if args.smoothness_type == "regular":
                loss_3 = smooth_loss(depth) + smooth_loss(
                    flow_fwd) + smooth_loss(flow_bwd) + smooth_loss(
                        explainability_mask)
            elif args.smoothness_type == "edgeaware":
                loss_3 = edge_aware_smoothness_loss(
                    tgt_img, depth) + edge_aware_smoothness_loss(
                        tgt_img, flow_fwd)
                loss_3 += edge_aware_smoothness_loss(
                    tgt_img, flow_bwd) + edge_aware_smoothness_loss(
                        tgt_img, explainability_mask)
        else:
            loss_3 = torch.tensor([0.]).to(device)
        # E_F
        # minimizes photometric loss on moving regions

        if w4 > 0:
            loss_4 = loss_flow(tgt_img,
                               ref_imgs[1:3], [flow_bwd, flow_fwd],
                               flow_exp_mask,
                               lambda_oob=args.lambda_oob,
                               qch=args.qch,
                               wssim=args.wssim)
        else:
            loss_4 = torch.tensor([0.]).to(device)
        # E_C
        # drives the collaboration
        #explainagy_mask:list(6) of [4,4,4,16] rigidity_mask :list(4):[4,2,128,512]
        if w5 > 0:
            loss_5 = consensus_depth_flow_mask(explainability_mask,
                                               rigidity_mask_bwd,
                                               rigidity_mask_fwd,
                                               exp_masks_target,
                                               exp_masks_target,
                                               THRESH=args.THRESH,
                                               wbce=args.wbce)
        else:
            loss_5 = torch.tensor([0.]).to(device)

        #3.2.6
        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 + w4 * loss_4 + w5 * loss_5
        #end of loss

        #3.3
        # record loss and EPE
        losses.update(loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        #3.4 log data

        # add scalar
        if args.scalar_freq > 0 and n_iter % args.scalar_freq == 0:
            train_writer.add_scalar('batch/cam_photometric_error',
                                    loss_1.item(), n_iter)
            if w2 > 0:
                train_writer.add_scalar('batch/explanability_loss',
                                        loss_2.item(), n_iter)
            train_writer.add_scalar('batch/disparity_smoothness_loss',
                                    loss_3.item(), n_iter)
            train_writer.add_scalar('batch/flow_photometric_error',
                                    loss_4.item(), n_iter)
            train_writer.add_scalar('batch/consensus_error', loss_5.item(),
                                    n_iter)
            train_writer.add_scalar('batch/total_loss', loss.item(), n_iter)

        # add_image为0 则不输出
        if args.training_output_freq > 0 and n_iter % args.training_output_freq == 0:

            train_writer.add_image('train Input', tensor2array(tgt_img[0]),
                                   n_iter)
            train_writer.add_image(
                'train Cam Flow Output',
                flow_to_image(tensor2array(flow_cam.data[0].cpu())), n_iter)

            for k, scaled_depth in enumerate(depth):
                train_writer.add_image(
                    'train Dispnet Output Normalized111 {}'.format(k),
                    tensor2array(disparities[k].data[0].cpu(),
                                 max_value=None,
                                 colormap='bone'), n_iter)
                train_writer.add_image(
                    'train Depth Output {}'.format(k),
                    tensor2array(1 / disparities[k].data[0].cpu(),
                                 max_value=10), n_iter)
                train_writer.add_image(
                    'train Non Rigid Flow Output {}'.format(k),
                    flow_to_image(tensor2array(flow_fwd[k].data[0].cpu())),
                    n_iter)
                train_writer.add_image(
                    'train Target Rigidity {}'.format(k),
                    tensor2array((rigidity_mask_fwd[k] > args.THRESH).type_as(
                        rigidity_mask_fwd[k]).data[0].cpu(),
                                 max_value=1,
                                 colormap='bone'), n_iter)

                b, _, h, w = scaled_depth.size()
                downscale = tgt_img.size(2) / h

                tgt_img_scaled = nn.functional.adaptive_avg_pool2d(
                    tgt_img, (h, w))
                ref_imgs_scaled = [
                    nn.functional.adaptive_avg_pool2d(ref_img, (h, w))
                    for ref_img in ref_imgs
                ]

                intrinsics_scaled = torch.cat(
                    (intrinsics[:, 0:2] / downscale, intrinsics[:, 2:]), dim=1)
                intrinsics_scaled_inv = torch.cat(
                    (intrinsics_inv[:, :, 0:2] * downscale,
                     intrinsics_inv[:, :, 2:]),
                    dim=2)

                train_writer.add_image(
                    'train Non Rigid Warped Image {}'.format(k),
                    tensor2array(
                        flow_warp(ref_imgs_scaled[2],
                                  flow_fwd[k]).data[0].cpu()), n_iter)

                # log warped images along with explainability mask
                for j, ref in enumerate(ref_imgs_scaled):
                    ref_warped = inverse_warp(
                        ref,
                        scaled_depth[:, 0],
                        pose[:, j],
                        intrinsics_scaled,
                        intrinsics_scaled_inv,
                        rotation_mode=args.rotation_mode,
                        padding_mode=args.padding_mode)[0]
                    train_writer.add_image(
                        'train Warped Outputs {} {}'.format(k, j),
                        tensor2array(ref_warped.data.cpu()), n_iter)
                    train_writer.add_image(
                        'train Diff Outputs {} {}'.format(k, j),
                        tensor2array(
                            0.5 *
                            (tgt_img_scaled[0] - ref_warped).abs().data.cpu()),
                        n_iter)
                    if explainability_mask[k] is not None:
                        train_writer.add_image(
                            'train Exp mask Outputs {} {}'.format(k, j),
                            tensor2array(explainability_mask[k][0,
                                                                j].data.cpu(),
                                         max_value=1,
                                         colormap='bone'), n_iter)

        # csv file write
        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([
                loss.item(),
                loss_1.item(),
                loss_2.item() if w2 > 0 else 0,
                loss_3.item(),
                loss_4.item()
            ])
        #terminal output
        if args.log_terminal:
            logger.train_bar.update(i + 1)  #当前epoch 进度
            if i % args.print_freq == 0:
                logger.valid_bar_writer.write(
                    'Train: Time {} Data {} Loss {}'.format(
                        batch_time, data_time, losses))

    # 3.4 edge conditionsssssssssssssssssssssssss
        epoch_size = len(train_loader)
        if i >= epoch_size - 1:
            break

        n_iter += 1

    global_vars_dict['n_iter'] = n_iter
    return losses.avg[0]  #epoch loss
Ejemplo n.º 8
0
def validate_flow_with_gt(val_loader,
                          disp_net,
                          pose_net,
                          mask_net,
                          flow_net,
                          epoch,
                          logger,
                          output_writers=[]):
    global args
    batch_time = AverageMeter()
    error_names = [
        'epe_total', 'epe_rigid', 'epe_non_rigid', 'outliers',
        'epe_total_with_gt_mask', 'epe_rigid_with_gt_mask',
        'epe_non_rigid_with_gt_mask', 'outliers_gt_mask'
    ]
    errors = AverageMeter(i=len(error_names))
    log_outputs = len(output_writers) > 0

    # switch to evaluate mode
    disp_net.eval()
    pose_net.eval()
    mask_net.eval()
    flow_net.eval()

    end = time.time()

    poses = np.zeros(
        ((len(val_loader) - 1) * 1 * (args.sequence_length - 1), 6))

    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv, flow_gt,
            obj_map_gt) in enumerate(val_loader):
        tgt_img = Variable(tgt_img.cuda(), volatile=True)
        ref_imgs = [Variable(img.cuda(), volatile=True) for img in ref_imgs]
        intrinsics_var = Variable(intrinsics.cuda(), volatile=True)
        intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True)

        flow_gt_var = Variable(flow_gt.cuda(), volatile=True)
        obj_map_gt_var = Variable(obj_map_gt.cuda(), volatile=True)

        # compute output-------------------------

        #1. disp fwd
        disp = disp_net(tgt_img)
        if args.spatial_normalize:
            disp = spatial_normalize(disp)

        depth = 1 / disp

        #2. pose fwd
        pose = pose_net(tgt_img, ref_imgs)

        #3. mask fwd
        explainability_mask = mask_net(tgt_img, ref_imgs)

        #4. flow fwd
        if args.flownet == 'Back2Future':
            flow_fwd, flow_bwd, _ = flow_net(tgt_img, ref_imgs[1:3])  #前一帧,后一阵
        elif args.flownet == 'FlowNetC6':
            flow_fwd = flow_net(tgt_img, ref_imgs[2])
            flow_bwd = flow_net(tgt_img, ref_imgs[1])
        # compute output-------------------------

        if args.DEBUG:
            flow_fwd_x = flow_fwd[:, 0].view(-1).abs().data
            print("Flow Fwd Median: ", flow_fwd_x.median())
            flow_gt_var_x = flow_gt_var[:, 0].view(-1).abs().data
            print(
                "Flow GT Median: ",
                flow_gt_var_x.index_select(
                    0,
                    flow_gt_var_x.nonzero().view(-1)).median())

        flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics_var,
                             intrinsics_inv_var)
        oob_rigid = flow2oob(flow_cam)
        oob_non_rigid = flow2oob(flow_fwd)

        rigidity_mask = 1 - (1 - explainability_mask[:, 1]) * (
            1 - explainability_mask[:, 2]).unsqueeze(1) > 0.5

        rigidity_mask_census_soft = (flow_cam - flow_fwd).abs()  #.normalize()
        rigidity_mask_census_u = rigidity_mask_census_soft[:, 0] < args.THRESH
        rigidity_mask_census_v = rigidity_mask_census_soft[:, 1] < args.THRESH
        rigidity_mask_census = (rigidity_mask_census_u).type_as(flow_fwd) * (
            rigidity_mask_census_v).type_as(flow_fwd)

        rigidity_mask_combined = 1 - (
            1 - rigidity_mask.type_as(explainability_mask)) * (
                1 - rigidity_mask_census.type_as(explainability_mask))

        #get flow
        flow_fwd_non_rigid = (rigidity_mask_combined <= args.THRESH).type_as(
            flow_fwd).expand_as(flow_fwd) * flow_fwd
        flow_fwd_rigid = (rigidity_mask_combined > args.THRESH
                          ).type_as(flow_fwd).expand_as(flow_fwd) * flow_cam
        total_flow = flow_fwd_rigid + flow_fwd_non_rigid

        obj_map_gt_var_expanded = obj_map_gt_var.unsqueeze(1).type_as(flow_fwd)

        if log_outputs and i % 10 == 0 and i / 10 < len(output_writers):
            index = int(i // 10)
            if epoch == 0:
                output_writers[index].add_image('val flow Input',
                                                tensor2array(tgt_img[0]), 0)
                flow_to_show = flow_gt[0][:2, :, :].cpu()
                output_writers[index].add_image(
                    'val target Flow',
                    flow_to_image(tensor2array(flow_to_show)), epoch)

            output_writers[index].add_image(
                'val Total Flow Output',
                flow_to_image(tensor2array(total_flow.data[0].cpu())), epoch)
            output_writers[index].add_image(
                'val Rigid Flow Output',
                flow_to_image(tensor2array(flow_fwd_rigid.data[0].cpu())),
                epoch)
            output_writers[index].add_image(
                'val Non-rigid Flow Output',
                flow_to_image(tensor2array(flow_fwd_non_rigid.data[0].cpu())),
                epoch)
            output_writers[index].add_image(
                'val Out of Bound (Rigid)',
                tensor2array(oob_rigid.type(torch.FloatTensor).data[0].cpu(),
                             max_value=1,
                             colormap='bone'), epoch)
            output_writers[index].add_scalar(
                'val Mean oob (Rigid)',
                oob_rigid.type(torch.FloatTensor).sum(), epoch)
            output_writers[index].add_image(
                'val Out of Bound (Non-Rigid)',
                tensor2array(oob_non_rigid.type(
                    torch.FloatTensor).data[0].cpu(),
                             max_value=1,
                             colormap='bone'), epoch)
            output_writers[index].add_scalar(
                'val Mean oob (Non-Rigid)',
                oob_non_rigid.type(torch.FloatTensor).sum(), epoch)
            output_writers[index].add_image(
                'val Cam Flow Errors',
                tensor2array(flow_diff(flow_gt_var, flow_cam).data[0].cpu()),
                epoch)
            output_writers[index].add_image(
                'val Rigidity Mask',
                tensor2array(rigidity_mask.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), epoch)
            output_writers[index].add_image(
                'val Rigidity Mask Census',
                tensor2array(rigidity_mask_census.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), epoch)

            for j, ref in enumerate(ref_imgs):
                ref_warped = inverse_warp(ref[:1],
                                          depth[:1, 0],
                                          pose[:1, j],
                                          intrinsics_var[:1],
                                          intrinsics_inv_var[:1],
                                          rotation_mode=args.rotation_mode,
                                          padding_mode=args.padding_mode)[0]

                output_writers[index].add_image(
                    'val Warped Outputs {}'.format(j),
                    tensor2array(ref_warped.data.cpu()), epoch)
                output_writers[index].add_image(
                    'val Diff Outputs {}'.format(j),
                    tensor2array(0.5 *
                                 (tgt_img[0] - ref_warped).abs().data.cpu()),
                    epoch)
                if explainability_mask is not None:
                    output_writers[index].add_image(
                        'val Exp mask Outputs {}'.format(j),
                        tensor2array(explainability_mask[0, j].data.cpu(),
                                     max_value=1,
                                     colormap='bone'), epoch)

            if args.DEBUG:
                # Check if pose2flow is consistant with inverse warp
                ref_warped_from_depth = inverse_warp(
                    ref_imgs[2][:1],
                    depth[:1, 0],
                    pose[:1, 2],
                    intrinsics_var[:1],
                    intrinsics_inv_var[:1],
                    rotation_mode=args.rotation_mode,
                    padding_mode=args.padding_mode)[0]
                ref_warped_from_cam_flow = flow_warp(ref_imgs[2][:1],
                                                     flow_cam)[0]
                print(
                    "DEBUG_INFO: Inverse_warp vs pose2flow",
                    torch.mean(
                        torch.abs(ref_warped_from_depth -
                                  ref_warped_from_cam_flow)).item())
                output_writers[index].add_image(
                    'val Warped Outputs from Cam Flow',
                    tensor2array(ref_warped_from_cam_flow.data.cpu()), epoch)
                output_writers[index].add_image(
                    'val Warped Outputs from inverse warp',
                    tensor2array(ref_warped_from_depth.data.cpu()), epoch)

        if log_outputs and i < len(val_loader) - 1:
            step = args.sequence_length - 1
            poses[i * step:(i + 1) * step] = pose.data.cpu().view(-1,
                                                                  6).numpy()

        if np.isnan(flow_gt.sum().item()) or np.isnan(
                total_flow.data.sum().item()):
            print('NaN encountered')
        #
        _epe_errors = compute_all_epes(
            flow_gt_var, flow_cam,
            flow_fwd, rigidity_mask_combined) + compute_all_epes(
                flow_gt_var, flow_cam, flow_fwd, (1 - obj_map_gt_var_expanded))
        errors.update(_epe_errors)

        if args.DEBUG:
            print("DEBUG_INFO: EPE errors: ", _epe_errors)
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

    if log_outputs:
        output_writers[0].add_histogram('val poses_tx', poses[:, 0], epoch)
        output_writers[0].add_histogram('val poses_ty', poses[:, 1], epoch)
        output_writers[0].add_histogram('val poses_tz', poses[:, 2], epoch)
        if args.rotation_mode == 'euler':
            rot_coeffs = ['rx', 'ry', 'rz']
        elif args.rotation_mode == 'quat':
            rot_coeffs = ['qx', 'qy', 'qz']
        output_writers[0].add_histogram('val poses_{}'.format(rot_coeffs[0]),
                                        poses[:, 3], epoch)
        output_writers[0].add_histogram('val poses_{}'.format(rot_coeffs[1]),
                                        poses[:, 4], epoch)
        output_writers[0].add_histogram('val poses_{}'.format(rot_coeffs[2]),
                                        poses[:, 5], epoch)

    if args.DEBUG:
        print("DEBUG_INFO =================>")
        print("DEBUG_INFO: Average EPE : ", errors.avg)
        print("DEBUG_INFO =================>")
        print("DEBUG_INFO =================>")
        print("DEBUG_INFO =================>")

    return errors.avg, error_names
Ejemplo n.º 9
0
def train(train_loader, flow_net, optimizer, epoch_size, logger=None, train_writer=None):
    global args, n_iter
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)

    # switch to train mode
    flow_net.train()

    end = time.time()

    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        tgt_img_var = Variable(tgt_img.cuda())
        ref_imgs_var = [Variable(img.cuda()) for img in ref_imgs]

        if args.flownet == 'Back2Future':
            flow_fwd, flow_bwd = flow_net(tgt_img_var, ref_imgs_var)
        else:
            flow_fwd = flow_net(tgt_img_var, ref_imgs_var[1])
            flow_bwd = flow_net(tgt_img_var, ref_imgs_var[0])
            
        loss_smooth = torch.zeros(1).cuda() 
        loss_flow_recon = torch.zeros(1).cuda()
        loss_velocity_consis = torch.zeros(1).cuda()

        if args.flow_photo_loss_weight_first: 
            if args.min:
                loss_flow_recon += args.flow_photo_loss_weight_first*photometric_flow_min_loss(tgt_img_var, ref_imgs_var, [flow_bwd, flow_fwd],
                                                lambda_oob=args.lambda_oob, qch=args.qch, wssim=args.wssim)
            else:
                loss_flow_recon += args.flow_photo_loss_weight_first*photometric_flow_loss(tgt_img_var, ref_imgs_var, [flow_bwd, flow_fwd],
                                                lambda_oob=args.lambda_oob, qch=args.qch, wssim=args.wssim)

        if args.flow_photo_loss_weight_second: 
            if args.min:
                
                loss_per, loss_weight= photometric_flow_gradient_min_loss(tgt_img_var, ref_imgs_var, [flow_bwd, flow_fwd],
                                                lambda_oob=args.lambda_oob, qch=args.qch, wssim=args.wssim)
                loss_flow_recon += args.flow_photo_loss_weight_second * loss_per
            else:
                loss_flow_recon += args.flow_photo_loss_weight_second*photometric_flow_gradient_loss(tgt_img_var, ref_imgs_var, [flow_bwd, flow_fwd],
                                                lambda_oob=args.lambda_oob, qch=args.qch, wssim=args.wssim)
            

        if args.smooth_loss_weight_first:
            if args.smoothness_type == "regular":
                loss_smooth += args.smooth_loss_weight_first*(smooth_loss(flow_fwd) + smooth_loss(flow_bwd))
            elif args.smoothness_type == "edgeaware":
                loss_smooth += args.smooth_loss_weight_first*(edge_aware_smoothness_loss(tgt_img_var, flow_fwd)+edge_aware_smoothness_loss(tgt_img_var, flow_bwd))

        if args.smooth_loss_weight_second:
            if args.smoothness_type == "regular":
                loss_smooth += args.smooth_loss_weight_second*(smooth_loss(flow_fwd) + smooth_loss(flow_bwd))
            elif args.smoothness_type == "edgeaware":
                loss_smooth = args.smooth_loss_weight_second*(edge_aware_smoothness_second_order_loss_change_weight(tgt_img_var, flow_bwd, args.alpha)\
                    + edge_aware_smoothness_second_order_loss_change_weight(tgt_img_var, flow_fwd, args.alpha))


        if args.velocity_consis_loss_weight:
            loss_velocity_consis = args.velocity_consis_loss_weight*flow_velocity_consis_loss( [flow_bwd, flow_fwd])


        loss = loss_smooth + loss_flow_recon + loss_velocity_consis
        
        if i > 0 and n_iter % args.print_freq == 0:
            train_writer.add_scalar('flow_photometric_error', loss_flow_recon.item(), n_iter)
            train_writer.add_scalar('flow_smoothness_loss', loss_smooth.item(), n_iter)
            train_writer.add_scalar('velocity_consis_loss', loss_velocity_consis.item(), n_iter)
            train_writer.add_scalar('total_loss', loss.item(), n_iter)


        if args.training_output_freq > 0 and n_iter % args.training_output_freq == 0:

            train_writer.add_image('train Input', tensor2array(tgt_img[0]), n_iter)
            train_writer.add_image('train Flow FWD Output',flow_to_image(tensor2array(flow_fwd[0].data[0].cpu())) , n_iter )
            train_writer.add_image('train Flow BWD Output',flow_to_image(tensor2array(flow_bwd[0].data[0].cpu())) , n_iter )

            loss_weight_bwd = loss_weight[0][0,0,:,:].unsqueeze(0)
            loss_weight_fwd = loss_weight[0][0,1,:,:].unsqueeze(0)

            train_writer.add_image('loss_weight_bwd', tensor2array(loss_weight_bwd.data[0].cpu(), max_value=None, colormap='bone'), n_iter)
            train_writer.add_image('loss_weight_fwd', tensor2array(loss_weight_fwd.data[0].cpu(), max_value=None, colormap='bone'), n_iter)


            train_writer.add_image('train Flow FWD error Image',tensor2array(flow_warp(tgt_img_var-ref_imgs_var[1],flow_fwd[0]).data[0].cpu()) , n_iter )
            train_writer.add_image('train Flow BWD error Image',tensor2array(flow_warp(tgt_img_var-ref_imgs_var[0],flow_bwd[0]).data[0].cpu()) , n_iter )

            
        # record loss and EPE
        losses.update(loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if args.log_terminal:
            logger.train_bar.update(i+1)
            if i % args.print_freq == 0:
                logger.train_writer.write('Train: Time {} Data {} Loss {}'.format(batch_time, data_time, losses))
        if i >= epoch_size - 1:
            break

        n_iter += 1

    return losses.avg[0]
Ejemplo n.º 10
0
    def one_scale(flows):
        assert (len(flows) == len(ref_imgs))

        # reconstruction_loss = 0
        b, _, h, w = flows[0].size()

        tgt_img_scaled = nn.functional.adaptive_avg_pool2d(tgt_img, (h, w))
        ref_imgs_scaled = [
            nn.functional.adaptive_avg_pool2d(ref_img, (h, w))
            for ref_img in ref_imgs
        ]

        reconstruction_loss_all = []
        reconstruction_weight_all = []
        # consistancy_loss_all = []
        ssim_loss = 0.0
        for i, ref_img in enumerate(ref_imgs_scaled):
            current_flow = flows[i]
            ref_img_warped = flow_warp(ref_img, current_flow)
            diff = (tgt_img_scaled - ref_img_warped)

            if wssim:
                ssim_loss += wssim * (
                    1 - ssim(tgt_img_scaled, ref_img_warped)).mean()

            # reconstruction_loss = gradient_photometric_loss(tgt_img_scaled, ref_img_warped, qch)
            reconstruction_loss = gradient_photometric_all_direction_loss(
                tgt_img_scaled, ref_img_warped, qch)
            reconstruction_weight = robust_l1_per_pix(diff.mean(1, True),
                                                      q=qch)
            # reconstruction_weight = reconstruction_loss
            reconstruction_loss_all.append(reconstruction_loss)
            reconstruction_weight_all.append(reconstruction_weight)
            # consistancy_loss_all.append(reconstruction_loss)

        reconstruction_loss = torch.cat(reconstruction_loss_all, 1)
        reconstruction_weight = torch.cat(reconstruction_weight_all, 1)
        # consistancy_loss = torch.cat(consistancy_loss_all,1)

        # reconstruction_weight_min,_ = reconstruction_weight.min(1,keepdim=True)
        # reconstruction_weight_min = reconstruction_weight_min.repeat(1,2,1,1)
        # reconstruction_weight_sum = reconstruction_weight.sum(1,keepdim=True)
        # reconstruction_weight_sum = reconstruction_weight_sum.repeat(1,2,1,1)

        # consistancy_loss = consistancy_loss[:,0,:,:]-consistancy_loss[:,1,:,:]
        # consistancy_loss = wconsis*torch.mean(torch.abs(consistancy_loss))

        # loss_weight = reconstruction_weight_min/(reconstruction_weight)
        # loss_weight = reconstruction_weight/reconstruction_weight_sum
        loss_weight = 1 - torch.nn.functional.softmax(reconstruction_weight, 1)
        # loss_weight = (loss_weight >= 0.4).type_as(reconstruction_loss)
        # print(loss_weight.size())
        # loss_weight = loss_weight[:,:,:-1,:-1]
        loss_weight = loss_weight[:, :, 1:-1, 1:-1]
        # loss_weight = scale_weight(loss_weight,0.3,10)

        # # loss_weight = torch.pow(loss_weight,4)
        loss_weight = Variable(loss_weight.data, requires_grad=False)
        loss = reconstruction_loss * loss_weight
        # loss, _ = torch.min(reconstruction_loss, dim=1)
        # # loss = torch.mean(loss,3)
        # # loss = torch.mean(loss,2)
        # # loss = torch.mean(loss,0)
        # loss, _ = torch.min(reconstruction_loss, dim=1)
        loss = loss.sum() / loss_weight.sum()
        return loss + ssim_loss, loss_weight
def flow_loss(tgt_img,
              ref_imgs,
              flows,
              explainability_mask,
              lambda_oob=0,
              qch=0.5,
              wssim=0.5):
    def one_scale(explainability_mask, occ_masks, flows):
        reconstruction_loss = 0
        b, _, h, w = flows[0].size()
        downscale = tgt_img.size(2) / h

        tgt_img_scaled = nn.functional.adaptive_avg_pool2d(tgt_img, (h, w))
        ref_imgs_scaled = [
            nn.functional.adaptive_avg_pool2d(ref_img, (h, w))
            for ref_img in ref_imgs
        ]

        weight = 1.

        for i, ref_img in enumerate(ref_imgs_scaled):
            current_flow = flows[i]

            ref_img_warped = flow_warp(ref_img, current_flow)
            valid_pixels = 1 - (ref_img_warped == 0).prod(
                1, keepdim=True).type_as(ref_img_warped)
            diff = (tgt_img_scaled - ref_img_warped) * valid_pixels
            ssim_loss = 1 - ssim(tgt_img_scaled, ref_img_warped) * valid_pixels
            oob_normalization_const = valid_pixels.nelement(
            ) / valid_pixels.sum()

            if explainability_mask is not None:
                diff = diff * explainability_mask[:, i:i + 1].expand_as(diff)
                ssim_loss = ssim_loss * explainability_mask[:,
                                                            i:i + 1].expand_as(
                                                                ssim_loss)

            if occ_masks is not None:
                diff = diff * (1 - occ_masks[:, i:i + 1]).expand_as(diff)
                ssim_loss = ssim_loss * (
                    1 - occ_masks[:, i:i + 1]).expand_as(ssim_loss)

            reconstruction_loss += (
                1 - wssim) * weight * oob_normalization_const * (
                    robust_l1(diff, q=qch) + wssim * ssim_loss.mean()
                ) + lambda_oob * robust_l1(1 - valid_pixels, q=qch)
            #weight /= 2.83
            assert ((reconstruction_loss == reconstruction_loss).item() == 1)

        return reconstruction_loss

    if type(flows[0]) not in [tuple, list]:
        if explainability_mask is not None:
            explainability_mask = [explainability_mask]
        flows = [[uv] for uv in flows]

    loss = 0
    for i in range(len(flows[0])):
        flow_at_scale = [uv[i] for uv in flows]
        occ_mask_at_scale_bw, occ_mask_at_scale_fw = occlusion_masks(
            flow_at_scale[0], flow_at_scale[1])
        occ_mask_at_scale = torch.stack(
            (occ_mask_at_scale_bw, occ_mask_at_scale_fw), dim=1)
        # occ_mask_at_scale = None
        loss += one_scale(explainability_mask[i], occ_mask_at_scale,
                          flow_at_scale)
    ref_img_warped = flow_warp(ref_imgs[0], flows[0][0])
    diff = (tgt_img - ref_img_warped)
    return loss, ref_img_warped, diff