Ejemplo n.º 1
0
def train(args, logger, device_ids):
    writer = SummaryWriter()

    logger.info("Loading network")
    model = AdaMatting(in_channel=4)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.0001)
    if args.resume != "":
        ckpt = torch.load(args.resume)
        model.load_state_dict(ckpt["state_dict"])
        optimizer.load_state_dict(ckpt["optimizer"])
    if args.cuda:
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()
        device = torch.device("cuda:{}".format(device_ids[0]))
        if len(device_ids) > 1:
            logger.info("Loading with multiple GPUs")
            model = torch.nn.DataParallel(model, device_ids=device_ids)
            # model = convert_model(model)
    else:
        device = torch.device("cpu")
    model = model.to(device)

    logger.info("Initializing data loaders")
    train_dataset = AdaMattingDataset(args.raw_data_path, "train")
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, 
                                               num_workers=16, pin_memory=True, drop_last=True)
    valid_dataset = AdaMattingDataset(args.raw_data_path, "valid")
    valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, 
                                               num_workers=16, pin_memory=True, drop_last=True)

    if args.resume != "":
        logger.info("Start training from saved ckpt")
        start_epoch = ckpt["epoch"] + 1
        cur_iter = ckpt["cur_iter"]
        peak_lr = ckpt["peak_lr"]
        best_loss = ckpt["best_loss"]
        best_alpha_loss = ckpt["best_alpha_loss"]
    else:
        logger.info("Start training from scratch")
        start_epoch = 0
        cur_iter = 0
        peak_lr = args.lr
        best_loss = float('inf')
        best_alpha_loss = float('inf')

    max_iter = 43100 * (1 - args.valid_portion / 100) / args.batch_size * args.epochs
    tensorboard_iter = cur_iter * (args.batch_size / 16)

    avg_lo = AverageMeter()
    avg_lt = AverageMeter()
    avg_la = AverageMeter()
    for epoch in range(start_epoch, args.epochs):
        # Training
        torch.set_grad_enabled(True)
        model.train()
        for index, (_, inputs, gts) in enumerate(train_loader):
            # cur_lr, peak_lr = lr_scheduler(optimizer=optimizer, cur_iter=cur_iter, peak_lr=peak_lr, end_lr=0.000001, 
            #                                decay_iters=args.decay_iters, decay_power=0.8, power=0.5)
            cur_lr = lr_scheduler(optimizer=optimizer, init_lr=args.lr, cur_iter=cur_iter, max_iter=max_iter, 
                                  max_decay_times=30, decay_rate=0.9)
            
            # img = img.type(torch.FloatTensor).to(device) # [bs, 4, 320, 320]
            inputs = inputs.to(device)
            gt_alpha = (gts[:, 0, :, :].unsqueeze(1)).type(torch.FloatTensor).to(device) # [bs, 1, 320, 320]
            gt_trimap = gts[:, 1, :, :].type(torch.LongTensor).to(device) # [bs, 320, 320]

            optimizer.zero_grad()
            trimap_adaption, t_argmax, alpha_estimation, log_sigma_t_sqr, log_sigma_a_sqr = model(inputs)

            L_overall, L_t, L_a = task_uncertainty_loss(pred_trimap=trimap_adaption, pred_trimap_argmax=t_argmax, 
                                                        pred_alpha=alpha_estimation, gt_trimap=gt_trimap, gt_alpha=gt_alpha, 
                                                        log_sigma_t_sqr=log_sigma_t_sqr, log_sigma_a_sqr=log_sigma_a_sqr)

            sigma_t, sigma_a = torch.exp(log_sigma_t_sqr.mean() / 2), torch.exp(log_sigma_a_sqr.mean() / 2)

            optimizer.zero_grad()
            L_overall.backward()
            clip_gradient(optimizer, 5)
            optimizer.step()

            avg_lo.update(L_overall.item())
            avg_lt.update(L_t.item())
            avg_la.update(L_a.item())

            if cur_iter % 10 == 0:
                logger.info("Epoch: {:03d} | Iter: {:05d}/{} | Loss: {:.4e} | L_t: {:.4e} | L_a: {:.4e}"
                            .format(epoch, index, len(train_loader), avg_lo.avg, avg_lt.avg, avg_la.avg))
                writer.add_scalar("loss/L_overall", avg_lo.avg, tensorboard_iter)
                writer.add_scalar("loss/L_t", avg_lt.avg, tensorboard_iter)
                writer.add_scalar("loss/L_a", avg_la.avg, tensorboard_iter)
                writer.add_scalar("other/sigma_t", sigma_t.item(), tensorboard_iter)
                writer.add_scalar("other/sigma_a", sigma_a.item(), tensorboard_iter)
                writer.add_scalar("other/lr", cur_lr, tensorboard_iter)

                avg_lo.reset()
                avg_lt.reset()
                avg_la.reset()
                
            cur_iter += 1
            tensorboard_iter = cur_iter * (args.batch_size / 16)

        # Validation
        logger.info("Validating after the {}th epoch".format(epoch))
        avg_loss = AverageMeter()
        avg_l_t = AverageMeter()
        avg_l_a = AverageMeter()
        torch.cuda.empty_cache()
        torch.set_grad_enabled(False)
        model.eval()
        with tqdm(total=len(valid_loader)) as pbar:
            for index, (display_rgb, inputs, gts) in enumerate(valid_loader):
                inputs = inputs.to(device) # [bs, 4, 320, 320]
                gt_alpha = (gts[:, 0, :, :].unsqueeze(1)).type(torch.FloatTensor).to(device) # [bs, 1, 320, 320]
                gt_trimap = gts[:, 1, :, :].type(torch.LongTensor).to(device) # [bs, 320, 320]

                trimap_adaption, t_argmax, alpha_estimation, log_sigma_t_sqr, log_sigma_a_sqr = model(inputs)
                L_overall_valid, L_t_valid, L_a_valid = task_uncertainty_loss(pred_trimap=trimap_adaption, pred_trimap_argmax=t_argmax, 
                                                            pred_alpha=alpha_estimation, gt_trimap=gt_trimap, gt_alpha=gt_alpha, 
                                                            log_sigma_t_sqr=log_sigma_t_sqr, log_sigma_a_sqr=log_sigma_a_sqr)

                avg_loss.update(L_overall_valid.item())
                avg_l_t.update(L_t_valid.item())
                avg_l_a.update(L_a_valid.item())

                if index == 0:
                    input_rbg = torchvision.utils.make_grid(display_rgb, normalize=False, scale_each=True)
                    writer.add_image('input/rbg_image', input_rbg, tensorboard_iter)

                    input_trimap = inputs[:, 3, :, :].unsqueeze(dim=1)
                    input_trimap = torchvision.utils.make_grid(input_trimap, normalize=False, scale_each=True)
                    writer.add_image('input/trimap', input_trimap, tensorboard_iter)

                    output_alpha = alpha_estimation.clone()
                    output_alpha[t_argmax.unsqueeze(dim=1) == 0] = 0.0
                    output_alpha[t_argmax.unsqueeze(dim=1) == 2] = 1.0
                    output_alpha = torchvision.utils.make_grid(output_alpha, normalize=False, scale_each=True)
                    writer.add_image('output/alpha', output_alpha, tensorboard_iter)

                    trimap_adaption_res = (t_argmax.type(torch.FloatTensor) / 2).unsqueeze(dim=1)
                    trimap_adaption_res = torchvision.utils.make_grid(trimap_adaption_res, normalize=False, scale_each=True)
                    writer.add_image('pred/trimap_adaptation', trimap_adaption_res, tensorboard_iter)

                    alpha_estimation_res = torchvision.utils.make_grid(alpha_estimation, normalize=False, scale_each=True)
                    writer.add_image('pred/alpha_estimation', alpha_estimation_res, tensorboard_iter)

                    gt_alpha = gt_alpha
                    gt_alpha = torchvision.utils.make_grid(gt_alpha, normalize=False, scale_each=True)
                    writer.add_image('gt/alpha', gt_alpha, tensorboard_iter)

                    gt_trimap = (gt_trimap.type(torch.FloatTensor) / 2).unsqueeze(dim=1)
                    gt_trimap = torchvision.utils.make_grid(gt_trimap, normalize=False, scale_each=True)
                    writer.add_image('gt/trimap', gt_trimap, tensorboard_iter)
                    
                pbar.update()

        logger.info("Average loss overall: {:.4e}".format(avg_loss.avg))
        logger.info("Average loss of trimap adaptation: {:.4e}".format(avg_l_t.avg))
        logger.info("Average loss of alpha estimation: {:.4e}".format(avg_l_a.avg))
        writer.add_scalar("valid_loss/L_overall", avg_loss.avg, tensorboard_iter)
        writer.add_scalar("valid_loss/L_t", avg_l_t.avg, tensorboard_iter)
        writer.add_scalar("valid_loss/L_a", avg_l_a.avg, tensorboard_iter)

        is_best = avg_loss.avg < best_loss
        best_loss = min(avg_loss.avg, best_loss)
        is_alpha_best = avg_l_a.avg < best_alpha_loss
        best_alpha_loss = min(avg_l_a.avg, best_alpha_loss)
        if is_best or is_alpha_best or args.save_ckpt:
            if not os.path.exists("ckpts"):
                os.makedirs("ckpts")
            save_checkpoint(ckpt_path=args.ckpt_path, is_best=is_best, is_alpha_best=is_alpha_best, logger=logger, model=model, optimizer=optimizer, 
                            epoch=epoch, cur_iter=cur_iter, peak_lr=peak_lr, best_loss=best_loss, best_alpha_loss=best_alpha_loss)

    writer.close()
