Пример #1
0
def test_seg_model(args):
    if args.model_name == "UNet":
        model = UNet(n_channels=args.in_channels, n_classes=args.class_num)
    elif args.model_name == "PSP":
        model = pspnet.PSPNet(n_classes=19, input_size=(448, 448))
        model.classification = nn.Conv2d(512, args.class_num, kernel_size=1)
    else:
        raise NotImplemented("Unknown model {}".format(args.model_name))
    model_path = os.path.join(args.model_dir, args.best_model)
    model = nn.DataParallel(model)
    model.load_state_dict(torch.load(model_path))
    model.cuda()
    model.eval()

    print('--------Start testing--------')
    since = time.time()
    dloader = gen_dloader(os.path.join(args.data_dir, "val"),
                          args.batch_size,
                          mode="val")

    metrics = defaultdict(float)
    ttl_samples = 0

    # preds_dir = os.path.join(args.data_dir, "val/preds", args.model_name)
    # filesystem.overwrite_dir(preds_dir)
    for batch_ind, (imgs, masks) in enumerate(dloader):
        if batch_ind != 0 and batch_ind % 100 == 0:
            print("Processing {}/{}".format(batch_ind, len(dloader)))
        inputs = Variable(imgs.cuda())
        masks = Variable(masks.cuda())

        with torch.no_grad():
            outputs = model(inputs)
            loss = calc_loss(outputs, masks, metrics)
            # result_img = gen_patch_pred(inputs, masks, outputs)
            # result_path = os.path.join(preds_dir, str(uuid.uuid1())[:8] + ".png")
            # io.imsave(result_path, result_img)

        ttl_samples += inputs.size(0)
    avg_dice = metrics['dice'] / ttl_samples
    time_elapsed = time.time() - since
    print('Testing takes {:.0f}m {:.2f}s'.format(time_elapsed // 60,
                                                 time_elapsed % 60))
    print("----Dice coefficient is: {:.3f}".format(avg_dice))
def load_seg_model(args):
    if args.seg_model_name == "UNet":
        seg_model = UNet(n_channels=args.in_channels,
                         n_classes=args.seg_class_num)
    elif args.seg_model_name == "PSP":
        seg_model = pspnet.PSPNet(n_classes=19,
                                  input_size=(args.patch_len, args.patch_len))
        seg_model.classification = nn.Conv2d(512,
                                             args.seg_class_num,
                                             kernel_size=1)
    else:
        raise NotImplemented("Unknown model {}".format(args.seg_model_name))

    seg_model_path = os.path.join(args.model_dir, "SegBestModel",
                                  args.best_seg_model)
    seg_model = nn.DataParallel(seg_model)
    seg_model.load_state_dict(torch.load(seg_model_path))
    seg_model.cuda()
    seg_model.eval()

    return seg_model
