Exemplo n.º 1
0
def train(train_loader, mask_net, pose_net, optimizer, epoch_size,
          train_writer):
    global args, n_iter
    w1 = args.smooth_loss_weight
    w2 = args.mask_loss_weight
    w3 = args.consensus_loss_weight
    w4 = args.pose_loss_weight

    mask_net.train()
    pose_net.train()
    average_loss = 0
    for i, (rgb_tgt_img, rgb_ref_imgs, depth_tgt_img, depth_ref_imgs,
            mask_tgt_img, mask_ref_imgs, intrinsics, intrinsics_inv,
            pose_list) in enumerate(tqdm(train_loader)):
        rgb_tgt_img_var = Variable(rgb_tgt_img.cuda())
        rgb_ref_imgs_var = [Variable(img.cuda()) for img in rgb_ref_imgs]
        depth_tgt_img_var = Variable(depth_tgt_img.unsqueeze(1).cuda())
        depth_ref_imgs_var = [
            Variable(img.unsqueeze(1).cuda()) for img in depth_ref_imgs
        ]
        mask_tgt_img_var = Variable(mask_tgt_img.cuda())
        mask_ref_imgs_var = [Variable(img.cuda()) for img in mask_ref_imgs]

        mask_tgt_img_var = torch.where(mask_tgt_img_var > 0,
                                       torch.ones_like(mask_tgt_img_var),
                                       torch.zeros_like(mask_tgt_img_var))
        mask_ref_imgs_var = [
            torch.where(img > 0, torch.ones_like(img), torch.zeros_like(img))
            for img in mask_ref_imgs_var
        ]

        intrinsics_var = Variable(intrinsics.cuda())
        intrinsics_inv_var = Variable(intrinsics_inv.cuda())
        # pose_list_var = [Variable(one_pose.float().cuda()) for one_pose in pose_list]

        explainability_mask = mask_net(rgb_tgt_img_var, rgb_ref_imgs_var)

        # print(explainability_mask[0].size()) #torch.Size([4, 2, 384, 512])
        # print()
        pose = pose_net(rgb_tgt_img_var, rgb_ref_imgs_var)
        # loss 1: smoothness loss
        loss1 = smooth_loss(explainability_mask)

        # loss 2: explainability loss
        loss2 = explainability_loss(explainability_mask)

        # loss 3 consensus loss (the mask from networks and the mask from residual)
        loss3 = consensus_loss(explainability_mask[0], mask_ref_imgs_var)

        # loss 4 pose loss
        valid_pixle_mask = [
            torch.where(depth_ref_imgs_var[0] == 0,
                        torch.zeros_like(depth_tgt_img_var),
                        torch.ones_like(depth_tgt_img_var)),
            torch.where(depth_ref_imgs_var[1] == 0,
                        torch.zeros_like(depth_tgt_img_var),
                        torch.ones_like(depth_tgt_img_var))
        ]  # zero is invalid

        loss4, ref_img_warped, diff = pose_loss(
            valid_pixle_mask, mask_ref_imgs_var, rgb_tgt_img_var,
            rgb_ref_imgs_var, intrinsics_var, intrinsics_inv_var,
            depth_tgt_img_var, pose)

        # compute gradient and do Adam step
        loss = w1 * loss1 + w2 * loss2 + w3 * loss3 + w4 * loss4
        average_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # visualization in tensorboard
        if i > 0 and n_iter % args.print_freq == 0:
            train_writer.add_scalar('smoothness loss', loss1.item(), n_iter)
            train_writer.add_scalar('explainability loss', loss2.item(),
                                    n_iter)
            train_writer.add_scalar('consensus loss', loss3.item(), n_iter)
            train_writer.add_scalar('pose loss', loss4.item(), n_iter)
            train_writer.add_scalar('total loss', loss.item(), n_iter)
        if n_iter % (args.training_output_freq) == 0:
            train_writer.add_image('train Input',
                                   tensor2array(rgb_tgt_img_var[0]), n_iter)
            train_writer.add_image(
                'train Exp mask Outputs ',
                tensor2array(explainability_mask[0][0, 0].data.cpu(),
                             max_value=1,
                             colormap='bone'), n_iter)
            train_writer.add_image(
                'train gt mask ',
                tensor2array(mask_tgt_img[0].data.cpu(),
                             max_value=1,
                             colormap='bone'), n_iter)
            train_writer.add_image(
                'train depth ',
                tensor2array(depth_tgt_img[0].data.cpu(),
                             max_value=1,
                             colormap='bone'), n_iter)
            train_writer.add_image(
                'train after mask',
                tensor2array(rgb_tgt_img_var[0] *
                             explainability_mask[0][0, 0]), n_iter)
            train_writer.add_image('train diff', tensor2array(diff[0]), n_iter)
            train_writer.add_image('train warped img',
                                   tensor2array(ref_img_warped[0]), n_iter)

        n_iter += 1

    return average_loss / i
