示例#1
0
文件: evaluate.py 项目: SHX9610/voxel
def evaluate_npy(npyfile, csv_path, device, model):
    """
    for npy FIRE
    EVALUATE
    """
    # set up
    with open(csv_path, 'r') as f:
        dice_all = 0
        reader = csv.reader(f)
        for row in reader[1:]:
            fixed_id_name = npyfile + row[0] + '.npy'
            moving_id_name = npyfile + row[1] + '.npy'
            refs = np.load(fixed_id_name)[np.newaxis, ..., np.newaxis]
            movs = np.load(moving_id_name)[np.newaxis, ..., np.newaxis]
            input_fixed = torch.from_numpy(refs).to(device).float()
            input_fixed = input_fixed.permute(0, 3, 1, 2)
            input_moving = torch.from_numpy(movs).to(device).float()
            input_moving = input_moving.permute(0, 3, 1, 2)
            # Use this to warp segments
            # trf = SpatialTransformer(input_fixed.shape[2:], mode='nearest')
            # trf.to(device)
            warp, flow = model(input_moving, input_fixed)
            # 位移向量场的可视化
            # addimage(input_fixed,input_moving,warp,k)   # 可视化结果
            dice_score = metrics.dice_score(warp, input_fixed)
            dice_all += dice_score
            print('总相似性度量dice:', dice_all)
示例#2
0
文件: evaluate.py 项目: SHX9610/voxel
def evaluate_hippocampusMRI(niifile, csv_path, device, model):
    """
    for hippocampus MRI
    EVALUATE
    """
    # set up
    with open(csv_path, 'r') as f:
        dice_all = 0
        reader = csv.reader(f)
        for row in reader[1:]:
            fixed_id_name = niifile + 'hippocampus_' + row[0] + '.nii.gz'
            moving_id_name = niifile + 'hippocampus_' + row[1] + '.nii.gz'
            X = sitk.ReadImage(fixed_id_name)
            X = sitk.GetArrayFromImage(X)[np.newaxis, np.newaxis, ...]
            Y = sitk.ReadImage(moving_id_name)
            Y = sitk.GetArrayFromImage(Y)[np.newaxis, np.newaxis, ...]
            input_fixed = torch.from_numpy(X).to(device).float()
            input_fixed = input_fixed.permute(0, 3, 1, 2)
            input_moving = torch.from_numpy(Y).to(device).float()
            input_moving = input_moving.permute(0, 3, 1, 2)

            warp, flow = model(input_moving, input_fixed)
            dice_score = metrics.dice_score(warp, input_fixed)
            dice_all += dice_score
            print('相似性度量dice:', row[0], 'and', row[1], dice_all)
示例#3
0
def forward(model, loader, criterion, optimizer=None, force_cpu=False):
    if optimizer is None:
        model.eval()
    else:
        model.train()

    metrics = DefaultOrderedDict(list)

    for inputs, outputs in tqdm(loader, position=0, leave=True):
        if not force_cpu:
            inputs = inputs.cuda(non_blocking=True)
            outputs = outputs.cuda(non_blocking=True)

        prediction = model(inputs)
        loss = criterion(prediction, outputs)

        metrics['loss'].append(loss.item())

        hard_prediction = (torch.sigmoid(prediction) > 0.5).float()
        predicted_volumes = get_batch_volume(hard_prediction)
        true_volumes = get_batch_volume(outputs)

        metrics['volume_error'].append(
            (torch.abs(predicted_volumes - true_volumes) /
             (true_volumes + 1e-5)).mean().item())

        dice = dice_score(hard_prediction, outputs).item()
        metrics['dice'].append(dice)

        if optimizer is not None:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    return OrderedDict({k: np.mean(v) for (k, v) in metrics.items()})
示例#4
0
def test(gpu, ref_dir, mov_dir, model, init_model_file):
    """
    model training function
    :param gpu: integer specifying the gpu to use
    :param atlas_file: atlas filename. So far we support npz file with a 'vol' variable
    :param model: either vm1 or vm2 (based on CVPR 2018 paper)
    :param init_model_file: the model directory to load from
    """

    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    device = "cuda"

    # Prepare the vm1 or vm2 model and send to device
    nf_enc = [16, 32, 32, 32]
    if model == "vm1":
        nf_dec = [32, 32, 32, 32, 8, 8]
    elif model == "vm2":
        nf_dec = [32, 32, 32, 32, 32, 16, 16]

    # Set up model
    vol_size = [2912, 2912]
    model = cvpr2018_net(vol_size, nf_enc, nf_dec)
    model.to(device)
    model.load_state_dict(
        torch.load(init_model_file, map_location=lambda storage, loc: storage))

    # set up
    ref_vol_names = glob.glob(os.path.join(ref_dir, '*.npy'))
    mov_vol_names = glob.glob(os.path.join(mov_dir, '*npy'))
    nums = len(ref_vol_names)

    for k in range(0, nums):
        refs, movs = datagenerators.example_gen(ref_vol_names,
                                                mov_vol_names,
                                                batch_size=1)
        input_fixed = torch.from_numpy(refs).to(device).float()
        input_fixed = input_fixed.permute(0, 3, 1, 2)
        input_moving = torch.from_numpy(movs).to(device).float()
        input_moving = input_moving.permute(0, 3, 1, 2)

        # Use this to warp segments
        # trf = SpatialTransformer(input_fixed.shape[2:], mode='nearest')
        # trf.to(device)
        warp, flow = model(input_moving, input_fixed)
        flow_save = sitk.GetImageFromArray(flow.cpu().detach().numpy())
        # sitk.WriteImage(flow_save,'D:\peizhunsd\data\\flow_img\\' + str(k) + '.nii')

        # 位移向量场的可视化
        # addimage(input_fixed,input_moving,warp,k)   # 可视化结果
        dice_score = metrics.dice_score(warp, input_fixed)
        print('相似性度量dice:', dice_score)
