def train_fn(train_loader, model, optimizer, device, epoch):
        total_loss = AverageMeter()
        accuracies = AverageMeter()

        model.train()

        t = tqdm(train_loader)
        for step, d in enumerate(t):
            spect = d["spect"].to(device)
            targets = d["target"].to(device)

            outputs = model(spect)

            loss = utility.loss_fn(outputs, targets)

            optimizer.zero_grad()

            loss.backward()
            # xm.optimizer_step(optimizer, barrier=True)
            optimizer.step()

            acc, n_position = utility.get_position_accuracy(outputs, targets)

            total_loss.update(loss.item(), n_position)
            accuracies.update(acc, n_position)

            t.set_description(
                f"Train E:{epoch + 1} - Loss:{total_loss.avg:0.4f} - Acc:{accuracies.avg:0.4f}"
            )

        return total_loss.avg
    def valid_fn(valid_loader, model, device, epoch):
        total_loss = AverageMeter()
        accuracies = AverageMeter()

        model.eval()

        t = tqdm(valid_loader)
        for step, d in enumerate(t):
            with torch.no_grad():
                spect = d["spect"].to(device)
                targets = d["target"].to(device)

                outputs = model(spect)

                loss = utility.loss_fn(outputs, targets)

                acc, n_position = utility.get_position_accuracy(
                    outputs, targets)

                total_loss.update(loss.item(), n_position)
                accuracies.update(acc, n_position)

                t.set_description(
                    f"Eval E:{epoch + 1} - Loss:{total_loss.avg:0.4f} - Acc:{accuracies.avg:0.4f}"
                )

        return total_loss.avg, accuracies.avg
    def train(self):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()

        self.model.train()
        end = time.time()

        for batch, (inputs, labels, _) in enumerate(tqdm(self.loader_train)):
            data_time.update(time.time() - end)

            inputs = inputs.cuda()
            labels = labels.cuda()

            r = np.random.rand(1)
            if r < self.args.prob_mix and self.args.mix_type != 'none':
                outputs, loss, labels = utility.mix_regularization(inputs, labels, self.model, self.loss, self.args.mix_type,
                                                           self.args.mix_beta)
            else:  # no mix no out
                # compute output
                outputs = self.model(inputs)
                loss = self.loss(outputs, labels.long())

            prec1, prec5 = utility.accuracy(outputs.data, labels, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1.item(), inputs.size(0))
            top5.update(prec5.item(), inputs.size(0))

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            batch_time.update(time.time() - end)
            end = time.time()

            if (batch + 1) % self.args.print_every == 0:
                print('-------------------------------------------------------')
                print_string = 'Epoch: [{0}][{1}/{2}]'.format(self.current_epoch + 1, batch + 1, len(self.loader_train))
                print(print_string)
                print_string = 'data_time: {data_time:.3f}, batch time: {batch_time:.3f}'.format(
                    data_time=data_time.val,
                    batch_time=batch_time.val)
                print(print_string)
                print_string = 'loss: {loss:.5f}'.format(loss=losses.avg)
                print(print_string)
                print_string = '[Training] Top-1 accuracy: {top1_acc:.2f}%, Top-5 accuracy: {top5_acc:.2f}%'.format(
                    top1_acc=top1.avg,
                    top5_acc=top5.avg)
                print(print_string)

        self.current_epoch += 1
        self.load_epoch += 1
        if self.current_epoch > self.warmup:
            self.scheduler.step()

        self.metrics['train_loss'].append(losses.avg)
        self.metrics['train_acc'].append(top1.avg)
Beispiel #4
0
    def train_epoch(self, epoch=None):

        tt = tqdm(range(self.data_loader.train_iterations),
                  total=self.data_loader.train_iterations,
                  desc="epoch-{}-".format(epoch))

        loss_per_epoch = AverageMeter()

        for cur_it in tt:
            # One Train step on the current batch
            loss = self.train_step()
            # update metrics returned from train_step func
            loss_per_epoch.update(loss)

        self.sess.run(self.model.global_epoch_inc)

        self.model.save(self.sess, self.config.checkpoint_dir)

        print("""
        Epoch-{}  loss:{:.4f}
                """.format(epoch, loss_per_epoch.val))

        tt.close()
Beispiel #5
0
    def test(self, epoch):

        # initialize tqdm
        tt = tqdm(range(self.data_loader.test_iterations),
                  total=self.data_loader.test_iterations,
                  desc="Val-{}-".format(epoch))

        loss_per_epoch = AverageMeter()

        # Iterate over batches
        for cur_it in tt:
            # One Train step on the current batch
            feed_dict = {}
            for step in range(self.model.time_steps):
                input, label = self.data_loader.next_batch()

                feed_dict[self.model.train_inputs[step]] = input.reshape(-1, 1)
                feed_dict[self.model.train_labels[step]] = label.reshape(-1, 1)

            feed_dict.update({self.is_training: False})

            loss = self.sess.run([self.loss_node], feed_dict=feed_dict)
            loss = loss[0]
            # update metrics returned from train_step func

            loss_per_epoch.update(loss)

        # summarize
        # summaries_dict = {'test/loss_per_epoch': loss_per_epoch.val,
        #                   'test/acc_per_epoch': acc_per_epoch.val}
        # self.summarizer.summarize(self.model.global_step_tensor.eval(self.sess), summaries_dict)

        print("""
            Val-{}  loss:{:.4f}
                    """.format(epoch, loss_per_epoch.val))

        tt.close()
