def eval(cfg, net, loader, device):
    target_type = torch.float32 if cfg['class_num'] == 1 else torch.long
    n_val = len(loader)
    tot = 0

    with tqdm(total=n_val, desc='Validation ', unit='batch',
              leave=False) as pbar:
        for iter, (imgs, targets) in enumerate(loader):
            imgs = imgs.to(device)
            targets = targets.to(device)

            with torch.no_grad():
                predict = net(imgs)

            if cfg['class_num'] > 1:
                weight = torch.tensor(cfg['weight']).to(device)
                tot += F.cross_entropy(predict, targets, weight).item()
            else:
                mask = torch.sigmoid(predict)
                mask = (mask > 0.5).float()
                tot += dice_coeff(mask, targets).item()
            pbar.update()
    print('')  # for better display

    return tot / n_val
Пример #2
0
def train_epoch(net, loader, optimizer, cost):
    # we transfer the mode of network to train
    net.train()

    batch_loss = AvgMeter()
    for batch_idx, (data, label) in enumerate(loader):
        data = Variable(
            data.cuda()
        )  # A Variable wraps a Tensor. It supports nearly all the API’s defined by a Tensor.
        label = Variable(label.cuda())

        output = net(data)  # Give the data to the network

        loss = cost(output, label)
        # evaluate the cost function
        output = output.squeeze().data.cpu().numpy()
        label = label.squeeze().cpu().numpy()
        dice = dice_coeff(output, label)

        optimizer.zero_grad(
        )  # we need to set the gradients to zero before starting to do backpropragation because PyTorch accumulates the gradients on subsequent backward passes
        loss.backward()
        optimizer.step()

        batch_loss.update(loss.item())
        if batch_idx % 10 == 0:
            print("Train Batch {} || Loss: {:.4f} | Training Dice: {:.4f}".
                  format(str(batch_idx).zfill(4), batch_loss.val, dice))
    return batch_loss.avg
 def loss_func(self):
     with tf.name_scope('Loss'):
         y_one_hot = tf.one_hot(self.y,
                                depth=self.conf.num_cls,
                                axis=4,
                                name='y_one_hot')
         if self.conf.loss_type == 'cross-entropy':
             with tf.name_scope('cross_entropy'):
                 loss = cross_entropy(y_one_hot, self.logits,
                                      self.conf.num_cls)
         elif self.conf.loss_type == 'dice':
             with tf.name_scope('dice_coefficient'):
                 loss = dice_coeff(y_one_hot, self.logits)
         with tf.name_scope('total'):
             if self.conf.use_reg:
                 with tf.name_scope('L2_loss'):
                     l2_loss = tf.reduce_sum(self.conf.lmbda * tf.stack([
                         tf.nn.l2_loss(v)
                         for v in tf.get_collection('weights')
                     ]))
                     self.total_loss = loss + l2_loss
             else:
                 self.total_loss = loss
             self.mean_loss, self.mean_loss_op = tf.metrics.mean(
                 self.total_loss)
Пример #4
0
def evaluate(net, loader, device, n_val):
    """Evaluate with the given model on the given dataloader, without the densecrf with the dice coefficient"""
    net.eval()
    tot = 0

    with tqdm(total=n_val, desc='Validation round', unit='img',
              leave=False) as pbar:
        for batch in loader:
            imgs = batch['image']
            true_masks = batch['mask']

            imgs = imgs.to(device=device, dtype=torch.float32)
            mask_type = torch.float32 if net.n_classes == 1 else torch.long
            true_masks = true_masks.to(device=device, dtype=mask_type)

            # model forward
            with torch.no_grad():
                masks_pred = net(imgs)

            # evaluation stats
            for true_mask, pred in zip(true_masks, masks_pred):
                pred = (pred > 0.5).float()
                if net.n_classes > 1:
                    tot += F.cross_entropy(pred.unsqueeze(dim=0),
                                           true_mask.unsqueeze(dim=0)).item()
                else:
                    tot += dice_coeff(pred, true_mask.squeeze(dim=1)).item()
            pbar.update(imgs.shape[0])

    return tot / n_val
Пример #5
0
def eval_net(net, dataset, device):
    net.eval()
    tot = 0.
    with torch.no_grad():
        for i, b in tqdm.tqdm(enumerate(dataset), total=len(dataset)):
            imgs, true_masks = b
            masks_pred = net(imgs.to(device)).squeeze(
                1)  # (b, 1, h, w) -> (b, h, w)
            masks_pred = (F.sigmoid(masks_pred) > 0.5).float()
            tot += dice_coeff(masks_pred.cpu(), true_masks).item()
    return tot / len(dataset)
Пример #6
0
def evaluate(dataloader, model, thresholds, device):
    # given data, model, and threshold, what is the utility at that threshold,
    # and how much of the image is covered
    model.eval()
    model.to(device)
    dice = [[] for _ in range(len(thresholds))]
    coverage = [[] for _ in range(len(thresholds))]
    bs = []

    for (images, masks) in dataloader:
        images = images.to(device)
        masks = masks.to(device)

        _, B = model(images)
        bs.append(B)

        for i, threshold in enumerate(thresholds):
            thresh_images = images * (B <= threshold)
            masks_pred = model.util_model(thresh_images)
            dice[i].append(float(dice_coeff(masks_pred > 0.0, masks).item()))
            coverage[i].append(((B <= threshold).sum().float() / B.numel()).item())

    dice = [np.mean(dice[i]) for i in range(len(thresholds))]
    coverage = [np.mean(c) for c in coverage]

    bs = torch.cat(bs)
    median_b = torch.median(bs)
    dice_at_half_coverage = []

    for (images, masks) in dataloader:
        images = images.to(device)
        masks = masks.to(device)

        _, B = model(images)
        images = images * (B <= median_b)
        masks_pred = model.util_model(images)
        dice_at_half_coverage.append(dice_coeff(masks_pred > 0.0, masks).item())

    return dice, coverage, np.mean(dice_at_half_coverage)
Пример #7
0
    def validation_step(self, batch, batch_idx):
        images, masks = batch
        noise, B = self.forward(images)

        # forward pass through utility model
        self.util_model.eval()
        masks_pred = self.util_model(images + noise)

        loss = self.criterion(masks_pred, masks) - self.noise_coeff * torch.mean(
            B.log()
        )

        dice = dice_coeff(masks_pred > 0.0, masks)
        return {"val_dice": dice, "val_loss": loss}