示例#5
0
def test_npy(gpu, npyfile, csv_path, model, init_model_file):
    """
    for npy FIRE
    model training function
    :param gpu: integer specifying the gpu to use
    :param atlas_file: atlas filename. So far we support npz file with a 'vol' variable
    :param model: either vm1 or vm2 (based on CVPR 2018 paper)
    :param init_model_file: the model directory to load from
    """

    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    device = "cuda"

    # Prepare the vm1 or vm2 model and send to device
    nf_enc = [16, 32, 32, 32]
    if model == "vm1":
        nf_dec = [32, 32, 32, 32, 8, 8]
    elif model == "vm2":
        nf_dec = [32, 32, 32, 32, 32, 16, 16]

    # Set up model
    vol_size = [2912, 2912]
    model = cvpr2018_net(vol_size, nf_enc, nf_dec)
    model.to(device)
    model.load_state_dict(
        torch.load(init_model_file, map_location=lambda storage, loc: storage))

    # set up
    with open(csv_path, 'r') as f:
        reader = csv.reader(f)
        for row in reader[1:]:
            fixed_id_name = npyfile + row[0] + '.npy'
            moving_id_name = npyfile + row[1] + '.npy'
            refs = np.load(fixed_id_name)[np.newaxis, ..., np.newaxis]
            movs = np.load(moving_id_name)[np.newaxis, ..., np.newaxis]
            input_fixed = torch.from_numpy(refs).to(device).float()
            input_fixed = input_fixed.permute(0, 3, 1, 2)
            input_moving = torch.from_numpy(movs).to(device).float()
            input_moving = input_moving.permute(0, 3, 1, 2)
            # Use this to warp segments
            # trf = SpatialTransformer(input_fixed.shape[2:], mode='nearest')
            # trf.to(device)
            warp, flow = model(input_moving, input_fixed)
            # 位移向量场的可视化
            # addimage(input_fixed,input_moving,warp,k)   # 可视化结果
            dice_score = metrics.dice_score(warp, input_fixed)
            print('相似性度量dice:', row[0], 'and', row[1], dice_score)
示例#6
0
def test_hippocampusMRI(gpu, niifile, csv_path, model, init_model_file):
    """
    for hippocampus MRI
    model training function
    :param gpu: integer specifying the gpu to use
    :param atlas_file: atlas filename. So far we support npz file with a 'vol' variable
    :param model: either vm1 or vm2 (based on CVPR 2018 paper)
    :param init_model_file: the model directory to load from
    """

    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    device = "cuda"

    # Prepare the vm1 or vm2 model and send to device
    nf_enc = [16, 32, 32, 32]
    if model == "vm1":
        nf_dec = [32, 32, 32, 32, 8, 8]
    elif model == "vm2":
        nf_dec = [32, 32, 32, 32, 32, 16, 16]

    # Set up model
    vol_size = [2912, 2912]
    model = cvpr2018_net(vol_size, nf_enc, nf_dec)
    model.to(device)
    model.load_state_dict(
        torch.load(init_model_file, map_location=lambda storage, loc: storage))

    # set up
    with open(csv_path, 'r') as f:
        reader = csv.reader(f)
        for row in reader[1:]:
            fixed_id_name = niifile + 'hippocampus_' + row[0] + '.nii.gz'
            moving_id_name = niifile + 'hippocampus_' + row[1] + '.nii.gz'
            X = sitk.ReadImage(fixed_id_name)
            X = sitk.GetArrayFromImage(X)[np.newaxis, np.newaxis, ...]
            Y = sitk.ReadImage(moving_id_name)
            Y = sitk.GetArrayFromImage(Y)[np.newaxis, np.newaxis, ...]

            input_fixed = torch.from_numpy(X).to(device).float()
            input_fixed = input_fixed.permute(0, 3, 1, 2)
            input_moving = torch.from_numpy(Y).to(device).float()
            input_moving = input_moving.permute(0, 3, 1, 2)

            warp, flow = model(input_moving, input_fixed)
            dice_score = metrics.dice_score(warp, input_fixed)
            print('相似性度量dice:', row[0], 'and', row[1], dice_score)
