예제 #1
0
def test():
    device = torch.device(args.devices if torch.cuda.is_available() else "cpu")
    #test_dataset = Training_Dataset(args.test_dir, (args.image_size,args.image_size),(args.noise, args.noise_param))
    # test_dataset = HongZhang_Dataset("/data_1/data/Noise2Noise/shenqingbiao/0202", "/data_1/data/Noise2Noise/hongzhang")
    test_dataset = HongZhang_TestDataset("/data_1/data/红章图片/test/hongzhang", (256, 256))
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

    # choose the model
    if args.model == "unet":
        model = UNet(in_channels=args.image_channels, out_channels=args.image_channels)
    elif args.model == "srresnet":
        model = SRResnet(args.image_channels, args.image_channels)
    elif args.model == "eesp":
        model = EESPNet_Seg(args.image_channels, 2)
    else:
        model = UNet(in_channels=args.image_channels, out_channels=args.image_channels)
    print('loading model')
    # model.load_state_dict(torch.load(model_path))
    # model.eval()
    # model.to(device)
    if args.resume_model:
        resume_model(model, args.resume_model)
        model.eval()
        model.to(device)

    # result_dir = args.denoised_dir
    # if not os.path.exists(result_dir):
    #     os.mkdir(result_dir)

    for batch_idx, image in enumerate(test_loader):
        #PIL_ShowTensor(torch.squeeze(source))
        #PIL_ShowTensor2(torch.squeeze(source),torch.squeeze(noise))
        image = image.to(device)
        denoised_img = model(image).detach().cpu()
        CV2_showTensors(image.cpu(),denoised_img,timeout=5000)
예제 #2
0
def test(args):
    model = UNet(3, 1)
    model.load_state_dict(torch.load(args.weight, map_location='cpu'))
    verse_data = DatasetVerse(dir_img,
                              dir_mask,
                              transform=x_transform,
                              target_transform=y_transform)
    dataloaders = DataLoader(verse_data, batch_size=1)
    model.eval()
    import matplotlib.pyplot as plt
    plt.ion()
    with torch.no_grad():
        for x, _ in dataloaders:
            y = model(x).sigmoid()
            img_y = torch.squeeze(y).numpy()
            plt.imshow(img_y)
            plt.pause(0.01)
        plt.show()
예제 #3
0
파일: predict.py 프로젝트: ozanpkr/FracNet
def predict(args):
    batch_size = 16
    num_workers = 4
    postprocess = True if args.postprocess == "True" else False

    model = UNet(1, 1, first_out_channels=16)
    model.eval()
    if args.model_path is not None:
        model_weights = torch.load(args.model_path)
        model.load_state_dict(model_weights)
    model = nn.DataParallel(model).cuda()

    transforms = [tsfm.Window(-200, 1000), tsfm.MinMaxNorm(-200, 1000)]

    image_path_list = sorted([
        os.path.join(args.image_dir, file)
        for file in os.listdir(args.image_dir) if "nii" in file
    ])
    image_id_list = [
        os.path.basename(path).split("-")[0] for path in image_path_list
    ]

    progress = tqdm(total=len(image_id_list))
    pred_info_list = []
    for image_id, image_path in zip(image_id_list, image_path_list):
        dataset = FracNetInferenceDataset(image_path, transforms=transforms)
        dataloader = FracNetInferenceDataset.get_dataloader(
            dataset, batch_size, num_workers)
        pred_arr = _predict_single_image(model, dataloader, postprocess,
                                         args.prob_thresh, args.bone_thresh,
                                         args.size_thresh)
        pred_image, pred_info = _make_submission_files(pred_arr, image_id,
                                                       dataset.image_affine)
        pred_info_list.append(pred_info)
        pred_path = os.path.join(args.pred_dir, f"{image_id}_pred.nii.gz")
        nib.save(pred_image, pred_path)

        progress.update()

    pred_info = pd.concat(pred_info_list, ignore_index=True)
    pred_info.to_csv(os.path.join(args.pred_dir, "pred_info.csv"), index=False)
예제 #4
0
# resume
if osp.isfile(resume_path):
    checkpoint = torch.load(resume_path)
    model.load_state_dict(checkpoint["model_state"])
    best_iou = checkpoint['best_iou']
    print(
        "=====>",
        "Loaded checkpoint '{}' (iter {})".format(resume_path,
                                                  checkpoint["epoch"]))
    print("=====> best mIoU: %.4f best mean dice: %.4f" %
          (best_iou, (best_iou * 2) / (best_iou + 1)))
else:
    raise ValueError("can't find model")

