def _record_image(self, image_packet, global_cnt, string=None):

        if string is None:
            string = ""

        flow2, flow_label, image = image_packet
        image_index = 0

        b, c, h, w = flow_label.size()

        upsampled_flow = nn.functional.upsample(flow2,
                                                size=(h, w),
                                                mode="bilinear")
        upsampled_flow = upsampled_flow.cpu().detach().numpy()
        orig_image = image[image_index].cpu().numpy()

        orig_flow = flow2rgb(flow_label[image_index].cpu().detach().numpy(),
                             max_value=None)
        pred_flow = flow2rgb(upsampled_flow[image_index], max_value=None)

        concat_image = np.concatenate([orig_image, orig_flow, pred_flow], 1)

        concat_image = concat_image * 255
        concat_image = concat_image.astype(np.uint8)
        concat_image = concat_image.transpose(2, 0, 1)

        self.logger.tb.add_image(string + "predicted_flow", concat_image,
                                 global_cnt)
예제 #2
0
def plot_flow(flow, flow_gt):
    # plot flow
    f, (ax1, ax2) = plt.subplots(1, 2)
    ax1.axis("off")
    ax1.set_title("Flow GT")
    ax1.imshow(flow2rgb(torch2numpy(flow_gt.squeeze(0))))
    ax2.axis("off")
    ax2.set_title("Flow")
    ax2.imshow(flow2rgb(torch2numpy(flow.squeeze(0))))
    plt.show()
예제 #3
0
    def val(self, nb_epoch):
        self.model.eval()
        # if self.val_loader is None: return self.test()
        # DO VAL STUFF HERE
        valstream = tqdm(self.dataloader.val())
        self.avg_loss = AverageMeter()
        self.avg_epe = AverageMeter()
        valstream.set_description('VALIDATING')
        with torch.no_grad():
            for i, data in enumerate(valstream):
                frame = data['frame'].to(self.device)
                flow = data['flow'].cpu()
                finalflow = self.model(frame)
                occlu_final, frame_final = self.warpframes(*finalflow, frame)
                loss = self.getcost(*frame_final, *occlu_final, frame)
                eper_final = self.epe(flow.cpu().detach(),
                                      finalflow[1].cpu().detach())
                self.avg_loss.update(loss.item(), i + 1)
                self.avg_epe.update(eper_final.item(), i + 1)

        self.writer.add_scalar('Loss/val', self.avg_loss.avg, self.global_step)

        self.writer.add_scalar('EPE/val', self.avg_epe.avg, self.global_step)

        fb_frame_final = frame_final[1]
        fb_final = finalflow[1]
        fb_occlu_final = occlu_final[1]

        valstream.close()

        self.val_end({
            'VLloss':
            self.avg_loss.avg,
            'VLepe':
            self.avg_epe.avg,
            'epoch':
            nb_epoch,
            'pred_frame':
            fb_frame_final[0, :, 0:2, :].permute(1, 0, 2, 3),
            'gt_frame':
            frame[0, :, 0:2, :].permute(1, 0, 2, 3),
            'pred_flow':
            flow2rgb(fb_final[0, :, 0:2, :].permute(1, 0, 2, 3), False),
            'gt_flow':
            flow2rgb(flow[0, :, 0:2, :].permute(1, 0, 2, 3), False),
            'pred_occ':
            fb_occlu_final[0, :, 0:2, :].permute(1, 0, 2, 3),
            'gt_occ':
            data['occlusion'][0, :, 0:2, :].permute(1, 0, 2, 3)
        })
