Esempio n. 1
0
def train(args, logger, device_ids):
    writer = SummaryWriter()

    logger.info("Loading network")
    model = AdaMatting(in_channel=4)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.0001)
    if args.resume != "":
        ckpt = torch.load(args.resume)
        model.load_state_dict(ckpt["state_dict"])
        optimizer.load_state_dict(ckpt["optimizer"])
    if args.cuda:
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()
        device = torch.device("cuda:{}".format(device_ids[0]))
        if len(device_ids) > 1:
            logger.info("Loading with multiple GPUs")
            model = torch.nn.DataParallel(model, device_ids=device_ids)
            # model = convert_model(model)
    else:
        device = torch.device("cpu")
    model = model.to(device)

    logger.info("Initializing data loaders")
    train_dataset = AdaMattingDataset(args.raw_data_path, "train")
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, 
                                               num_workers=16, pin_memory=True, drop_last=True)
    valid_dataset = AdaMattingDataset(args.raw_data_path, "valid")
    valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, 
                                               num_workers=16, pin_memory=True, drop_last=True)

    if args.resume != "":
        logger.info("Start training from saved ckpt")
        start_epoch = ckpt["epoch"] + 1
        cur_iter = ckpt["cur_iter"]
        peak_lr = ckpt["peak_lr"]
        best_loss = ckpt["best_loss"]
        best_alpha_loss = ckpt["best_alpha_loss"]
    else:
        logger.info("Start training from scratch")
        start_epoch = 0
        cur_iter = 0
        peak_lr = args.lr
        best_loss = float('inf')
        best_alpha_loss = float('inf')

    max_iter = 43100 * (1 - args.valid_portion / 100) / args.batch_size * args.epochs
    tensorboard_iter = cur_iter * (args.batch_size / 16)

    avg_lo = AverageMeter()
    avg_lt = AverageMeter()
    avg_la = AverageMeter()
    for epoch in range(start_epoch, args.epochs):
        # Training
        torch.set_grad_enabled(True)
        model.train()
        for index, (_, inputs, gts) in enumerate(train_loader):
            # cur_lr, peak_lr = lr_scheduler(optimizer=optimizer, cur_iter=cur_iter, peak_lr=peak_lr, end_lr=0.000001, 
            #                                decay_iters=args.decay_iters, decay_power=0.8, power=0.5)
            cur_lr = lr_scheduler(optimizer=optimizer, init_lr=args.lr, cur_iter=cur_iter, max_iter=max_iter, 
                                  max_decay_times=30, decay_rate=0.9)
            
            # img = img.type(torch.FloatTensor).to(device) # [bs, 4, 320, 320]
            inputs = inputs.to(device)
            gt_alpha = (gts[:, 0, :, :].unsqueeze(1)).type(torch.FloatTensor).to(device) # [bs, 1, 320, 320]
            gt_trimap = gts[:, 1, :, :].type(torch.LongTensor).to(device) # [bs, 320, 320]

            optimizer.zero_grad()
            trimap_adaption, t_argmax, alpha_estimation, log_sigma_t_sqr, log_sigma_a_sqr = model(inputs)

            L_overall, L_t, L_a = task_uncertainty_loss(pred_trimap=trimap_adaption, pred_trimap_argmax=t_argmax, 
                                                        pred_alpha=alpha_estimation, gt_trimap=gt_trimap, gt_alpha=gt_alpha, 
                                                        log_sigma_t_sqr=log_sigma_t_sqr, log_sigma_a_sqr=log_sigma_a_sqr)

            sigma_t, sigma_a = torch.exp(log_sigma_t_sqr.mean() / 2), torch.exp(log_sigma_a_sqr.mean() / 2)

            optimizer.zero_grad()
            L_overall.backward()
            clip_gradient(optimizer, 5)
            optimizer.step()

            avg_lo.update(L_overall.item())
            avg_lt.update(L_t.item())
            avg_la.update(L_a.item())

            if cur_iter % 10 == 0:
                logger.info("Epoch: {:03d} | Iter: {:05d}/{} | Loss: {:.4e} | L_t: {:.4e} | L_a: {:.4e}"
                            .format(epoch, index, len(train_loader), avg_lo.avg, avg_lt.avg, avg_la.avg))
                writer.add_scalar("loss/L_overall", avg_lo.avg, tensorboard_iter)
                writer.add_scalar("loss/L_t", avg_lt.avg, tensorboard_iter)
                writer.add_scalar("loss/L_a", avg_la.avg, tensorboard_iter)
                writer.add_scalar("other/sigma_t", sigma_t.item(), tensorboard_iter)
                writer.add_scalar("other/sigma_a", sigma_a.item(), tensorboard_iter)
                writer.add_scalar("other/lr", cur_lr, tensorboard_iter)

                avg_lo.reset()
                avg_lt.reset()
                avg_la.reset()
                
            cur_iter += 1
            tensorboard_iter = cur_iter * (args.batch_size / 16)

        # Validation
        logger.info("Validating after the {}th epoch".format(epoch))
        avg_loss = AverageMeter()
        avg_l_t = AverageMeter()
        avg_l_a = AverageMeter()
        torch.cuda.empty_cache()
        torch.set_grad_enabled(False)
        model.eval()
        with tqdm(total=len(valid_loader)) as pbar:
            for index, (display_rgb, inputs, gts) in enumerate(valid_loader):
                inputs = inputs.to(device) # [bs, 4, 320, 320]
                gt_alpha = (gts[:, 0, :, :].unsqueeze(1)).type(torch.FloatTensor).to(device) # [bs, 1, 320, 320]
                gt_trimap = gts[:, 1, :, :].type(torch.LongTensor).to(device) # [bs, 320, 320]

                trimap_adaption, t_argmax, alpha_estimation, log_sigma_t_sqr, log_sigma_a_sqr = model(inputs)
                L_overall_valid, L_t_valid, L_a_valid = task_uncertainty_loss(pred_trimap=trimap_adaption, pred_trimap_argmax=t_argmax, 
                                                            pred_alpha=alpha_estimation, gt_trimap=gt_trimap, gt_alpha=gt_alpha, 
                                                            log_sigma_t_sqr=log_sigma_t_sqr, log_sigma_a_sqr=log_sigma_a_sqr)

                avg_loss.update(L_overall_valid.item())
                avg_l_t.update(L_t_valid.item())
                avg_l_a.update(L_a_valid.item())

                if index == 0:
                    input_rbg = torchvision.utils.make_grid(display_rgb, normalize=False, scale_each=True)
                    writer.add_image('input/rbg_image', input_rbg, tensorboard_iter)

                    input_trimap = inputs[:, 3, :, :].unsqueeze(dim=1)
                    input_trimap = torchvision.utils.make_grid(input_trimap, normalize=False, scale_each=True)
                    writer.add_image('input/trimap', input_trimap, tensorboard_iter)

                    output_alpha = alpha_estimation.clone()
                    output_alpha[t_argmax.unsqueeze(dim=1) == 0] = 0.0
                    output_alpha[t_argmax.unsqueeze(dim=1) == 2] = 1.0
                    output_alpha = torchvision.utils.make_grid(output_alpha, normalize=False, scale_each=True)
                    writer.add_image('output/alpha', output_alpha, tensorboard_iter)

                    trimap_adaption_res = (t_argmax.type(torch.FloatTensor) / 2).unsqueeze(dim=1)
                    trimap_adaption_res = torchvision.utils.make_grid(trimap_adaption_res, normalize=False, scale_each=True)
                    writer.add_image('pred/trimap_adaptation', trimap_adaption_res, tensorboard_iter)

                    alpha_estimation_res = torchvision.utils.make_grid(alpha_estimation, normalize=False, scale_each=True)
                    writer.add_image('pred/alpha_estimation', alpha_estimation_res, tensorboard_iter)

                    gt_alpha = gt_alpha
                    gt_alpha = torchvision.utils.make_grid(gt_alpha, normalize=False, scale_each=True)
                    writer.add_image('gt/alpha', gt_alpha, tensorboard_iter)

                    gt_trimap = (gt_trimap.type(torch.FloatTensor) / 2).unsqueeze(dim=1)
                    gt_trimap = torchvision.utils.make_grid(gt_trimap, normalize=False, scale_each=True)
                    writer.add_image('gt/trimap', gt_trimap, tensorboard_iter)
                    
                pbar.update()

        logger.info("Average loss overall: {:.4e}".format(avg_loss.avg))
        logger.info("Average loss of trimap adaptation: {:.4e}".format(avg_l_t.avg))
        logger.info("Average loss of alpha estimation: {:.4e}".format(avg_l_a.avg))
        writer.add_scalar("valid_loss/L_overall", avg_loss.avg, tensorboard_iter)
        writer.add_scalar("valid_loss/L_t", avg_l_t.avg, tensorboard_iter)
        writer.add_scalar("valid_loss/L_a", avg_l_a.avg, tensorboard_iter)

        is_best = avg_loss.avg < best_loss
        best_loss = min(avg_loss.avg, best_loss)
        is_alpha_best = avg_l_a.avg < best_alpha_loss
        best_alpha_loss = min(avg_l_a.avg, best_alpha_loss)
        if is_best or is_alpha_best or args.save_ckpt:
            if not os.path.exists("ckpts"):
                os.makedirs("ckpts")
            save_checkpoint(ckpt_path=args.ckpt_path, is_best=is_best, is_alpha_best=is_alpha_best, logger=logger, model=model, optimizer=optimizer, 
                            epoch=epoch, cur_iter=cur_iter, peak_lr=peak_lr, best_loss=best_loss, best_alpha_loss=best_alpha_loss)

    writer.close()
