def test(cfg, model, post_processor, criterion, device, test_loader): """ Return: a validation metric between 0-1 where 1 is perfect """ model.eval() post_processor.eval() test_loss = 0 correct = 0 # TODO: use a more consistent evaluation interface pixel_acc_list = [] iou_list = [] with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) feature = model(data) if cfg.task == "classification": output = post_processor(feature) elif cfg.task == "semantic_segmentation": ori_spatial_res = data.shape[-2:] output = post_processor(feature, ori_spatial_res) test_loss += criterion(output, target).item() # sum up batch loss if cfg.task == "classification": pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() elif cfg.task == "semantic_segmentation": pred_map = output.max(dim=1)[1] batch_acc, _ = utils.compute_pixel_acc( pred_map, target, fg_only=cfg.METRIC.SEGMENTATION.fg_only) pixel_acc_list.append(float(batch_acc)) for i in range(pred_map.shape[0]): iou = utils.compute_iou( np.array(pred_map[i].cpu()), np.array(target[i].cpu(), dtype=np.int64), cfg.num_classes, fg_only=cfg.METRIC.SEGMENTATION.fg_only) iou_list.append(float(iou)) else: raise NotImplementedError test_loss /= len(test_loader.dataset) if cfg.task == "classification": acc = 100. * correct / len(test_loader.dataset) print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'. format(test_loss, correct, len(test_loader.dataset), acc)) return acc elif cfg.task == "semantic_segmentation": m_iou = np.mean(iou_list) print( '\nTest set: Average loss: {:.4f}, Mean Pixel Accuracy: {:.4f}, Mean IoU {:.4f}' .format(test_loss, np.mean(pixel_acc_list), m_iou)) return m_iou else: raise NotImplementedError
def train(cfg, model, post_processor, criterion, device, train_loader, optimizer, epoch): model.train() post_processor.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) feature = model(data) if cfg.task == "classification": output = post_processor(feature) elif cfg.task == "semantic_segmentation" or cfg.task == "few_shot_semantic_segmentation_fine_tuning": ori_spatial_res = data.shape[-2:] output = post_processor(feature, ori_spatial_res) loss = criterion(output, target) optimizer.zero_grad() # reset gradient loss.backward() optimizer.step() if cfg.task == "classification": if batch_idx % cfg.TRAIN.log_interval == 0: pred = output.argmax(dim=1, keepdim=True) correct_prediction = pred.eq(target.view_as(pred)).sum().item() batch_acc = correct_prediction / data.shape[0] print( 'Train Epoch: {0} [{1}/{2} ({3:.0f}%)]\tLoss: {4:.6f}\tBatch Acc: {5:.6f}' .format(epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item(), batch_acc)) elif cfg.task == "semantic_segmentation" or cfg.task == "few_shot_semantic_segmentation_fine_tuning": if batch_idx % cfg.TRAIN.log_interval == 0: pred_map = output.max(dim=1)[1] batch_acc, _ = utils.compute_pixel_acc( pred_map, target, fg_only=cfg.METRIC.SEGMENTATION.fg_only) print( 'Train Epoch: {0} [{1}/{2} ({3:.0f}%)]\tLoss: {4:.6f}\tBatch Pixel Acc: {5:.6f}' .format(epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item(), batch_acc)) else: raise NotImplementedError
def test(cfg, model, post_processor, criterion, device, test_loader, visfreq): model.eval() post_processor.eval() test_loss = 0 correct = 0 pixel_acc_list = [] iou_list = [] with torch.no_grad(): for idx, (data, target) in enumerate(test_loader): data, target = data.to(device), target.to(device) feature = model(data) output = post_processor(feature) test_loss += criterion(output, target).item() # sum up batch loss if cfg.task == "classification": pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() # TODO: save classified images with raw image as content and # human readable label as filenames elif cfg.task == "semantic_segmentation": pred_map = output.max(dim=1)[1] batch_acc, _ = utils.compute_pixel_acc( pred_map, target, fg_only=cfg.METRIC.SEGMENTATION.fg_only) pixel_acc_list.append(float(batch_acc)) for i in range(pred_map.shape[0]): pred_np = np.array(pred_map[i].cpu()) target_np = np.array(target[i].cpu(), dtype=np.int64) iou = utils.compute_iou( pred_np, target_np, cfg.num_classes, fg_only=cfg.METRIC.SEGMENTATION.fg_only) iou_list.append(float(iou)) if (i + 1) % visfreq == 0: cv2.imwrite("{}_{}_pred.png".format(idx, i), pred_np) cv2.imwrite("{}_{}_label.png".format(idx, i), target_np) # Visualize RGB image as well ori_rgb_np = np.array(data[i].permute((1, 2, 0)).cpu()) if 'normalize' in cfg.DATASET.TRANSFORM.TEST.transforms: rgb_mean = cfg.DATASET.TRANSFORM.TEST.TRANSFORMS_DETAILS.NORMALIZE.mean rgb_sd = cfg.DATASET.TRANSFORM.TEST.TRANSFORMS_DETAILS.NORMALIZE.sd ori_rgb_np = (ori_rgb_np * rgb_sd) + rgb_mean assert ori_rgb_np.max() <= 1.1, "Max is {}".format( ori_rgb_np.max()) ori_rgb_np[ori_rgb_np >= 1] = 1 ori_rgb_np = (ori_rgb_np * 255).astype(np.uint8) # Convert to OpenCV BGR ori_rgb_np = cv2.cvtColor(ori_rgb_np, cv2.COLOR_RGB2BGR) cv2.imwrite("{}_{}_ori.jpg".format(idx, i), ori_rgb_np) else: raise NotImplementedError test_loss /= len(test_loader.dataset) if cfg.task == "classification": print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'. format(test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) elif cfg.task == "semantic_segmentation": print( '\nTest set: Average loss: {:.4f}, Mean Pixel Accuracy: {:.4f}, Mean IoU {:.4f}\n' .format(test_loss, np.mean(pixel_acc_list), np.mean(iou_list))) else: raise NotImplementedError