def train(args, logger, device_ids): writer = SummaryWriter() logger.info("Loading network") model = AdaMatting(in_channel=4) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.0001) if args.resume != "": ckpt = torch.load(args.resume) model.load_state_dict(ckpt["state_dict"]) optimizer.load_state_dict(ckpt["optimizer"]) if args.cuda: for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() device = torch.device("cuda:{}".format(device_ids[0])) if len(device_ids) > 1: logger.info("Loading with multiple GPUs") model = torch.nn.DataParallel(model, device_ids=device_ids) # model = convert_model(model) else: device = torch.device("cpu") model = model.to(device) logger.info("Initializing data loaders") train_dataset = AdaMattingDataset(args.raw_data_path, "train") train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=16, pin_memory=True, drop_last=True) valid_dataset = AdaMattingDataset(args.raw_data_path, "valid") valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16, pin_memory=True, drop_last=True) if args.resume != "": logger.info("Start training from saved ckpt") start_epoch = ckpt["epoch"] + 1 cur_iter = ckpt["cur_iter"] peak_lr = ckpt["peak_lr"] best_loss = ckpt["best_loss"] best_alpha_loss = ckpt["best_alpha_loss"] else: logger.info("Start training from scratch") start_epoch = 0 cur_iter = 0 peak_lr = args.lr best_loss = float('inf') best_alpha_loss = float('inf') max_iter = 43100 * (1 - args.valid_portion / 100) / args.batch_size * args.epochs tensorboard_iter = cur_iter * (args.batch_size / 16) avg_lo = AverageMeter() avg_lt = AverageMeter() avg_la = AverageMeter() for epoch in range(start_epoch, args.epochs): # Training torch.set_grad_enabled(True) model.train() for index, (_, inputs, gts) in enumerate(train_loader): # cur_lr, peak_lr = lr_scheduler(optimizer=optimizer, cur_iter=cur_iter, peak_lr=peak_lr, end_lr=0.000001, # decay_iters=args.decay_iters, decay_power=0.8, power=0.5) cur_lr = lr_scheduler(optimizer=optimizer, init_lr=args.lr, cur_iter=cur_iter, max_iter=max_iter, max_decay_times=30, decay_rate=0.9) # img = img.type(torch.FloatTensor).to(device) # [bs, 4, 320, 320] inputs = inputs.to(device) gt_alpha = (gts[:, 0, :, :].unsqueeze(1)).type(torch.FloatTensor).to(device) # [bs, 1, 320, 320] gt_trimap = gts[:, 1, :, :].type(torch.LongTensor).to(device) # [bs, 320, 320] optimizer.zero_grad() trimap_adaption, t_argmax, alpha_estimation, log_sigma_t_sqr, log_sigma_a_sqr = model(inputs) L_overall, L_t, L_a = task_uncertainty_loss(pred_trimap=trimap_adaption, pred_trimap_argmax=t_argmax, pred_alpha=alpha_estimation, gt_trimap=gt_trimap, gt_alpha=gt_alpha, log_sigma_t_sqr=log_sigma_t_sqr, log_sigma_a_sqr=log_sigma_a_sqr) sigma_t, sigma_a = torch.exp(log_sigma_t_sqr.mean() / 2), torch.exp(log_sigma_a_sqr.mean() / 2) optimizer.zero_grad() L_overall.backward() clip_gradient(optimizer, 5) optimizer.step() avg_lo.update(L_overall.item()) avg_lt.update(L_t.item()) avg_la.update(L_a.item()) if cur_iter % 10 == 0: logger.info("Epoch: {:03d} | Iter: {:05d}/{} | Loss: {:.4e} | L_t: {:.4e} | L_a: {:.4e}" .format(epoch, index, len(train_loader), avg_lo.avg, avg_lt.avg, avg_la.avg)) writer.add_scalar("loss/L_overall", avg_lo.avg, tensorboard_iter) writer.add_scalar("loss/L_t", avg_lt.avg, tensorboard_iter) writer.add_scalar("loss/L_a", avg_la.avg, tensorboard_iter) writer.add_scalar("other/sigma_t", sigma_t.item(), tensorboard_iter) writer.add_scalar("other/sigma_a", sigma_a.item(), tensorboard_iter) writer.add_scalar("other/lr", cur_lr, tensorboard_iter) avg_lo.reset() avg_lt.reset() avg_la.reset() cur_iter += 1 tensorboard_iter = cur_iter * (args.batch_size / 16) # Validation logger.info("Validating after the {}th epoch".format(epoch)) avg_loss = AverageMeter() avg_l_t = AverageMeter() avg_l_a = AverageMeter() torch.cuda.empty_cache() torch.set_grad_enabled(False) model.eval() with tqdm(total=len(valid_loader)) as pbar: for index, (display_rgb, inputs, gts) in enumerate(valid_loader): inputs = inputs.to(device) # [bs, 4, 320, 320] gt_alpha = (gts[:, 0, :, :].unsqueeze(1)).type(torch.FloatTensor).to(device) # [bs, 1, 320, 320] gt_trimap = gts[:, 1, :, :].type(torch.LongTensor).to(device) # [bs, 320, 320] trimap_adaption, t_argmax, alpha_estimation, log_sigma_t_sqr, log_sigma_a_sqr = model(inputs) L_overall_valid, L_t_valid, L_a_valid = task_uncertainty_loss(pred_trimap=trimap_adaption, pred_trimap_argmax=t_argmax, pred_alpha=alpha_estimation, gt_trimap=gt_trimap, gt_alpha=gt_alpha, log_sigma_t_sqr=log_sigma_t_sqr, log_sigma_a_sqr=log_sigma_a_sqr) avg_loss.update(L_overall_valid.item()) avg_l_t.update(L_t_valid.item()) avg_l_a.update(L_a_valid.item()) if index == 0: input_rbg = torchvision.utils.make_grid(display_rgb, normalize=False, scale_each=True) writer.add_image('input/rbg_image', input_rbg, tensorboard_iter) input_trimap = inputs[:, 3, :, :].unsqueeze(dim=1) input_trimap = torchvision.utils.make_grid(input_trimap, normalize=False, scale_each=True) writer.add_image('input/trimap', input_trimap, tensorboard_iter) output_alpha = alpha_estimation.clone() output_alpha[t_argmax.unsqueeze(dim=1) == 0] = 0.0 output_alpha[t_argmax.unsqueeze(dim=1) == 2] = 1.0 output_alpha = torchvision.utils.make_grid(output_alpha, normalize=False, scale_each=True) writer.add_image('output/alpha', output_alpha, tensorboard_iter) trimap_adaption_res = (t_argmax.type(torch.FloatTensor) / 2).unsqueeze(dim=1) trimap_adaption_res = torchvision.utils.make_grid(trimap_adaption_res, normalize=False, scale_each=True) writer.add_image('pred/trimap_adaptation', trimap_adaption_res, tensorboard_iter) alpha_estimation_res = torchvision.utils.make_grid(alpha_estimation, normalize=False, scale_each=True) writer.add_image('pred/alpha_estimation', alpha_estimation_res, tensorboard_iter) gt_alpha = gt_alpha gt_alpha = torchvision.utils.make_grid(gt_alpha, normalize=False, scale_each=True) writer.add_image('gt/alpha', gt_alpha, tensorboard_iter) gt_trimap = (gt_trimap.type(torch.FloatTensor) / 2).unsqueeze(dim=1) gt_trimap = torchvision.utils.make_grid(gt_trimap, normalize=False, scale_each=True) writer.add_image('gt/trimap', gt_trimap, tensorboard_iter) pbar.update() logger.info("Average loss overall: {:.4e}".format(avg_loss.avg)) logger.info("Average loss of trimap adaptation: {:.4e}".format(avg_l_t.avg)) logger.info("Average loss of alpha estimation: {:.4e}".format(avg_l_a.avg)) writer.add_scalar("valid_loss/L_overall", avg_loss.avg, tensorboard_iter) writer.add_scalar("valid_loss/L_t", avg_l_t.avg, tensorboard_iter) writer.add_scalar("valid_loss/L_a", avg_l_a.avg, tensorboard_iter) is_best = avg_loss.avg < best_loss best_loss = min(avg_loss.avg, best_loss) is_alpha_best = avg_l_a.avg < best_alpha_loss best_alpha_loss = min(avg_l_a.avg, best_alpha_loss) if is_best or is_alpha_best or args.save_ckpt: if not os.path.exists("ckpts"): os.makedirs("ckpts") save_checkpoint(ckpt_path=args.ckpt_path, is_best=is_best, is_alpha_best=is_alpha_best, logger=logger, model=model, optimizer=optimizer, epoch=epoch, cur_iter=cur_iter, peak_lr=peak_lr, best_loss=best_loss, best_alpha_loss=best_alpha_loss) writer.close()
def train(args, model, optimizer, train_loader, epoch, logger): t0 = time.time() model.train() #fout = open("train_loss.txt",'w') for iteration, batch in enumerate(train_loader, 1): torch.cuda.empty_cache() img = Variable(batch[0]) alpha = Variable(batch[1]) fg = Variable(batch[2]) bg = Variable(batch[3]) trimap = Variable(batch[4]) img_norm = Variable(batch[6]) gts = Variable(batch[7]) img_info = batch[-1] if args.cuda: img = img.cuda() gt_alpha = (gts[:, 0, :, :].unsqueeze(1)).type( torch.FloatTensor).cuda() # [bs, 1, 320, 320] gt_trimap = gts[:, 1, :, :].type( torch.LongTensor).cuda() # [bs, 320, 320] alpha = alpha.cuda() fg = fg.cuda() bg = bg.cuda() trimap = trimap.cuda() img_norm = img_norm.cuda() for i in range(gt_alpha.size(0)): torchvision.utils.save_image(gt_alpha[i, :, :, :], '{}_gt_alpha.png'.format(i)) for i in range(gt_trimap.size(0)): torchvision.utils.save_image(gt_trimap[i, :, :, :], '{}_gt_trimap.png'.format(i)) # print("Shape: \nImg:{} \nImg Norm:{} \nAlpha:{} \nFg:{} \nBg:{} \nTrimap:{} \ngt_Trimap:{} \ngt_alpha:{}".format(img.shape, img_norm.shape, alpha.shape, fg.shape, bg.shape, trimap.shape, gt_trimap.shape, gt_alpha.shape)) # print("Val: Img:{} Alpha:{} Fg:{} Bg:{} Trimap:{} Img_info".format(img, alpha, fg, bg, trimap, img_info)) lr_scheduler(args, optimizer=optimizer, init_lr=args.lr, cur_iter=args.cur_iter, max_decay_times=40, decay_rate=0.9) optimizer.zero_grad() trimap_adaption, t_argmax, alpha_estimation, log_sigma_t_sqr, log_sigma_a_sqr = model( torch.cat((img_norm, trimap / 255.), 1)) for i in range(alpha_estimation.size(0)): torchvision.utils.save_image(alpha_estimation[i, :, :, :], '{}_pred_alpha.png'.format(i)) for i in range(trimap_adaption.size(0)): torchvision.utils.save_image(trimap_adaption[i, :, :, :], '{}_pred_trimap.png'.format(i)) # print(trimap_adaption.shape, t_argmax.shape, alpha_estimation.shape, log_sigma_t_sqr.shape, log_sigma_a_sqr.shape) L_overall, L_t, L_a = task_uncertainty_loss( pred_trimap=trimap_adaption, input_trimap_argmax=trimap, pred_alpha=alpha_estimation, gt_trimap=gt_trimap, gt_alpha=gt_alpha, log_sigma_t_sqr=log_sigma_t_sqr, log_sigma_a_sqr=log_sigma_a_sqr) # print(L_overall, L_a, L_t) sigma_t, sigma_a = torch.exp(log_sigma_t_sqr.mean() / 2), torch.exp( log_sigma_a_sqr.mean() / 2) optimizer.zero_grad() L_overall.backward() optimizer.step() if args.cur_iter % args.printFreq == 0: t1 = time.time() num_iter = len(train_loader) speed = (t1 - t0) / iteration # exp_time = format_second(speed * (num_iter * (args.epochs - epoch + 1) - iteration)) logger.info( "Epoch: {:03d} | Iter: {:05d}/{} | Loss: {:.4e} | L_t: {:.4e} | L_a: {:.4e}" .format(epoch, args.cur_iter, len(train_loader), L_overall.item(), L_t.item(), L_a.item())) args.cur_iter += 1
def train(model, optimizer, device, args, logger, multi_gpu): torch.manual_seed(7) writer = SummaryWriter() logger.info("Initializing data loaders") train_dataset = AdaMattingDataset(args.raw_data_path, 'train') train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=16, pin_memory=True) valid_dataset = AdaMattingDataset(args.raw_data_path, 'valid') valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16, pin_memory=True) if args.resume: logger.info("Start training from saved ckpt") ckpt = torch.load(args.ckpt_path) model = ckpt["model"].module model = model.to(device) optimizer = ckpt["optimizer"] start_epoch = ckpt["epoch"] + 1 max_iter = ckpt["max_iter"] cur_iter = ckpt["cur_iter"] init_lr = ckpt["init_lr"] best_loss = ckpt["best_loss"] else: logger.info("Start training from scratch") start_epoch = 0 max_iter = 43100 * (1 - args.valid_portion) / args.batch_size * args.epochs cur_iter = 0 init_lr = args.lr best_loss = float('inf') for epoch in range(start_epoch, args.epochs): # Training torch.set_grad_enabled(True) model.train() for index, (img, gt) in enumerate(train_loader): cur_lr = poly_lr_scheduler(optimizer=optimizer, init_lr=init_lr, iter=cur_iter, max_iter=max_iter) img = img.type(torch.FloatTensor).to(device) # [bs, 4, 320, 320] gt_alpha = (gt[:, 0, :, :].unsqueeze(1)).type(torch.FloatTensor).to(device) # [bs, 1, 320, 320] gt_trimap = gt[:, 1, :, :].type(torch.LongTensor).to(device) # [bs, 320, 320] optimizer.zero_grad() trimap_adaption, t_argmax, alpha_estimation = model(img) L_overall, L_t, L_a = task_uncertainty_loss(pred_trimap=trimap_adaption, pred_trimap_argmax=t_argmax, pred_alpha=alpha_estimation, gt_trimap=gt_trimap, gt_alpha=gt_alpha, log_sigma_t_sqr=model.log_sigma_t_sqr, log_sigma_a_sqr=model.log_sigma_a_sqr) # if multi_gpu: # L_overall, L_t, L_a = L_overall.mean(), L_t.mean(), L_a.mean() optimizer.zero_grad() L_overall.backward() optimizer.step() if cur_iter % 10 == 0: logger.info("Epoch: {:03d} | Iter: {:05d}/{} | Loss: {:.4e} | L_t: {:.4e} | L_a: {:.4e}" .format(epoch, index, len(train_loader), L_overall.item(), L_t.item(), L_a.item())) writer.add_scalar("loss/L_overall", L_overall.item(), cur_iter) writer.add_scalar("loss/L_t", L_t.item(), cur_iter) writer.add_scalar("loss/L_a", L_a.item(), cur_iter) sigma_t = torch.exp(model.log_sigma_t_sqr / 2) sigma_a = torch.exp(model.log_sigma_a_sqr / 2) writer.add_scalar("sigma/sigma_t", sigma_t, cur_iter) writer.add_scalar("sigma/sigma_a", sigma_a, cur_iter) writer.add_scalar("lr", cur_lr, cur_iter) cur_iter += 1 # Validation logger.info("Validating after the {}th epoch".format(epoch)) avg_loss = AverageMeter() avg_l_t = AverageMeter() avg_l_a = AverageMeter() torch.cuda.empty_cache() torch.set_grad_enabled(False) model.eval() with tqdm(total=len(valid_loader)) as pbar: for index, (img, gt) in enumerate(valid_loader): img = img.type(torch.FloatTensor).to(device) # [bs, 4, 320, 320] gt_alpha = (gt[:, 0, :, :].unsqueeze(1)).type(torch.FloatTensor).to(device) # [bs, 1, 320, 320] gt_trimap = gt[:, 1, :, :].type(torch.LongTensor).to(device) # [bs, 320, 320] trimap_adaption, t_argmax, alpha_estimation = model(img) L_overall_valid, L_t_valid, L_a_valid = task_uncertainty_loss(pred_trimap=trimap_adaption, pred_trimap_argmax=t_argmax, pred_alpha=alpha_estimation, gt_trimap=gt_trimap, gt_alpha=gt_alpha, log_sigma_t_sqr=model.log_sigma_t_sqr, log_sigma_a_sqr=model.log_sigma_a_sqr) # if multi_gpu: # L_overall, L_t, L_a = L_overall.mean(), L_t.mean(), L_a.mean() avg_loss.update(L_overall_valid.item()) avg_l_t.update(L_t_valid.item()) avg_l_a.update(L_a_valid.item()) if index == 0: trimap_adaption_res = torchvision.utils.make_grid(t_argmax.type(torch.FloatTensor) / 2, normalize=True, scale_each=True) writer.add_image('valid_image/trimap_adaptation', trimap_adaption_res, cur_iter) alpha_estimation_res = torchvision.utils.make_grid(alpha_estimation, normalize=True, scale_each=True) writer.add_image('valid_image/alpha_estimation', alpha_estimation_res, cur_iter) pbar.update() logger.info("Average loss overall: {:.4e}".format(avg_loss.avg)) logger.info("Average loss of trimap adaptation: {:.4e}".format(avg_l_t.avg)) logger.info("Average loss of alpha estimation: {:.4e}".format(avg_l_a.avg)) writer.add_scalar("valid_loss/L_overall", avg_loss.avg, cur_iter) writer.add_scalar("valid_loss/L_t", avg_l_t.avg, cur_iter) writer.add_scalar("valid_loss/L_a", avg_l_a.avg, cur_iter) is_best = avg_loss.avg < best_loss best_loss = min(avg_loss.avg, best_loss) if is_best or (args.save_ckpt and epoch % 10 == 0): if not os.path.exists("ckpts"): os.makedirs("ckpts") logger.info("Checkpoint saved") if (is_best): logger.info("Best checkpoint saved") save_checkpoint(epoch, model, optimizer, cur_iter, max_iter, init_lr, avg_loss.avg, is_best, args.ckpt_path) writer.export_scalars_to_json("./all_scalars.json") writer.close()
def train(args, logger, device_ids): torch.manual_seed(7) writer = SummaryWriter() logger.info("Loading network") model = AdaMatting(in_channel=4) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=0) if args.resume != "": ckpt = torch.load(args.resume) # for key, _ in ckpt.items(): # print(key) model.load_state_dict(ckpt["state_dict"]) optimizer.load_state_dict(ckpt["optimizer"]) if args.cuda: for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() device = torch.device("cuda:{}".format(device_ids[0])) if len(device_ids) > 1: logger.info("Loading with multiple GPUs") model = torch.nn.DataParallel(model, device_ids=device_ids) # model = model.cuda(device=device_ids[0]) else: device = torch.device("cpu") model = model.to(device) logger.info("Initializing data loaders") train_dataset = AdaMattingDataset(args.raw_data_path, "train") train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=16, pin_memory=True) valid_dataset = AdaMattingDataset(args.raw_data_path, "valid") valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16, pin_memory=True) if args.resume != "": logger.info("Start training from saved ckpt") start_epoch = ckpt["epoch"] + 1 cur_iter = ckpt["cur_iter"] + 1 peak_lr = ckpt["peak_lr"] best_loss = ckpt["best_loss"] else: logger.info("Start training from scratch") start_epoch = 0 cur_iter = 0 peak_lr = args.lr best_loss = float('inf') avg_lo = AverageMeter() avg_lt = AverageMeter() avg_la = AverageMeter() for epoch in range(start_epoch, args.epochs): # Training torch.set_grad_enabled(True) model.train() for index, (img, gt) in enumerate(train_loader): cur_lr, peak_lr = lr_scheduler(optimizer=optimizer, cur_iter=cur_iter, peak_lr=peak_lr, end_lr=0.00001, decay_iters=args.decay_iters, decay_power=0.9, power=0.9) img = img.type(torch.FloatTensor).to(device) # [bs, 4, 320, 320] gt_alpha = (gt[:, 0, :, :].unsqueeze(1)).type(torch.FloatTensor).to( device) # [bs, 1, 320, 320] gt_trimap = gt[:, 1, :, :].type(torch.LongTensor).to( device) # [bs, 320, 320] optimizer.zero_grad() trimap_adaption, t_argmax, alpha_estimation, log_sigma_t_sqr, log_sigma_a_sqr = model( img) L_overall, L_t, L_a = task_uncertainty_loss( pred_trimap=trimap_adaption, pred_trimap_argmax=t_argmax, pred_alpha=alpha_estimation, gt_trimap=gt_trimap, gt_alpha=gt_alpha, log_sigma_t_sqr=log_sigma_t_sqr, log_sigma_a_sqr=log_sigma_a_sqr) L_overall, L_t, L_a = L_overall.mean(), L_t.mean(), L_a.mean() sigma_t, sigma_a = log_sigma_t_sqr.mean(), log_sigma_a_sqr.mean() optimizer.zero_grad() L_overall.backward() optimizer.step() avg_lo.update(L_overall.item()) avg_lt.update(L_t.item()) avg_la.update(L_a.item()) if cur_iter % 10 == 0: logger.info( "Epoch: {:03d} | Iter: {:05d}/{} | Loss: {:.4e} | L_t: {:.4e} | L_a: {:.4e}" .format(epoch, index, len(train_loader), avg_lo.avg, avg_lt.avg, avg_la.avg)) writer.add_scalar("loss/L_overall", avg_lo.avg, cur_iter) writer.add_scalar("loss/L_t", avg_lt.avg, cur_iter) writer.add_scalar("loss/L_a", avg_la.avg, cur_iter) sigma_t = torch.exp(sigma_t / 2) sigma_a = torch.exp(sigma_a / 2) writer.add_scalar("other/sigma_t", sigma_t.item(), cur_iter) writer.add_scalar("other/sigma_a", sigma_a.item(), cur_iter) writer.add_scalar("other/lr", cur_lr, cur_iter) avg_lo.reset() avg_lt.reset() avg_la.reset() cur_iter += 1 # Validation logger.info("Validating after the {}th epoch".format(epoch)) avg_loss = AverageMeter() avg_l_t = AverageMeter() avg_l_a = AverageMeter() torch.cuda.empty_cache() torch.set_grad_enabled(False) model.eval() with tqdm(total=len(valid_loader)) as pbar: for index, (img, gt) in enumerate(valid_loader): img = img.type(torch.FloatTensor).to( device) # [bs, 4, 320, 320] gt_alpha = (gt[:, 0, :, :].unsqueeze(1)).type( torch.FloatTensor).to(device) # [bs, 1, 320, 320] gt_trimap = gt[:, 1, :, :].type(torch.LongTensor).to( device) # [bs, 320, 320] trimap_adaption, t_argmax, alpha_estimation, log_sigma_t_sqr, log_sigma_a_sqr = model( img) L_overall_valid, L_t_valid, L_a_valid = task_uncertainty_loss( pred_trimap=trimap_adaption, pred_trimap_argmax=t_argmax, pred_alpha=alpha_estimation, gt_trimap=gt_trimap, gt_alpha=gt_alpha, log_sigma_t_sqr=log_sigma_t_sqr, log_sigma_a_sqr=log_sigma_a_sqr) L_overall_valid, L_t_valid, L_a_valid = L_overall_valid.mean( ), L_t_valid.mean(), L_a_valid.mean() avg_loss.update(L_overall_valid.item()) avg_l_t.update(L_t_valid.item()) avg_l_a.update(L_a_valid.item()) if index == 0: trimap_adaption_res = (t_argmax.type(torch.FloatTensor) / 2).unsqueeze(dim=1) trimap_adaption_res = torchvision.utils.make_grid( trimap_adaption_res, normalize=False, scale_each=True) writer.add_image('valid_image/trimap_adaptation', trimap_adaption_res, cur_iter) alpha_estimation_res = torchvision.utils.make_grid( alpha_estimation, normalize=True, scale_each=True) writer.add_image('valid_image/alpha_estimation', alpha_estimation_res, cur_iter) pbar.update() logger.info("Average loss overall: {:.4e}".format(avg_loss.avg)) logger.info("Average loss of trimap adaptation: {:.4e}".format( avg_l_t.avg)) logger.info("Average loss of alpha estimation: {:.4e}".format( avg_l_a.avg)) writer.add_scalar("valid_loss/L_overall", avg_loss.avg, cur_iter) writer.add_scalar("valid_loss/L_t", avg_l_t.avg, cur_iter) writer.add_scalar("valid_loss/L_a", avg_l_a.avg, cur_iter) is_best = avg_loss.avg < best_loss best_loss = min(avg_loss.avg, best_loss) if is_best or args.save_ckpt: if not os.path.exists("ckpts"): os.makedirs("ckpts") save_checkpoint(ckpt_path=args.ckpt_path, is_best=is_best, logger=logger, model=model, optimizer=optimizer, epoch=epoch, cur_iter=cur_iter, peak_lr=peak_lr, best_loss=best_loss) writer.close()