Esempio n. 2
0
def train(args, model, optimizer, train_loader, epoch, logger):
    t0 = time.time()
    model.train()

    #fout = open("train_loss.txt",'w')
    for iteration, batch in enumerate(train_loader, 1):
        torch.cuda.empty_cache()

        img = Variable(batch[0])
        alpha = Variable(batch[1])
        fg = Variable(batch[2])
        bg = Variable(batch[3])
        trimap = Variable(batch[4])
        img_norm = Variable(batch[6])
        gts = Variable(batch[7])
        img_info = batch[-1]

        if args.cuda:
            img = img.cuda()
            gt_alpha = (gts[:, 0, :, :].unsqueeze(1)).type(
                torch.FloatTensor).cuda()  # [bs, 1, 320, 320]
            gt_trimap = gts[:, 1, :, :].type(
                torch.LongTensor).cuda()  # [bs, 320, 320]
            alpha = alpha.cuda()
            fg = fg.cuda()
            bg = bg.cuda()
            trimap = trimap.cuda()
            img_norm = img_norm.cuda()

        for i in range(gt_alpha.size(0)):
            torchvision.utils.save_image(gt_alpha[i, :, :, :],
                                         '{}_gt_alpha.png'.format(i))
        for i in range(gt_trimap.size(0)):
            torchvision.utils.save_image(gt_trimap[i, :, :, :],
                                         '{}_gt_trimap.png'.format(i))

        # print("Shape: \nImg:{} \nImg Norm:{} \nAlpha:{} \nFg:{} \nBg:{} \nTrimap:{} \ngt_Trimap:{} \ngt_alpha:{}".format(img.shape, img_norm.shape, alpha.shape, fg.shape, bg.shape, trimap.shape, gt_trimap.shape, gt_alpha.shape))
        # print("Val: Img:{} Alpha:{} Fg:{} Bg:{} Trimap:{} Img_info".format(img, alpha, fg, bg, trimap, img_info))

        lr_scheduler(args,
                     optimizer=optimizer,
                     init_lr=args.lr,
                     cur_iter=args.cur_iter,
                     max_decay_times=40,
                     decay_rate=0.9)
        optimizer.zero_grad()

        trimap_adaption, t_argmax, alpha_estimation, log_sigma_t_sqr, log_sigma_a_sqr = model(
            torch.cat((img_norm, trimap / 255.), 1))
        for i in range(alpha_estimation.size(0)):
            torchvision.utils.save_image(alpha_estimation[i, :, :, :],
                                         '{}_pred_alpha.png'.format(i))
        for i in range(trimap_adaption.size(0)):
            torchvision.utils.save_image(trimap_adaption[i, :, :, :],
                                         '{}_pred_trimap.png'.format(i))
        # print(trimap_adaption.shape, t_argmax.shape, alpha_estimation.shape, log_sigma_t_sqr.shape, log_sigma_a_sqr.shape)

        L_overall, L_t, L_a = task_uncertainty_loss(
            pred_trimap=trimap_adaption,
            input_trimap_argmax=trimap,
            pred_alpha=alpha_estimation,
            gt_trimap=gt_trimap,
            gt_alpha=gt_alpha,
            log_sigma_t_sqr=log_sigma_t_sqr,
            log_sigma_a_sqr=log_sigma_a_sqr)
        # print(L_overall, L_a, L_t)

        sigma_t, sigma_a = torch.exp(log_sigma_t_sqr.mean() / 2), torch.exp(
            log_sigma_a_sqr.mean() / 2)

        optimizer.zero_grad()
        L_overall.backward()
        optimizer.step()

        if args.cur_iter % args.printFreq == 0:
            t1 = time.time()
            num_iter = len(train_loader)
            speed = (t1 - t0) / iteration
            # exp_time = format_second(speed * (num_iter * (args.epochs - epoch + 1) - iteration))

            logger.info(
                "Epoch: {:03d} | Iter: {:05d}/{} | Loss: {:.4e} | L_t: {:.4e} | L_a: {:.4e}"
                .format(epoch, args.cur_iter, len(train_loader),
                        L_overall.item(), L_t.item(), L_a.item()))
        args.cur_iter += 1
