Пример #1
0
def test(model):
    model.eval()

    transformer = data_transforms['valid']

    names = gen_test_names()

    mse_losses = AverageMeter()
    sad_losses = AverageMeter()

    i = 0
    for name in tqdm(names):
        fcount = int(name.split('.')[0].split('_')[0])
        bcount = int(name.split('.')[0].split('_')[1])
        im_name = fg_test_files[fcount]
        bg_name = bg_test_files[bcount]
        trimap_name = im_name.split('.')[0] + '_' + str(i) + '.png'

        trimap = cv.imread('data/Combined_Dataset/Test_set/Adobe-licensed images/trimaps/' + trimap_name, 0)

        i += 1
        if i == 20:
            i = 0

        img, alpha, fg, bg, new_trimap = process_test(im_name, bg_name, trimap, trimap_name)
        h, w = img.shape[:2]

        x = torch.zeros((1, 4, h, w), dtype=torch.float)
        img = img[..., ::-1]  # RGB
        img = transforms.ToPILImage()(img)  # [3, 320, 320]
        img = transformer(img)  # [3, 320, 320]
        x[0:, 0:3, :, :] = img
        x[0:, 3, :, :] = torch.from_numpy(new_trimap.copy() / 255.)

        # Move to GPU, if available
        x = x.type(torch.FloatTensor).to(device)  # [1, 4, 320, 320]
        alpha = alpha / 255.

        with torch.no_grad():
            pred = model(x)  # [1, 4, 320, 320]

        pred = pred.cpu().numpy()
        pred = pred.reshape((h, w))  # [320, 320]

        pred[new_trimap == 0] = 0.0
        pred[new_trimap == 255] = 1.0
        cv.imwrite('images/test/out/' + trimap_name, pred * 255)

        # Calculate loss
        mse_loss = compute_mse(pred, alpha, trimap)
        sad_loss = compute_sad(pred, alpha)

        # Keep track of metrics
        mse_losses.update(mse_loss.item())
        sad_losses.update(sad_loss.item())

    return sad_losses.avg, mse_losses.avg
Пример #2
0
def val(val_loader, model):
    mse_losses = AverageMeter()
    sad_losses = AverageMeter()
    gradient_losses = AverageMeter()
    connectivity_losses = AverageMeter()

    model.eval()

    # Batches
    for i, (img, alpha_label, trimap_label, img_path) in enumerate(val_loader):
        # Move to GPU, if available
        img = img.type(torch.FloatTensor).to(device)  # [N, 4, 320, 320]
        alpha_label = alpha_label.type(torch.FloatTensor).to(
            device)  # [N, 320, 320]
        alpha_label = alpha_label.unsqueeze(1)
        trimap_label = trimap_label.to(device)
        # Forward prop.
        trimap_out, alpha_out = model(img)  # [N, 3, 320, 320]
        trimap_out.squeeze(0)
        # alpha_out = alpha_out.reshape((-1, 1, im_size * im_size))  # [N, 320*320]
        trimap_out = trimap_out.argmax(dim=1)
        trimap_out = trimap_out.squeeze(0)
        trimap_out[trimap_out == 1] = 128
        trimap_out[trimap_out == 2] = 255
        trimap_out = np.array(trimap_out.cpu(), dtype=np.uint8)
        # print(trimap_out)
        # return trimap, alpha
        mse_loss = compute_mse(alpha_out, alpha_label, trimap_label)
        sad_loss = compute_sad(alpha_out, alpha_label)
        gradient_loss = compute_gradient_loss(alpha_out, alpha_label,
                                              trimap_label)
        connectivity_loss = compute_connectivity_error(alpha_out, alpha_label,
                                                       trimap_label)
        print("sad:{} mse:{} gradient: {} connectivity: {}".format(
            sad_loss.item(), mse_loss.item(), gradient_loss,
            connectivity_loss))
        # f.write("sad:{} mse:{} gradient: {} connectivity: {}".format(sad_loss.item(), mse_loss.item(), gradient_loss, connectivity_loss) + "\n")

        alpha_out = (alpha_out.copy() * 255).astype(np.uint8)
        draw_str(
            alpha_out, (10, 20),
            "sad:{} mse:{} gradient: {} connectivity: {}".format(
                sad_loss.item(), mse_loss.item(), gradient_loss,
                connectivity_loss))
        cv.imwrite(
            os.path.join('images/test/out/', output_folder,
                         img_path[0].split('/')[-1]), alpha_out)
        # print(os.path.join('images/test/out', output_folder, img_path[0].split('/')[-1]))
        # cv.imwrite(os.path.join('images/test/out', output_folder, img_path[0].split('/')[-1]), alpha_out)
    print("sad_avg:{} mse_avg:{} gradient_avg: {} connectivity_avg: {}".format(
        sad_losses.avg, mse_losses.avg, gradient_losses.avg,
        connectivity_losses.avg))
Пример #3
0
        x = x.type(torch.FloatTensor).to(device)
        alpha = alpha / 255.

        with torch.no_grad():
            pred = model(x)

        pred = pred.cpu().numpy()
        pred = pred.reshape((h, w))

        pred[trimap == 0] = 0.0
        pred[trimap == 255] = 1.0

        # Calculate loss
        # loss = criterion(alpha_out, alpha_label)
        mse_loss = compute_mse(pred, alpha, trimap)
        sad_loss = compute_sad(pred, alpha)
        str_msg = 'sad: %.4f, mse: %.4f' % (sad_loss, mse_loss)
        print(str_msg)

        out = (pred.copy() * 255).astype(np.uint8)
        draw_str(out, (10, 20), str_msg)
        cv.imwrite('images/{}_out.png'.format(i), out)

        new_bg = new_bgs[i]
        new_bg = cv.imread(os.path.join(bg_test, new_bg))
        bh, bw = new_bg.shape[:2]
        wratio = w / bw
        hratio = h / bh
        ratio = wratio if wratio > hratio else hratio
        print('ratio: ' + str(ratio))
        if ratio > 1:
Пример #4
0
        print(x_test.size())

        with torch.no_grad():
            y_pred = model(x_test)

        y_pred = y_pred.cpu().numpy()
        # print('y_pred.shape: ' + str(y_pred.shape))
        y_pred = np.reshape(y_pred, (im_size, im_size))
        # print(y_pred.shape)

        y_pred[trimap == 0] = 0.0
        y_pred[trimap == 255] = 1.0

        alpha = alpha / 255.  # [0., 1.]

        sad = compute_sad(y_pred, alpha)
        mse = compute_mse(y_pred, alpha, trimap)
        str_msg = 'sad: %.4f, mse: %.4f, size: %s' % (sad, mse, str(crop_size))
        print(str_msg)

        out = (y_pred * 255).astype(np.uint8)
        draw_str(out, (10, 20), str_msg)
        cv.imwrite('images/{}_out.png'.format(i), out)

        sample_bg = sample_bgs[i]
        bg = cv.imread(os.path.join(bg_test, sample_bg))
        bh, bw = bg.shape[:2]
        wratio = im_size / bw
        hratio = im_size / bh
        ratio = wratio if wratio > hratio else hratio
        if ratio > 1: