Пример #1
0
def test(cfg, model, post_processor, criterion, device, test_loader):
    """
    Return: a validation metric between 0-1 where 1 is perfect
    """
    model.eval()
    post_processor.eval()
    test_loss = 0
    correct = 0
    # TODO: use a more consistent evaluation interface
    pixel_acc_list = []
    iou_list = []
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            feature = model(data)
            if cfg.task == "classification":
                output = post_processor(feature)
            elif cfg.task == "semantic_segmentation":
                ori_spatial_res = data.shape[-2:]
                output = post_processor(feature, ori_spatial_res)
            test_loss += criterion(output, target).item()  # sum up batch loss
            if cfg.task == "classification":
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()
            elif cfg.task == "semantic_segmentation":
                pred_map = output.max(dim=1)[1]
                batch_acc, _ = utils.compute_pixel_acc(
                    pred_map, target, fg_only=cfg.METRIC.SEGMENTATION.fg_only)
                pixel_acc_list.append(float(batch_acc))
                for i in range(pred_map.shape[0]):
                    iou = utils.compute_iou(
                        np.array(pred_map[i].cpu()),
                        np.array(target[i].cpu(), dtype=np.int64),
                        cfg.num_classes,
                        fg_only=cfg.METRIC.SEGMENTATION.fg_only)
                    iou_list.append(float(iou))
            else:
                raise NotImplementedError

    test_loss /= len(test_loader.dataset)

    if cfg.task == "classification":
        acc = 100. * correct / len(test_loader.dataset)
        print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.
              format(test_loss, correct, len(test_loader.dataset), acc))
        return acc
    elif cfg.task == "semantic_segmentation":
        m_iou = np.mean(iou_list)
        print(
            '\nTest set: Average loss: {:.4f}, Mean Pixel Accuracy: {:.4f}, Mean IoU {:.4f}'
            .format(test_loss, np.mean(pixel_acc_list), m_iou))
        return m_iou
    else:
        raise NotImplementedError
Пример #2
0
def train(cfg, model, post_processor, criterion, device, train_loader,
          optimizer, epoch):
    model.train()
    post_processor.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        feature = model(data)
        if cfg.task == "classification":
            output = post_processor(feature)
        elif cfg.task == "semantic_segmentation" or cfg.task == "few_shot_semantic_segmentation_fine_tuning":
            ori_spatial_res = data.shape[-2:]
            output = post_processor(feature, ori_spatial_res)
        loss = criterion(output, target)
        optimizer.zero_grad()  # reset gradient
        loss.backward()
        optimizer.step()
        if cfg.task == "classification":
            if batch_idx % cfg.TRAIN.log_interval == 0:
                pred = output.argmax(dim=1, keepdim=True)
                correct_prediction = pred.eq(target.view_as(pred)).sum().item()
                batch_acc = correct_prediction / data.shape[0]
                print(
                    'Train Epoch: {0} [{1}/{2} ({3:.0f}%)]\tLoss: {4:.6f}\tBatch Acc: {5:.6f}'
                    .format(epoch, batch_idx * len(data),
                            len(train_loader.dataset),
                            100. * batch_idx / len(train_loader), loss.item(),
                            batch_acc))
        elif cfg.task == "semantic_segmentation" or cfg.task == "few_shot_semantic_segmentation_fine_tuning":
            if batch_idx % cfg.TRAIN.log_interval == 0:
                pred_map = output.max(dim=1)[1]
                batch_acc, _ = utils.compute_pixel_acc(
                    pred_map, target, fg_only=cfg.METRIC.SEGMENTATION.fg_only)
                print(
                    'Train Epoch: {0} [{1}/{2} ({3:.0f}%)]\tLoss: {4:.6f}\tBatch Pixel Acc: {5:.6f}'
                    .format(epoch, batch_idx * len(data),
                            len(train_loader.dataset),
                            100. * batch_idx / len(train_loader), loss.item(),
                            batch_acc))
        else:
            raise NotImplementedError
Пример #3
0
def test(cfg, model, post_processor, criterion, device, test_loader, visfreq):
    model.eval()
    post_processor.eval()
    test_loss = 0
    correct = 0
    pixel_acc_list = []
    iou_list = []
    with torch.no_grad():
        for idx, (data, target) in enumerate(test_loader):
            data, target = data.to(device), target.to(device)
            feature = model(data)
            output = post_processor(feature)
            test_loss += criterion(output, target).item()  # sum up batch loss
            if cfg.task == "classification":
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()
                # TODO: save classified images with raw image as content and
                # human readable label as filenames
            elif cfg.task == "semantic_segmentation":
                pred_map = output.max(dim=1)[1]
                batch_acc, _ = utils.compute_pixel_acc(
                    pred_map, target, fg_only=cfg.METRIC.SEGMENTATION.fg_only)
                pixel_acc_list.append(float(batch_acc))
                for i in range(pred_map.shape[0]):
                    pred_np = np.array(pred_map[i].cpu())
                    target_np = np.array(target[i].cpu(), dtype=np.int64)
                    iou = utils.compute_iou(
                        pred_np,
                        target_np,
                        cfg.num_classes,
                        fg_only=cfg.METRIC.SEGMENTATION.fg_only)
                    iou_list.append(float(iou))
                    if (i + 1) % visfreq == 0:
                        cv2.imwrite("{}_{}_pred.png".format(idx, i), pred_np)
                        cv2.imwrite("{}_{}_label.png".format(idx, i),
                                    target_np)
                        # Visualize RGB image as well
                        ori_rgb_np = np.array(data[i].permute((1, 2, 0)).cpu())
                        if 'normalize' in cfg.DATASET.TRANSFORM.TEST.transforms:
                            rgb_mean = cfg.DATASET.TRANSFORM.TEST.TRANSFORMS_DETAILS.NORMALIZE.mean
                            rgb_sd = cfg.DATASET.TRANSFORM.TEST.TRANSFORMS_DETAILS.NORMALIZE.sd
                            ori_rgb_np = (ori_rgb_np * rgb_sd) + rgb_mean
                        assert ori_rgb_np.max() <= 1.1, "Max is {}".format(
                            ori_rgb_np.max())
                        ori_rgb_np[ori_rgb_np >= 1] = 1
                        ori_rgb_np = (ori_rgb_np * 255).astype(np.uint8)
                        # Convert to OpenCV BGR
                        ori_rgb_np = cv2.cvtColor(ori_rgb_np,
                                                  cv2.COLOR_RGB2BGR)
                        cv2.imwrite("{}_{}_ori.jpg".format(idx, i), ori_rgb_np)
            else:
                raise NotImplementedError

    test_loss /= len(test_loader.dataset)

    if cfg.task == "classification":
        print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.
              format(test_loss, correct, len(test_loader.dataset),
                     100. * correct / len(test_loader.dataset)))
    elif cfg.task == "semantic_segmentation":
        print(
            '\nTest set: Average loss: {:.4f}, Mean Pixel Accuracy: {:.4f}, Mean IoU {:.4f}\n'
            .format(test_loss, np.mean(pixel_acc_list), np.mean(iou_list)))
    else:
        raise NotImplementedError