示例#1
0
文件: main.py 项目: whjzsy/CAIN
def train(args, epoch):
    global LOSS_0
    losses, psnrs, ssims, lpips = utils.init_meters(args.loss)
    model.train()
    criterion.train()

    t = time.time()
    for i, (images, imgpaths) in enumerate(train_loader):

        # Build input batch
        im1, im2, gt = utils.build_input(images, imgpaths)

        # Forward
        optimizer.zero_grad()
        out, feats = model(im1, im2)
        loss, loss_specific = criterion(out, gt, None, feats)

        # Save loss values
        for k, v in losses.items():
            if k != 'total':
                v.update(loss_specific[k].item())
        if LOSS_0 == 0:
            LOSS_0 = loss.data.item()
        losses['total'].update(loss.item())

        # Backward (+ grad clip) - if loss explodes, skip current iteration
        loss.backward()
        if loss.data.item() > 10.0 * LOSS_0:
            print(max(p.grad.data.abs().max() for p in model.parameters()))
            continue
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()

        # Calc metrics & print logs
        if i % args.log_iter == 0:
            utils.eval_metrics(out, gt, psnrs, ssims, lpips, lpips_model)

            print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}\tPSNR: {:.4f}\tTime({:.2f})'.format(
                epoch, i, len(train_loader), losses['total'].avg, psnrs.avg, time.time() - t))
            
            # Log to TensorBoard
            utils.log_tensorboard(writer, losses, psnrs.avg, ssims.avg, lpips.avg,
                optimizer.param_groups[-1]['lr'], epoch * len(train_loader) + i)

            # Reset metrics
            losses, psnrs, ssims, lpips = utils.init_meters(args.loss)
            t = time.time()
示例#2
0
文件: main.py 项目: whjzsy/CAIN
def test(args, epoch, eval_alpha=0.5):
    print('Evaluating for epoch = %d' % epoch)
    losses, psnrs, ssims, lpips = utils.init_meters(args.loss)
    model.eval()
    criterion.eval()

    save_folder = 'test%03d' % epoch
    if args.dataset == 'snufilm':
        save_folder = os.path.join(save_folder, args.dataset, args.test_mode)
    else:
        save_folder = os.path.join(save_folder, args.dataset)
    save_dir = os.path.join('checkpoint', args.exp_name, save_folder)
    utils.makedirs(save_dir)
    save_fn = os.path.join(save_dir, 'results.txt')
    if not os.path.exists(save_fn):
        with open(save_fn, 'w') as f:
            f.write('For epoch=%d\n' % epoch)

    t = time.time()
    with torch.no_grad():
        for i, (images, imgpaths) in enumerate(tqdm(test_loader)):

            # Build input batch
            im1, im2, gt = utils.build_input(images, imgpaths, is_training=False)

            # Forward
            out, feats = model(im1, im2)

            # Save loss values
            loss, loss_specific = criterion(out, gt, None, feats)
            for k, v in losses.items():
                if k != 'total':
                    v.update(loss_specific[k].item())
            losses['total'].update(loss.item())

            # Evaluate metrics
            utils.eval_metrics(out, gt, psnrs, ssims, lpips)

            # Log examples that have bad performance
            if (ssims.val < 0.9 or psnrs.val < 25) and epoch > 50:
                print(imgpaths)
                print("\nLoss: %f, PSNR: %f, SSIM: %f, LPIPS: %f" %
                      (losses['total'].val, psnrs.val, ssims.val, lpips.val))
                print(imgpaths[1][-1])

            # Save result images
            if ((epoch + 1) % 1 == 0 and i < 20) or args.mode == 'test':
                savepath = os.path.join('checkpoint', args.exp_name, save_folder)

                for b in range(images[0].size(0)):
                    paths = imgpaths[1][b].split('/')
                    fp = os.path.join(savepath, paths[-3], paths[-2])
                    if not os.path.exists(fp):
                        os.makedirs(fp)
                    # remove '.png' extension
                    fp = os.path.join(fp, paths[-1][:-4])
                    utils.save_image(out[b], "%s.png" % fp)
                    
    # Print progress
    print('im_processed: {:d}/{:d} {:.3f}s   \r'.format(i + 1, len(test_loader), time.time() - t))
    print("Loss: %f, PSNR: %f, SSIM: %f, LPIPS: %f\n" %
          (losses['total'].avg, psnrs.avg, ssims.avg, lpips.avg))

    # Save psnr & ssim
    save_fn = os.path.join('checkpoint', args.exp_name, save_folder, 'results.txt')
    with open(save_fn, 'a') as f:
        f.write("PSNR: %f, SSIM: %f, LPIPS: %f\n" %
                (psnrs.avg, ssims.avg, lpips.avg))

    # Log to TensorBoard
    if args.mode != 'test':
        utils.log_tensorboard(writer, losses, psnrs.avg, ssims.avg, lpips.avg,
            optimizer.param_groups[-1]['lr'], epoch * len(train_loader) + i, mode='test')

    return losses['total'].avg, psnrs.avg, ssims.avg, lpips.avg
示例#3
0
def train(args, epoch):
    ### GUI things
    
    global psnrs
    global out
    global gt
    global it
    global apptr
    ### progres bar
    global startedpbar
    startedpbar=0
    ### loss
    global LOSS_0

    losses, psnrs, ssims, lpips = utils.init_meters(args.loss)
    model.train()
    criterion.train()

    t = time.time()

    
    for i, (images, imgpaths) in enumerate(train_loader):
        #print(startedpbar)
        if startedpbar==0:
            pbar=tqdm(range(i, len(train_loader)))
            startedpbar=1
        else:
            startedpbar=1
            #print(startedpbar)

        # Build input batch
        im1, im2, gt = utils.build_input(images, imgpaths)
        
        
        # Forward
        optimizer.zero_grad()
        out, feats = model(im1, im2)
        it+=1

        loss, loss_specific = criterion(out, gt, None, feats)
        QApplication.processEvents()  

        # Save loss values
        for k, v in losses.items():
            if k != 'total':
                v.update(loss_specific[k].item())
        if LOSS_0 == 0:
            LOSS_0 = loss.data.item()
        losses['total'].update(loss.item())

        # Backward (+ grad clip) - if loss explodes, skip current iteration
        loss.backward()
        if loss.data.item() > 10.0 * LOSS_0:
            print(max(p.grad.data.abs().max() for p in model.parameters()))
            continue
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()

        # Calc metrics & print logs
        if i % args.log_iter == 0:
            utils.eval_metrics(out, gt, psnrs, ssims, lpips, lpips_model)

            pbar.update(1)
            pbar.set_postfix({'psnr': psnrs.avg, 'Loss': losses['total'].avg, 'epoch': epoch, 'iterations': it })
            # Log to TensorBoard
            utils.log_tensorboard(writer, losses, psnrs.avg, ssims.avg, lpips.avg,
                optimizer.param_groups[-1]['lr'], epoch * len(train_loader) + i)
            # Reset metrics
            
            t = time.time()
            # update gui
            if args.gui=="True":
                    tgui = threading.Thread(target=updategui)
                    tgui.start()
            losses, psnrs, ssims, lpips = utils.init_meters(args.loss)