Beispiel #6
0
def test(args, logger, device_ids):
    logger.info("Loading network")
    model = AdaMatting(in_channel=4)
    ckpt = torch.load("./ckpts/ckpt_best_alpha.tar")
    model.load_state_dict(ckpt["state_dict"])
    if args.cuda:
        device = torch.device("cuda:{}".format(device_ids[0]))
        if len(device_ids) > 1:
            logger.info("Loading with multiple GPUs")
            model = torch.nn.DataParallel(model, device_ids=device_ids)
        # model = model.cuda(device=device_ids[0])
    else:
        device = torch.device("cpu")
    model = model.to(device)
    torch.set_grad_enabled(False)
    model.eval()

    test_names = gen_test_names()

    with open(os.path.join(args.raw_data_path, "Combined_Dataset/Test_set/test_fg_names.txt")) as f:
        fg_files = f.read().splitlines()
    with open(os.path.join(args.raw_data_path, "Combined_Dataset/Test_set/test_bg_names.txt")) as f:
        bg_files = f.read().splitlines()

    out_path = os.path.join(args.raw_data_path, "pred/")
    if not os.path.exists(out_path):
        os.makedirs(out_path)

    logger.info("Start testing")
    avg_sad = AverageMeter()
    avg_mse = AverageMeter()
    for index, name in enumerate(test_names):
        # file names
        fcount = int(name.split('.')[0].split('_')[0])
        bcount = int(name.split('.')[0].split('_')[1])
        img_name = fg_files[fcount]
        bg_name = bg_files[bcount]
        merged_name = bg_name.split(".")[0] + "!" + img_name.split(".")[0] + "!" + str(fcount) + "!" + str(index) + ".png"
        trimap_name = img_name.split(".")[0] + "_" + str(index % 20) + ".png"

        # read files
        merged = os.path.join(args.raw_data_path, "test/merged/", merged_name)
        alpha = os.path.join(args.raw_data_path, "test/mask/", img_name)
        trimap = os.path.join(args.raw_data_path, "Combined_Dataset/Test_set/Adobe-licensed images/trimaps/", trimap_name)
        merged = cv.imread(merged)
        # merged = cv.resize(merged, None, fx=0.75, fy=0.75)
        merged = cv.cvtColor(merged, cv.COLOR_BGR2RGB)
        trimap = cv.imread(trimap)
        # trimap = cv.resize(trimap, None, fx=0.75, fy=0.75)
        alpha = cv.imread(alpha, 0)
        # alpha = cv.resize(alpha, None, fx=0.75, fy=0.75)

        # process merged image
        merged = transforms.ToPILImage()(merged)
        out_merged = merged.copy()
        merged = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])(merged)
        h, w = merged.shape[1:3]
        h_crop, w_crop = h, w
        for i in range(h):
            if (h - i) % 16 == 0:
                h_crop = h - i
                break
        h_margin = int((h - h_crop) / 2)
        for i in range(w):
            if (w - i) % 16 == 0:
                w_crop = w - i
                break
        w_margin = int((w - w_crop) / 2)

        # write cropped gt alpha
        alpha = alpha[h_margin : h_margin + h_crop, w_margin : w_margin + w_crop]
        cv.imwrite(out_path + "{:04d}_gt_alpha.png".format(index), alpha)

        # generate and write cropped gt trimap
        gt_trimap = np.zeros(alpha.shape)
        gt_trimap.fill(128)
        gt_trimap[alpha <= 0] = 0
        gt_trimap[alpha >= 255] = 255
        cv.imwrite(out_path + "{:04d}_gt_trimap.png".format(index), gt_trimap)

        # concat the 4-d input and crop to feed the network properly
        x = torch.zeros((1, 4, h, w), dtype=torch.float)
        x[0, 0:3, :, :] = merged
        x[0, 3, :, :] = torch.from_numpy(trimap[:, :, 0] / 255.)
        x = x[:, :, h_margin : h_margin + h_crop, w_margin : w_margin + w_crop]

        # write cropped input images
        out_merged = transforms.ToTensor()(out_merged)
        out_merged = out_merged[:, h_margin : h_margin + h_crop, w_margin : w_margin + w_crop]
        out_merged = transforms.ToPILImage()(out_merged)
        out_merged.save(out_path + "{:04d}_input_merged.png".format(index))
        out_trimap = transforms.ToPILImage()(x[0, 3, :, :])
        out_trimap.save(out_path + "{:04d}_input_trimap.png".format(index))

        # test
        x = x.type(torch.FloatTensor).to(device)
        _, pred_trimap, pred_alpha, _, _ = model(x)

        cropped_trimap = x[0, 3, :, :].unsqueeze(dim=0).unsqueeze(dim=0)
        pred_alpha[cropped_trimap <= 0] = 0.0
        pred_alpha[cropped_trimap >= 1] = 1.0

        # output predicted images
        pred_trimap = (pred_trimap.type(torch.FloatTensor) / 2).unsqueeze(dim=1)
        pred_trimap = transforms.ToPILImage()(pred_trimap[0, :, :, :])
        pred_trimap.save(out_path + "{:04d}_pred_trimap.png".format(index))
        out_pred_alpha = transforms.ToPILImage()(pred_alpha[0, :, :, :].cpu())
        out_pred_alpha.save(out_path + "{:04d}_pred_alpha.png".format(index))
        
        sad = compute_sad(pred_alpha, alpha)
        mse = compute_mse(pred_alpha, alpha, trimap)
        avg_sad.update(sad.item())
        avg_mse.update(mse.item())
        logger.info("{:04d}/{} | SAD: {:.1f} | MSE: {:.3f} | Avg SAD: {:.1f} | Avg MSE: {:.3f}".format(index, len(test_names), sad.item(), mse.item(), avg_sad.avg, avg_mse.avg))
    
    logger.info("Average SAD: {:.1f} | Average MSE: {:.3f}".format(avg_sad.avg, avg_mse.avg))
