Пример #1
0
def train(odometry_net, depth_net, train_loader, epoch, optimizer):
    global device
    odometry_net.set_fix_method(nfp.FIX_AUTO)
    odometry_net.train()
    depth_net.train()
    total_loss = 0
    lr_total = 0
    r12_total = 0
    smooth_total = 0
    for batch_idx, (img_R1, img_L2, img_R2, intrinsics, inv_intrinsics, raw_K,
                    T_R2L) in tqdm(enumerate(train_loader),
                                   desc='Train epoch %d' % epoch,
                                   leave=False,
                                   ncols=80):
        img_R1 = img_R1.type(torch.FloatTensor).to(device)
        img_R2 = img_R2.type(torch.FloatTensor).to(device)
        img_L2 = img_L2.type(torch.FloatTensor).to(device)
        intrinsics = intrinsics.type(torch.FloatTensor).to(device)
        inv_intrinsics = inv_intrinsics.type(torch.FloatTensor).to(device)
        raw_K = raw_K.type(torch.FloatTensor).to(device)
        T_R2L = T_R2L.type(torch.FloatTensor).to(device)

        batch_size = img_R1.size(0)

        img_R = torch.cat((img_R2, img_R1), dim=1)
        K = torch.cat((raw_K, raw_K), dim=0)

        norm_img_L2 = 0.004 * img_L2
        norm_img_R1 = 0.004 * img_R1
        norm_img_R2 = 0.004 * img_R2

        inv_depth_img_R2 = depth_net(img_R2)
        T_2to1, _ = odometry_net(img_R)

        T = torch.cat((T_R2L, T_2to1), dim=0)

        SE3 = generate_se3(T)
        inv_depth = torch.cat((inv_depth_img_R2, inv_depth_img_R2), dim=0)
        depth = (1 / (inv_depth + 1e-4))

        pts3D = geo_transform(depth, SE3, K)
        proj_coords = pin_hole_project(pts3D, K)

        Isrc = torch.cat((norm_img_L2, norm_img_R1), dim=0)
        warp_Itgt = inverse_warp(Isrc, proj_coords)

        warp_Itgt_LR = warp_Itgt[:batch_size, :, :, :]
        warp_Itgt_R12 = warp_Itgt[batch_size:, :, :, :]

        out_of_bound = 1 - (warp_Itgt_LR == 0).prod(
            1, keepdim=True).type_as(warp_Itgt_LR)
        diff_LR = (norm_img_R2 - warp_Itgt_LR) * out_of_bound
        LR_error = diff_LR.abs().mean()

        out_of_bound = 1 - (warp_Itgt_R12 == 0).prod(
            1, keepdim=True).type_as(warp_Itgt_R12)
        diff_R12 = (norm_img_R2 - warp_Itgt_R12) * out_of_bound
        R12_error = diff_R12.abs().mean()

        smooth_error = smooth_loss(depth)
        loss = LR_error + R12_error + 10 * smooth_error

        total_loss += loss.item()
        lr_total += LR_error.item()
        r12_total += R12_error.item()
        smooth_total += smooth_error.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(
        "Train epoch {}: loss: {:.6f} LR-loss: {:.6f} R12-loss: {:.6f} smooth-loss: {:.6f}"
        .format(epoch, total_loss / len(train_loader),
                lr_total / len(train_loader), r12_total / len(train_loader),
                smooth_total / len(train_loader)))
Пример #2
0
def train(train_loader, mask_net, pose_net, optimizer, epoch_size,
          train_writer):
    global args, n_iter
    w1 = args.smooth_loss_weight
    w2 = args.mask_loss_weight
    w3 = args.consensus_loss_weight
    w4 = args.pose_loss_weight

    mask_net.train()
    pose_net.train()
    average_loss = 0
    for i, (rgb_tgt_img, rgb_ref_imgs, depth_tgt_img, depth_ref_imgs,
            mask_tgt_img, mask_ref_imgs, intrinsics, intrinsics_inv,
            pose_list) in enumerate(tqdm(train_loader)):
        rgb_tgt_img_var = Variable(rgb_tgt_img.cuda())
        rgb_ref_imgs_var = [Variable(img.cuda()) for img in rgb_ref_imgs]
        depth_tgt_img_var = Variable(depth_tgt_img.unsqueeze(1).cuda())
        depth_ref_imgs_var = [
            Variable(img.unsqueeze(1).cuda()) for img in depth_ref_imgs
        ]
        mask_tgt_img_var = Variable(mask_tgt_img.cuda())
        mask_ref_imgs_var = [Variable(img.cuda()) for img in mask_ref_imgs]

        mask_tgt_img_var = torch.where(mask_tgt_img_var > 0,
                                       torch.ones_like(mask_tgt_img_var),
                                       torch.zeros_like(mask_tgt_img_var))
        mask_ref_imgs_var = [
            torch.where(img > 0, torch.ones_like(img), torch.zeros_like(img))
            for img in mask_ref_imgs_var
        ]

        intrinsics_var = Variable(intrinsics.cuda())
        intrinsics_inv_var = Variable(intrinsics_inv.cuda())
        # pose_list_var = [Variable(one_pose.float().cuda()) for one_pose in pose_list]

        explainability_mask = mask_net(rgb_tgt_img_var, rgb_ref_imgs_var)

        # print(explainability_mask[0].size()) #torch.Size([4, 2, 384, 512])
        # print()
        pose = pose_net(rgb_tgt_img_var, rgb_ref_imgs_var)
        # loss 1: smoothness loss
        loss1 = smooth_loss(explainability_mask)

        # loss 2: explainability loss
        loss2 = explainability_loss(explainability_mask)

        # loss 3 consensus loss (the mask from networks and the mask from residual)
        loss3 = consensus_loss(explainability_mask[0], mask_ref_imgs_var)

        # loss 4 pose loss
        valid_pixle_mask = [
            torch.where(depth_ref_imgs_var[0] == 0,
                        torch.zeros_like(depth_tgt_img_var),
                        torch.ones_like(depth_tgt_img_var)),
            torch.where(depth_ref_imgs_var[1] == 0,
                        torch.zeros_like(depth_tgt_img_var),
                        torch.ones_like(depth_tgt_img_var))
        ]  # zero is invalid

        loss4, ref_img_warped, diff = pose_loss(
            valid_pixle_mask, mask_ref_imgs_var, rgb_tgt_img_var,
            rgb_ref_imgs_var, intrinsics_var, intrinsics_inv_var,
            depth_tgt_img_var, pose)

        # compute gradient and do Adam step
        loss = w1 * loss1 + w2 * loss2 + w3 * loss3 + w4 * loss4
        average_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # visualization in tensorboard
        if i > 0 and n_iter % args.print_freq == 0:
            train_writer.add_scalar('smoothness loss', loss1.item(), n_iter)
            train_writer.add_scalar('explainability loss', loss2.item(),
                                    n_iter)
            train_writer.add_scalar('consensus loss', loss3.item(), n_iter)
            train_writer.add_scalar('pose loss', loss4.item(), n_iter)
            train_writer.add_scalar('total loss', loss.item(), n_iter)
        if n_iter % (args.training_output_freq) == 0:
            train_writer.add_image('train Input',
                                   tensor2array(rgb_tgt_img_var[0]), n_iter)
            train_writer.add_image(
                'train Exp mask Outputs ',
                tensor2array(explainability_mask[0][0, 0].data.cpu(),
                             max_value=1,
                             colormap='bone'), n_iter)
            train_writer.add_image(
                'train gt mask ',
                tensor2array(mask_tgt_img[0].data.cpu(),
                             max_value=1,
                             colormap='bone'), n_iter)
            train_writer.add_image(
                'train depth ',
                tensor2array(depth_tgt_img[0].data.cpu(),
                             max_value=1,
                             colormap='bone'), n_iter)
            train_writer.add_image(
                'train after mask',
                tensor2array(rgb_tgt_img_var[0] *
                             explainability_mask[0][0, 0]), n_iter)
            train_writer.add_image('train diff', tensor2array(diff[0]), n_iter)
            train_writer.add_image('train warped img',
                                   tensor2array(ref_img_warped[0]), n_iter)

        n_iter += 1

    return average_loss / i
