예제 #1
0
def visualize(args, epoch, model_mag, model_pha, model_vs,  data_loader, writer):
    def save_image(image, tag):
        image -= image.min()
        image /= image.max()
        grid = torchvision.utils.make_grid(image, nrow=4, pad_value=1)
        writer.add_image(tag, grid, epoch)

    model_mag.eval()
    model_mag.eval()
    model_vs.eval()
    
    with torch.no_grad():

        for iter, data in enumerate(data_loader):
            
            mag_us,mag_gt,pha_us,pha_gt,ksp_us ,img_us,img_gt,img_us_np,img_gt_np,sens,mask,_,_ = data
            
            inp_mag = mag_us.unsqueeze(1).to(args.device)
            tgt_mag = mag_gt.unsqueeze(1).to(args.device)
            inp_pha = pha_us.unsqueeze(1).to(args.device)
            tgt_pha = pha_gt.unsqueeze(1).to(args.device)
            ksp_us = ksp_us.to(args.device)
            sens = sens.to(args.device)
            mask = mask.to(args.device)
            img_gt = img_gt.to(args.device)
                
            out_mag = model_mag(inp_mag)
            out_pha = model_pha(inp_pha)
            
            out_mag_unpad = T.unpad(out_mag , ksp_us)
            out_pha_unpad = T.unpad(out_pha , ksp_us)     
            out_cmplx = T.dc(out_mag,out_pha,ksp_us,sens,mask)     
                  
            
            if (epoch >= 2):
                out_cmplx = model_vs(out_cmplx,ksp_us,sens,mask)
            
            save_image(inp_mag, 'Img_mag')
            save_image(tgt_mag, 'Tgt_mag')
            save_image(inp_pha, 'Img_pha')
            save_image(tgt_pha, 'Tgt_pha')
            

            img_gt_cmplx_abs = (torch.sqrt(img_gt[:,:,:,0]**2 + img_gt[:,:,:,1]**2)).unsqueeze(1).to(args.device)
            out_cmplx_abs = (torch.sqrt(out_cmplx[:,:,:,0]**2 + out_cmplx[:,:,:,1]**2)).unsqueeze(1).to(args.device)
            error_cmplx = torch.abs(out_cmplx.cuda() - img_gt.cuda())
            error_cmplx_abs = (torch.sqrt(error_cmplx[:,:,:,0]**2 + error_cmplx[:,:,:,1]**2)).unsqueeze(1).to(args.device)
            
           
            out_cmplx_abs  = T.pad(out_cmplx_abs[0,0,:,:],[256,256]).unsqueeze(0).unsqueeze(1).to(args.device)  
            error_cmplx_abs  = T.pad(error_cmplx_abs[0,0,:,:],[256,256]).unsqueeze(0).unsqueeze(1).to(args.device) 
            img_gt_cmplx_abs  = T.pad(img_gt_cmplx_abs[0,0,:,:],[256,256]).unsqueeze(0).unsqueeze(1).to(args.device) 

            save_image(error_cmplx_abs,'Error')
            save_image(out_cmplx_abs, 'Recons')
            save_image(img_gt_cmplx_abs,'Target')
            
            

            break
예제 #2
0
def visualize(args, epoch, model_mag , model_pha, model_vs ,model_dun,  data_loader, writer):
    def save_image(image, tag):
        image -= image.min()
        image /= image.max()
        grid = torchvision.utils.make_grid(image, nrow=4, pad_value=1)
        writer.add_image(tag, grid, epoch)

    
    model_mag.eval()
    model_pha.eval()
    model_vs.eval()
    model_dun.eval()
    
    with torch.no_grad():
        print(':::::::::::::::: IN VISUALIZE :::::::::::::::::::::')
        for iter, data in enumerate(data_loader):
            
            mag_us,mag_gt,pha_us,pha_gt,ksp_us ,img_us,img_gt,img_us_np,img_gt_np,sens,mask,_,_ = data

            inp_mag = mag_us.unsqueeze(1).to(args.device)
            inp_pha = pha_us.unsqueeze(1).to(args.device)
            img_us_np = img_us_np.unsqueeze(1).float().to(args.device)
            img_gt_np = img_gt_np.unsqueeze(1).float().to(args.device)
            ksp_us = ksp_us.to(args.device)
            sens = sens.to(args.device)
            mask = mask.to(args.device)
            img_gt = img_gt.to(args.device)

                
            out_mag = model_mag(inp_mag)
            out_pha = model_pha(inp_pha)
            
            out_mag_unpad = T.unpad(out_mag , ksp_us)
            out_pha_unpad = T.unpad(out_pha , ksp_us)
            
            out_cmplx = T.dc(out_mag,out_pha,ksp_us,sens,mask)
            out_cmplx = model_vs(out_cmplx,ksp_us,sens,mask)
            
            
            # inp_dun = T.rss(out_cmplx,sens).float().to(args.device)                 ## takes rss
            inp_dun = T.inp_to_dun(out_cmplx).float().to(args.device)        ## takes only mag 
            
            out_dun = model_dun(inp_dun)
            
            save_image(inp_dun.float(), 'Input')
                
                
            err_dun = torch.abs(out_dun - img_gt_np)
            
            save_image(err_dun,'Error')                           
            save_image(out_dun, 'Recons')
            save_image(img_gt_np, 'Target')
            
            break