Beispiel #7
0
def train(args, logger, device_ids):
    writer = SummaryWriter()

    logger.info("Loading network")
    model = AdaMatting(in_channel=4)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.0001)
    if args.resume != "":
        ckpt = torch.load(args.resume)
        model.load_state_dict(ckpt["state_dict"])
        optimizer.load_state_dict(ckpt["optimizer"])
    if args.cuda:
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()
        device = torch.device("cuda:{}".format(device_ids[0]))
        if len(device_ids) > 1:
            logger.info("Loading with multiple GPUs")
            model = torch.nn.DataParallel(model, device_ids=device_ids)
            # model = convert_model(model)
    else:
        device = torch.device("cpu")
    model = model.to(device)

    logger.info("Initializing data loaders")
    train_dataset = AdaMattingDataset(args.raw_data_path, "train")
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, 
                                               num_workers=16, pin_memory=True, drop_last=True)
    valid_dataset = AdaMattingDataset(args.raw_data_path, "valid")
    valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, 
                                               num_workers=16, pin_memory=True, drop_last=True)

    if args.resume != "":
        logger.info("Start training from saved ckpt")
        start_epoch = ckpt["epoch"] + 1
        cur_iter = ckpt["cur_iter"]
        peak_lr = ckpt["peak_lr"]
        best_loss = ckpt["best_loss"]
        best_alpha_loss = ckpt["best_alpha_loss"]
    else:
        logger.info("Start training from scratch")
        start_epoch = 0
        cur_iter = 0
        peak_lr = args.lr
        best_loss = float('inf')
        best_alpha_loss = float('inf')

    max_iter = 43100 * (1 - args.valid_portion / 100) / args.batch_size * args.epochs
    tensorboard_iter = cur_iter * (args.batch_size / 16)

    avg_lo = AverageMeter()
    avg_lt = AverageMeter()
    avg_la = AverageMeter()
    for epoch in range(start_epoch, args.epochs):
        # Training
        torch.set_grad_enabled(True)
        model.train()
        for index, (_, inputs, gts) in enumerate(train_loader):
            # cur_lr, peak_lr = lr_scheduler(optimizer=optimizer, cur_iter=cur_iter, peak_lr=peak_lr, end_lr=0.000001, 
            #                                decay_iters=args.decay_iters, decay_power=0.8, power=0.5)
            cur_lr = lr_scheduler(optimizer=optimizer, init_lr=args.lr, cur_iter=cur_iter, max_iter=max_iter, 
                                  max_decay_times=30, decay_rate=0.9)
            
            # img = img.type(torch.FloatTensor).to(device) # [bs, 4, 320, 320]
            inputs = inputs.to(device)
            gt_alpha = (gts[:, 0, :, :].unsqueeze(1)).type(torch.FloatTensor).to(device) # [bs, 1, 320, 320]
            gt_trimap = gts[:, 1, :, :].type(torch.LongTensor).to(device) # [bs, 320, 320]

            optimizer.zero_grad()
            trimap_adaption, t_argmax, alpha_estimation, log_sigma_t_sqr, log_sigma_a_sqr = model(inputs)

            L_overall, L_t, L_a = task_uncertainty_loss(pred_trimap=trimap_adaption, pred_trimap_argmax=t_argmax, 
                                                        pred_alpha=alpha_estimation, gt_trimap=gt_trimap, gt_alpha=gt_alpha, 
                                                        log_sigma_t_sqr=log_sigma_t_sqr, log_sigma_a_sqr=log_sigma_a_sqr)

            sigma_t, sigma_a = torch.exp(log_sigma_t_sqr.mean() / 2), torch.exp(log_sigma_a_sqr.mean() / 2)

            optimizer.zero_grad()
            L_overall.backward()
            clip_gradient(optimizer, 5)
            optimizer.step()

            avg_lo.update(L_overall.item())
            avg_lt.update(L_t.item())
            avg_la.update(L_a.item())

            if cur_iter % 10 == 0:
                logger.info("Epoch: {:03d} | Iter: {:05d}/{} | Loss: {:.4e} | L_t: {:.4e} | L_a: {:.4e}"
                            .format(epoch, index, len(train_loader), avg_lo.avg, avg_lt.avg, avg_la.avg))
                writer.add_scalar("loss/L_overall", avg_lo.avg, tensorboard_iter)
                writer.add_scalar("loss/L_t", avg_lt.avg, tensorboard_iter)
                writer.add_scalar("loss/L_a", avg_la.avg, tensorboard_iter)
                writer.add_scalar("other/sigma_t", sigma_t.item(), tensorboard_iter)
                writer.add_scalar("other/sigma_a", sigma_a.item(), tensorboard_iter)
                writer.add_scalar("other/lr", cur_lr, tensorboard_iter)

                avg_lo.reset()
                avg_lt.reset()
                avg_la.reset()
                
            cur_iter += 1
            tensorboard_iter = cur_iter * (args.batch_size / 16)

        # Validation
        logger.info("Validating after the {}th epoch".format(epoch))
        avg_loss = AverageMeter()
        avg_l_t = AverageMeter()
        avg_l_a = AverageMeter()
        torch.cuda.empty_cache()
        torch.set_grad_enabled(False)
        model.eval()
        with tqdm(total=len(valid_loader)) as pbar:
            for index, (display_rgb, inputs, gts) in enumerate(valid_loader):
                inputs = inputs.to(device) # [bs, 4, 320, 320]
                gt_alpha = (gts[:, 0, :, :].unsqueeze(1)).type(torch.FloatTensor).to(device) # [bs, 1, 320, 320]
                gt_trimap = gts[:, 1, :, :].type(torch.LongTensor).to(device) # [bs, 320, 320]

                trimap_adaption, t_argmax, alpha_estimation, log_sigma_t_sqr, log_sigma_a_sqr = model(inputs)
                L_overall_valid, L_t_valid, L_a_valid = task_uncertainty_loss(pred_trimap=trimap_adaption, pred_trimap_argmax=t_argmax, 
                                                            pred_alpha=alpha_estimation, gt_trimap=gt_trimap, gt_alpha=gt_alpha, 
                                                            log_sigma_t_sqr=log_sigma_t_sqr, log_sigma_a_sqr=log_sigma_a_sqr)

                avg_loss.update(L_overall_valid.item())
                avg_l_t.update(L_t_valid.item())
                avg_l_a.update(L_a_valid.item())

                if index == 0:
                    input_rbg = torchvision.utils.make_grid(display_rgb, normalize=False, scale_each=True)
                    writer.add_image('input/rbg_image', input_rbg, tensorboard_iter)

                    input_trimap = inputs[:, 3, :, :].unsqueeze(dim=1)
                    input_trimap = torchvision.utils.make_grid(input_trimap, normalize=False, scale_each=True)
                    writer.add_image('input/trimap', input_trimap, tensorboard_iter)

                    output_alpha = alpha_estimation.clone()
                    output_alpha[t_argmax.unsqueeze(dim=1) == 0] = 0.0
                    output_alpha[t_argmax.unsqueeze(dim=1) == 2] = 1.0
                    output_alpha = torchvision.utils.make_grid(output_alpha, normalize=False, scale_each=True)
                    writer.add_image('output/alpha', output_alpha, tensorboard_iter)

                    trimap_adaption_res = (t_argmax.type(torch.FloatTensor) / 2).unsqueeze(dim=1)
                    trimap_adaption_res = torchvision.utils.make_grid(trimap_adaption_res, normalize=False, scale_each=True)
                    writer.add_image('pred/trimap_adaptation', trimap_adaption_res, tensorboard_iter)

                    alpha_estimation_res = torchvision.utils.make_grid(alpha_estimation, normalize=False, scale_each=True)
                    writer.add_image('pred/alpha_estimation', alpha_estimation_res, tensorboard_iter)

                    gt_alpha = gt_alpha
                    gt_alpha = torchvision.utils.make_grid(gt_alpha, normalize=False, scale_each=True)
                    writer.add_image('gt/alpha', gt_alpha, tensorboard_iter)

                    gt_trimap = (gt_trimap.type(torch.FloatTensor) / 2).unsqueeze(dim=1)
                    gt_trimap = torchvision.utils.make_grid(gt_trimap, normalize=False, scale_each=True)
                    writer.add_image('gt/trimap', gt_trimap, tensorboard_iter)
                    
                pbar.update()

        logger.info("Average loss overall: {:.4e}".format(avg_loss.avg))
        logger.info("Average loss of trimap adaptation: {:.4e}".format(avg_l_t.avg))
        logger.info("Average loss of alpha estimation: {:.4e}".format(avg_l_a.avg))
        writer.add_scalar("valid_loss/L_overall", avg_loss.avg, tensorboard_iter)
        writer.add_scalar("valid_loss/L_t", avg_l_t.avg, tensorboard_iter)
        writer.add_scalar("valid_loss/L_a", avg_l_a.avg, tensorboard_iter)

        is_best = avg_loss.avg < best_loss
        best_loss = min(avg_loss.avg, best_loss)
        is_alpha_best = avg_l_a.avg < best_alpha_loss
        best_alpha_loss = min(avg_l_a.avg, best_alpha_loss)
        if is_best or is_alpha_best or args.save_ckpt:
            if not os.path.exists("ckpts"):
                os.makedirs("ckpts")
            save_checkpoint(ckpt_path=args.ckpt_path, is_best=is_best, is_alpha_best=is_alpha_best, logger=logger, model=model, optimizer=optimizer, 
                            epoch=epoch, cur_iter=cur_iter, peak_lr=peak_lr, best_loss=best_loss, best_alpha_loss=best_alpha_loss)

    writer.close()
