示例#1
0
def test_generalized_dice():
    shape = (8, 32, 32, 32, 16)
    x = np.zeros(shape)
    y = np.zeros(shape)
    assert_array_equal(metrics.generalized_dice(x, y), np.ones(shape[0]))

    shape = (8, 32, 32, 32, 16)
    x = np.ones(shape)
    y = np.ones(shape)
    assert_array_equal(metrics.generalized_dice(x, y), np.ones(shape[0]))

    shape = (8, 32, 32, 32, 16)
    x = np.ones(shape)
    y = np.zeros(shape)
    # Why aren't the scores exactly zero? Could it be the propogation of floating
    # point inaccuracies when summing?
    assert_allclose(metrics.generalized_dice(x, y),
                    np.zeros(shape[0]),
                    atol=1e-03)

    x = np.ones((4, 32, 32, 32, 1), dtype=np.float64)
    y = x.copy()
    x[:2, :10, 10:] = 0
    y[:2, :3, 20:] = 0
    y[3:, 10:] = 0
    # Dice is similar to generalized Dice for one class. The weight factor
    # makes the generalized form slightly different from Dice.
    gd = metrics.generalized_dice(x, y, axis=(1, 2, 3)).numpy()
    dd = metrics.dice(x, y, axis=(1, 2, 3, 4)).numpy()
    assert_allclose(gd, dd, rtol=1e-02)  # is this close enough?
示例#2
0
def dice_cce(y_true, y_pred, axis=(1, 2, 3), ignore_background=False):
    "y_true and y_pred should be one_hot encoded"
    dice = 1 - generalized_dice(y_true, y_pred, axis=(1, 2, 3))
    if ignore_background:
        mask = 1 - y_true[:, :, :, :, 0]
        loss = categorical_crossentropy(y_true, y_pred)
        cce = tf.math.multiply(loss, mask)
    else:
        cce = categorical_crossentropy(y_true, y_pred)
    return dice + cce
示例#3
0
def calcualte_dice(label,
                   pred,
                   n_classes,
                   axis=(1, 2, 3),
                   one_hot_label=False):
    """ pred is the output probabilities of the network"""
    #pred = np.argmax(pred, -1)
    #pred = tf.one_hot(pred, depth = n_classes)
    if not one_hot_label:
        label = tf.one_hot(label, depth=n_classes)
    return generalized_dice(label, pred, axis=axis)
示例#4
0
def generalized_dice(y_true, y_pred, axis=(1, 2, 3)):
    return 1.0 - metrics.generalized_dice(
        y_true=y_true, y_pred=y_pred, axis=axis)
def calculate_dice(labels, preds):
    labels = tf.one_hot(labels)
    preds = tf.one_hot(np.argmax(preds, -1))
    return generalized_dice(labels, preds, axis=(1, 2, 3))
示例#6
0
def run(block_shape, dropout_typ, model_name):

    # Constants
    root_path = '/om/user/satra/kwyk/tfrecords/'
    # to run the code on Satori
    #root_path = "/nobackup/users/abizeul/kwyk/tfrecords/"

    train_pattern = root_path + 'data-train_shard-*.tfrec'
    eval_pattern = root_path + "data-evaluate_shard-*.tfrec"

    n_classes = 115
    volume_shape = (256, 256, 256)
    EPOCHS = 10
    BATCH_SIZE_PER_REPLICA = 1

    #Setting up the multi gpu strategy
    strategy = tf.distribute.MirroredStrategy()
    print("Number of replicas {}".format(strategy.num_replicas_in_sync))
    GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

    # Create a `tf.data.Dataset` instance.
    dataset_train = get_dataset(train_pattern, volume_shape, GLOBAL_BATCH_SIZE,
                                block_shape, n_classes)
    dataset_eval = get_dataset(eval_pattern, volume_shape, GLOBAL_BATCH_SIZE,
                               block_shape, n_classes)

    # Distribute dataset.
    #train_dist_dataset = strategy.experimental_distribute_dataset(dataset_train)

    # Create a checkpoint directory to store the checkpoints.
    checkpoint_dir = os.path.join("training_files", model_name,
                                  "training_checkpoints")
    # Name of the checkpoint files
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

    with strategy.scope():
        optimizer = tf.keras.optimizers.Adam(1e-04)
        model = variational_meshnet(n_classes=n_classes,
                                    input_shape=block_shape + (1, ),
                                    filters=96,
                                    dropout=dropout_typ,
                                    is_monte_carlo=True,
                                    receptive_field=129)
        loss_fn = ELBO(model=model,
                       num_examples=np.prod(block_shape),
                       reduction=tf.keras.losses.Reduction.NONE)

        checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
        model.compile(loss=loss_fn,
                      optimizer=optimizer,
                      metrics=[generalized_dice],
                      experimental_run_tf_function=False)

        # outfile= os.path.join("training_files",model_name,"out-{}")

        # training loop
        train_loss = []
        train_metrics = []
        start = time()
        for epoch in range(EPOCHS):
            print('Epoch number ', epoch)
            i = 0
            for data in dataset_train:
                i += 1
                error, metric = model.train_on_batch(data)
                train_loss.append(error)
                train_metrics.append(metric)
                print('Batch {}, error : {}, dice:{}'.format(i, error, metric))

            checkpoint.save(checkpoint_prefix.format(epoch=epoch))
            # result = model.predict_on_batch(data)
            # (feat, label) = data
            # np.savez(outfile.format(epoch),label=label.numpy(),result=result)
        training_time = time() - start

        # evaluating loop
        print("---------- evaluating ----------")
        i = 0
        eval_loss = []
        dice_scores = []
        outfile_eval = os.path.join("training_files", model_name, "evalout-{}")
        for data in dataset_eval:
            i += 1
            eval_error = model.test_on_batch(data)
            eval_loss.append(eval_error)
            print('Batch {}, eval_loss : {}'.format(i, eval_error))

            # calculate dice
            result = model.predict_on_batch(data)
            result = np.argmax(result, -1)
            result = tf.one_hot(result, depth=n_classes)
            (feat, label) = data
            label = tf.one_hot(label, depth=n_classes)
            dice_score = generalized_dice(label, result, axis=(1, 2, 3))
            dice_scores.append(tf.reduce_mean(dice_score).numpy().tolist())
            if i % 20 == 0:
                np.savez(outfile_eval.format(i),
                         label=label.numpy(),
                         result=result)

        # Save model and variables
        variables = {
            "train_loss": train_loss,
            "train_dice": train_metrics,
            "eval_loss": eval_loss,
            "eval_dice": dice_scores
        }
        file_path = os.path.join("training_files", model_name,
                                 "data-{}.json".format(model_name))
        with open(file_path, 'w') as fp:
            json.dump(variables, fp, indent=4)

        #model_name="kwyk_128_full.h5"
        #saved_model_path=os.path.join("./training_files",model_name,"saved_model/{}.h5".format(model_name))
        #model.save(saved_model_path, save_format='h5')
        saved_model_path = os.path.join("./training_files", model_name,
                                        "saved_model/")
        model.save(saved_model_path, save_format='tf')

    return training_time
示例#7
0
def dice_loss(y_true, y_pred, axis=(1, 2, 3)):
    return 1 - generalized_dice(y_true, y_pred, axis=(1, 2, 3))