Esempio n. 3
0
def train(model, optimizer, device, args, logger, multi_gpu):
    torch.manual_seed(7)
    writer = SummaryWriter()

    logger.info("Initializing data loaders")
    train_dataset = AdaMattingDataset(args.raw_data_path, 'train')
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, 
                                               num_workers=16, pin_memory=True)
    valid_dataset = AdaMattingDataset(args.raw_data_path, 'valid')
    valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, 
                                               num_workers=16, pin_memory=True)

    if args.resume:
        logger.info("Start training from saved ckpt")
        ckpt = torch.load(args.ckpt_path)
        model = ckpt["model"].module
        model = model.to(device)
        optimizer = ckpt["optimizer"]

        start_epoch = ckpt["epoch"] + 1
        max_iter = ckpt["max_iter"]
        cur_iter = ckpt["cur_iter"]
        init_lr = ckpt["init_lr"]
        best_loss = ckpt["best_loss"]
    else:
        logger.info("Start training from scratch")
        start_epoch = 0
        max_iter = 43100 * (1 - args.valid_portion) / args.batch_size * args.epochs
        cur_iter = 0
        init_lr = args.lr
        best_loss = float('inf')
    
    for epoch in range(start_epoch, args.epochs):
        # Training
        torch.set_grad_enabled(True)
        model.train()
        for index, (img, gt) in enumerate(train_loader):
            cur_lr = poly_lr_scheduler(optimizer=optimizer, init_lr=init_lr, iter=cur_iter, max_iter=max_iter)

            img = img.type(torch.FloatTensor).to(device) # [bs, 4, 320, 320]
            gt_alpha = (gt[:, 0, :, :].unsqueeze(1)).type(torch.FloatTensor).to(device) # [bs, 1, 320, 320]
            gt_trimap = gt[:, 1, :, :].type(torch.LongTensor).to(device) # [bs, 320, 320]

            optimizer.zero_grad()
            trimap_adaption, t_argmax, alpha_estimation = model(img)
            L_overall, L_t, L_a = task_uncertainty_loss(pred_trimap=trimap_adaption, pred_trimap_argmax=t_argmax, 
                                                        pred_alpha=alpha_estimation, gt_trimap=gt_trimap, 
                                                        gt_alpha=gt_alpha, log_sigma_t_sqr=model.log_sigma_t_sqr, log_sigma_a_sqr=model.log_sigma_a_sqr)
            # if multi_gpu:
            #     L_overall, L_t, L_a = L_overall.mean(), L_t.mean(), L_a.mean()
            optimizer.zero_grad()
            L_overall.backward()
            optimizer.step()

            if cur_iter % 10 == 0:
                logger.info("Epoch: {:03d} | Iter: {:05d}/{} | Loss: {:.4e} | L_t: {:.4e} | L_a: {:.4e}"
                            .format(epoch, index, len(train_loader), L_overall.item(), L_t.item(), L_a.item()))
                writer.add_scalar("loss/L_overall", L_overall.item(), cur_iter)
                writer.add_scalar("loss/L_t", L_t.item(), cur_iter)
                writer.add_scalar("loss/L_a", L_a.item(), cur_iter)
                sigma_t = torch.exp(model.log_sigma_t_sqr / 2)
                sigma_a = torch.exp(model.log_sigma_a_sqr / 2)
                writer.add_scalar("sigma/sigma_t", sigma_t, cur_iter)
                writer.add_scalar("sigma/sigma_a", sigma_a, cur_iter)
                writer.add_scalar("lr", cur_lr, cur_iter)
            
            cur_iter += 1
        
        # Validation
        logger.info("Validating after the {}th epoch".format(epoch))
        avg_loss = AverageMeter()
        avg_l_t = AverageMeter()
        avg_l_a = AverageMeter()
        torch.cuda.empty_cache()
        torch.set_grad_enabled(False)
        model.eval()
        with tqdm(total=len(valid_loader)) as pbar:
            for index, (img, gt) in enumerate(valid_loader):
                img = img.type(torch.FloatTensor).to(device) # [bs, 4, 320, 320]
                gt_alpha = (gt[:, 0, :, :].unsqueeze(1)).type(torch.FloatTensor).to(device) # [bs, 1, 320, 320]
                gt_trimap = gt[:, 1, :, :].type(torch.LongTensor).to(device) # [bs, 320, 320]

                trimap_adaption, t_argmax, alpha_estimation = model(img)
                L_overall_valid, L_t_valid, L_a_valid = task_uncertainty_loss(pred_trimap=trimap_adaption, pred_trimap_argmax=t_argmax, 
                                                            pred_alpha=alpha_estimation, gt_trimap=gt_trimap, 
                                                            gt_alpha=gt_alpha, log_sigma_t_sqr=model.log_sigma_t_sqr, log_sigma_a_sqr=model.log_sigma_a_sqr)
                # if multi_gpu:
                #     L_overall, L_t, L_a = L_overall.mean(), L_t.mean(), L_a.mean()
                avg_loss.update(L_overall_valid.item())
                avg_l_t.update(L_t_valid.item())
                avg_l_a.update(L_a_valid.item())

                if index == 0:
                    trimap_adaption_res = torchvision.utils.make_grid(t_argmax.type(torch.FloatTensor) / 2, normalize=True, scale_each=True)
                    writer.add_image('valid_image/trimap_adaptation', trimap_adaption_res, cur_iter)
                    alpha_estimation_res = torchvision.utils.make_grid(alpha_estimation, normalize=True, scale_each=True)
                    writer.add_image('valid_image/alpha_estimation', alpha_estimation_res, cur_iter)
                
                pbar.update()

        logger.info("Average loss overall: {:.4e}".format(avg_loss.avg))
        logger.info("Average loss of trimap adaptation: {:.4e}".format(avg_l_t.avg))
        logger.info("Average loss of alpha estimation: {:.4e}".format(avg_l_a.avg))
        writer.add_scalar("valid_loss/L_overall", avg_loss.avg, cur_iter)
        writer.add_scalar("valid_loss/L_t", avg_l_t.avg, cur_iter)
        writer.add_scalar("valid_loss/L_a", avg_l_a.avg, cur_iter)

        is_best = avg_loss.avg < best_loss
        best_loss = min(avg_loss.avg, best_loss)
        if is_best or (args.save_ckpt and epoch % 10 == 0):
            if not os.path.exists("ckpts"):
                os.makedirs("ckpts")
            logger.info("Checkpoint saved")
            if (is_best):
                logger.info("Best checkpoint saved")
            save_checkpoint(epoch, model, optimizer, cur_iter, max_iter, init_lr, avg_loss.avg, is_best, args.ckpt_path)

    writer.export_scalars_to_json("./all_scalars.json")
    writer.close()
Esempio n. 4
0
def train(args, logger, device_ids):
    torch.manual_seed(7)
    writer = SummaryWriter()

    logger.info("Loading network")
    model = AdaMatting(in_channel=4)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=0)
    if args.resume != "":
        ckpt = torch.load(args.resume)
        # for key, _ in ckpt.items():
        #     print(key)
        model.load_state_dict(ckpt["state_dict"])
        optimizer.load_state_dict(ckpt["optimizer"])
    if args.cuda:
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()
        device = torch.device("cuda:{}".format(device_ids[0]))
        if len(device_ids) > 1:
            logger.info("Loading with multiple GPUs")
            model = torch.nn.DataParallel(model, device_ids=device_ids)
        # model = model.cuda(device=device_ids[0])
    else:
        device = torch.device("cpu")
    model = model.to(device)

    logger.info("Initializing data loaders")
    train_dataset = AdaMattingDataset(args.raw_data_path, "train")
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=16,
                                               pin_memory=True)
    valid_dataset = AdaMattingDataset(args.raw_data_path, "valid")
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=16,
                                               pin_memory=True)

    if args.resume != "":
        logger.info("Start training from saved ckpt")
        start_epoch = ckpt["epoch"] + 1
        cur_iter = ckpt["cur_iter"] + 1
        peak_lr = ckpt["peak_lr"]
        best_loss = ckpt["best_loss"]
    else:
        logger.info("Start training from scratch")
        start_epoch = 0
        cur_iter = 0
        peak_lr = args.lr
        best_loss = float('inf')

    avg_lo = AverageMeter()
    avg_lt = AverageMeter()
    avg_la = AverageMeter()
    for epoch in range(start_epoch, args.epochs):
        # Training
        torch.set_grad_enabled(True)
        model.train()
        for index, (img, gt) in enumerate(train_loader):
            cur_lr, peak_lr = lr_scheduler(optimizer=optimizer,
                                           cur_iter=cur_iter,
                                           peak_lr=peak_lr,
                                           end_lr=0.00001,
                                           decay_iters=args.decay_iters,
                                           decay_power=0.9,
                                           power=0.9)

            img = img.type(torch.FloatTensor).to(device)  # [bs, 4, 320, 320]
            gt_alpha = (gt[:,
                           0, :, :].unsqueeze(1)).type(torch.FloatTensor).to(
                               device)  # [bs, 1, 320, 320]
            gt_trimap = gt[:, 1, :, :].type(torch.LongTensor).to(
                device)  # [bs, 320, 320]

            optimizer.zero_grad()
            trimap_adaption, t_argmax, alpha_estimation, log_sigma_t_sqr, log_sigma_a_sqr = model(
                img)
            L_overall, L_t, L_a = task_uncertainty_loss(
                pred_trimap=trimap_adaption,
                pred_trimap_argmax=t_argmax,
                pred_alpha=alpha_estimation,
                gt_trimap=gt_trimap,
                gt_alpha=gt_alpha,
                log_sigma_t_sqr=log_sigma_t_sqr,
                log_sigma_a_sqr=log_sigma_a_sqr)

            L_overall, L_t, L_a = L_overall.mean(), L_t.mean(), L_a.mean()
            sigma_t, sigma_a = log_sigma_t_sqr.mean(), log_sigma_a_sqr.mean()

            optimizer.zero_grad()
            L_overall.backward()
            optimizer.step()

            avg_lo.update(L_overall.item())
            avg_lt.update(L_t.item())
            avg_la.update(L_a.item())

            if cur_iter % 10 == 0:
                logger.info(
                    "Epoch: {:03d} | Iter: {:05d}/{} | Loss: {:.4e} | L_t: {:.4e} | L_a: {:.4e}"
                    .format(epoch, index, len(train_loader), avg_lo.avg,
                            avg_lt.avg, avg_la.avg))
                writer.add_scalar("loss/L_overall", avg_lo.avg, cur_iter)
                writer.add_scalar("loss/L_t", avg_lt.avg, cur_iter)
                writer.add_scalar("loss/L_a", avg_la.avg, cur_iter)
                sigma_t = torch.exp(sigma_t / 2)
                sigma_a = torch.exp(sigma_a / 2)
                writer.add_scalar("other/sigma_t", sigma_t.item(), cur_iter)
                writer.add_scalar("other/sigma_a", sigma_a.item(), cur_iter)
                writer.add_scalar("other/lr", cur_lr, cur_iter)

                avg_lo.reset()
                avg_lt.reset()
                avg_la.reset()

            cur_iter += 1

        # Validation
        logger.info("Validating after the {}th epoch".format(epoch))
        avg_loss = AverageMeter()
        avg_l_t = AverageMeter()
        avg_l_a = AverageMeter()
        torch.cuda.empty_cache()
        torch.set_grad_enabled(False)
        model.eval()
        with tqdm(total=len(valid_loader)) as pbar:
            for index, (img, gt) in enumerate(valid_loader):
                img = img.type(torch.FloatTensor).to(
                    device)  # [bs, 4, 320, 320]
                gt_alpha = (gt[:, 0, :, :].unsqueeze(1)).type(
                    torch.FloatTensor).to(device)  # [bs, 1, 320, 320]
                gt_trimap = gt[:, 1, :, :].type(torch.LongTensor).to(
                    device)  # [bs, 320, 320]

                trimap_adaption, t_argmax, alpha_estimation, log_sigma_t_sqr, log_sigma_a_sqr = model(
                    img)
                L_overall_valid, L_t_valid, L_a_valid = task_uncertainty_loss(
                    pred_trimap=trimap_adaption,
                    pred_trimap_argmax=t_argmax,
                    pred_alpha=alpha_estimation,
                    gt_trimap=gt_trimap,
                    gt_alpha=gt_alpha,
                    log_sigma_t_sqr=log_sigma_t_sqr,
                    log_sigma_a_sqr=log_sigma_a_sqr)

                L_overall_valid, L_t_valid, L_a_valid = L_overall_valid.mean(
                ), L_t_valid.mean(), L_a_valid.mean()

                avg_loss.update(L_overall_valid.item())
                avg_l_t.update(L_t_valid.item())
                avg_l_a.update(L_a_valid.item())

                if index == 0:
                    trimap_adaption_res = (t_argmax.type(torch.FloatTensor) /
                                           2).unsqueeze(dim=1)
                    trimap_adaption_res = torchvision.utils.make_grid(
                        trimap_adaption_res, normalize=False, scale_each=True)
                    writer.add_image('valid_image/trimap_adaptation',
                                     trimap_adaption_res, cur_iter)
                    alpha_estimation_res = torchvision.utils.make_grid(
                        alpha_estimation, normalize=True, scale_each=True)
                    writer.add_image('valid_image/alpha_estimation',
                                     alpha_estimation_res, cur_iter)

                pbar.update()

        logger.info("Average loss overall: {:.4e}".format(avg_loss.avg))
        logger.info("Average loss of trimap adaptation: {:.4e}".format(
            avg_l_t.avg))
        logger.info("Average loss of alpha estimation: {:.4e}".format(
            avg_l_a.avg))
        writer.add_scalar("valid_loss/L_overall", avg_loss.avg, cur_iter)
        writer.add_scalar("valid_loss/L_t", avg_l_t.avg, cur_iter)
        writer.add_scalar("valid_loss/L_a", avg_l_a.avg, cur_iter)

        is_best = avg_loss.avg < best_loss
        best_loss = min(avg_loss.avg, best_loss)
        if is_best or args.save_ckpt:
            if not os.path.exists("ckpts"):
                os.makedirs("ckpts")
            save_checkpoint(ckpt_path=args.ckpt_path,
                            is_best=is_best,
                            logger=logger,
                            model=model,
                            optimizer=optimizer,
                            epoch=epoch,
                            cur_iter=cur_iter,
                            peak_lr=peak_lr,
                            best_loss=best_loss)

    writer.close()