예제 #3
0
def evaluate(args, epoch, model_mag , model_pha , model_vs ,  data_loader, writer):
    model_mag.eval()
    model_pha.eval()
    model_vs.eval()
    
    losses_mag = []
    losses_pha = []
    losses_cmplx = []
    start = time.perf_counter()
    with torch.no_grad():

        for iter, data in enumerate(tqdm(data_loader)):
            
            mag_us,mag_gt,pha_us,pha_gt,ksp_us ,img_us,img_gt,img_us_np,img_gt_np,sens,mask,_,_ = data
            
            inp_mag = mag_us.unsqueeze(1).to(args.device)
            tgt_mag = mag_gt.unsqueeze(1).to(args.device)
            inp_pha = pha_us.unsqueeze(1).to(args.device)
            tgt_pha = pha_gt.unsqueeze(1).to(args.device)
            ksp_us = ksp_us.to(args.device)
            sens = sens.to(args.device)
            mask = mask.to(args.device)
            img_gt = img_gt.to(args.device)
                
            out_mag = model_mag(inp_mag)
            out_pha = model_pha(inp_pha)
            
            out_mag_unpad = T.unpad(out_mag , ksp_us)
            out_pha_unpad = T.unpad(out_pha , ksp_us)
            out_cmplx = T.dc(out_mag,out_pha,ksp_us,sens,mask)     
                  
            
            if (epoch >= 2):
                out_cmplx = model_vs(out_cmplx,ksp_us,sens,mask)

            loss_mag = F.mse_loss(out_mag , tgt_mag)
            loss_pha = F.mse_loss(out_pha , tgt_pha)
            loss_cmplx = F.mse_loss(out_cmplx,img_gt.to(args.device))
            
            losses_mag.append(loss_mag.item())
            losses_pha.append(loss_pha.item())
            losses_cmplx.append(loss_cmplx.item())
     
        writer.add_scalar('Dev_Loss_Mag', loss_mag, epoch)
        writer.add_scalar('Dev_Loss_Pha', loss_pha, epoch)
        writer.add_scalar('Dev_Loss_cmplx',loss_cmplx, epoch)
        
    return np.mean(losses_mag), np.mean(losses_pha), np.mean(losses_cmplx) , time.perf_counter() - start
예제 #4
0
def run_submission(model_mag, model_pha, model_vs, model_dun, data_loader):

    model_mag.eval()
    model_pha.eval()
    model_vs.eval()
    model_dun.eval()

    reconstructions = defaultdict(list)
    for iter, data in enumerate(tqdm(data_loader)):

        mag_us, pha_us, ksp_us, sens, mask, fname, slice, max_mag = data
        inp_mag = mag_us.unsqueeze(1).cuda()
        inp_pha = pha_us.unsqueeze(1).cuda()
        ksp_us = ksp_us.cuda()
        sens = sens.cuda()
        mask = mask.cuda()

        out_mag = model_mag(inp_mag)
        out_pha = model_pha(inp_pha)

        out_mag_unpad = T.unpad(out_mag, ksp_us)
        out_pha_unpad = T.unpad(out_pha, ksp_us)
        out_cmplx = T.dc(out_mag, out_pha, ksp_us, sens, mask)
        out_cmplx = model_vs(out_cmplx, ksp_us, sens, mask)

        # inp_dun = T.rss(out_cmplx,sens).float().to(args.device)                 ## takes rss
        inp_dun = T.inp_to_dun(out_cmplx).float().cuda()  ## takes only mag

        out_dun = model_dun(inp_dun)  #.squeeze(0).squeeze(0)
        out_dun = T.unpad(out_dun, ksp_us)

        out_dun = out_dun.detach().cpu()
        out_dun = out_dun.squeeze(1)  #.squeeze(0)
        out_dun = out_dun * max_mag.float()
        # print("out_dun",out_dun.shape)

        for i in range(1):
            #                 recons[i] = recons[i] * std[i] + mean[i]
            reconstructions[fname[i]].append(
                (slice[i].numpy(), out_dun[i].numpy()))

    reconstructions = {
        fname: np.stack([pred for _, pred in sorted(slice_preds)])
        for fname, slice_preds in reconstructions.items()
    }
    return reconstructions