Пример #8
0
def test_epoch(net, loader):
    # we transfer the mode of network to test
    net.eval()
    test_dice_meter = AvgMeter()
    for batch_idx, (data, label) in enumerate(loader):
        data = Variable(data.cuda())
        output = net(data)

        output = output.squeeze().data.cpu().numpy()
        label = label.squeeze().cpu().numpy()

        test_dice_meter.update(dice_coeff(output, label))

        print("Test {} || Dice: {:.4f}".format(str(batch_idx).zfill(4), test_dice_meter.val))
    return test_dice_meter.avg
Пример #9
0
 def validation_step(self, batch, batch_idx):
     images, masks = batch
     masks_pred = self.model(images)
     loss = self.criterion(masks_pred, masks)
     dice = dice_coeff(masks_pred > 0.0, masks)
     return {"val_loss": loss, "val_dice": dice}
Пример #10
0
def evaluate_test(model, test_ds, num_test_examples, cspace, epochs, save_model_path=None, type_train='',write_images=True, it=0):
    if (save_model_path != None):
        model = models.load_model(
            save_model_path,
            custom_objects={
                'bce_dice_loss': bce_dice_loss,
                'dice_loss': dice_loss
            })
    # Let's visualize some of the outputs
    mjccard = 0
    score = 0
    v_jaccard = np.zeros(num_test_examples)
    v_sensitivity = np.zeros(num_test_examples)
    v_specificity = np.zeros(num_test_examples)
    v_accuracy = np.zeros(num_test_examples)
    v_dice = np.zeros(num_test_examples)

    crf_jaccard = np.zeros(num_test_examples)
    crf_sensitivity = np.zeros(num_test_examples)
    crf_specificity = np.zeros(num_test_examples)
    crf_accuracy = np.zeros(num_test_examples)
    crf_dice = np.zeros(num_test_examples)

    data_aug_iter = test_ds.make_one_shot_iterator()
    next_element = data_aug_iter.get_next()
    if(not os.path.exists('pos_results/' + type_train + cspace + '/' + str(epochs) + '/' + str(it) + '/predict/')):
            os.makedirs('pos_results/' + type_train + cspace + '/' + str(epochs) + '/' + str(it) + '/predict/')
    for j in range(num_test_examples):
        # Running next element in our graph will produce a batch of images
        batch_of_imgs, label = tf.keras.backend.get_session().run(next_element)
        img = batch_of_imgs[0]

        predicted_label = model.predict(batch_of_imgs)[0]
        mpimg.imsave('pos_results/' + type_train + cspace + '/' + str(epochs) + '/' + str(it) + '/predict/' + str(j) + '.png', predicted_label[:,:,0])
        mask_pred = (predicted_label[:, :, 0] > 0.55).astype(int)
        label = label.astype(int)

        v_jaccard[j] = fjaccard(label[0, :, :, 0], mask_pred)
        v_sensitivity[j] = utils.sensitivity(label[0,:,:,0], mask_pred)
        v_specificity[j] = utils.specificity(label[0,:,:,0], mask_pred)
        v_accuracy[j] = utils.accuracy(label[0,:,:,0], mask_pred)
        v_dice[j] = utils.dice_coeff(label[0,:,:,0], mask_pred)
        score += v_jaccard[j] if v_jaccard[j] >= 0.65 else 0
        print(score)
        mjccard += v_jaccard[j]

        img_rgb = img[:, :, :3]

        if(cspace == 'HSV'):
            img_rgb = tf.keras.backend.get_session().run(tf.image.hsv_to_rgb(img_rgb))
        elif(cspace == 'LAB'):
            img_rgb = tf.keras.backend.get_session().run(Conv_img.lab_to_rgb(img_rgb))

        crf_mask = utils.dense_crf(np.array(img_rgb*255).astype(np.uint8), np.array(predicted_label[:, :, 0]).astype(np.float32))

        crf_jaccard[j] = fjaccard(label[0, :, :, 0], crf_mask)
        crf_sensitivity[j] = utils.sensitivity(label[0,:,:,0], crf_mask)
        crf_specificity[j] = utils.specificity(label[0,:,:,0], crf_mask)
        crf_accuracy[j] = utils.accuracy(label[0,:,:,0], crf_mask)
        crf_dice[j] = utils.dice_coeff(label[0,:,:,0], crf_mask)

        if(write_images):
            fig = plt.figure(figsize=(25, 25))

            plt.subplot(1, 4, 1)
            plt.imshow(img[:, :, :3])
            plt.title("Input image")
            
            plt.subplot(1, 4, 2)
            plt.imshow(label[0, :, :, 0])
            plt.title("Actual Mask")
            
            plt.subplot(1, 4, 3)
            plt.imshow(predicted_label[:, :, 0] > 0.55)
            plt.title("Predicted Mask\n" +
                        "Jaccard = " + str(v_jaccard[j]) +
                        '\nSensitivity = ' + str(v_sensitivity[j]) +
                        '\nSpecificity = ' + str(v_specificity[j]) +
                        '\nAccuracy = ' + str(v_accuracy[j]) +
                        '\nDice = ' + str(v_dice[j]))
            
            plt.subplot(1, 4, 4)
            plt.imshow(crf_mask)
            plt.title("CRF Mask\n" +
                        "Jaccard = " + str(crf_jaccard[j]) +
                        '\nSensitivity = ' + str(crf_sensitivity[j]) +
                        '\nSpecificity = ' + str(crf_specificity[j]) +
                        '\nAccuracy = ' + str(crf_accuracy[j]) +
                        '\nDice = ' + str(crf_dice[j]))
            
            fig.savefig(
                'pos_results/' + type_train + cspace + '/' + str(epochs) + '/' + str(it) + '/' + str(j) + '.png',
                bbox_inches='tight')
            plt.close(fig)
            mpimg.imsave('pos_results/' + type_train + cspace + '/' + str(epochs) + '/' + str(it) + '/predict/' + str(j) + '.png', predicted_label[:,:,0])
            plt.close()

    mjccard /= num_test_examples
    score /= num_test_examples
    np.savetxt('pos_results/' + type_train + cspace + '/' + str(epochs) + '/' + str(it) + '/jaccard', v_jaccard)
    np.savetxt('pos_results/' + type_train + cspace + '/' + str(epochs) + '/' + str(it) + '/sensitivity', v_sensitivity)
    np.savetxt('pos_results/' + type_train + cspace + '/' + str(epochs) + '/' + str(it) + '/specificity', v_specificity)
    np.savetxt('pos_results/' + type_train + cspace + '/' + str(epochs) + '/' + str(it) + '/accuracy', v_accuracy)
    np.savetxt('pos_results/' + type_train + cspace + '/' + str(epochs) + '/' + str(it) + '/dice', v_dice)
    with open('pos_results/' + type_train + cspace + '/' + str(epochs)  + '/' + str(it) + '/score','w') as f:
        f.write('Score = ' + str(score) +
        '\nSensitivity = ' + str(np.mean(v_sensitivity)) +
        '\nSpecificity = ' + str(np.mean(v_specificity)) +
        '\nAccuracy = ' + str(np.mean(v_accuracy)) +
        '\nDice = ' + str(np.mean(v_dice)) +
        '\nJaccars = ' + str(np.mean(v_jaccard)))

    np.savetxt('pos_results/' + type_train + cspace + '/' + str(epochs) + '/' + str(it) + '/crf_jaccard', crf_jaccard)
    np.savetxt('pos_results/' + type_train + cspace + '/' + str(epochs) + '/' + str(it) + '/crf_sensitivity', crf_sensitivity)
    np.savetxt('pos_results/' + type_train + cspace + '/' + str(epochs) + '/' + str(it) + '/crf_crf_specificity', crf_specificity)
    np.savetxt('pos_results/' + type_train + cspace + '/' + str(epochs) + '/' + str(it) + '/crf_accuracy', crf_accuracy)
    np.savetxt('pos_results/' + type_train + cspace + '/' + str(epochs) + '/' + str(it) + '/crf_dice', crf_dice)
    with open('pos_results/' + type_train + cspace + '/' + str(epochs)  + '/' + str(it) + '/crf_score','w') as f:
        f.write('Sensitivity = ' + str(np.mean(crf_sensitivity)) +
        '\nSpecificity = ' + str(np.mean(crf_specificity)) +
        '\nAccuracy = ' + str(np.mean(crf_accuracy)) +
        '\nDice = ' + str(np.mean(crf_dice)) +
        '\nJaccars = ' + str(np.mean(crf_jaccard)))

    print('Jccard = ' + str(mjccard))
    print('Score = ' + str(score))
    return mjccard, score
