예제 #1
0
def main():
    global  best_acc
    #load libraries
    normalize = transforms.Normalize(mean=cfg.mean,std=cfg.std)
    trans = transforms.Compose([
        transforms.ToTensor(),
        normalize
    ])

    with procedure('Prepare Dataset'):
        with open(cfg.rnn_data_split) as f:
            data = json.load(f)
        shuffle ={'train': True,'val':False}
        dataset = {x : Rnndata(data[x+'_pos']+data[x+'_neg'], args.s,args.patch_size, False, trans) for x in ['train', 'val']}
        dataloader = {x : torch.utils.data.DataLoader(dataset[x], batch_size = args.batch_size, shuffle = shuffle[x],
                                                        num_workers = args.workers,pin_memory = True)
                                                        for x in ['train','val']}
    #make model
    with procedure('Init Model'):
        embedder = ResNetEncoder(args.model)
        for param in embedder.parameters():
            param.requires_grad = False
        embedder = torch.nn.parallel.DataParallel(embedder.cuda())
        embedder.eval()

        rnn = rnn_single(args.ndims)
        rnn = torch.nn.parallel.DataParallel(rnn.cuda())
    
    #optimization
    with procedure('Optimization and Criterion'):
        if cfg.weights==0.5:
            criterion = nn.CrossEntropyLoss().cuda()
        else:
            w = torch.Tensor([1-cfg.weights,cfg.weights])
            criterion = nn.CrossEntropyLoss(w).cuda()
        optimizer = optim.SGD(rnn.parameters(), 0.1, momentum=0.9, dampening=0, weight_decay=1e-4, nesterov=True)
        cudnn.benchmark = True

    fconv = open(os.path.join(args.output, 'convergence.csv'), 'w')
    fconv.write('epoch,train.loss,train.fpr,train.fnr,val.loss,val.fpr,val.fnr\n')
    fconv.close()

    #
    for epoch in range(args.nepochs):

        train_loss, train_fpr, train_fnr = train_single(epoch, embedder, rnn, dataloader['train'], criterion, optimizer)
        val_loss, val_fpr, val_fnr = test_single(epoch, embedder, rnn, dataloader['val'], criterion)

        fconv = open(os.path.join(args.output,'convergence.csv'), 'a')
        fconv.write('{},{},{},{},{},{},{}\n'.format(epoch+1, train_loss, train_fpr, train_fnr, val_loss, val_fpr, val_fnr))
        fconv.close()

        val_err = (val_fpr + val_fnr)/2
        if 1-val_err >= best_acc:
            best_acc = 1-val_err
            obj = {
                'epoch': epoch+1,
                'state_dict': rnn.state_dict()
            }
            torch.save(obj, os.path.join(args.output,'rnn_checkpoint_best.pth'))