Beispiel #8
0
def train(model, optimizer, device, args, logger, multi_gpu):
    torch.manual_seed(7)
    writer = SummaryWriter()

    logger.info("Initializing data loaders")
    train_dataset = AdaMattingDataset(args.raw_data_path, 'train')
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, 
                                               num_workers=16, pin_memory=True)
    valid_dataset = AdaMattingDataset(args.raw_data_path, 'valid')
    valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, 
                                               num_workers=16, pin_memory=True)

    if args.resume:
        logger.info("Start training from saved ckpt")
        ckpt = torch.load(args.ckpt_path)
        model = ckpt["model"].module
        model = model.to(device)
        optimizer = ckpt["optimizer"]

        start_epoch = ckpt["epoch"] + 1
        max_iter = ckpt["max_iter"]
        cur_iter = ckpt["cur_iter"]
        init_lr = ckpt["init_lr"]
        best_loss = ckpt["best_loss"]
    else:
        logger.info("Start training from scratch")
        start_epoch = 0
        max_iter = 43100 * (1 - args.valid_portion) / args.batch_size * args.epochs
        cur_iter = 0
        init_lr = args.lr
        best_loss = float('inf')
    
    for epoch in range(start_epoch, args.epochs):
        # Training
        torch.set_grad_enabled(True)
        model.train()
        for index, (img, gt) in enumerate(train_loader):
            cur_lr = poly_lr_scheduler(optimizer=optimizer, init_lr=init_lr, iter=cur_iter, max_iter=max_iter)

            img = img.type(torch.FloatTensor).to(device) # [bs, 4, 320, 320]
            gt_alpha = (gt[:, 0, :, :].unsqueeze(1)).type(torch.FloatTensor).to(device) # [bs, 1, 320, 320]
            gt_trimap = gt[:, 1, :, :].type(torch.LongTensor).to(device) # [bs, 320, 320]

            optimizer.zero_grad()
            trimap_adaption, t_argmax, alpha_estimation = model(img)
            L_overall, L_t, L_a = task_uncertainty_loss(pred_trimap=trimap_adaption, pred_trimap_argmax=t_argmax, 
                                                        pred_alpha=alpha_estimation, gt_trimap=gt_trimap, 
                                                        gt_alpha=gt_alpha, log_sigma_t_sqr=model.log_sigma_t_sqr, log_sigma_a_sqr=model.log_sigma_a_sqr)
            # if multi_gpu:
            #     L_overall, L_t, L_a = L_overall.mean(), L_t.mean(), L_a.mean()
            optimizer.zero_grad()
            L_overall.backward()
            optimizer.step()

            if cur_iter % 10 == 0:
                logger.info("Epoch: {:03d} | Iter: {:05d}/{} | Loss: {:.4e} | L_t: {:.4e} | L_a: {:.4e}"
                            .format(epoch, index, len(train_loader), L_overall.item(), L_t.item(), L_a.item()))
                writer.add_scalar("loss/L_overall", L_overall.item(), cur_iter)
                writer.add_scalar("loss/L_t", L_t.item(), cur_iter)
                writer.add_scalar("loss/L_a", L_a.item(), cur_iter)
                sigma_t = torch.exp(model.log_sigma_t_sqr / 2)
                sigma_a = torch.exp(model.log_sigma_a_sqr / 2)
                writer.add_scalar("sigma/sigma_t", sigma_t, cur_iter)
                writer.add_scalar("sigma/sigma_a", sigma_a, cur_iter)
                writer.add_scalar("lr", cur_lr, cur_iter)
            
            cur_iter += 1
        
        # Validation
        logger.info("Validating after the {}th epoch".format(epoch))
        avg_loss = AverageMeter()
        avg_l_t = AverageMeter()
        avg_l_a = AverageMeter()
        torch.cuda.empty_cache()
        torch.set_grad_enabled(False)
        model.eval()
        with tqdm(total=len(valid_loader)) as pbar:
            for index, (img, gt) in enumerate(valid_loader):
                img = img.type(torch.FloatTensor).to(device) # [bs, 4, 320, 320]
                gt_alpha = (gt[:, 0, :, :].unsqueeze(1)).type(torch.FloatTensor).to(device) # [bs, 1, 320, 320]
                gt_trimap = gt[:, 1, :, :].type(torch.LongTensor).to(device) # [bs, 320, 320]

                trimap_adaption, t_argmax, alpha_estimation = model(img)
                L_overall_valid, L_t_valid, L_a_valid = task_uncertainty_loss(pred_trimap=trimap_adaption, pred_trimap_argmax=t_argmax, 
                                                            pred_alpha=alpha_estimation, gt_trimap=gt_trimap, 
                                                            gt_alpha=gt_alpha, log_sigma_t_sqr=model.log_sigma_t_sqr, log_sigma_a_sqr=model.log_sigma_a_sqr)
                # if multi_gpu:
                #     L_overall, L_t, L_a = L_overall.mean(), L_t.mean(), L_a.mean()
                avg_loss.update(L_overall_valid.item())
                avg_l_t.update(L_t_valid.item())
                avg_l_a.update(L_a_valid.item())

                if index == 0:
                    trimap_adaption_res = torchvision.utils.make_grid(t_argmax.type(torch.FloatTensor) / 2, normalize=True, scale_each=True)
                    writer.add_image('valid_image/trimap_adaptation', trimap_adaption_res, cur_iter)
                    alpha_estimation_res = torchvision.utils.make_grid(alpha_estimation, normalize=True, scale_each=True)
                    writer.add_image('valid_image/alpha_estimation', alpha_estimation_res, cur_iter)
                
                pbar.update()

        logger.info("Average loss overall: {:.4e}".format(avg_loss.avg))
        logger.info("Average loss of trimap adaptation: {:.4e}".format(avg_l_t.avg))
        logger.info("Average loss of alpha estimation: {:.4e}".format(avg_l_a.avg))
        writer.add_scalar("valid_loss/L_overall", avg_loss.avg, cur_iter)
        writer.add_scalar("valid_loss/L_t", avg_l_t.avg, cur_iter)
        writer.add_scalar("valid_loss/L_a", avg_l_a.avg, cur_iter)

        is_best = avg_loss.avg < best_loss
        best_loss = min(avg_loss.avg, best_loss)
        if is_best or (args.save_ckpt and epoch % 10 == 0):
            if not os.path.exists("ckpts"):
                os.makedirs("ckpts")
            logger.info("Checkpoint saved")
            if (is_best):
                logger.info("Best checkpoint saved")
            save_checkpoint(epoch, model, optimizer, cur_iter, max_iter, init_lr, avg_loss.avg, is_best, args.ckpt_path)

    writer.export_scalars_to_json("./all_scalars.json")
    writer.close()
