Example #1
0
def inference_net():
    input_shape = (91, 862, 862)
    raw = tf.placeholder(tf.float32, shape=input_shape)
    raw_bc = tf.reshape(raw, (1, 1) + input_shape)

    last_fmap, fov, anisotropy = unet.unet(
        raw_bc,
        12,
        6,
        [[1, 3, 3], [1, 3, 3], [3, 3, 3]],
        [
            [(1, 3, 3), (1, 3, 3)],
            [(1, 3, 3), (1, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
        ],
        [
            [(1, 3, 3), (1, 3, 3)],
            [(1, 3, 3), (1, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
        ],
        voxel_size=(10, 1, 1),
        fov=(10, 1, 1),
    )

    logits_bc, fov = ops3d.conv_pass(
        last_fmap,
        kernel_size=[[1, 1, 1]],
        num_fmaps=2,
        activation=None,
        fov=fov,
        voxel_size=anisotropy,
    )

    output_shape_bc = logits_bc.get_shape().as_list()

    output_shape_c = output_shape_bc[1:]  # strip the batch dimension
    output_shape = output_shape_c[1:]  # strip the channel dimension

    probabilities = tf.reshape(tf.nn.softmax(logits_bc, dim=1)[0], output_shape_c)
    predictions = tf.argmax(probabilities, axis=0)
    print(probabilities.name)

    tf.train.export_meta_graph(filename="unet_inference.meta")
Example #2
0
def inference_net():
    input_shape = (400, 400, 400)
    raw = tf.placeholder(tf.float32, shape=input_shape)
    raw_bc = tf.reshape(raw, (1, 1) + input_shape)

    last_fmap, fov, anisotropy = unet.unet(
        raw_bc,
        12,
        6,
        [[2, 2, 2], [2, 2, 2], [3, 3, 3]],
        [
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
        ],
        [
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
        ],
        voxel_size=(1, 1, 1),
        fov=(1, 1, 1),
    )
    pred_raw_bc, fov = ops3d.conv_pass(
        last_fmap,
        kernel_size=[[1, 1, 1]],
        num_fmaps=1,
        activation=None,
        fov=fov,
        voxel_size=anisotropy,
    )
    output_shape_bc = pred_raw_bc.get_shape().as_list()
    output_shape_c = output_shape_bc[1:]
    output_shape = output_shape_c[1:]

    pred_raw = tf.reshape(pred_raw_bc, output_shape)

    tf.train.export_meta_graph(filename="unet_inference.meta")
Example #3
0
def train_net():

    # z    [1, 1, 1]:  66 ->  38 -> 10
    # y, x [2, 2, 2]: 228 -> 140 -> 52
    shape_0 = (220, ) * 3
    shape_1 = (132, ) * 3
    shape_2 = (44, ) * 3

    affs_0_bc = tf.ones((1, 3) + shape_0) * 0.5

    with tf.variable_scope("autocontext") as scope:

        # phase 1
        raw_0 = tf.placeholder(tf.float32, shape=shape_0)
        raw_0_bc = tf.reshape(raw_0, (1, 1) + shape_0)

        input_0 = tf.concat([raw_0_bc, affs_0_bc], 1)

        out_bc, fov, anisotropy = unet.unet(
            input_0,
            24,
            3,
            [[2, 2, 2], [2, 2, 2], [2, 2, 2]],
            [
                [(3, 3, 3), (3, 3, 3)],
                [(3, 3, 3), (3, 3, 3)],
                [(3, 3, 3), (3, 3, 3)],
                [(3, 3, 3), (3, 3, 3)],
            ],
            [
                [(3, 3, 3), (3, 3, 3)],
                [(3, 3, 3), (3, 3, 3)],
                [(3, 3, 3), (3, 3, 3)],
                [(3, 3, 3), (3, 3, 3)],
            ],
        )

        affs_1_bc, fov = ops3d.conv_pass(
            out_bc,
            kernel_size=[[1, 1, 1]],
            num_fmaps=3,
            activation="sigmoid",
            fov=fov,
            voxel_size=anisotropy,
        )

        affs_1_c = tf.reshape(affs_1_bc, (3, ) + shape_1)
        gt_affs_1_c = tf.placeholder(tf.float32, shape=(3, ) + shape_1)
        loss_weights_1_c = tf.placeholder(tf.float32, shape=(3, ) + shape_1)

        loss_1 = tf.losses.mean_squared_error(gt_affs_1_c, affs_1_c,
                                              loss_weights_1_c)

        # phase 2
        tf.summary.scalar("loss_pred0", loss_1)
        scope.reuse_variables()

        raw_1 = ops3d.center_crop(raw_0, shape_1)
        raw_1_bc = tf.reshape(raw_1, (1, 1) + shape_1)

        input_1 = tf.concat([raw_1_bc, affs_1_bc], 1)

        out_bc, fov, anisotropy = unet.unet(
            input_1,
            24,
            3,
            [[2, 2, 2], [2, 2, 2], [2, 2, 2]],
            [
                [(3, 3, 3), (3, 3, 3)],
                [(3, 3, 3), (3, 3, 3)],
                [(3, 3, 3), (3, 3, 3)],
                [(3, 3, 3), (3, 3, 3)],
            ],
            [
                [(3, 3, 3), (3, 3, 3)],
                [(3, 3, 3), (3, 3, 3)],
                [(3, 3, 3), (3, 3, 3)],
                [(3, 3, 3), (3, 3, 3)],
            ],
            fov=fov,
            voxel_size=anisotropy,
        )

        affs_2_bc, fov = ops3d.conv_pass(
            out_bc,
            kernel_size=[[1, 1, 1]],
            num_fmaps=3,
            activation="sigmoid",
            fov=fov,
            voxel_size=anisotropy,
        )

        affs_2_c = tf.reshape(affs_2_bc, (3, ) + shape_2)
        gt_affs_2_c = ops3d.center_crop(gt_affs_1_c, (3, ) + shape_2)
        loss_weights_2_c = ops3d.center_crop(loss_weights_1_c, (3, ) + shape_2)

        loss_2 = tf.losses.mean_squared_error(gt_affs_2_c, affs_2_c,
                                              loss_weights_2_c)
        tf.summary.scalar("loss_pred1", loss_2)
    loss = loss_1 + loss_2
    tf.summary.scalar("loss_total", loss)
    tf.summary.scalar("loss_diff", loss_1 - loss_2)
    for trainable in tf.trainable_variables():
        custom_ops.tf_var_summary(trainable)
    merged = tf.summary.merge_all()

    opt = tf.train.AdamOptimizer(learning_rate=0.5e-4,
                                 beta1=0.95,
                                 beta2=0.999,
                                 epsilon=1e-8)
    optimizer = opt.minimize(loss)

    tf.train.export_meta_graph(filename="wnet.meta")

    names = {
        "raw": raw_0.name,
        "affs_1": affs_1_c.name,
        "affs_2": affs_2_c.name,
        "gt_affs": gt_affs_1_c.name,
        "loss_weights": loss_weights_1_c.name,
        "loss": loss.name,
        "optimizer": optimizer.name,
        "summary": merged.name,
    }
    with open("net_io_names.json", "w") as f:
        json.dump(names, f)
Example #4
0
def train_net():
    input_shape = (196, 196, 196)
    raw = tf.placeholder(tf.float32, shape=input_shape)
    raw_bc = tf.reshape(raw, (1, 1) + input_shape)

    last_fmap, fov, anisotropy = unet.unet(
        raw_bc,
        12,
        6,
        [[2, 2, 2], [2, 2, 2], [3, 3, 3]],
        [
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
        ],
        [
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
        ],
        voxel_size=(1, 1, 1),
        fov=(1, 1, 1),
    )
    pred_raw_bc, fov = ops3d.conv_pass(
        last_fmap,
        kernel_size=[[1, 1, 1]],
        num_fmaps=1,
        activation=None,
        fov=fov,
        voxel_size=anisotropy,
    )
    output_shape_bc = pred_raw_bc.get_shape().as_list()
    output_shape_c = output_shape_bc[1:]
    output_shape = output_shape_c[1:]

    pred_raw = tf.reshape(pred_raw_bc, output_shape)

    gt_raw_bc = ops3d.crop_zyx(raw_bc, output_shape_bc)
    gt_raw = tf.reshape(gt_raw_bc, output_shape)

    loss = tf.losses.mean_squared_error(gt_raw, pred_raw)
    tf.summary.scalar("loss", loss)

    opt = tf.train.AdamOptimizer(learning_rate=0.5e-4,
                                 beta1=0.95,
                                 beta2=0.999,
                                 epsilon=1e-8)
    optimizer = opt.minimize(loss)

    merged = tf.summary.merge_all()
    tf.train.export_meta_graph(filename="unet.meta")

    names = {
        "raw": raw.name,
        "pred_raw": pred_raw.name,
        "optimizer": optimizer.name,
        "summary": merged.name,
        "loss": loss.name,
    }
    with open("net_io_names.json", "w") as f:
        json.dump(names, f)
Example #5
0
def train_net():
    input_shape = (43, 430, 430)
    raw = tf.placeholder(tf.float32, shape=input_shape)
    raw_bc = tf.reshape(raw, (1, 1) + input_shape)

    last_fmap, fov, anisotropy = unet.unet(
        raw_bc,
        12,
        6,
        [[1, 3, 3], [1, 3, 3], [3, 3, 3]],
        [
            [(1, 3, 3), (1, 3, 3)],
            [(1, 3, 3), (1, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
        ],
        [
            [(1, 3, 3), (1, 3, 3)],
            [(1, 3, 3), (1, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
        ],
        voxel_size=(10, 1, 1),
        fov=(10, 1, 1),
    )

    logits_bc, fov = ops3d.conv_pass(
        last_fmap,
        kernel_size=[[1, 1, 1]],
        num_fmaps=2,
        activation=None,
        fov=fov,
        voxel_size=anisotropy,
    )

    output_shape_bc = logits_bc.get_shape().as_list()

    output_shape_c = output_shape_bc[1:]  # strip the batch dimension
    output_shape = output_shape_c[1:]  # strip the channel dimension
    flat_logits = tf.transpose(tf.reshape(tensor=logits_bc, shape=(2, -1)))

    gt_labels = tf.placeholder(tf.float32, shape=output_shape)
    gt_labels_flat = tf.reshape(gt_labels, (-1,))

    gt_bg = tf.to_float(tf.not_equal(gt_labels_flat, 1))
    flat_ohe = tf.stack(values=[gt_labels_flat, gt_bg], axis=1)

    loss_weights = tf.placeholder(tf.float32, shape=output_shape)
    loss_weights_flat = tf.reshape(loss_weights, (-1,))

    mask = tf.placeholder(tf.float32, shape=output_shape)
    mask_flat = tf.reshape(mask, (-1,))

    probabilities = tf.reshape(tf.nn.softmax(logits_bc, dim=1)[0], output_shape_c)
    predictions = tf.argmax(probabilities, axis=0)

    ce_loss_balanced = tf.losses.softmax_cross_entropy(
        flat_ohe, flat_logits, weights=loss_weights_flat
    )
    ce_loss_unbalanced = tf.losses.softmax_cross_entropy(
        flat_ohe, flat_logits, weights=mask_flat
    )
    tf.summary.scalar("loss_balanced_syn", ce_loss_balanced)
    tf.summary.scalar("loss_unbalanced_syn", ce_loss_unbalanced)

    opt = tf.train.AdamOptimizer(
        learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8
    )

    optimizer = opt.minimize(ce_loss_balanced)
    merged = tf.summary.merge_all()

    tf.train.export_meta_graph(filename="unet.meta")

    names = {
        "raw": raw.name,
        "probabilities": probabilities.name,
        "predictions": predictions.name,
        "gt_labels": gt_labels.name,
        "loss_balanced_syn": ce_loss_balanced.name,
        "loss_unbalanced_syn": ce_loss_unbalanced.name,
        "loss_weights": loss_weights.name,
        "mask": mask.name,
        "optimizer": optimizer.name,
        "summary": merged.name,
    }

    with open("net_io_names.json", "w") as f:
        json.dump(names, f)