Пример #3
0
def validate_without_gt(args,
                        val_loader,
                        disp_net,
                        pose_exp_net,
                        epoch,
                        logger,
                        output_writers=[]):
    batch_time = AverageMeter()
    losses = AverageMeter(i=3, precision=4)
    log_outputs = len(output_writers) > 0
    w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight
    poses = np.zeros(
        ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1),
         6))
    disp_values = np.zeros(((len(val_loader) - 1) * args.batch_size * 3))

    # switch to evaluate mode
    disp_net.eval()
    pose_exp_net.eval()

    end = time.time()
    logger.valid_bar.update(0)
    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(val_loader):
        tgt_img_var = Variable(tgt_img.cuda(), volatile=True)
        ref_imgs_var = [
            Variable(img.cuda(), volatile=True) for img in ref_imgs
        ]
        intrinsics_var = Variable(intrinsics.cuda(), volatile=True)
        intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True)

        # compute output
        disp = disp_net(tgt_img_var)
        depth = 1 / disp
        explainability_mask, pose = pose_exp_net(tgt_img_var, ref_imgs_var)

        loss_1 = photometric_reconstruction_loss(tgt_img_var, ref_imgs_var,
                                                 intrinsics_var,
                                                 intrinsics_inv_var, depth,
                                                 explainability_mask, pose,
                                                 args.rotation_mode,
                                                 args.padding_mode)
        loss_1 = loss_1.data[0]
        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask).data[0]
        else:
            loss_2 = 0
        loss_3 = smooth_loss(disp).data[0]

        if log_outputs and i % 100 == 0 and i / 100 < len(
                output_writers):  # log first output of every 100 batch
            index = int(i // 100)
            if epoch == 0:
                for j, ref in enumerate(ref_imgs):
                    output_writers[index].add_image('val Input {}'.format(j),
                                                    tensor2array(tgt_img[0]),
                                                    0)
                    output_writers[index].add_image('val Input {}'.format(j),
                                                    tensor2array(ref[0]), 1)

            output_writers[index].add_image(
                'val Dispnet Output Normalized',
                tensor2array(disp.data[0].cpu(),
                             max_value=None,
                             colormap='bone'), epoch)
            output_writers[index].add_image(
                'val Depth Output',
                tensor2array(1. / disp.data[0].cpu(), max_value=10), epoch)
            # log warped images along with explainability mask
            for j, ref in enumerate(ref_imgs_var):
                ref_warped = inverse_warp(ref[:1],
                                          depth[:1, 0],
                                          pose[:1, j],
                                          intrinsics_var[:1],
                                          intrinsics_inv_var[:1],
                                          rotation_mode=args.rotation_mode,
                                          padding_mode=args.padding_mode)[0]

                output_writers[index].add_image(
                    'val Warped Outputs {}'.format(j),
                    tensor2array(ref_warped.data.cpu()), epoch)
                output_writers[index].add_image(
                    'val Diff Outputs {}'.format(j),
                    tensor2array(
                        0.5 * (tgt_img_var[0] - ref_warped).abs().data.cpu()),
                    epoch)
                if explainability_mask is not None:
                    output_writers[index].add_image(
                        'val Exp mask Outputs {}'.format(j),
                        tensor2array(explainability_mask[0, j].data.cpu(),
                                     max_value=1,
                                     colormap='bone'), epoch)

        if log_outputs and i < len(val_loader) - 1:
            step = args.batch_size * (args.sequence_length - 1)
            poses[i * step:(i + 1) * step] = pose.data.cpu().view(-1,
                                                                  6).numpy()
            step = args.batch_size * 3
            disp_unraveled = disp.data.cpu().view(args.batch_size, -1)
            disp_values[i * step:(i + 1) * step] = torch.cat([
                disp_unraveled.min(-1)[0],
                disp_unraveled.median(-1)[0],
                disp_unraveled.max(-1)[0]
            ]).numpy()

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3
        losses.update([loss, loss_1, loss_2])

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.valid_writer.write('valid: Time {} Loss {}'.format(
                batch_time, losses))
    if log_outputs:
        prefix = 'valid poses'
        coeffs_names = ['tx', 'ty', 'tz']
        if args.rotation_mode == 'euler':
            coeffs_names.extend(['rx', 'ry', 'rz'])
        elif args.rotation_mode == 'quat':
            coeffs_names.extend(['qx', 'qy', 'qz'])
        for i in range(poses.shape[1]):
            output_writers.add_histogram(
                '{} {}'.format(prefix, coeffs_names[i]), poses[:, i], epoch)
        output_writers[0].add_histogram('disp_values', disp_values, epoch)
    logger.valid_bar.update(len(val_loader))
    return losses.avg, ['Total loss', 'Photo loss', 'Exp loss']
Пример #4
0
def train(train_loader, disp_net, pose_exp_net, optimizer, epoch_size, logger,
          train_writer):
    global args, n_iter
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight

    # switch to train mode
    disp_net.train()
    pose_exp_net.train()

    end = time.time()

    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        tgt_img_var = Variable(tgt_img.cuda())
        ref_imgs_var = [Variable(img.cuda()) for img in ref_imgs]
        intrinsics_var = Variable(intrinsics.cuda())
        intrinsics_inv_var = Variable(intrinsics_inv.cuda())

        # compute output
        disparities = disp_net(tgt_img_var)
        depth = [1 / disp for disp in disparities]
        explainability_mask, pose = pose_exp_net(tgt_img_var, ref_imgs_var)

        loss_1 = photometric_reconstruction_loss(tgt_img_var, ref_imgs_var,
                                                 intrinsics_var,
                                                 intrinsics_inv_var, depth,
                                                 explainability_mask, pose)
        loss_2 = explainability_loss(explainability_mask)
        loss_3 = smooth_loss(disparities)

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3

        train_writer.add_scalar('photometric_error', loss_1.data[0], n_iter)
        train_writer.add_scalar('explanability_loss', loss_2.data[0], n_iter)
        train_writer.add_scalar('disparity_smoothness_loss', loss_3.data[0],
                                n_iter)
        train_writer.add_scalar('total_loss', loss.data[0], n_iter)

        if n_iter % 200 == 0 and args.log_output:

            train_writer.add_image('train Input', tensor2array(ref_imgs[0][0]),
                                   n_iter - 1)
            train_writer.add_image('train Input', tensor2array(tgt_img[0]),
                                   n_iter)
            train_writer.add_image('train Input', tensor2array(ref_imgs[1][0]),
                                   n_iter + 1)

            for k, scaled_depth in enumerate(depth):
                train_writer.add_image(
                    'train Dispnet Output {}'.format(k),
                    tensor2array(disparities[k].data[0].cpu(),
                                 max_value=10,
                                 colormap='bone'), n_iter)
                train_writer.add_image(
                    'train Depth Output Normalized {}'.format(k),
                    tensor2array(1 / disparities[k].data[0].cpu(),
                                 max_value=None), n_iter)
                b, _, h, w = scaled_depth.size()
                downscale = tgt_img_var.size(2) / h

                tgt_img_scaled = nn.functional.adaptive_avg_pool2d(
                    tgt_img_var, (h, w))
                ref_imgs_scaled = [
                    nn.functional.adaptive_avg_pool2d(ref_img, (h, w))
                    for ref_img in ref_imgs_var
                ]

                intrinsics_scaled = torch.cat(
                    (intrinsics_var[:, 0:2] / downscale, intrinsics_var[:,
                                                                        2:]),
                    dim=1)
                intrinsics_scaled_inv = torch.cat(
                    (intrinsics_inv_var[:, :, 0:2] * downscale,
                     intrinsics_inv_var[:, :, 2:]),
                    dim=2)

                # log warped images along with explainability mask
                for j, ref in enumerate(ref_imgs_scaled):
                    ref_warped = inverse_warp(ref, scaled_depth[:, 0],
                                              pose[:, j], intrinsics_scaled,
                                              intrinsics_scaled_inv)[0]
                    train_writer.add_image(
                        'train Warped Outputs {} {}'.format(k, j),
                        tensor2array(ref_warped.data.cpu(), max_value=1),
                        n_iter)
                    train_writer.add_image(
                        'train Diff Outputs {} {}'.format(k, j),
                        tensor2array(
                            0.5 *
                            (tgt_img_scaled[0] - ref_warped).abs().data.cpu()),
                        n_iter)
                    train_writer.add_image(
                        'train Exp mask Outputs {} {}'.format(k, j),
                        tensor2array(explainability_mask[k][0, j].data.cpu(),
                                     max_value=1,
                                     colormap='bone'), n_iter)

        # record loss and EPE
        losses.update(loss.data[0], args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow(
                [loss.data[0], loss_1.data[0], loss_2.data[0], loss_3.data[0]])
        logger.train_bar.update(i)
        if i % args.print_freq == 0:
            logger.train_writer.write(
                'Train: Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                'Loss {loss.val:.4f} ({loss.avg:.4f}) '.format(
                    batch_time=batch_time, data_time=data_time, loss=losses))
        if i >= epoch_size - 1:
            break

        n_iter += 1

    return losses.avg
Пример #5
0
def train(train_loader, mask_net, pose_net, flow_net, optimizer, epoch_size,
          train_writer):
    global args, n_iter
    w1 = args.smooth_loss_weight
    w2 = args.mask_loss_weight
    w3 = args.consensus_loss_weight
    w4 = args.flow_loss_weight

    mask_net.train()
    pose_net.train()
    flow_net.train()
    average_loss = 0
    for i, (rgb_tgt_img, rgb_ref_imgs, depth_tgt_img, depth_ref_imgs,
            intrinsics, intrinsics_inv,
            pose_list) in enumerate(tqdm(train_loader)):
        rgb_tgt_img_var = Variable(rgb_tgt_img.cuda())
        # print(rgb_tgt_img_var.size())
        rgb_ref_imgs_var = [Variable(img.cuda()) for img in rgb_ref_imgs]
        # rgb_ref_imgs_var = [rgb_ref_imgs_var[0], rgb_ref_imgs_var[0], rgb_ref_imgs_var[1], rgb_ref_imgs_var[1]]
        depth_tgt_img_var = Variable(depth_tgt_img.unsqueeze(1).cuda())
        depth_ref_imgs_var = [
            Variable(img.unsqueeze(1).cuda()) for img in depth_ref_imgs
        ]
        intrinsics_var = Variable(intrinsics.cuda())
        intrinsics_inv_var = Variable(intrinsics_inv.cuda())
        # pose_list_var = [Variable(one_pose.float().cuda()) for one_pose in pose_list]

        explainability_mask = mask_net(rgb_tgt_img_var, rgb_ref_imgs_var)
        valid_pixle_mask = torch.where(
            depth_tgt_img_var == 0, torch.zeros_like(depth_tgt_img_var),
            torch.ones_like(depth_tgt_img_var))  # zero is invalid
        # print(depth_test[0].sum())

        # print(explainability_mask[0].size()) #torch.Size([4, 2, 384, 512])
        # print()
        pose = pose_net(rgb_tgt_img_var, rgb_ref_imgs_var)

        # generate flow from camera pose and depth
        flow_fwd, flow_bwd, _ = flow_net(rgb_tgt_img_var, rgb_ref_imgs_var)
        flows_cam_fwd = pose2flow(depth_ref_imgs_var[1].squeeze(1), pose[:, 1],
                                  intrinsics_var, intrinsics_inv_var)
        flows_cam_bwd = pose2flow(depth_ref_imgs_var[0].squeeze(1), pose[:, 0],
                                  intrinsics_var, intrinsics_inv_var)
        rigidity_mask_fwd = (flows_cam_fwd - flow_fwd[0]).abs()
        rigidity_mask_bwd = (flows_cam_bwd - flow_bwd[0]).abs()

        # loss 1: smoothness loss
        loss1 = smooth_loss(explainability_mask) + smooth_loss(
            flow_bwd) + smooth_loss(flow_fwd)

        # loss 2: explainability loss
        loss2 = explainability_loss(explainability_mask)

        # loss 3 consensus loss (the mask from networks and the mask from residual)
        depth_Res_mask, depth_ref_img_warped, depth_diff = depth_residual_mask(
            valid_pixle_mask, explainability_mask[0], rgb_tgt_img_var,
            rgb_ref_imgs_var, intrinsics_var, intrinsics_inv_var,
            depth_tgt_img_var, pose)
        # print(depth_Res_mask[0].size(), explainability_mask[0].size())

        loss3 = consensus_loss(explainability_mask[0], rigidity_mask_bwd,
                               rigidity_mask_fwd, args.THRESH, args.wbce)

        # loss 4: flow loss
        loss4, flow_ref_img_warped, flow_diff = flow_loss(
            rgb_tgt_img_var, rgb_ref_imgs_var, [flow_bwd, flow_fwd],
            explainability_mask)

        # compute gradient and do Adam step
        loss = w1 * loss1 + w2 * loss2 + w3 * loss3 + w4 * loss4
        average_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # visualization in tensorboard
        if i > 0 and n_iter % args.print_freq == 0:
            train_writer.add_scalar('smoothness loss', loss1.item(), n_iter)
            train_writer.add_scalar('explainability loss', loss2.item(),
                                    n_iter)
            train_writer.add_scalar('consensus loss', loss3.item(), n_iter)
            train_writer.add_scalar('flow loss', loss4.item(), n_iter)
            train_writer.add_scalar('total loss', loss.item(), n_iter)
        if n_iter % (args.training_output_freq) == 0:
            train_writer.add_image('train Input',
                                   tensor2array(rgb_tgt_img_var[0]), n_iter)
            train_writer.add_image(
                'train Exp mask Outputs ',
                tensor2array(explainability_mask[0][0, 0].data.cpu(),
                             max_value=1,
                             colormap='bone'), n_iter)
            train_writer.add_image(
                'train depth Res mask ',
                tensor2array(depth_Res_mask[0][0].data.cpu(),
                             max_value=1,
                             colormap='bone'), n_iter)
            train_writer.add_image(
                'train depth ',
                tensor2array(depth_tgt_img_var[0].data.cpu(),
                             max_value=1,
                             colormap='bone'), n_iter)
            train_writer.add_image(
                'train valid pixel ',
                tensor2array(valid_pixle_mask[0].data.cpu(),
                             max_value=1,
                             colormap='bone'), n_iter)
            train_writer.add_image(
                'train after mask',
                tensor2array(rgb_tgt_img_var[0] *
                             explainability_mask[0][0, 0]), n_iter)
            train_writer.add_image('train depth diff',
                                   tensor2array(depth_diff[0]), n_iter)
            train_writer.add_image('train flow diff',
                                   tensor2array(flow_diff[0]), n_iter)
            train_writer.add_image('train depth warped img',
                                   tensor2array(depth_ref_img_warped[0]),
                                   n_iter)
            train_writer.add_image('train flow warped img',
                                   tensor2array(flow_ref_img_warped[0]),
                                   n_iter)
            train_writer.add_image(
                'train Cam Flow Output',
                flow_to_image(tensor2array(flow_fwd[0].data[0].cpu())), n_iter)
            train_writer.add_image(
                'train Flow from Depth Output',
                flow_to_image(tensor2array(flows_cam_fwd.data[0].cpu())),
                n_iter)
            train_writer.add_image(
                'train Flow and Depth diff',
                flow_to_image(tensor2array(rigidity_mask_fwd.data[0].cpu())),
                n_iter)

        n_iter += 1

    return average_loss / i
Пример #6
0
def train(args, train_loader, disp_net, pose_exp_net, optimizer, epoch_size,
          logger, tb_writer):
    global n_iter, device
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)
    w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight

    # switch to train mode
    disp_net.train()
    pose_exp_net.train()

    end = time.time()
    logger.train_bar.update(0)

    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(train_loader):
        log_losses = i > 0 and n_iter % args.print_freq == 0
        log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0

        # measure data loading time
        data_time.update(time.time() - end)
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics = intrinsics.to(device)

        # compute output
        disparities = disp_net(tgt_img)
        depth = [1 / disp for disp in disparities]
        explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs)

        loss_1, warped, diff = photometric_reconstruction_loss(
            tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose,
            args.rotation_mode, args.padding_mode)
        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask)
        else:
            loss_2 = 0
        loss_3 = smooth_loss(depth)

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3

        if log_losses:
            tb_writer.add_scalar('photometric_error', loss_1.item(), n_iter)
            if w2 > 0:
                tb_writer.add_scalar('explanability_loss', loss_2.item(),
                                     n_iter)
            tb_writer.add_scalar('disparity_smoothness_loss', loss_3.item(),
                                 n_iter)
            tb_writer.add_scalar('total_loss', loss.item(), n_iter)

        if log_output:
            tb_writer.add_image('train Input', tensor2array(tgt_img[0]),
                                n_iter)
            for k, scaled_maps in enumerate(
                    zip(depth, disparities, warped, diff,
                        explainability_mask)):
                log_output_tensorboard(tb_writer, "train", 0, " {}".format(k),
                                       n_iter, *scaled_maps)

        # record loss and EPE
        losses.update(loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([
                loss.item(),
                loss_1.item(),
                loss_2.item() if w2 > 0 else 0,
                loss_3.item()
            ])
        logger.train_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.train_writer.write('Train: Time {} Data {} Loss {}'.format(
                batch_time, data_time, losses))
        if i >= epoch_size - 1:
            break

        n_iter += 1

    return losses.avg[0]
Пример #7
0
def validate_without_gt(args,
                        val_loader,
                        disp_net,
                        pose_exp_net,
                        epoch,
                        logger,
                        tb_writer,
                        sample_nb_to_log=3):
    global device
    batch_time = AverageMeter()
    losses = AverageMeter(i=3, precision=4)
    log_outputs = sample_nb_to_log > 0
    w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight
    poses = np.zeros(
        ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1),
         6))
    disp_values = np.zeros(((len(val_loader) - 1) * args.batch_size * 3))

    # switch to evaluate mode
    disp_net.eval()
    pose_exp_net.eval()

    end = time.time()
    logger.valid_bar.update(0)
    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(val_loader):
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics = intrinsics.to(device)
        intrinsics_inv = intrinsics_inv.to(device)

        # compute output
        disp = disp_net(tgt_img)
        depth = 1 / disp
        explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs)

        loss_1, warped, diff = photometric_reconstruction_loss(
            tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose,
            args.rotation_mode, args.padding_mode)
        loss_1 = loss_1.item()
        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask).item()
        else:
            loss_2 = 0
        loss_3 = smooth_loss(depth).item()

        if log_outputs and i < sample_nb_to_log - 1:  # log first output of first batches
            if epoch == 0:
                for j, ref in enumerate(ref_imgs):
                    tb_writer.add_image('val Input {}/{}'.format(j, i),
                                        tensor2array(tgt_img[0]), 0)
                    tb_writer.add_image('val Input {}/{}'.format(j, i),
                                        tensor2array(ref[0]), 1)

            log_output_tensorboard(tb_writer, 'val', i, '', epoch, 1. / disp,
                                   disp, warped[0], diff[0],
                                   explainability_mask)

        if log_outputs and i < len(val_loader) - 1:
            step = args.batch_size * (args.sequence_length - 1)
            poses[i * step:(i + 1) * step] = pose.cpu().view(-1, 6).numpy()
            step = args.batch_size * 3
            disp_unraveled = disp.cpu().view(args.batch_size, -1)
            disp_values[i * step:(i + 1) * step] = torch.cat([
                disp_unraveled.min(-1)[0],
                disp_unraveled.median(-1)[0],
                disp_unraveled.max(-1)[0]
            ]).numpy()

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3

        if args.with_photocon_loss:
            batch_size = pose.size()[0]
            homo_row = torch.tensor([[0, 0, 0, 1]],
                                    dtype=torch.float).to(device)
            homo_row = homo_row.unsqueeze(0).expand(batch_size, -1, -1)
            T21 = pose_vec2mat(pose[:, 0])
            T21 = torch.cat((T21, homo_row), 1)
            T12 = torch.inverse(T21)
            T23 = pose_vec2mat(pose[:, 1])
            T23 = torch.cat((T23, homo_row), 1)
            T13 = torch.matmul(T23, T12)  #[B,4,4]
            #             print("----",T13.size())
            # target = 1(ref_imgs[0]) and ref = 3(ref_imgs[1])
            ref_img_warped, valid_points = inverse_warp_posemat(
                ref_imgs[1], depth[:, 0], T13, intrinsics, args.rotation_mode,
                args.padding_mode)
            diff = (ref_imgs[0] -
                    ref_img_warped) * valid_points.unsqueeze(1).float()
            loss_4 = diff.abs().mean()

            loss += loss_4

        losses.update([loss, loss_1, loss_2])

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.valid_writer.write('valid: Time {} Loss {}'.format(
                batch_time, losses))
    if log_outputs:
        prefix = 'valid poses'
        coeffs_names = ['tx', 'ty', 'tz']
        if args.rotation_mode == 'euler':
            coeffs_names.extend(['rx', 'ry', 'rz'])
        elif args.rotation_mode == 'quat':
            coeffs_names.extend(['qx', 'qy', 'qz'])
        for i in range(poses.shape[1]):
            tb_writer.add_histogram('{} {}'.format(prefix, coeffs_names[i]),
                                    poses[:, i], epoch)
        tb_writer.add_histogram('disp_values', disp_values, epoch)
    logger.valid_bar.update(len(val_loader))
    return losses.avg, [
        'Validation Total loss', 'Validation Photo loss', 'Validation Exp loss'
    ]
Пример #8
0
def validate_without_gt(args,
                        val_loader,
                        disp_net,
                        pose_exp_net,
                        epoch,
                        tb_writer,
                        sample_nb_to_log=3):
    global device
    batch_time = AverageMeter()
    losses = AverageMeter(i=3, precision=4)
    log_outputs = sample_nb_to_log > 0
    w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight
    poses = np.zeros(
        ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1),
         6))
    disp_values = np.zeros(((len(val_loader) - 1) * args.batch_size * 3))

    # switch to evaluate mode
    disp_net.eval()
    pose_exp_net.eval()

    end = time.time()

    validate_pbar = tqdm(
        total=len(val_loader),
        bar_format='{desc} {percentage:3.0f}%|{bar}| {postfix}')
    validate_pbar.set_description(
        'valid: Loss *.**** *.*****.****(*.**** *.**** *.****)')
    validate_pbar.set_postfix_str('<Time *.***(*.***)>')

    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(val_loader):
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics = intrinsics.to(device)
        intrinsics_inv = intrinsics_inv.to(device)

        # compute output
        disp = disp_net(tgt_img)
        depth = 1 / disp

        explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs)
        loss_1, warped, diff = photometric_reconstruction_loss(
            tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose,
            args.rotation_mode, args.padding_mode)
        loss_1 = loss_1.item()
        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask).item()
        else:
            loss_2 = 0
        loss_3 = smooth_loss(depth).item()

        if log_outputs and i < sample_nb_to_log - 1:  # log first output of first batches
            if epoch == 0:
                for j, ref in enumerate(ref_imgs):
                    tb_writer.add_image('val Input {}/{}'.format(j, i),
                                        tensor2array(tgt_img[0]), 0)
                    tb_writer.add_image('val Input {}/{}'.format(j, i),
                                        tensor2array(ref[0]), 1)

            log_output_tensorboard(tb_writer, 'val', i, '', epoch, 1. / disp,
                                   disp, warped, diff, explainability_mask)

        if log_outputs and i < len(val_loader) - 1:
            step = args.batch_size * (args.sequence_length - 1)
            poses[i * step:(i + 1) * step] = pose.cpu().view(-1, 6).numpy()
            step = args.batch_size * 3
            disp_unraveled = disp.cpu().view(args.batch_size, -1)
            disp_values[i * step:(i + 1) * step] = torch.cat([
                disp_unraveled.min(-1)[0],
                disp_unraveled.median(-1)[0],
                disp_unraveled.max(-1)[0]
            ]).numpy()

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3
        losses.update([loss, loss_1, loss_2])

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        validate_pbar.clear()
        validate_pbar.update(1)
        validate_pbar.set_description('valid: Loss {}'.format(losses))
        validate_pbar.set_postfix_str('<Time {}>'.format(batch_time))
    validate_pbar.close()

    if log_outputs:
        prefix = 'valid poses'
        coeffs_names = ['tx', 'ty', 'tz']
        if args.rotation_mode == 'euler':
            coeffs_names.extend(['rx', 'ry', 'rz'])
        elif args.rotation_mode == 'quat':
            coeffs_names.extend(['qx', 'qy', 'qz'])
        for i in range(poses.shape[1]):
            tb_writer.add_histogram('{} {}'.format(prefix, coeffs_names[i]),
                                    poses[:, i], epoch)
        tb_writer.add_histogram('disp_values', disp_values, epoch)
        time.sleep(0.2)
    else:
        time.sleep(1)
    return losses.avg, ['Total loss', 'Photo loss', 'Exp loss']
