Exemplo n.º 1
0
def test_net_dice(args, net, batch_size=1, gpu=False, site=["site3", "site3"]):
    net.eval()
    # train_dataset = spinalcordGen100DataSet(args.spinal_root, img_size=args.img_size, site=site, batchsize=1, n_class=args.n_class,nlabel=args.nlabel,set=args.set,real_or_fake='real')
    train_dataset = spinalcordRealCropDataSet(args.spinal_root,
                                              img_size=args.img_size,
                                              site=site,
                                              batchsize=1,
                                              n_class=args.n_class,
                                              nlabel=args.nlabel,
                                              set=args.set)
    # train_dataset = spinalcordCenterCropDataSet(args.spinal_root,img_size=args.img_size,site=site,batchsize=1, n_class=args.n_class, nlabel=args.nlabel,set=args.set)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=1,
                                               shuffle=False,
                                               num_workers=args.num_workers,
                                               pin_memory=True,
                                               drop_last=True)

    N_train = len(train_dataset)
    dice = np.array([0.0, 0.0, 0.0], dtype=np.float32)
    s = np.array([0.0, 0.0, 0.0], dtype=np.float32)
    for batch_idx, (data, target, label, size,
                    name) in enumerate(train_loader):
        batch_idx = batch_idx + 1
        imgsavepath = os.path.join(args.savepath, name[0])
        imgs = np.array(data).astype(np.float32)
        true_masks = np.array(target).astype(np.float32)[0]
        imgs = torch.from_numpy(imgs)
        if gpu:
            imgs = imgs.cuda()
        masks_pred = net(imgs)
        masks_pred = masks_pred.cpu().detach().numpy()

        pred = np.argmax(masks_pred[0], 0)

        pred = image_resize(pred, name)

        label = np.array(label, dtype=np.uint8)
        label = np.squeeze(label)
        label = image_resize(label, name)
        pred_lbl = np.zeros((3, pred.shape[0], pred.shape[1]))
        for i in range(3):
            pred_lbl[i] = np.array((pred == i), dtype=np.uint8)

        for i in range(3):
            dice[i] = calculate_dice(pred_lbl[i], true_masks[i])

        s += dice
        save_mask = np.ceil(pred / 2 * 255)
        scipy.misc.imsave(imgsavepath, save_mask)

    dice_result = s / N_train
    print(site[0])
    print(dice_result)
    line4 = '{}:Dice result: {}\n'.format(site[0], dice_result[1])

    with open(args.logfile, 'a+') as logf:
        logf.write(line4)

    return dice_result
Exemplo n.º 2
0
model_path = "/home/jjchu/Result/UNetsnapshots/Real_center100_mixedsite12_meanstd_imgs_b8_25l_1103/CP30.pth"
net = UNet_GN5(n_channels=3, n_classes=3)
net.eval()
net.cuda()
net.load_state_dict(torch.load(model_path))
args = get_args()
# for layer in layers:
features_blobs = []
net.up4.up.register_forward_hook(hook_feature)  # 2,5

# train_dataset = spinalcordGen100DataSet(args.spinal_root,img_size=args.img_size,site=['site3','site3'],batchsize=1, n_class=args.n_class, nlabel=True,set='train',real_or_fake='fake')
train_dataset = spinalcordRealCropDataSet(args.spinal_root,
                                          img_size=args.img_size,
                                          site=['site4', 'site4'],
                                          batchsize=1,
                                          n_class=args.n_class,
                                          nlabel=True,
                                          set='train')
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=1,
                                           shuffle=False,
                                           num_workers=args.num_workers,
                                           pin_memory=True,
                                           drop_last=True)
t = 526
for batch_idx, (data, target, label, size, name) in enumerate(train_loader):
    if batch_idx == t:
        image, target, label, size, name = data, target, label, size, name
        imgpath = './mask/' + name[0] + '.jpg'
        savepath = './mask/' + 'ht_up4_up' + name[0] + '.jpg'