예제 #4
0
    def test(self, nb_epoch):
        self.model.eval()
        teststream = tqdm(self.dataloader.test())
        self.avg_loss = AverageMeter()
        teststream.set_description('TESTING')
        with torch.no_grad():
            for i, data in enumerate(teststream):
                frame = data['frame']
                finalflow = self.model(frame)

                occlu_final, frame_final = self.warpframes(*finalflow, frame)
                loss = self.getcost(*frame_final, *occlu_final, frame)

                self.avg_loss.update(loss.item(), i + 1)

        self.writer.add_scalar('Loss/test',
                               self.avg_loss.avg, self.global_step)

        fb_frame_final = frame_final[1]
        fb_final = finalflow[1]
        fb_occlu_final = occlu_final[1]

        teststream.close()

        self.test_end({'VLloss': self.avg_loss.avg,
                       'epoch': nb_epoch,
                       'pred_frame': fb_frame_final[0, :, 0:4, :].permute(1, 0, 2, 3),
                       'gt_frame': frame[0, :, 0:4, :].permute(1, 0, 2, 3),
                       'pred_flow': flow2rgb(fb_final[0, :, 0:4, :].permute(1, 0, 2, 3),False),
                       'pred_occ': fb_occlu_final[0, :, 0:4, :].permute(1, 0, 2, 3), })
예제 #5
0
def test(val_loader,disp_net,mask_net,pose_net, flow_net, tb_writer,global_vars_dict = None):
#data prepared
    device = global_vars_dict['device']
    n_iter_val = global_vars_dict['n_iter_val']
    args = global_vars_dict['args']


    data_time = AverageMeter()


# to eval model
    disp_net.eval()
    pose_net.eval()
    mask_net.eval()
    flow_net.eval()

    end = time.time()
    poses = np.zeros(((len(val_loader)-1) * 1 * (args.sequence_length-1),6))#init

    disp_list = []

    flow_list = []
    mask_list = []

#3. validation cycle
    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in tqdm(enumerate(val_loader)):
        data_time.update(time.time() - end)
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics,intrinsics_inv = intrinsics.to(device),intrinsics_inv.to(device)
    #3.1 forwardpass
        #disp
        disp = disp_net(tgt_img)
        if args.spatial_normalize:
            disp = spatial_normalize(disp)
        depth = 1 / disp

        #pose
        pose = pose_net(tgt_img, ref_imgs)
        #flow----
        #制作前后一帧的
        if args.flownet == 'Back2Future':
            flow_fwd, flow_bwd, _ = flow_net(tgt_img, ref_imgs[1:3])
        elif args.flownet == 'FlowNetC6':
            flow_fwd = flow_net(tgt_img, ref_imgs[2])
            flow_bwd = flow_net(tgt_img, ref_imgs[1])
        #FLOW FWD [B,2,H,W]
        #flow cam :tensor[b,2,h,w]
        #flow_background
        flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics, intrinsics_inv)

        flows_cam_fwd = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics, intrinsics_inv)
        flows_cam_bwd = pose2flow(depth.squeeze(1), pose[:, 1], intrinsics, intrinsics_inv)

        #exp_masks_target = consensus_exp_masks(flows_cam_fwd, flows_cam_bwd, flow_fwd, flow_bwd, tgt_img,
        #                                       ref_imgs[2], ref_imgs[1], wssim=args.wssim, wrig=args.wrig,
        #                                       ws=args.smooth_loss_weight)

        rigidity_mask_fwd = (flows_cam_fwd - flow_fwd).abs()#[b,2,h,w]
        rigidity_mask_bwd = (flows_cam_bwd - flow_bwd).abs()

        # mask
        # 4.explainability_mask(none)
        explainability_mask = mask_net(tgt_img, ref_imgs)  # 有效区域?4??

        # list(5):item:tensor:[4,4,128,512]...[4,4,4,16] value:[0.33~0.48~0.63]
        end = time.time()


    #3.4 check log

        #查看forward pass效果
    # 2 disp
        disp_to_show =tensor2array(disp[0].cpu(), max_value=None,colormap='bone')# tensor disp_to_show :[1,h,w],0.5~3.1~10
        tb_writer.add_image('Disp/disp0', disp_to_show,i)
        disp_list.append(disp_to_show)

        if i == 0:
            disp_arr =  np.expand_dims(disp_to_show,axis=0)
        else:
            disp_to_show = np.expand_dims(disp_to_show,axis=0)
            disp_arr = np.concatenate([disp_arr,disp_to_show],0)


    #3. flow
        tb_writer.add_image('Flow/Flow Output', flow2rgb(flow_fwd[0], max_value=6),i)
        tb_writer.add_image('Flow/cam_Flow Output', flow2rgb(flow_cam[0], max_value=6),i)
        tb_writer.add_image('Flow/rigid_Flow Output', flow2rgb(rigidity_mask_fwd[0], max_value=6),i)
        tb_writer.add_image('Flow/rigidity_mask_fwd',flow2rgb(rigidity_mask_fwd[0],max_value=6),i)
        flow_list.append(flow2rgb(flow_fwd[0], max_value=6))
    #4. mask
        tb_writer.add_image('Mask /mask0',tensor2array(explainability_mask[0][0], max_value=None, colormap='magma'), i)
        #tb_writer.add_image('Mask Output/mask1 sample{}'.format(i),tensor2array(explainability_mask[1][0], max_value=None, colormap='magma'), epoch)
        #tb_writer.add_image('Mask Output/mask2 sample{}'.format(i),tensor2array(explainability_mask[2][0], max_value=None, colormap='magma'), epoch)
        #tb_writer.add_image('Mask Output/mask3 sample{}'.format(i),tensor2array(explainability_mask[3][0], max_value=None, colormap='magma'), epoch)
        mask_list.append(tensor2array(explainability_mask[0][0], max_value=None, colormap='magma'))
    #

    return disp_list,disp_arr,flow_list,mask_list