Пример #9
0
def validate_without_gt(args,
                        val_loader,
                        disp_net,
                        pose_exp_net,
                        epoch,
                        logger,
                        tb_writer,
                        sample_nb_to_log=3):
    global device
    mse_l = torch.nn.MSELoss(reduction='mean')
    batch_time = AverageMeter()
    losses = AverageMeter(i=3, precision=4)
    log_outputs = sample_nb_to_log > 0
    # Output the logs throughout the whole dataset
    batches_to_log = list(
        np.linspace(0, len(val_loader), sample_nb_to_log).astype(int))
    w1, w2, w3, wf, wp = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight, args.flow_loss_weight, args.prior_loss_weight
    poses = np.zeros(
        ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1),
         6))
    disp_values = np.zeros(((len(val_loader) - 1) * args.batch_size * 3))

    # switch to evaluate mode
    disp_net.eval()
    pose_exp_net.eval()

    end = time.time()
    logger.valid_bar.update(0)
    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv,
            flow_maps) in enumerate(val_loader):
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics = intrinsics.to(device)
        intrinsics_inv = intrinsics_inv.to(device)
        flow_maps = [flow_map.to(device) for flow_map in flow_maps]

        # compute output
        disp = disp_net(tgt_img)
        depth = 1 / disp
        explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs)

        loss_1, warped, diff, grid = photometric_reconstruction_loss(
            tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose,
            args.rotation_mode, args.padding_mode)
        loss_1 = loss_1.item()

        if wf > 0:
            loss_f = flow_consistency_loss(grid, flow_maps, mse_l)
        else:
            loss_f = 0

        if wp > 0:
            loss_p = ground_prior_loss(disp)
        else:
            loss_p = 0

        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask).item()
        else:
            loss_2 = 0
        loss_3 = smooth_loss(depth).item()

        if log_outputs and i in batches_to_log:  # log first output of wanted batches
            index = batches_to_log.index(i)
            if epoch == 0:
                for j, ref in enumerate(ref_imgs):
                    tb_writer.add_image('val Input {}/{}'.format(j, index),
                                        tensor2array(tgt_img[0]), 0)
                    tb_writer.add_image('val Input {}/{}'.format(j, index),
                                        tensor2array(ref[0]), 1)

            log_output_tensorboard(tb_writer, 'val', index, '', epoch,
                                   1. / disp, disp, warped[0], diff[0],
                                   explainability_mask)

        if log_outputs and i < len(val_loader) - 1:
            step = args.batch_size * (args.sequence_length - 1)
            poses[i * step:(i + 1) * step] = pose.cpu().view(-1, 6).numpy()
            step = args.batch_size * 3
            disp_unraveled = disp.cpu().view(args.batch_size, -1)
            disp_values[i * step:(i + 1) * step] = torch.cat([
                disp_unraveled.min(-1)[0],
                disp_unraveled.median(-1)[0],
                disp_unraveled.max(-1)[0]
            ]).numpy()

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 + wf * loss_f + wp * loss_p
        losses.update([loss, loss_1, loss_2])

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.valid_writer.write('valid: Time {} Loss {}'.format(
                batch_time, losses))
    if log_outputs:
        prefix = 'valid poses'
        coeffs_names = ['tx', 'ty', 'tz']
        if args.rotation_mode == 'euler':
            coeffs_names.extend(['rx', 'ry', 'rz'])
        elif args.rotation_mode == 'quat':
            coeffs_names.extend(['qx', 'qy', 'qz'])
        for i in range(poses.shape[1]):
            tb_writer.add_histogram('{} {}'.format(prefix, coeffs_names[i]),
                                    poses[:, i], epoch)
        tb_writer.add_histogram('disp_values', disp_values, epoch)
    logger.valid_bar.update(len(val_loader))
    return losses.avg, [
        'Validation Total loss', 'Validation Photo loss', 'Validation Exp loss'
    ]
Пример #10
0
def validate_without_gt(val_loader,
                        disp_net,
                        pose_net,
                        mask_net,
                        flow_net,
                        epoch,
                        logger,
                        tb_writer,
                        nb_writers,
                        global_vars_dict=None):
    #data prepared
    device = global_vars_dict['device']
    n_iter_val = global_vars_dict['n_iter_val']
    args = global_vars_dict['args']
    show_samples = copy.deepcopy(args.show_samples)
    for i in range(len(show_samples)):
        show_samples[i] *= len(val_loader)
        show_samples[i] = show_samples[i] // 1

    batch_time = AverageMeter()
    data_time = AverageMeter()
    log_outputs = nb_writers > 0
    losses = AverageMeter(precision=4)

    w1, w2, w3, w4 = args.cam_photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight, args.flow_photo_loss_weight
    w5 = args.consensus_loss_weight

    loss_camera = photometric_reconstruction_loss
    loss_flow = photometric_flow_loss

    # to eval model
    disp_net.eval()
    pose_net.eval()
    mask_net.eval()
    flow_net.eval()

    end = time.time()
    poses = np.zeros(
        ((len(val_loader) - 1) * 1 * (args.sequence_length - 1), 6))  #init

    #3. validation cycle
    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(val_loader):
        data_time.update(time.time() - end)
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics, intrinsics_inv = intrinsics.to(device), intrinsics_inv.to(
            device)
        #3.1 forwardpass
        #disp
        disp = disp_net(tgt_img)
        if args.spatial_normalize:
            disp = spatial_normalize(disp)
        depth = 1 / disp

        #pose
        pose = pose_net(tgt_img, ref_imgs)  #[b,3,h,w]; list

        #flow----
        #制作前后一帧的
        if args.flownet == 'Back2Future':
            flow_fwd, flow_bwd, _ = flow_net(tgt_img, ref_imgs[1:3])
        elif args.flownet == 'FlowNetC6':
            flow_fwd = flow_net(tgt_img, ref_imgs[2])
            flow_bwd = flow_net(tgt_img, ref_imgs[1])
        flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics,
                             intrinsics_inv)

        flows_cam_fwd = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics,
                                  intrinsics_inv)
        flows_cam_bwd = pose2flow(depth.squeeze(1), pose[:, 1], intrinsics,
                                  intrinsics_inv)

        exp_masks_target = consensus_exp_masks(flows_cam_fwd,
                                               flows_cam_bwd,
                                               flow_fwd,
                                               flow_bwd,
                                               tgt_img,
                                               ref_imgs[2],
                                               ref_imgs[1],
                                               wssim=args.wssim,
                                               wrig=args.wrig,
                                               ws=args.smooth_loss_weight)
        no_rigid_flow = flow_fwd - flows_cam_fwd

        rigidity_mask_fwd = (flows_cam_fwd - flow_fwd).abs()  #[b,2,h,w]
        rigidity_mask_bwd = (flows_cam_bwd - flow_bwd).abs()

        # mask
        # 4.explainability_mask(none)
        explainability_mask = mask_net(tgt_img, ref_imgs)  # 有效区域?4??

        # list(5):item:tensor:[4,4,128,512]...[4,4,4,16] value:[0.33~0.48~0.63]

        if args.joint_mask_for_depth:  # false
            explainability_mask_for_depth = explainability_mask

            #explainability_mask_for_depth = compute_joint_mask_for_depth(explainability_mask, rigidity_mask_bwd,
            #                                                            rigidity_mask_fwd,THRESH=args.THRESH)
        else:
            explainability_mask_for_depth = explainability_mask

        # chage

        if args.no_non_rigid_mask:
            flow_exp_mask = None
            if args.DEBUG:
                print('Using no masks for flow')
        else:
            flow_exp_mask = 1 - explainability_mask[:, 1:3]

        #3.2loss-compute
        if w1 > 0:
            loss_1 = loss_camera(tgt_img,
                                 ref_imgs,
                                 intrinsics,
                                 intrinsics_inv,
                                 depth,
                                 explainability_mask_for_depth,
                                 pose,
                                 lambda_oob=args.lambda_oob,
                                 qch=args.qch,
                                 wssim=args.wssim)
        else:
            loss_1 = torch.tensor([0.]).to(device)

        # E_M
        if w2 > 0:
            loss_2 = explainability_loss(
                explainability_mask
            )  # + 0.2*gaussian_explainability_loss(explainability_mask)
        else:
            loss_2 = 0

        #if args.smoothness_type == "regular":
        if w3 > 0:
            loss_3 = smooth_loss(depth) + smooth_loss(
                explainability_mask) + smooth_loss(flow_fwd) + smooth_loss(
                    flow_bwd)
        else:
            loss_3 = torch.tensor([0.]).to(device)
        if w4 > 0:
            loss_4 = loss_flow(tgt_img,
                               ref_imgs[1:3], [flow_bwd, flow_fwd],
                               flow_exp_mask,
                               lambda_oob=args.lambda_oob,
                               qch=args.qch,
                               wssim=args.wssim)
        else:
            loss_4 = torch.tensor([0.]).to(device)
        if w5 > 0:
            loss_5 = consensus_depth_flow_mask(explainability_mask,
                                               rigidity_mask_bwd,
                                               rigidity_mask_fwd,
                                               exp_masks_target,
                                               exp_masks_target,
                                               THRESH=args.THRESH,
                                               wbce=args.wbce)
        else:
            loss_5 = torch.tensor([0.]).to(device)

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 + w4 * loss_4 + w5 * loss_5

        #3.3 data update
        losses.update(loss.item(), args.batch_size)
        batch_time.update(time.time() - end)
        end = time.time()

        #3.4 check log

        #查看forward pass效果
        if args.img_freq > 0 and i in show_samples:  #output_writers list(3)
            if epoch == 0:  #训练前的validate,目的在于先评估下网络效果
                #1.img
                # 不会执行第二次,注意ref_imgs axis0是batch的索引; axis 1是list(adjacent frame)的索引!
                tb_writer.add_image(
                    'epoch 0 Input/sample{}(img{} to img{})'.format(
                        i, i + 1, i + args.sequence_length),
                    tensor2array(ref_imgs[0][0]), 0)
                tb_writer.add_image(
                    'epoch 0 Input/sample{}(img{} to img{})'.format(
                        i, i + 1, i + args.sequence_length),
                    tensor2array(ref_imgs[1][0]), 1)
                tb_writer.add_image(
                    'epoch 0 Input/sample{}(img{} to img{})'.format(
                        i, i + 1, i + args.sequence_length),
                    tensor2array(tgt_img[0]), 2)
                tb_writer.add_image(
                    'epoch 0 Input/sample{}(img{} to img{})'.format(
                        i, i + 1, i + args.sequence_length),
                    tensor2array(ref_imgs[2][0]), 3)
                tb_writer.add_image(
                    'epoch 0 Input/sample{}(img{} to img{})'.format(
                        i, i + 1, i + args.sequence_length),
                    tensor2array(ref_imgs[3][0]), 4)

                depth_to_show = depth[0].cpu(
                )  # tensor disp_to_show :[1,h,w],0.5~3.1~10
                tb_writer.add_image(
                    'Disp Output/sample{}'.format(i),
                    tensor2array(depth_to_show,
                                 max_value=None,
                                 colormap='bone'), 0)

            else:
                #2.disp
                depth_to_show = disp[0].cpu(
                )  # tensor disp_to_show :[1,h,w],0.5~3.1~10
                tb_writer.add_image(
                    'Disp Output/sample{}'.format(i),
                    tensor2array(depth_to_show,
                                 max_value=None,
                                 colormap='bone'), epoch)
                #3. flow
                tb_writer.add_image('Flow/Flow Output sample {}'.format(i),
                                    flow2rgb(flow_fwd[0], max_value=6), epoch)
                tb_writer.add_image('Flow/cam_Flow Output sample {}'.format(i),
                                    flow2rgb(flow_cam[0], max_value=6), epoch)
                tb_writer.add_image(
                    'Flow/no rigid flow Output sample {}'.format(i),
                    flow2rgb(no_rigid_flow[0], max_value=6), epoch)
                tb_writer.add_image(
                    'Flow/rigidity_mask_fwd{}'.format(i),
                    flow2rgb(rigidity_mask_fwd[0], max_value=6), epoch)

                #4. mask
                tb_writer.add_image(
                    'Mask Output/mask0 sample{}'.format(i),
                    tensor2array(explainability_mask[0][0],
                                 max_value=None,
                                 colormap='magma'), epoch)
                #tb_writer.add_image('Mask Output/mask1 sample{}'.format(i),tensor2array(explainability_mask[1][0], max_value=None, colormap='magma'), epoch)
                #tb_writer.add_image('Mask Output/mask2 sample{}'.format(i),tensor2array(explainability_mask[2][0], max_value=None, colormap='magma'), epoch)
                #tb_writer.add_image('Mask Output/mask3 sample{}'.format(i),tensor2array(explainability_mask[3][0], max_value=None, colormap='magma'), epoch)
                tb_writer.add_image(
                    'Mask Output/exp_masks_target sample{}'.format(i),
                    tensor2array(exp_masks_target[0][0],
                                 max_value=None,
                                 colormap='magma'), epoch)
                #tb_writer.add_image('Mask Output/mask0 sample{}'.format(i),
                #            tensor2array(explainability_mask[0][0], max_value=None, colormap='magma'), epoch)

        #

        #output_writers[index].add_image('val Depth Output', tensor2array(depth.data[0].cpu(), max_value=10),
        #                               epoch)

        # errors.update(compute_errors(depth, output_depth.data.squeeze(1)))
        # add scalar
        if args.scalar_freq > 0 and n_iter_val % args.scalar_freq == 0:
            tb_writer.add_scalar('val/E_R', loss_1.item(), n_iter_val)
            if w2 > 0:
                tb_writer.add_scalar('val/E_M', loss_2.item(), n_iter_val)
            tb_writer.add_scalar('val/E_S', loss_3.item(), n_iter_val)
            tb_writer.add_scalar('val/E_F', loss_4.item(), n_iter_val)
            tb_writer.add_scalar('val/E_C', loss_5.item(), n_iter_val)
            tb_writer.add_scalar('val/total_loss', loss.item(), n_iter_val)

        # terminal output
        if args.log_terminal:
            logger.valid_bar.update(i + 1)  # 当前epoch 进度
            if i % args.print_freq == 0:
                logger.valid_bar_writer.write(
                    'Valid: Time {} Data {} Loss {}'.format(
                        batch_time, data_time, losses))

        n_iter_val += 1

    global_vars_dict['n_iter_val'] = n_iter_val
    return losses.avg[0]  #epoch validate loss