Beispiel #9
0
def train(args, logger, device_ids):
    torch.manual_seed(7)
    writer = SummaryWriter()

    logger.info("Loading network")
    model = AdaMatting(in_channel=4)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=0)
    if args.resume != "":
        ckpt = torch.load(args.resume)
        # for key, _ in ckpt.items():
        #     print(key)
        model.load_state_dict(ckpt["state_dict"])
        optimizer.load_state_dict(ckpt["optimizer"])
    if args.cuda:
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()
        device = torch.device("cuda:{}".format(device_ids[0]))
        if len(device_ids) > 1:
            logger.info("Loading with multiple GPUs")
            model = torch.nn.DataParallel(model, device_ids=device_ids)
        # model = model.cuda(device=device_ids[0])
    else:
        device = torch.device("cpu")
    model = model.to(device)

    logger.info("Initializing data loaders")
    train_dataset = AdaMattingDataset(args.raw_data_path, "train")
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=16,
                                               pin_memory=True)
    valid_dataset = AdaMattingDataset(args.raw_data_path, "valid")
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=16,
                                               pin_memory=True)

    if args.resume != "":
        logger.info("Start training from saved ckpt")
        start_epoch = ckpt["epoch"] + 1
        cur_iter = ckpt["cur_iter"] + 1
        peak_lr = ckpt["peak_lr"]
        best_loss = ckpt["best_loss"]
    else:
        logger.info("Start training from scratch")
        start_epoch = 0
        cur_iter = 0
        peak_lr = args.lr
        best_loss = float('inf')

    avg_lo = AverageMeter()
    avg_lt = AverageMeter()
    avg_la = AverageMeter()
    for epoch in range(start_epoch, args.epochs):
        # Training
        torch.set_grad_enabled(True)
        model.train()
        for index, (img, gt) in enumerate(train_loader):
            cur_lr, peak_lr = lr_scheduler(optimizer=optimizer,
                                           cur_iter=cur_iter,
                                           peak_lr=peak_lr,
                                           end_lr=0.00001,
                                           decay_iters=args.decay_iters,
                                           decay_power=0.9,
                                           power=0.9)

            img = img.type(torch.FloatTensor).to(device)  # [bs, 4, 320, 320]
            gt_alpha = (gt[:,
                           0, :, :].unsqueeze(1)).type(torch.FloatTensor).to(
                               device)  # [bs, 1, 320, 320]
            gt_trimap = gt[:, 1, :, :].type(torch.LongTensor).to(
                device)  # [bs, 320, 320]

            optimizer.zero_grad()
            trimap_adaption, t_argmax, alpha_estimation, log_sigma_t_sqr, log_sigma_a_sqr = model(
                img)
            L_overall, L_t, L_a = task_uncertainty_loss(
                pred_trimap=trimap_adaption,
                pred_trimap_argmax=t_argmax,
                pred_alpha=alpha_estimation,
                gt_trimap=gt_trimap,
                gt_alpha=gt_alpha,
                log_sigma_t_sqr=log_sigma_t_sqr,
                log_sigma_a_sqr=log_sigma_a_sqr)

            L_overall, L_t, L_a = L_overall.mean(), L_t.mean(), L_a.mean()
            sigma_t, sigma_a = log_sigma_t_sqr.mean(), log_sigma_a_sqr.mean()

            optimizer.zero_grad()
            L_overall.backward()
            optimizer.step()

            avg_lo.update(L_overall.item())
            avg_lt.update(L_t.item())
            avg_la.update(L_a.item())

            if cur_iter % 10 == 0:
                logger.info(
                    "Epoch: {:03d} | Iter: {:05d}/{} | Loss: {:.4e} | L_t: {:.4e} | L_a: {:.4e}"
                    .format(epoch, index, len(train_loader), avg_lo.avg,
                            avg_lt.avg, avg_la.avg))
                writer.add_scalar("loss/L_overall", avg_lo.avg, cur_iter)
                writer.add_scalar("loss/L_t", avg_lt.avg, cur_iter)
                writer.add_scalar("loss/L_a", avg_la.avg, cur_iter)
                sigma_t = torch.exp(sigma_t / 2)
                sigma_a = torch.exp(sigma_a / 2)
                writer.add_scalar("other/sigma_t", sigma_t.item(), cur_iter)
                writer.add_scalar("other/sigma_a", sigma_a.item(), cur_iter)
                writer.add_scalar("other/lr", cur_lr, cur_iter)

                avg_lo.reset()
                avg_lt.reset()
                avg_la.reset()

            cur_iter += 1

        # Validation
        logger.info("Validating after the {}th epoch".format(epoch))
        avg_loss = AverageMeter()
        avg_l_t = AverageMeter()
        avg_l_a = AverageMeter()
        torch.cuda.empty_cache()
        torch.set_grad_enabled(False)
        model.eval()
        with tqdm(total=len(valid_loader)) as pbar:
            for index, (img, gt) in enumerate(valid_loader):
                img = img.type(torch.FloatTensor).to(
                    device)  # [bs, 4, 320, 320]
                gt_alpha = (gt[:, 0, :, :].unsqueeze(1)).type(
                    torch.FloatTensor).to(device)  # [bs, 1, 320, 320]
                gt_trimap = gt[:, 1, :, :].type(torch.LongTensor).to(
                    device)  # [bs, 320, 320]

                trimap_adaption, t_argmax, alpha_estimation, log_sigma_t_sqr, log_sigma_a_sqr = model(
                    img)
                L_overall_valid, L_t_valid, L_a_valid = task_uncertainty_loss(
                    pred_trimap=trimap_adaption,
                    pred_trimap_argmax=t_argmax,
                    pred_alpha=alpha_estimation,
                    gt_trimap=gt_trimap,
                    gt_alpha=gt_alpha,
                    log_sigma_t_sqr=log_sigma_t_sqr,
                    log_sigma_a_sqr=log_sigma_a_sqr)

                L_overall_valid, L_t_valid, L_a_valid = L_overall_valid.mean(
                ), L_t_valid.mean(), L_a_valid.mean()

                avg_loss.update(L_overall_valid.item())
                avg_l_t.update(L_t_valid.item())
                avg_l_a.update(L_a_valid.item())

                if index == 0:
                    trimap_adaption_res = (t_argmax.type(torch.FloatTensor) /
                                           2).unsqueeze(dim=1)
                    trimap_adaption_res = torchvision.utils.make_grid(
                        trimap_adaption_res, normalize=False, scale_each=True)
                    writer.add_image('valid_image/trimap_adaptation',
                                     trimap_adaption_res, cur_iter)
                    alpha_estimation_res = torchvision.utils.make_grid(
                        alpha_estimation, normalize=True, scale_each=True)
                    writer.add_image('valid_image/alpha_estimation',
                                     alpha_estimation_res, cur_iter)

                pbar.update()

        logger.info("Average loss overall: {:.4e}".format(avg_loss.avg))
        logger.info("Average loss of trimap adaptation: {:.4e}".format(
            avg_l_t.avg))
        logger.info("Average loss of alpha estimation: {:.4e}".format(
            avg_l_a.avg))
        writer.add_scalar("valid_loss/L_overall", avg_loss.avg, cur_iter)
        writer.add_scalar("valid_loss/L_t", avg_l_t.avg, cur_iter)
        writer.add_scalar("valid_loss/L_a", avg_l_a.avg, cur_iter)

        is_best = avg_loss.avg < best_loss
        best_loss = min(avg_loss.avg, best_loss)
        if is_best or args.save_ckpt:
            if not os.path.exists("ckpts"):
                os.makedirs("ckpts")
            save_checkpoint(ckpt_path=args.ckpt_path,
                            is_best=is_best,
                            logger=logger,
                            model=model,
                            optimizer=optimizer,
                            epoch=epoch,
                            cur_iter=cur_iter,
                            peak_lr=peak_lr,
                            best_loss=best_loss)

    writer.close()
    def test(self):
        if self.current_epoch % self.args.test_every == 0:
            batch_time = AverageMeter()
            data_time = AverageMeter()
            top1 = AverageMeter()
            top5 = AverageMeter()
            self.model.eval()

            end = time.time()

            with torch.no_grad():
                for batch, (inputs, labels, filename) in enumerate(tqdm(self.loader_test)):
                    data_time.update(time.time() - end)

                    _, _, len_of_frame, height, width = inputs.size()
                    spatial_stride = (width - self.args.crop_size) // 2
                    stride = len_of_frame / 10
                    if len_of_frame <= self.args.clip_len:
                        avail_number = 0
                        new_len = len_of_frame
                    else:
                        last_subclip_end_idx = round(min(len_of_frame, len_of_frame - (stride / 2) + (self.args.clip_len / 2)))
                        last_subclip_begin_idx = last_subclip_end_idx - self.args.clip_len

                        avail_number = min(last_subclip_begin_idx, 9)
                        new_stride = last_subclip_begin_idx / float(avail_number)
                        new_len = self.args.clip_len

                    # Per View Test
                    begin_idx = 0
                    for t in range(avail_number + 1):
                        end_idx = begin_idx + new_len

                        sub_inputs_t = inputs[:, :, begin_idx:end_idx, :, :]
                        if self.args.test_view == 30:
                            begin_spatial_idx = 0
                            for st in range(3):
                                end_spatial_idx = begin_spatial_idx + self.args.crop_size
                                sub_inputs_st = sub_inputs_t[:, :, :, :, begin_spatial_idx:end_spatial_idx]
                                begin_spatial_idx = begin_spatial_idx + spatial_stride

                                sub_inputs_st = sub_inputs_st.cuda()
                                if t == 0 and st == 0:
                                    outputs = torch.nn.Softmax(dim=1)(self.model(sub_inputs_st))
                                else:
                                    outputs = outputs + torch.nn.Softmax(dim=1)(self.model(sub_inputs_st))
                        else:
                            sub_inputs_t = sub_inputs_t.cuda()
                            if t == 0:
                                outputs = torch.nn.Softmax(dim=1)(self.model(sub_inputs_t))
                            else:
                                outputs = outputs + torch.nn.Softmax(dim=1)(self.model(sub_inputs_t))

                        # idx update
                        begin_idx = round(begin_idx + new_stride)

                    if self.args.test_view == 10:
                        outputs = outputs / (avail_number + 1)
                    else:
                        outputs = outputs / (3 * (avail_number + 1))
                    labels = labels.cuda()

                    if self.final_test:
                        # Write Prediction into Text File Here
                        final_array = utility.inference(outputs.data)

                        # write [filename final_array] and [newline]
                        self.logfile.write(filename[0][-22:])
                        for tops in range(5):
                            data_msg = ' {0}'.format(final_array[tops])
                            self.logfile.write(data_msg)
                        self.logfile.write('\n')
                    else:
                        # measure accuracy and record loss
                        prec1, prec5 = utility.accuracy(outputs.data, labels, topk=(1, 5))
                        top1.update(prec1.item(), inputs.size(0))
                        top5.update(prec5.item(), inputs.size(0))
                        batch_time.update(time.time() - end)
                        end = time.time()

            if self.args.is_validate:
                print('----Validation Results Summary----')
                print_string = 'Epoch: [{0}]'.format(self.current_epoch)
                print(print_string)
                print_string = '----------------------------- Top-1 accuracy: {top1_acc:.2f}%'.format(top1_acc=top1.avg)
                print(print_string)
                print_string = '----------------------------- Top-5 accuracy: {top5_acc:.2f}%'.format(top5_acc=top5.avg)
                print(print_string)

            # save model per epoch
            if not self.args.test_only:
                if self.current_epoch % self.args.save_every == 0:
                    torch.save({'epoch': self.current_epoch, 'state_dict': self.model.state_dict(),
                                'optimizer_state_dict': self.optimizer.state_dict(),
                                'scheduler_state_dict': self.scheduler.state_dict()},
                               self.ckpt_dir + '/model_epoch' + str(self.current_epoch).zfill(3) + '.pth')
            self.metrics['val_acc'].append(top1.avg)
            self.metrics['val_acc_top5'].append(top5.avg)
        else:
            self.metrics['val_acc'].append(0.)
            self.metrics['val_acc_top5'].append(0.)

        # Write logs
        if not self.args.test_only:
            with open(self.args.out_dir + '/log_epoch.csv', 'a') as epoch_log:
                if not self.args.load:
                    epoch_log.write('{}, {:.5f}, {:.5f}, {:.5f}, {:.5f}\n'.format(
                            self.current_epoch, self.metrics['train_loss'][self.current_epoch-1], 
                            self.metrics['train_acc'][self.current_epoch-1], 
                            self.metrics['val_acc'][self.current_epoch-1], self.metrics['val_acc_top5'][self.current_epoch-1]))
                    plot_learning_curves(self.metrics, self.current_epoch, self.args)
                else:
                    epoch_log.write('{}, {:.5f}, {:.5f}, {:.5f}, {:.5f}\n'.format(
                            self.current_epoch, self.metrics['train_loss'][self.load_epoch-1], 
                            self.metrics['train_acc'][self.load_epoch-1], 
                            self.metrics['val_acc'][self.load_epoch-1], self.metrics['val_acc_top5'][self.load_epoch-1]))
                    plot_learning_curves(self.metrics, self.load_epoch, self.args)