예제 #6
0
    def train(self, nb_epoch):
        trainstream = tqdm(self.dataloader.train())
        self.avg_loss = AverageMeter()
        self.avg_epe = AverageMeter()
        self.model.train()
        for i, data in enumerate(trainstream):
            self.global_step += 1
            trainstream.set_description('TRAINING')

            # GET X and Frame 2
            # wdt = data['displacement'].to(self.device)
            frame = data['frame'].to(self.device)
            flow = data['flow'].cpu()
            # frame.requires_grad = True
            flow.requires_grad = False
            """
            NOTE : THIS MUST BE ADJUSTED AT DATA LOADER SIDE 
            torch.Size([1, 2, 9, 436, 1024])    -> finalflow size
            torch.Size([1, 2, 9, 108, 256])     -> pyraflow1 size
            torch.Size([1, 2, 9, 54, 128])      -> pyraflow2 size
            torch.Size([1, 2, 9, 27, 64])       -> pyraflow3 size
            """
            pyra1_frame = data['pyra1_frame'].to(self.device)
            # pyra1_frame.requires_grad = True
            pyra2_frame = data['pyra2_frame'].to(self.device)
            # pyra2_frame.requires_grad = True
            laten_frame = data['laten_frame'].to(self.device)
            # laten_frame.requires_grad = True

            self.optimizer.zero_grad()
            # forward
            with torch.set_grad_enabled(True):
                finalflow, pyraflow1, pyraflow2, latenflow = self.model(frame)
                occlu_final, frame_final = self.warpframes(*finalflow, frame)
                occlu_pyra1, frame_pyra1 = self.warpframes(*pyraflow1, pyra1_frame)
                occlu_pyra2, frame_pyra2 = self.warpframes(*pyraflow2, pyra2_frame)
                occlu_laten, frame_laten = self.warpframes(*latenflow, laten_frame)

                # print(occlu_final[0].shape)

                cost_final = self.getcost(*frame_final, *occlu_final, frame)
                cost_pyra1 = self.getcost(*frame_pyra1, *occlu_pyra1, pyra1_frame)
                cost_pyra2 = self.getcost(*frame_pyra2, *occlu_pyra2, pyra2_frame)
                cost_laten = self.getcost(*frame_laten, *occlu_laten, laten_frame)

                eper_final = self.epe(finalflow[1].cpu().detach(), flow.cpu().detach())

                loss = cost_final + cost_pyra1 + cost_pyra2 + cost_laten

                self.avg_loss.update(loss.item(), i + 1)
                self.avg_epe.update(eper_final.item(), i + 1)

                loss.backward()

                self.optimizer.step()

                self.writer.add_scalar('Loss/train',
                                       self.avg_loss.avg, self.global_step)

                self.writer.add_scalar('EPE/train',
                                       self.avg_epe.avg, self.global_step)

                trainstream.set_postfix({'epoch': nb_epoch,
                                         'loss': self.avg_loss.avg,
                                         'epe': self.avg_epe.avg})
        self.scheduler.step(loss)
        trainstream.close()

        fb_frame_final = frame_final[1]
        fb_final = finalflow[1]
        fb_occlu_final = occlu_final[1]

        self.writer.add_histogram('REAL/flow_u', flow[0,0,:].view(-1), nb_epoch)
        self.writer.add_histogram('REAL/flow_v', flow[0,1,:].view(-1), nb_epoch)

        self.writer.add_histogram('PRED/flow_u_ff', finalflow[0][0,0,:].view(-1), nb_epoch)
        self.writer.add_histogram('PRED/flow_v_ff', finalflow[0][0,1,:].view(-1), nb_epoch)

        self.writer.add_histogram('PRED/flow_u_fb', finalflow[1][0,0,:].view(-1), nb_epoch)
        self.writer.add_histogram('PRED/flow_v_fb', finalflow[1][0,1,:].view(-1), nb_epoch)

        self.writer.add_histogram('REAL/occ',data['occlusion'][0,:].view(-1),nb_epoch)

        self.writer.add_histogram('PRED/occ_fb',occlu_final[0][0,:].view(-1),nb_epoch)
        self.writer.add_histogram('PRED/occ_fb',occlu_final[1][0,:].view(-1),nb_epoch)

        return self.train_epoch_end({'TRloss': self.avg_loss.avg,
                                     'epoch': nb_epoch,
                                     'pred_frame': fb_frame_final[0, :, 0:4, :].permute(1, 0, 2, 3),
                                     'gt_frame': frame[0, :, 0:4, :].permute(1, 0, 2, 3),
                                     'pred_flow': flow2rgb(fb_final[0, :, 0:4, :].permute(1, 0, 2, 3), False),
                                     'gt_flow': flow2rgb(flow[0, :, 0:4, :].permute(1, 0, 2, 3),False),
                                     'pred_occ': fb_occlu_final[0, :, 0:4, :].permute(1, 0, 2, 3),
                                     'gt_occ': data['occlusion'][0, :, 0:4, :].permute(1, 0, 2, 3)})
예제 #7
0
def validate_without_gt(val_loader,
                        disp_net,
                        pose_net,
                        mask_net,
                        flow_net,
                        epoch,
                        logger,
                        tb_writer,
                        nb_writers,
                        global_vars_dict=None):
    #data prepared
    device = global_vars_dict['device']
    n_iter_val = global_vars_dict['n_iter_val']
    args = global_vars_dict['args']
    show_samples = copy.deepcopy(args.show_samples)
    for i in range(len(show_samples)):
        show_samples[i] *= len(val_loader)
        show_samples[i] = show_samples[i] // 1

    batch_time = AverageMeter()
    data_time = AverageMeter()
    log_outputs = nb_writers > 0
    losses = AverageMeter(precision=4)

    w1, w2, w3, w4 = args.cam_photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight, args.flow_photo_loss_weight
    w5 = args.consensus_loss_weight

    loss_camera = photometric_reconstruction_loss
    loss_flow = photometric_flow_loss

    # to eval model
    disp_net.eval()
    pose_net.eval()
    mask_net.eval()
    flow_net.eval()

    end = time.time()
    poses = np.zeros(
        ((len(val_loader) - 1) * 1 * (args.sequence_length - 1), 6))  #init

    #3. validation cycle
    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(val_loader):
        data_time.update(time.time() - end)
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics, intrinsics_inv = intrinsics.to(device), intrinsics_inv.to(
            device)
        #3.1 forwardpass
        #disp
        disp = disp_net(tgt_img)
        if args.spatial_normalize:
            disp = spatial_normalize(disp)
        depth = 1 / disp

        #pose
        pose = pose_net(tgt_img, ref_imgs)  #[b,3,h,w]; list

        #flow----
        #制作前后一帧的
        if args.flownet == 'Back2Future':
            flow_fwd, flow_bwd, _ = flow_net(tgt_img, ref_imgs[1:3])
        elif args.flownet == 'FlowNetC6':
            flow_fwd = flow_net(tgt_img, ref_imgs[2])
            flow_bwd = flow_net(tgt_img, ref_imgs[1])
        flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics,
                             intrinsics_inv)

        flows_cam_fwd = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics,
                                  intrinsics_inv)
        flows_cam_bwd = pose2flow(depth.squeeze(1), pose[:, 1], intrinsics,
                                  intrinsics_inv)

        exp_masks_target = consensus_exp_masks(flows_cam_fwd,
                                               flows_cam_bwd,
                                               flow_fwd,
                                               flow_bwd,
                                               tgt_img,
                                               ref_imgs[2],
                                               ref_imgs[1],
                                               wssim=args.wssim,
                                               wrig=args.wrig,
                                               ws=args.smooth_loss_weight)
        no_rigid_flow = flow_fwd - flows_cam_fwd

        rigidity_mask_fwd = (flows_cam_fwd - flow_fwd).abs()  #[b,2,h,w]
        rigidity_mask_bwd = (flows_cam_bwd - flow_bwd).abs()

        # mask
        # 4.explainability_mask(none)
        explainability_mask = mask_net(tgt_img, ref_imgs)  # 有效区域?4??

        # list(5):item:tensor:[4,4,128,512]...[4,4,4,16] value:[0.33~0.48~0.63]

        if args.joint_mask_for_depth:  # false
            explainability_mask_for_depth = explainability_mask

            #explainability_mask_for_depth = compute_joint_mask_for_depth(explainability_mask, rigidity_mask_bwd,
            #                                                            rigidity_mask_fwd,THRESH=args.THRESH)
        else:
            explainability_mask_for_depth = explainability_mask

        # chage

        if args.no_non_rigid_mask:
            flow_exp_mask = None
            if args.DEBUG:
                print('Using no masks for flow')
        else:
            flow_exp_mask = 1 - explainability_mask[:, 1:3]

        #3.2loss-compute
        if w1 > 0:
            loss_1 = loss_camera(tgt_img,
                                 ref_imgs,
                                 intrinsics,
                                 intrinsics_inv,
                                 depth,
                                 explainability_mask_for_depth,
                                 pose,
                                 lambda_oob=args.lambda_oob,
                                 qch=args.qch,
                                 wssim=args.wssim)
        else:
            loss_1 = torch.tensor([0.]).to(device)

        # E_M
        if w2 > 0:
            loss_2 = explainability_loss(
                explainability_mask
            )  # + 0.2*gaussian_explainability_loss(explainability_mask)
        else:
            loss_2 = 0

        #if args.smoothness_type == "regular":
        if w3 > 0:
            loss_3 = smooth_loss(depth) + smooth_loss(
                explainability_mask) + smooth_loss(flow_fwd) + smooth_loss(
                    flow_bwd)
        else:
            loss_3 = torch.tensor([0.]).to(device)
        if w4 > 0:
            loss_4 = loss_flow(tgt_img,
                               ref_imgs[1:3], [flow_bwd, flow_fwd],
                               flow_exp_mask,
                               lambda_oob=args.lambda_oob,
                               qch=args.qch,
                               wssim=args.wssim)
        else:
            loss_4 = torch.tensor([0.]).to(device)
        if w5 > 0:
            loss_5 = consensus_depth_flow_mask(explainability_mask,
                                               rigidity_mask_bwd,
                                               rigidity_mask_fwd,
                                               exp_masks_target,
                                               exp_masks_target,
                                               THRESH=args.THRESH,
                                               wbce=args.wbce)
        else:
            loss_5 = torch.tensor([0.]).to(device)

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 + w4 * loss_4 + w5 * loss_5

        #3.3 data update
        losses.update(loss.item(), args.batch_size)
        batch_time.update(time.time() - end)
        end = time.time()

        #3.4 check log

        #查看forward pass效果
        if args.img_freq > 0 and i in show_samples:  #output_writers list(3)
            if epoch == 0:  #训练前的validate,目的在于先评估下网络效果
                #1.img
                # 不会执行第二次,注意ref_imgs axis0是batch的索引; axis 1是list(adjacent frame)的索引!
                tb_writer.add_image(
                    'epoch 0 Input/sample{}(img{} to img{})'.format(
                        i, i + 1, i + args.sequence_length),
                    tensor2array(ref_imgs[0][0]), 0)
                tb_writer.add_image(
                    'epoch 0 Input/sample{}(img{} to img{})'.format(
                        i, i + 1, i + args.sequence_length),
                    tensor2array(ref_imgs[1][0]), 1)
                tb_writer.add_image(
                    'epoch 0 Input/sample{}(img{} to img{})'.format(
                        i, i + 1, i + args.sequence_length),
                    tensor2array(tgt_img[0]), 2)
                tb_writer.add_image(
                    'epoch 0 Input/sample{}(img{} to img{})'.format(
                        i, i + 1, i + args.sequence_length),
                    tensor2array(ref_imgs[2][0]), 3)
                tb_writer.add_image(
                    'epoch 0 Input/sample{}(img{} to img{})'.format(
                        i, i + 1, i + args.sequence_length),
                    tensor2array(ref_imgs[3][0]), 4)

                depth_to_show = depth[0].cpu(
                )  # tensor disp_to_show :[1,h,w],0.5~3.1~10
                tb_writer.add_image(
                    'Disp Output/sample{}'.format(i),
                    tensor2array(depth_to_show,
                                 max_value=None,
                                 colormap='bone'), 0)

            else:
                #2.disp
                depth_to_show = disp[0].cpu(
                )  # tensor disp_to_show :[1,h,w],0.5~3.1~10
                tb_writer.add_image(
                    'Disp Output/sample{}'.format(i),
                    tensor2array(depth_to_show,
                                 max_value=None,
                                 colormap='bone'), epoch)
                #3. flow
                tb_writer.add_image('Flow/Flow Output sample {}'.format(i),
                                    flow2rgb(flow_fwd[0], max_value=6), epoch)
                tb_writer.add_image('Flow/cam_Flow Output sample {}'.format(i),
                                    flow2rgb(flow_cam[0], max_value=6), epoch)
                tb_writer.add_image(
                    'Flow/no rigid flow Output sample {}'.format(i),
                    flow2rgb(no_rigid_flow[0], max_value=6), epoch)
                tb_writer.add_image(
                    'Flow/rigidity_mask_fwd{}'.format(i),
                    flow2rgb(rigidity_mask_fwd[0], max_value=6), epoch)

                #4. mask
                tb_writer.add_image(
                    'Mask Output/mask0 sample{}'.format(i),
                    tensor2array(explainability_mask[0][0],
                                 max_value=None,
                                 colormap='magma'), epoch)
                #tb_writer.add_image('Mask Output/mask1 sample{}'.format(i),tensor2array(explainability_mask[1][0], max_value=None, colormap='magma'), epoch)
                #tb_writer.add_image('Mask Output/mask2 sample{}'.format(i),tensor2array(explainability_mask[2][0], max_value=None, colormap='magma'), epoch)
                #tb_writer.add_image('Mask Output/mask3 sample{}'.format(i),tensor2array(explainability_mask[3][0], max_value=None, colormap='magma'), epoch)
                tb_writer.add_image(
                    'Mask Output/exp_masks_target sample{}'.format(i),
                    tensor2array(exp_masks_target[0][0],
                                 max_value=None,
                                 colormap='magma'), epoch)
                #tb_writer.add_image('Mask Output/mask0 sample{}'.format(i),
                #            tensor2array(explainability_mask[0][0], max_value=None, colormap='magma'), epoch)

        #

        #output_writers[index].add_image('val Depth Output', tensor2array(depth.data[0].cpu(), max_value=10),
        #                               epoch)

        # errors.update(compute_errors(depth, output_depth.data.squeeze(1)))
        # add scalar
        if args.scalar_freq > 0 and n_iter_val % args.scalar_freq == 0:
            tb_writer.add_scalar('val/E_R', loss_1.item(), n_iter_val)
            if w2 > 0:
                tb_writer.add_scalar('val/E_M', loss_2.item(), n_iter_val)
            tb_writer.add_scalar('val/E_S', loss_3.item(), n_iter_val)
            tb_writer.add_scalar('val/E_F', loss_4.item(), n_iter_val)
            tb_writer.add_scalar('val/E_C', loss_5.item(), n_iter_val)
            tb_writer.add_scalar('val/total_loss', loss.item(), n_iter_val)

        # terminal output
        if args.log_terminal:
            logger.valid_bar.update(i + 1)  # 当前epoch 进度
            if i % args.print_freq == 0:
                logger.valid_bar_writer.write(
                    'Valid: Time {} Data {} Loss {}'.format(
                        batch_time, data_time, losses))

        n_iter_val += 1

    global_vars_dict['n_iter_val'] = n_iter_val
    return losses.avg[0]  #epoch validate loss