Пример #11
0
def train(args, train_loader, disp_net, pose_exp_net, optimizer, epoch_size,
          tb_writer):
    global n_iter, device
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)
    w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight

    # switch to train mode
    disp_net.train()
    pose_exp_net.train()

    end = time.time()

    train_pbar = tqdm(total=min(len(train_loader), args.epoch_size),
                      bar_format='{desc} {percentage:3.0f}%|{bar}| {postfix}')
    train_pbar.set_description('Train: Total Loss=#.####(#.####)')
    train_pbar.set_postfix_str('<TIME: op=#.###(#.###) DataFlow=#.###(#.###)>')

    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(train_loader):
        log_losses = i > 0 and n_iter % args.print_freq == 0
        log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0

        # measure DataFlow loading time
        data_time.update(time.time() - end)
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics = intrinsics.to(device)

        # compute output
        disparities = disp_net(tgt_img)
        depth = [1 / disp for disp in disparities]
        explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs)

        loss_1, warped, diff = photometric_reconstruction_loss(
            tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose,
            args.rotation_mode, args.padding_mode)
        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask)
        else:
            loss_2 = 0
        loss_3 = smooth_loss(depth)

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3

        if loss < 0.0005:
            abc = 0
        if log_losses:
            tb_writer.add_scalar('photometric_error', loss_1.item(), n_iter)
            if w2 > 0:
                tb_writer.add_scalar('explanability_loss', loss_2.item(),
                                     n_iter)
            tb_writer.add_scalar('disparity_smoothness_loss', loss_3.item(),
                                 n_iter)
            tb_writer.add_scalar('total_loss', loss.item(), n_iter)

        if log_output:
            tb_writer.add_image('train Input', tensor2array(tgt_img[0]),
                                n_iter)
            for k, scaled_maps in enumerate(
                    zip(depth, disparities, warped, diff,
                        explainability_mask)):
                log_output_tensorboard(tb_writer, "train", 0, k, n_iter,
                                       *scaled_maps)

        # record loss and EPE
        losses.update(loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([
                loss.item(),
                loss_1.item(),
                loss_2.item() if w2 > 0 else 0,
                loss_3.item()
            ])
        train_pbar.clear()
        train_pbar.update(1)
        train_pbar.set_description('Train: Total Loss={}'.format(losses))
        train_pbar.set_postfix_str('<TIME: op={} DataFlow={}>'.format(
            batch_time, data_time))
        if i >= epoch_size - 1:
            break

        n_iter += 1
    train_pbar.close()
    time.sleep(1)
    return losses.avg[0]
Пример #12
0
def validate_without_gt(val_loader,
                        disp_net,
                        pose_net,
                        mask_net,
                        epoch,
                        logger,
                        output_writers=[]):
    #data prepared
    global args, device, n_iter_val
    batch_time = AverageMeter()
    data_time = AverageMeter()
    log_outputs = len(output_writers) > 0
    losses = AverageMeter(precision=4)

    w1, w2, w3, w4 = args.cam_photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight, args.flow_photo_loss_weight

    loss_camera = photometric_reconstruction_loss
    loss_flow = photometric_flow_loss

    # to eval model
    disp_net.eval()
    pose_net.eval()
    mask_net.eval()
    #flow_net.eval()

    end = time.time()
    poses = np.zeros(
        ((len(val_loader) - 1) * 1 * (args.sequence_length - 1), 6))  #init

    #3. validation cycle
    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(val_loader):
        data_time.update(time.time() - end)
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics, intrinsics_inv = intrinsics.to(device), intrinsics_inv.to(
            device)
        #3.1 forwardpass
        #disp
        disp = disp_net(tgt_img)
        if args.spatial_normalize:
            disp = spatial_normalize(disp)
        depth = 1 / disp

        #pose
        pose = pose_net(tgt_img, ref_imgs)

        #mask
        # 4.explainability_mask(none)
        explainability_mask = mask_net(tgt_img, ref_imgs)  # 有效区域?4??
        # list(5):item:tensor:[4,4,128,512]...[4,4,4,16] value:[0.33~0.48~0.63]
        # -------------------------------------------------

        if args.joint_mask_for_depth:
            explainability_mask_for_depth = compute_joint_mask_for_depth(
                explainability_mask, rigidity_mask_bwd, rigidity_mask_fwd)
        else:
            explainability_mask_for_depth = explainability_mask

    #3.2loss-compute
        loss_1 = loss_camera(tgt_img,
                             ref_imgs,
                             intrinsics,
                             intrinsics_inv,
                             depth,
                             explainability_mask_for_depth,
                             pose,
                             lambda_oob=args.lambda_oob,
                             qch=args.qch,
                             wssim=args.wssim)

        # E_M
        if w2 > 0:
            loss_2 = explainability_loss(
                explainability_mask
            )  # + 0.2*gaussian_explainability_loss(explainability_mask)
        else:
            loss_2 = 0

        #if args.smoothness_type == "regular":
        loss_3 = smooth_loss(depth) + smooth_loss(explainability_mask)

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3

        #3.3 data update
        losses.update(loss.item(), args.batch_size)
        batch_time.update(time.time() - end)
        end = time.time()

        #3.4 check log

        #查看forward pass效果
        if log_outputs and i % 40 == 0 and i / 100 < len(
                output_writers):  #output_writers list(3)
            index = int(i // 40)
            #disp = disp.data.cpu()[0]
            #disp = (255 * tensor2array(disp, max_value=None, colormap='bone')).astype(np.uint8)
            #disp = disp.transpose(1, 2, 0)

            if epoch == 0:
                output_writers[index].add_image('val Input',
                                                tensor2array(tgt_img[0]), 0)

                disp_to_show = disp[0].cpu(
                )  # tensor disp_to_show :[1,h,w],0.5~3.1~10
                output_writers[index].add_image(
                    'val target disp222',
                    tensor2array(disp_to_show,
                                 max_value=None,
                                 colormap='magma'), epoch)
                save = (255 * tensor2array(
                    disp_to_show, max_value=None, colormap='magma')).astype(
                        np.uint8)
                save = save.transpose(1, 2, 0)
                plt.imsave('ep1_test.jpg', save, cmap='plasma')


#                depth_to_show[depth_to_show == 0] = 1000
#               disp_to_show = (1 / depth_to_show).clamp(0, 10)
#              output_writers[index].add_image('val target Disparity Normalized',
#                                             tensor2array(disp_to_show, max_value=None, colormap='bone'), epoch)

            output_writers[index].add_image(
                'val Dispnet Output Normalized123',
                tensor2array(disp.data[0].cpu(),
                             max_value=None,
                             colormap='bone'), epoch)
            #output_writers[index].add_image('val Depth Output', tensor2array(depth.data[0].cpu(), max_value=10),
            #                               epoch)

        # errors.update(compute_errors(depth, output_depth.data.squeeze(1)))
        # add scalar
        if args.scalar_freq > 0 and n_iter_val % args.scalar_freq == 0:
            output_writers[0].add_scalar('val/cam_photometric_error',
                                         loss_1.item(), n_iter_val)
            if w2 > 0:
                output_writers[0].add_scalar('val/explanability_loss',
                                             loss_2.item(), n_iter_val)
            output_writers[0].add_scalar('val/disparity_smoothness_loss',
                                         loss_3.item(), n_iter_val)
            #output_writers[0].add_scalar('batch/flow_photometric_error', loss_4.item(), n_iter)
            #output_writers.add_scalar('batch/consensus_error', loss_5.item(), n_iter)
            output_writers[0].add_scalar('val/total_loss', loss.item(),
                                         n_iter_val)

        # terminal output
        if args.log_terminal:
            logger.valid_bar.update(i + 1)  # 当前epoch 进度
            if i % args.print_freq == 0:
                logger.valid_bar_writer.write(
                    'Valid: Time {} Data {} Loss {}'.format(
                        batch_time, data_time, losses))

        n_iter_val += 1

    return losses.avg[0]  #epoch validate loss
Пример #13
0
def train(odometry_net, depth_net, feat_extractor, train_loader, epoch,
          optimizer):
    global device
    global data_parallel
    if data_parallel:
        odometry_net.module.set_fix_method(nfp.FIX_AUTO)
    else:
        odometry_net.set_fix_method(nfp.FIX_AUTO)
    odometry_net.train()
    depth_net.train()
    feat_extractor.train()
    total_loss = 0
    img_reconstruction_total = 0
    f_reconstruction_total = 0
    smooth_total = 0
    for batch_idx, (img_R1, img_L2, img_R2, intrinsics, inv_intrinsics, raw_K,
                    T_R2L) in tqdm(enumerate(train_loader),
                                   desc='Train epoch %d' % epoch,
                                   leave=False,
                                   ncols=80):
        img_R1 = img_R1.type(torch.FloatTensor).to(device)
        img_R2 = img_R2.type(torch.FloatTensor).to(device)
        img_L2 = img_L2.type(torch.FloatTensor).to(device)
        intrinsics = intrinsics.type(torch.FloatTensor).to(device)
        inv_intrinsics = inv_intrinsics.type(torch.FloatTensor).to(device)
        raw_K = raw_K.type(torch.FloatTensor).to(device)
        T_R2L = T_R2L.type(torch.FloatTensor).to(device)

        img_R = torch.cat((img_R2, img_R1), dim=1)

        inv_depth_img_R2 = depth_net(img_R2)
        T_2to1, _ = odometry_net(img_R)
        T_2to1 = T_2to1.view(T_2to1.size(0), -1)
        T_R2L = T_R2L.view(T_R2L.size(0), -1)

        depth = (1 / (inv_depth_img_R2 + 1e-4)).squeeze(1)

        img_reconstruction_error = photometric_reconstruction_loss(
            0.004 * img_R2, 0.004 * img_R1, 0.004 * img_L2, depth, T_2to1,
            T_R2L, intrinsics, inv_intrinsics)
        smooth_error = smooth_loss(depth.unsqueeze(1))

        imgs = torch.cat((img_L2, img_R2, img_R1), dim=0)
        feat = feat_extractor(imgs)
        batch_size = img_R1.size(0)
        f_L2, f_R2, f_R1 = feat[:batch_size, :, :, :], feat[
            batch_size:batch_size * 2, :, :, :], feat[2 * batch_size:, :, :, :]

        feat_reconstruction_error = photometric_reconstruction_loss(
            f_R2, f_R1, f_L2, depth, T_2to1, T_R2L, intrinsics, inv_intrinsics)

        loss = img_reconstruction_error + 0.1 * feat_reconstruction_error + 10 * smooth_error

        total_loss += loss.item()
        img_reconstruction_total += img_reconstruction_error.item()
        f_reconstruction_total += feat_reconstruction_error.item()
        smooth_total += smooth_error.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(
        "Train epoch {}: loss: {:.9f} img-recon-loss: {:.9f} f-recon-loss: {:.9f} smooth-loss: {:.9f}"
        .format(epoch, total_loss / len(train_loader),
                img_reconstruction_total / len(train_loader),
                f_reconstruction_total / len(train_loader),
                smooth_total / len(train_loader)))
Пример #14
0
def train(args, train_loader, disp_net, pose_exp_net, optimizer, epoch_size,
          logger, train_writer):
    global n_iter, device
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)
    w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight

    # switch to train mode
    disp_net.train()
    pose_exp_net.train()

    end = time.time()
    logger.train_bar.update(0)
    #train main cycle
    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(train_loader):
        #for (i, data) in enumerate(train_loader):#data(list): [tensor(B,3,H,W),list(B),(B,H,W),(b,h,w)]
        log_losses = i > 0 and n_iter % args.print_freq == 0
        log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0

        #1 measure data loading time
        data_time.update(time.time() - end)
        tgt_img = tgt_img.to(device)  #(4,3,128,416)
        ref_imgs = [img.to(device) for img in ref_imgs]  #batch size张图片的前一帧和后一帧
        intrinsics = intrinsics.to(device)  #(4,3,3)
        """forward and loss"""
        #2 compute output
        disparities = disp_net(
            tgt_img
        )  # lenth batch-size list of tensor(4,1,128,416) ,(4,1,64,208),(4,1,32,104),(4,1,16,52)]

        explainability_mask, pose = pose_exp_net(
            tgt_img,
            ref_imgs)  #pose tensor(bs,sq-lenth-1,6), relative camera pose

        depth = [1 / disp
                 for disp in disparities]  #depth = fxT/(d) 成反比关系,简单取倒数

        #3 loss compute
        loss_1, warped, diff = photometric_reconstruction_loss(
            tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose,
            args.rotation_mode, args.padding_mode)
        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask)
        else:
            loss_2 = 0

        loss_3 = smooth_loss(depth)

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3

        #4. 数据记录 tensorboard batch-record data, 而且不用初始化数据名称(自动初始化),直接往里面加
        if log_losses:
            train_writer.add_scalar('photometric_error', loss_1.item(), n_iter)
            if w2 > 0:
                train_writer.add_scalar('explanabilityyyyyy_loss',
                                        loss_2.item(), n_iter)
            train_writer.add_scalar('disparity_smoothness_loss', loss_3.item(),
                                    n_iter)
            train_writer.add_scalar('total_loss', loss.item(), n_iter)

        if log_output:  #数据弄到tensorboard可读文件里去, 名字就是events开头(defaulted)
            train_writer.add_image('train Input', tensor2array(tgt_img[0]),
                                   n_iter)
            for k, scaled_maps in enumerate(
                    zip(depth, disparities, warped, diff,
                        explainability_mask)):
                log_output_tensorboard(train_writer, "train", k, n_iter,
                                       *scaled_maps)

        # record loss and EPE
        losses.update(loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        #csv record
        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([
                loss.item(),
                loss_1.item(),
                loss_2.item() if w2 > 0 else 0,
                loss_3.item()
            ])
        logger.train_bar.update(i + 1)

        if i % args.print_freq == 0:
            logger.train_writer.write('Train: Time {} Data {} Loss {}'.format(
                batch_time, data_time, losses))
        if i >= epoch_size - 1:
            break

        n_iter += 1

    return losses.avg[0]
Пример #15
0
def validate(val_loader,
             disp_net,
             pose_exp_net,
             epoch,
             logger,
             output_writers=[]):
    global args
    batch_time = AverageMeter()
    losses = AverageMeter(i=3, precision=4)
    log_outputs = len(output_writers) > 0
    w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight
    poses = np.zeros(
        ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1),
         6))

    # switch to evaluate mode
    disp_net.eval()
    pose_exp_net.eval()

    end = time.time()

    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(val_loader):
        tgt_img_var = Variable(tgt_img.cuda(), volatile=True)
        ref_imgs_var = [
            Variable(img.cuda(), volatile=True) for img in ref_imgs
        ]
        intrinsics_var = Variable(intrinsics.cuda(), volatile=True)
        intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True)

        # compute output
        disp = disp_net(tgt_img_var)
        depth = 1 / disp
        explainability_mask, pose = pose_exp_net(tgt_img_var, ref_imgs_var)

        loss_1 = photometric_reconstruction_loss(tgt_img_var, ref_imgs_var,
                                                 intrinsics_var,
                                                 intrinsics_inv_var, depth,
                                                 explainability_mask, pose)
        loss_1 = loss_1.data[0]
        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask).data[0]
        else:
            loss_2 = 0
        loss_3 = smooth_loss(disp).data[0]

        if log_outputs and i % 100 == 0 and i / 100 < len(
                output_writers):  # log first output of every 100 batch
            index = int(i // 100)
            if epoch == 0:
                for j, ref in enumerate(ref_imgs):
                    output_writers[index].add_image('val Input {}'.format(j),
                                                    tensor2array(tgt_img[0]),
                                                    0)
                    output_writers[index].add_image('val Input {}'.format(j),
                                                    tensor2array(ref[0]), 1)

            output_writers[index].add_image(
                'val Dispnet Output Normalized',
                tensor2array(disp.data[0].cpu(),
                             max_value=None,
                             colormap='bone'), epoch)
            output_writers[index].add_image(
                'val Depth Output',
                tensor2array(1. / disp.data[0].cpu(), max_value=10), epoch)
            # log warped images along with explainability mask
            for j, ref in enumerate(ref_imgs_var):
                ref_warped = inverse_warp(ref[:1], depth[:1, 0], pose[:1, j],
                                          intrinsics_var[:1],
                                          intrinsics_inv_var[:1])[0]
                output_writers[index].add_image(
                    'val Warped Outputs {}'.format(j),
                    tensor2array(ref_warped.data.cpu()), epoch)
                output_writers[index].add_image(
                    'val Diff Outputs {}'.format(j),
                    tensor2array(
                        0.5 * (tgt_img_var[0] - ref_warped).abs().data.cpu()),
                    epoch)
                if explainability_mask is not None:
                    output_writers[index].add_image(
                        'val Exp mask Outputs {}'.format(j),
                        tensor2array(explainability_mask[0, j].data.cpu(),
                                     max_value=1,
                                     colormap='bone'), epoch)

        if log_outputs and i < len(val_loader) - 1:
            step = args.batch_size * (args.sequence_length - 1)
            poses[i * step:(i + 1) * step] = pose.data.cpu().view(-1,
                                                                  6).numpy()

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3
        losses.update([loss, loss_1, loss_2])

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i)
        if i % args.print_freq == 0:
            logger.valid_writer.write('valid: Time {} Loss {}'.format(
                batch_time, losses))
    if log_outputs:
        output_writers[0].add_histogram('val poses_tx', poses[:, 0], epoch)
        output_writers[0].add_histogram('val poses_ty', poses[:, 1], epoch)
        output_writers[0].add_histogram('val poses_tz', poses[:, 2], epoch)
        output_writers[0].add_histogram('val poses_rx', poses[:, 3], epoch)
        output_writers[0].add_histogram('val poses_ry', poses[:, 4], epoch)
        output_writers[0].add_histogram('val poses_rz', poses[:, 5], epoch)

    return losses.avg