Пример #11
0
def train(opt, model_name):
    if model_name == 'pix2pix':  #This is for Arpit's version of Pix2Pix
        model = Pix2Pix(opt)
        # best_model_wts = copy.deepcopy(model.state_dict())
        # best_acc = 0.0
        for epoch in range(opt.epochs):
            since = time.time()
            print('Epoch ' + str(epoch) + ' running')
            for phase in range(2):
                val_dice = 0
                count = 0
                for i, Data in enumerate(dataloader[phase]):
                    inputs, masks = Data
                    inputs, masks = inputs.to(device), masks.to(device)
                    inputs = normalization(inputs)  ##Changes made here
                    masks = normalization(masks)
                    inputs, masks = Variable(inputs), Variable(
                        masks)  ##Ye Variable kyu likha hai ek baar batana
                    Data = inputs, masks  ## --> kyuki isse computation fast ho jaata
                    ## memoization ki vajah se
                    with torch.set_grad_enabled(phase == 0):
                        model.get_input(Data)
                        if phase == 0:
                            model.optimize()

                        else:
                            pred_mask = model.forward(inputs)

                            for j in range(pred_mask.size()[0]):
                                cv2.imwrite(
                                    os.path.join(
                                        '../results/pred_masks',
                                        'mask_{}_{}_{}.png'.format(
                                            i, j, epoch)),
                                    np.array(
                                        denormalize(pred_mask[j]).cpu().detach(
                                        )).reshape(256, 256, 3))
                                cv2.imwrite(
                                    os.path.join(
                                        '../results/inputs',
                                        'input_{}_{}_{}.png'.format(
                                            i, j, epoch)),
                                    np.array(
                                        denormalize(
                                            inputs[j]).cpu().detach()).reshape(
                                                256, 256, 3))

                            val_dice += dice_coeff(
                                denormalize(pred_mask, flag=1),
                                denormalize(masks, flag=1))
                            count += 1
            print("Validation Dice Coefficient is " + str(val_dice / count))
            time_elapsed = time.time() - since
            print('Epoch completed in {:.0f}m {:.0f}s'.format(
                time_elapsed // 60, time_elapsed % 60))

    elif model_name == 'CycleGAN':
        model = cycleGan(cg_opt)
        print_freq = 10
        train_iter = iter(cg_train_loader)
        val_iter = iter(cg_val_loader)
        fixed_X, fixed_Y = val_iter.next()
        fixed_X = normalization(fixed_X).to(device)
        fixed_Y = normalization(fixed_Y).to(device)
        loss_Gl = []
        loss_DXl = []
        loss_DYl = []

        num_batches = len(train_iter)
        for epoch in range(200):
            if epoch == 35:
                model.change_lr(model.opt.lr / 2)

            if epoch == 80:
                model.change_lr(model.opt.lr / 2)

            if epoch == 130:
                model.change_lr(model.opt.lr / 2)

            since = time.time()
            print("Epoch ", epoch, " entering ")
            train_iter = iter(cg_train_loader)
            for batch in range(num_batches):
                print("Epoch ", epoch, "Batch ", batch,
                      " running with learning rate ", model.opt.lr)
                inputX, inputY = train_iter.next()
                inputX = normalization(inputX).to(device)
                inputY = normalization(inputY).to(device)
                model.get_input(inputX, inputY)
                model.optimize()
                # print("Dx Loss : {:.6f} Dy Loss: {:.6f} Generator Loss: {:.6f} ".format(model.dx_loss, model.dy_loss, model.gen_loss))
                print("Model dx loss ", float(model.loss_D_X), "Model dy loss",
                      float(model.loss_D_Y), "model_gen_loss",
                      float(model.loss_G))

            if (epoch + 1) % 10 == 0:
                # torch.set_grad_enabled(False)
                depth_map = model.G_XtoY.forward(fixed_X)
                for j in range(depth_map.size()[0]):
                    if cg_opt.n_blocks == 6:
                        cv2.imwrite(
                            os.path.join(
                                '../cgresults/pred_masks',
                                'mask_{}_{}_{}.png'.format(batch, j, epoch)),
                            np.array(denormalize(
                                depth_map[j]).cpu().detach()).reshape(
                                    256, 256, 3))
                        if epoch == 9:
                            cv2.imwrite(
                                os.path.join(
                                    '../cgresults/inputs',
                                    'input_{}_{}_{}.png'.format(
                                        batch, j, epoch)),
                                np.array(
                                    denormalize(
                                        fixed_X[j]).cpu().detach()).reshape(
                                            256, 256, 3))
                    else:
                        cv2.imwrite(
                            os.path.join(
                                '../cgresults/r-9-pred_masks',
                                'mask_{}_{}_{}.png'.format(batch, j, epoch)),
                            np.array(denormalize(
                                depth_map[j]).cpu().detach()).reshape(
                                    256, 256, 3))
                        if epoch == 9:
                            cv2.imwrite(
                                os.path.join(
                                    '../cgresults/r-9-inputs',
                                    'input_{}_{}_{}.png'.format(
                                        batch, j, epoch)),
                                np.array(
                                    denormalize(
                                        fixed_X[j]).cpu().detach()).reshape(
                                            256, 256, 3))

                # torch.set_grad_enabled(True)

            print("Time to finish epoch ", time.time() - since)

            torch.save(model, '../CGmodel/best_model5.pt')
            loss_Gl.append(float(model.loss_G))
            loss_DXl.append(float(model.loss_D_X))
            loss_DYl.append(float(model.loss_D_Y))
            with open('../CGloss/lossG5.pk', 'wb') as f:
                pickle.dump(loss_Gl, f)
            with open('../CGloss/lossD_X5.pk', 'wb') as f:
                pickle.dump(loss_DXl, f)
            with open('../CGloss/lossd_Y5.pk', 'wb') as f:
                pickle.dump(loss_DYl, f)

    elif model_name == 'P2P':  # This is for Khem's version of Pix2Pix
        model = Pix2Pix(p2p_opt)
        print_freq = 10
        train_iter = iter(p2p_train_loader)
        val_iter = iter(p2p_val_loader)
        fixed_X, fixed_Y = val_iter.next()
        fixed_X = normalization(fixed_X).to(device)
        fixed_Y = normalization(fixed_Y).to(device)
        loss_Gl = []
        loss_Dl = []

        num_batches = len(train_iter)
        for epoch in range(3000):

            if epoch == 299:
                model.change_lr(model.opt.lr / 2)

            if epoch == 499:
                model.change_lr(model.opt.lr / 2)

            since = time.time()
            print("Epoch ", epoch, " entering ")
            train_iter = iter(p2p_train_loader)
            for batch in range(num_batches):
                print("Epoch ", epoch, "Batch ", batch,
                      " running with learning rate ", model.opt.lr)
                inputX, inputY = train_iter.next()
                inputX = normalization(inputX).to(device)
                inputY = normalization(inputY).to(device)
                model.get_input(inputX, inputY)
                model.optimize()
                # print("Dx Loss : {:.6f} Dy Loss: {:.6f} Generator Loss: {:.6f} ".format(model.dx_loss, model.dy_loss, model.gen_loss))
                print("Model D Loss ", float(model.loss_D), "Model G loss",
                      float(model.loss_G))

            if (epoch + 1) % 10 == 0:
                # torch.set_grad_enabled(False)
                depth_map = model.G.forward(fixed_X)
                for j in range(depth_map.size()[0]):
                    cv2.imwrite(
                        os.path.join(
                            '../p2presults/pred_masks',
                            'mask_{}_{}_{}.png'.format(batch, j, epoch)),
                        np.array(denormalize(
                            depth_map[j]).cpu().detach()).reshape(256, 256, 3))
                    if epoch == 9:
                        cv2.imwrite(
                            os.path.join(
                                '../p2presults/inputs',
                                'input_{}_{}_{}.png'.format(batch, j, epoch)),
                            np.array(denormalize(
                                fixed_X[j]).cpu().detach()).reshape(
                                    256, 256, 3))
                        cv2.imwrite(
                            os.path.join(
                                '../p2presults/inputs',
                                'ground_depth_{}_{}_{}.png'.format(
                                    batch, j, epoch)),
                            np.array(denormalize(
                                fixed_Y[j]).cpu().detach()).reshape(
                                    256, 256, 3))

                # torch.set_grad_enabled(True)

            print("Time to finish epoch ", time.time() - since)

            torch.save(model, '../P2Pmodel/best_model8.pt')
            loss_Gl.append(float(model.loss_G))
            loss_Dl.append(float(model.loss_D))
            with open('../P2Ploss/lossG8.pk', 'wb') as f:
                pickle.dump(loss_Gl, f)
            with open('../P2Ploss/lossD8.pk', 'wb') as f:
                pickle.dump(loss_Dl, f)
Пример #12
0
def main():
    #read/parse user command line input
    parser = argparse.ArgumentParser()

    parser.add_argument("-dataset ", dest="dataset", help="either tcia or visceral", default='tcia', required=True)
    #parser.add_argument("-fold", dest="fold", help="number of training fold", default=1, required=True)
    parser.add_argument("-model", dest="model", help="filename of pytorch pth model", default='obeliskhybrid', required=True)
    parser.add_argument("-input", dest="input",  help="nii.gz CT volume to segment", required=True)
    parser.add_argument("-output", dest="output",  help="nii.gz label output prediction", default=None, required=True)
    parser.add_argument("-groundtruth", dest="groundtruth",  help="nii.gz groundtruth segmentation", default=None, required=False)

    options = parser.parse_args()
    d_options = vars(options)
    modelfilename = os.path.basename(d_options['model'])
    modelname = split_at(modelfilename, '_', 1)[0]
    print('input CT image',d_options['input'],'\n   and model name',modelname,'for dataset',d_options['dataset'])
    
    img_val =  torch.from_numpy(nib.load(d_options['input']).get_data()).float().unsqueeze(0).unsqueeze(0)
    
    load_successful = False
    if(d_options['dataset']=='tcia'):
        if(modelname=='obeliskhybrid'):
            img_val = img_val/1024.0 + 1.0 #scale data
            net = obeliskhybrid_tcia(9) #has 8 anatomical foreground labels
            net.load_state_dict(torch.load(d_options['model']))
            load_successful = True
        if(modelname=='allconvunet'):
            #no scaling done for unet models
            net = allconvunet_tcia(9) #has 8 anatomical foreground labels
            net.load_state_dict(torch.load(d_options['model']))
            load_successful = True
        if(modelname=='globalfcnet'):
            net = globalfcnet_tcia(9) #has 8 anatomical foreground labels
            net.load_state_dict(torch.load(d_options['model']))
            load_successful = True

    if(d_options['dataset']=='visceral'):
        _,_,D_in0,H_in0,W_in0 = img_val.size()
        if(modelname=='obeliskhybrid'):
            img_val = img_val/1000.0
            with torch.no_grad():
            #subsample by factor of 2 (higher resolution in our original data)
                img_val = F.avg_pool3d(img_val,3,padding=1,stride=2)
            _,_,D_in1,H_in1,W_in1 = img_val.size()
            full_res = torch.Tensor([D_in1,H_in1,W_in1]).long()
            net = obeliskhybrid_visceral(8,full_res) #has 7 anatomical foreground labels
            net.load_state_dict(torch.load(d_options['model']))
            load_successful = True
        if(modelname=='obelisk'):
            img_val = img_val/500.0
            img_val = F.avg_pool3d(img_val,5,stride=1,padding=2)
            img_val = F.avg_pool3d(img_val,5,stride=1,padding=2)
            img_val = F.avg_pool3d(img_val,3,stride=1,padding=1)
            _,_,D_in1,H_in1,W_in1 = img_val.size()
            full_res = torch.Tensor([D_in1,H_in1,W_in1]).long()
            net = obelisk_visceral(8,full_res) #has 7 anatomical foreground labels
            net.load_state_dict(torch.load(d_options['model']))
            load_successful = True
        if(modelname=='allconvunet'):
            with torch.no_grad():
            #subsample by factor of 2 (higher resolution in our original data)
                img_val = F.avg_pool3d(img_val,3,padding=1,stride=2)
            net = allconvunet_visceral() #has 7 anatomical foreground labels
            net.load_state_dict(torch.load(d_options['model']))
            load_successful = True
        if(modelname=='globalfcnet'):
            net = globalfcnet_visceral() #has 7 anatomical foreground labels
            net.load_state_dict(torch.load(d_options['model']))
            load_successful = True
        

    if(load_successful):
        print('read in model with',countParam(net),'parameters')
    else:
        print('model',modelname,'for dataset',d_options['dataset'],'not yet supported. exit()')
        exit()
       
    net.eval()

    if(torch.cuda.is_available()==1):
        print('using GPU acceleration')
        img_val = img_val.cuda()
        net.cuda()
    with torch.no_grad():
        predict = net(img_val)
        if(d_options['dataset']=='visceral'):
            predict = F.interpolate(predict,size=[D_in0,H_in0,W_in0], mode='trilinear', align_corners=False)

    argmax = torch.argmax(predict,dim=1)
    seg_img = nib.Nifti1Image(argmax.cpu().short().squeeze().numpy(), np.eye(4))
    print('saving nifti file with labels')
    nib.save(seg_img, d_options['output'])
       
    if d_options['groundtruth'] is not None:
        seg_val =  torch.from_numpy(nib.load(d_options['groundtruth']).get_data()).long().unsqueeze(0)
        dice = dice_coeff(argmax.cpu(), seg_val, predict.size(1)).numpy()
        np.set_printoptions(formatter={'float': '{: 0.3f}'.format})
        print('Dice validation:',dice,'Avg.','%0.3f'%(dice.mean()))
Пример #13
0
def main():
    #read/parse user command line input
    parser = argparse.ArgumentParser()

    parser.add_argument("-folder",
                        dest="folder",
                        help="training dataset folder",
                        default='TCIA_CT',
                        required=True)
    parser.add_argument(
        "-scannumbers",
        dest="scannumbers",
        help="list of integers indicating which scans to use, \"1 2 3\" ",
        default=1,
        required=True,
        type=lambda s: [int(n) for n in s.split()])
    parser.add_argument(
        "-filescan",
        dest="filescan",
        help="prototype scan filename i.e. pancreas_ct?.nii.gz",
        default='pancreas_ct?.nii.gz',
        required=True)
    parser.add_argument(
        "-fileseg",
        dest="fileseg",
        help="prototype segmentation name i.e. label_ct?.nii.gz",
        required=True)
    parser.add_argument("-output",
                        dest="output",
                        help="filename (without extension) for output",
                        default=None,
                        required=True)
    #parser.add_argument("-groundtruth", dest="groundtruth",  help="nii.gz groundtruth segmentation", default=None, required=False)

    options = parser.parse_args()
    d_options = vars(options)
    #modelfilename = os.path.basename(d_options['model'])
    #modelname = split_at(modelfilename, '_', 1)[0]

    sys.stdout = Logger(d_options['output'] + '_log.txt')

    # load train images and segmentations
    imgs = []
    segs = []
    scannumbers = d_options['scannumbers']
    print('scannumbers', scannumbers)
    if (d_options['filescan'].find("?") == -1):
        print('error filescan must contain \"?\" to insert numbers')
        exit()
    filesplit = split_at(d_options['filescan'], '?', 1)
    filesegsplit = split_at(d_options['fileseg'], '?', 1)

    for i in range(0, len(scannumbers)):
        #/share/data_rechenknecht01_1/heinrich/TCIA_CT
        filescan1 = filesplit[0] + str(scannumbers[i]) + filesplit[1]
        img = nib.load(os.path.join(d_options['folder'], filescan1)).get_data()
        fileseg1 = filesegsplit[0] + str(scannumbers[i]) + filesegsplit[1]
        seg = nib.load(os.path.join(d_options['folder'], fileseg1)).get_data()
        imgs.append(torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float())
        segs.append(torch.from_numpy(seg).unsqueeze(0).long())

    imgs = torch.cat(imgs, 0)
    segs = torch.cat(segs, 0)
    imgs = imgs / 1024.0 + 1.0  #scale data

    numEpoches = 300  #1000
    batchSize = 4

    print('data loaded')

    class_weight = torch.sqrt(1.0 / (torch.bincount(segs.view(-1)).float()))
    class_weight = class_weight / class_weight.mean()
    class_weight[0] = 0.5
    np.set_printoptions(formatter={'float': '{: 0.2f}'.format})
    print('inv sqrt class_weight', class_weight.data.cpu().numpy())

    num_labels = int(class_weight.numel())

    D_in1 = imgs.size(2)
    H_in1 = imgs.size(3)
    W_in1 = imgs.size(4)
    #full resolution
    full_res = torch.Tensor([D_in1, H_in1, W_in1]).long()

    net = obeliskhybrid_tcia(num_labels)
    net.apply(init_weights)
    print('obelisk params', countParam(net))

    print('initial offset std', '%.3f' % (torch.std(net.offset1.data).item()))
    net.cuda(cuda_idx)

    #criterion = nn.CrossEntropyLoss()#
    my_criterion = my_ohem(.25, class_weight.cuda())  #0.25

    optimizer = optim.Adam(net.parameters(), lr=0.002, weight_decay=0.00001)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.99)

    run_loss = np.zeros(300)

    dice_epoch = np.zeros((imgs.size(0), num_labels - 1, 300))
    fold_size = imgs.size(0)
    fold_size4 = fold_size - fold_size % 4
    print('fold/batch sizes', fold_size, fold_size4, imgs.size(0))
    #for loop over iterations and epochs
    for epoch in range(300):

        net.train()

        run_loss[epoch] = 0.0
        t1 = 0.0

        idx_epoch = torch.randperm(fold_size)[:fold_size4].view(4, -1)
        t0 = time.time()

        for iter in range(idx_epoch.size(1)):
            idx = idx_epoch[:, iter]

            with torch.no_grad():
                imgs_cuda, y_label = augmentAffine(
                    imgs[idx, :, :, :, :].cuda(),
                    segs[idx, :, :, :].cuda(),
                    strength=0.075)
                torch.cuda.empty_cache()

            optimizer.zero_grad()

            #forward path and loss
            predict = net(imgs_cuda)

            loss = my_criterion(F.log_softmax(predict, dim=1), y_label)
            loss.backward()

            run_loss[epoch] += loss.item()
            optimizer.step()
            del loss
            del predict
            torch.cuda.empty_cache()
            del imgs_cuda
            del y_label
            torch.cuda.empty_cache()
        scheduler.step()

        #evaluation on training images
        t1 = time.time() - t0
        net.eval()

        if (epoch % 3 == 0):
            for testNo in range(imgs.size(0)):
                imgs_cuda = (imgs[testNo:testNo + 1, :, :, :, :]).cuda()

                t0 = time.time()
                predict = net(imgs_cuda)

                argmax = torch.max(predict, dim=1)[1]
                torch.cuda.synchronize()
                time_i = (time.time() - t0)
                dice_all = dice_coeff(argmax.cpu(),
                                      segs[testNo:testNo + 1, :, :, :],
                                      num_labels)
                dice_epoch[testNo, :, epoch] = dice_all.numpy()
                #del output_test
                del predict
                del imgs_cuda
                torch.cuda.empty_cache()

            #print some feedback information
            print('epoch', epoch, 'time train', '%.3f' % t1, 'time inf',
                  '%.3f' % time_i, 'loss', '%.3f' % (run_loss[epoch]),
                  'stddev', '%.3f' % (torch.std(net.offset1.data)))
            np.set_printoptions(formatter={'float': '{: 0.3f}'.format})
            print('dice_avgs (training)',
                  (np.nanmean(dice_epoch[:, :, epoch], 0) * 100.0))
            sys.stdout.saveCurrentResults()
            arr = {}
            arr['dice_epoch'] = dice_epoch  #.numpy()

            scipy.io.savemat(d_options['output'] + '.mat', arr)

        if (epoch % 6 == 0):

            net.cpu()

            torch.save(net.state_dict(), d_options['output'] + '.pth')

            net.cuda()
Пример #14
0
def val(model, dataloader):
    # ============================ Prepare Metrics ==========================
    T = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
    cm = {
        '0.1': [[0, 0], [0, 0]],
        '0.2': [[0, 0], [0, 0]],
        '0.3': [[0, 0], [0, 0]],
        '0.4': [[0, 0], [0, 0]],
        '0.5': [[0, 0], [0, 0]],
        '0.6': [[0, 0], [0, 0]],
        '0.7': [[0, 0], [0, 0]],
        '0.8': [[0, 0], [0, 0]],
        '0.9': [[0, 0], [0, 0]]
    }
    dice = []
    softmax = functional.softmax

    # ================================ Validate ==============================
    for i, (image, label, edge_mask,
            image_path) in tqdm(enumerate(dataloader)):
        # ******************* prepare input and go through the model *******************
        if config.use_gpu:
            image = image.cuda()
            label = label.cuda()
            edge_mask = edge_mask.cuda()

        score, score_mask = model(x=image)

        # *********************** confusion matrix and dice ***********************
        for p, l in zip(softmax(score, dim=1).detach(), label.detach()):
            for t in T:
                if p[1] >= t:
                    cm[str(t)][int(l)][1] += 1
                else:
                    cm[str(t)][int(l)][0] += 1
        dice.append(
            dice_coeff(input=(score_mask > 0.5).float(),
                       target=edge_mask[:, 0, :, :]).item())

    # ============================ Calculate ROC Curve and Best Threshold ==========================
    ROC = {
        str(t): [
            cm[str(t)][0][0] / (cm[str(t)][0][0] + cm[str(t)][0][1]),
            cm[str(t)][1][1] / (cm[str(t)][1][0] + cm[str(t)][1][1])
        ]
        for t in T
    }
    Best_T = sorted(ROC.items(),
                    key=lambda x: x[1][0] + x[1][1] - 1,
                    reverse=True)[0][0]
    val_accuracy = 100. * sum(
        [cm[Best_T][c][c]
         for c in range(config.num_classes)]) / np.sum(cm[Best_T])
    val_spse = [
        100. * cm[Best_T][0][0] / (cm[Best_T][0][0] + cm[Best_T][0][1]),
        100. * cm[Best_T][1][1] / (cm[Best_T][1][0] + cm[Best_T][1][1])
    ]

    # ==================================== Calculate AUC ===========================================
    AUC = 0
    for i in range(len(T) - 1):
        AUC += (ROC[str(T[i])][1] + ROC[str(T[i + 1])][1]) / 2 * (
            ROC[str(T[i + 1])][0] - ROC[str(T[i])][0])
    AUC += (1 + ROC[str(T[0])][1]) * ROC[str(T[0])][0] / 2
    AUC += ROC[str(T[-1])][1] * (1 - ROC[str(T[-1])][0]) / 2

    val_dice = sum(dice) / len(dice)

    return Best_T, cm[Best_T], val_spse, val_accuracy, AUC, val_dice
Пример #15
0
def train(**kwargs):
    config.parse(kwargs)

    # ============================================ Visualization =============================================
    vis = Visualizer(port=2333, env=config.env)
    vis.log('Use config:')
    for k, v in config.__class__.__dict__.items():
        if not k.startswith('__'):
            vis.log(f"{k}: {getattr(config, k)}")

    # ============================================= Prepare Data =============================================
    train_data = SlideWindowDataset(config.train_paths,
                                    phase='train',
                                    useRGB=config.useRGB,
                                    usetrans=config.usetrans,
                                    balance=config.data_balance)
    val_data = SlideWindowDataset(config.test_paths,
                                  phase='val',
                                  useRGB=config.useRGB,
                                  usetrans=config.usetrans,
                                  balance=False)
    print('Training Images:', train_data.__len__(), 'Validation Images:',
          val_data.__len__())
    dist = train_data.dist()
    print('Train Data Distribution:', dist)

    train_dataloader = DataLoader(train_data,
                                  batch_size=config.batch_size,
                                  shuffle=True,
                                  num_workers=config.num_workers)
    val_dataloader = DataLoader(val_data,
                                batch_size=config.batch_size,
                                shuffle=False,
                                num_workers=config.num_workers)

    # ============================================= Prepare Model ============================================
    model = UNet_Classifier(num_classes=config.num_classes)
    print(model)

    if config.load_model_path:
        model.load(config.load_model_path)
        print('Model loaded')
    if config.use_gpu:
        model.cuda()
    if config.parallel:
        model = torch.nn.DataParallel(
            model, device_ids=[x for x in range(config.num_of_gpu)])

    # =========================================== Criterion and Optimizer =====================================
    # weight = torch.Tensor([1, 1])
    # weight = torch.Tensor([dist['1']/(dist['0']+dist['1']), dist['0']/(dist['0']+dist['1'])])  # weight需要将二者反过来,多于二分类可以取倒数
    # weight = torch.Tensor([1, 3.5])
    # weight = torch.Tensor([1, 5])
    weight = torch.Tensor([1, 7])

    vis.log(f'loss weight: {weight}')
    print('loss weight:', weight)
    weight = weight.cuda()
    criterion = torch.nn.CrossEntropyLoss(weight=weight)
    lr = config.lr
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 weight_decay=config.weight_decay)

    # ================================================== Metrics ===============================================
    softmax = functional.softmax
    loss_meter_edge = meter.AverageValueMeter()
    epoch_loss_edge = meter.AverageValueMeter()
    loss_meter_cls = meter.AverageValueMeter()
    epoch_loss_cls = meter.AverageValueMeter()
    loss_meter = meter.AverageValueMeter()
    epoch_loss = meter.AverageValueMeter()
    train_cm = meter.ConfusionMeter(config.num_classes)

    # ====================================== Saving and Recording Configuration =================================
    previous_auc = 0
    if config.parallel:
        save_model_dir = config.save_model_dir if config.save_model_dir else model.module.model_name
        save_model_name = config.save_model_name if config.save_model_name else model.module.model_name + '_best_model.pth'
    else:
        save_model_dir = config.save_model_dir if config.save_model_dir else model.model_name
        save_model_name = config.save_model_name if config.save_model_name else model.model_name + '_best_model.pth'
    save_epoch = 1  # 用于记录验证集上效果最好模型对应的epoch
    process_record = {
        'epoch_loss': [],
        'epoch_loss_edge': [],
        'epoch_loss_cls': [],
        'train_avg_se': [],
        'train_se_0': [],
        'train_se_1': [],
        'val_avg_se': [],
        'val_se_0': [],
        'val_se_1': [],
        'AUC': [],
        'DICE': []
    }  # 用于记录实验过程中的曲线,便于画曲线图

    # ================================================== Training ===============================================
    for epoch in range(config.max_epoch):
        print(
            f"epoch: [{epoch + 1}/{config.max_epoch}] {config.save_model_name[:-4]} =================================="
        )
        train_cm.reset()
        epoch_loss.reset()
        dice = []

        # ****************************************** train ****************************************
        model.train()
        for i, (image, label, edge_mask,
                image_path) in tqdm(enumerate(train_dataloader)):
            loss_meter.reset()

            # ------------------------------------ prepare input ------------------------------------
            if config.use_gpu:
                image = image.cuda()
                label = label.cuda()
                edge_mask = edge_mask.cuda()

            # ---------------------------------- go through the model --------------------------------
            score, score_mask = model(x=image)

            # ----------------------------------- backpropagate -------------------------------------
            optimizer.zero_grad()

            # 分类loss
            loss_cls = criterion(score, label)
            # 对Edge包含pixel加loss
            log_prob_mask = functional.logsigmoid(score_mask)
            count_edge = torch.sum(edge_mask, dim=(1, 2, 3), keepdim=True)
            loss_edge = -1 * torch.mean(
                torch.sum(
                    edge_mask * log_prob_mask, dim=(1, 2, 3), keepdim=True) /
                (count_edge + 1e-8))

            # 对非Edge包含pixel加loss
            r_prob_mask = torch.Tensor([1.0
                                        ]).cuda() - torch.sigmoid(score_mask)
            r_edge_mask = torch.Tensor([1.0]).cuda() - edge_mask
            log_rprob_mask = torch.log(r_prob_mask + 1e-5)
            count_redge = torch.sum(r_edge_mask, dim=(1, 2, 3), keepdim=True)
            loss_redge = -1 * torch.mean(
                torch.sum(r_edge_mask * log_rprob_mask,
                          dim=(1, 2, 3),
                          keepdim=True) / (count_redge + 1e-8))

            # 权重按照前景和背景的像素点数量来算
            w1 = torch.sum(count_edge).item() / (torch.sum(count_edge).item() +
                                                 torch.sum(count_redge).item())
            w2 = torch.sum(count_redge).item() / (
                torch.sum(count_edge).item() + torch.sum(count_redge).item())
            loss = loss_cls + w1 * loss_edge + w2 * loss_redge

            loss.backward()
            optimizer.step()

            # ------------------------------------ record loss ------------------------------------
            loss_meter_edge.add((w1 * loss_edge + w2 * loss_redge).item())
            epoch_loss_edge.add((w1 * loss_edge + w2 * loss_redge).item())
            loss_meter_cls.add(loss_cls.item())
            epoch_loss_cls.add(loss_cls.item())
            loss_meter.add(loss.item())
            epoch_loss.add(loss.item())
            train_cm.add(softmax(score, dim=1).detach(), label.detach())
            dice.append(
                dice_coeff(input=(score_mask > 0.5).float(),
                           target=edge_mask[:, 0, :, :]).item())

            if (i + 1) % config.print_freq == 0:
                vis.plot_many({
                    'loss': loss_meter.value()[0],
                    'loss_edge': loss_meter_edge.value()[0],
                    'loss_cls': loss_meter_cls.value()[0]
                })

        train_se = [
            100. * train_cm.value()[0][0] /
            (train_cm.value()[0][0] + train_cm.value()[0][1]),
            100. * train_cm.value()[1][1] /
            (train_cm.value()[1][0] + train_cm.value()[1][1])
        ]
        train_dice = sum(dice) / len(dice)

        # *************************************** validate ***************************************
        model.eval()
        if (epoch + 1) % 1 == 0:
            Best_T, val_cm, val_spse, val_accuracy, AUC, val_dice = val(
                model, val_dataloader)

            # ------------------------------------ save model ------------------------------------
            if AUC > previous_auc and epoch + 1 > 5:  # 5个epoch之后,当测试集上的平均sensitivity升高时保存模型
                if config.parallel:
                    if not os.path.exists(
                            os.path.join('checkpoints', save_model_dir,
                                         save_model_name[:-4])):
                        os.makedirs(
                            os.path.join('checkpoints', save_model_dir,
                                         save_model_name[:-4]))
                    model.module.save(
                        os.path.join('checkpoints', save_model_dir,
                                     save_model_name[:-4], save_model_name))
                else:
                    if not os.path.exists(
                            os.path.join('checkpoints', save_model_dir,
                                         save_model_name[:-4])):
                        os.makedirs(
                            os.path.join('checkpoints', save_model_dir,
                                         save_model_name[:-4]))
                    model.save(
                        os.path.join('checkpoints', save_model_dir,
                                     save_model_name[:-4], save_model_name))
                previous_auc = AUC
                save_epoch = epoch + 1

            # ---------------------------------- recond and print ---------------------------------
            process_record['epoch_loss'].append(epoch_loss.value()[0])
            process_record['epoch_loss_edge'].append(
                epoch_loss_edge.value()[0])
            process_record['epoch_loss_cls'].append(epoch_loss_cls.value()[0])
            process_record['train_avg_se'].append(np.average(train_se))
            process_record['train_se_0'].append(train_se[0])
            process_record['train_se_1'].append(train_se[1])
            process_record['val_avg_se'].append(np.average(val_spse))
            process_record['val_se_0'].append(val_spse[0])
            process_record['val_se_1'].append(val_spse[1])
            process_record['AUC'].append(AUC)
            process_record['DICE'].append(val_dice)

            vis.plot_many({
                'epoch_loss': epoch_loss.value()[0],
                'epoch_loss_edge': epoch_loss_edge.value()[0],
                'epoch_loss_cls': epoch_loss_cls.value()[0],
                'train_avg_se': np.average(train_se),
                'train_se_0': train_se[0],
                'train_se_1': train_se[1],
                'val_avg_se': np.average(val_spse),
                'val_se_0': val_spse[0],
                'val_se_1': val_spse[1],
                'AUC': AUC,
                'train_dice': train_dice,
                'val_dice': val_dice
            })
            vis.log(
                f"epoch: [{epoch + 1}/{config.max_epoch}] ==============================================="
            )
            vis.log(
                f"lr: {optimizer.param_groups[0]['lr']}, loss: {round(loss_meter.value()[0], 5)}"
            )
            vis.log(
                f"train_avg_se: {round(np.average(train_se), 4)}, train_se_0: {round(train_se[0], 4)}, train_se_1: {round(train_se[1], 4)}"
            )
            vis.log(f"train_dice: {round(train_dice, 4)}")
            vis.log(
                f"val_avg_se: {round(sum(val_spse) / len(val_spse), 4)}, val_se_0: {round(val_spse[0], 4)}, val_se_1: {round(val_spse[1], 4)}"
            )
            vis.log(f"val_dice: {round(val_dice, 4)}")
            vis.log(f"AUC: {AUC}")
            vis.log(f'train_cm: {train_cm.value()}')
            vis.log(f'Best Threshold: {Best_T}')
            vis.log(f'val_cm: {val_cm}')
            print("lr:", optimizer.param_groups[0]['lr'], "loss:",
                  round(epoch_loss.value()[0], 5))
            print('train_avg_se:', round(np.average(train_se), 4),
                  'train_se_0:', round(train_se[0], 4), 'train_se_1:',
                  round(train_se[1], 4))
            print('train_dice:', train_dice)
            print('val_avg_se:', round(np.average(val_spse), 4), 'val_se_0:',
                  round(val_spse[0], 4), 'val_se_1:', round(val_spse[1], 4))
            print('val_dice:', val_dice)
            print('AUC:', AUC)
            print('train_cm:')
            print(train_cm.value())
            print('Best Threshold:', Best_T, 'val_cm:')
            print(val_cm)

            # ------------------------------------ save record ------------------------------------
            if os.path.exists(
                    os.path.join('checkpoints', save_model_dir,
                                 save_model_name.split('.')[0])):
                write_json(file=os.path.join('checkpoints', save_model_dir,
                                             save_model_name[:-4],
                                             'process_record.json'),
                           content=process_record)
        # if (epoch+1) % 20 == 0:
        #     lr = lr * config.lr_decay
        #     for param_group in optimizer.param_groups:
        #         param_group['lr'] = lr

    vis.log(f"Best Epoch: {save_epoch}")
    print("Best Epoch:", save_epoch)