Exemplo n.º 3
0
def train_net(args,
              net,
              epochs=5,
              batch_size=1,
              lr=0.1,
              val_percent=0.05,
              save_cp=True,
              gpu=False):
    print('''
    Starting training:
        Epochs: {}
        Batch size: {}
        Learning rate: {}
        Checkpoints: {}
        CUDA: {}
    '''.format(epochs, batch_size, lr, str(save_cp), str(gpu)))
    # lossdice = DiceLoss()
    lossdice = MulticlassDiceLoss()
    optimizer = optim.Adam(net.parameters(),
                           lr=args.lr,
                           betas=(args.beta1, args.beta2))

    for epoch in range(epochs):
        print('Starting epoch {}/{}.'.format(epoch + 1, epochs))
        line1 = 'Starting epoch {}/{}.\n'.format(epoch + 1, epochs)
        net.train()

        epoch_loss = 0
        # train_dataset = spinalcordGen100DataSet(args.spinal_root, img_size=args.img_size, site=args.site, batchsize=args.batchsize, n_class=args.n_class,nlabel=args.nlabel,set=args.set,real_or_fake=args.real_or_fake)
        train_dataset = spinalcordRealCropDataSet(args.spinal_root,
                                                  img_size=args.img_size,
                                                  site=args.site,
                                                  batchsize=args.batchsize,
                                                  n_class=args.n_class,
                                                  nlabel=args.nlabel,
                                                  set=args.set)
        # train_dataset = spinalcordCenterCropDataSet(args.spinal_root,img_size=args.img_size,site=args.site,batchsize=args.batchsize, n_class=args.n_class, nlabel=args.nlabel,set=args.set)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batchsize,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=True,
            drop_last=True)

        N_train = len(train_dataset)
        for batch_idx, (data, target, label, size,
                        name) in enumerate(train_loader):
            batch_idx = batch_idx + 1
            # print(batch_idx)
            imgs = np.array(data).astype(np.float32)
            true_masks = np.array(target).astype(np.float32)

            imgs = torch.from_numpy(imgs)
            true_masks = torch.from_numpy(true_masks)

            if gpu:
                imgs = imgs.cuda()
                true_masks = true_masks.cuda()

            masks_pred = net(imgs)
            # loss = dice_coeff(masks_pred, true_masks)
            loss = lossdice(masks_pred, true_masks, args.weight)

            epoch_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        adjust_learning_rate(optimizer, epoch)
        # print(epoch_loss)
        print(batch_idx)
        print('Epoch finished ! Loss: {}'.format(epoch_loss / batch_idx))

        line2 = 'Epoch finished ! Loss: {}\n'.format(epoch_loss / batch_idx)
        line3 = str(optimizer.param_groups[0]['lr'])
        print(line3)
        with open(args.logfile, 'a+') as logf:
            logf.write(line1)
            logf.write(line2)
            logf.write(line3 + '\n')

        if (epoch + 1) % 30 == 0:
            site = ['site1', 'site1']
            print(
                "######################## test site1 ###############################"
            )
            print(site[0])
            dice = test_net_dice(args=args,
                                 net=net,
                                 batch_size=1,
                                 gpu=args.gpu,
                                 site=site)
            print(dice)

        if (epoch + 1) % 5 == 0:
            site = ['site3', 'site3']
            print(
                "######################## test site3 ###############################"
            )
            dice = test_net_dice(args=args,
                                 net=net,
                                 batch_size=1,
                                 gpu=args.gpu,
                                 site=site)
            print(args.weight)

        if (epoch + 1) % 5 == 0:
            site = ['site4', 'site4']
            print(
                "######################## test site4 ###############################"
            )
            dice = test_net_dice(args=args,
                                 net=net,
                                 batch_size=1,
                                 gpu=args.gpu,
                                 site=site)
            print(args.weight)

        if save_cp and (epoch + 1) % 10 == 0:
            torch.save(net.state_dict(),
                       args.snapshots + 'CP{}.pth'.format(epoch + 1))
            print('Checkpoint {} saved !'.format(epoch + 1))