def predict(args, test_loader, model):
    """
    args:
      test_loader: loaded for test dataset, for those that do not provide label on the test set
      model: model
    return: class IoU and mean IoU
    """
    # evaluation or test mode
    model.eval()
    total_batches = len(test_loader)
    for i, (input, size, name) in enumerate(test_loader):
        with torch.no_grad():
            input_var = input.cuda()
        start_time = time.time()
        output = model(input_var)
        torch.cuda.synchronize()
        time_taken = time.time() - start_time
        print('[%d/%d]  time: %.2f' % (i + 1, total_batches, time_taken))
        output = output.cpu().data[0].numpy()
        output = output.transpose(1, 2, 0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

        # Save the predict greyscale output for Cityscapes official evaluation
        # Modify image name to meet official requirement
        name[0] = name[0].rsplit('_', 1)[0] + '*'
        save_predict(output,
                     None,
                     name[0],
                     args.dataset,
                     args.save_seg_dir,
                     output_grey=True,
                     output_color=False,
                     gt_color=False)
Пример #2
0
def test(args, test_loader, model):
    """
    args:
      test_loader: loaded for test dataset
      model: model
    return: class IoU and mean IoU
    """
    # evaluation or test mode
    model.eval()
    total_batches = len(test_loader)

    data_list = []
    for i, (input, label, size, name) in enumerate(test_loader):
        with torch.no_grad():
            input_var = Variable(input).cuda()
        start_time = time.time()
        output = model(input_var)
        torch.cuda.synchronize()
        time_taken = time.time() - start_time
        print('[%d/%d]  time: %.2f' % (i + 1, total_batches, time_taken))
        output = output.cpu().data[0].numpy()
        gt = np.asarray(label[0].numpy(), dtype=np.uint8)
        output = output.transpose(1, 2, 0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
        data_list.append([gt.flatten(), output.flatten()])

        # save the predicted image
        if args.save:
            save_predict(output, gt, name[0], args.dataset, args.save_seg_dir,
                         output_grey=False, output_color=True, gt_color=True)

    meanIoU, per_class_iu = get_iou(data_list, args.classes)
    return meanIoU, per_class_iu
Пример #3
0
def test(args, test_loader, model):
    """
    args:
      test_loader: loaded for test dataset
      model: model
    return: class IoU and mean IoU
    """
    # evaluation or test mode
    model.eval()
    total_batches = len(test_loader)

    metric = SegmentationMetric(numClass=args.classes)
    pbar = tqdm(iterable=enumerate(test_loader),
                total=total_batches,
                desc='Valing')
    for i, (input, gt, size, name) in pbar:
        with torch.no_grad():
            input_var = Variable(input).cuda()

        output = model(input_var)
        torch.cuda.synchronize()

        output = output.cpu().data[0].numpy()
        output = output.transpose(1, 2, 0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
        gt = np.asarray(gt[0], dtype=np.uint8)

        # 计算miou
        metric.addBatch(imgPredict=output.flatten(), imgLabel=gt.flatten())

        # save the predicted image
        if args.save:
            save_predict(output,
                         gt,
                         name[0],
                         args.dataset,
                         args.save_seg_dir,
                         output_grey=False,
                         output_color=True,
                         gt_color=True)

    pa = metric.pixelAccuracy()
    cpa = metric.classPixelAccuracy()
    mpa = metric.meanPixelAccuracy()
    Miou, PerMiou_set = metric.meanIntersectionOverUnion()
    FWIoU = metric.Frequency_Weighted_Intersection_over_Union()

    return Miou, PerMiou_set, FWIoU, pa, mpa
Пример #4
0
def predict(args, test_loader, model):
    """
    args:
      test_loader: loaded for test dataset, for those that do not provide label on the test set
      model: model
    return: class IoU and mean IoU
    """
    # evaluation or test mode
    model.eval()
    total_batches = len(test_loader)
    data_list = []
    pbar = tqdm(iterable=enumerate(test_loader),
                total=total_batches,
                desc='Predicting')
    for i, (input, label, size, name) in pbar:
        with torch.no_grad():
            input_var = input.cuda().float()
        output = model(input_var)
        torch.cuda.synchronize()

        output = output.cpu().data[0].numpy()
        output = output.transpose(1, 2, 0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
        gt = np.asarray(label[0].numpy(), dtype=np.uint8)

        # Save the predict greyscale output for Cityscapes official evaluation
        # Modify image name to meet official requirement
        save_predict(output,
                     None,
                     name[0],
                     args.dataset,
                     args.save_seg_dir,
                     output_grey=False,
                     output_color=True,
                     gt_color=False)
        data_list.append([gt.flatten(), output.flatten()])
    meanIoU, per_class_iu = get_iou(data_list, args.classes)
    print('miou {}\nclass iou {}'.format(meanIoU, per_class_iu))
    result = args.save_seg_dir + '/results.txt'
    with open(result, 'w') as f:
        f.write(str(meanIoU))
        f.write('\n{}'.format(str(per_class_iu)))
Пример #5
0
def predict(args, test_loader, model):
    """
    args:
      test_loader: loaded for test dataset, for those that do not provide label on the test set
      model: model
    return: class IoU and mean IoU
    """
    # evaluation or test mode
    model.eval()
    total_batches = len(test_loader)

    metric = SegmentationMetric(numClass=args.classes)
    pbar = tqdm(iterable=enumerate(test_loader), total=total_batches, desc='Valing')
    for i, (input, gt, size, name) in pbar:
        with torch.no_grad():
            input_var = input.cuda()

        output = model(input_var)
        torch.cuda.synchronize()

        output = output.cpu().data[0].numpy()
        output = output.transpose(1, 2, 0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
        gt = np.asarray(gt[0], dtype=np.uint8)

        # 计算miou
        metric.addBatch(imgPredict=output.flatten(), imgLabel=gt.flatten())

        # save the predicted image
        save_predict(output, gt, name[0], args.dataset, args.save_seg_dir,
                         output_grey=False, output_color=True, gt_color=True)

    pa = metric.pixelAccuracy()
    cpa = metric.classPixelAccuracy()
    mpa = metric.meanPixelAccuracy()
    Miou, PerMiou_set = metric.meanIntersectionOverUnion()
    FWIoU = metric.Frequency_Weighted_Intersection_over_Union()
    print('miou {}\nclass iou {}'.format(Miou, PerMiou_set))
    result = args.save_seg_dir + '/results.txt'
    with open(result, 'w') as f:
        f.write(str(Miou))
        f.write('\n{}'.format(PerMiou_set))
Пример #6
0
def predict(args, test_loader, model):
    """
    args:
      test_loader: loaded for test dataset, for those that do not provide label on the test set
      model: model
    return: class IoU and mean IoU
    """
    # evaluation or test mode
    model.eval()
    total_batches = len(test_loader)
    # k = enumerate(test_loader)
    if args.dataset == "camvid":
        for i, (input, size, _, name) in enumerate(test_loader):
            with torch.no_grad():
                input_var = input.cuda()
            start_time = time.time()
            output = model(input_var)
            torch.cuda.synchronize()
            time_taken = time.time() - start_time
            print('[%d/%d]  time: %.2f' % (i + 1, total_batches, time_taken))
            output = output.cpu().data[0].numpy()
            output = output.transpose(1, 2, 0)
            output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

            # Save the predict greyscale output for Cityscapes official evaluation
            # Modify image name to meet official requirement
            name[0] = name[0].rsplit('_', 1)[0] + '*'
            save_predict(output,
                         None,
                         name[0],
                         args.dataset,
                         args.save_seg_dir,
                         output_grey=False,
                         output_color=True,
                         gt_color=False)
            # save_predict(output, None, name[0], args.dataset, args.save_seg_dir,
            #              output_grey=True, output_color=False, gt_color=False)

    elif args.dataset == "cityscapes":
        for i, (input, size, name) in enumerate(test_loader):
            with torch.no_grad():
                input_var = input.cuda()
            start_time = time.time()
            output = model(input_var)
            torch.cuda.synchronize()
            time_taken = time.time() - start_time
            print('[%d/%d]  time: %.2f' % (i + 1, total_batches, time_taken))
            output = output.cpu().data[0].numpy()
            output = output.transpose(1, 2, 0)
            output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

            # Save the predict greyscale output for Cityscapes official evaluation
            # Modify image name to meet official requirement
            name[0] = name[0].rsplit('_', 1)[0] + '*'
            save_predict(output,
                         None,
                         name[0],
                         args.dataset,
                         args.save_seg_dir,
                         output_grey=False,
                         output_color=True,
                         gt_color=False)
            # save_predict(output, None, name[0], args.dataset, args.save_seg_dir,
            #              output_grey=True, output_color=False, gt_color=False)

    elif args.dataset == "remote":
        for i, (input, size, name) in enumerate(test_loader):
            with torch.no_grad():
                input_var = input.cuda()
            start_time = time.time()
            output = model(input_var)
            torch.cuda.synchronize()
            time_taken = time.time() - start_time
            print('[%d/%d]  time: %.2f' % (i + 1, total_batches, time_taken))
            output = output.cpu().data[0].numpy()
            output = output.transpose(1, 2, 0)
            output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

            # Save the predict greyscale output for Cityscapes official evaluation
            # Modify image name to meet official requirement
            # name[0] = name[0].rsplit('_', 1)[0] + '*'
            # save_predict(output, None, name[0], args.dataset, args.save_seg_dir,
            #              output_grey=False, output_color=True, gt_color=False)
            save_predict(output,
                         None,
                         name[0],
                         args.dataset,
                         args.save_seg_dir,
                         output_grey=True,
                         output_color=False,
                         gt_color=False)
        zipDir(args.save_seg_dir,
               "C:/Users/DELL/Desktop/ccf_baidu_remote_sense.zip")

    else:
        for i, (input, size, name) in enumerate(test_loader):
            with torch.no_grad():
                input_var = input.cuda()
            start_time = time.time()
            output = model(input_var)
            torch.cuda.synchronize()
            time_taken = time.time() - start_time
            print('[%d/%d]  time: %.2f' % (i + 1, total_batches, time_taken))
            output = output.cpu().data[0].numpy()
            output = output.transpose(1, 2, 0)
            output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
            output[output == 1] = 255
            # Save the predict greyscale output for Cityscapes official evaluation
            # Modify image name to meet official requirement
            # name[0] = name[0].rsplit('_', 1)[0] + '*'
            # save_predict(output, None, name[0], args.dataset, args.save_seg_dir,
            #              output_grey=False, output_color=True, gt_color=False)
            # print(name)
            name = name[0].split('\\')
            name = name[1].split('/')
            save_predict(output,
                         None,
                         name[1],
                         args.dataset,
                         args.save_seg_dir,
                         output_grey=True,
                         output_color=False,
                         gt_color=False)
Пример #7
0
def predict_overlap_sliding(args,
                            model,
                            testLoader,
                            scales,
                            criterion,
                            mode='predict',
                            save_result=True):
    loss = 0
    count_loss = 0
    model.eval()
    criterion = criterion.cuda()
    tile_h_size, tile_w_size = int(args.tile_hw_size.split(',')[0]), int(
        args.tile_hw_size.split(',')[1])
    center_h_size, center_w_size = int(args.center_hw_size.split(',')[0]), int(
        args.center_hw_size.split(',')[1])
    metric = SegmentationMetric(args.classes)  # args.classes表示有args.classes个分类
    pbar = tqdm(iterable=enumerate(testLoader),
                total=len(testLoader),
                desc='Predicting')
    if mode == 'validation':
        for i, (image, gt, size, name) in pbar:
            B, C, H, W = image.shape
            # image and gt scaled together [0.75, 1.0, 1.25, 1.5, 2.0]
            full_prob = torch.zeros(B, args.classes, H, W).cuda()
            for scale in scales:
                scale = float(scale)
                scale = float(scale)
                sh = int(H * scale)
                sw = int(W * scale)

                scale_image = F.interpolate(image, (sh, sw),
                                            mode='bilinear',
                                            align_corners=True)
                scale_gt = F.interpolate(gt.unsqueeze(1).float(), (sh, sw),
                                         mode='nearest').long()

                # scale之后的尺寸是否大于title_size
                if H < sh and W < sw and (H < tile_h_size or W < tile_w_size):
                    # 直接整张图片预测,并且还原为正常尺寸
                    scale_image = scale_image.cuda()
                    scale_gt = scale_gt.cuda()
                    if args.flip_merge:
                        scale_output = flip_merge(model, scale_image)
                    else:
                        scale_output = model(scale_image)
                else:
                    # 根据保留中心尺寸,检查图片是否需要padding,确保图片是中心尺寸的倍数,倍数*中心尺寸-512=padding*2
                    scale_h, scale_w = scale_image.shape[2], scale_image.shape[
                        3]
                    if scale_h % center_h_size == 0 and scale_w % center_w_size == 0:
                        tile_rows = scale_h / center_h_size
                        tile_cols = scale_w / center_w_size
                    else:
                        h_times = scale_h // center_h_size + 1
                        w_times = scale_w // center_w_size + 1
                        scale_image = pad_image(
                            scale_image,
                            (h_times * center_h_size, w_times * center_w_size))
                        pad_scale_h, pad_scale_w = scale_image.shape[
                            2], scale_image.shape[3]
                        tile_rows = pad_scale_h / center_h_size
                        tile_cols = pad_scale_w / center_w_size
                    # (输入尺寸-保留中心尺寸)// 2 == 大图padding
                    outer_h_padding = int((tile_h_size - center_h_size) / 2)
                    outer_w_padding = int((tile_w_size - center_w_size) / 2)
                    scale_image = pad_image(scale_image,
                                            (outer_h_padding, outer_w_padding))

                    scale_image_size = scale_image.shape  # (b,c,h,w)
                    overlap = 1 / 3  # 每次滑动的覆盖率为1/3
                    stride = ceil(
                        tile_h_size * (1 - overlap)
                    )  # 滑动步长:512*(1-1/3) = 513     512*(1-1/3)= 342
                    tile_rows = int(
                        ceil((scale_image_size[2] - tile_h_size) / stride) +
                        1)  # 行滑动步数:(3072-512)/342+1=9
                    tile_cols = int(
                        ceil((scale_image_size[3] - tile_w_size) / stride) +
                        1)  # 列滑动步数:(3328-512)/342+1=10
                    outputs_prob = torch.zeros(B, args.classes, sh, sw).cuda()
                    count_prob = torch.zeros(B, 1, sh, sw).cuda()

                    for row in range(
                            tile_rows):  # row = 0,1     0,1,2,3,4,5,6,7,8
                        for col in range(
                                tile_cols
                        ):  # col = 0,1,2,3     0,1,2,3,4,5,6,7,8,9
                            x1 = int(col *
                                     stride)  # 起始位置x1 = 0 * 513 = 0   0*342
                            y1 = int(row * stride)  # y1 = 0 * 513 = 0   0*342
                            x2 = min(x1 + tile_w_size, scale_image_size[3]
                                     )  # 末位置x2 = min(0+512, 3328)
                            y2 = min(
                                y1 + tile_h_size,
                                scale_image_size[2])  # y2 = min(0+512, 3072)
                            x1 = max(int(x2 - tile_w_size),
                                     0)  # 重新校准起始位置x1 = max(512-512, 0)
                            y1 = max(int(y2 - tile_h_size),
                                     0)  # y1 = max(512-512, 0)

                            with torch.no_grad():
                                tile_image = scale_image[:, :, y1:y2,
                                                         x1:x2].cuda()
                                tile_gt = scale_gt[:, :, y1:y2,
                                                   x1:x2].long().cuda()
                                if args.flip_merge:
                                    tile_output = flip_merge(model, tile_image)
                                else:
                                    tile_output = model(tile_image)

                                    # output = (main_loss, aux_loss1, axu_loss2***)
                            if type(tile_output) is tuple:
                                length = len(scale_output)
                                for index, scale_out in enumerate(
                                        scale_output):
                                    criterion = criterion.cuda()
                                    loss_record = criterion(
                                        scale_out, scale_gt.squeeze(1))
                                    if index == 0:
                                        loss_record *= 0.6
                                    else:
                                        loss_record *= 0.4 / (length - 1)
                                    loss += loss_record
                                scale_output = scale_output[0]
                                count_loss += 1
                            elif type(tile_output) is not tuple:
                                loss += criterion(tile_output,
                                                  tile_gt.squeeze(1))
                                count_loss += 1

                            outputs_prob[:, :, y1:y2, x1:x2] += tile_output
                            count_prob[:, :, y1:y2, x1:x2] += 1

                    # 结束每一个scale之后的图片滑动窗口计算概率
                    assert ((count_prob == 0).sum() == 0)
                    outputs = outputs_prob / count_prob

                outputs = F.interpolate(outputs, (H, W),
                                        mode='bilinear',
                                        align_corners=True)
                full_prob += outputs

            # visualize normalization Weights
            # plt.imshow(np.mean(count_predictions, axis=2))
            # plt.show()
            gt = gt.cpu().numpy()
            full_prob = torch.argmax(full_prob, 1).long()
            full_prob = np.asarray(full_prob.cpu(),
                                   dtype=np.uint8)  # (B,C,H,W)

            # plt.imshow(gt[0])
            # plt.show()
            # 计算miou
            '''设置输出原图和预测图片的颜色灰度还是彩色'''
            for index in range(
                    full_prob.shape[0]):  # full_prob shape[0] is batch_size
                metric.addBatch(full_prob[index], gt[index])
                if save_result:
                    save_predict(full_prob[index],
                                 gt[index],
                                 name[index],
                                 args.dataset,
                                 args.save_seg_dir,
                                 output_grey=False,
                                 output_color=True,
                                 gt_color=True)

        loss, FWIoU, Miou, MIoU_avg, PerCiou_set, Pa, PerCpa_set, Mpa, MF, F_set, F_avg = eval_metric(
            args, metric, count_loss, loss)
    else:
        for i, (image, size, name) in pbar:
            B, C, H, W = image.shape
            # image scaled [0.75, 1.0, 1.25, 1.5, 2.0]
            full_prob = torch.zeros(B, args.classes, H, W).cuda()
            for scale in scales:
                sh = int(H * float(scale))
                sw = int(W * float(scale))
                scale_image = F.interpolate(image, (sh, sw),
                                            mode='bilinear',
                                            align_corners=True)

                # scale之后的尺寸是否大于title_size
                if H < sh and W < sw and (H < tile_h_size or W < tile_w_size):
                    # 直接整张图片预测,并且还原为正常尺寸
                    scale_image = scale_image.cuda()
                    scale_gt = scale_gt.cuda()
                    if args.flip_merge:
                        scale_output = flip_merge(model, scale_image)
                    else:
                        scale_output = model(scale_image)
                else:
                    scale_image_size = scale_image.shape  # (b,c,h,w)
                    overlap = 1 / 3  # 每次滑动的覆盖率为1/3
                    stride = ceil(tile_h_size *
                                  (1 - overlap))  # 滑动步长:512*(1-1/3)= 342
                    tile_rows = int(
                        ceil((scale_image_size[2] - tile_h_size) / stride) +
                        1)  # 行滑动步数:(3072-512)/342+1=9
                    tile_cols = int(
                        ceil((scale_image_size[3] - tile_w_size) / stride) +
                        1)  # 列滑动步数:(3328-512)/342+1=10
                    outputs_prob = torch.zeros(B, args.classes, sh, sw).cuda()
                    count_prob = torch.zeros(B, 1, sh, sw).cuda()

                    for row in range(
                            tile_rows):  # row = 0,1     0,1,2,3,4,5,6,7,8
                        for col in range(
                                tile_cols
                        ):  # col = 0,1,2,3     0,1,2,3,4,5,6,7,8,9
                            x1 = int(col *
                                     stride)  # 起始位置x1 = 0 * 513 = 0   0*342
                            y1 = int(row * stride)  # y1 = 0 * 513 = 0   0*342
                            x2 = min(x1 + tile_w_size, scale_image_size[3]
                                     )  # 末位置x2 = min(0+512, 3328)
                            y2 = min(
                                y1 + tile_h_size,
                                scale_image_size[2])  # y2 = min(0+512, 3072)
                            x1 = max(int(x2 - tile_w_size),
                                     0)  # 重新校准起始位置x1 = max(512-512, 0)
                            y1 = max(int(y2 - tile_h_size),
                                     0)  # y1 = max(512-512, 0)

                            with torch.no_grad():
                                tile_image = scale_image[:, :, y1:y2,
                                                         x1:x2].cuda()
                                tile_gt = scale_gt[:, :, y1:y2,
                                                   x1:x2].long().cuda()
                                if args.flip_merge:
                                    tile_output = flip_merge(model, tile_image)
                                else:
                                    tile_output = model(tile_image)

                            if type(tile_output) is tuple:
                                tile_output = tile_output[0]

                            outputs_prob[:, :, y1:y2, x1:x2] += tile_output
                            count_prob[:, :, y1:y2, x1:x2] += 1

                    # 结束每一个scale之后的图片滑动窗口计算概率
                    assert ((count_prob == 0).sum() == 0)
                    outputs = outputs_prob / count_prob

                outputs = F.interpolate(outputs, (H, W),
                                        mode='bilinear',
                                        align_corners=True)
                full_prob += outputs

            # visualize normalization Weights
            # plt.imshow(np.mean(count_predictions, axis=2))
            # plt.show()
            gt = gt.cpu().numpy()
            full_prob = torch.argmax(full_prob, 1).long()
            full_prob = np.asarray(full_prob.cpu(),
                                   dtype=np.uint8)  # (B,C,H,W)

            # plt.imshow(gt[0])
            # plt.show()
            # 计算miou
            for index in range(
                    full_prob.shape[0]):  # gt shape[0] is batch_size
                if save_result:
                    save_predict(full_prob[index],
                                 None,
                                 name[index],
                                 args.dataset,
                                 args.save_seg_dir,
                                 output_grey=True,
                                 output_color=False,
                                 gt_color=False)

        loss, FWIoU, Miou, MIoU_avg, PerCiou_set, Pa, PerCpa_set, Mpa, MF, F_set, F_avg = 0, 0, 0, 0, {}, 0, {}, 0, 0, {}, 0

    return loss, FWIoU, Miou, MIoU_avg, PerCiou_set, Pa, PerCpa_set, Mpa, MF, F_set, F_avg
Пример #8
0
def predict_sliding(args, net, image, tile_size, classes):
    total_batches = len(image)
    data_list = []
    pbar = tqdm(iterable=enumerate(image),
                total=total_batches,
                desc='Predicting')
    for i, (input, gt, size, name) in pbar:
        image_size = input.shape  # (1,3,3328,3072)
        overlap = 1 / 3  # 每次滑动的覆盖率为1/3
        # print(image_size, tile_size)
        stride = ceil(
            tile_size[0] *
            (1 - overlap))  # 滑动步长:512*(1-1/3) = 513     512*(1-1/3)= 342
        tile_rows = int(ceil((image_size[2] - tile_size[0]) / stride) +
                        1)  # 行滑动步数:(3072-512)/342+1=9
        tile_cols = int(ceil((image_size[3] - tile_size[1]) / stride) +
                        1)  # 列滑动步数:(3328-512)/342+1=10
        full_probs = np.zeros((image_size[2], image_size[3],
                               classes))  # 初始化全概率矩阵shape(3072,3328,3)
        count_predictions = np.zeros((image_size[2], image_size[3],
                                      classes))  # 初始化计数矩阵shape(3072,3328,3)

        for row in range(tile_rows):  # row = 0,1     0,1,2,3,4,5,6,7,8
            for col in range(
                    tile_cols):  # col = 0,1,2,3     0,1,2,3,4,5,6,7,8,9
                x1 = int(col * stride)  # 起始位置x1 = 0 * 513 = 0   0*342
                y1 = int(row * stride)  # y1 = 0 * 513 = 0   0*342
                x2 = min(x1 + tile_size[1],
                         image_size[3])  # 末位置x2 = min(0+512, 3328)
                y2 = min(y1 + tile_size[0],
                         image_size[2])  # y2 = min(0+512, 3072)
                x1 = max(int(x2 - tile_size[1]),
                         0)  # 重新校准起始位置x1 = max(512-512, 0)
                y1 = max(int(y2 - tile_size[0]), 0)  # y1 = max(512-512, 0)

                img = input[:, :, y1:y2,
                            x1:x2]  # 滑动窗口对应的图像 imge[:, :, 0:512, 0:512]
                padded_img = pad_image(img,
                                       tile_size)  # padding 确保扣下来的图像为512*512
                # plt.imshow(padded_img)
                # plt.show()

                # 将扣下来的部分传入网络,网络输出概率图。
                with torch.no_grad():
                    input_var = torch.from_numpy(padded_img).cuda().float()
                    padded_prediction = net(input_var)

                    if type(padded_prediction) is tuple:
                        padded_prediction = padded_prediction[0]

                    torch.cuda.synchronize()

                if isinstance(padded_prediction, list):
                    padded_prediction = padded_prediction[
                        0]  # shape(1,3,512,512)

                padded_prediction = padded_prediction.cpu().data[0].numpy(
                ).transpose(1, 2, 0)  # 通道位置变换(512,512,3)
                prediction = padded_prediction[
                    0:img.shape[2],
                    0:img.shape[3], :]  # 扣下相应面积 shape(512,512,3)
                count_predictions[y1:y2, x1:x2] += 1  # 窗口区域内的计数矩阵加1
                full_probs[y1:y2, x1:x2] += prediction  # 窗口区域内的全概率矩阵叠加预测结果

        # average the predictions in the overlapping regions
        full_probs /= count_predictions  # 全概率矩阵 除以 计数矩阵 即得 平均概率
        # visualize normalization Weights
        # plt.imshow(np.mean(count_predictions, axis=2))
        # plt.show()
        full_probs = np.asarray(np.argmax(full_probs, axis=2), dtype=np.uint8)
        '''设置输出原图和预测图片的颜色灰度还是彩色'''
        gt = gt[0].numpy()
        # 计算miou
        data_list.append([gt.flatten(), full_probs.flatten()])
        save_predict(full_probs,
                     gt,
                     name[0],
                     args.dataset,
                     args.save_seg_dir,
                     output_grey=False,
                     output_color=True,
                     gt_color=True)

    meanIoU, per_class_iu = get_iou(data_list, args.classes)
    print('miou {}\nclass iou {}'.format(meanIoU, per_class_iu))
    result = args.save_seg_dir + '/results.txt'
    with open(result, 'w') as f:
        f.write(str(meanIoU))
        f.write('\n{}'.format(str(per_class_iu)))
def predict_multiscale_sliding(args,
                               model,
                               class_dict_df,
                               testLoader,
                               scales,
                               overlap,
                               criterion,
                               mode='predict',
                               save_result=True):
    loss = 0
    count_loss = 0
    model.eval()
    criterion = criterion.cuda()
    tile_h_size, tile_w_size = int(args.tile_hw_size.split(',')[0]), int(
        args.tile_hw_size.split(',')[1])
    metric = SegmentationMetric(args.classes)
    pbar = tqdm(iterable=enumerate(testLoader),
                total=len(testLoader),
                desc='Predicting')
    if mode == 'validation':
        for i, (image, gt, size, name) in pbar:
            B, C, H, W = image.shape
            # image and gt scaled together [0.75, 1.0, 1.25, 1.5, 2.0]
            full_prob = torch.zeros(B, args.classes, H, W).cuda()
            for scale in scales:
                scale = float(scale)
                scale = float(scale)
                sh = int(H * scale)
                sw = int(W * scale)

                scale_image = F.interpolate(image, (sh, sw),
                                            mode='bilinear',
                                            align_corners=True).float()
                scale_gt = F.interpolate(gt.unsqueeze(1).float(), (sh, sw),
                                         mode='nearest').long()

                # Whether the size after scale is greater than title_size
                if (H > sh or W > sw) and (H < tile_h_size or W < tile_w_size):
                    # Directly predict the entire image and restore it to normal size
                    with torch.no_grad():
                        scale_image = scale_image.cuda()
                        scale_gt = scale_gt.cuda()
                        if args.flip_merge:
                            outputs = flip_merge(model, scale_image)
                        else:
                            outputs = model(scale_image)

                        if type(outputs) is tuple:
                            length = len(outputs)
                            for index, out in enumerate(outputs):
                                criterion = criterion.cuda()
                                loss_record = criterion(
                                    out, scale_gt.squeeze(1))
                                if index == 0:
                                    loss_record *= 0.6
                                else:
                                    loss_record *= 0.4 / (length - 1)
                                loss += loss_record
                            outputs = outputs[0]
                            count_loss += 1
                        elif type(outputs) is not tuple:
                            loss += criterion(outputs, scale_gt.squeeze(1))
                            count_loss += 1
                else:
                    scale_image_size = scale_image.shape  # (b,c,h,w)
                    # overlap stands for coverage per slide
                    stride = ceil(tile_h_size * (1 - overlap))
                    tile_rows = int(
                        ceil((scale_image_size[2] - tile_h_size) / stride) + 1)
                    tile_cols = int(
                        ceil((scale_image_size[3] - tile_w_size) / stride) + 1)
                    outputs_prob = torch.zeros(B, args.classes, sh, sw).cuda()
                    count_prob = torch.zeros(B, 1, sh, sw).cuda()

                    for row in range(tile_rows):
                        for col in range(tile_cols):
                            x1 = int(col * stride)
                            y1 = int(row * stride)
                            x2 = min(x1 + tile_w_size, scale_image_size[3])
                            y2 = min(y1 + tile_h_size, scale_image_size[2])
                            x1 = max(int(x2 - tile_w_size), 0)
                            y1 = max(int(y2 - tile_h_size), 0)

                            with torch.no_grad():
                                tile_image = scale_image[:, :, y1:y2,
                                                         x1:x2].float().cuda()
                                tile_gt = scale_gt[:, :, y1:y2,
                                                   x1:x2].long().cuda()
                                if args.flip_merge:
                                    tile_output = flip_merge(model, tile_image)
                                else:
                                    tile_output = model(tile_image)

                            # output = (main_loss, aux_loss1, axu_loss2***)
                            if type(tile_output) is tuple:
                                length = len(tile_output)
                                for index, out in enumerate(tile_output):
                                    criterion = criterion.cuda()
                                    loss_record = criterion(
                                        out, tile_gt.squeeze(1))
                                    if index == 0:
                                        loss_record *= 0.6
                                    else:
                                        loss_record *= 0.4 / (length - 1)
                                    loss += loss_record
                                tile_output = tile_output[0]
                                count_loss += 1
                            elif type(tile_output) is not tuple:
                                loss += criterion(tile_output,
                                                  tile_gt.squeeze(1))
                                count_loss += 1

                            outputs_prob[:, :, y1:y2, x1:x2] += tile_output
                            count_prob[:, :, y1:y2, x1:x2] += 1

                    assert ((count_prob == 0).sum() == 0)
                    outputs = outputs_prob / count_prob

                outputs = F.interpolate(outputs, (H, W),
                                        mode='bilinear',
                                        align_corners=True)
                full_prob += outputs

            gt = np.asarray(gt.cpu(), dtype=np.uint8)
            full_prob = torch.argmax(full_prob, 1).long()
            full_prob = np.asarray(full_prob.cpu(),
                                   dtype=np.uint8)  # (B,C,H,W)
            '''Sets the color of the output image and predict image to be grayscale or color'''
            for index in range(
                    full_prob.shape[0]):  # full_prob shape[0] is batch_size
                metric.addBatch(full_prob[index], gt[index])
                if save_result:
                    save_predict(full_prob[index],
                                 gt[index],
                                 name[index],
                                 args.dataset,
                                 args.save_seg_dir,
                                 output_grey=False,
                                 output_color=True,
                                 gt_color=True)

        loss, FWIoU, Miou, Miou_Noback, PerCiou_set, Pa, PerCpa_set, Mpa, MF, F_set, F1_Noback = \
            eval_metric(args, class_dict_df, metric, count_loss, loss)
    else:
        for i, (image, size, name) in pbar:
            B, C, H, W = image.shape
            # image scaled [0.75, 1.0, 1.25, 1.5, 2.0]
            full_prob = torch.zeros(B, args.classes, H, W).cuda()
            for scale in scales:
                scale = float(scale)
                sh = int(H * scale)
                sw = int(W * scale)

                scale_image = F.interpolate(image, (sh, sw),
                                            mode='bilinear',
                                            align_corners=True).float()

                # Whether the size after scale is greater than title_size
                if (H > sh or W > sw) and (H < tile_h_size or W < tile_w_size):
                    # Directly predict the entire image and restore it to normal size
                    with torch.no_grad():
                        scale_image = scale_image.cuda()
                        if args.flip_merge:
                            outputs = flip_merge(model, scale_image)
                        else:
                            outputs = model(scale_image)
                        if type(outputs) is tuple:
                            outputs = outputs[0]

                else:
                    scale_image_size = scale_image.shape  # (b,c,h,w)
                    # overlap stands for coverage per slide
                    stride = ceil(tile_h_size * (1 - overlap))
                    tile_rows = int(
                        ceil((scale_image_size[2] - tile_h_size) / stride) + 1)
                    tile_cols = int(
                        ceil((scale_image_size[3] - tile_w_size) / stride) + 1)
                    outputs_prob = torch.zeros(B, args.classes, sh, sw).cuda()
                    count_prob = torch.zeros(B, 1, sh, sw).cuda()

                    for row in range(tile_rows):
                        for col in range(tile_cols):
                            x1 = int(col * stride)
                            y1 = int(row * stride)
                            x2 = min(x1 + tile_w_size, scale_image_size[3])
                            y2 = min(y1 + tile_h_size, scale_image_size[2])
                            x1 = max(int(x2 - tile_w_size), 0)
                            y1 = max(int(y2 - tile_h_size), 0)

                            with torch.no_grad():
                                tile_image = scale_image[:, :, y1:y2,
                                                         x1:x2].float().cuda()
                                if args.flip_merge:
                                    tile_output = flip_merge(model, tile_image)
                                else:
                                    tile_output = model(tile_image)

                            if type(tile_output) is tuple:
                                tile_output = tile_output[0]
                            outputs_prob[:, :, y1:y2, x1:x2] += tile_output
                            count_prob[:, :, y1:y2, x1:x2] += 1

                    assert ((count_prob == 0).sum() == 0)
                    outputs = outputs_prob / count_prob

                outputs = F.interpolate(outputs, (H, W),
                                        mode='bilinear',
                                        align_corners=True)
                full_prob += outputs

            full_prob = torch.argmax(full_prob, 1).long()
            full_prob = np.asarray(full_prob.cpu(),
                                   dtype=np.uint8)  # (B,C,H,W)
            '''Sets the color of the output image and predict image to be grayscale or color'''
            # save results
            for index in range(
                    full_prob.shape[0]):  # gt shape[0] is batch_size
                if save_result:
                    save_predict(full_prob[index],
                                 None,
                                 name[index],
                                 args.dataset,
                                 args.save_seg_dir,
                                 output_grey=True,
                                 output_color=False,
                                 gt_color=False)

        loss, FWIoU, Miou, Miou_Noback, PerCiou_set, Pa, PerCpa_set, Mpa, MF, F_set, F1_Noback = 0, 0, 0, 0, {}, 0, {}, 0, 0, {}, 0

    return loss, FWIoU, Miou, Miou_Noback, PerCiou_set, Pa, PerCpa_set, Mpa, MF, F_set, F1_Noback
def test(args, test_loader, model):
    """
    args:
      test_loader: loaded for test dataset
      model: model
    return: class IoU and mean IoU
    """
    # evaluation or test mode
    model.eval()
    total_batches = len(test_loader)
    metric = IOUMetric(args.classes)
    df = pd.DataFrame(columns=['name','acc','mean_iu', 'fwavacc','Sky', 'Building', 'Pole', 'Road', 'Sidewalk', 'Tree', 'Sign', 'Fence', 'Car',
                     'Pedestrian', 'Bicyclist','Sky_iu', 'Building_iu', 'Pole_iu', 'Road_iu', 'Sidewalk_iu', 'Tree_iu', 'Sign_iu', 'Fence_iu', 'Car_iu',
                     'Pedestrian_iu', 'Bicyclist_iu'])

    # data_list = []
    # start_time1 = time.time()
    for i, (input, label, size, name) in enumerate(test_loader):
        metric_each = IOUMetric(args.classes)
        with torch.no_grad():
            if args.cuda:
                input_var = input.cuda()
            else:
                input_var = input
            if args.cuda:
                torch.cuda.synchronize()
            start_time = time.time()
            output = model(input_var)
            if args.model.startswith('BiSeNet'):
                _, predicted = torch.max(output[0].data, 1)
            else:
                _, predicted = torch.max(output.data, 1)
            if args.cuda:
                torch.cuda.synchronize()
        time_taken = time.time() - start_time
        print('[%d/%d]  time: %.2f' % (i + 1, total_batches, time_taken))
        # output = output.cpu().data[0].numpy()
        # gt = np.asarray(label[0].numpy(), dtype=np.uint8)
        # output = output.transpose(1, 2, 0)
        # output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
        # data_list.append([gt.flatten(), output.flatten()])

        np_pred = predicted.cpu().numpy()
        np_label = label.cpu().numpy()

        metric.add_batch(np_pred, np_label)
        metric_each.add_batch(np_pred, np_label)
        acc, acc_cls, iu, mean_iu, fwavacc = metric_each.evaluate()
        Sky,Building,Pole,Road,Sidewalk, Tree, Sign, Fence, Car,Pedestrian, Bicyclist=acc_cls
        Sky_iu, Building_iu, Pole_iu, Road_iu, Sidewalk_iu, Tree_iu, Sign_iu, Fence_iu, Car_iu, Pedestrian_iu, Bicyclist_iu = iu
        df.loc[df.shape[0]] = [name, acc, mean_iu, fwavacc,Sky,Building,Pole,Road,Sidewalk, Tree, Sign, Fence, Car,Pedestrian, Bicyclist,Sky_iu, Building_iu, Pole_iu, Road_iu, Sidewalk_iu, Tree_iu, Sign_iu, Fence_iu, Car_iu, Pedestrian_iu, Bicyclist_iu]


        # save the predicted image
        if args.save:
            for j in range(len(name)):
                save_predict(np_pred[j], np_label[j], name[j], args.dataset, args.save_seg_dir,
                             output_grey=False, output_color=True, gt_color=True)

    # meanIoU, per_class_iu = get_iou(data_list, args.classes)
    # return meanIoU, per_class_iu

    # time_taken1 = time.time() - start_time1
    # print('predict time: %.2f' % (time_taken1))
    # start_time2 = time.time()
    acc, acc_cls, iu, mean_iu, fwavacc = metric.evaluate()
    # time_taken2 = time.time() - start_time2
    # print('metric time: %.2f' % (time_taken2))
    # df.loc[df.shape[0]] = ['all',acc, mean_iu, fwavacc]
    df.to_csv(args.checkpoint+'test.csv')
    return mean_iu, iu
    def detect(self,
               source,
               qt_input=None,
               qt_output=None,
               qt_mask_output=None):
        """
        args:
          test_loader: loaded for test dataset, for those that do not provide label on the test set
          model: model
        return: class IoU and mean IoU
        """

        # load the test set
        _, dataset_loader = build_dataset_predict(source,
                                                  self.dataset_type,
                                                  self.num_workers,
                                                  none_gt=True)

        show_count = 0

        # evaluation or test mode
        self.model.eval()
        total_batches = len(dataset_loader)
        vid_writer = None
        vid_path = None
        vid_mask_writer = None
        vid_mask_path = None

        self.input_windows_width = 0
        self.input_windows_height = 0
        self.output_windows_height = 0

        for i, (input, size, name, mode, frame_count, img_original, vid_cap,
                info_str) in enumerate(dataset_loader):
            with torch.no_grad():
                input = input[None, ...]  # 增加多一个维度
                input = torch.tensor(input)  # [1, 3, 224, 224]
                input_var = input.cuda()
            start_time = time.time()
            output = self.model(input_var)
            torch.cuda.synchronize()
            time_taken = time.time() - start_time
            print(
                f'[{i + 1}/{total_batches}]  time: {time_taken * 1000:.4f} ms = {1 / time_taken:.1f} FPS'
            )
            output = output.cpu().data[0].numpy()
            output = output.transpose(1, 2, 0)
            output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

            save_name = Path(name).stem + f'_predict'
            if mode == 'images':
                # 保存图片推理结果
                save_predict(output,
                             None,
                             save_name,
                             self.dataset_type,
                             self.save_seg_dir,
                             output_grey=True,
                             output_color=False,
                             gt_color=False)

            # 将结果和原图画到一起
            img = img_original
            mask = output
            mask[mask == 1] = 255  # 将 mask 的 1 变成 255 --> 用于后面显示充当红色通道
            zeros = np.zeros(mask.shape[:2],
                             dtype="uint8")  # 生成 全为0 的矩阵,用于充当 蓝色 和 绿色通道
            mask_final = cv2.merge([zeros, zeros, mask])  # 合并成 3 通道
            img = cv2.addWeighted(img, 1, mask_final, 1, 0)  # 合并

            # 保存推理信息
            image_shape = f'{img_original.shape[0]}x{img_original.shape[1]} '
            self.predict_info = info_str + '%sDone. (%.3fs)' % (image_shape,
                                                                time_taken)
            print(self.predict_info)
            # QT 显示
            if qt_input is not None and qt_output is not None and dataset_loader.mode == 'video':
                video_count, vid_total = info_str.split(" ")[2][1:-1].split(
                    "/")  # 得出当前总帧数
                fps = (time_taken / 1) * 100
                fps_threshold = 25  # FPS 阈值
                show_flag = True
                if fps > fps_threshold:  # 如果 FPS > 阀值,则跳帧处理
                    fps_interval = 15  # 实时显示的帧率
                    show_unit = math.ceil(fps / fps_interval)  # 取出多少帧显示一帧,向上取整
                    if int(video_count) % show_unit != 0:  # 跳帧显示
                        show_flag = False
                    else:
                        show_count += 1

                if show_flag:
                    # 推理前的图片 origin_image, 推理后的图片 im0
                    self.show_real_time_image("input", qt_input,
                                              img_original)  # 原图
                    self.show_real_time_image("output", qt_output,
                                              img)  # 最终推理图
                    self.show_real_time_image("output", qt_mask_output,
                                              mask_final)  # 分割 mask 图

            if mode == 'images':
                # 保存 推理+原图 结果
                save_path = os.path.join(self.save_seg_dir,
                                         save_name + '_img.jpg')
                cv2.imwrite(f"{save_path}", img)

                save_mask_path = os.path.join(self.save_seg_dir,
                                              save_name + '_mask_img.jpg')
                cv2.imwrite(f"{save_mask_path}", mask_final)
            else:
                # 保存视频
                save_path = os.path.join(self.save_seg_dir,
                                         save_name + '_predict.mp4')
                if vid_path != save_path:  # new video
                    vid_path = save_path
                    if isinstance(vid_writer, cv2.VideoWriter):
                        vid_writer.release()  # release previous video writer

                    fourcc = 'mp4v'  # output video codec
                    fps = vid_cap.get(cv2.CAP_PROP_FPS)
                    w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
                    h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
                    vid_writer = cv2.VideoWriter(
                        save_path, cv2.VideoWriter_fourcc(*fourcc), fps,
                        (w, h))
                vid_writer.write(img)

                # 保存 mask 视频
                save_mask_path = os.path.join(self.save_seg_dir,
                                              save_name + '_mask_predict.mp4')
                if vid_mask_path != save_mask_path:  # new video
                    vid_mask_path = save_mask_path
                    if isinstance(vid_mask_writer, cv2.VideoWriter):
                        vid_mask_writer.release(
                        )  # release previous video writer

                    fourcc = 'mp4v'  # output video codec
                    fps = vid_cap.get(cv2.CAP_PROP_FPS)
                    w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
                    h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
                    vid_mask_writer = cv2.VideoWriter(
                        save_mask_path, cv2.VideoWriter_fourcc(*fourcc), fps,
                        (w, h))
                vid_mask_writer.write(mask_final)

        return save_path, save_mask_path
Пример #12
0
def test(args, test_loader, model):
    """
    args:
      test_loader: loaded for test dataset
      model: model
    return: class IoU and mean IoU
    """
    # evaluation or test mode
    model.eval()
    total_batches = len(test_loader)

    Miou_list = []
    Iou_list = []
    Pa_list = []
    Mpa_list = []
    Fmiou_list = []
    pbar = tqdm(iterable=enumerate(test_loader),
                total=total_batches,
                desc='Valing')
    for i, (input, gt, size, name) in pbar:
        with torch.no_grad():
            input_var = Variable(input).cuda()
        start_time = time.time()
        output = model(input_var)
        torch.cuda.synchronize()
        time_taken = time.time() - start_time
        pbar.set_postfix(cost_time='%.3f' % time_taken)
        output = output.cpu().data[0].numpy()
        output = output.transpose(1, 2, 0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
        gt = np.asarray(gt[0], dtype=np.uint8)

        # 计算miou
        metric = SegmentationMetric(numClass=args.classes)
        metric.addBatch(imgPredict=output, imgLabel=gt)
        miou, iou = metric.meanIntersectionOverUnion()
        fmiou = metric.Frequency_Weighted_Intersection_over_Union()
        pa = metric.pixelAccuracy()
        mpa = metric.meanPixelAccuracy()
        Miou_list.append(miou)
        Fmiou_list.append(fmiou)
        Pa_list.append(pa)
        Mpa_list.append(mpa)
        iou = np.array(iou)
        Iou_list.append(iou)

        # save the predicted image
        if args.save:
            save_predict(output,
                         gt,
                         name[0],
                         args.dataset,
                         args.save_seg_dir,
                         output_grey=False,
                         output_color=True,
                         gt_color=True)

    miou = np.mean(Miou_list)
    fmiou = np.mean(Fmiou_list)
    pa = np.mean(Pa_list)
    mpa = np.mean(Mpa_list)
    Iou_list = np.asarray(Iou_list)
    iou = np.mean(Iou_list, axis=0)
    cls_iu = dict(zip(range(args.classes), iou))
    return miou, cls_iu, fmiou, pa, mpa
Пример #13
0
def predict(args, test_loader, model):
    """
    args:
      test_loader: loaded for test dataset, for those that do not provide label on the test set
      model: model
    return: class IoU and mean IoU
    """
    # evaluation or test mode
    model.eval()
    total_batches = len(test_loader)
    vid_writer = None
    vid_path = None

    for i, (input, size, name, mode, frame_count, img_original,
            vid_cap) in enumerate(test_loader):
        with torch.no_grad():
            input = input[None, ...]  # 增加多一个维度
            input = torch.tensor(input)  # [1, 3, 224, 224]
            input_var = input.cuda()
        start_time = time.time()
        output = model(input_var)
        torch.cuda.synchronize()
        time_taken = time.time() - start_time
        print(
            f'[{i + 1}/{total_batches}]  time: {time_taken * 1000:.4f} ms = {1 / time_taken:.1f} FPS'
        )
        output = output.cpu().data[0].numpy()
        output = output.transpose(1, 2, 0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

        save_name = Path(name).stem + f'_predict'
        if mode == 'images':
            # 保存图片推理结果
            save_predict(output,
                         None,
                         save_name,
                         args.dataset,
                         args.save_seg_dir,
                         output_grey=True,
                         output_color=True,
                         gt_color=False)

        # 将结果和原图画到一起
        img = img_original
        mask = output
        mask[mask == 1] = 255  # 将 mask 的 1 变成 255 --> 用于后面显示充当红色通道
        zeros = np.zeros(mask.shape[:2],
                         dtype="uint8")  # 生成 全为0 的矩阵,用于充当 蓝色 和 绿色通道
        mask_final = cv2.merge([zeros, zeros, mask])  # 合并成 3 通道
        img = cv2.addWeighted(img, 1, mask_final, 1, 0)  # 合并

        if mode == 'images':
            # 保存 推理+原图 结果
            cv2.imwrite(
                f"{os.path.join(args.save_seg_dir, save_name + '_img.png')}",
                img)
        else:
            # 保存视频
            save_path = os.path.join(args.save_seg_dir,
                                     save_name + '_predict.mp4')
            if vid_path != save_path:  # new video
                vid_path = save_path
                if isinstance(vid_writer, cv2.VideoWriter):
                    vid_writer.release()  # release previous video writer

                fourcc = 'mp4v'  # output video codec
                fps = vid_cap.get(cv2.CAP_PROP_FPS)
                w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
                h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
                vid_writer = cv2.VideoWriter(save_path,
                                             cv2.VideoWriter_fourcc(*fourcc),
                                             fps, (w, h))
            vid_writer.write(img)
def test(args, test_loader, model):
    """
    args:
      test_loader: loaded for test dataset
      model: model
    return: class IoU and mean IoU
    """
    # evaluation or test mode
    model.eval()
    total_batches = len(test_loader)
    metric = IOUMetric(args.classes)

    # data_list = []
    # start_time1 = time.time()
    for i, (input, label, size, name) in enumerate(test_loader):
        with torch.no_grad():
            if args.cuda:
                input_var = input.cuda()
            else:
                input_var = input
            if args.cuda:
                torch.cuda.synchronize()
            start_time = time.time()
            output = model(input_var)
            if args.model.startswith('BiSeNet'):
                _, predicted = torch.max(output[0].data, 1)
            else:
                _, predicted = torch.max(output.data, 1)
            if args.cuda:
                torch.cuda.synchronize()
        time_taken = time.time() - start_time
        print('[%d/%d]  time: %.2f' % (i + 1, total_batches, time_taken))
        # output = output.cpu().data[0].numpy()
        # gt = np.asarray(label[0].numpy(), dtype=np.uint8)
        # output = output.transpose(1, 2, 0)
        # output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
        # data_list.append([gt.flatten(), output.flatten()])

        np_pred = predicted.cpu().numpy()
        np_label = label.cpu().numpy()

        metric.add_batch(np_pred, np_label)

        # save the predicted image
        if args.save:
            for j in range(len(name)):
                save_predict(np_pred[j],
                             np_label[j],
                             name[j],
                             args.dataset,
                             args.save_seg_dir,
                             output_grey=False,
                             output_color=True,
                             gt_color=True)

    # meanIoU, per_class_iu = get_iou(data_list, args.classes)
    # return meanIoU, per_class_iu

    # time_taken1 = time.time() - start_time1
    # print('predict time: %.2f' % (time_taken1))
    # start_time2 = time.time()
    acc, acc_cls, iu, mean_iu, fwavacc = metric.evaluate()
    # time_taken2 = time.time() - start_time2
    # print('metric time: %.2f' % (time_taken2))
    return mean_iu, iu