Ejemplo n.º 2
0
def train(args, logger, device_ids):
    torch.manual_seed(7)
    writer = SummaryWriter()

    logger.info("Loading network")
    model = AdaMatting(in_channel=4)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=0)
    if args.resume != "":
        ckpt = torch.load(args.resume)
        # for key, _ in ckpt.items():
        #     print(key)
        model.load_state_dict(ckpt["state_dict"])
        optimizer.load_state_dict(ckpt["optimizer"])
    if args.cuda:
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()
        device = torch.device("cuda:{}".format(device_ids[0]))
        if len(device_ids) > 1:
            logger.info("Loading with multiple GPUs")
            model = torch.nn.DataParallel(model, device_ids=device_ids)
        # model = model.cuda(device=device_ids[0])
    else:
        device = torch.device("cpu")
    model = model.to(device)

    logger.info("Initializing data loaders")
    train_dataset = AdaMattingDataset(args.raw_data_path, "train")
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=16,
                                               pin_memory=True)
    valid_dataset = AdaMattingDataset(args.raw_data_path, "valid")
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=16,
                                               pin_memory=True)

    if args.resume != "":
        logger.info("Start training from saved ckpt")
        start_epoch = ckpt["epoch"] + 1
        cur_iter = ckpt["cur_iter"] + 1
        peak_lr = ckpt["peak_lr"]
        best_loss = ckpt["best_loss"]
    else:
        logger.info("Start training from scratch")
        start_epoch = 0
        cur_iter = 0
        peak_lr = args.lr
        best_loss = float('inf')

    avg_lo = AverageMeter()
    avg_lt = AverageMeter()
    avg_la = AverageMeter()
    for epoch in range(start_epoch, args.epochs):
        # Training
        torch.set_grad_enabled(True)
        model.train()
        for index, (img, gt) in enumerate(train_loader):
            cur_lr, peak_lr = lr_scheduler(optimizer=optimizer,
                                           cur_iter=cur_iter,
                                           peak_lr=peak_lr,
                                           end_lr=0.00001,
                                           decay_iters=args.decay_iters,
                                           decay_power=0.9,
                                           power=0.9)

            img = img.type(torch.FloatTensor).to(device)  # [bs, 4, 320, 320]
            gt_alpha = (gt[:,
                           0, :, :].unsqueeze(1)).type(torch.FloatTensor).to(
                               device)  # [bs, 1, 320, 320]
            gt_trimap = gt[:, 1, :, :].type(torch.LongTensor).to(
                device)  # [bs, 320, 320]

            optimizer.zero_grad()
            trimap_adaption, t_argmax, alpha_estimation, log_sigma_t_sqr, log_sigma_a_sqr = model(
                img)
            L_overall, L_t, L_a = task_uncertainty_loss(
                pred_trimap=trimap_adaption,
                pred_trimap_argmax=t_argmax,
                pred_alpha=alpha_estimation,
                gt_trimap=gt_trimap,
                gt_alpha=gt_alpha,
                log_sigma_t_sqr=log_sigma_t_sqr,
                log_sigma_a_sqr=log_sigma_a_sqr)

            L_overall, L_t, L_a = L_overall.mean(), L_t.mean(), L_a.mean()
            sigma_t, sigma_a = log_sigma_t_sqr.mean(), log_sigma_a_sqr.mean()

            optimizer.zero_grad()
            L_overall.backward()
            optimizer.step()

            avg_lo.update(L_overall.item())
            avg_lt.update(L_t.item())
            avg_la.update(L_a.item())

            if cur_iter % 10 == 0:
                logger.info(
                    "Epoch: {:03d} | Iter: {:05d}/{} | Loss: {:.4e} | L_t: {:.4e} | L_a: {:.4e}"
                    .format(epoch, index, len(train_loader), avg_lo.avg,
                            avg_lt.avg, avg_la.avg))
                writer.add_scalar("loss/L_overall", avg_lo.avg, cur_iter)
                writer.add_scalar("loss/L_t", avg_lt.avg, cur_iter)
                writer.add_scalar("loss/L_a", avg_la.avg, cur_iter)
                sigma_t = torch.exp(sigma_t / 2)
                sigma_a = torch.exp(sigma_a / 2)
                writer.add_scalar("other/sigma_t", sigma_t.item(), cur_iter)
                writer.add_scalar("other/sigma_a", sigma_a.item(), cur_iter)
                writer.add_scalar("other/lr", cur_lr, cur_iter)

                avg_lo.reset()
                avg_lt.reset()
                avg_la.reset()

            cur_iter += 1

        # Validation
        logger.info("Validating after the {}th epoch".format(epoch))
        avg_loss = AverageMeter()
        avg_l_t = AverageMeter()
        avg_l_a = AverageMeter()
        torch.cuda.empty_cache()
        torch.set_grad_enabled(False)
        model.eval()
        with tqdm(total=len(valid_loader)) as pbar:
            for index, (img, gt) in enumerate(valid_loader):
                img = img.type(torch.FloatTensor).to(
                    device)  # [bs, 4, 320, 320]
                gt_alpha = (gt[:, 0, :, :].unsqueeze(1)).type(
                    torch.FloatTensor).to(device)  # [bs, 1, 320, 320]
                gt_trimap = gt[:, 1, :, :].type(torch.LongTensor).to(
                    device)  # [bs, 320, 320]

                trimap_adaption, t_argmax, alpha_estimation, log_sigma_t_sqr, log_sigma_a_sqr = model(
                    img)
                L_overall_valid, L_t_valid, L_a_valid = task_uncertainty_loss(
                    pred_trimap=trimap_adaption,
                    pred_trimap_argmax=t_argmax,
                    pred_alpha=alpha_estimation,
                    gt_trimap=gt_trimap,
                    gt_alpha=gt_alpha,
                    log_sigma_t_sqr=log_sigma_t_sqr,
                    log_sigma_a_sqr=log_sigma_a_sqr)

                L_overall_valid, L_t_valid, L_a_valid = L_overall_valid.mean(
                ), L_t_valid.mean(), L_a_valid.mean()

                avg_loss.update(L_overall_valid.item())
                avg_l_t.update(L_t_valid.item())
                avg_l_a.update(L_a_valid.item())

                if index == 0:
                    trimap_adaption_res = (t_argmax.type(torch.FloatTensor) /
                                           2).unsqueeze(dim=1)
                    trimap_adaption_res = torchvision.utils.make_grid(
                        trimap_adaption_res, normalize=False, scale_each=True)
                    writer.add_image('valid_image/trimap_adaptation',
                                     trimap_adaption_res, cur_iter)
                    alpha_estimation_res = torchvision.utils.make_grid(
                        alpha_estimation, normalize=True, scale_each=True)
                    writer.add_image('valid_image/alpha_estimation',
                                     alpha_estimation_res, cur_iter)

                pbar.update()

        logger.info("Average loss overall: {:.4e}".format(avg_loss.avg))
        logger.info("Average loss of trimap adaptation: {:.4e}".format(
            avg_l_t.avg))
        logger.info("Average loss of alpha estimation: {:.4e}".format(
            avg_l_a.avg))
        writer.add_scalar("valid_loss/L_overall", avg_loss.avg, cur_iter)
        writer.add_scalar("valid_loss/L_t", avg_l_t.avg, cur_iter)
        writer.add_scalar("valid_loss/L_a", avg_l_a.avg, cur_iter)

        is_best = avg_loss.avg < best_loss
        best_loss = min(avg_loss.avg, best_loss)
        if is_best or args.save_ckpt:
            if not os.path.exists("ckpts"):
                os.makedirs("ckpts")
            save_checkpoint(ckpt_path=args.ckpt_path,
                            is_best=is_best,
                            logger=logger,
                            model=model,
                            optimizer=optimizer,
                            epoch=epoch,
                            cur_iter=cur_iter,
                            peak_lr=peak_lr,
                            best_loss=best_loss)

    writer.close()