def train(train_loader, mask_net, pose_net, optimizer, epoch_size, train_writer): global args, n_iter w1 = args.smooth_loss_weight w2 = args.mask_loss_weight w3 = args.consensus_loss_weight w4 = args.pose_loss_weight mask_net.train() pose_net.train() average_loss = 0 for i, (rgb_tgt_img, rgb_ref_imgs, depth_tgt_img, depth_ref_imgs, mask_tgt_img, mask_ref_imgs, intrinsics, intrinsics_inv, pose_list) in enumerate(tqdm(train_loader)): rgb_tgt_img_var = Variable(rgb_tgt_img.cuda()) rgb_ref_imgs_var = [Variable(img.cuda()) for img in rgb_ref_imgs] depth_tgt_img_var = Variable(depth_tgt_img.unsqueeze(1).cuda()) depth_ref_imgs_var = [ Variable(img.unsqueeze(1).cuda()) for img in depth_ref_imgs ] mask_tgt_img_var = Variable(mask_tgt_img.cuda()) mask_ref_imgs_var = [Variable(img.cuda()) for img in mask_ref_imgs] mask_tgt_img_var = torch.where(mask_tgt_img_var > 0, torch.ones_like(mask_tgt_img_var), torch.zeros_like(mask_tgt_img_var)) mask_ref_imgs_var = [ torch.where(img > 0, torch.ones_like(img), torch.zeros_like(img)) for img in mask_ref_imgs_var ] intrinsics_var = Variable(intrinsics.cuda()) intrinsics_inv_var = Variable(intrinsics_inv.cuda()) # pose_list_var = [Variable(one_pose.float().cuda()) for one_pose in pose_list] explainability_mask = mask_net(rgb_tgt_img_var, rgb_ref_imgs_var) # print(explainability_mask[0].size()) #torch.Size([4, 2, 384, 512]) # print() pose = pose_net(rgb_tgt_img_var, rgb_ref_imgs_var) # loss 1: smoothness loss loss1 = smooth_loss(explainability_mask) # loss 2: explainability loss loss2 = explainability_loss(explainability_mask) # loss 3 consensus loss (the mask from networks and the mask from residual) loss3 = consensus_loss(explainability_mask[0], mask_ref_imgs_var) # loss 4 pose loss valid_pixle_mask = [ torch.where(depth_ref_imgs_var[0] == 0, torch.zeros_like(depth_tgt_img_var), torch.ones_like(depth_tgt_img_var)), torch.where(depth_ref_imgs_var[1] == 0, torch.zeros_like(depth_tgt_img_var), torch.ones_like(depth_tgt_img_var)) ] # zero is invalid loss4, ref_img_warped, diff = pose_loss( valid_pixle_mask, mask_ref_imgs_var, rgb_tgt_img_var, rgb_ref_imgs_var, intrinsics_var, intrinsics_inv_var, depth_tgt_img_var, pose) # compute gradient and do Adam step loss = w1 * loss1 + w2 * loss2 + w3 * loss3 + w4 * loss4 average_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() # visualization in tensorboard if i > 0 and n_iter % args.print_freq == 0: train_writer.add_scalar('smoothness loss', loss1.item(), n_iter) train_writer.add_scalar('explainability loss', loss2.item(), n_iter) train_writer.add_scalar('consensus loss', loss3.item(), n_iter) train_writer.add_scalar('pose loss', loss4.item(), n_iter) train_writer.add_scalar('total loss', loss.item(), n_iter) if n_iter % (args.training_output_freq) == 0: train_writer.add_image('train Input', tensor2array(rgb_tgt_img_var[0]), n_iter) train_writer.add_image( 'train Exp mask Outputs ', tensor2array(explainability_mask[0][0, 0].data.cpu(), max_value=1, colormap='bone'), n_iter) train_writer.add_image( 'train gt mask ', tensor2array(mask_tgt_img[0].data.cpu(), max_value=1, colormap='bone'), n_iter) train_writer.add_image( 'train depth ', tensor2array(depth_tgt_img[0].data.cpu(), max_value=1, colormap='bone'), n_iter) train_writer.add_image( 'train after mask', tensor2array(rgb_tgt_img_var[0] * explainability_mask[0][0, 0]), n_iter) train_writer.add_image('train diff', tensor2array(diff[0]), n_iter) train_writer.add_image('train warped img', tensor2array(ref_img_warped[0]), n_iter) n_iter += 1 return average_loss / i
def train(train_loader, mask_net, pose_net, flow_net, optimizer, epoch_size, train_writer): global args, n_iter w1 = args.smooth_loss_weight w2 = args.mask_loss_weight w3 = args.consensus_loss_weight w4 = args.flow_loss_weight mask_net.train() pose_net.train() flow_net.train() average_loss = 0 for i, (rgb_tgt_img, rgb_ref_imgs, depth_tgt_img, depth_ref_imgs, intrinsics, intrinsics_inv, pose_list) in enumerate(tqdm(train_loader)): rgb_tgt_img_var = Variable(rgb_tgt_img.cuda()) # print(rgb_tgt_img_var.size()) rgb_ref_imgs_var = [Variable(img.cuda()) for img in rgb_ref_imgs] # rgb_ref_imgs_var = [rgb_ref_imgs_var[0], rgb_ref_imgs_var[0], rgb_ref_imgs_var[1], rgb_ref_imgs_var[1]] depth_tgt_img_var = Variable(depth_tgt_img.unsqueeze(1).cuda()) depth_ref_imgs_var = [ Variable(img.unsqueeze(1).cuda()) for img in depth_ref_imgs ] intrinsics_var = Variable(intrinsics.cuda()) intrinsics_inv_var = Variable(intrinsics_inv.cuda()) # pose_list_var = [Variable(one_pose.float().cuda()) for one_pose in pose_list] explainability_mask = mask_net(rgb_tgt_img_var, rgb_ref_imgs_var) valid_pixle_mask = torch.where( depth_tgt_img_var == 0, torch.zeros_like(depth_tgt_img_var), torch.ones_like(depth_tgt_img_var)) # zero is invalid # print(depth_test[0].sum()) # print(explainability_mask[0].size()) #torch.Size([4, 2, 384, 512]) # print() pose = pose_net(rgb_tgt_img_var, rgb_ref_imgs_var) # generate flow from camera pose and depth flow_fwd, flow_bwd, _ = flow_net(rgb_tgt_img_var, rgb_ref_imgs_var) flows_cam_fwd = pose2flow(depth_ref_imgs_var[1].squeeze(1), pose[:, 1], intrinsics_var, intrinsics_inv_var) flows_cam_bwd = pose2flow(depth_ref_imgs_var[0].squeeze(1), pose[:, 0], intrinsics_var, intrinsics_inv_var) rigidity_mask_fwd = (flows_cam_fwd - flow_fwd[0]).abs() rigidity_mask_bwd = (flows_cam_bwd - flow_bwd[0]).abs() # loss 1: smoothness loss loss1 = smooth_loss(explainability_mask) + smooth_loss( flow_bwd) + smooth_loss(flow_fwd) # loss 2: explainability loss loss2 = explainability_loss(explainability_mask) # loss 3 consensus loss (the mask from networks and the mask from residual) depth_Res_mask, depth_ref_img_warped, depth_diff = depth_residual_mask( valid_pixle_mask, explainability_mask[0], rgb_tgt_img_var, rgb_ref_imgs_var, intrinsics_var, intrinsics_inv_var, depth_tgt_img_var, pose) # print(depth_Res_mask[0].size(), explainability_mask[0].size()) loss3 = consensus_loss(explainability_mask[0], rigidity_mask_bwd, rigidity_mask_fwd, args.THRESH, args.wbce) # loss 4: flow loss loss4, flow_ref_img_warped, flow_diff = flow_loss( rgb_tgt_img_var, rgb_ref_imgs_var, [flow_bwd, flow_fwd], explainability_mask) # compute gradient and do Adam step loss = w1 * loss1 + w2 * loss2 + w3 * loss3 + w4 * loss4 average_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() # visualization in tensorboard if i > 0 and n_iter % args.print_freq == 0: train_writer.add_scalar('smoothness loss', loss1.item(), n_iter) train_writer.add_scalar('explainability loss', loss2.item(), n_iter) train_writer.add_scalar('consensus loss', loss3.item(), n_iter) train_writer.add_scalar('flow loss', loss4.item(), n_iter) train_writer.add_scalar('total loss', loss.item(), n_iter) if n_iter % (args.training_output_freq) == 0: train_writer.add_image('train Input', tensor2array(rgb_tgt_img_var[0]), n_iter) train_writer.add_image( 'train Exp mask Outputs ', tensor2array(explainability_mask[0][0, 0].data.cpu(), max_value=1, colormap='bone'), n_iter) train_writer.add_image( 'train depth Res mask ', tensor2array(depth_Res_mask[0][0].data.cpu(), max_value=1, colormap='bone'), n_iter) train_writer.add_image( 'train depth ', tensor2array(depth_tgt_img_var[0].data.cpu(), max_value=1, colormap='bone'), n_iter) train_writer.add_image( 'train valid pixel ', tensor2array(valid_pixle_mask[0].data.cpu(), max_value=1, colormap='bone'), n_iter) train_writer.add_image( 'train after mask', tensor2array(rgb_tgt_img_var[0] * explainability_mask[0][0, 0]), n_iter) train_writer.add_image('train depth diff', tensor2array(depth_diff[0]), n_iter) train_writer.add_image('train flow diff', tensor2array(flow_diff[0]), n_iter) train_writer.add_image('train depth warped img', tensor2array(depth_ref_img_warped[0]), n_iter) train_writer.add_image('train flow warped img', tensor2array(flow_ref_img_warped[0]), n_iter) train_writer.add_image( 'train Cam Flow Output', flow_to_image(tensor2array(flow_fwd[0].data[0].cpu())), n_iter) train_writer.add_image( 'train Flow from Depth Output', flow_to_image(tensor2array(flows_cam_fwd.data[0].cpu())), n_iter) train_writer.add_image( 'train Flow and Depth diff', flow_to_image(tensor2array(rigidity_mask_fwd.data[0].cpu())), n_iter) n_iter += 1 return average_loss / i