Пример #16
0
def train(train_loader,
          disp_net,
          pose_net,
          mask_net,
          flow_net,
          optimizer,
          logger=None,
          train_writer=None,
          global_vars_dict=None):
    # 0. 准备
    args = global_vars_dict['args']
    n_iter = global_vars_dict['n_iter']
    device = global_vars_dict['device']

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)
    w1, w2, w3, w4 = args.cam_photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight, args.flow_photo_loss_weight
    w5 = args.consensus_loss_weight

    if args.robust:
        loss_camera = photometric_reconstruction_loss_robust
        loss_flow = photometric_flow_loss_robust
    else:
        loss_camera = photometric_reconstruction_loss
        loss_flow = photometric_flow_loss


#2. switch to train mode
    disp_net.train()
    pose_net.train()
    mask_net.train()
    flow_net.train()

    end = time.time()
    #3. train cycle
    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics = intrinsics.to(device)
        intrinsics_inv = intrinsics_inv.to(device)

        #3.1 compute output and lossfunc input valve---------------------

        #1. disp->depth(none)
        disparities = disp_net(tgt_img)
        if args.spatial_normalize:
            disparities = [spatial_normalize(disp) for disp in disparities]

        depth = [1 / disp for disp in disparities]

        #2. pose(none)
        pose = pose_net(tgt_img, ref_imgs)
        #pose:[4,4,6]

        #3.flow_fwd,flow_bwd 全光流 (depth, pose)
        # 自己改了一点
        if args.flownet == 'Back2Future':  #临近一共三帧做训练/推断
            flow_fwd, flow_bwd, _ = flow_net(tgt_img, ref_imgs[1:3])
        elif args.flownet == 'FlowNetC6':
            flow_fwd = flow_net(tgt_img, ref_imgs[2])
            flow_bwd = flow_net(tgt_img, ref_imgs[1])
        elif args.flownet == 'FlowNetS':
            print(' ')

        # flow_cam 即背景光流
        # flow - flow_s = flow_o
        flow_cam = pose2flow(
            depth[0].squeeze(), pose[:, 2], intrinsics,
            intrinsics_inv)  # pose[:,2] belongs to forward frame
        flows_cam_fwd = [
            pose2flow(depth_.squeeze(1), pose[:, 2], intrinsics,
                      intrinsics_inv) for depth_ in depth
        ]
        flows_cam_bwd = [
            pose2flow(depth_.squeeze(1), pose[:, 1], intrinsics,
                      intrinsics_inv) for depth_ in depth
        ]

        exp_masks_target = consensus_exp_masks(flows_cam_fwd,
                                               flows_cam_bwd,
                                               flow_fwd,
                                               flow_bwd,
                                               tgt_img,
                                               ref_imgs[2],
                                               ref_imgs[1],
                                               wssim=args.wssim,
                                               wrig=args.wrig,
                                               ws=args.smooth_loss_weight)
        rigidity_mask_fwd = [
            (flows_cam_fwd_i - flow_fwd_i).abs()
            for flows_cam_fwd_i, flow_fwd_i in zip(flows_cam_fwd, flow_fwd)
        ]  # .normalize()
        rigidity_mask_bwd = [
            (flows_cam_bwd_i - flow_bwd_i).abs()
            for flows_cam_bwd_i, flow_bwd_i in zip(flows_cam_bwd, flow_bwd)
        ]  # .normalize()
        #v_u

        # 4.explainability_mask(none)
        explainability_mask = mask_net(tgt_img, ref_imgs)  #有效区域?4??
        #list(5):item:tensor:[4,4,128,512]...[4,4,4,16] value:[0.33~0.48~0.63]
        #-------------------------------------------------

        if args.joint_mask_for_depth:
            explainability_mask_for_depth = compute_joint_mask_for_depth(
                explainability_mask, rigidity_mask_bwd, rigidity_mask_fwd,
                args.THRESH)
        else:
            explainability_mask_for_depth = explainability_mask
        #explainability_mask_for_depth list(5) [b,2,h/ , w/]
        if args.no_non_rigid_mask:
            flow_exp_mask = [None for exp_mask in explainability_mask]
            if args.DEBUG:
                print('Using no masks for flow')
        else:
            flow_exp_mask = [
                1 - exp_mask[:, 1:3] for exp_mask in explainability_mask
            ]
            # explaninbility mask 本来是背景mask, 背景对应像素为1
            #取反改成动物mask,并且只要前后两帧
            #list(4) [4,2,256,512]

    #3.2. compute loss重

    # E-r minimizes the photometric loss on static scene
        if w1 > 0:
            loss_1 = loss_camera(tgt_img,
                                 ref_imgs,
                                 intrinsics,
                                 intrinsics_inv,
                                 depth,
                                 explainability_mask_for_depth,
                                 pose,
                                 lambda_oob=args.lambda_oob,
                                 qch=args.qch,
                                 wssim=args.wssim)
        else:
            loss_1 = torch.tensor([0.]).to(device)
        # E_M
        if w2 > 0:
            loss_2 = explainability_loss(
                explainability_mask
            )  #+ 0.2*gaussian_explainability_loss(explainability_mask)
        else:
            loss_2 = 0
        # E_S
        if w3 > 0:
            if args.smoothness_type == "regular":
                loss_3 = smooth_loss(depth) + smooth_loss(
                    flow_fwd) + smooth_loss(flow_bwd) + smooth_loss(
                        explainability_mask)
            elif args.smoothness_type == "edgeaware":
                loss_3 = edge_aware_smoothness_loss(
                    tgt_img, depth) + edge_aware_smoothness_loss(
                        tgt_img, flow_fwd)
                loss_3 += edge_aware_smoothness_loss(
                    tgt_img, flow_bwd) + edge_aware_smoothness_loss(
                        tgt_img, explainability_mask)
        else:
            loss_3 = torch.tensor([0.]).to(device)
        # E_F
        # minimizes photometric loss on moving regions

        if w4 > 0:
            loss_4 = loss_flow(tgt_img,
                               ref_imgs[1:3], [flow_bwd, flow_fwd],
                               flow_exp_mask,
                               lambda_oob=args.lambda_oob,
                               qch=args.qch,
                               wssim=args.wssim)
        else:
            loss_4 = torch.tensor([0.]).to(device)
        # E_C
        # drives the collaboration
        #explainagy_mask:list(6) of [4,4,4,16] rigidity_mask :list(4):[4,2,128,512]
        if w5 > 0:
            loss_5 = consensus_depth_flow_mask(explainability_mask,
                                               rigidity_mask_bwd,
                                               rigidity_mask_fwd,
                                               exp_masks_target,
                                               exp_masks_target,
                                               THRESH=args.THRESH,
                                               wbce=args.wbce)
        else:
            loss_5 = torch.tensor([0.]).to(device)

        #3.2.6
        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 + w4 * loss_4 + w5 * loss_5
        #end of loss

        #3.3
        # record loss and EPE
        losses.update(loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        #3.4 log data

        # add scalar
        if args.scalar_freq > 0 and n_iter % args.scalar_freq == 0:
            train_writer.add_scalar('batch/cam_photometric_error',
                                    loss_1.item(), n_iter)
            if w2 > 0:
                train_writer.add_scalar('batch/explanability_loss',
                                        loss_2.item(), n_iter)
            train_writer.add_scalar('batch/disparity_smoothness_loss',
                                    loss_3.item(), n_iter)
            train_writer.add_scalar('batch/flow_photometric_error',
                                    loss_4.item(), n_iter)
            train_writer.add_scalar('batch/consensus_error', loss_5.item(),
                                    n_iter)
            train_writer.add_scalar('batch/total_loss', loss.item(), n_iter)

        # add_image为0 则不输出
        if args.training_output_freq > 0 and n_iter % args.training_output_freq == 0:

            train_writer.add_image('train Input', tensor2array(tgt_img[0]),
                                   n_iter)
            train_writer.add_image(
                'train Cam Flow Output',
                flow_to_image(tensor2array(flow_cam.data[0].cpu())), n_iter)

            for k, scaled_depth in enumerate(depth):
                train_writer.add_image(
                    'train Dispnet Output Normalized111 {}'.format(k),
                    tensor2array(disparities[k].data[0].cpu(),
                                 max_value=None,
                                 colormap='bone'), n_iter)
                train_writer.add_image(
                    'train Depth Output {}'.format(k),
                    tensor2array(1 / disparities[k].data[0].cpu(),
                                 max_value=10), n_iter)
                train_writer.add_image(
                    'train Non Rigid Flow Output {}'.format(k),
                    flow_to_image(tensor2array(flow_fwd[k].data[0].cpu())),
                    n_iter)
                train_writer.add_image(
                    'train Target Rigidity {}'.format(k),
                    tensor2array((rigidity_mask_fwd[k] > args.THRESH).type_as(
                        rigidity_mask_fwd[k]).data[0].cpu(),
                                 max_value=1,
                                 colormap='bone'), n_iter)

                b, _, h, w = scaled_depth.size()
                downscale = tgt_img.size(2) / h

                tgt_img_scaled = nn.functional.adaptive_avg_pool2d(
                    tgt_img, (h, w))
                ref_imgs_scaled = [
                    nn.functional.adaptive_avg_pool2d(ref_img, (h, w))
                    for ref_img in ref_imgs
                ]

                intrinsics_scaled = torch.cat(
                    (intrinsics[:, 0:2] / downscale, intrinsics[:, 2:]), dim=1)
                intrinsics_scaled_inv = torch.cat(
                    (intrinsics_inv[:, :, 0:2] * downscale,
                     intrinsics_inv[:, :, 2:]),
                    dim=2)

                train_writer.add_image(
                    'train Non Rigid Warped Image {}'.format(k),
                    tensor2array(
                        flow_warp(ref_imgs_scaled[2],
                                  flow_fwd[k]).data[0].cpu()), n_iter)

                # log warped images along with explainability mask
                for j, ref in enumerate(ref_imgs_scaled):
                    ref_warped = inverse_warp(
                        ref,
                        scaled_depth[:, 0],
                        pose[:, j],
                        intrinsics_scaled,
                        intrinsics_scaled_inv,
                        rotation_mode=args.rotation_mode,
                        padding_mode=args.padding_mode)[0]
                    train_writer.add_image(
                        'train Warped Outputs {} {}'.format(k, j),
                        tensor2array(ref_warped.data.cpu()), n_iter)
                    train_writer.add_image(
                        'train Diff Outputs {} {}'.format(k, j),
                        tensor2array(
                            0.5 *
                            (tgt_img_scaled[0] - ref_warped).abs().data.cpu()),
                        n_iter)
                    if explainability_mask[k] is not None:
                        train_writer.add_image(
                            'train Exp mask Outputs {} {}'.format(k, j),
                            tensor2array(explainability_mask[k][0,
                                                                j].data.cpu(),
                                         max_value=1,
                                         colormap='bone'), n_iter)

        # csv file write
        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([
                loss.item(),
                loss_1.item(),
                loss_2.item() if w2 > 0 else 0,
                loss_3.item(),
                loss_4.item()
            ])
        #terminal output
        if args.log_terminal:
            logger.train_bar.update(i + 1)  #当前epoch 进度
            if i % args.print_freq == 0:
                logger.valid_bar_writer.write(
                    'Train: Time {} Data {} Loss {}'.format(
                        batch_time, data_time, losses))

    # 3.4 edge conditionsssssssssssssssssssssssss
        epoch_size = len(train_loader)
        if i >= epoch_size - 1:
            break

        n_iter += 1

    global_vars_dict['n_iter'] = n_iter
    return losses.avg[0]  #epoch loss
Пример #17
0
def train(args, train_loader, disp_net, pose_exp_net, optimizer, epoch_size,
          logger, train_writer):
    global n_iter, device
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)
    w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight

    # switch to train mode
    disp_net.train()
    pose_exp_net.train()

    end = time.time()
    logger.train_bar.update(0)

    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics = intrinsics.to(device)

        # compute output
        disparities = disp_net(tgt_img)
        depth = [1 / disp for disp in disparities]
        explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs)

        loss_1 = photometric_reconstruction_loss(tgt_img, ref_imgs, intrinsics,
                                                 depth, explainability_mask,
                                                 pose, args.rotation_mode,
                                                 args.padding_mode)
        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask)
        else:
            loss_2 = 0
        loss_3 = smooth_loss(depth)

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3

        if i > 0 and n_iter % args.print_freq == 0:
            train_writer.add_scalar('photometric_error', loss_1.item(), n_iter)
            if w2 > 0:
                train_writer.add_scalar('explanability_loss', loss_2.item(),
                                        n_iter)
            train_writer.add_scalar('disparity_smoothness_loss', loss_3.item(),
                                    n_iter)
            train_writer.add_scalar('total_loss', loss.item(), n_iter)

        if args.training_output_freq > 0 and n_iter % args.training_output_freq == 0:
            train_writer.add_image('train Input', tensor2array(tgt_img[0]),
                                   n_iter)

            with torch.no_grad():
                for k, scaled_depth in enumerate(depth):
                    train_writer.add_image(
                        'train Dispnet Output Normalized {}'.format(k),
                        tensor2array(disparities[k][0],
                                     max_value=None,
                                     colormap='magma'), n_iter)
                    train_writer.add_image(
                        'train Depth Output Normalized {}'.format(k),
                        tensor2array(1 / disparities[k][0], max_value=None),
                        n_iter)
                    b, _, h, w = scaled_depth.size()
                    downscale = tgt_img.size(2) / h

                    tgt_img_scaled = F.interpolate(tgt_img, (h, w),
                                                   mode='area')
                    ref_imgs_scaled = [
                        F.interpolate(ref_img, (h, w), mode='area')
                        for ref_img in ref_imgs
                    ]

                    intrinsics_scaled = torch.cat(
                        (intrinsics[:, 0:2] / downscale, intrinsics[:, 2:]),
                        dim=1)

                    # log warped images along with explainability mask
                    for j, ref in enumerate(ref_imgs_scaled):
                        ref_warped = inverse_warp(
                            ref,
                            scaled_depth[:, 0],
                            pose[:, j],
                            intrinsics_scaled,
                            rotation_mode=args.rotation_mode,
                            padding_mode=args.padding_mode)[0]
                        train_writer.add_image(
                            'train Warped Outputs {} {}'.format(k, j),
                            tensor2array(ref_warped), n_iter)
                        train_writer.add_image(
                            'train Diff Outputs {} {}'.format(k, j),
                            tensor2array(
                                0.5 * (tgt_img_scaled[0] - ref_warped).abs()),
                            n_iter)
                        if explainability_mask[k] is not None:
                            train_writer.add_image(
                                'train Exp mask Outputs {} {}'.format(k, j),
                                tensor2array(explainability_mask[k][0, j],
                                             max_value=1,
                                             colormap='bone'), n_iter)

        # record loss and EPE
        losses.update(loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([
                loss.item(),
                loss_1.item(),
                loss_2.item() if w2 > 0 else 0,
                loss_3.item()
            ])
        logger.train_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.train_writer.write('Train: Time {} Data {} Loss {}'.format(
                batch_time, data_time, losses))
        if i >= epoch_size - 1:
            break

        n_iter += 1

    return losses.avg[0]
Пример #18
0
def train_depth_gt(train_loader,
                   disp_net,
                   optimizer,
                   criterion,
                   logger=None,
                   train_writer=None,
                   global_vars_dict=None):
    # 0. 准备
    args = global_vars_dict['args']
    n_iter = global_vars_dict['n_iter']
    device = global_vars_dict['device']

    batch_time = AverageMeter()
    data_time = AverageMeter()
    loss_names = ['total_loss', 'l1_loss', 'smooth']
    losses = AverageMeter(precision=4, i=len(loss_names))
    w1, w2 = args.gt_loss_weight, args.smooth_loss_weight
    loss_l1 = MaskedL1Loss().to(device)

    #2. switch to train mode
    disp_net.train()
    #pose_net.train()
    #mask_net.train()
    #flow_net.train()

    end = time.time()

    #3. train cycle
    numel = args.batch_size * 1 * 256 * 512

    #main cycle
    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv,
            gt_depth) in enumerate(train_loader):
        # measure data loading time

        data_time.update(time.time() - end)
        #dat
        tgt_img = tgt_img.to(device)
        ref_imgs = [(img.to(device)) for img in ref_imgs]
        intrinsics = intrinsics.to(device)
        intrinsics_inv = intrinsics_inv.to(device)
        gt_depth = gt_depth.to(device)  #[0~1]

        #gt

        disparities = disp_net(tgt_img)
        if args.spatial_normalize:
            disparities = [spatial_normalize(disp)
                           for disp in disparities]  #[0.4,2.7,8.7]

        output_depth = [1 / disp for disp in disparities]

        #output_depth = output_depth[0]#只保留最大尺度

        # compute gradient and do Adam step
        # pre_histcs=[]
        # gt_histcs=[]
        # for depth in output_depth:
        #     pre_histcs.append(torch.histc(depth,bins=100,min=0,max=1))

        loss1 = loss_l1(gt_depth, output_depth)
        loss2 = smooth_loss(output_depth)
        loss = w1 * loss1 + w2 * loss2

        loss.requires_grad_()
        loss.to(device)

        losses.update([loss.item(), loss1.item(),
                       loss2.item()], args.batch_size)
        #plt.imshow(tensor2array(output_depth[0],out_shape='HWC',colormap='bone'))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        #log terminal
        if args.log_terminal:
            logger.train_logger_update(batch=i,
                                       time=batch_time,
                                       names=loss_names,
                                       values=losses)

    #3.4 log data#只在train这里输出batch data 尽早看看能否学习
        train_writer.add_scalar('batch/l2_loss', loss.item(), n_iter)

        # 3.4 edge conditions
        epoch_size = len(train_loader)
        if i >= epoch_size - 1:
            break

        n_iter += 1

    global_vars_dict['n_iter'] = n_iter
    return loss_names, losses  #epoch loss
