def test_seg_model(args): if args.model_name == "UNet": model = UNet(n_channels=args.in_channels, n_classes=args.class_num) elif args.model_name == "PSP": model = pspnet.PSPNet(n_classes=19, input_size=(448, 448)) model.classification = nn.Conv2d(512, args.class_num, kernel_size=1) else: raise NotImplemented("Unknown model {}".format(args.model_name)) model_path = os.path.join(args.model_dir, args.best_model) model = nn.DataParallel(model) model.load_state_dict(torch.load(model_path)) model.cuda() model.eval() print('--------Start testing--------') since = time.time() dloader = gen_dloader(os.path.join(args.data_dir, "val"), args.batch_size, mode="val") metrics = defaultdict(float) ttl_samples = 0 # preds_dir = os.path.join(args.data_dir, "val/preds", args.model_name) # filesystem.overwrite_dir(preds_dir) for batch_ind, (imgs, masks) in enumerate(dloader): if batch_ind != 0 and batch_ind % 100 == 0: print("Processing {}/{}".format(batch_ind, len(dloader))) inputs = Variable(imgs.cuda()) masks = Variable(masks.cuda()) with torch.no_grad(): outputs = model(inputs) loss = calc_loss(outputs, masks, metrics) # result_img = gen_patch_pred(inputs, masks, outputs) # result_path = os.path.join(preds_dir, str(uuid.uuid1())[:8] + ".png") # io.imsave(result_path, result_img) ttl_samples += inputs.size(0) avg_dice = metrics['dice'] / ttl_samples time_elapsed = time.time() - since print('Testing takes {:.0f}m {:.2f}s'.format(time_elapsed // 60, time_elapsed % 60)) print("----Dice coefficient is: {:.3f}".format(avg_dice))
def load_seg_model(args): if args.seg_model_name == "UNet": seg_model = UNet(n_channels=args.in_channels, n_classes=args.seg_class_num) elif args.seg_model_name == "PSP": seg_model = pspnet.PSPNet(n_classes=19, input_size=(args.patch_len, args.patch_len)) seg_model.classification = nn.Conv2d(512, args.seg_class_num, kernel_size=1) else: raise NotImplemented("Unknown model {}".format(args.seg_model_name)) seg_model_path = os.path.join(args.model_dir, "SegBestModel", args.best_seg_model) seg_model = nn.DataParallel(seg_model) seg_model.load_state_dict(torch.load(seg_model_path)) seg_model.cuda() seg_model.eval() return seg_model
def train_seg_model(args): # model model = None if args.model_name == "UNet": model = UNet(n_channels=args.in_channels, n_classes=args.class_num) elif args.model_name == "PSP": model = pspnet.PSPNet(n_classes=19, input_size=(512, 512)) model.load_pretrained_model( model_path="./segnet/pspnet/pspnet101_cityscapes.caffemodel") model.classification = nn.Conv2d(512, args.class_num, kernel_size=1) else: raise AssertionError("Unknow modle: {}".format(args.model_name)) model = nn.DataParallel(model) model.cuda() # optimizer optimizer = None if args.optim_name == "Adam": optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1.0e-3) elif args.optim_name == "SGD": optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.init_lr, momentum=0.9, weight_decay=0.0005) else: raise AssertionError("Unknow optimizer: {}".format(args.optim_name)) scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=LambdaLR(args.maxepoch, 0, 0).step) # dataloader train_data_dir = os.path.join(args.data_dir, args.tumor_type, "train") train_dloader = gen_dloader(train_data_dir, args.batch_size, mode="train", normalize=args.normalize, tumor_type=args.tumor_type) test_data_dir = os.path.join(args.data_dir, args.tumor_type, "val") val_dloader = gen_dloader(test_data_dir, args.batch_size, mode="val", normalize=args.normalize, tumor_type=args.tumor_type) # training save_model_dir = os.path.join(args.model_dir, args.tumor_type, args.session) if not os.path.exists(save_model_dir): os.makedirs(save_model_dir) best_dice = 0.0 for epoch in np.arange(0, args.maxepoch): print('Epoch {}/{}'.format(epoch + 1, args.maxepoch)) print('-' * 10) since = time.time() for phase in ['train', 'val']: if phase == 'train': dloader = train_dloader scheduler.step() for param_group in optimizer.param_groups: print("Current LR: {:.8f}".format(param_group['lr'])) model.train() # Set model to training mode else: dloader = val_dloader model.eval() # Set model to evaluate mode metrics = defaultdict(float) epoch_samples = 0 for batch_ind, (imgs, masks) in enumerate(dloader): inputs = Variable(imgs.cuda()) masks = Variable(masks.cuda()) optimizer.zero_grad() with torch.set_grad_enabled(phase == 'train'): outputs = model(inputs) loss = calc_loss(outputs, masks, metrics, bce_weight=args.bce_weight) if phase == 'train': loss.backward() optimizer.step() # statistics epoch_samples += inputs.size(0) print_metrics(metrics, epoch_samples, phase) epoch_dice = metrics['dice'] / epoch_samples # deep copy the model if phase == 'val' and (epoch_dice > best_dice or epoch > args.maxepoch - 5): best_dice = epoch_dice best_model = copy.deepcopy(model.state_dict()) best_model_name = "-".join([ args.model_name, "{:03d}-{:.3f}.pth".format(epoch, best_dice) ]) torch.save(best_model, os.path.join(save_model_dir, best_model_name)) time_elapsed = time.time() - since print('Epoch {:2d} takes {:.0f}m {:.0f}s'.format( epoch, time_elapsed // 60, time_elapsed % 60)) print( "================================================================================" ) print("Training finished...")
def test_slide_seg(args): model = None if args.model_name == "UNet": model = UNet(n_channels=args.in_channels, n_classes=args.class_num) elif args.model_name == "PSP": model = pspnet.PSPNet(n_classes=19, input_size=(512, 512)) model.classification = nn.Conv2d(512, args.class_num, kernel_size=1) else: raise AssertionError("Unknow modle: {}".format(args.model_name)) model_path = os.path.join(args.model_dir, args.tumor_type, args.split, args.best_model) model = nn.DataParallel(model) model.load_state_dict(torch.load(model_path)) model.cuda() model.eval() since = time.time() result_dir = os.path.join(args.result_dir, args.tumor_type) filesystem.overwrite_dir(result_dir) slide_names = get_slide_filenames(args.slides_dir) if args.save_org and args.tumor_type == "viable": org_result_dir = os.path.join(result_dir, "Level0") filesystem.overwrite_dir(org_result_dir) for num, cur_slide in enumerate(slide_names): print("--{:02d}/{:02d} Slide:{}".format(num+1, len(slide_names), cur_slide)) metrics = defaultdict(float) # load level-2 slide slide_path = os.path.join(args.slides_dir, cur_slide+".svs") if not os.path.exists(slide_path): slide_path = os.path.join(args.slides_dir, cur_slide+".SVS") wsi_head = pyramid.load_wsi_head(slide_path) p_level = args.slide_level pred_h, pred_w = (wsi_head.level_dimensions[p_level][1], wsi_head.level_dimensions[p_level][0]) slide_img = wsi_head.read_region((0, 0), p_level, wsi_head.level_dimensions[p_level]) slide_img = np.asarray(slide_img)[:,:,:3] coors_arr = wsi_stride_splitting(pred_h, pred_w, patch_len=args.patch_len, stride_len=args.stride_len) patch_arr, wmap = gen_patch_wmap(slide_img, coors_arr, plen=args.patch_len) patch_dset = PatchDataset(patch_arr, mask_arr=None, normalize=args.normalize, tumor_type=args.tumor_type) patch_loader = DataLoader(patch_dset, batch_size=args.batch_size, shuffle=False, num_workers=4, drop_last=False) ttl_samples = 0 pred_map = np.zeros_like(wmap).astype(np.float32) for ind, patches in enumerate(patch_loader): inputs = Variable(patches.cuda()) with torch.no_grad(): outputs = model(inputs) preds = F.sigmoid(outputs) preds = torch.squeeze(preds, dim=1).data.cpu().numpy() if (ind+1)*args.batch_size <= len(coors_arr): patch_coors = coors_arr[ind*args.batch_size:(ind+1)*args.batch_size] else: patch_coors = coors_arr[ind*args.batch_size:] for ind, coor in enumerate(patch_coors): ph, pw = coor[0], coor[1] pred_map[ph:ph+args.patch_len, pw:pw+args.patch_len] += preds[ind] ttl_samples += inputs.size(0) prob_pred = np.divide(pred_map, wmap) slide_pred = (prob_pred > 0.5).astype(np.uint8) pred_save_path = os.path.join(result_dir, cur_slide + "_" + args.tumor_type + ".tif") io.imsave(pred_save_path, slide_pred*255) if args.save_org and args.tumor_type == "viable": org_w, org_h = wsi_head.level_dimensions[0] org_pred = transform.resize(prob_pred, (org_h, org_w)) org_pred = (org_pred > 0.5).astype(np.uint8) org_save_path = os.path.join(org_result_dir, cur_slide[-3:] + ".tif") imsave(org_save_path, org_pred, compress=9) time_elapsed = time.time() - since print('Testing takes {:.0f}m {:.2f}s'.format(time_elapsed // 60, time_elapsed % 60))
def test_slide_seg(args): if args.model_name == "UNet": model = UNet(n_channels=args.in_channels, n_classes=args.class_num) elif args.model_name == "PSP": model = pspnet.PSPNet(n_classes=19, input_size=(args.patch_len, args.patch_len)) model.classification = nn.Conv2d(512, args.class_num, kernel_size=1) else: raise NotImplemented("Unknown model {}".format(args.model_name)) model_path = os.path.join(args.model_dir, args.best_model) model = nn.DataParallel(model) model.load_state_dict(torch.load(model_path)) model.cuda() model.eval() since = time.time() pydaily.filesystem.overwrite_dir(args.result_dir) slide_names = [ele for ele in os.listdir(args.slides_dir) if "jpg" in ele] ttl_pred_dice = 0.0 for num, cur_slide in enumerate(slide_names): print("--{:2d}/{:2d} Slide:{}".format(num + 1, len(slide_names), cur_slide)) start_time = timer() # load slide image and mask slide_path = os.path.join(args.slides_dir, cur_slide) slide_img = io.imread(slide_path) # split and predict coors_arr = wsi_stride_splitting(slide_img.shape[0], slide_img.shape[1], patch_len=args.patch_len, stride_len=args.stride_len) wmap = np.zeros((slide_img.shape[0], slide_img.shape[1]), dtype=np.int32) pred_map = np.zeros_like(wmap).astype(np.float32) patch_list, coor_list = [], [] for ic, coor in enumerate(coors_arr): ph, pw = coor[0], coor[1] patch_list.append( slide_img[ph:ph + args.patch_len, pw:pw + args.patch_len] / 255.0) coor_list.append([ph, pw]) wmap[ph:ph + args.patch_len, pw:pw + args.patch_len] += 1 if len(patch_list) == args.batch_size or ic + 1 == len(coors_arr): patch_arr = np.asarray(patch_list).astype(np.float32) patch_dset = PatchDataset(patch_arr) patch_loader = DataLoader(patch_dset, batch_size=args.batch_size, shuffle=False, num_workers=4, drop_last=False) with torch.no_grad(): pred_list = [] for patches in patch_loader: inputs = Variable(patches.cuda()) outputs = model(inputs) preds = F.sigmoid(outputs) preds = torch.squeeze(preds, dim=1).data.cpu().numpy() pred_list.append(preds) batch_preds = np.concatenate(pred_list, axis=0) for ind, coor in enumerate(coor_list): ph, pw = coor[0], coor[1] pred_map[ph:ph + args.patch_len, pw:pw + args.patch_len] += batch_preds[ind] patch_list, coor_list = [], [] prob_pred = np.divide(pred_map, wmap) slide_pred = morphology.remove_small_objects(prob_pred > 0.5, min_size=20480).astype( np.uint8) pred_save_path = os.path.join(args.result_dir, os.path.splitext(cur_slide)[0] + ".png") io.imsave(pred_save_path, slide_pred * 255) end_time = timer() print("Takes {}".format( pydaily.tic.time_to_str(end_time - start_time, 'sec'))) time_elapsed = time.time() - since print("stride-len: {} with batch-size: {}".format(args.stride_len, args.batch_size)) print("Testing takes {:.0f}m {:.2f}s".format(time_elapsed // 60, time_elapsed % 60))
# forward # track history if only in train with torch.no_grad(): outputs = model(inputs) loss = calc_loss(outputs, labels, metrics) # statistics epoch_samples += inputs.size(0) print_metrics(metrics, epoch_samples, "test") if __name__ == '__main__': args = set_args() os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu # create model model = None if args.network == "UNet": elif args.network == "PSP": model = pspnet.PSPNet(n_classes=19, input_size=(160, 160)) model.load_pretrained_model(model_path="./segnet/pspnet/pspnet101_cityscapes.caffemodel") model.classification = nn.Conv2d(512, args.class_num, kernel_size=1) else: raise Exception("Unknow network: {}".format(args.network)) print("Net: {} session: {} model name: {}".format(args.network, args.session, args.model_name)) model_path = os.path.join(args.model_dir, args.simu_type+args.network, args.session, args.model_name) model.load_state_dict(torch.load(model_path)) model.cuda() # train model test_seg_model(model, args)
def test_slide_seg(args): model = pspnet.PSPNet(n_classes=19, input_size=(args.patch_len, args.patch_len)) model.classification = nn.Conv2d(512, args.class_num, kernel_size=1) model_path = os.path.join(args.model_dir, args.best_model) model = nn.DataParallel(model) model.load_state_dict(torch.load(model_path)) model.cuda() model.eval() since = time.time() # filesystem.overwrite_dir(args.result_dir) slide_names = [ele for ele in os.listdir(args.slides_dir) if "jpg" in ele] ttl_pred_dice = 0.0 for num, cur_slide in enumerate(slide_names): metrics = defaultdict(float) # load slide image and mask slide_path = os.path.join(args.slides_dir, cur_slide) slide_img = io.imread(slide_path) / 255.0 mask_path = os.path.join(args.slides_dir, os.path.splitext(cur_slide)[0] + ".png") mask_img = io.imread(mask_path) / 255.0 # split and predict coors_arr = wsi_stride_splitting(slide_img.shape[0], slide_img.shape[1], patch_len=args.patch_len, stride_len=args.stride_len) wmap = np.zeros((slide_img.shape[0], slide_img.shape[1]), dtype=np.int32) pred_map = np.zeros_like(wmap).astype(np.float32) patch_list, coor_list = [], [] for ic, coor in enumerate(coors_arr): ph, pw = coor[0], coor[1] patch_list.append(slide_img[ph:ph + args.patch_len, pw:pw + args.patch_len]) coor_list.append([ph, pw]) wmap[ph:ph + args.patch_len, pw:pw + args.patch_len] += 1 if len(patch_list) == args.batch_size or ic + 1 == len(coors_arr): patch_arr = np.asarray(patch_list).astype(np.float32) patch_dset = PatchDataset(patch_arr) patch_loader = DataLoader(patch_dset, batch_size=args.batch_size, shuffle=False, num_workers=4, drop_last=False) with torch.no_grad(): pred_list = [] for patches in patch_loader: inputs = Variable(patches.cuda()) outputs = model(inputs) preds = F.sigmoid(outputs) preds = torch.squeeze(preds, dim=1).data.cpu().numpy() pred_list.append(preds) batch_preds = np.concatenate(pred_list, axis=0) for ind, coor in enumerate(coor_list): ph, pw = coor[0], coor[1] pred_map[ph:ph + args.patch_len, pw:pw + args.patch_len] += batch_preds[ind] patch_list, coor_list = [], [] prob_pred = np.divide(pred_map, wmap) slide_pred = morphology.remove_small_objects(prob_pred > 0.5, min_size=20480).astype( np.uint8) # pred_save_path = os.path.join(args.result_dir, os.path.splitext(cur_slide)[0]+".png") # io.imsave(pred_save_path, slide_pred*255) intersection = np.multiply(mask_img, slide_pred) pred_dice = np.sum(intersection) / (np.sum(mask_img) + np.sum(slide_pred) - np.sum(intersection) + 1.0e-8) ttl_pred_dice += pred_dice print("--{:2d}/{:2d} Slide:{} JI:{:.3f}".format( num + 1, len(slide_names), cur_slide, pred_dice)) time_elapsed = time.time() - since print('Testing takes {:.0f}m {:.2f}s'.format(time_elapsed // 60, time_elapsed % 60)) print('Slide-level average Dice coefficient is {:.3f}'.format( ttl_pred_dice / len(slide_names)))