Пример #3
0
def train_seg_model(args):
    # model
    model = None
    if args.model_name == "UNet":
        model = UNet(n_channels=args.in_channels, n_classes=args.class_num)
    elif args.model_name == "PSP":
        model = pspnet.PSPNet(n_classes=19, input_size=(512, 512))
        model.load_pretrained_model(
            model_path="./segnet/pspnet/pspnet101_cityscapes.caffemodel")
        model.classification = nn.Conv2d(512, args.class_num, kernel_size=1)
    else:
        raise AssertionError("Unknow modle: {}".format(args.model_name))
    model = nn.DataParallel(model)
    model.cuda()
    # optimizer
    optimizer = None
    if args.optim_name == "Adam":
        optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                      model.parameters()),
                               lr=1.0e-3)
    elif args.optim_name == "SGD":
        optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                     model.parameters()),
                              lr=args.init_lr,
                              momentum=0.9,
                              weight_decay=0.0005)
    else:
        raise AssertionError("Unknow optimizer: {}".format(args.optim_name))
    scheduler = lr_scheduler.LambdaLR(optimizer,
                                      lr_lambda=LambdaLR(args.maxepoch, 0,
                                                         0).step)
    # dataloader
    train_data_dir = os.path.join(args.data_dir, args.tumor_type, "train")
    train_dloader = gen_dloader(train_data_dir,
                                args.batch_size,
                                mode="train",
                                normalize=args.normalize,
                                tumor_type=args.tumor_type)
    test_data_dir = os.path.join(args.data_dir, args.tumor_type, "val")
    val_dloader = gen_dloader(test_data_dir,
                              args.batch_size,
                              mode="val",
                              normalize=args.normalize,
                              tumor_type=args.tumor_type)

    # training
    save_model_dir = os.path.join(args.model_dir, args.tumor_type,
                                  args.session)
    if not os.path.exists(save_model_dir):
        os.makedirs(save_model_dir)
    best_dice = 0.0
    for epoch in np.arange(0, args.maxepoch):
        print('Epoch {}/{}'.format(epoch + 1, args.maxepoch))
        print('-' * 10)
        since = time.time()
        for phase in ['train', 'val']:
            if phase == 'train':
                dloader = train_dloader
                scheduler.step()
                for param_group in optimizer.param_groups:
                    print("Current LR: {:.8f}".format(param_group['lr']))
                model.train()  # Set model to training mode
            else:
                dloader = val_dloader
                model.eval()  # Set model to evaluate mode

            metrics = defaultdict(float)
            epoch_samples = 0
            for batch_ind, (imgs, masks) in enumerate(dloader):
                inputs = Variable(imgs.cuda())
                masks = Variable(masks.cuda())
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = calc_loss(outputs,
                                     masks,
                                     metrics,
                                     bce_weight=args.bce_weight)
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                # statistics
                epoch_samples += inputs.size(0)
            print_metrics(metrics, epoch_samples, phase)
            epoch_dice = metrics['dice'] / epoch_samples

            # deep copy the model
            if phase == 'val' and (epoch_dice > best_dice
                                   or epoch > args.maxepoch - 5):
                best_dice = epoch_dice
                best_model = copy.deepcopy(model.state_dict())
                best_model_name = "-".join([
                    args.model_name,
                    "{:03d}-{:.3f}.pth".format(epoch, best_dice)
                ])
                torch.save(best_model,
                           os.path.join(save_model_dir, best_model_name))
        time_elapsed = time.time() - since
        print('Epoch {:2d} takes {:.0f}m {:.0f}s'.format(
            epoch, time_elapsed // 60, time_elapsed % 60))
    print(
        "================================================================================"
    )
    print("Training finished...")
Пример #4
0
def test_slide_seg(args):
    model = None
    if args.model_name == "UNet":
        model = UNet(n_channels=args.in_channels, n_classes=args.class_num)
    elif args.model_name == "PSP":
        model = pspnet.PSPNet(n_classes=19, input_size=(512, 512))
        model.classification = nn.Conv2d(512, args.class_num, kernel_size=1)
    else:
        raise AssertionError("Unknow modle: {}".format(args.model_name))
    model_path = os.path.join(args.model_dir, args.tumor_type, args.split, args.best_model)
    model = nn.DataParallel(model)
    model.load_state_dict(torch.load(model_path))
    model.cuda()
    model.eval()

    since = time.time()
    result_dir = os.path.join(args.result_dir, args.tumor_type)
    filesystem.overwrite_dir(result_dir)
    slide_names = get_slide_filenames(args.slides_dir)
    if args.save_org and args.tumor_type == "viable":
        org_result_dir = os.path.join(result_dir, "Level0")
        filesystem.overwrite_dir(org_result_dir)

    for num, cur_slide in enumerate(slide_names):
        print("--{:02d}/{:02d} Slide:{}".format(num+1, len(slide_names), cur_slide))
        metrics = defaultdict(float)
        # load level-2 slide
        slide_path = os.path.join(args.slides_dir, cur_slide+".svs")
        if not os.path.exists(slide_path):
            slide_path = os.path.join(args.slides_dir, cur_slide+".SVS")
        wsi_head = pyramid.load_wsi_head(slide_path)
        p_level = args.slide_level
        pred_h, pred_w = (wsi_head.level_dimensions[p_level][1], wsi_head.level_dimensions[p_level][0])
        slide_img = wsi_head.read_region((0, 0), p_level, wsi_head.level_dimensions[p_level])
        slide_img = np.asarray(slide_img)[:,:,:3]

        coors_arr = wsi_stride_splitting(pred_h, pred_w, patch_len=args.patch_len, stride_len=args.stride_len)
        patch_arr, wmap = gen_patch_wmap(slide_img, coors_arr, plen=args.patch_len)
        patch_dset = PatchDataset(patch_arr, mask_arr=None, normalize=args.normalize, tumor_type=args.tumor_type)
        patch_loader = DataLoader(patch_dset, batch_size=args.batch_size, shuffle=False, num_workers=4, drop_last=False)
        ttl_samples = 0
        pred_map = np.zeros_like(wmap).astype(np.float32)
        for ind, patches in enumerate(patch_loader):
            inputs = Variable(patches.cuda())
            with torch.no_grad():
                outputs = model(inputs)
                preds = F.sigmoid(outputs)
                preds = torch.squeeze(preds, dim=1).data.cpu().numpy()
                if (ind+1)*args.batch_size <= len(coors_arr):
                    patch_coors = coors_arr[ind*args.batch_size:(ind+1)*args.batch_size]
                else:
                    patch_coors = coors_arr[ind*args.batch_size:]
                for ind, coor in enumerate(patch_coors):
                    ph, pw = coor[0], coor[1]
                    pred_map[ph:ph+args.patch_len, pw:pw+args.patch_len] += preds[ind]
                ttl_samples += inputs.size(0)

        prob_pred = np.divide(pred_map, wmap)
        slide_pred = (prob_pred > 0.5).astype(np.uint8)
        pred_save_path = os.path.join(result_dir, cur_slide + "_" + args.tumor_type + ".tif")
        io.imsave(pred_save_path, slide_pred*255)

        if args.save_org and args.tumor_type == "viable":
            org_w, org_h = wsi_head.level_dimensions[0]
            org_pred = transform.resize(prob_pred, (org_h, org_w))
            org_pred = (org_pred > 0.5).astype(np.uint8)
            org_save_path = os.path.join(org_result_dir, cur_slide[-3:] + ".tif")
            imsave(org_save_path, org_pred, compress=9)

    time_elapsed = time.time() - since
    print('Testing takes {:.0f}m {:.2f}s'.format(time_elapsed // 60, time_elapsed % 60))
def test_slide_seg(args):
    if args.model_name == "UNet":
        model = UNet(n_channels=args.in_channels, n_classes=args.class_num)
    elif args.model_name == "PSP":
        model = pspnet.PSPNet(n_classes=19,
                              input_size=(args.patch_len, args.patch_len))
        model.classification = nn.Conv2d(512, args.class_num, kernel_size=1)
    else:
        raise NotImplemented("Unknown model {}".format(args.model_name))

    model_path = os.path.join(args.model_dir, args.best_model)
    model = nn.DataParallel(model)
    model.load_state_dict(torch.load(model_path))
    model.cuda()
    model.eval()

    since = time.time()
    pydaily.filesystem.overwrite_dir(args.result_dir)
    slide_names = [ele for ele in os.listdir(args.slides_dir) if "jpg" in ele]

    ttl_pred_dice = 0.0
    for num, cur_slide in enumerate(slide_names):
        print("--{:2d}/{:2d} Slide:{}".format(num + 1, len(slide_names),
                                              cur_slide))
        start_time = timer()
        # load slide image and mask
        slide_path = os.path.join(args.slides_dir, cur_slide)
        slide_img = io.imread(slide_path)
        # split and predict
        coors_arr = wsi_stride_splitting(slide_img.shape[0],
                                         slide_img.shape[1],
                                         patch_len=args.patch_len,
                                         stride_len=args.stride_len)
        wmap = np.zeros((slide_img.shape[0], slide_img.shape[1]),
                        dtype=np.int32)
        pred_map = np.zeros_like(wmap).astype(np.float32)

        patch_list, coor_list = [], []
        for ic, coor in enumerate(coors_arr):
            ph, pw = coor[0], coor[1]
            patch_list.append(
                slide_img[ph:ph + args.patch_len, pw:pw + args.patch_len] /
                255.0)
            coor_list.append([ph, pw])
            wmap[ph:ph + args.patch_len, pw:pw + args.patch_len] += 1
            if len(patch_list) == args.batch_size or ic + 1 == len(coors_arr):
                patch_arr = np.asarray(patch_list).astype(np.float32)
                patch_dset = PatchDataset(patch_arr)
                patch_loader = DataLoader(patch_dset,
                                          batch_size=args.batch_size,
                                          shuffle=False,
                                          num_workers=4,
                                          drop_last=False)
                with torch.no_grad():
                    pred_list = []
                    for patches in patch_loader:
                        inputs = Variable(patches.cuda())
                        outputs = model(inputs)
                        preds = F.sigmoid(outputs)
                        preds = torch.squeeze(preds, dim=1).data.cpu().numpy()
                        pred_list.append(preds)
                    batch_preds = np.concatenate(pred_list, axis=0)
                    for ind, coor in enumerate(coor_list):
                        ph, pw = coor[0], coor[1]
                        pred_map[ph:ph + args.patch_len,
                                 pw:pw + args.patch_len] += batch_preds[ind]
                patch_list, coor_list = [], []

        prob_pred = np.divide(pred_map, wmap)
        slide_pred = morphology.remove_small_objects(prob_pred > 0.5,
                                                     min_size=20480).astype(
                                                         np.uint8)
        pred_save_path = os.path.join(args.result_dir,
                                      os.path.splitext(cur_slide)[0] + ".png")
        io.imsave(pred_save_path, slide_pred * 255)
        end_time = timer()
        print("Takes {}".format(
            pydaily.tic.time_to_str(end_time - start_time, 'sec')))

    time_elapsed = time.time() - since
    print("stride-len: {} with batch-size: {}".format(args.stride_len,
                                                      args.batch_size))
    print("Testing takes {:.0f}m {:.2f}s".format(time_elapsed // 60,
                                                 time_elapsed % 60))
Пример #6
0
        # forward
        # track history if only in train
        with torch.no_grad():
            outputs = model(inputs)
            loss = calc_loss(outputs, labels, metrics)
        # statistics
        epoch_samples += inputs.size(0)
    print_metrics(metrics, epoch_samples, "test")


if  __name__ == '__main__':
    args = set_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    # create model
    model = None
    if args.network == "UNet":

    elif args.network == "PSP":
        model = pspnet.PSPNet(n_classes=19, input_size=(160, 160))
        model.load_pretrained_model(model_path="./segnet/pspnet/pspnet101_cityscapes.caffemodel")
        model.classification = nn.Conv2d(512, args.class_num, kernel_size=1)
    else:
        raise Exception("Unknow network: {}".format(args.network))
    print("Net: {} session: {} model name: {}".format(args.network, args.session, args.model_name))
    model_path = os.path.join(args.model_dir, args.simu_type+args.network, args.session, args.model_name)
    model.load_state_dict(torch.load(model_path))
    model.cuda()

    # train model
    test_seg_model(model, args)
def test_slide_seg(args):
    model = pspnet.PSPNet(n_classes=19,
                          input_size=(args.patch_len, args.patch_len))
    model.classification = nn.Conv2d(512, args.class_num, kernel_size=1)

    model_path = os.path.join(args.model_dir, args.best_model)
    model = nn.DataParallel(model)
    model.load_state_dict(torch.load(model_path))
    model.cuda()
    model.eval()

    since = time.time()
    # filesystem.overwrite_dir(args.result_dir)
    slide_names = [ele for ele in os.listdir(args.slides_dir) if "jpg" in ele]

    ttl_pred_dice = 0.0
    for num, cur_slide in enumerate(slide_names):
        metrics = defaultdict(float)
        # load slide image and mask
        slide_path = os.path.join(args.slides_dir, cur_slide)
        slide_img = io.imread(slide_path) / 255.0
        mask_path = os.path.join(args.slides_dir,
                                 os.path.splitext(cur_slide)[0] + ".png")
        mask_img = io.imread(mask_path) / 255.0
        # split and predict
        coors_arr = wsi_stride_splitting(slide_img.shape[0],
                                         slide_img.shape[1],
                                         patch_len=args.patch_len,
                                         stride_len=args.stride_len)
        wmap = np.zeros((slide_img.shape[0], slide_img.shape[1]),
                        dtype=np.int32)
        pred_map = np.zeros_like(wmap).astype(np.float32)

        patch_list, coor_list = [], []
        for ic, coor in enumerate(coors_arr):
            ph, pw = coor[0], coor[1]
            patch_list.append(slide_img[ph:ph + args.patch_len,
                                        pw:pw + args.patch_len])
            coor_list.append([ph, pw])
            wmap[ph:ph + args.patch_len, pw:pw + args.patch_len] += 1
            if len(patch_list) == args.batch_size or ic + 1 == len(coors_arr):
                patch_arr = np.asarray(patch_list).astype(np.float32)
                patch_dset = PatchDataset(patch_arr)
                patch_loader = DataLoader(patch_dset,
                                          batch_size=args.batch_size,
                                          shuffle=False,
                                          num_workers=4,
                                          drop_last=False)
                with torch.no_grad():
                    pred_list = []
                    for patches in patch_loader:
                        inputs = Variable(patches.cuda())
                        outputs = model(inputs)
                        preds = F.sigmoid(outputs)
                        preds = torch.squeeze(preds, dim=1).data.cpu().numpy()
                        pred_list.append(preds)
                    batch_preds = np.concatenate(pred_list, axis=0)
                    for ind, coor in enumerate(coor_list):
                        ph, pw = coor[0], coor[1]
                        pred_map[ph:ph + args.patch_len,
                                 pw:pw + args.patch_len] += batch_preds[ind]
                patch_list, coor_list = [], []

        prob_pred = np.divide(pred_map, wmap)
        slide_pred = morphology.remove_small_objects(prob_pred > 0.5,
                                                     min_size=20480).astype(
                                                         np.uint8)
        # pred_save_path = os.path.join(args.result_dir, os.path.splitext(cur_slide)[0]+".png")
        # io.imsave(pred_save_path, slide_pred*255)
        intersection = np.multiply(mask_img, slide_pred)
        pred_dice = np.sum(intersection) / (np.sum(mask_img) +
                                            np.sum(slide_pred) -
                                            np.sum(intersection) + 1.0e-8)
        ttl_pred_dice += pred_dice
        print("--{:2d}/{:2d} Slide:{} JI:{:.3f}".format(
            num + 1, len(slide_names), cur_slide, pred_dice))

    time_elapsed = time.time() - since
    print('Testing takes {:.0f}m {:.2f}s'.format(time_elapsed // 60,
                                                 time_elapsed % 60))
    print('Slide-level average Dice coefficient is {:.3f}'.format(
        ttl_pred_dice / len(slide_names)))