Exemplo n.º 2
0
def train(train_loader, mask_net, pose_net, flow_net, optimizer, epoch_size,
          train_writer):
    global args, n_iter
    w1 = args.smooth_loss_weight
    w2 = args.mask_loss_weight
    w3 = args.consensus_loss_weight
    w4 = args.flow_loss_weight

    mask_net.train()
    pose_net.train()
    flow_net.train()
    average_loss = 0
    for i, (rgb_tgt_img, rgb_ref_imgs, depth_tgt_img, depth_ref_imgs,
            intrinsics, intrinsics_inv,
            pose_list) in enumerate(tqdm(train_loader)):
        rgb_tgt_img_var = Variable(rgb_tgt_img.cuda())
        # print(rgb_tgt_img_var.size())
        rgb_ref_imgs_var = [Variable(img.cuda()) for img in rgb_ref_imgs]
        # rgb_ref_imgs_var = [rgb_ref_imgs_var[0], rgb_ref_imgs_var[0], rgb_ref_imgs_var[1], rgb_ref_imgs_var[1]]
        depth_tgt_img_var = Variable(depth_tgt_img.unsqueeze(1).cuda())
        depth_ref_imgs_var = [
            Variable(img.unsqueeze(1).cuda()) for img in depth_ref_imgs
        ]
        intrinsics_var = Variable(intrinsics.cuda())
        intrinsics_inv_var = Variable(intrinsics_inv.cuda())
        # pose_list_var = [Variable(one_pose.float().cuda()) for one_pose in pose_list]

        explainability_mask = mask_net(rgb_tgt_img_var, rgb_ref_imgs_var)
        valid_pixle_mask = torch.where(
            depth_tgt_img_var == 0, torch.zeros_like(depth_tgt_img_var),
            torch.ones_like(depth_tgt_img_var))  # zero is invalid
        # print(depth_test[0].sum())

        # print(explainability_mask[0].size()) #torch.Size([4, 2, 384, 512])
        # print()
        pose = pose_net(rgb_tgt_img_var, rgb_ref_imgs_var)

        # generate flow from camera pose and depth
        flow_fwd, flow_bwd, _ = flow_net(rgb_tgt_img_var, rgb_ref_imgs_var)
        flows_cam_fwd = pose2flow(depth_ref_imgs_var[1].squeeze(1), pose[:, 1],
                                  intrinsics_var, intrinsics_inv_var)
        flows_cam_bwd = pose2flow(depth_ref_imgs_var[0].squeeze(1), pose[:, 0],
                                  intrinsics_var, intrinsics_inv_var)
        rigidity_mask_fwd = (flows_cam_fwd - flow_fwd[0]).abs()
        rigidity_mask_bwd = (flows_cam_bwd - flow_bwd[0]).abs()

        # loss 1: smoothness loss
        loss1 = smooth_loss(explainability_mask) + smooth_loss(
            flow_bwd) + smooth_loss(flow_fwd)

        # loss 2: explainability loss
        loss2 = explainability_loss(explainability_mask)

        # loss 3 consensus loss (the mask from networks and the mask from residual)
        depth_Res_mask, depth_ref_img_warped, depth_diff = depth_residual_mask(
            valid_pixle_mask, explainability_mask[0], rgb_tgt_img_var,
            rgb_ref_imgs_var, intrinsics_var, intrinsics_inv_var,
            depth_tgt_img_var, pose)
        # print(depth_Res_mask[0].size(), explainability_mask[0].size())

        loss3 = consensus_loss(explainability_mask[0], rigidity_mask_bwd,
                               rigidity_mask_fwd, args.THRESH, args.wbce)

        # loss 4: flow loss
        loss4, flow_ref_img_warped, flow_diff = flow_loss(
            rgb_tgt_img_var, rgb_ref_imgs_var, [flow_bwd, flow_fwd],
            explainability_mask)

        # compute gradient and do Adam step
        loss = w1 * loss1 + w2 * loss2 + w3 * loss3 + w4 * loss4
        average_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # visualization in tensorboard
        if i > 0 and n_iter % args.print_freq == 0:
            train_writer.add_scalar('smoothness loss', loss1.item(), n_iter)
            train_writer.add_scalar('explainability loss', loss2.item(),
                                    n_iter)
            train_writer.add_scalar('consensus loss', loss3.item(), n_iter)
            train_writer.add_scalar('flow loss', loss4.item(), n_iter)
            train_writer.add_scalar('total loss', loss.item(), n_iter)
        if n_iter % (args.training_output_freq) == 0:
            train_writer.add_image('train Input',
                                   tensor2array(rgb_tgt_img_var[0]), n_iter)
            train_writer.add_image(
                'train Exp mask Outputs ',
                tensor2array(explainability_mask[0][0, 0].data.cpu(),
                             max_value=1,
                             colormap='bone'), n_iter)
            train_writer.add_image(
                'train depth Res mask ',
                tensor2array(depth_Res_mask[0][0].data.cpu(),
                             max_value=1,
                             colormap='bone'), n_iter)
            train_writer.add_image(
                'train depth ',
                tensor2array(depth_tgt_img_var[0].data.cpu(),
                             max_value=1,
                             colormap='bone'), n_iter)
            train_writer.add_image(
                'train valid pixel ',
                tensor2array(valid_pixle_mask[0].data.cpu(),
                             max_value=1,
                             colormap='bone'), n_iter)
            train_writer.add_image(
                'train after mask',
                tensor2array(rgb_tgt_img_var[0] *
                             explainability_mask[0][0, 0]), n_iter)
            train_writer.add_image('train depth diff',
                                   tensor2array(depth_diff[0]), n_iter)
            train_writer.add_image('train flow diff',
                                   tensor2array(flow_diff[0]), n_iter)
            train_writer.add_image('train depth warped img',
                                   tensor2array(depth_ref_img_warped[0]),
                                   n_iter)
            train_writer.add_image('train flow warped img',
                                   tensor2array(flow_ref_img_warped[0]),
                                   n_iter)
            train_writer.add_image(
                'train Cam Flow Output',
                flow_to_image(tensor2array(flow_fwd[0].data[0].cpu())), n_iter)
            train_writer.add_image(
                'train Flow from Depth Output',
                flow_to_image(tensor2array(flows_cam_fwd.data[0].cpu())),
                n_iter)
            train_writer.add_image(
                'train Flow and Depth diff',
                flow_to_image(tensor2array(rigidity_mask_fwd.data[0].cpu())),
                n_iter)

        n_iter += 1

    return average_loss / i