Пример #19
0
def validate_without_gt(args,
                        val_loader,
                        disp_net,
                        pose_exp_net,
                        epoch,
                        logger,
                        output_writers=[]):
    global device
    batch_time = AverageMeter()
    losses = AverageMeter(i=3, precision=4)
    log_outputs = len(output_writers) > 0
    w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight
    poses = np.zeros(
        ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1),
         6))
    disp_values = np.zeros(((len(val_loader) - 1) * args.batch_size * 3))

    # switch to evaluate mode
    disp_net.eval()
    pose_exp_net.eval()

    end = time.time()
    logger.valid_bar.update(0)
    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(val_loader):
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics = intrinsics.to(device)
        intrinsics_inv = intrinsics_inv.to(device)

        # compute output
        disp = disp_net(tgt_img)
        depth = 1 / disp
        explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs)

        loss_1, warped, diff = photometric_reconstruction_loss(
            tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose,
            args.rotation_mode, args.padding_mode)
        loss_1 = loss_1.item()
        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask).item()
        else:
            loss_2 = 0
        loss_3 = smooth_loss(depth).item()

        if log_outputs and i < len(
                output_writers):  # log first output of first batches
            if epoch == 0:
                for j, ref in enumerate(ref_imgs):
                    output_writers[i].add_image('val Input {}'.format(j),
                                                tensor2array(tgt_img[0]), 0)
                    output_writers[i].add_image('val Input {}'.format(j),
                                                tensor2array(ref[0]), 1)

            log_output_tensorboard(output_writers[i], 'val', '', epoch,
                                   1. / disp, disp, warped, diff,
                                   explainability_mask)

        if log_outputs and i < len(val_loader) - 1:
            step = args.batch_size * (args.sequence_length - 1)
            poses[i * step:(i + 1) * step] = pose.cpu().view(-1, 6).numpy()
            step = args.batch_size * 3
            disp_unraveled = disp.cpu().view(args.batch_size, -1)
            disp_values[i * step:(i + 1) * step] = torch.cat([
                disp_unraveled.min(-1)[0],
                disp_unraveled.median(-1)[0],
                disp_unraveled.max(-1)[0]
            ]).numpy()

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3
        losses.update([loss, loss_1, loss_2])

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.valid_writer.write('valid: Time {} Loss {}'.format(
                batch_time, losses))
    if log_outputs:
        prefix = 'valid poses'
        coeffs_names = ['tx', 'ty', 'tz']
        if args.rotation_mode == 'euler':
            coeffs_names.extend(['rx', 'ry', 'rz'])
        elif args.rotation_mode == 'quat':
            coeffs_names.extend(['qx', 'qy', 'qz'])
        for i in range(poses.shape[1]):
            output_writers[0].add_histogram(
                '{} {}'.format(prefix, coeffs_names[i]), poses[:, i], epoch)
        output_writers[0].add_histogram('disp_values', disp_values, epoch)
    logger.valid_bar.update(len(val_loader))
    return losses.avg, ['Total loss', 'Photo loss', 'Exp loss']
