Exemplo n.º 1
0
def test(args, test_loader, model):
    """
    args:
      test_loader: loaded for test dataset
      model: model
    return: class IoU and mean IoU
    """
    # evaluation or test mode
    model.eval()
    total_batches = len(test_loader)

    metric = SegmentationMetric(numClass=args.classes)
    pbar = tqdm(iterable=enumerate(test_loader),
                total=total_batches,
                desc='Valing')
    for i, (input, gt, size, name) in pbar:
        with torch.no_grad():
            input_var = Variable(input).cuda()

        output = model(input_var)
        torch.cuda.synchronize()

        output = output.cpu().data[0].numpy()
        output = output.transpose(1, 2, 0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
        gt = np.asarray(gt[0], dtype=np.uint8)

        # 计算miou
        metric.addBatch(imgPredict=output.flatten(), imgLabel=gt.flatten())

        # save the predicted image
        if args.save:
            save_predict(output,
                         gt,
                         name[0],
                         args.dataset,
                         args.save_seg_dir,
                         output_grey=False,
                         output_color=True,
                         gt_color=True)

    pa = metric.pixelAccuracy()
    cpa = metric.classPixelAccuracy()
    mpa = metric.meanPixelAccuracy()
    Miou, PerMiou_set = metric.meanIntersectionOverUnion()
    FWIoU = metric.Frequency_Weighted_Intersection_over_Union()

    return Miou, PerMiou_set, FWIoU, pa, mpa
Exemplo n.º 2
0
def predict(args, test_loader, model):
    """
    args:
      test_loader: loaded for test dataset, for those that do not provide label on the test set
      model: model
    return: class IoU and mean IoU
    """
    # evaluation or test mode
    model.eval()
    total_batches = len(test_loader)

    metric = SegmentationMetric(numClass=args.classes)
    pbar = tqdm(iterable=enumerate(test_loader), total=total_batches, desc='Valing')
    for i, (input, gt, size, name) in pbar:
        with torch.no_grad():
            input_var = input.cuda()

        output = model(input_var)
        torch.cuda.synchronize()

        output = output.cpu().data[0].numpy()
        output = output.transpose(1, 2, 0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
        gt = np.asarray(gt[0], dtype=np.uint8)

        # 计算miou
        metric.addBatch(imgPredict=output.flatten(), imgLabel=gt.flatten())

        # save the predicted image
        save_predict(output, gt, name[0], args.dataset, args.save_seg_dir,
                         output_grey=False, output_color=True, gt_color=True)

    pa = metric.pixelAccuracy()
    cpa = metric.classPixelAccuracy()
    mpa = metric.meanPixelAccuracy()
    Miou, PerMiou_set = metric.meanIntersectionOverUnion()
    FWIoU = metric.Frequency_Weighted_Intersection_over_Union()
    print('miou {}\nclass iou {}'.format(Miou, PerMiou_set))
    result = args.save_seg_dir + '/results.txt'
    with open(result, 'w') as f:
        f.write(str(Miou))
        f.write('\n{}'.format(PerMiou_set))
Exemplo n.º 3
0
def val(args, val_loader, criteria, model):
    """
    args:
      val_loader: loaded for validation dataset
      model: model
    return: mean IoU and IoU class
    """
    # evaluation mode
    model.eval()
    total_batches = len(val_loader)

    val_loss = []
    metric = SegmentationMetric(args.classes)
    pbar = tqdm(iterable=enumerate(val_loader),
                total=total_batches,
                desc='Val')
    for iteration, (input, label, size, name) in pbar:
        with torch.no_grad():
            input_var = input.cuda().float()
            output = model(input_var)
            if type(output) is tuple:
                output = output[0]

        loss = criteria(output, label.long().cuda())
        val_loss.append(loss)
        output = output.cpu().data[0].numpy()
        gt = np.asarray(label[0].numpy(), dtype=np.uint8)
        output = output.transpose(1, 2, 0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
        # 计算miou
        metric.addBatch(imgPredict=output.flatten(), imgLabel=gt.flatten())

    val_loss = sum(val_loss) / len(val_loss)

    pa = metric.pixelAccuracy()
    cpa = metric.classPixelAccuracy()
    mpa = metric.meanPixelAccuracy()
    Miou, PerMiou_set = metric.meanIntersectionOverUnion()
    FWIoU = metric.Frequency_Weighted_Intersection_over_Union()

    return val_loss, FWIoU, Miou, PerMiou_set
Exemplo n.º 4
0
def predict_overlap_sliding(args,
                            model,
                            testLoader,
                            scales,
                            criterion,
                            mode='predict',
                            save_result=True):
    loss = 0
    count_loss = 0
    model.eval()
    criterion = criterion.cuda()
    tile_h_size, tile_w_size = int(args.tile_hw_size.split(',')[0]), int(
        args.tile_hw_size.split(',')[1])
    center_h_size, center_w_size = int(args.center_hw_size.split(',')[0]), int(
        args.center_hw_size.split(',')[1])
    metric = SegmentationMetric(args.classes)  # args.classes表示有args.classes个分类
    pbar = tqdm(iterable=enumerate(testLoader),
                total=len(testLoader),
                desc='Predicting')
    if mode == 'validation':
        for i, (image, gt, size, name) in pbar:
            B, C, H, W = image.shape
            # image and gt scaled together [0.75, 1.0, 1.25, 1.5, 2.0]
            full_prob = torch.zeros(B, args.classes, H, W).cuda()
            for scale in scales:
                scale = float(scale)
                scale = float(scale)
                sh = int(H * scale)
                sw = int(W * scale)

                scale_image = F.interpolate(image, (sh, sw),
                                            mode='bilinear',
                                            align_corners=True)
                scale_gt = F.interpolate(gt.unsqueeze(1).float(), (sh, sw),
                                         mode='nearest').long()

                # scale之后的尺寸是否大于title_size
                if H < sh and W < sw and (H < tile_h_size or W < tile_w_size):
                    # 直接整张图片预测,并且还原为正常尺寸
                    scale_image = scale_image.cuda()
                    scale_gt = scale_gt.cuda()
                    if args.flip_merge:
                        scale_output = flip_merge(model, scale_image)
                    else:
                        scale_output = model(scale_image)
                else:
                    # 根据保留中心尺寸,检查图片是否需要padding,确保图片是中心尺寸的倍数,倍数*中心尺寸-512=padding*2
                    scale_h, scale_w = scale_image.shape[2], scale_image.shape[
                        3]
                    if scale_h % center_h_size == 0 and scale_w % center_w_size == 0:
                        tile_rows = scale_h / center_h_size
                        tile_cols = scale_w / center_w_size
                    else:
                        h_times = scale_h // center_h_size + 1
                        w_times = scale_w // center_w_size + 1
                        scale_image = pad_image(
                            scale_image,
                            (h_times * center_h_size, w_times * center_w_size))
                        pad_scale_h, pad_scale_w = scale_image.shape[
                            2], scale_image.shape[3]
                        tile_rows = pad_scale_h / center_h_size
                        tile_cols = pad_scale_w / center_w_size
                    # (输入尺寸-保留中心尺寸)// 2 == 大图padding
                    outer_h_padding = int((tile_h_size - center_h_size) / 2)
                    outer_w_padding = int((tile_w_size - center_w_size) / 2)
                    scale_image = pad_image(scale_image,
                                            (outer_h_padding, outer_w_padding))

                    scale_image_size = scale_image.shape  # (b,c,h,w)
                    overlap = 1 / 3  # 每次滑动的覆盖率为1/3
                    stride = ceil(
                        tile_h_size * (1 - overlap)
                    )  # 滑动步长:512*(1-1/3) = 513     512*(1-1/3)= 342
                    tile_rows = int(
                        ceil((scale_image_size[2] - tile_h_size) / stride) +
                        1)  # 行滑动步数:(3072-512)/342+1=9
                    tile_cols = int(
                        ceil((scale_image_size[3] - tile_w_size) / stride) +
                        1)  # 列滑动步数:(3328-512)/342+1=10
                    outputs_prob = torch.zeros(B, args.classes, sh, sw).cuda()
                    count_prob = torch.zeros(B, 1, sh, sw).cuda()

                    for row in range(
                            tile_rows):  # row = 0,1     0,1,2,3,4,5,6,7,8
                        for col in range(
                                tile_cols
                        ):  # col = 0,1,2,3     0,1,2,3,4,5,6,7,8,9
                            x1 = int(col *
                                     stride)  # 起始位置x1 = 0 * 513 = 0   0*342
                            y1 = int(row * stride)  # y1 = 0 * 513 = 0   0*342
                            x2 = min(x1 + tile_w_size, scale_image_size[3]
                                     )  # 末位置x2 = min(0+512, 3328)
                            y2 = min(
                                y1 + tile_h_size,
                                scale_image_size[2])  # y2 = min(0+512, 3072)
                            x1 = max(int(x2 - tile_w_size),
                                     0)  # 重新校准起始位置x1 = max(512-512, 0)
                            y1 = max(int(y2 - tile_h_size),
                                     0)  # y1 = max(512-512, 0)

                            with torch.no_grad():
                                tile_image = scale_image[:, :, y1:y2,
                                                         x1:x2].cuda()
                                tile_gt = scale_gt[:, :, y1:y2,
                                                   x1:x2].long().cuda()
                                if args.flip_merge:
                                    tile_output = flip_merge(model, tile_image)
                                else:
                                    tile_output = model(tile_image)

                                    # output = (main_loss, aux_loss1, axu_loss2***)
                            if type(tile_output) is tuple:
                                length = len(scale_output)
                                for index, scale_out in enumerate(
                                        scale_output):
                                    criterion = criterion.cuda()
                                    loss_record = criterion(
                                        scale_out, scale_gt.squeeze(1))
                                    if index == 0:
                                        loss_record *= 0.6
                                    else:
                                        loss_record *= 0.4 / (length - 1)
                                    loss += loss_record
                                scale_output = scale_output[0]
                                count_loss += 1
                            elif type(tile_output) is not tuple:
                                loss += criterion(tile_output,
                                                  tile_gt.squeeze(1))
                                count_loss += 1

                            outputs_prob[:, :, y1:y2, x1:x2] += tile_output
                            count_prob[:, :, y1:y2, x1:x2] += 1

                    # 结束每一个scale之后的图片滑动窗口计算概率
                    assert ((count_prob == 0).sum() == 0)
                    outputs = outputs_prob / count_prob

                outputs = F.interpolate(outputs, (H, W),
                                        mode='bilinear',
                                        align_corners=True)
                full_prob += outputs

            # visualize normalization Weights
            # plt.imshow(np.mean(count_predictions, axis=2))
            # plt.show()
            gt = gt.cpu().numpy()
            full_prob = torch.argmax(full_prob, 1).long()
            full_prob = np.asarray(full_prob.cpu(),
                                   dtype=np.uint8)  # (B,C,H,W)

            # plt.imshow(gt[0])
            # plt.show()
            # 计算miou
            '''设置输出原图和预测图片的颜色灰度还是彩色'''
            for index in range(
                    full_prob.shape[0]):  # full_prob shape[0] is batch_size
                metric.addBatch(full_prob[index], gt[index])
                if save_result:
                    save_predict(full_prob[index],
                                 gt[index],
                                 name[index],
                                 args.dataset,
                                 args.save_seg_dir,
                                 output_grey=False,
                                 output_color=True,
                                 gt_color=True)

        loss, FWIoU, Miou, MIoU_avg, PerCiou_set, Pa, PerCpa_set, Mpa, MF, F_set, F_avg = eval_metric(
            args, metric, count_loss, loss)
    else:
        for i, (image, size, name) in pbar:
            B, C, H, W = image.shape
            # image scaled [0.75, 1.0, 1.25, 1.5, 2.0]
            full_prob = torch.zeros(B, args.classes, H, W).cuda()
            for scale in scales:
                sh = int(H * float(scale))
                sw = int(W * float(scale))
                scale_image = F.interpolate(image, (sh, sw),
                                            mode='bilinear',
                                            align_corners=True)

                # scale之后的尺寸是否大于title_size
                if H < sh and W < sw and (H < tile_h_size or W < tile_w_size):
                    # 直接整张图片预测,并且还原为正常尺寸
                    scale_image = scale_image.cuda()
                    scale_gt = scale_gt.cuda()
                    if args.flip_merge:
                        scale_output = flip_merge(model, scale_image)
                    else:
                        scale_output = model(scale_image)
                else:
                    scale_image_size = scale_image.shape  # (b,c,h,w)
                    overlap = 1 / 3  # 每次滑动的覆盖率为1/3
                    stride = ceil(tile_h_size *
                                  (1 - overlap))  # 滑动步长:512*(1-1/3)= 342
                    tile_rows = int(
                        ceil((scale_image_size[2] - tile_h_size) / stride) +
                        1)  # 行滑动步数:(3072-512)/342+1=9
                    tile_cols = int(
                        ceil((scale_image_size[3] - tile_w_size) / stride) +
                        1)  # 列滑动步数:(3328-512)/342+1=10
                    outputs_prob = torch.zeros(B, args.classes, sh, sw).cuda()
                    count_prob = torch.zeros(B, 1, sh, sw).cuda()

                    for row in range(
                            tile_rows):  # row = 0,1     0,1,2,3,4,5,6,7,8
                        for col in range(
                                tile_cols
                        ):  # col = 0,1,2,3     0,1,2,3,4,5,6,7,8,9
                            x1 = int(col *
                                     stride)  # 起始位置x1 = 0 * 513 = 0   0*342
                            y1 = int(row * stride)  # y1 = 0 * 513 = 0   0*342
                            x2 = min(x1 + tile_w_size, scale_image_size[3]
                                     )  # 末位置x2 = min(0+512, 3328)
                            y2 = min(
                                y1 + tile_h_size,
                                scale_image_size[2])  # y2 = min(0+512, 3072)
                            x1 = max(int(x2 - tile_w_size),
                                     0)  # 重新校准起始位置x1 = max(512-512, 0)
                            y1 = max(int(y2 - tile_h_size),
                                     0)  # y1 = max(512-512, 0)

                            with torch.no_grad():
                                tile_image = scale_image[:, :, y1:y2,
                                                         x1:x2].cuda()
                                tile_gt = scale_gt[:, :, y1:y2,
                                                   x1:x2].long().cuda()
                                if args.flip_merge:
                                    tile_output = flip_merge(model, tile_image)
                                else:
                                    tile_output = model(tile_image)

                            if type(tile_output) is tuple:
                                tile_output = tile_output[0]

                            outputs_prob[:, :, y1:y2, x1:x2] += tile_output
                            count_prob[:, :, y1:y2, x1:x2] += 1

                    # 结束每一个scale之后的图片滑动窗口计算概率
                    assert ((count_prob == 0).sum() == 0)
                    outputs = outputs_prob / count_prob

                outputs = F.interpolate(outputs, (H, W),
                                        mode='bilinear',
                                        align_corners=True)
                full_prob += outputs

            # visualize normalization Weights
            # plt.imshow(np.mean(count_predictions, axis=2))
            # plt.show()
            gt = gt.cpu().numpy()
            full_prob = torch.argmax(full_prob, 1).long()
            full_prob = np.asarray(full_prob.cpu(),
                                   dtype=np.uint8)  # (B,C,H,W)

            # plt.imshow(gt[0])
            # plt.show()
            # 计算miou
            for index in range(
                    full_prob.shape[0]):  # gt shape[0] is batch_size
                if save_result:
                    save_predict(full_prob[index],
                                 None,
                                 name[index],
                                 args.dataset,
                                 args.save_seg_dir,
                                 output_grey=True,
                                 output_color=False,
                                 gt_color=False)

        loss, FWIoU, Miou, MIoU_avg, PerCiou_set, Pa, PerCpa_set, Mpa, MF, F_set, F_avg = 0, 0, 0, 0, {}, 0, {}, 0, 0, {}, 0

    return loss, FWIoU, Miou, MIoU_avg, PerCiou_set, Pa, PerCpa_set, Mpa, MF, F_set, F_avg
Exemplo n.º 5
0
def predict_sliding(args, net, image, tile_size, classes):
    total_batches = len(image)
    metric = SegmentationMetric(args.classes)  # args.classes表示有args.classes个分类
    pbar = tqdm(iterable=enumerate(image),
                total=total_batches,
                desc='Predicting')
    for i, (input, gt, size, name) in pbar:
        image_size = input.shape  # (1,3,3328,3072)
        overlap = 1 / 3  # 每次滑动的覆盖率为1/3
        # print(image_size, tile_size)
        stride = ceil(
            tile_size[0] *
            (1 - overlap))  # 滑动步长:512*(1-1/3) = 513     512*(1-1/3)= 342
        tile_rows = int(ceil((image_size[2] - tile_size[0]) / stride) +
                        1)  # 行滑动步数:(3072-512)/342+1=9
        tile_cols = int(ceil((image_size[3] - tile_size[1]) / stride) +
                        1)  # 列滑动步数:(3328-512)/342+1=10
        full_probs = np.zeros((image_size[2], image_size[3],
                               classes))  # 初始化全概率矩阵shape(3072,3328,3)
        count_predictions = np.zeros((image_size[2], image_size[3],
                                      classes))  # 初始化计数矩阵shape(3072,3328,3)

        for row in range(tile_rows):  # row = 0,1     0,1,2,3,4,5,6,7,8
            for col in range(
                    tile_cols):  # col = 0,1,2,3     0,1,2,3,4,5,6,7,8,9
                x1 = int(col * stride)  # 起始位置x1 = 0 * 513 = 0   0*342
                y1 = int(row * stride)  # y1 = 0 * 513 = 0   0*342
                x2 = min(x1 + tile_size[1],
                         image_size[3])  # 末位置x2 = min(0+512, 3328)
                y2 = min(y1 + tile_size[0],
                         image_size[2])  # y2 = min(0+512, 3072)
                x1 = max(int(x2 - tile_size[1]),
                         0)  # 重新校准起始位置x1 = max(512-512, 0)
                y1 = max(int(y2 - tile_size[0]), 0)  # y1 = max(512-512, 0)

                img = input[:, :, y1:y2,
                            x1:x2]  # 滑动窗口对应的图像 imge[:, :, 0:512, 0:512]
                padded_img = pad_image(img,
                                       tile_size)  # padding 确保扣下来的图像为512*512
                # plt.imshow(padded_img)
                # plt.show()

                # 将扣下来的部分传入网络,网络输出概率图。
                with torch.no_grad():
                    input_var = torch.from_numpy(padded_img).cuda().float()
                    padded_prediction = net(input_var)

                    if type(padded_prediction) is tuple:
                        padded_prediction = padded_prediction[0]

                    torch.cuda.synchronize()

                if isinstance(padded_prediction, list):
                    padded_prediction = padded_prediction[
                        0]  # shape(1,3,512,512)

                padded_prediction = padded_prediction.cpu().data[0].numpy(
                ).transpose(1, 2, 0)  # 通道位置变换(512,512,3)
                prediction = padded_prediction[
                    0:img.shape[2],
                    0:img.shape[3], :]  # 扣下相应面积 shape(512,512,3)
                count_predictions[y1:y2, x1:x2] += 1  # 窗口区域内的计数矩阵加1
                full_probs[y1:y2, x1:x2] += prediction  # 窗口区域内的全概率矩阵叠加预测结果

        # average the predictions in the overlapping regions
        full_probs /= count_predictions  # 全概率矩阵 除以 计数矩阵 即得 平均概率
        # visualize normalization Weights
        # plt.imshow(np.mean(count_predictions, axis=2))
        # plt.show()
        full_probs = np.asarray(np.argmax(full_probs, axis=2), dtype=np.uint8)
        '''设置输出原图和预测图片的颜色灰度还是彩色'''
        gt = gt[0].numpy()
        # 计算miou
        metric.addBatch(full_probs, gt)
        save_predict(full_probs,
                     gt,
                     name[0],
                     args.dataset,
                     args.save_seg_dir,
                     output_grey=False,
                     output_color=True,
                     gt_color=True)

    pa = metric.pixelAccuracy()
    cpa = metric.classPixelAccuracy()
    mpa = metric.meanPixelAccuracy()
    Miou, PerMiou_set = metric.meanIntersectionOverUnion()
    FWIoU = metric.Frequency_Weighted_Intersection_over_Union()

    print('miou {}\nclass iou {}'.format(Miou, PerMiou_set))
    result = args.save_seg_dir + '/results.txt'
    with open(result, 'w') as f:
        f.write(str(Miou))
        f.write('\n{}'.format(PerMiou_set))
def predict_multiscale_sliding(args,
                               model,
                               class_dict_df,
                               testLoader,
                               scales,
                               overlap,
                               criterion,
                               mode='predict',
                               save_result=True):
    loss = 0
    count_loss = 0
    model.eval()
    criterion = criterion.cuda()
    tile_h_size, tile_w_size = int(args.tile_hw_size.split(',')[0]), int(
        args.tile_hw_size.split(',')[1])
    metric = SegmentationMetric(args.classes)
    pbar = tqdm(iterable=enumerate(testLoader),
                total=len(testLoader),
                desc='Predicting')
    if mode == 'validation':
        for i, (image, gt, size, name) in pbar:
            B, C, H, W = image.shape
            # image and gt scaled together [0.75, 1.0, 1.25, 1.5, 2.0]
            full_prob = torch.zeros(B, args.classes, H, W).cuda()
            for scale in scales:
                scale = float(scale)
                scale = float(scale)
                sh = int(H * scale)
                sw = int(W * scale)

                scale_image = F.interpolate(image, (sh, sw),
                                            mode='bilinear',
                                            align_corners=True).float()
                scale_gt = F.interpolate(gt.unsqueeze(1).float(), (sh, sw),
                                         mode='nearest').long()

                # Whether the size after scale is greater than title_size
                if (H > sh or W > sw) and (H < tile_h_size or W < tile_w_size):
                    # Directly predict the entire image and restore it to normal size
                    with torch.no_grad():
                        scale_image = scale_image.cuda()
                        scale_gt = scale_gt.cuda()
                        if args.flip_merge:
                            outputs = flip_merge(model, scale_image)
                        else:
                            outputs = model(scale_image)

                        if type(outputs) is tuple:
                            length = len(outputs)
                            for index, out in enumerate(outputs):
                                criterion = criterion.cuda()
                                loss_record = criterion(
                                    out, scale_gt.squeeze(1))
                                if index == 0:
                                    loss_record *= 0.6
                                else:
                                    loss_record *= 0.4 / (length - 1)
                                loss += loss_record
                            outputs = outputs[0]
                            count_loss += 1
                        elif type(outputs) is not tuple:
                            loss += criterion(outputs, scale_gt.squeeze(1))
                            count_loss += 1
                else:
                    scale_image_size = scale_image.shape  # (b,c,h,w)
                    # overlap stands for coverage per slide
                    stride = ceil(tile_h_size * (1 - overlap))
                    tile_rows = int(
                        ceil((scale_image_size[2] - tile_h_size) / stride) + 1)
                    tile_cols = int(
                        ceil((scale_image_size[3] - tile_w_size) / stride) + 1)
                    outputs_prob = torch.zeros(B, args.classes, sh, sw).cuda()
                    count_prob = torch.zeros(B, 1, sh, sw).cuda()

                    for row in range(tile_rows):
                        for col in range(tile_cols):
                            x1 = int(col * stride)
                            y1 = int(row * stride)
                            x2 = min(x1 + tile_w_size, scale_image_size[3])
                            y2 = min(y1 + tile_h_size, scale_image_size[2])
                            x1 = max(int(x2 - tile_w_size), 0)
                            y1 = max(int(y2 - tile_h_size), 0)

                            with torch.no_grad():
                                tile_image = scale_image[:, :, y1:y2,
                                                         x1:x2].float().cuda()
                                tile_gt = scale_gt[:, :, y1:y2,
                                                   x1:x2].long().cuda()
                                if args.flip_merge:
                                    tile_output = flip_merge(model, tile_image)
                                else:
                                    tile_output = model(tile_image)

                            # output = (main_loss, aux_loss1, axu_loss2***)
                            if type(tile_output) is tuple:
                                length = len(tile_output)
                                for index, out in enumerate(tile_output):
                                    criterion = criterion.cuda()
                                    loss_record = criterion(
                                        out, tile_gt.squeeze(1))
                                    if index == 0:
                                        loss_record *= 0.6
                                    else:
                                        loss_record *= 0.4 / (length - 1)
                                    loss += loss_record
                                tile_output = tile_output[0]
                                count_loss += 1
                            elif type(tile_output) is not tuple:
                                loss += criterion(tile_output,
                                                  tile_gt.squeeze(1))
                                count_loss += 1

                            outputs_prob[:, :, y1:y2, x1:x2] += tile_output
                            count_prob[:, :, y1:y2, x1:x2] += 1

                    assert ((count_prob == 0).sum() == 0)
                    outputs = outputs_prob / count_prob

                outputs = F.interpolate(outputs, (H, W),
                                        mode='bilinear',
                                        align_corners=True)
                full_prob += outputs

            gt = np.asarray(gt.cpu(), dtype=np.uint8)
            full_prob = torch.argmax(full_prob, 1).long()
            full_prob = np.asarray(full_prob.cpu(),
                                   dtype=np.uint8)  # (B,C,H,W)
            '''Sets the color of the output image and predict image to be grayscale or color'''
            for index in range(
                    full_prob.shape[0]):  # full_prob shape[0] is batch_size
                metric.addBatch(full_prob[index], gt[index])
                if save_result:
                    save_predict(full_prob[index],
                                 gt[index],
                                 name[index],
                                 args.dataset,
                                 args.save_seg_dir,
                                 output_grey=False,
                                 output_color=True,
                                 gt_color=True)

        loss, FWIoU, Miou, Miou_Noback, PerCiou_set, Pa, PerCpa_set, Mpa, MF, F_set, F1_Noback = \
            eval_metric(args, class_dict_df, metric, count_loss, loss)
    else:
        for i, (image, size, name) in pbar:
            B, C, H, W = image.shape
            # image scaled [0.75, 1.0, 1.25, 1.5, 2.0]
            full_prob = torch.zeros(B, args.classes, H, W).cuda()
            for scale in scales:
                scale = float(scale)
                sh = int(H * scale)
                sw = int(W * scale)

                scale_image = F.interpolate(image, (sh, sw),
                                            mode='bilinear',
                                            align_corners=True).float()

                # Whether the size after scale is greater than title_size
                if (H > sh or W > sw) and (H < tile_h_size or W < tile_w_size):
                    # Directly predict the entire image and restore it to normal size
                    with torch.no_grad():
                        scale_image = scale_image.cuda()
                        if args.flip_merge:
                            outputs = flip_merge(model, scale_image)
                        else:
                            outputs = model(scale_image)
                        if type(outputs) is tuple:
                            outputs = outputs[0]

                else:
                    scale_image_size = scale_image.shape  # (b,c,h,w)
                    # overlap stands for coverage per slide
                    stride = ceil(tile_h_size * (1 - overlap))
                    tile_rows = int(
                        ceil((scale_image_size[2] - tile_h_size) / stride) + 1)
                    tile_cols = int(
                        ceil((scale_image_size[3] - tile_w_size) / stride) + 1)
                    outputs_prob = torch.zeros(B, args.classes, sh, sw).cuda()
                    count_prob = torch.zeros(B, 1, sh, sw).cuda()

                    for row in range(tile_rows):
                        for col in range(tile_cols):
                            x1 = int(col * stride)
                            y1 = int(row * stride)
                            x2 = min(x1 + tile_w_size, scale_image_size[3])
                            y2 = min(y1 + tile_h_size, scale_image_size[2])
                            x1 = max(int(x2 - tile_w_size), 0)
                            y1 = max(int(y2 - tile_h_size), 0)

                            with torch.no_grad():
                                tile_image = scale_image[:, :, y1:y2,
                                                         x1:x2].float().cuda()
                                if args.flip_merge:
                                    tile_output = flip_merge(model, tile_image)
                                else:
                                    tile_output = model(tile_image)

                            if type(tile_output) is tuple:
                                tile_output = tile_output[0]
                            outputs_prob[:, :, y1:y2, x1:x2] += tile_output
                            count_prob[:, :, y1:y2, x1:x2] += 1

                    assert ((count_prob == 0).sum() == 0)
                    outputs = outputs_prob / count_prob

                outputs = F.interpolate(outputs, (H, W),
                                        mode='bilinear',
                                        align_corners=True)
                full_prob += outputs

            full_prob = torch.argmax(full_prob, 1).long()
            full_prob = np.asarray(full_prob.cpu(),
                                   dtype=np.uint8)  # (B,C,H,W)
            '''Sets the color of the output image and predict image to be grayscale or color'''
            # save results
            for index in range(
                    full_prob.shape[0]):  # gt shape[0] is batch_size
                if save_result:
                    save_predict(full_prob[index],
                                 None,
                                 name[index],
                                 args.dataset,
                                 args.save_seg_dir,
                                 output_grey=True,
                                 output_color=False,
                                 gt_color=False)

        loss, FWIoU, Miou, Miou_Noback, PerCiou_set, Pa, PerCpa_set, Mpa, MF, F_set, F1_Noback = 0, 0, 0, 0, {}, 0, {}, 0, 0, {}, 0

    return loss, FWIoU, Miou, Miou_Noback, PerCiou_set, Pa, PerCpa_set, Mpa, MF, F_set, F1_Noback
Exemplo n.º 7
0
def test(args, test_loader, model):
    """
    args:
      test_loader: loaded for test dataset
      model: model
    return: class IoU and mean IoU
    """
    # evaluation or test mode
    model.eval()
    total_batches = len(test_loader)

    Miou_list = []
    Iou_list = []
    Pa_list = []
    Mpa_list = []
    Fmiou_list = []
    pbar = tqdm(iterable=enumerate(test_loader),
                total=total_batches,
                desc='Valing')
    for i, (input, gt, size, name) in pbar:
        with torch.no_grad():
            input_var = Variable(input).cuda()
        start_time = time.time()
        output = model(input_var)
        torch.cuda.synchronize()
        time_taken = time.time() - start_time
        pbar.set_postfix(cost_time='%.3f' % time_taken)
        output = output.cpu().data[0].numpy()
        output = output.transpose(1, 2, 0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
        gt = np.asarray(gt[0], dtype=np.uint8)

        # 计算miou
        metric = SegmentationMetric(numClass=args.classes)
        metric.addBatch(imgPredict=output, imgLabel=gt)
        miou, iou = metric.meanIntersectionOverUnion()
        fmiou = metric.Frequency_Weighted_Intersection_over_Union()
        pa = metric.pixelAccuracy()
        mpa = metric.meanPixelAccuracy()
        Miou_list.append(miou)
        Fmiou_list.append(fmiou)
        Pa_list.append(pa)
        Mpa_list.append(mpa)
        iou = np.array(iou)
        Iou_list.append(iou)

        # save the predicted image
        if args.save:
            save_predict(output,
                         gt,
                         name[0],
                         args.dataset,
                         args.save_seg_dir,
                         output_grey=False,
                         output_color=True,
                         gt_color=True)

    miou = np.mean(Miou_list)
    fmiou = np.mean(Fmiou_list)
    pa = np.mean(Pa_list)
    mpa = np.mean(Mpa_list)
    Iou_list = np.asarray(Iou_list)
    iou = np.mean(Iou_list, axis=0)
    cls_iu = dict(zip(range(args.classes), iou))
    return miou, cls_iu, fmiou, pa, mpa