예제 #8
0
    def train(self, nb_epoch):
        trainstream = tqdm(self.dataloader.train())
        self.avg_loss = AverageMeter()
        self.avg_epe = AverageMeter()
        self.model.train()
        for i, data in enumerate(trainstream):
            self.global_step += 1
            trainstream.set_description('TRAINING')

            # GET X and Frame 2
            # wdt = data['displacement'].to(self.device)
            frame = data['frame'].to(self.device)

            pyra1_frame = F.interpolate(frame, size=(108, 256))
            pyra2_frame = F.interpolate(frame, size=(54, 128))
            laten_frame = F.interpolate(frame, size=(27, 64))

            flow = data['flow'].cpu()
            # frame.requires_grad = True
            flow.requires_grad = False
            """
            NOTE : THIS MUST BE ADJUSTED AT DATA LOADER SIDE 
            torch.Size([1, 2, 9, 436, 1024])    -> finalflow size
            torch.Size([1, 2, 9, 108, 256])     -> pyraflow1 size
            torch.Size([1, 2, 9, 54, 128])      -> pyraflow2 size
            torch.Size([1, 2, 9, 27, 64])       -> pyraflow3 size
            """
            # pyra1_frame = data['pyra1_frame'].to(self.device)
            # pyra1_frame.requires_grad = True
            # pyra2_frame = data['pyra2_frame'].to(self.device)
            # pyra2_frame.requires_grad = True
            # laten_frame = data['laten_frame'].to(self.device)
            # laten_frame.requires_grad = True
            # ff = data['ff'].to(self.device)
            # fb = data['fb'].to(self.device)
            motion = data['motion'].to(self.device)
            subf = data['subf'].to(self.device)
            # forward
            with torch.set_grad_enabled(True):
                self.optimizer.zero_grad()
                # flows = self.model(frame, ff, fb)
                # for i, finalflow in enumerate(flows):
                #     flowdim = finalflow[0].size()[2:]
                #     framedim = frame.size()[2:]
                #
                #     if flowdim != framedim:
                #         frame_ = F.interpolate(frame, size=(finalflow[0].size(2),finalflow[0].size(3)))
                #         occlu_final, frame_final = self.warpframes(*finalflow, frame_)
                #         loss = self.getcost(*frame_final, *occlu_final, frame_)
                #     else:
                #         occlu_final, frame_final = self.warpframes(*finalflow, frame)
                #         loss = self.getcost(*frame_final, *occlu_final, frame)
                #
                #     self.optimizer.zero_grad()
                #     if i < 3:
                #         loss.backward(retain_graph=True)
                #     else:
                #         loss.backward()
                #         self.writer.add_figure('Activations', plot_grad_flow(self.model.named_parameters()),
                #                                self.global_step)
                #     self.optimizer.step()
                #
                # self.avg_loss.update(loss.item(), i + 1)
                # eper_final = self.epe(finalflow[1].detach(), flow.detach())
                # self.avg_epe.update(eper_final.item(), i + 1)
                #
                # self.writer.add_scalar('Loss/train',
                #                        self.avg_loss.avg, self.global_step)
                #
                # self.writer.add_scalar('EPE/train',
                #                        self.avg_epe.avg, self.global_step)
                #
                # trainstream.set_postfix({'epoch': nb_epoch,
                #                          'loss': self.avg_loss.avg,
                #                          'epe': self.avg_epe.avg})


                # DIfferent way above

                # print(motion.shape, frame.shape)

                latenflow, pyraflow2, pyraflow1, finalflow = self.model(subf, motion)

                occlu_final, frame_final = self.warpframes(*finalflow, frame)
                occlu_pyra1, frame_pyra1 = self.warpframes(*pyraflow1, pyra1_frame)
                occlu_pyra2, frame_pyra2 = self.warpframes(*pyraflow2, pyra2_frame)
                occlu_laten, frame_laten = self.warpframes(*latenflow, laten_frame)

                # print(occlu_final[0].shape)

                cost_final = self.getcost(*frame_final, *occlu_final, frame, *finalflow)
                cost_pyra1 = self.getcost(*frame_pyra1, *occlu_pyra1, pyra1_frame, *pyraflow1)
                cost_pyra2 = self.getcost(*frame_pyra2, *occlu_pyra2, pyra2_frame, *pyraflow2)
                cost_laten = self.getcost(*frame_laten, *occlu_laten, laten_frame, *latenflow)

                loss = cost_final + cost_pyra1 + cost_pyra2 + cost_laten


                # self.optimizer.zero_grad()
                loss.backward()
                self.avg_loss.update(loss.item(), i + 1)

                self.writer.add_figure('Activations', plot_grad_flow(self.model.named_parameters()), self.global_step)

                self.optimizer.step()

                eper_final = self.epe(finalflow[1].detach(), flow.detach())
                self.avg_epe.update(eper_final.item(), i + 1)

                self.writer.add_scalar('Loss/train',
                                       self.avg_loss.avg, self.global_step)

                self.writer.add_scalar('EPE/train',
                                       self.avg_epe.avg, self.global_step)

                trainstream.set_postfix({'epoch': nb_epoch,
                                         'loss': self.avg_loss.avg,
                                         'epe': self.avg_epe.avg})


        self.scheduler.step(loss)
        trainstream.close()

        fb_frame_final = frame_final[1]
        fb_final = finalflow[1]
        fb_occlu_final = occlu_final[1]

        self.writer.add_histogram('REAL/flow_u', flow[0,0].view(-1), nb_epoch)
        self.writer.add_histogram('REAL/flow_v', flow[0,1].view(-1), nb_epoch)

        self.writer.add_histogram('PRED/flow_u_ff', finalflow[0][0,0].view(-1), nb_epoch)
        self.writer.add_histogram('PRED/flow_v_ff', finalflow[0][0,1].view(-1), nb_epoch)

        self.writer.add_histogram('PRED/flow_u_fb', finalflow[1][0,0].view(-1), nb_epoch)
        self.writer.add_histogram('PRED/flow_v_fb', finalflow[1][0,1].view(-1), nb_epoch)

        self.writer.add_histogram('REAL/occ',data['occlusion'][0].view(-1),nb_epoch)

        self.writer.add_histogram('PRED/occ_ff',occlu_final[0][0].view(-1),nb_epoch)
        self.writer.add_histogram('PRED/occ_fb',occlu_final[1][0].view(-1),nb_epoch)

        return self.train_epoch_end({'TRloss': self.avg_loss.avg,
                                     'epoch': nb_epoch,
                                     'pred_frame': fb_frame_final[0:4],
                                     'gt_frame': frame[0:4,:3],
                                     'subf': subf[0:4, :3],
                                     'pred_flow': flow2rgb(fb_final[0:4], False),
                                     'gt_flow': flow2rgb(flow[0:4],False),
                                     # 'ff_in': flow2rgb(ff[0:4], False),
                                     # 'fb_in': flow2rgb(fb[0:4], False),
                                     'pred_occ': 1. - fb_occlu_final[0:4],
                                     'gt_occ': data['occlusion'][0:4]})