예제 #5
0
def reconstruct(args, model_mag, model_pha, model_vs, model_dun, data_loader):

    model_mag.eval()
    model_pha.eval()
    model_vs.eval()
    model_dun.eval()

    with torch.no_grad():

        for iter, data in enumerate(tqdm(data_loader)):

            mag_us, mag_gt, pha_us, pha_gt, ksp_us, img_us, img_gt, img_us_np, img_gt_np, sens, mask, fname = data

            inp_mag = mag_us.unsqueeze(1).to(args.device)
            inp_pha = pha_us.unsqueeze(1).to(args.device)
            img_us_np = img_us_np.unsqueeze(1).float().to(args.device)
            img_gt_np = img_gt_np.unsqueeze(1).float().to(args.device)
            ksp_us = ksp_us.to(args.device)
            sens = sens.to(args.device)
            mask = mask.to(args.device)
            img_gt = img_gt.to(args.device)

            out_mag = model_mag(inp_mag)
            out_pha = model_pha(inp_pha)

            out_mag_unpad = T.unpad(out_mag, ksp_us)
            out_pha_unpad = T.unpad(out_pha, ksp_us)
            out_cmplx = T.dc(out_mag, out_pha, ksp_us, sens, mask)
            out_cmplx = model_vs(out_cmplx, ksp_us, sens, mask)

            # inp_dun = T.rss(out_cmplx,sens).float().to(args.device)                 ## takes rss
            inp_dun = T.inp_to_dun(out_cmplx).float().to(
                args.device)  ## takes only mag

            out_dun = model_dun(inp_dun).squeeze(0)
            out_dun = out_dun.cpu()

            fname = (pathlib.Path(str(fname)))
            parts = list(fname.parts)
            parts[-2] = 'Recons/acc_' + str(args.acceleration) + 'x'
            path = pathlib.Path(*parts)
            path = str(path)[2:-2]

            with h5py.File(str(path), 'w') as f:
                f.create_dataset("Recons", data=out_dun)
예제 #6
0
def evaluate(args, epoch, model_mag , model_pha , model_vs , model_dun ,  data_loader, writer):

    model_mag.eval()
    model_pha.eval()
    model_vs.eval()
    model_dun.eval()

    losses_dun = []
    start = time.perf_counter()
    with torch.no_grad():
        print(':::::::::::::::: IN EVALUATE :::::::::::::::::::::')
        for iter, data in enumerate(tqdm(data_loader)):
            
            mag_us,mag_gt,pha_us,pha_gt,ksp_us ,img_us,img_gt,img_us_np,img_gt_np,sens,mask,_,_ = data

            inp_mag = mag_us.unsqueeze(1).to(args.device)
            inp_pha = pha_us.unsqueeze(1).to(args.device)
            img_us_np = img_us_np.unsqueeze(1).float().to(args.device)
            img_gt_np = img_gt_np.unsqueeze(1).float().to(args.device)
            ksp_us = ksp_us.to(args.device)
            sens = sens.to(args.device)
            mask = mask.to(args.device)
            img_gt = img_gt.to(args.device)

                
            out_mag = model_mag(inp_mag)
            out_pha = model_pha(inp_pha)
            
            out_mag_unpad = T.unpad(out_mag , ksp_us)
            out_pha_unpad = T.unpad(out_pha , ksp_us) 
            out_cmplx = T.dc(out_mag,out_pha,ksp_us,sens,mask)
            out_cmplx = model_vs(out_cmplx,ksp_us,sens,mask)
            
            # inp_dun = T.rss(out_cmplx,sens).float().to(args.device)                 ## takes rss
            inp_dun = T.inp_to_dun(out_cmplx).float().to(args.device)          ## takes only mag
            
            out_dun = model_dun(inp_dun)
            
            # loss_dun = F.mse_loss(out_dun , img_gt_np)
            loss_dun =  ssim_loss(out_dun, img_gt_np,torch.tensor(1.0).unsqueeze(0).cuda())
            losses_dun.append(loss_dun.item())
                
        writer.add_scalar('Dev_Loss_cmplx',np.mean(losses_dun), epoch)
                
        return np.mean(losses_dun) , time.perf_counter() - start
예제 #7
0
def train_epoch(args, epoch, model_mag, model_pha,model_vs, data_loader, optimizer_mag, optimizer_pha,optimizer_vs, writer):
    model_mag.train()
    model_pha.train()
    model_vs.train()
    
    avg_loss_mag = 0.
    avg_loss_pha = 0.
    avg_loss_cmplx = 0.
    start_epoch = start_iter = time.perf_counter()
    global_step = epoch * len(data_loader)
    
    for iter, data in enumerate(tqdm(data_loader)):
       
        mag_us,mag_gt,pha_us,pha_gt,ksp_us ,img_us,img_gt,img_us_np,img_gt_np,sens,mask,_,_ = data

        # input_kspace = input_kspace.to(args.device)
        inp_mag = mag_us.unsqueeze(1).to(args.device)
        tgt_mag = mag_gt.unsqueeze(1).to(args.device)
        inp_pha = pha_us.unsqueeze(1).to(args.device)
        tgt_pha = pha_gt.unsqueeze(1).to(args.device)
        # target = target.unsqueeze(1).to(args.device)
        ksp_us = ksp_us.to(args.device)
        sens = sens.to(args.device)
        mask = mask.to(args.device)
        img_gt = img_gt.to(args.device)

        out_mag = model_mag(inp_mag)
        out_pha = model_pha(inp_pha)
        
        out_mag_unpad = T.unpad(out_mag , ksp_us)
        out_pha_unpad = T.unpad(out_pha , ksp_us)     

        out_cmplx = T.dc(out_mag,out_pha,ksp_us,sens,mask)
        
        
        if (epoch < 1):
            
            model_vs.eval()
        
            loss_mag = F.mse_loss(out_mag, tgt_mag)
            loss_pha = F.mse_loss(out_pha ,tgt_pha)
            loss_cmplx = F.mse_loss(out_cmplx,img_gt.to(args.device))
        
            optimizer_mag.zero_grad()
            optimizer_pha.zero_grad()
        
            loss_mag.backward()
            loss_pha.backward()
        
            optimizer_mag.step()
            optimizer_pha.step()

            avg_loss_mag = 0.99 * avg_loss_mag + 0.01 * loss_mag.item() if iter > 0 else loss_mag.item()
            avg_loss_pha = 0.99 * avg_loss_pha + 0.01 * loss_pha.item() if iter > 0 else loss_pha.item()
            avg_loss_cmplx = 0.99 * avg_loss_cmplx + 0.01 * loss_cmplx.item() if iter > 0 else loss_cmplx.item()
        
            writer.add_scalar('TrainLoss_mag', loss_mag.item(), global_step + iter)
            writer.add_scalar('TrainLoss_pha', loss_pha.item(), global_step + iter)
            writer.add_scalar('TrainLoss_cmplx', loss_cmplx.item(), global_step + iter)
            
            if iter % args.report_interval == 0:
                logging.info(
                f'Epoch = [{epoch:3d}/{args.num_epochs:3d}] '
                f'Iter = [{iter:4d}/{len(data_loader):4d}] '
                f'Loss_mag = {loss_mag.item():.4g} Avg Loss Mag = {avg_loss_mag:.4g} '
                f'Loss_pha = {loss_pha.item():.4g} Avg Loss Pha = {avg_loss_pha:.4g} '
                f'Loss_cmplx = {loss_cmplx.item():.4g} Avg Loss cmplx = {avg_loss_cmplx:.4g} '
                f'Time = {time.perf_counter() - start_iter:.4f}s',
            )
            
            start_iter = time.perf_counter()
            
            
        elif (epoch >= 1 and epoch < 2):
                
            model_mag.train()
            model_pha.train()
            model_vs.eval()
            
            loss_mag = F.mse_loss(out_mag, tgt_mag)
            loss_pha = F.mse_loss(out_pha ,tgt_pha)
            loss_cmplx = F.mse_loss(out_cmplx,img_gt.to(args.device))
        
            optimizer_mag.zero_grad()
            optimizer_pha.zero_grad()
        
            loss_cmplx.backward()

            optimizer_mag.step()
            optimizer_pha.step()
            
            avg_loss_mag = 0.99 * avg_loss_mag + 0.01 * loss_mag.item() if iter > 0 else loss_mag.item()
            avg_loss_pha = 0.99 * avg_loss_pha + 0.01 * loss_pha.item() if iter > 0 else loss_pha.item()
            avg_loss_cmplx = 0.99 * avg_loss_cmplx + 0.01 * loss_cmplx.item() if iter > 0 else loss_cmplx.item()
        
            writer.add_scalar('TrainLoss_mag', loss_mag.item(), global_step + iter)
            writer.add_scalar('TrainLoss_pha', loss_pha.item(), global_step + iter)
            writer.add_scalar('TrainLoss_cmplx', loss_cmplx.item(), global_step + iter)
            
            if iter % args.report_interval == 0:
                logging.info(
                f'Epoch = [{epoch:3d}/{args.num_epochs:3d}] '
                f'Iter = [{iter:4d}/{len(data_loader):4d}] '
                f'Loss_mag = {loss_mag.item():.4g} Avg Loss Mag = {avg_loss_mag:.4g} '
                f'Loss_pha = {loss_pha.item():.4g} Avg Loss Pha = {avg_loss_pha:.4g} '
                f'Loss_cmplx = {loss_cmplx.item():.4g} Avg Loss cmplx = {avg_loss_cmplx:.4g} '
                f'Time = {time.perf_counter() - start_iter:.4f}s',
            )
            
            start_iter = time.perf_counter()
            
 
    
        
        elif (epoch >= 2 and epoch < 3):
            
            model_mag.eval()
            model_pha.eval()
            model_vs.train()
            
            loss_mag = F.mse_loss(out_mag, tgt_mag)
            loss_pha = F.mse_loss(out_pha ,tgt_pha)
            
            out_cmplx = model_vs(out_cmplx,ksp_us,sens,mask)
            
            loss_cmplx = F.mse_loss(out_cmplx,img_gt.cuda())
            
            optimizer_vs.zero_grad()
        
            loss_cmplx.backward()
        
            optimizer_vs.step()
            
            avg_loss_mag = 0.99 * avg_loss_mag + 0.01 * loss_mag.item() if iter > 0 else loss_mag.item()
            avg_loss_pha = 0.99 * avg_loss_pha + 0.01 * loss_pha.item() if iter > 0 else loss_pha.item()
            avg_loss_cmplx = 0.99 * avg_loss_cmplx + 0.01 * loss_cmplx.item() if iter > 0 else loss_cmplx.item()
                       
            writer.add_scalar('TrainLoss_mag', loss_mag.item(), global_step + iter)
            writer.add_scalar('TrainLoss_pha', loss_pha.item(), global_step + iter)
            writer.add_scalar('TrainLoss_cmplx', loss_cmplx.item(), global_step + iter)

            if iter % args.report_interval == 0:
                logging.info(
                    f'Epoch = [{epoch:3d}/{args.num_epochs:3d}] '
                    f'Iter = [{iter:4d}/{len(data_loader):4d}] '
                    f'Loss_mag = {loss_mag.item():.4g} Avg Loss Mag = {avg_loss_mag:.4g} '
                    f'Loss_pha = {loss_pha.item():.4g} Avg Loss Pha = {avg_loss_pha:.4g} '
                    f'Loss_cmplx = {loss_cmplx.item():.4g} Avg Loss cmplx = {avg_loss_cmplx:.4g} '
                    f'Time = {time.perf_counter() - start_iter:.4f}s',
                )
            start_iter = time.perf_counter()
        
        else:
            
            model_mag.train()
            model_pha.train()
            model_vs.train()
            
            out_cmplx = model_vs(out_cmplx,ksp_us,sens,mask)
                   
            loss_mag = F.mse_loss(out_mag, tgt_mag)
            loss_pha = F.mse_loss(out_pha ,tgt_pha)
            loss_cmplx = F.mse_loss(out_cmplx,img_gt.cuda())

            optimizer_mag.zero_grad()
            optimizer_pha.zero_grad()            
            optimizer_vs.zero_grad()
        
            loss_cmplx.backward()
        
            optimizer_mag.step()
            optimizer_pha.step()
            optimizer_vs.step()
            
            avg_loss_mag = 0.99 * avg_loss_mag + 0.01 * loss_mag.item() if iter > 0 else loss_mag.item()
            avg_loss_pha = 0.99 * avg_loss_pha + 0.01 * loss_pha.item() if iter > 0 else loss_pha.item()
            avg_loss_cmplx = 0.99 * avg_loss_cmplx + 0.01 * loss_cmplx.item() if iter > 0 else loss_cmplx.item()
            
            writer.add_scalar('TrainLoss_mag', loss_mag.item(), global_step + iter)
            writer.add_scalar('TrainLoss_pha', loss_pha.item(), global_step + iter)
            writer.add_scalar('TrainLoss_cmplx', loss_cmplx.item(), global_step + iter)

            if iter % args.report_interval == 0:
                logging.info(
                f'Epoch = [{epoch:3d}/{args.num_epochs:3d}] '
                f'Iter = [{iter:4d}/{len(data_loader):4d}] '
                f'Loss_mag = {loss_mag.item():.4g} Avg Loss Mag = {avg_loss_mag:.4g} '
                f'Loss_pha = {loss_pha.item():.4g} Avg Loss Pha = {avg_loss_pha:.4g} '
                f'Loss_cmplx = {loss_cmplx.item():.4g} Avg Loss cmplx = {avg_loss_cmplx:.4g} '
                f'Time = {time.perf_counter() - start_iter:.4f}s',
            )
            start_iter = time.perf_counter()
            
            
            
    return avg_loss_mag , avg_loss_pha , avg_loss_cmplx,  time.perf_counter() - start_epoch