Пример #20
0
def train(args, train_loader, disp_net, pose_exp_net, optimizer, epoch_size,
          logger, tb_writer):
    global n_iter, device
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)
    w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight

    # switch to train mode
    disp_net.train()
    pose_exp_net.train()

    end = time.time()
    logger.train_bar.update(0)

    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(train_loader):
        log_losses = i > 0 and n_iter % args.print_freq == 0
        log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0

        # measure data loading time
        data_time.update(time.time() - end)
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics = intrinsics.to(device)

        # compute output
        disparities = disp_net(tgt_img)
        depth = [1 / disp for disp in disparities]
        #         print("***",len(depth),depth[0].size())
        explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs)

        loss_1, warped, diff = photometric_reconstruction_loss(
            tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose,
            args.rotation_mode, args.padding_mode)
        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask)
        else:
            loss_2 = 0
        loss_3 = smooth_loss(depth)

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3

        if args.with_photocon_loss:
            batch_size = pose.size()[0]
            homo_row = torch.tensor([[0, 0, 0, 1]],
                                    dtype=torch.float).to(device)
            homo_row = homo_row.unsqueeze(0).expand(batch_size, -1, -1)
            T21 = pose_vec2mat(pose[:, 0])
            T21 = torch.cat((T21, homo_row), 1)
            T12 = torch.inverse(T21)
            T23 = pose_vec2mat(pose[:, 1])
            T23 = torch.cat((T23, homo_row), 1)
            T13 = torch.matmul(T23, T12)  #[B, 4, 4]
            #             print("----",T13.size())
            # target = 1 and ref = 3
            ref_img_warped, valid_points = inverse_warp_posemat(
                ref_imgs[1], depth[0][:, 0], T13, intrinsics,
                args.rotation_mode, args.padding_mode)
            diff = (ref_imgs[0] -
                    ref_img_warped) * valid_points.unsqueeze(1).float()
            loss_4 = diff.abs().mean()

            loss += loss_4

        if log_losses:
            tb_writer.add_scalar('photometric_error', loss_1.item(), n_iter)
            if w2 > 0:
                tb_writer.add_scalar('explanability_loss', loss_2.item(),
                                     n_iter)
            tb_writer.add_scalar('disparity_smoothness_loss', loss_3.item(),
                                 n_iter)
            tb_writer.add_scalar('total_loss', loss.item(), n_iter)

        if log_output:
            tb_writer.add_image('train Input', tensor2array(tgt_img[0]),
                                n_iter)
            for k, scaled_maps in enumerate(
                    zip(depth, disparities, warped, diff,
                        explainability_mask)):
                log_output_tensorboard(tb_writer, "train", 0, " {}".format(k),
                                       n_iter, *scaled_maps)

        # record loss and EPE
        losses.update(loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([
                loss.item(),
                loss_1.item(),
                loss_2.item() if w2 > 0 else 0,
                loss_3.item()
            ])
        logger.train_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.train_writer.write('Train: Time {} Data {} Loss {}'.format(
                batch_time, data_time, losses))
        if i >= epoch_size - 1:
            break

        n_iter += 1

    return losses.avg[0]