예제 #2
0
def main():

    #load model
    with procedure('load model'):
        model = models.resnet34(True)
        model.fc = nn.Linear(model.fc.in_features, cfg.n_classes)
        model = torch.nn.DataParallel(model.cuda())
        if args.resume:
            ch = torch.load(args.model)
            model.load_state_dict(ch['state_dict'])
        cudnn.benchmark = True

    with procedure('prepare dataset'):
        # normalization
        normalize = transforms.Normalize(mean=cfg.mean, std=cfg.std)
        trans = transforms.Compose([transforms.ToTensor(), normalize])
        #load data
        with open(cfg.data_split) as f:
            data = json.load(f)
        dset = MILdataset(data['train_pos'][:2], args.patch_size, trans)
        loader = torch.utils.data.DataLoader(dset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    dset.setmode(1)
    probs = inference(loader, model)
    maxs, index = group_topk(np.array(dset.slideIDX), probs, args.k)
    if not os.path.isdir(cfg.color_img_path):
        os.makedirs(cfg.color_img_path)
    for name, target, probs, idxs in zip(dset.slidenames, dset.targets, maxs,
                                         index):
        assert len(dset.slidenames) == len(maxs), print("length, error")
        flag = 'pos' if target == 1 else 'neg'
        orign_img_path = os.path.join(cfg.data_path, flag, name,
                                      name + '_orign.jpg')
        color_img_path = os.path.join(cfg.color_img_path, name + '.jpg')
        #print("orign_img_path: ",orign_img_path)
        patch_names = []
        orign_img = cv2.imread(orign_img_path)
        for i in range(args.k):
            idx = idxs[i]
            src = dset.grid[idx]
            dst = os.path.join(cfg.patch_predict, flag, src.split('/')[-2])
            if not os.path.isdir(dst):
                os.makedirs(dst)
            shutil.copy(src, dst)
            cp('(#r){}(#)\t(#g){}(#)'.format(src, dst))
            patch_names.append(src.split('/')[-1])

        plot_label(orign_img, patch_names, probs, color_img_path)
예제 #3
0
파일: visual.py 프로젝트: GryhomShaw/Tissue
def main():
    args = get_args()
    #load model
    with procedure('load model'):
        model = get_model(config)
        #model.fc = nn.Linear(model.fc.in_features, config.NUMCLASSES)
        model = model.cuda()
        if config.TEST.RESUME:
            ch = torch.load(config.TEST.CHECKPOINT)
            model_dict = {}
            for key, val in ch['state_dict'].items():
                model_dict[key[7:]] = val
            model.load_state_dict(model_dict)
            print(ch['best_dsc'], ch['epoch'])

        cudnn.benchmark = True

    with procedure('prepare dataset'):
        # normalization
        normalize = transforms.Normalize(mean=config.DATASET.MEAN,
                                         std=config.DATASET.STD)
        trans = transforms.Compose([transforms.ToTensor(), normalize])
        #load data
        with open(config.DATASET.SPLIT) as f:
            data = json.load(f)
        data_root = '/home/gryhomshaw/SSD1T/xiaoguohong/MIL_Tissue/patch/pos'
        data_list = [
            os.path.join(data_root, each_slide)
            for each_slide in os.listdir(data_root)
        ]
        dset = MILdataset(data_list, trans)
        loader = torch.utils.data.DataLoader(dset,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=config.WORKERS,
                                             pin_memory=True)
    time_fromat = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    output_path = os.path.join(os.path.join(config.TEST.OUTPUT, config.MODEL),
                               config.TRAIN.MODE + '_' + time_fromat)

    for idx, each_scale in enumerate([1]):
        dset.setmode(idx)
        model.eval()
        for i, (input, _, _, _) in enumerate(loader):
            input = input.cuda()
            input.requires_grad = True
            print(input.requires_grad)
            output = model(input)
            loss = output[:, 1]
            model.zero_grad()
            loss.backward()
            prob = F.softmax(output, dim=1)
            prob = prob.squeeze()
            if prob[1] <= 0.5:
                continue
            cur_weights = input.grad.cpu().numpy()
            cur_weights = cur_weights.squeeze()
            cur_weights = cur_weights.transpose([1, 2, 0])
            print(cur_weights.shape)
            cur_weights = np.max(cur_weights, axis=2)
            cur_weights = np.clip(cur_weights * 5000, 0, 255).astype(np.int)
            cur_weights = cur_weights.squeeze()
            cur_slide = dset.grid[i].split('/')[-2]
            cur_dir = os.path.join(output_path, cur_slide)
            if not os.path.isdir(os.path.join(cur_dir)):
                os.makedirs(cur_dir)
            cur_name = dset.grid[i].split('/')[-1]

            cv2.imwrite(os.path.join(cur_dir, cur_name), cur_weights)
예제 #4
0
파일: test.py 프로젝트: GryhomShaw/Tissue
def main():
    args = get_args()
    #load model
    with procedure('load model'):
        model = get_model(config)
        model = nn.DataParallel(model.cuda())
        if config.TEST.RESUME:
            ch = torch.load(config.TEST.CHECKPOINT)
            model.load_state_dict(ch['state_dict'])
            print(ch['best_dsc'], ch['epoch'])

        cudnn.benchmark = True

    with procedure('prepare dataset'):
        # normalization
        normalize = transforms.Normalize(mean=config.DATASET.MEAN,
                                         std=config.DATASET.STD)
        trans = transforms.Compose([transforms.ToTensor(), normalize])
        #load data
        with open(config.DATASET.SPLIT) as f:
            data = json.load(f)
        data_root = '/home/gryhomshaw/SSD1T/xiaoguohong/MIL_Tissue/patch/pos'
        data_list = [
            os.path.join(data_root, each_slide)
            for each_slide in os.listdir(data_root)
        ][:10]
        dset = MILdataset(data_list, trans)
        loader = torch.utils.data.DataLoader(dset,
                                             batch_size=config.TEST.BATCHSIZE,
                                             shuffle=False,
                                             num_workers=config.WORKERS,
                                             pin_memory=True)
    time_fromat = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    output_path = os.path.join(os.path.join(config.TEST.OUTPUT, config.MODEL),
                               config.TRAIN.MODE + '_' + time_fromat)
    patch_info = {}

    for idx, each_scale in enumerate(config.DATASET.MULTISCALE):
        dset.setmode(idx)
        probs, img_idxs, rows, cols = inference_vt(0, loader, model)
        if config.TEST.CAM:
            inference_cam(0, loader, model)
        parser_end = time.time()
        res = probs_parser(probs, img_idxs, rows, cols, dset, each_scale)
        print("finish parser: {}".format(time.time() - parser_end))
        merage_end = time.time()
        for key, values in res.items():
            if key not in patch_info:
                patch_info[key] = values
            else:
                patch_info[key].extend(values)
        print("finish merage: {}".format(time.time() - merage_end))
        '''
        for img_path, labels in res.items():
            if len(labels) == 0:
                continue
            plot_label(img_path, labels, each_scale, output_path)
        '''
    masks_end = time.time()
    res = []
    dsc = []
    with multiprocessing.Pool(processes=16) as pool:
        for each_img, each_labels in patch_info.items():
            res.append(
                pool.apply(get_mask, (each_img, each_labels, output_path)))
    pool.join()
    for each_res in res:
        print(type(each_res), each_res)
        dsc.extend([each_val for each_val in each_res.values()])

    print("finish get mask: {}".format(time.time() - masks_end))

    mean_dsc = np.array(dsc).mean()
    print(mean_dsc, dsc)
예제 #5
0
def work(args):
    img_path, size, scale, output_patch_path, patch_size, nums, bin, thresh = args
    '''
    img_path: path of tif  (e.g. ./data_append/1/1.tif)
    size: size of patch (from tiff to jpeg) (e.g. 20000)
    scale: scale (riff2jpeg)  (e.g. 4)
    output_patch_path: path of patch (e.g. ./Patch/pos/1)
    patch_size: during cut_image (2048)
    '''
    output_mask_path = img_path[:-4] + '_mask.jpg'
    output_img_path = img_path[:-4] + '_orign.jpg'

    slide = opsl.OpenSlide(img_path)
    [n, m] = slide.dimensions

    with procedure('Tiff2jpeg'):
        if not os.path.isfile(output_mask_path) or not os.path.isfile(output_img_path) :
            blocks_pre_col = math.ceil(m / size)
            blocks_pre_row = math.ceil(n / size)
            row_cache = []
            img_cache = []
            for i in range(blocks_pre_col):
                for j in range(blocks_pre_row):
                    x = i * size
                    y = j * size
                    height = min(x + size, m) - x
                    width = min(y + size, n) - y
                    img = np.array(slide.read_region((y, x), 0, (width, height)))
                    img = cv2.resize(img, (width // scale, height // scale))
                    row_cache.append(img)
                img_cache.append(np.concatenate(row_cache, axis=1))
                row_cache = []
            img = np.concatenate(img_cache, axis=0)
            gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
            ret1, th1 = cv2.threshold(gray, 0, 255, cv2.THRESH_OTSU)
            mask = 255 - th1
            cv2.imwrite(output_mask_path, mask)
            cv2.imwrite(output_img_path,img)
            cp('(#g)save_mask_path:{}(#)'.format(output_mask_path))
            cp('(#g)save_orign_path:{}(#)'.format(output_img_path))

    with procedure('Cut image'):
        if not os.path.isdir(output_patch_path):
            mask = cv2.imread(output_mask_path,0)
            assert  len(mask.shape) == 2 ,print('size error')
            mask_patch_size = patch_size // scale
            step = mask_patch_size // 2
            try:
                os.makedirs(output_patch_path)
            except:
                pass
            data = {}
            #patch_overlap_count = []
            data['roi'] = []
            h, w =mask.shape[0], mask.shape[1]
            threshold = get_threshold(img_path.split('/')[-2], nums, bin, thresh)

            data['threshold'] = threshold
            cp('(#r)Processinf:{}\tThreshold:{}'.format(img_path.split('/')[-2],threshold))
            for i in range(0, h, step):
                for j in range(0, w, step):
                    si = min(i, h - mask_patch_size)
                    sj = min(j, w - mask_patch_size)
                    si = max(0, si)  # 有可能h比size还要小
                    sj = max(0, sj)
                    x = min(scale * si, m - patch_size)
                    y = min(scale * sj, n - patch_size)
                    sub_img = mask[si: si + mask_patch_size, sj: sj + mask_patch_size]
                    cur_scale = (np.sum(sub_img) // 255) / (sub_img.shape[0] * sub_img.shape[1])
                    #patch_overlap_count.append([x, y, cur_threshold])
                    if cur_scale > threshold:
                        data['roi'].append([x, y,cur_scale])
                        patch = np.array(slide.read_region((y, x), 0, (patch_size, patch_size)).convert('RGB'))
                        patch_name = "{}_{}.jpg".format(x, y)
                        patch_path = os.path.join(output_patch_path, patch_name)
                        cv2.imwrite(patch_path, patch)
                        cp('(#y)save_path:\t{}(#)'.format(patch_path))
                    if sj != j:
                        break
                if si != i:
                    break
            json_path = output_mask_path[:-9] + '_mask.json'
            data['id'] = img_path.split('/')[-2]
            with open(json_path, 'w') as f:
                json.dump(data, f)
            cp('(#g)save_json:\t{}(#)'.format(json_path))
            '''
예제 #6
0
def main():
    global args, best_acc
    args = get_args()

    #cnn
    with procedure('init model'):
        model = models.resnet34(True)
        model.fc = nn.Linear(model.fc.in_features, 2)
        model = torch.nn.parallel.DataParallel(model.cuda())

    with procedure('loss and optimizer'):
        if cfg.weights == 0.5:
            criterion = nn.CrossEntropyLoss().cuda()
        else:
            w = torch.Tensor([1 - cfg.weights, cfg.weights])
            criterion = nn.CrossEntropyLoss(w).cuda()
        optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)

    cudnn.benchmark = True

    #normalization
    normalize = transforms.Normalize(mean=cfg.mean, std=cfg.std)
    trans = transforms.Compose([transforms.ToTensor(), normalize])

    with procedure('prepare dataset'):
        #load data
        with open(cfg.data_split) as f:  #
            data = json.load(f)
        train_dset = MILdataset(data['train_neg'][:14] + data['train_pos'],
                                args.patch_size, trans)
        train_loader = torch.utils.data.DataLoader(train_dset,
                                                   batch_size=args.batch_size,
                                                   shuffle=False,
                                                   num_workers=args.workers,
                                                   pin_memory=True)
        if args.val:
            val_dset = MILdataset(data['val_pos'] + data['val_neg'],
                                  args.patch_size, trans)
            val_loader = torch.utils.data.DataLoader(
                val_dset,
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=args.workers,
                pin_memory=True)
    with procedure('init tensorboardX'):
        tensorboard_path = os.path.join(args.output, 'tensorboard')
        if not os.path.isdir(tensorboard_path):
            os.makedirs(tensorboard_path)
        summary = TensorboardSummary(tensorboard_path, args.dis_slide)
        writer = summary.create_writer()

    #open output file
    fconv = open(os.path.join(args.output, 'convergence.csv'), 'w')
    fconv.write('epoch,metric,value\n')
    fconv.close()

    #loop throuh epochs
    for epoch in range(args.nepochs):
        train_dset.setmode(1)
        probs = inference(epoch, train_loader, model)
        topk = group_argtopk(np.array(train_dset.slideIDX), probs, args.k)
        images, names, labels = train_dset.getpatchinfo(topk)
        summary.plot_calsses_pred(writer, images, names, labels,
                                  np.array([probs[k] for k in topk]), args.k,
                                  epoch)
        slidenames, length = train_dset.getslideinfo()
        summary.plot_histogram(writer, slidenames, probs, length, epoch)
        #print([probs[k] for k in topk ])
        train_dset.maketraindata(topk)
        train_dset.shuffletraindata()
        train_dset.setmode(2)
        loss = train(epoch, train_loader, model, criterion, optimizer, writer)
        cp('(#r)Training(#)\t(#b)Epoch: [{}/{}](#)\t(#g)Loss:{}(#)'.format(
            epoch + 1, args.nepochs, loss))
        fconv = open(os.path.join(args.output, 'convergence.csv'), 'a')
        fconv.write('{},loss,{}\n'.format(epoch + 1, loss))
        fconv.close()

        #Validation
        if args.val and (epoch + 1) % args.test_every == 0:
            val_dset.setmode(1)
            probs = inference(epoch, val_loader, model)
            maxs = group_max(np.array(val_dset.slideIDX), probs,
                             len(val_dset.targets))
            pred = [1 if x >= 0.5 else 0 for x in maxs]
            err, fpr, fnr = calc_err(pred, val_dset.targets)
            #print('Validation\tEpoch: [{}/{}]\tError: {}\tFPR: {}\tFNR: {}'.format(epoch+1, args.nepochs, err, fpr, fnr))
            cp('(#y)Vaildation\t(#)(#b)Epoch: [{}/{}]\t(#)(#g)Error: {}\tFPR: {}\tFNR: {}(#)'
               .format(epoch + 1, args.nepochs, err, fpr, fnr))
            fconv = open(os.path.join(args.output, 'convergence.csv'), 'a')
            fconv.write('{},error,{}\n'.format(epoch + 1, err))
            fconv.write('{},fpr,{}\n'.format(epoch + 1, fpr))
            fconv.write('{},fnr,{}\n'.format(epoch + 1, fnr))
            fconv.close()
            #Save best model
            err = (fpr + fnr) / 2.
            if 1 - err >= best_acc:
                best_acc = 1 - err
                obj = {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_acc': best_acc,
                    'optimizer': optimizer.state_dict()
                }
                torch.save(obj, os.path.join(args.output,
                                             'checkpoint_best.pth'))
예제 #7
0
def main():

    args = get_args()
    global best_dsc
    #cnn
    with procedure('init model'):
        model = get_model(config)
        model = torch.nn.parallel.DataParallel(model.cuda())

    with procedure('loss and optimizer'):
        criterion = FocalLoss(config.TRAIN.LOSS.GAMMA,
                              config.DATASET.ALPHA).cuda()
        optimizer = optim.Adam(model.parameters(),
                               lr=config.TRAIN.LR,
                               weight_decay=config.TRAIN.LR)
    start_epoch = 0

    if config.TRAIN.RESUME:
        with procedure('resume model'):
            start_epoch, best_acc, model, optimizer = load_model(
                model, optimizer)

    cudnn.benchmark = True
    #normalization
    normalize = transforms.Normalize(mean=config.DATASET.MEAN,
                                     std=config.DATASET.STD)
    trans = transforms.Compose([transforms.ToTensor(), normalize])

    with procedure('prepare dataset'):
        #load data
        data_split = config.DATASET.SPLIT
        with open(data_split) as f:
            data = json.load(f)

        train_dset = MILdataset(data['train_neg'] + data['train_pos'], trans)
        train_loader = DataLoader(train_dset,
                                  batch_size=config.TRAIN.BATCHSIZE,
                                  shuffle=False,
                                  num_workers=config.WORKERS,
                                  pin_memory=True)
        if config.TRAIN.VAL:
            val_dset = MILdataset(data['val_pos'] + data['val_neg'], trans)
            val_loader = DataLoader(val_dset,
                                    batch_size=config.TEST.BATCHSIZE,
                                    shuffle=False,
                                    num_workers=config.WORKERS,
                                    pin_memory=True)

    with procedure('init tensorboardX'):
        train_log_path = os.path.join(
            config.TRAIN.OUTPUT,
            time.strftime('%Y%m%d_%H%M%S', time.localtime()))
        if not os.path.isdir(train_log_path):
            os.makedirs(train_log_path)
        tensorboard_path = os.path.join(train_log_path, 'tensorboard')
        with open(os.path.join(train_log_path, 'cfg.yaml'), 'w') as f:
            print(config, file=f)
        if not os.path.isdir(tensorboard_path):
            os.makedirs(tensorboard_path)
        summary = TensorboardSummary(tensorboard_path)
        writer = summary.create_writer()

    for epoch in range(start_epoch, config.TRAIN.EPOCHS):
        index = []
        for idx, each_scale in enumerate(config.DATASET.MULTISCALE):
            train_dset.setmode(idx)
            #print(len(train_loader), len(train_dset))
            probs = inference(epoch, train_loader, model)
            topk = group_argtopk(train_dset.ms_slideIDX[:], probs,
                                 train_dset.targets[:],
                                 train_dset.ms_slideLen[:], each_scale)
            index.extend([[each[0], each[1]]
                          for each in zip(topk, [idx] * len(topk))])
        train_dset.maketraindata(index)
        train_dset.shuffletraindata()
        train_dset.setmode(-1)
        loss = trainer(epoch, train_loader, model, criterion, optimizer,
                       writer)
        cp('(#r)Training(#)\t(#b)Epoch: [{}/{}](#)\t(#g)Loss:{}(#)'.format(
            epoch + 1, config.TRAIN.EPOCHS, loss))

        if config.TRAIN.VAL and (epoch + 1) % config.TRAIN.VALGAP == 0:
            patch_info = {}
            for idx, each_scale in enumerate(config.DATASET.MULTISCALE):
                val_dset.setmode(idx)
                probs, img_idxs, rows, cols = inference_vt(
                    epoch, val_loader, model)
                res = probs_parser(probs, img_idxs, rows, cols, val_dset,
                                   each_scale)

                for key, val in res.items():
                    if key not in patch_info:
                        patch_info[key] = val
                    else:
                        patch_info[key].extend(val)
            res = []
            dsc = []
            with multiprocessing.Pool(processes=16) as pool:
                for each_img, each_labels in patch_info.items():
                    res.append(
                        pool.apply(get_mask,
                                   (each_img, each_labels, None, False)))
            pool.join()
            for each_res in res:
                dsc.extend([each_val for each_val in each_res.values()])

            dsc = np.array(dsc).mean()
            '''
            maxs = group_max(np.array(val_dset.slideLen), probs, len(val_dset.targets), config.DATASET.MULTISCALE[-1])
            threshold = 0.5
            pred = [1 if x >= threshold else 0 for x in maxs]
            err, fpr, fnr, f1 = calc_err(pred, val_dset.targets)

            cp('(#y)Vaildation\t(#)(#b)Epoch: [{}/{}]\t(#)(#g)Error: {}\tFPR: {}\tFNR: {}\tF1: {}(#)'.format(epoch+1, config.TRAIN.EPOCHS, err, fpr, fnr, f1))
            '''
            cp('(#y)Vaildation\t(#)(#b)Epoch: [{}/{}]\t(#)(#g)DSC: {}(#)'.
               format(epoch + 1, config.TRAIN.EPOCHS, dsc))
            writer.add_scalar('Val/dsc', dsc, epoch)
            if dsc >= best_dsc:
                best_dsc = dsc
                obj = {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_dsc': best_dsc,
                    'optimizer': optimizer.state_dict()
                }
                torch.save(obj,
                           os.path.join(train_log_path, 'BestCheckpoint.pth'))