def train(args, epoch): losses, psnrs, ssims = myutils.init_meters(args.loss) model.train() criterion.train() t = time.time() for i, (images, gt_image) in enumerate(train_loader): # Build input batch images = [img_.cuda() for img_ in images] gt = [gt_.cuda() for gt_ in gt_image] # Forward optimizer.zero_grad() out = model(images) out = torch.cat(out) gt = torch.cat(gt) loss, loss_specific = criterion(out, gt) # Save loss values for k, v in losses.items(): if k != 'total': v.update(loss_specific[k].item()) losses['total'].update(loss.item()) loss.backward() optimizer.step() # Calc metrics & print logs if i % args.log_iter == 0: myutils.eval_metrics(out, gt, psnrs, ssims) print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}\tPSNR: {:.4f}'.format( epoch, i, len(train_loader), losses['total'].avg, psnrs.avg, flush=True)) # Log to TensorBoard timestep = epoch * len(train_loader) + i writer.add_scalar('Loss/train', loss.data.item(), timestep) writer.add_scalar('PSNR/train', psnrs.avg, timestep) writer.add_scalar('SSIM/train', ssims.avg, timestep) writer.add_scalar('lr', optimizer.param_groups[-1]['lr'], timestep) # Reset metrics losses, psnrs, ssims = myutils.init_meters(args.loss) t = time.time()
def test(args): time_taken = [] losses, psnrs, ssims = myutils.init_meters(args.loss) model.eval() psnr_list = [] with torch.no_grad(): for i, (images, gt_image) in enumerate(tqdm(test_loader)): images = [img_.cuda() for img_ in images] gt = [g_.cuda() for g_ in gt_image] torch.cuda.synchronize() start_time = time.time() out = model(images) out = torch.cat(out) gt = torch.cat(gt) torch.cuda.synchronize() time_taken.append(time.time() - start_time) myutils.eval_metrics(out, gt, psnrs, ssims) print("PSNR: %f, SSIM: %fn" % (psnrs.avg, ssims.avg)) print("Time , ", sum(time_taken) / len(time_taken)) return psnrs.avg
def test(args, epoch): print('Evaluating for epoch = %d' % epoch) losses, psnrs, ssims = myutils.init_meters(args.loss) model.eval() criterion.eval() t = time.time() with torch.no_grad(): for i, (images, gt_image) in enumerate(tqdm(test_loader)): images = [img_.cuda() for img_ in images] gt = [gt_.cuda() for gt_ in gt_image] out = model(images) ## images is a list of neighboring frames out = torch.cat(out) gt = torch.cat(gt) # Save loss values loss, loss_specific = criterion(out, gt) for k, v in losses.items(): if k != 'total': v.update(loss_specific[k].item()) losses['total'].update(loss.item()) # Evaluate metrics myutils.eval_metrics(out, gt, psnrs, ssims) # Print progress print("Loss: %f, PSNR: %f, SSIM: %f\n" % (losses['total'].avg, psnrs.avg, ssims.avg)) # Save psnr & ssim save_fn = os.path.join(save_loc, 'results.txt') with open(save_fn, 'a') as f: f.write('For epoch=%d\t' % epoch) f.write("PSNR: %f, SSIM: %f\n" % (psnrs.avg, ssims.avg)) # Log to TensorBoard timestep = epoch +1 writer.add_scalar('Loss/test', loss.data.item(), timestep) writer.add_scalar('PSNR/test', psnrs.avg, timestep) writer.add_scalar('SSIM/test', ssims.avg, timestep) return losses['total'].avg, psnrs.avg, ssims.avg
def test(args): time_taken = [] img_save_id = 0 losses, psnrs, ssims = myutils.init_meters(args.loss) model.eval() psnr_list = [] with torch.no_grad(): for i, (images, name) in enumerate((test_loader)): if name[0] not in folderList: continue images = torch.stack(images, dim=1).squeeze(0) # images = [img_.cuda() for img_ in images] H, W = images[0].shape[-2:] resizes = 8 * (H // 8), 8 * (W // 8) import torchvision transform = Resize(resizes) rev_transforms = Resize((H, W)) images = transform(images).unsqueeze(0).cuda( ) # [transform(img_.squeeze(0)).unsqueeze(0).cuda() for img_ in images] images = torch.unbind(images, dim=1) start_time = time.time() out = model(images) print("Time Taken", time.time() - start_time) out = torch.cat(out) out = rev_transforms(out) output_image = make_image(out.squeeze(0)) import imageio os.makedirs("Middleburry/%s/" % name[0]) imageio.imwrite("Middleburry/%s/frame10i11.png" % name[0], output_image) return