def _record_image(self, image_packet, global_cnt, string=None): if string is None: string = "" flow2, flow_label, image = image_packet image_index = 0 b, c, h, w = flow_label.size() upsampled_flow = nn.functional.upsample(flow2, size=(h, w), mode="bilinear") upsampled_flow = upsampled_flow.cpu().detach().numpy() orig_image = image[image_index].cpu().numpy() orig_flow = flow2rgb(flow_label[image_index].cpu().detach().numpy(), max_value=None) pred_flow = flow2rgb(upsampled_flow[image_index], max_value=None) concat_image = np.concatenate([orig_image, orig_flow, pred_flow], 1) concat_image = concat_image * 255 concat_image = concat_image.astype(np.uint8) concat_image = concat_image.transpose(2, 0, 1) self.logger.tb.add_image(string + "predicted_flow", concat_image, global_cnt)
def plot_flow(flow, flow_gt): # plot flow f, (ax1, ax2) = plt.subplots(1, 2) ax1.axis("off") ax1.set_title("Flow GT") ax1.imshow(flow2rgb(torch2numpy(flow_gt.squeeze(0)))) ax2.axis("off") ax2.set_title("Flow") ax2.imshow(flow2rgb(torch2numpy(flow.squeeze(0)))) plt.show()
def val(self, nb_epoch): self.model.eval() # if self.val_loader is None: return self.test() # DO VAL STUFF HERE valstream = tqdm(self.dataloader.val()) self.avg_loss = AverageMeter() self.avg_epe = AverageMeter() valstream.set_description('VALIDATING') with torch.no_grad(): for i, data in enumerate(valstream): frame = data['frame'].to(self.device) flow = data['flow'].cpu() finalflow = self.model(frame) occlu_final, frame_final = self.warpframes(*finalflow, frame) loss = self.getcost(*frame_final, *occlu_final, frame) eper_final = self.epe(flow.cpu().detach(), finalflow[1].cpu().detach()) self.avg_loss.update(loss.item(), i + 1) self.avg_epe.update(eper_final.item(), i + 1) self.writer.add_scalar('Loss/val', self.avg_loss.avg, self.global_step) self.writer.add_scalar('EPE/val', self.avg_epe.avg, self.global_step) fb_frame_final = frame_final[1] fb_final = finalflow[1] fb_occlu_final = occlu_final[1] valstream.close() self.val_end({ 'VLloss': self.avg_loss.avg, 'VLepe': self.avg_epe.avg, 'epoch': nb_epoch, 'pred_frame': fb_frame_final[0, :, 0:2, :].permute(1, 0, 2, 3), 'gt_frame': frame[0, :, 0:2, :].permute(1, 0, 2, 3), 'pred_flow': flow2rgb(fb_final[0, :, 0:2, :].permute(1, 0, 2, 3), False), 'gt_flow': flow2rgb(flow[0, :, 0:2, :].permute(1, 0, 2, 3), False), 'pred_occ': fb_occlu_final[0, :, 0:2, :].permute(1, 0, 2, 3), 'gt_occ': data['occlusion'][0, :, 0:2, :].permute(1, 0, 2, 3) })
def test(self, nb_epoch): self.model.eval() teststream = tqdm(self.dataloader.test()) self.avg_loss = AverageMeter() teststream.set_description('TESTING') with torch.no_grad(): for i, data in enumerate(teststream): frame = data['frame'] finalflow = self.model(frame) occlu_final, frame_final = self.warpframes(*finalflow, frame) loss = self.getcost(*frame_final, *occlu_final, frame) self.avg_loss.update(loss.item(), i + 1) self.writer.add_scalar('Loss/test', self.avg_loss.avg, self.global_step) fb_frame_final = frame_final[1] fb_final = finalflow[1] fb_occlu_final = occlu_final[1] teststream.close() self.test_end({'VLloss': self.avg_loss.avg, 'epoch': nb_epoch, 'pred_frame': fb_frame_final[0, :, 0:4, :].permute(1, 0, 2, 3), 'gt_frame': frame[0, :, 0:4, :].permute(1, 0, 2, 3), 'pred_flow': flow2rgb(fb_final[0, :, 0:4, :].permute(1, 0, 2, 3),False), 'pred_occ': fb_occlu_final[0, :, 0:4, :].permute(1, 0, 2, 3), })
def test(val_loader,disp_net,mask_net,pose_net, flow_net, tb_writer,global_vars_dict = None): #data prepared device = global_vars_dict['device'] n_iter_val = global_vars_dict['n_iter_val'] args = global_vars_dict['args'] data_time = AverageMeter() # to eval model disp_net.eval() pose_net.eval() mask_net.eval() flow_net.eval() end = time.time() poses = np.zeros(((len(val_loader)-1) * 1 * (args.sequence_length-1),6))#init disp_list = [] flow_list = [] mask_list = [] #3. validation cycle for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in tqdm(enumerate(val_loader)): data_time.update(time.time() - end) tgt_img = tgt_img.to(device) ref_imgs = [img.to(device) for img in ref_imgs] intrinsics,intrinsics_inv = intrinsics.to(device),intrinsics_inv.to(device) #3.1 forwardpass #disp disp = disp_net(tgt_img) if args.spatial_normalize: disp = spatial_normalize(disp) depth = 1 / disp #pose pose = pose_net(tgt_img, ref_imgs) #flow---- #制作前后一帧的 if args.flownet == 'Back2Future': flow_fwd, flow_bwd, _ = flow_net(tgt_img, ref_imgs[1:3]) elif args.flownet == 'FlowNetC6': flow_fwd = flow_net(tgt_img, ref_imgs[2]) flow_bwd = flow_net(tgt_img, ref_imgs[1]) #FLOW FWD [B,2,H,W] #flow cam :tensor[b,2,h,w] #flow_background flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics, intrinsics_inv) flows_cam_fwd = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics, intrinsics_inv) flows_cam_bwd = pose2flow(depth.squeeze(1), pose[:, 1], intrinsics, intrinsics_inv) #exp_masks_target = consensus_exp_masks(flows_cam_fwd, flows_cam_bwd, flow_fwd, flow_bwd, tgt_img, # ref_imgs[2], ref_imgs[1], wssim=args.wssim, wrig=args.wrig, # ws=args.smooth_loss_weight) rigidity_mask_fwd = (flows_cam_fwd - flow_fwd).abs()#[b,2,h,w] rigidity_mask_bwd = (flows_cam_bwd - flow_bwd).abs() # mask # 4.explainability_mask(none) explainability_mask = mask_net(tgt_img, ref_imgs) # 有效区域?4?? # list(5):item:tensor:[4,4,128,512]...[4,4,4,16] value:[0.33~0.48~0.63] end = time.time() #3.4 check log #查看forward pass效果 # 2 disp disp_to_show =tensor2array(disp[0].cpu(), max_value=None,colormap='bone')# tensor disp_to_show :[1,h,w],0.5~3.1~10 tb_writer.add_image('Disp/disp0', disp_to_show,i) disp_list.append(disp_to_show) if i == 0: disp_arr = np.expand_dims(disp_to_show,axis=0) else: disp_to_show = np.expand_dims(disp_to_show,axis=0) disp_arr = np.concatenate([disp_arr,disp_to_show],0) #3. flow tb_writer.add_image('Flow/Flow Output', flow2rgb(flow_fwd[0], max_value=6),i) tb_writer.add_image('Flow/cam_Flow Output', flow2rgb(flow_cam[0], max_value=6),i) tb_writer.add_image('Flow/rigid_Flow Output', flow2rgb(rigidity_mask_fwd[0], max_value=6),i) tb_writer.add_image('Flow/rigidity_mask_fwd',flow2rgb(rigidity_mask_fwd[0],max_value=6),i) flow_list.append(flow2rgb(flow_fwd[0], max_value=6)) #4. mask tb_writer.add_image('Mask /mask0',tensor2array(explainability_mask[0][0], max_value=None, colormap='magma'), i) #tb_writer.add_image('Mask Output/mask1 sample{}'.format(i),tensor2array(explainability_mask[1][0], max_value=None, colormap='magma'), epoch) #tb_writer.add_image('Mask Output/mask2 sample{}'.format(i),tensor2array(explainability_mask[2][0], max_value=None, colormap='magma'), epoch) #tb_writer.add_image('Mask Output/mask3 sample{}'.format(i),tensor2array(explainability_mask[3][0], max_value=None, colormap='magma'), epoch) mask_list.append(tensor2array(explainability_mask[0][0], max_value=None, colormap='magma')) # return disp_list,disp_arr,flow_list,mask_list
def train(self, nb_epoch): trainstream = tqdm(self.dataloader.train()) self.avg_loss = AverageMeter() self.avg_epe = AverageMeter() self.model.train() for i, data in enumerate(trainstream): self.global_step += 1 trainstream.set_description('TRAINING') # GET X and Frame 2 # wdt = data['displacement'].to(self.device) frame = data['frame'].to(self.device) flow = data['flow'].cpu() # frame.requires_grad = True flow.requires_grad = False """ NOTE : THIS MUST BE ADJUSTED AT DATA LOADER SIDE torch.Size([1, 2, 9, 436, 1024]) -> finalflow size torch.Size([1, 2, 9, 108, 256]) -> pyraflow1 size torch.Size([1, 2, 9, 54, 128]) -> pyraflow2 size torch.Size([1, 2, 9, 27, 64]) -> pyraflow3 size """ pyra1_frame = data['pyra1_frame'].to(self.device) # pyra1_frame.requires_grad = True pyra2_frame = data['pyra2_frame'].to(self.device) # pyra2_frame.requires_grad = True laten_frame = data['laten_frame'].to(self.device) # laten_frame.requires_grad = True self.optimizer.zero_grad() # forward with torch.set_grad_enabled(True): finalflow, pyraflow1, pyraflow2, latenflow = self.model(frame) occlu_final, frame_final = self.warpframes(*finalflow, frame) occlu_pyra1, frame_pyra1 = self.warpframes(*pyraflow1, pyra1_frame) occlu_pyra2, frame_pyra2 = self.warpframes(*pyraflow2, pyra2_frame) occlu_laten, frame_laten = self.warpframes(*latenflow, laten_frame) # print(occlu_final[0].shape) cost_final = self.getcost(*frame_final, *occlu_final, frame) cost_pyra1 = self.getcost(*frame_pyra1, *occlu_pyra1, pyra1_frame) cost_pyra2 = self.getcost(*frame_pyra2, *occlu_pyra2, pyra2_frame) cost_laten = self.getcost(*frame_laten, *occlu_laten, laten_frame) eper_final = self.epe(finalflow[1].cpu().detach(), flow.cpu().detach()) loss = cost_final + cost_pyra1 + cost_pyra2 + cost_laten self.avg_loss.update(loss.item(), i + 1) self.avg_epe.update(eper_final.item(), i + 1) loss.backward() self.optimizer.step() self.writer.add_scalar('Loss/train', self.avg_loss.avg, self.global_step) self.writer.add_scalar('EPE/train', self.avg_epe.avg, self.global_step) trainstream.set_postfix({'epoch': nb_epoch, 'loss': self.avg_loss.avg, 'epe': self.avg_epe.avg}) self.scheduler.step(loss) trainstream.close() fb_frame_final = frame_final[1] fb_final = finalflow[1] fb_occlu_final = occlu_final[1] self.writer.add_histogram('REAL/flow_u', flow[0,0,:].view(-1), nb_epoch) self.writer.add_histogram('REAL/flow_v', flow[0,1,:].view(-1), nb_epoch) self.writer.add_histogram('PRED/flow_u_ff', finalflow[0][0,0,:].view(-1), nb_epoch) self.writer.add_histogram('PRED/flow_v_ff', finalflow[0][0,1,:].view(-1), nb_epoch) self.writer.add_histogram('PRED/flow_u_fb', finalflow[1][0,0,:].view(-1), nb_epoch) self.writer.add_histogram('PRED/flow_v_fb', finalflow[1][0,1,:].view(-1), nb_epoch) self.writer.add_histogram('REAL/occ',data['occlusion'][0,:].view(-1),nb_epoch) self.writer.add_histogram('PRED/occ_fb',occlu_final[0][0,:].view(-1),nb_epoch) self.writer.add_histogram('PRED/occ_fb',occlu_final[1][0,:].view(-1),nb_epoch) return self.train_epoch_end({'TRloss': self.avg_loss.avg, 'epoch': nb_epoch, 'pred_frame': fb_frame_final[0, :, 0:4, :].permute(1, 0, 2, 3), 'gt_frame': frame[0, :, 0:4, :].permute(1, 0, 2, 3), 'pred_flow': flow2rgb(fb_final[0, :, 0:4, :].permute(1, 0, 2, 3), False), 'gt_flow': flow2rgb(flow[0, :, 0:4, :].permute(1, 0, 2, 3),False), 'pred_occ': fb_occlu_final[0, :, 0:4, :].permute(1, 0, 2, 3), 'gt_occ': data['occlusion'][0, :, 0:4, :].permute(1, 0, 2, 3)})
def validate_without_gt(val_loader, disp_net, pose_net, mask_net, flow_net, epoch, logger, tb_writer, nb_writers, global_vars_dict=None): #data prepared device = global_vars_dict['device'] n_iter_val = global_vars_dict['n_iter_val'] args = global_vars_dict['args'] show_samples = copy.deepcopy(args.show_samples) for i in range(len(show_samples)): show_samples[i] *= len(val_loader) show_samples[i] = show_samples[i] // 1 batch_time = AverageMeter() data_time = AverageMeter() log_outputs = nb_writers > 0 losses = AverageMeter(precision=4) w1, w2, w3, w4 = args.cam_photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight, args.flow_photo_loss_weight w5 = args.consensus_loss_weight loss_camera = photometric_reconstruction_loss loss_flow = photometric_flow_loss # to eval model disp_net.eval() pose_net.eval() mask_net.eval() flow_net.eval() end = time.time() poses = np.zeros( ((len(val_loader) - 1) * 1 * (args.sequence_length - 1), 6)) #init #3. validation cycle for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(val_loader): data_time.update(time.time() - end) tgt_img = tgt_img.to(device) ref_imgs = [img.to(device) for img in ref_imgs] intrinsics, intrinsics_inv = intrinsics.to(device), intrinsics_inv.to( device) #3.1 forwardpass #disp disp = disp_net(tgt_img) if args.spatial_normalize: disp = spatial_normalize(disp) depth = 1 / disp #pose pose = pose_net(tgt_img, ref_imgs) #[b,3,h,w]; list #flow---- #制作前后一帧的 if args.flownet == 'Back2Future': flow_fwd, flow_bwd, _ = flow_net(tgt_img, ref_imgs[1:3]) elif args.flownet == 'FlowNetC6': flow_fwd = flow_net(tgt_img, ref_imgs[2]) flow_bwd = flow_net(tgt_img, ref_imgs[1]) flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics, intrinsics_inv) flows_cam_fwd = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics, intrinsics_inv) flows_cam_bwd = pose2flow(depth.squeeze(1), pose[:, 1], intrinsics, intrinsics_inv) exp_masks_target = consensus_exp_masks(flows_cam_fwd, flows_cam_bwd, flow_fwd, flow_bwd, tgt_img, ref_imgs[2], ref_imgs[1], wssim=args.wssim, wrig=args.wrig, ws=args.smooth_loss_weight) no_rigid_flow = flow_fwd - flows_cam_fwd rigidity_mask_fwd = (flows_cam_fwd - flow_fwd).abs() #[b,2,h,w] rigidity_mask_bwd = (flows_cam_bwd - flow_bwd).abs() # mask # 4.explainability_mask(none) explainability_mask = mask_net(tgt_img, ref_imgs) # 有效区域?4?? # list(5):item:tensor:[4,4,128,512]...[4,4,4,16] value:[0.33~0.48~0.63] if args.joint_mask_for_depth: # false explainability_mask_for_depth = explainability_mask #explainability_mask_for_depth = compute_joint_mask_for_depth(explainability_mask, rigidity_mask_bwd, # rigidity_mask_fwd,THRESH=args.THRESH) else: explainability_mask_for_depth = explainability_mask # chage if args.no_non_rigid_mask: flow_exp_mask = None if args.DEBUG: print('Using no masks for flow') else: flow_exp_mask = 1 - explainability_mask[:, 1:3] #3.2loss-compute if w1 > 0: loss_1 = loss_camera(tgt_img, ref_imgs, intrinsics, intrinsics_inv, depth, explainability_mask_for_depth, pose, lambda_oob=args.lambda_oob, qch=args.qch, wssim=args.wssim) else: loss_1 = torch.tensor([0.]).to(device) # E_M if w2 > 0: loss_2 = explainability_loss( explainability_mask ) # + 0.2*gaussian_explainability_loss(explainability_mask) else: loss_2 = 0 #if args.smoothness_type == "regular": if w3 > 0: loss_3 = smooth_loss(depth) + smooth_loss( explainability_mask) + smooth_loss(flow_fwd) + smooth_loss( flow_bwd) else: loss_3 = torch.tensor([0.]).to(device) if w4 > 0: loss_4 = loss_flow(tgt_img, ref_imgs[1:3], [flow_bwd, flow_fwd], flow_exp_mask, lambda_oob=args.lambda_oob, qch=args.qch, wssim=args.wssim) else: loss_4 = torch.tensor([0.]).to(device) if w5 > 0: loss_5 = consensus_depth_flow_mask(explainability_mask, rigidity_mask_bwd, rigidity_mask_fwd, exp_masks_target, exp_masks_target, THRESH=args.THRESH, wbce=args.wbce) else: loss_5 = torch.tensor([0.]).to(device) loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 + w4 * loss_4 + w5 * loss_5 #3.3 data update losses.update(loss.item(), args.batch_size) batch_time.update(time.time() - end) end = time.time() #3.4 check log #查看forward pass效果 if args.img_freq > 0 and i in show_samples: #output_writers list(3) if epoch == 0: #训练前的validate,目的在于先评估下网络效果 #1.img # 不会执行第二次,注意ref_imgs axis0是batch的索引; axis 1是list(adjacent frame)的索引! tb_writer.add_image( 'epoch 0 Input/sample{}(img{} to img{})'.format( i, i + 1, i + args.sequence_length), tensor2array(ref_imgs[0][0]), 0) tb_writer.add_image( 'epoch 0 Input/sample{}(img{} to img{})'.format( i, i + 1, i + args.sequence_length), tensor2array(ref_imgs[1][0]), 1) tb_writer.add_image( 'epoch 0 Input/sample{}(img{} to img{})'.format( i, i + 1, i + args.sequence_length), tensor2array(tgt_img[0]), 2) tb_writer.add_image( 'epoch 0 Input/sample{}(img{} to img{})'.format( i, i + 1, i + args.sequence_length), tensor2array(ref_imgs[2][0]), 3) tb_writer.add_image( 'epoch 0 Input/sample{}(img{} to img{})'.format( i, i + 1, i + args.sequence_length), tensor2array(ref_imgs[3][0]), 4) depth_to_show = depth[0].cpu( ) # tensor disp_to_show :[1,h,w],0.5~3.1~10 tb_writer.add_image( 'Disp Output/sample{}'.format(i), tensor2array(depth_to_show, max_value=None, colormap='bone'), 0) else: #2.disp depth_to_show = disp[0].cpu( ) # tensor disp_to_show :[1,h,w],0.5~3.1~10 tb_writer.add_image( 'Disp Output/sample{}'.format(i), tensor2array(depth_to_show, max_value=None, colormap='bone'), epoch) #3. flow tb_writer.add_image('Flow/Flow Output sample {}'.format(i), flow2rgb(flow_fwd[0], max_value=6), epoch) tb_writer.add_image('Flow/cam_Flow Output sample {}'.format(i), flow2rgb(flow_cam[0], max_value=6), epoch) tb_writer.add_image( 'Flow/no rigid flow Output sample {}'.format(i), flow2rgb(no_rigid_flow[0], max_value=6), epoch) tb_writer.add_image( 'Flow/rigidity_mask_fwd{}'.format(i), flow2rgb(rigidity_mask_fwd[0], max_value=6), epoch) #4. mask tb_writer.add_image( 'Mask Output/mask0 sample{}'.format(i), tensor2array(explainability_mask[0][0], max_value=None, colormap='magma'), epoch) #tb_writer.add_image('Mask Output/mask1 sample{}'.format(i),tensor2array(explainability_mask[1][0], max_value=None, colormap='magma'), epoch) #tb_writer.add_image('Mask Output/mask2 sample{}'.format(i),tensor2array(explainability_mask[2][0], max_value=None, colormap='magma'), epoch) #tb_writer.add_image('Mask Output/mask3 sample{}'.format(i),tensor2array(explainability_mask[3][0], max_value=None, colormap='magma'), epoch) tb_writer.add_image( 'Mask Output/exp_masks_target sample{}'.format(i), tensor2array(exp_masks_target[0][0], max_value=None, colormap='magma'), epoch) #tb_writer.add_image('Mask Output/mask0 sample{}'.format(i), # tensor2array(explainability_mask[0][0], max_value=None, colormap='magma'), epoch) # #output_writers[index].add_image('val Depth Output', tensor2array(depth.data[0].cpu(), max_value=10), # epoch) # errors.update(compute_errors(depth, output_depth.data.squeeze(1))) # add scalar if args.scalar_freq > 0 and n_iter_val % args.scalar_freq == 0: tb_writer.add_scalar('val/E_R', loss_1.item(), n_iter_val) if w2 > 0: tb_writer.add_scalar('val/E_M', loss_2.item(), n_iter_val) tb_writer.add_scalar('val/E_S', loss_3.item(), n_iter_val) tb_writer.add_scalar('val/E_F', loss_4.item(), n_iter_val) tb_writer.add_scalar('val/E_C', loss_5.item(), n_iter_val) tb_writer.add_scalar('val/total_loss', loss.item(), n_iter_val) # terminal output if args.log_terminal: logger.valid_bar.update(i + 1) # 当前epoch 进度 if i % args.print_freq == 0: logger.valid_bar_writer.write( 'Valid: Time {} Data {} Loss {}'.format( batch_time, data_time, losses)) n_iter_val += 1 global_vars_dict['n_iter_val'] = n_iter_val return losses.avg[0] #epoch validate loss
def train(self, nb_epoch): trainstream = tqdm(self.dataloader.train()) self.avg_loss = AverageMeter() self.avg_epe = AverageMeter() self.model.train() for i, data in enumerate(trainstream): self.global_step += 1 trainstream.set_description('TRAINING') # GET X and Frame 2 # wdt = data['displacement'].to(self.device) frame = data['frame'].to(self.device) pyra1_frame = F.interpolate(frame, size=(108, 256)) pyra2_frame = F.interpolate(frame, size=(54, 128)) laten_frame = F.interpolate(frame, size=(27, 64)) flow = data['flow'].cpu() # frame.requires_grad = True flow.requires_grad = False """ NOTE : THIS MUST BE ADJUSTED AT DATA LOADER SIDE torch.Size([1, 2, 9, 436, 1024]) -> finalflow size torch.Size([1, 2, 9, 108, 256]) -> pyraflow1 size torch.Size([1, 2, 9, 54, 128]) -> pyraflow2 size torch.Size([1, 2, 9, 27, 64]) -> pyraflow3 size """ # pyra1_frame = data['pyra1_frame'].to(self.device) # pyra1_frame.requires_grad = True # pyra2_frame = data['pyra2_frame'].to(self.device) # pyra2_frame.requires_grad = True # laten_frame = data['laten_frame'].to(self.device) # laten_frame.requires_grad = True # ff = data['ff'].to(self.device) # fb = data['fb'].to(self.device) motion = data['motion'].to(self.device) subf = data['subf'].to(self.device) # forward with torch.set_grad_enabled(True): self.optimizer.zero_grad() # flows = self.model(frame, ff, fb) # for i, finalflow in enumerate(flows): # flowdim = finalflow[0].size()[2:] # framedim = frame.size()[2:] # # if flowdim != framedim: # frame_ = F.interpolate(frame, size=(finalflow[0].size(2),finalflow[0].size(3))) # occlu_final, frame_final = self.warpframes(*finalflow, frame_) # loss = self.getcost(*frame_final, *occlu_final, frame_) # else: # occlu_final, frame_final = self.warpframes(*finalflow, frame) # loss = self.getcost(*frame_final, *occlu_final, frame) # # self.optimizer.zero_grad() # if i < 3: # loss.backward(retain_graph=True) # else: # loss.backward() # self.writer.add_figure('Activations', plot_grad_flow(self.model.named_parameters()), # self.global_step) # self.optimizer.step() # # self.avg_loss.update(loss.item(), i + 1) # eper_final = self.epe(finalflow[1].detach(), flow.detach()) # self.avg_epe.update(eper_final.item(), i + 1) # # self.writer.add_scalar('Loss/train', # self.avg_loss.avg, self.global_step) # # self.writer.add_scalar('EPE/train', # self.avg_epe.avg, self.global_step) # # trainstream.set_postfix({'epoch': nb_epoch, # 'loss': self.avg_loss.avg, # 'epe': self.avg_epe.avg}) # DIfferent way above # print(motion.shape, frame.shape) latenflow, pyraflow2, pyraflow1, finalflow = self.model(subf, motion) occlu_final, frame_final = self.warpframes(*finalflow, frame) occlu_pyra1, frame_pyra1 = self.warpframes(*pyraflow1, pyra1_frame) occlu_pyra2, frame_pyra2 = self.warpframes(*pyraflow2, pyra2_frame) occlu_laten, frame_laten = self.warpframes(*latenflow, laten_frame) # print(occlu_final[0].shape) cost_final = self.getcost(*frame_final, *occlu_final, frame, *finalflow) cost_pyra1 = self.getcost(*frame_pyra1, *occlu_pyra1, pyra1_frame, *pyraflow1) cost_pyra2 = self.getcost(*frame_pyra2, *occlu_pyra2, pyra2_frame, *pyraflow2) cost_laten = self.getcost(*frame_laten, *occlu_laten, laten_frame, *latenflow) loss = cost_final + cost_pyra1 + cost_pyra2 + cost_laten # self.optimizer.zero_grad() loss.backward() self.avg_loss.update(loss.item(), i + 1) self.writer.add_figure('Activations', plot_grad_flow(self.model.named_parameters()), self.global_step) self.optimizer.step() eper_final = self.epe(finalflow[1].detach(), flow.detach()) self.avg_epe.update(eper_final.item(), i + 1) self.writer.add_scalar('Loss/train', self.avg_loss.avg, self.global_step) self.writer.add_scalar('EPE/train', self.avg_epe.avg, self.global_step) trainstream.set_postfix({'epoch': nb_epoch, 'loss': self.avg_loss.avg, 'epe': self.avg_epe.avg}) self.scheduler.step(loss) trainstream.close() fb_frame_final = frame_final[1] fb_final = finalflow[1] fb_occlu_final = occlu_final[1] self.writer.add_histogram('REAL/flow_u', flow[0,0].view(-1), nb_epoch) self.writer.add_histogram('REAL/flow_v', flow[0,1].view(-1), nb_epoch) self.writer.add_histogram('PRED/flow_u_ff', finalflow[0][0,0].view(-1), nb_epoch) self.writer.add_histogram('PRED/flow_v_ff', finalflow[0][0,1].view(-1), nb_epoch) self.writer.add_histogram('PRED/flow_u_fb', finalflow[1][0,0].view(-1), nb_epoch) self.writer.add_histogram('PRED/flow_v_fb', finalflow[1][0,1].view(-1), nb_epoch) self.writer.add_histogram('REAL/occ',data['occlusion'][0].view(-1),nb_epoch) self.writer.add_histogram('PRED/occ_ff',occlu_final[0][0].view(-1),nb_epoch) self.writer.add_histogram('PRED/occ_fb',occlu_final[1][0].view(-1),nb_epoch) return self.train_epoch_end({'TRloss': self.avg_loss.avg, 'epoch': nb_epoch, 'pred_frame': fb_frame_final[0:4], 'gt_frame': frame[0:4,:3], 'subf': subf[0:4, :3], 'pred_flow': flow2rgb(fb_final[0:4], False), 'gt_flow': flow2rgb(flow[0:4],False), # 'ff_in': flow2rgb(ff[0:4], False), # 'fb_in': flow2rgb(fb[0:4], False), 'pred_occ': 1. - fb_occlu_final[0:4], 'gt_occ': data['occlusion'][0:4]})