print(">>>Test After Dense CRF: ")
model.eval()
running_metrics.reset()
with torch.no_grad():
    for i, (img, mask) in tqdm(enumerate(val_loader)):
        img = img.to(device)
        output = model(img)  #[-1, 9, 256, 256]
        probs = F.softmax(output, dim=1)
        pred = probs.cpu().data[0].numpy()
        label = mask.cpu().data[0].numpy()
        # crf
        img = img.cpu().data[0].numpy()
        pred = dense_crf(img * 255, pred)
        # print(pred.shape)
        # _, pred = torch.max(torch.tensor(pred), dim=-1)
        pred = np.asarray(pred, dtype=np.int)
        label = np.asarray(label, dtype=np.int)
예제 #5
0
class BaseModel:
    losses = {'train': [], 'val': []}
    acces = {'train': [], 'val': []}
    scores = {'train': [], 'val': []}
    pred = {'train': [], 'val': []}
    true = {'train': [], 'val': []}

    def __init__(self, args):
        self.args = args
        self.net = None
        print(args.model_name)
        if args.model_name == 'UNet':
            self.net = UNet(args.in_channels, args.num_classes)
            self.net.apply(weights_init)
        elif args.model_name == 'UNetResNet34':
            self.net = UNetResNet34(args.num_classes, dropout_2d=0.2)
        elif args.model_name == 'UNetResNet152':
            self.net = UNetResNet152(args.num_classes, dropout_2d=0.2)
        elif args.model_name == 'UNet11':
            self.net = UNet11(args.num_classes, pretrained=True)
        elif args.model_name == 'UNetVGG16':
            self.net = UNetVGG16(args.num_classes,
                                 pretrained=True,
                                 dropout_2d=0.0,
                                 is_deconv=True)
        elif args.model_name == 'deeplab50_v2':
            if args.ms:
                raise NotImplemented
            else:
                self.net = deeplab50_v2(args.num_classes,
                                        pretrained=args.pretrained)
        elif args.model_name == 'deeplab_v2':
            if args.ms:
                self.net = ms_deeplab_v2(args.num_classes,
                                         pretrained=args.pretrained,
                                         scales=args.ms_scales)
            else:
                self.net = deeplab_v2(args.num_classes,
                                      pretrained=args.pretrained)
        elif args.model_name == 'deeplab_v3':
            if args.ms:
                self.net = ms_deeplab_v3(args.num_classes,
                                         out_stride=args.out_stride,
                                         pretrained=args.pretrained,
                                         scales=args.ms_scales)
            else:
                self.net = deeplab_v3(args.num_classes,
                                      out_stride=args.out_stride,
                                      pretrained=args.pretrained)
        elif args.model_name == 'deeplab_v3_plus':
            if args.ms:
                self.net = ms_deeplab_v3_plus(args.num_classes,
                                              out_stride=args.out_stride,
                                              pretrained=args.pretrained,
                                              scales=args.ms_scales)
            else:
                self.net = deeplab_v3_plus(args.num_classes,
                                           out_stride=args.out_stride,
                                           pretrained=args.pretrained)

        self.interp = nn.Upsample(size=args.size, mode='bilinear')

        self.iterations = args.epochs
        self.lr_current = args.lr
        self.cuda = args.cuda
        self.phase = args.phase
        self.lr_policy = args.lr_policy
        self.cyclic_m = args.cyclic_m
        if self.lr_policy == 'cyclic':
            print('using cyclic')
            assert self.iterations % self.cyclic_m == 0
        if args.loss == 'CELoss':
            self.criterion = nn.CrossEntropyLoss(size_average=True)
        elif args.loss == 'DiceLoss':
            self.criterion = DiceLoss(num_classes=args.num_classes)
        elif args.loss == 'MixLoss':
            self.criterion = MixLoss(args.num_classes,
                                     weights=args.loss_weights)
        elif args.loss == 'LovaszLoss':
            self.criterion = LovaszSoftmax(per_image=args.loss_per_img)
        elif args.loss == 'FocalLoss':
            self.criterion = FocalLoss(args.num_classes, alpha=None, gamma=2)
        else:
            raise RuntimeError('must define loss')

        if 'deeplab' in args.model_name:
            self.optimizer = optim.SGD(
                [{
                    'params': get_1x_lr_params_NOscale(self.net),
                    'lr': args.lr
                }, {
                    'params': get_10x_lr_params(self.net),
                    'lr': 10 * args.lr
                }],
                lr=args.lr,
                momentum=args.momentum,
                weight_decay=args.weight_decay)
        else:
            self.optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                              self.net.parameters()),
                                       lr=args.lr,
                                       momentum=args.momentum,
                                       weight_decay=args.weight_decay)
        self.iters = 0
        self.best_val = 0.0
        self.count = 0

    def init_model(self):
        if self.args.resume_model:
            saved_state_dict = torch.load(
                self.args.resume_model,
                map_location=lambda storage, loc: storage)
            if self.args.ms:
                new_params = self.net.Scale.state_dict().copy()
                for i in saved_state_dict:
                    # Scale.layer5.conv2d_list.3.weight
                    i_parts = i.split('.')
                    # print i_parts
                    if not (not i_parts[0] == 'layer5') and (not i_parts[0]
                                                             == 'decoder'):
                        new_params[i] = saved_state_dict[i]
                self.net.Scale.load_state_dict(new_params)
            else:
                new_params = self.net.state_dict().copy()
                for i in saved_state_dict:
                    # Scale.layer5.conv2d_list.3.weight
                    i_parts = i.split('.')
                    # print i_parts
                    if (not i_parts[0] == 'layer5') and (not i_parts[0]
                                                         == 'decoder'):
                        # if not i_parts[0] == 'layer5':
                        new_params[i] = saved_state_dict[i]
                self.net.load_state_dict(new_params)

            print('Resuming training, image net loading {}...'.format(
                self.args.resume_model))
            # self.load_weights(self.net, self.args.resume_model)

        if self.args.mGPUs:
            self.net = nn.DataParallel(self.net)

        if self.args.cuda:
            self.net = self.net.cuda()
            cudnn.benchmark = True

    def _adjust_learning_rate(self, epoch):
        """Sets the learning rate to the initial LR decayed by 10 at every specified step
        # Adapted from PyTorch Imagenet example:
        # https://github.com/pytorch/examples/blob/master/imagenet/main.py
        """
        if epoch < int(self.iterations * 0.5):
            self.lr_current = max(self.lr_current * self.args.gamma, 1e-4)
        elif epoch < int(self.iterations * 0.85):
            self.lr_current = max(self.lr_current * self.args.gamma, 1e-5)
        else:
            self.lr_current = max(self.lr_current * self.args.gamma, 1e-6)
        self.optimizer.param_groups[0]['lr'] = self.lr_current
        self.optimizer.param_groups[1]['lr'] = self.lr_current * 10

    def save_network(self, net, net_name, epoch, label=''):
        save_fname = '%s_%s_%s.pth' % (epoch, net_name, label)
        save_path = os.path.join(self.args.save_folder, self.args.exp_name,
                                 save_fname)
        torch.save(net.state_dict(), save_path)

    def load_weights(self, net, base_file):
        other, ext = os.path.splitext(base_file)
        if ext == '.pkl' or '.pth':
            print('Loading weights into state dict...')
            net.load_state_dict(
                torch.load(base_file,
                           map_location=lambda storage, loc: storage))
            print('Finished!')
        else:
            print('Sorry only .pth and .pkl files supported.')

    def load_trained_model(self):
        path = os.path.join(self.args.save_folder, self.args.exp_name,
                            self.args.trained_model)
        print('eval cls, image net loading {}...'.format(path))
        if self.args.ms:
            self.load_weights(self.net.Scale, path)
        else:
            self.load_weights(self.net, path)

    def eval(self, dataloader):
        assert self.phase == 'test', "Command arg phase should be 'test'. "
        from tqdm import tqdm
        self.net.eval()
        output = []

        for i, image in tqdm(enumerate(dataloader)):
            if self.cuda:
                image = Variable(image.cuda(), volatile=True)
            else:
                image = Variable(image, volatile=True)

            # cls forward
            out = self.net(image)
            if isinstance(out, list):
                out_max = out[-1]
                if out_max.size(2) != image.size(2):
                    out = self.interp(out_max)
            else:
                if out.size(2) != image.size(2):
                    out = self.interp(out)
            # out [bs * num_tta, c, h, w]
            if self.args.use_tta:
                num_tta = len(tta_config)
                # out = F.softmax(out, dim=1)
                out = detta_score(
                    out.view(num_tta, -1, self.args.num_classes, out.size(2),
                             out.size(3)))  # [num_tta, bs, nclass, H, W]
                out = out.mean(dim=0)  # [bs, nclass, H, W]
            out = F.softmax(out)
            output.extend([
                resize(pred[1].data.cpu().numpy(), (101, 101)) for pred in out
            ])
        return np.array(output)

    def tta(self, dataloaders):
        results = np.zeros(shape=(len(dataloaders[0].dataset),
                                  self.args.num_classes))
        for dataloader in dataloaders:
            output = self.eval(dataloader)
            results += output
        return np.argmax(results, 1)

    def tta_output(self, dataloaders):
        results = np.zeros(shape=(len(dataloaders[0].dataset),
                                  self.args.num_classes))
        for dataloader in dataloaders:
            output = self.eval(dataloader)
            results += output
        return results

    def test_val(self, dataloader):
        assert self.phase == 'test', "Command arg phase should be 'test'. "
        from tqdm import tqdm
        self.net.eval()
        predict = []
        true = []
        t1 = time.time()

        for i, (image, mask) in tqdm(enumerate(dataloader)):
            if self.cuda:
                image = Variable(image.cuda(), volatile=True)
                label_image = Variable(mask.cuda(), volatile=True)
            else:
                image = Variable(image, volatile=True)
                label_image = Variable(mask, volatile=True)

            # cls forward
            out = self.net(image)
            if isinstance(out, list):
                out_max = out[-1]
                if out_max.size(2) != label_image.size(2):
                    out = self.interp(out_max)
            else:
                if out.size(2) != image.size(2):
                    out = self.interp(out)
            # out [bs * num_tta, c, h, w]
            if self.args.use_tta:
                num_tta = len(tta_config)
                # out = F.softmax(out, dim=1)
                out = detta_score(
                    out.view(num_tta, -1, self.args.num_classes, out.size(2),
                             out.size(3)))  # [num_tta, bs, nclass, H, W]
                out = out.mean(dim=0)  # [bs, nclass, H, W]
            out = F.softmax(out)
            if self.args.aug == 'heng':
                out = out[:, :, 11:11 + 202, 11:11 + 202]
            predict.extend([
                resize(pred[1].data.cpu().numpy(), (101, 101)) for pred in out
            ])
            # predict.extend([pred[1, :101, :101].data.cpu().numpy() for pred in out])
            # pred.extend(out.data.cpu().numpy())
            true.extend(label_image.data.cpu().numpy())
        # pred_all = np.argmax(np.array(pred), 1)
        for t in np.arange(0.05, 0.51, 0.01):
            pred_all = np.array(predict) > t
            true_all = np.array(true).astype(np.int)
            # new_iou = intersection_over_union(true_all, pred_all)
            # new_iou_t = intersection_over_union_thresholds(true_all, pred_all)
            mean_iou, iou_t = mIoU(true_all, pred_all)
            print('threshold : {:.4f}'.format(t))
            print('mean IoU : {:.4f}, IoU threshold : {:.4f}'.format(
                mean_iou, iou_t))

        return predict, true

    def run_epoch(self, dataloader, writer, epoch, train=True, metrics=True):
        if train:
            self.net.train()
            flag = 'train'
        else:
            self.net.eval()
            flag = 'val'
        t2 = time.time()
        for image, mask in dataloader:
            if train and self.lr_policy != 'step':
                adjust_learning_rate(self.args.lr, self.optimizer, self.iters,
                                     self.iterations * len(dataloader), 0.9,
                                     self.cyclic_m, self.lr_policy)
                self.iters += 1

            if self.cuda:
                image = Variable(image.cuda(), volatile=(not train))
                label_image = Variable(mask.cuda(), volatile=(not train))
            else:
                image = Variable(image, volatile=(not train))
                label_image = Variable(mask, volatile=(not train))
            # cls forward
            out = self.net(image)

            if isinstance(out, list):
                out_max = None
                loss = 0.0
                for i, out_scale in enumerate(out):
                    if out_scale.size(2) != label_image.size(2):
                        out_scale = self.interp(out_scale)
                    if i == (len(out) - 1):
                        out_max = out_scale
                    loss += self.criterion(out_scale, label_image)
                label_image_np = label_image.data.cpu().numpy()
                sig_out_np = out_max.data.cpu().numpy()
                acc = accuracy(label_image_np, np.argmax(sig_out_np, 1))

                self.pred[flag].extend(sig_out_np)
                self.true[flag].extend(label_image_np)

                self.losses[flag].append(loss.data[0])
                self.acces[flag].append(acc)

            else:
                if out.size(-1) != label_image.size(-1):
                    out = self.interp(out)

                loss = self.criterion(out, label_image)
                label_image_np = label_image.data.cpu().numpy()
                sig_out_np = out.data.cpu().numpy()
                acc = accuracy(label_image_np, np.argmax(sig_out_np, 1))

                self.pred[flag].extend(sig_out_np)
                self.true[flag].extend(label_image_np)

                self.losses[flag].append(loss.data[0])
                self.acces[flag].append(acc)

            if train:
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

        if metrics:
            n = len(self.losses[flag])
            loss = sum(self.losses[flag]) / n
            scalars = [
                loss,
            ]
            names = [
                'loss',
            ]
            write_scalars(writer, scalars, names, epoch, tag=flag + '_loss')

            all_acc = sum(self.acces[flag]) / n
            scalars = [
                all_acc,
            ]
            names = [
                'all_acc',
            ]
            write_scalars(writer, scalars, names, epoch, tag=flag + '_acc')

            # all_score = sum(self.scores[flag]) / n
            # scalars = [all_score, ]
            # names = ['all_score', ]
            # write_scalars(writer, scalars, names, epoch, tag=flag + '_score')

            pred_all = np.argmax(np.array(self.pred[flag]), 1)
            true_all = np.array(self.true[flag]).astype(np.int)
            mean_iou, iou_t = mIoU(true_all, pred_all)

            # new_iou = intersection_over_union(true_all, pred_all)
            # new_iou_t = intersection_over_union_thresholds(true_all, pred_all)

            scalars = [
                mean_iou,
                iou_t,
            ]
            names = [
                'mIoU',
                'mIoU_threshold',
            ]
            write_scalars(writer, scalars, names, epoch, tag=flag + '_IoU')

            scalars = [
                self.optimizer.param_groups[0]['lr'],
            ]
            names = [
                'learning_rate',
            ]
            write_scalars(writer, scalars, names, epoch, tag=flag + '_lr')

            print(
                '{} loss: {:.4f} | acc: {:.4f} | mIoU: {:.4f} | mIoU_threshold: {:.4f} |  n_iter: {} |  learning_rate: {} | time: {:.2f}'
                .format(flag, loss, all_acc, mean_iou, iou_t, epoch,
                        self.optimizer.param_groups[0]['lr'],
                        time.time() - t2))

            self.losses[flag] = []
            self.pred[flag] = []
            self.true[flag] = []
            self.acces[flag] = []
            self.scores[flag] = []

            if (not train) and (iou_t >= self.best_val):
                if self.args.ms:
                    if self.args.mGPUs:
                        self.save_network(self.net.module.Scale,
                                          self.args.model_name,
                                          epoch=epoch,
                                          label='best')
                    else:
                        self.save_network(self.net.Scale,
                                          self.args.model_name,
                                          epoch=epoch,
                                          label='best')
                else:
                    if self.args.mGPUs:
                        self.save_network(self.net.module,
                                          self.args.model_name,
                                          epoch=epoch,
                                          label='best')
                    else:
                        self.save_network(self.net,
                                          self.args.model_name,
                                          epoch=epoch,
                                          label='best')
                print(
                    'val improve from {:.4f} to {:.4f} saving in best val_iteration {}'
                    .format(self.best_val, iou_t, epoch))
                self.best_val = iou_t
                self.count = 0

            if (not train) and (self.best_val - iou_t > 0.003) and (
                    self.count < 10) and (self.lr_policy == 'step'):
                self.count += 1
            if (not train) and (self.count >= 10) and (self.lr_policy
                                                       == 'step'):
                self._adjust_learning_rate(epoch)
                self.count = 0

    def train_val(self, dataloader_train, dataloader_val, writer):
        val_epoch = 0
        for epoch in range(self.iterations):
            if (self.lr_policy == 'cyclic') and (
                    epoch % int(self.iterations / self.cyclic_m) == 0):
                print('-------start cycle {}------------'.format(
                    epoch // int(self.iterations / self.cyclic_m)))
                self.best_val = 0.0
            self.run_epoch(dataloader_train,
                           writer,
                           epoch,
                           train=True,
                           metrics=True)
            self.run_epoch(dataloader_val,
                           writer,
                           val_epoch,
                           train=False,
                           metrics=True)
            val_epoch += 1
            if (epoch + 1) % self.args.save_freq == 0:
                if self.args.ms:
                    if self.args.mGPUs:
                        self.save_network(
                            self.net.module.Scale,
                            self.args.model_name,
                            epoch=val_epoch,
                        )
                    else:
                        self.save_network(
                            self.net.Scale,
                            self.args.model_name,
                            epoch=val_epoch,
                        )
                else:
                    if self.args.mGPUs:
                        self.save_network(
                            self.net.module,
                            self.args.model_name,
                            epoch=val_epoch,
                        )
                    else:
                        self.save_network(
                            self.net,
                            self.args.model_name,
                            epoch=val_epoch,
                        )
                print('saving in val_iteration {}'.format(val_epoch))