示例#7
0
    def on_batch_close(self, loss: torch.Tensor, np_probs: torch.Tensor,
                       targets: torch.Tensor):
        # np_probs N*2*H*W      targets: N*H*W
        # targets = torch.zeros(size=np_probs.shape).scatter_(dim=1, index=targets.unsqueeze(dim=1).long(), value=1)
        np_preds = torch.argmax(np_probs, dim=1).squeeze()
        assert np_preds.shape == targets.shape
        self.batch_num += 1
        if not torch.isnan(loss):
            self.metrics['loss'] += float(loss)

        dice: torch.Tensor = metrics.dice_score(np_preds, targets)
        if not torch.isnan(dice):
            self.metrics['dice'] += float(dice)

        # iou: torch.Tensor = metrics.iou_score(np_preds, targets)
        # if not torch.isnan(iou):
        #     self.metrics['iou'] += float(iou)

        acc: torch.Tensor = metrics.accuracy_score(np_preds, targets)
        if not torch.isnan(acc):
            self.metrics['acc'] += float(acc)
示例#8
0
ship_train_loader = DataLoader(ship_train_dataset,
                               batch_size=4,
                               num_workers=32,
                               shuffle=False)
for epoch in range(300):
    epoch_loss = 0
    Diceloss = 0
    Sen = 0
    print('Starting epoch {}/{}.'.format(epoch + 1, 300))
    for i, item in enumerate(ship_train_loader):
        data, label = item
        #         print(data.shape)
        data = data.cuda()
        label = label.cuda()
        prediction = unet(data)
        dice1 = dice_score(prediction, label)
        sen1 = sen_score(prediction, label)
        loss = loss_func(prediction, label) + loss_func2(prediction,
                                                         label)  # 计算两者的误差

        #         loss = CB_loss(label,prediction, [1],1,"sigmoid", 0.9999, 2.0)
        Diceloss = Diceloss + dice1
        Sen = Sen + sen1
        epoch_loss = epoch_loss + loss
        # print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size , loss))
        optimizer.zero_grad()  # 清空上一步的残余更新参数值
        loss.backward()  # 误差反向传播, 计算参数更新值
        optimizer.step()  # 将参数更新值施加到 net 的 parameters 上

    print('Epoch finished ! Loss: {}'.format(epoch_loss / i) +
          '   Dice_Loss: {}'.format(Diceloss / i) +
        visual_add(
            np.squeeze(inputs[i_slice, channel, ...,
                              int(n_z / 2)].detach().numpy()), i_slice, i_col,
            gs, '')
        i_col += 1
    visual_add(
        np.squeeze(lesions[i_slice, ..., int(n_z / 2)].detach().numpy()),
        i_slice, i_col, gs, '')
    visual_add(np.squeeze(pred.detach().numpy()), i_slice, i_col + 1, gs, '')
    all_loss = FocalTverskyLoss().forward(pred, lesions[i_slice, ...,
                                                        int(n_z / 2)]).item()

    visual_add(np.squeeze(torch.sigmoid(pred).detach().numpy()), i_slice,
               i_col + 2, gs, 'FTL: ' + str(round(all_loss, 4)))
    hard_prediction = (pred > 0.5).float()
    dice = dice_score(hard_prediction,
                      np.squeeze(lesions[i_slice, ...,
                                         int(n_z / 2)])).item()
    print(str(dice))
    visual_add(np.squeeze(hard_prediction.detach().numpy()), i_slice,
               i_col + 3, gs, 'D: ' + str(round(dice, 4)))
    i_slice += 1

import os

plt.ioff()
plt.switch_backend('agg')
figure_path = os.path.join(data_dir, setting + '_prediction_visualisation.png')
print(figure_path)
figure.savefig(figure_path, dpi='figure', format='png')
plt.close(figure)
示例#10
0
def train_fn(loaders, model, criterion, optimizer, lr_scheduler, start_epoch,
             total_epochs, device, device_ids, save_path):
    model = model.to(device)
    model = nn.DataParallel(model, device_ids=device_ids)

    print("Epochs: {}\n".format(total_epochs))
    best_epoch = 1
    best_dice = 0.0
    history = {
        "train": {
            "loss": [],
            "dice": []
        },
        "eval": {
            "loss": [],
            "dice": []
        },
        "lr": []
    }

    for epoch in range(start_epoch, total_epochs + 1):
        head = "epoch {:3}/{:3}".format(epoch, total_epochs)
        print(head + "\n" + "-" * (len(head)))

        model.train()
        running_loss = 0.0
        running_dice = 0.0
        for images, masks in tqdm.tqdm(loaders["train"]):
            images, masks = images.to(device), masks.to(device).squeeze(
                1).transpose(1, 3).transpose(2, 3)

            optimizer.zero_grad()

            outputs5, outputs4, outputs3 = model(images)
            outputs5, outputs4, outputs3 = outputs5, F.interpolate(
                outputs4,
                size=(masks.shape[2], masks.shape[3]),
                mode="bilinear"), F.interpolate(outputs3,
                                                size=(masks.shape[2],
                                                      masks.shape[3]),
                                                mode="bilinear")
            preds5, preds4, preds3 = torch.sigmoid(outputs5), torch.sigmoid(
                outputs4), torch.sigmoid(outputs3)
            loss = criterion(preds5, masks) + criterion(
                preds4, masks) + criterion(preds3, masks)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            running_dice += dice_score(masks.cpu().numpy(),
                                       preds5.cpu().detach().numpy())

        epoch_loss = running_loss / len(loaders["train"])
        epoch_dice = running_dice / len(loaders["train"])
        history["train"]["loss"].append(epoch_loss)
        history["train"]["dice"].append(epoch_dice)
        print("{:5} - loss: {:.6f} dice: {:.6f}".format(
            "train", epoch_loss, epoch_dice))

        with torch.no_grad():
            model.eval()
            running_loss = 0.0
            running_dice = 0.0
            for images, masks in tqdm.tqdm(loaders["eval"]):
                images, masks = images.to(device), masks.to(device).squeeze(
                    1).transpose(1, 3).transpose(2, 3)

                outputs5, outputs4, outputs3 = model(images)
                outputs5, outputs4, outputs3 = outputs5, F.interpolate(
                    outputs4,
                    size=(masks.shape[2], masks.shape[3]),
                    mode="bilinear"), F.interpolate(outputs3,
                                                    size=(masks.shape[2],
                                                          masks.shape[3]),
                                                    mode="bilinear")
                preds5, preds4, preds3 = torch.sigmoid(
                    outputs5), torch.sigmoid(outputs4), torch.sigmoid(outputs3)
                loss = criterion(preds5, masks) + criterion(
                    preds4, masks) + criterion(preds3, masks)

                running_loss += loss.item()
                running_dice += dice_score(masks.cpu().numpy(),
                                           preds5.cpu().detach().numpy())

        epoch_loss = running_loss / len(loaders["eval"])
        epoch_dice = running_dice / len(loaders["eval"])
        history["eval"]["loss"].append(epoch_loss)
        history["eval"]["dice"].append(epoch_dice)
        print("{:5} - loss: {:.6f} dice: {:.6f}".format(
            "eval", epoch_loss, epoch_dice))
        history["lr"].append(optimizer.param_groups[0]["lr"])
        lr_scheduler.step(epoch_loss)

        if epoch_dice > best_dice:
            best_epoch = epoch
            best_dice = epoch_dice

            state_dicts = {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "lr_scheduler_state_dict": lr_scheduler.state_dict(),
                "best_metric": best_dice
            }
            torch.save(state_dicts, save_path)

    with open("{}.json".format(save_path[:-3]), "w") as f:
        json.dump(history, f)
    print("\nFinish: - Best Epoch: {:3} - Best DICE: {:.6f}\n".format(
        best_epoch, best_dice))
def main_training_loop(base_path,
                       idlist_image_train,
                       idlist_seg_train,
                       idlist_image_val,
                       idlist_seg_val,
                       img_folder,
                       seg_folder,
                       train_unet_with_GAN=False,
                       number_of_gan_samples=BATCH_SIZE,
                       number_of_real_samples=BATCH_SIZE,
                       load_gan_model_path=None,
                       load_unet_model_path=None,
                       augmentation_functions_image=None,
                       augmentation_functions_segmentation=None,
                       augmentation_params=None,
                       threshold_val=None,
                       folder_path=FOLDER_PATH,
                       save_checkpoints=True,
                       only_test=False):
    # Check mutable default args

    if augmentation_params == None:
        augmentation_params = [{}]

    #Reset everything from before
    tf.reset_default_graph()
    graph = tf.Graph()
    lib.delete_all_params()
    lib.delete_param_aliases()
    iteration = 0
    lib.plot.set(iteration)
    CHECKPOINT_PATH = os.path.join(folder_path, 'checkpoints')
    with graph.as_default():
        with tf.Session(graph=graph) as session:

            if number_of_real_samples != 0:
                images, segmentation_images = images_segmentations_from_paths(
                    base_path,
                    idlist_image_train,
                    idlist_seg_train,
                    img_folder,
                    seg_folder,
                    batch_size=number_of_real_samples,
                    resized_image_size=[IMAGE_HEIGHT, IMAGE_WIDTH],
                    shift_params=IMAGE_INTENSITY_SHIFT,
                    rescale_params=IMAGE_INTENSITY_SCALE,
                    image_channels=3,
                    force_to_grayscale=(IMAGE_CHANNELS == 1),
                    shuffle=True)
            else:
                images, segmentation_images = images_segmentations_from_paths(
                    base_path,
                    idlist_image_train,
                    idlist_seg_train,
                    img_folder,
                    seg_folder,
                    batch_size=number_of_gan_samples,
                    resized_image_size=[IMAGE_HEIGHT, IMAGE_WIDTH],
                    shift_params=IMAGE_INTENSITY_SHIFT,
                    rescale_params=IMAGE_INTENSITY_SCALE,
                    image_channels=3,
                    force_to_grayscale=(IMAGE_CHANNELS == 1),
                    shuffle=True)

            if augmentation_functions_segmentation != None and augmentation_functions_image != None:
                images, segmentation_images = augment_images(
                    images, segmentation_images, augmentation_functions_image,
                    augmentation_functions_segmentation, augmentation_params)

            val_images, val_segmentation_images = images_segmentations_from_paths(
                base_path,
                idlist_image_val,
                idlist_seg_val,
                img_folder,
                seg_folder,
                batch_size=1,
                num_preprocess_threads=1,
                min_queue_examples=1,
                resized_image_size=[IMAGE_HEIGHT, IMAGE_WIDTH],
                shift_params=IMAGE_INTENSITY_SHIFT,
                rescale_params=IMAGE_INTENSITY_SCALE,
                image_channels=3,
                force_to_grayscale=(IMAGE_CHANNELS == 1),
                shuffle=True)
            # Get number of entries in idlist_image_val so we know how many images to process
            with open(os.path.join(base_path, idlist_image_val)) as f:
                file_names = f.readlines()
            # you may also want to remove whitespace characters like `\n` at the end of each line
            file_names = [folder_path + x.strip() for x in file_names]

            num_val_images = len(file_names)

            segmentation_max = tf.reduce_max(segmentation_images)
            segmentation_min = tf.reduce_min(segmentation_images)

            segmentations_normalized = tf.multiply(
                tf.add(
                    tf.div(
                        tf.subtract(segmentation_images,
                                    tf.reduce_min(segmentation_images)),
                        tf.subtract(segmentation_max, segmentation_min)),
                    -0.5), 2)

            gan_input_plus_segmentation = tf.concat(
                [images, segmentations_normalized], 3)

            Generator, Discriminator = gan_model.GeneratorAndDiscriminator()

            is_training = None  # tf.Variable(False, trainable=False)
            stats_iter_bn = tf.Variable(0, trainable=False)

            session.run(tf.initialize_variables([stats_iter_bn]))

            if number_of_gan_samples != 0:
                gen_train_op, disc_train_op, gen_cost, disc_cost = gan_model.build(
                    session, gan_input_plus_segmentation, Generator,
                    Discriminator, MODE, LAMBDA, number_of_gan_samples,
                    DEVICES, KERNEL_SIZE, IMAGE_WIDTH, IMAGE_HEIGHT,
                    IMAGE_CHANNELS, NOISE_DIM, OUTPUT_DIM, SMALLEST_IMAGE_DIM,
                    DIM, is_training, stats_iter_bn)
            else:
                gen_train_op, disc_train_op, gen_cost, disc_cost = gan_model.build(
                    session, gan_input_plus_segmentation, Generator,
                    Discriminator, MODE, LAMBDA, number_of_real_samples,
                    DEVICES, KERNEL_SIZE, IMAGE_WIDTH, IMAGE_HEIGHT,
                    IMAGE_CHANNELS, NOISE_DIM, OUTPUT_DIM, SMALLEST_IMAGE_DIM,
                    DIM, is_training, stats_iter_bn)

            if train_unet_with_GAN == True:
                gen_data, gen_data_orig_shape = Generator(
                    number_of_gan_samples,
                    KERNEL_SIZE,
                    IMAGE_WIDTH,
                    IMAGE_HEIGHT,
                    IMAGE_CHANNELS,
                    NOISE_DIM,
                    OUTPUT_DIM,
                    SMALLEST_IMAGE_DIM,
                    DIM,
                    MODE,
                    reuse=True,
                    stats_iter=stats_iter_bn,
                    is_training=is_training)

                # Make sure that Generator data (NCHW) is transposed correctly before it's used as NHWC again
                gen_data = tf.reshape(gen_data,
                                      (number_of_gan_samples, IMAGE_CHANNELS +
                                       1, IMAGE_HEIGHT, IMAGE_WIDTH))
                gen_data = util.NCHW_to_NHWC(gen_data)

                gen_img = tf.squeeze(gen_data[:, :, :, 0:IMAGE_CHANNELS])

                if len(gen_img.shape) == 3:
                    gen_img = tf.expand_dims(gen_img, -1)

                gen_seg = gen_data[:, :, :, IMAGE_CHANNELS]

                #unnormalize segmentation data for unet
                gen_seg = tf.add(
                    tf.multiply(
                        tf.multiply(tf.add(gen_seg, 1), 0.5),
                        tf.subtract(segmentation_max, segmentation_min)),
                    segmentation_min)

                #Threshold the generated GAN data.

                if threshold_val != None:
                    zeros_like_gen_seg = tf.zeros_like(gen_seg)
                    gen_seg = zeros_like_gen_seg + tf.to_float(
                        tf.greater_equal(gen_seg, threshold_val))

                if len(gen_seg.shape) == 3:
                    gen_seg = tf.expand_dims(gen_seg, -1)

                if augmentation_functions_segmentation != None and augmentation_functions_image != None:
                    gen_img, gen_seg = augment_images(
                        gen_img, gen_seg, augmentation_functions_image,
                        augmentation_functions_segmentation,
                        augmentation_params)

                if number_of_real_samples != 0:
                    images = tf.concat([images, gen_img], 0)
                    segmentation_images = tf.concat(
                        [segmentation_images, gen_seg], 0)
                else:
                    images = gen_img
                    segmentation_images = gen_seg

                gan_input_plus_segmentation = tf.concat(
                    [images, segmentation_images], 3)

            unet_train_op, unet_loss_train, pred_train, unet_out_nchw, probs_train = unet_train.unet_ops(
                images,
                segmentation_images,
                num_classes=NUMBER_OF_SEGMENTATION_CLASSES,
                is_training=True,
                reuse_vars=None)

            _, unet_loss_val, pred_val, _, _ = unet_train.unet_ops(
                val_images,
                val_segmentation_images,
                num_classes=NUMBER_OF_SEGMENTATION_CLASSES,
                is_training=False,
                reuse_vars=True)

            if only_test == True:
                test_images, test_segmentation_images = images_segmentations_from_paths(
                    base_path,
                    idlist_image_val,
                    idlist_seg_val,
                    img_folder,
                    seg_folder,
                    batch_size=1,
                    num_preprocess_threads=1,
                    min_queue_examples=1,
                    resized_image_size=[IMAGE_HEIGHT, IMAGE_WIDTH],
                    shift_params=IMAGE_INTENSITY_SHIFT,
                    rescale_params=IMAGE_INTENSITY_SCALE,
                    image_channels=3,
                    force_to_grayscale=(IMAGE_CHANNELS == 1),
                    shuffle=False)

            session.run(tf.initialize_all_variables())

            # Start training queue
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=session, coord=coord)

            # Train loop

            # only restore model stuff
            var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                         scope='g')
            var_list.extend(
                tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='d'))

            saver = tf.train.Saver(var_list=var_list)

            var_list_unet = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                              scope='unet')

            unet_saver = tf.train.Saver(var_list=var_list_unet)

            if load_gan_model_path != None:
                saver.restore(session, load_gan_model_path)
                iteration = 0
                print("GAN Model restored.")
            else:
                iteration = 0

            if load_unet_model_path != None:
                unet_saver.restore(session, load_unet_model_path)
                iteration = 0
                print("UNet Model restored.")
            else:
                iteration = 0

            imgs_in = session.run(gan_input_plus_segmentation)
            for k in range(BATCH_SIZE):
                imgs_folder = os.path.join(folder_path, 'in/iter_%d/') % 0
                if not os.path.exists(imgs_folder):
                    os.makedirs(imgs_folder)

                img_channel = imgs_in[k][:, :, 0:IMAGE_CHANNELS]
                img_seg = imgs_in[k][:, :, IMAGE_CHANNELS]

                imsave(
                    os.path.join(imgs_folder, 'img_%d.png') % k,
                    img_channel.reshape(IMAGE_HEIGHT, IMAGE_WIDTH,
                                        IMAGE_CHANNELS).squeeze())
                imsave(
                    os.path.join(imgs_folder, 'seg_%d.png') % k,
                    img_seg.reshape(IMAGE_HEIGHT, IMAGE_WIDTH))

            if not os.path.exists(folder_path):
                os.makedirs(folder_path)

            if only_test == True:
                #Get number of entries in idlist_image_val so we know how many images to process
                with open(os.path.join(base_path, idlist_image_val)) as f:
                    file_names = f.readlines()

                file_names = [os.path.basename(x.strip()) for x in file_names]

                num_images = len(file_names)

                _, unet_loss_test, pred_test, _, _ = unet_train.unet_ops(
                    test_images,
                    test_segmentation_images,
                    num_classes=NUMBER_OF_SEGMENTATION_CLASSES,
                    is_training=False,
                    reuse_vars=True)

                dice_scores_test = []

                seg_test_0_1 = tf.multiply(
                    tf.add(test_segmentation_images, 1.0),
                    0.5) * (NUMBER_OF_SEGMENTATION_CLASSES - 1)

                dice_unet_test = dice_score(
                    tf.squeeze(pred_test),
                    tf.squeeze(test_segmentation_images),
                    num_classes=NUMBER_OF_SEGMENTATION_CLASSES,
                    session=session)

                test_imgs_out = []
                test_segs_out = []
                test_preds_out = []
                with open(folder_path + '/test_out.txt', mode='wt') as myfile:

                    for img_index in range(num_images):
                        dice_score_test_value, test_pred, test_img, test_seg = session.run(
                            [
                                dice_unet_test, pred_test, test_images,
                                test_segmentation_images
                            ])

                        test_imgs_out.append(test_img)
                        test_segs_out.append(test_seg)
                        test_preds_out.append(test_pred)

                        dice_scores_test.append(dice_score_test_value)
                        out_string = 'Dice ' + str(
                            img_index + 1) + ' / ' + str(
                                num_images) + ': ' + str(dice_score_test_value)

                        print(out_string)
                        myfile.write(out_string + '\n')

                    print('Average Dice: ' +
                          str(np.asarray(dice_scores_test).mean()) +
                          ' (stddev: ' +
                          str(np.asarray(dice_scores_test).std()) + ')')

                    myfile.write('Average Dice: ' +
                                 str(np.asarray(dice_scores_test).mean()) +
                                 ' (stddev: ' +
                                 str(np.asarray(dice_scores_test).std()) +
                                 ')' + '\n')

                counter = 0

                if not os.path.exists(os.path.join(folder_path, 'test_out')):
                    os.makedirs(os.path.join(folder_path, 'test_out'))
                for img, seg, pred, filename in zip(test_imgs_out,
                                                    test_segs_out,
                                                    test_preds_out,
                                                    file_names):
                    img = ((img + 0) * (255.99)).astype('int32')
                    seg = seg.astype('int32')
                    pred = pred.astype('int32')

                    imsave(
                        os.path.join(os.path.join(folder_path, 'test_out'),
                                     'img_%d.png') % counter, np.squeeze(img))

                    scipy.misc.toimage(np.squeeze(seg), cmin=0, cmax=255).save(
                        os.path.join(os.path.join(folder_path, 'test_out'),
                                     'seg_%d.png') % counter)
                    scipy.misc.toimage(
                        np.squeeze(pred), cmin=0, cmax=255).save(os.path.join(
                            os.path.join(folder_path, 'test_out'), filename),
                                                                 format='png')

                    counter = counter + 1

                coord.request_stop()
                coord.join(threads)

                return _, _, _, _

            # The dice score expects labels to be [0 1], but the labels were brought to the range of [-1 1]
            # Therefore we need to bring it back to [0 1]
            seg_0_1 = tf.multiply(tf.add(segmentation_images, 1.0),
                                  0.5) * (NUMBER_OF_SEGMENTATION_CLASSES - 1)
            seg_val_0_1 = tf.multiply(
                tf.add(val_segmentation_images,
                       1.0), 0.5) * (NUMBER_OF_SEGMENTATION_CLASSES - 1)

            dice_unet_train = dice_score(
                pred_train,
                tf.squeeze(segmentation_images),
                num_classes=NUMBER_OF_SEGMENTATION_CLASSES)
            dice_unet_val = dice_score(
                pred_val,
                tf.squeeze(val_segmentation_images),
                num_classes=NUMBER_OF_SEGMENTATION_CLASSES)

            test_pred_1 = session.run(pred_train)
            test_seg_1 = session.run(segmentation_images)
            test_seg_0_1 = session.run(seg_0_1)
            test_rounded = session.run(tf.round(seg_0_1))

            loss_vals = []
            last_loss_avg = 9999999

            unet_val_loss_value = float('nan')
            dice_value_val = 0
            dice_value_train = 0
            unet_train_loss_value = float('nan')
            while iteration < ITERS:

                start_time = time.time()

                if TRAIN_GAN == True:
                    # Train generator
                    if iteration > 0:
                        _ = session.run(gen_train_op)

                    # Train critic
                    if (MODE == 'dcgan') or (MODE == 'lsgan'):
                        disc_iters = 1
                    else:
                        disc_iters = CRITIC_ITERS
                    for i in xrange(disc_iters):
                        _disc_cost, _ = session.run([disc_cost, disc_train_op])

                    lib.plot.plot('train disc cost', _disc_cost)

                if TRAIN_UNET == True:
                    dice_value_train = session.run(dice_unet_train)
                    unet_train_loss_value, _ = session.run(
                        [unet_loss_train, unet_train_op])

                    if iteration % 20 == 0:

                        dice_scores_val = []
                        val_losses = []
                        for img_index in range(num_val_images):
                            dice_value_val = session.run(dice_unet_val)
                            val_loss = session.run(unet_loss_val)
                            dice_scores_val.append(dice_value_val)
                            val_losses.append(val_loss)

                        print('Average Dice: ' +
                              str(np.asarray(dice_scores_val).mean()) +
                              ' (stddev: ' +
                              str(np.asarray(dice_scores_val).std()) + ')')
                        dice_value_val = np.asarray(dice_scores_val).mean()

                        unet_val_loss_value = np.asarray(val_losses).mean()

                        print('Average validation loss: ' +
                              str(unet_val_loss_value))

                lib.plot.plot('train unet dice', dice_value_train)
                lib.plot.plot('train unet cost', unet_train_loss_value)
                lib.plot.plot('val unet cost', unet_val_loss_value)
                lib.plot.plot('val unet dice', dice_value_val)

                if iteration % ITERS_BETWEEN_OUTPUTS == 1:
                    t = time.time()

                    # Save Checkpoints / Iteration count
                    if not os.path.exists(CHECKPOINT_PATH):
                        os.makedirs(CHECKPOINT_PATH)

                    np.save(os.path.join(CHECKPOINT_PATH, ITERATIONS_NAME),
                            iteration)

                    if TRAIN_GAN == True and save_checkpoints == True:
                        # Generate Samples
                        gan_model.generate_image(
                            iteration, session, Generator, KERNEL_SIZE,
                            OUTPUT_DIM, SMALLEST_IMAGE_DIM, DIM, BATCH_SIZE,
                            NOISE_DIM, DEVICES, IMAGE_WIDTH, IMAGE_HEIGHT,
                            IMAGE_CHANNELS, MODE, folder_path, stats_iter_bn,
                            is_training)
                        saver.save(
                            session,
                            os.path.join(CHECKPOINT_PATH, CHECKPOINT_NAME))

                    if TRAIN_UNET == True and save_checkpoints == True:
                        util.save_tensor_image_to_folder(
                            session, pred_train, BATCH_SIZE, iteration,
                            os.path.join(folder_path, 'pred_train'))
                        util.save_tensor_image_to_folder(
                            session, pred_val, 1, iteration,
                            os.path.join(folder_path, 'pred_val'))
                        unet_saver.save(
                            session,
                            os.path.join(CHECKPOINT_PATH,
                                         UNET_CHECKPOINT_NAME))

                if iteration > 10000 and TRAIN_GAN == True:

                    loss_vals.append(_disc_cost)

                    if len(loss_vals) > 100:
                        avg_loss = np.asarray(loss_vals).mean()

                        if avg_loss > last_loss_avg:
                            print(
                                'Loss increased over the last 100 iterations, stopping!'
                            )
                            break
                        else:
                            last_loss_avg = avg_loss

                        del loss_vals[:]

                if iteration > 500 and TRAIN_UNET == True:
                    # After 3000 iters, check if the loss is decreasing still
                    loss_vals.append(unet_val_loss_value)

                    if len(loss_vals) > 50:
                        avg_loss = np.asarray(loss_vals).mean()

                        if avg_loss > last_loss_avg:
                            print(
                                'Loss increased over the last 50 iterations, stopping!'
                            )
                            lib.plot.flush(
                                os.path.join(folder_path, CSV_LOG_NAME))
                            break
                        else:
                            last_loss_avg = avg_loss

                        del loss_vals[:]

                lib.plot.plot('time', time.time() - start_time)
                if (iteration < 5) or (iteration % 20 == 0):
                    #print('Before flush: dice: ' + str(dice_value_val) + ' loss: ' + str(unet_val_loss_value))
                    lib.plot.flush(os.path.join(folder_path, CSV_LOG_NAME))

                lib.plot.tick()
                iteration = iteration + 1

            coord.request_stop()
            coord.join(threads)

        return dice_value_train, dice_value_val, unet_val_loss_value, unet_train_loss_value
示例#12
0
                        elif model_type == 'BB_Unet':
                            var_bbox = batch["bboxes"].reshape(
                                -1, 1, size, size)
                            var_bbox = torch.tensor(var_bbox,
                                                    dtype=torch.float)
                            var_bbox = Variable(var_bbox, requires_grad=True)
                            preds = model(var_input, var_bbox)

                        loss = losses.dice_loss(preds.to('cpu'), gt_samples)
                        print('loss for batch {} on epoch {} is {}'.format(
                            i, epoch, loss))

                    predicted_output = threshold(preds.to('cpu'))
                    predicted_output = predicted_output.type(torch.float32)
                    d_metric, dict = dice_score(predicted_output.to('cpu'),
                                                gt_samples.to('cpu'))

                    if phase in ['train_ancillary', 'train']:
                        train_loss_total += loss
                        tr_dice_total += d_metric
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                        print('training on batch' + str(i) + 'with accuracy' +
                              str(d_metric))
                    else:
                        val_loss_total += loss
                        val_dice_total += d_metric
                        print('validating on batch : ' + str(i) +
                              'with accuracy' + str(d_metric))