예제 #8
0
def train_epoch(args, epoch, model_mag,model_pha , model_vs, model_dun , data_loader, optimizer_dun, writer):

    model_mag.eval()
    model_pha.eval()
    model_vs.eval()
    model_dun.eval()
    
    avg_loss_dun = 0.
    
    start_epoch = start_iter = time.perf_counter()
    global_step = epoch * len(data_loader)
    
    for iter, data in enumerate(tqdm(data_loader)):
       
        mag_us,mag_gt,pha_us,pha_gt,ksp_us ,img_us,img_gt,img_us_np,img_gt_np,sens,mask,_,_ = data

        inp_mag = mag_us.unsqueeze(1).to(args.device)
        inp_pha = pha_us.unsqueeze(1).to(args.device)
        img_us_np = img_us_np.unsqueeze(1).float().to(args.device)
        img_gt_np = img_gt_np.unsqueeze(1).float().to(args.device)
        ksp_us = ksp_us.to(args.device)
        sens = sens.to(args.device)
        mask = mask.to(args.device)
        img_gt = img_gt.to(args.device)

        
        out_mag = model_mag(inp_mag)
        out_pha = model_pha(inp_pha)
        
        out_mag_unpad = T.unpad(out_mag , ksp_us)
        out_pha_unpad = T.unpad(out_pha , ksp_us)     
        out_cmplx = T.dc(out_mag,out_pha,ksp_us,sens,mask)
        out_cmplx = model_vs(out_cmplx,ksp_us,sens,mask)
        
        # inp_dun = T.rss(out_cmplx,sens).float().to(args.device)                 ## takes rss
        inp_dun = T.inp_to_dun(out_cmplx).float().to(args.device)        ## takes only mag
        
        out_dun = model_dun(inp_dun)
        
        # loss_dun = F.mse_loss(out_dun , img_gt_np)
        loss_dun =  ssim_loss(out_dun, img_gt_np,torch.tensor(1.0).unsqueeze(0).cuda())
                    
        optimizer_dun.zero_grad()
    
        loss_dun.backward()
    
        optimizer_dun.step()

        avg_loss_dun = 0.99 * avg_loss_dun + 0.01 * loss_dun.item() if iter > 0 else loss_dun.item()

        writer.add_scalar('TrainLoss_cmplx', loss_dun.item(), global_step + iter)

        if iter % args.report_interval == 0:
            logging.info(
            f'Epoch = [{epoch:3d}/{args.num_epochs:3d}] '
            f'Iter = [{iter:4d}/{len(data_loader):4d}] '
            f'Loss_dun = {loss_dun.item():.4g} Avg Loss dun = {avg_loss_dun:.4g} '
            f'Time = {time.perf_counter() - start_iter:.4f}s',
        )
        start_iter = time.perf_counter()
            
            
            
    return avg_loss_dun,  time.perf_counter() - start_epoch