Example #1
0
def inference_net():
    input_shape = (91, 862, 862)
    raw = tf.placeholder(tf.float32, shape=input_shape)
    raw_bc = tf.reshape(raw, (1, 1) + input_shape)

    last_fmap, fov, anisotropy = unet.unet(
        raw_bc,
        12,
        6,
        [[1, 3, 3], [1, 3, 3], [3, 3, 3]],
        [
            [(1, 3, 3), (1, 3, 3)],
            [(1, 3, 3), (1, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
        ],
        [
            [(1, 3, 3), (1, 3, 3)],
            [(1, 3, 3), (1, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
        ],
        voxel_size=(10, 1, 1),
        fov=(10, 1, 1),
    )

    logits_bc, fov = ops3d.conv_pass(
        last_fmap,
        kernel_size=[[1, 1, 1]],
        num_fmaps=2,
        activation=None,
        fov=fov,
        voxel_size=anisotropy,
    )

    output_shape_bc = logits_bc.get_shape().as_list()

    output_shape_c = output_shape_bc[1:]  # strip the batch dimension
    output_shape = output_shape_c[1:]  # strip the channel dimension

    probabilities = tf.reshape(tf.nn.softmax(logits_bc, dim=1)[0], output_shape_c)
    predictions = tf.argmax(probabilities, axis=0)
    print(probabilities.name)

    tf.train.export_meta_graph(filename="unet_inference.meta")
Example #2
0
def inference_net():
    input_shape = (91, 862, 862)
    raw = tf.placeholder(tf.float32, shape=input_shape)
    raw_bc = tf.reshape(raw, (1, 1) + input_shape)

    last_fmap, fov, anisotropy = autoencoder.autoencoder(
        raw_bc,
        12,
        6,
        [[1, 3, 3], [1, 3, 3], [3, 3, 3]],
        [
            [(1, 3, 3), (1, 3, 3)],
            [(1, 3, 3), (1, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
        ],
        [
            [(1, 3, 3), (1, 3, 3)],
            [(1, 3, 3), (1, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
        ],
        voxel_size=(10, 1, 1),
        fov=(10, 1, 1),
    )

    dist_bc, fov = ops3d.conv_pass(
        last_fmap,
        kernel_size=[[1, 1, 1]],
        num_fmaps=1,
        activation=None,
        fov=fov,
        voxel_size=anisotropy,
    )

    output_shape_bc = dist_bc.get_shape().as_list()

    output_shape_c = output_shape_bc[1:]
    output_shape = output_shape_c[1:]

    dist = tf.reshape(dist_bc, output_shape)

    tf.train.export_meta_graph(filename="autoencoder_inference.meta")
def inference_net(labels):
    input_shape = (340, 340, 340)
    raw = tf.placeholder(tf.float32, shape=input_shape)
    raw_bc = tf.reshape(raw, (1, 1) + input_shape)

    last_fmap, fov, anisotropy = unet.unet(
        raw_bc,
        [12, 12 * 6, 12 * 6 * 6, 12 * 6 * 6 * 6],
        [48, 12 * 6, 12 * 6 * 6, 12 * 6 * 6 * 6],
        [[2, 2, 2], [2, 2, 2], [3, 3, 3]],
        [
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
        ],
        [
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
        ],
        voxel_size=(1, 1, 1),
        fov=(1, 1, 1),
    )

    dist_bc, fov = ops3d.conv_pass(
        last_fmap,
        kernel_size=[[1, 1, 1]],
        num_fmaps=len(labels),
        activation=None,
        fov=fov,
        voxel_size=anisotropy,
    )
    output_shape_bc = dist_bc.get_shape().as_list()
    output_shape_c = output_shape_bc[1:]  # strip the batch dimension
    output_shape = output_shape_c[1:]

    dist_c = tf.reshape(dist_bc, output_shape_c)
    network_outputs = tf.unstack(dist_c, len(labels), axis=0)
    tf.train.export_meta_graph(filename="unet_inference.meta")
Example #4
0
def inference_net():
    input_shape = (400, 400, 400)
    raw = tf.placeholder(tf.float32, shape=input_shape)
    raw_bc = tf.reshape(raw, (1, 1) + input_shape)

    last_fmap, fov, anisotropy = unet.unet(
        raw_bc,
        12,
        6,
        [[2, 2, 2], [2, 2, 2], [3, 3, 3]],
        [
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
        ],
        [
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
        ],
        voxel_size=(1, 1, 1),
        fov=(1, 1, 1),
    )
    pred_raw_bc, fov = ops3d.conv_pass(
        last_fmap,
        kernel_size=[[1, 1, 1]],
        num_fmaps=1,
        activation=None,
        fov=fov,
        voxel_size=anisotropy,
    )
    output_shape_bc = pred_raw_bc.get_shape().as_list()
    output_shape_c = output_shape_bc[1:]
    output_shape = output_shape_c[1:]

    pred_raw = tf.reshape(pred_raw_bc, output_shape)

    tf.train.export_meta_graph(filename="unet_inference.meta")
Example #5
0
def train_net():

    # z    [1, 1, 1]:  66 ->  38 -> 10
    # y, x [2, 2, 2]: 228 -> 140 -> 52
    shape_0 = (220, ) * 3
    shape_1 = (132, ) * 3
    shape_2 = (44, ) * 3

    affs_0_bc = tf.ones((1, 3) + shape_0) * 0.5

    with tf.variable_scope("autocontext") as scope:

        # phase 1
        raw_0 = tf.placeholder(tf.float32, shape=shape_0)
        raw_0_bc = tf.reshape(raw_0, (1, 1) + shape_0)

        input_0 = tf.concat([raw_0_bc, affs_0_bc], 1)

        out_bc, fov, anisotropy = unet.unet(
            input_0,
            24,
            3,
            [[2, 2, 2], [2, 2, 2], [2, 2, 2]],
            [
                [(3, 3, 3), (3, 3, 3)],
                [(3, 3, 3), (3, 3, 3)],
                [(3, 3, 3), (3, 3, 3)],
                [(3, 3, 3), (3, 3, 3)],
            ],
            [
                [(3, 3, 3), (3, 3, 3)],
                [(3, 3, 3), (3, 3, 3)],
                [(3, 3, 3), (3, 3, 3)],
                [(3, 3, 3), (3, 3, 3)],
            ],
        )

        affs_1_bc, fov = ops3d.conv_pass(
            out_bc,
            kernel_size=[[1, 1, 1]],
            num_fmaps=3,
            activation="sigmoid",
            fov=fov,
            voxel_size=anisotropy,
        )

        affs_1_c = tf.reshape(affs_1_bc, (3, ) + shape_1)
        gt_affs_1_c = tf.placeholder(tf.float32, shape=(3, ) + shape_1)
        loss_weights_1_c = tf.placeholder(tf.float32, shape=(3, ) + shape_1)

        loss_1 = tf.losses.mean_squared_error(gt_affs_1_c, affs_1_c,
                                              loss_weights_1_c)

        # phase 2
        tf.summary.scalar("loss_pred0", loss_1)
        scope.reuse_variables()

        raw_1 = ops3d.center_crop(raw_0, shape_1)
        raw_1_bc = tf.reshape(raw_1, (1, 1) + shape_1)

        input_1 = tf.concat([raw_1_bc, affs_1_bc], 1)

        out_bc, fov, anisotropy = unet.unet(
            input_1,
            24,
            3,
            [[2, 2, 2], [2, 2, 2], [2, 2, 2]],
            [
                [(3, 3, 3), (3, 3, 3)],
                [(3, 3, 3), (3, 3, 3)],
                [(3, 3, 3), (3, 3, 3)],
                [(3, 3, 3), (3, 3, 3)],
            ],
            [
                [(3, 3, 3), (3, 3, 3)],
                [(3, 3, 3), (3, 3, 3)],
                [(3, 3, 3), (3, 3, 3)],
                [(3, 3, 3), (3, 3, 3)],
            ],
            fov=fov,
            voxel_size=anisotropy,
        )

        affs_2_bc, fov = ops3d.conv_pass(
            out_bc,
            kernel_size=[[1, 1, 1]],
            num_fmaps=3,
            activation="sigmoid",
            fov=fov,
            voxel_size=anisotropy,
        )

        affs_2_c = tf.reshape(affs_2_bc, (3, ) + shape_2)
        gt_affs_2_c = ops3d.center_crop(gt_affs_1_c, (3, ) + shape_2)
        loss_weights_2_c = ops3d.center_crop(loss_weights_1_c, (3, ) + shape_2)

        loss_2 = tf.losses.mean_squared_error(gt_affs_2_c, affs_2_c,
                                              loss_weights_2_c)
        tf.summary.scalar("loss_pred1", loss_2)
    loss = loss_1 + loss_2
    tf.summary.scalar("loss_total", loss)
    tf.summary.scalar("loss_diff", loss_1 - loss_2)
    for trainable in tf.trainable_variables():
        custom_ops.tf_var_summary(trainable)
    merged = tf.summary.merge_all()

    opt = tf.train.AdamOptimizer(learning_rate=0.5e-4,
                                 beta1=0.95,
                                 beta2=0.999,
                                 epsilon=1e-8)
    optimizer = opt.minimize(loss)

    tf.train.export_meta_graph(filename="wnet.meta")

    names = {
        "raw": raw_0.name,
        "affs_1": affs_1_c.name,
        "affs_2": affs_2_c.name,
        "gt_affs": gt_affs_1_c.name,
        "loss_weights": loss_weights_1_c.name,
        "loss": loss.name,
        "optimizer": optimizer.name,
        "summary": merged.name,
    }
    with open("net_io_names.json", "w") as f:
        json.dump(names, f)
Example #6
0
def make_net(labels, added_steps, mode="train", loss_name="loss_total"):
    unet0 = scale_net.SerialUNet(
        [12, 12 * 6, 12 * 6**2],
        [48, 12 * 6, 12 * 6**2],
        [(3, 3, 3), (3, 3, 3)],
        [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3),
                                                          (3, 3, 3)]],
        [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]],
        input_voxel_size=(4, 4, 4),
    )
    unet1 = scale_net.SerialUNet(
        [12, 12 * 6, 12 * 6**2],
        [12 * 6**2, 12 * 6**2, 12 * 6**2],
        [(3, 3, 3), (3, 3, 3)],
        [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3),
                                                          (3, 3, 3)]],
        [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]],
        input_voxel_size=(36, 36, 36),
    )
    # input_voxel_size=(
    # 36,36,36))
    input_size = unet0.min_input_shape
    input_size_actual = input_size + added_steps * unet0.step_valid_shape
    scnet = scale_net.ScaleNet([unet0, unet1],
                               input_size_actual,
                               name="scnet_" + mode)
    inputs = []
    names = dict()
    for k, (inp, vs) in enumerate(zip(scnet.input_shapes, scnet.voxel_sizes)):
        raw = tf.placeholder(tf.float32, shape=inp)
        raw_bc = tf.reshape(raw, (1, 1) + tuple(inp.astype(np.int)))
        inputs.append(raw_bc)
        names["raw_{0:}".format(vs[0])] = raw.name

    last_fmap, fov, anisotropy = scnet.build(inputs)

    dist_bc, fov = ops3d.conv_pass(
        last_fmap,
        kernel_size=[(1, 1, 1)],
        num_fmaps=len(labels),
        activation=None,
        fov=fov,
        voxel_size=anisotropy,
    )
    output_shape_bc = dist_bc.get_shape().as_list()
    output_shape_c = output_shape_bc[1:]
    output_shape = output_shape_c[1:]

    dist_c = tf.reshape(dist_bc, output_shape_c)
    names["dist"] = dist_c.name
    network_outputs = tf.unstack(dist_c, len(labels), axis=0)
    if mode.lower() == "train" or mode.lower() == "training":
        # mask = tf.placeholder(tf.float32, shape=output_shape)
        # names['mask'] = mask.name
        # ribo_mask = tf.placeholder(tf.float32, shape=output_shape)
        # names['ribo_mask'] = ribo_mask.name
        gt = []
        w = []
        cw = []
        masks = []
        for l in labels:
            masks.append(tf.placeholder(tf.float32, shape=output_shape))
            gt.append(tf.placeholder(tf.float32, shape=output_shape))
            w.append(tf.placeholder(tf.float32, shape=output_shape))
            cw.append(l.class_weight)
        lb = []
        lub = []
        for output_it, gt_it, w_it, m_it, l in zip(network_outputs, gt, w,
                                                   masks, labels):
            lb.append(
                tf.losses.mean_squared_error(gt_it, output_it, w_it * m_it))
            lub.append(tf.losses.mean_squared_error(gt_it, output_it, m_it))
            # if l.labelname != 'ribosomes':
            #    lub.append(tf.losses.mean_squared_error(gt_it, output_it, mask))
            # else:
            #    lub.append(tf.losses.mean_squared_error(gt_it, output_it, ribo_mask))
            names[l.labelname] = output_it.name
            names["gt_" + l.labelname] = gt_it.name
            names["w_" + l.labelname] = w_it.name
            names["mask_" + l.labelname] = m_it.name
        for l, lb_it, lub_it in zip(labels, lb, lub):
            tf.summary.scalar("lb_" + l.labelname, lb_it)
            tf.summary.scalar("lub_" + l.labelname, lub_it)
            names["lb_" + l.labelname] = lb_it.name
            names["lub_" + l.labelname] = lub_it.name

        loss_total = tf.add_n(lb)
        loss_total_unbalanced = tf.add_n(lub)
        loss_total_classweighted = tf.tensordot(lb, cw, axes=1)
        loss_total_unbalanced_classweighted = tf.tensordot(lub, cw, axes=1)

        tf.summary.scalar("loss_total", loss_total)
        names["loss_total"] = loss_total.name
        tf.summary.scalar("loss_total_unbalanced", loss_total_unbalanced)
        names["loss_total_unbalanced"] = loss_total_unbalanced.name
        tf.summary.scalar("loss_total_classweighted", loss_total_classweighted)
        names["loss_total_classweighted"] = loss_total_classweighted.name
        tf.summary.scalar("loss_total_unbalanced_classweighted",
                          loss_total_unbalanced_classweighted)
        names[
            "loss_total_unbalanced_classweighted"] = loss_total_unbalanced_classweighted.name

        opt = tf.train.AdamOptimizer(learning_rate=0.5e-4,
                                     beta1=0.95,
                                     beta2=0.999,
                                     epsilon=1e-8)
        if loss_name == "loss_total":
            optimizer = opt.minimize(loss_total)
        elif loss_name == "loss_total_unbalanced":
            optimizer = opt.minimize(loss_total_unbalanced)
        elif loss_name == "loss_total_unbalanced_classweighted":
            optimizer = opt.minimize(loss_total_unbalanced_classweighted)
        elif loss_name == "loss_total_classweighted":
            optimizer = opt.minimize(loss_total_classweighted)
        else:
            raise ValueError(loss_name + " not defined")
        names["optimizer"] = optimizer.name
        merged = tf.summary.merge_all()
        names["summary"] = merged.name
        with open("net_io_names.json", "w") as f:
            json.dump(names, f)
    elif mode.lower() == "inference" or mode.lower() == "prediction":
        pass
    else:
        raise ValueError(
            "unknown mode for network construction: {0:}".format(mode))
    tf.train.export_meta_graph(filename=scnet.name + ".meta")
    return scnet
Example #7
0
def train_net():
    input_shape = (43, 430, 430)
    raw = tf.placeholder(tf.float32, shape=input_shape)
    raw_bc = tf.reshape(raw, (1, 1) + input_shape)

    last_fmap, fov, anisotropy = autoencoder.autoencoder(
        raw_bc,
        12,
        6,
        [[1, 3, 3], [1, 3, 3], [3, 3, 3]],
        [
            [(1, 3, 3), (1, 3, 3)],
            [(1, 3, 3), (1, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
        ],
        [
            [(1, 3, 3), (1, 3, 3)],
            [(1, 3, 3), (1, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
        ],
        voxel_size=(10, 1, 1),
        fov=(10, 1, 1),
    )

    dist_bc, fov = ops3d.conv_pass(
        last_fmap,
        kernel_size=[[1, 1, 1]],
        num_fmaps=1,
        activation=None,
        fov=fov,
        voxel_size=anisotropy,
    )

    output_shape_bc = dist_bc.get_shape().as_list()

    output_shape_c = output_shape_bc[1:]  # strip the batch dimension
    output_shape = output_shape_c[1:]  # strip the channel dimension

    dist = tf.reshape(dist_bc, output_shape)

    gt_dist = tf.placeholder(tf.float32, shape=output_shape)

    loss_weights = tf.placeholder(tf.float32, shape=output_shape)
    mask = tf.placeholder(tf.float32, shape=output_shape)

    loss_balanced = tf.losses.mean_squared_error(gt_dist, dist, loss_weights)
    tf.summary.scalar("loss_balanced_syn", loss_balanced)

    loss_unbalanced = tf.losses.mean_squared_error(gt_dist, dist, mask)
    tf.summary.scalar("loss_unbalanced_syn", loss_unbalanced)

    opt = tf.train.AdamOptimizer(learning_rate=0.5e-4,
                                 beta1=0.95,
                                 beta2=0.999,
                                 epsilon=1e-8)

    optimizer = opt.minimize(loss_balanced)
    merged = tf.summary.merge_all()

    tf.train.export_meta_graph(filename="autoencoder.meta")

    names = {
        "raw": raw.name,
        "dist": dist.name,
        "gt_dist": gt_dist.name,
        "loss_balanced_syn": loss_balanced.name,
        "loss_unbalanced_syn": loss_unbalanced.name,
        "loss_weights": loss_weights.name,
        "mask": mask.name,
        "optimizer": optimizer.name,
        "summary": merged.name,
    }

    with open("net_io_names.json", "w") as f:
        json.dump(names, f)
Example #8
0
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
        ],
        [
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
        ],
    )

    affs_bc, fov = ops3d.conv_pass(
        last_fmap,
        kernel_size=[[1, 1, 1]],
        num_fmaps=3,
        activation="sigmoid",
        fov=fov,
        voxel_size=anisotropy,
    )

    output_shape_bc = affs_bc.get_shape().as_list()
    output_shape_c = output_shape_bc[1:]  # strip the batch dimension

    affs_c = tf.reshape(affs_bc, output_shape_c)

    gt_affs_c = tf.placeholder(tf.float32, shape=output_shape_c)

    loss_weights_c = tf.placeholder(tf.float32, shape=output_shape_c)

    loss = tf.losses.mean_squared_error(gt_affs_c, affs_c, loss_weights_c)
Example #9
0
def train_net():
    input_shape = (196, 196, 196)
    raw = tf.placeholder(tf.float32, shape=input_shape)
    raw_bc = tf.reshape(raw, (1, 1) + input_shape)

    last_fmap, fov, anisotropy = unet.unet(
        raw_bc,
        12,
        6,
        [[2, 2, 2], [2, 2, 2], [3, 3, 3]],
        [
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
        ],
        [
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
        ],
        voxel_size=(1, 1, 1),
        fov=(1, 1, 1),
    )
    pred_raw_bc, fov = ops3d.conv_pass(
        last_fmap,
        kernel_size=[[1, 1, 1]],
        num_fmaps=1,
        activation=None,
        fov=fov,
        voxel_size=anisotropy,
    )
    output_shape_bc = pred_raw_bc.get_shape().as_list()
    output_shape_c = output_shape_bc[1:]
    output_shape = output_shape_c[1:]

    pred_raw = tf.reshape(pred_raw_bc, output_shape)

    gt_raw_bc = ops3d.crop_zyx(raw_bc, output_shape_bc)
    gt_raw = tf.reshape(gt_raw_bc, output_shape)

    loss = tf.losses.mean_squared_error(gt_raw, pred_raw)
    tf.summary.scalar("loss", loss)

    opt = tf.train.AdamOptimizer(learning_rate=0.5e-4,
                                 beta1=0.95,
                                 beta2=0.999,
                                 epsilon=1e-8)
    optimizer = opt.minimize(loss)

    merged = tf.summary.merge_all()
    tf.train.export_meta_graph(filename="unet.meta")

    names = {
        "raw": raw.name,
        "pred_raw": pred_raw.name,
        "optimizer": optimizer.name,
        "summary": merged.name,
        "loss": loss.name,
    }
    with open("net_io_names.json", "w") as f:
        json.dump(names, f)
def make_net(unet, labels, added_steps, loss_name="loss_total", mode="train"):
    names = dict()
    input_size = unet.min_input_shape
    input_size_actual = (input_size + added_steps * unet.step_valid_shape).astype(
        np.int
    )

    raw = tf.placeholder(tf.float32, shape=tuple(input_size_actual))
    names["raw"] = raw.name
    raw_bc = tf.reshape(raw, (1, 1) + tuple(input_size_actual))
    last_fmap, fov, anisotropy = unet.build(raw_bc)
    dist_bc, fov = ops3d.conv_pass(
        last_fmap,
        kernel_size=[[1, 1, 1]],
        num_fmaps=len(labels),
        activation=None,
        padding=unet.padding,
        fov=fov,
        voxel_size=anisotropy,
    )
    output_shape_bc = dist_bc.get_shape().as_list()
    output_shape_c = output_shape_bc[1:]  # strip the batch dimension
    output_shape = output_shape_c[1:]

    dist_c = tf.reshape(dist_bc, output_shape_c)
    names["dist"] = dist_c.name
    network_outputs = tf.unstack(dist_c, len(labels), axis=0)
    if mode.lower() == "train" or mode.lower() == "training":
        mask = tf.placeholder(tf.float32, shape=output_shape)
        names["mask"] = mask.name
        # ribo_mask = tf.placeholder(tf.float32, shape=output_shape)

        gt = []
        w = []
        # cw = []
        masks = []
        for l in labels:
            masks.append(tf.placeholder(tf.float32, shape=output_shape))
            gt.append(tf.placeholder(tf.float32, shape=output_shape))
            w.append(tf.placeholder(tf.float32, shape=output_shape))
            #cw.append(l.class_weight)

        lb = []
        lub = []
        for output_it, gt_it, w_it, m_it, label in zip(
            network_outputs, gt, w, masks, labels
        ):
            lb.append(tf.losses.mean_squared_error(gt_it, output_it, w_it * m_it * mask))
            lub.append(tf.losses.mean_squared_error(gt_it, output_it, m_it * mask))
            names[label.labelname] = output_it.name
            names["gt_" + label.labelname] = gt_it.name
            names["w_" + label.labelname] = w_it.name
            names["mask_" + label.labelname] = m_it.name
        for label, lb_it, lub_it in zip(labels, lb, lub):
            tf.summary.scalar("lb_" + label.labelname, lb_it)
            tf.summary.scalar("lub_" + label.labelname, lub_it)
            names["lb_" + label.labelname] = lb_it.name
            names["lub_" + label.labelname] = lub_it.name

        loss_total = tf.add_n(lb)
        loss_total_unbalanced = tf.add_n(lub)
        # loss_total_classweighted = tf.tensordot(lb, cw, axes=1)
        # loss_total_unbalanced_classweighted = tf.tensordot(lub, cw, axes=1)

        tf.summary.scalar("loss_total", loss_total)
        names["loss_total"] = loss_total.name
        tf.summary.scalar("loss_total_unbalanced", loss_total_unbalanced)
        names["loss_total_unbalanced"] = loss_total_unbalanced.name
        # tf.summary.scalar("loss_total_classweighted", loss_total_classweighted)
        # names["loss_total_classweighted"] = loss_total_classweighted.name
        # tf.summary.scalar(
        #     "loss_total_unbalanced_classweighted", loss_total_unbalanced_classweighted
        # )
        # names[
        #     "loss_total_unbalanced_classweighted"
        # ] = loss_total_unbalanced_classweighted.name
        #
        opt = tf.train.AdamOptimizer(
            learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8
        )
        if loss_name == "loss_total":
            optimizer = opt.minimize(loss_total)
        elif loss_name == "loss_total_unbalanced":
            optimizer = opt.minimize(loss_total_unbalanced)
        # elif loss_name == "loss_total_unbalanced_classweighted":
        #     optimizer = opt.minimize(loss_total_unbalanced_classweighted)
        # elif loss_name == "loss_total_classweighted":
        #     optimizer = opt.minimize(loss_total_classweighted)
        else:
            raise ValueError(loss_name + " not defined")
        names["optimizer"] = optimizer.name
        merged = tf.summary.merge_all()
        names["summary"] = merged.name

        with open("net_io_names.json", "w") as f:
            json.dump(names, f)
    elif (
        mode.lower() == "inference"
        or mode.lower() == "prediction"
        or mode.lower() == "pred"
    ):
        pass
    else:
        raise ValueError("unknown mode for network construction {0:}".format(mode))
    net_name = "unet_" + mode
    tf.train.export_meta_graph(filename=net_name + ".meta")
    return net_name, input_size_actual, output_shape
Example #11
0
def make_net(
    net_name,
    unet,
    n_out,
    added_context,
    sigma=1.0,
    lamb=1.0,
    input_name="raw",
    output_names=None,
    loss_name="loss_total",
    mode="train",
):

    names = dict()
    input_size = unet.min_input_shape
    if unet.padding == "valid":
        assert np.all(np.array(added_context) % np.array(unet.step_valid_shape) == 0), "input shape not suitable for " \
                                                                                       "valid padding"
    else:
        if not np.all(np.array(added_context) > 0):
            logging.warning(
                "Small input shape does not generate any output elements free of influence from padding"
            )

    input_size_actual = (np.array(input_size) +
                         np.array(added_context)).astype(np.int)

    input = tf.placeholder(tf.float32, shape=tuple(input_size_actual))
    names[input_name] = input.name
    input_bc = tf.reshape(input, (1, 1) + tuple(input_size_actual))
    last_fmap, fov, anisotropy = unet.build(input_bc)
    output_bc, fov = ops3d.conv_pass(
        last_fmap,
        kernel_size=[[1, 1, 1]],
        num_fmaps=n_out,
        activation=None,
        padding=unet.padding,
        fov=fov,
        voxel_size=anisotropy,
    )
    output_shape_bc = output_bc.get_shape().as_list()
    output_shape_c = output_shape_bc[1:]  # strip the batch dimension
    output_shape = output_shape_c[1:]

    output_c = tf.reshape(output_bc, output_shape_c)
    names["output"] = output_c.name
    network_outputs = tf.unstack(output_c, n_out, axis=0)

    blurred_full = ops3d.gaussian_blur(input_bc, sigma)
    blurred_bc = ops3d.crop_zyx(blurred_full, output_shape_bc)
    blurred_c = tf.reshape(blurred_bc, output_shape_c)
    blurred = tf.reshape(blurred_c, output_shape)
    names["blurred"] = blurred_c.name

    if output_names is None:
        output_names = ["output_{0:}".format(n) for n in range(n_out)]
    assert len(output_names) == n_out
    if mode.lower() == "training" or mode.lower() == "forward":
        target = []
        for tgt in range(n_out):
            target.append(tf.placeholder(tf.float32, shape=output_shape))

        loss_l2 = []
        loss_l1 = []
        loss_l2_gauss = []
        loss_l1_gauss = []

        for output_it, tgt_it, out_name in zip(network_outputs, target,
                                               output_names):
            names[out_name + "_predicted"] = output_it.name
            names[out_name + "_target"] = tgt_it.name

            l2 = tf.losses.mean_squared_error(tgt_it, output_it)
            loss_l2.append(l2)
            tf.summary.scalar("l2_" + out_name, l2)
            names[out_name + "_l2"] = l2.name

            l1 = tf.losses.absolute_difference(tgt_it, output_it)
            loss_l1.append(l1)
            tf.summary.scalar("l1_" + out_name, l1)
            names[out_name + "_l1"] = l1.name

            l2_gauss = tf.losses.mean_squared_error(blurred, output_it)
            loss_l2_gauss.append(l2_gauss)
            tf.summary.scalar("l2_gauss_" + out_name, l2_gauss)
            names[out_name + "_l2_gauss"] = l2_gauss.name

            l1_gauss = tf.losses.absolute_difference(blurred, output_it)
            loss_l1_gauss.append(l1_gauss)
            tf.summary.scalar("l1_gauss_" + out_name, l1_gauss)
            names[out_name + "_l1_gauss"] = l1_gauss.name

        l2_total = tf.add_n(loss_l2)
        tf.summary.scalar("l2_total", l2_total)
        l2_gp_readout = tf.reshape(l2_total, (1, ) * 3)
        names["L2"] = l2_gp_readout.name

        l1_total = tf.add_n(loss_l1)
        tf.summary.scalar("l1_total", l1_total)
        l1_gp_readout = tf.reshape(l1_total, (1, ) * 3)
        names["L1"] = l1_gp_readout.name

        l2_gauss_total = tf.add_n(loss_l2_gauss)
        tf.summary.scalar("l2_gauss_total", l2_gauss_total)
        l2_gauss_gp_readout = tf.reshape(l2_gauss_total, (1, ) * 3)
        names["L2gauss"] = l2_gauss_gp_readout.name

        l1_gauss_total = tf.add_n(loss_l1_gauss)
        tf.summary.scalar("l1_gauss_total", l1_gauss_total)
        l1_gauss_gp_readout = tf.reshape(l1_gauss_total, (1, ) * 3)
        names["L1gauss"] = l1_gauss_gp_readout.name

        if loss_name == "L2":
            loss_opt = l2_total
        elif loss_name == "L1":
            loss_opt = l1_total
        elif loss_name == "L2+L2gauss":
            loss_opt = l2_total + lamb * l2_gauss_total
        elif loss_name == "L2+L1gauss":
            loss_opt = l2_total + lamb * l1_gauss_total
        elif loss_name == "L1+L2gauss":
            loss_opt = l1_total + lamb * l2_gauss_total
        elif loss_name == "L1+L1gauss":
            loss_opt = l1_total + lamb * l1_gauss_total
        else:
            raise ValueError(loss_name + "not defined")
        names["loss"] = loss_opt.name
        if mode.lower() == "training":
            opt = tf.train.AdamOptimizer(learning_rate=0.5e-4,
                                         beta1=0.95,
                                         beta2=0.999,
                                         epsilon=1e-8)
            optimizer = opt.minimize(loss_opt)
            names["optimizer"] = optimizer.name
            merged = tf.summary.merge_all()
            names["summary"] = merged.name

            with open("{0:}_io_names.json".format(net_name), "w") as f:
                json.dump(names, f)
    elif (mode.lower() == "inference" or mode.lower() == "prediction"
          or mode.lower() == "pred"):
        pass
    else:
        raise ValueError(
            "unknown mode for network construction {0:}".format(mode))

    tf.train.export_meta_graph(filename=net_name + "_" + mode + ".meta")
    return net_name, input_size_actual, output_shape
Example #12
0
def train_net():
    input_shape = (43, 430, 430)
    raw = tf.placeholder(tf.float32, shape=input_shape)
    raw_bc = tf.reshape(raw, (1, 1) + input_shape)

    last_fmap, fov, anisotropy = unet.unet(
        raw_bc,
        12,
        6,
        [[1, 3, 3], [1, 3, 3], [3, 3, 3]],
        [
            [(1, 3, 3), (1, 3, 3)],
            [(1, 3, 3), (1, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
        ],
        [
            [(1, 3, 3), (1, 3, 3)],
            [(1, 3, 3), (1, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
        ],
        voxel_size=(10, 1, 1),
        fov=(10, 1, 1),
    )

    logits_bc, fov = ops3d.conv_pass(
        last_fmap,
        kernel_size=[[1, 1, 1]],
        num_fmaps=2,
        activation=None,
        fov=fov,
        voxel_size=anisotropy,
    )

    output_shape_bc = logits_bc.get_shape().as_list()

    output_shape_c = output_shape_bc[1:]  # strip the batch dimension
    output_shape = output_shape_c[1:]  # strip the channel dimension
    flat_logits = tf.transpose(tf.reshape(tensor=logits_bc, shape=(2, -1)))

    gt_labels = tf.placeholder(tf.float32, shape=output_shape)
    gt_labels_flat = tf.reshape(gt_labels, (-1,))

    gt_bg = tf.to_float(tf.not_equal(gt_labels_flat, 1))
    flat_ohe = tf.stack(values=[gt_labels_flat, gt_bg], axis=1)

    loss_weights = tf.placeholder(tf.float32, shape=output_shape)
    loss_weights_flat = tf.reshape(loss_weights, (-1,))

    mask = tf.placeholder(tf.float32, shape=output_shape)
    mask_flat = tf.reshape(mask, (-1,))

    probabilities = tf.reshape(tf.nn.softmax(logits_bc, dim=1)[0], output_shape_c)
    predictions = tf.argmax(probabilities, axis=0)

    ce_loss_balanced = tf.losses.softmax_cross_entropy(
        flat_ohe, flat_logits, weights=loss_weights_flat
    )
    ce_loss_unbalanced = tf.losses.softmax_cross_entropy(
        flat_ohe, flat_logits, weights=mask_flat
    )
    tf.summary.scalar("loss_balanced_syn", ce_loss_balanced)
    tf.summary.scalar("loss_unbalanced_syn", ce_loss_unbalanced)

    opt = tf.train.AdamOptimizer(
        learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8
    )

    optimizer = opt.minimize(ce_loss_balanced)
    merged = tf.summary.merge_all()

    tf.train.export_meta_graph(filename="unet.meta")

    names = {
        "raw": raw.name,
        "probabilities": probabilities.name,
        "predictions": predictions.name,
        "gt_labels": gt_labels.name,
        "loss_balanced_syn": ce_loss_balanced.name,
        "loss_unbalanced_syn": ce_loss_unbalanced.name,
        "loss_weights": loss_weights.name,
        "mask": mask.name,
        "optimizer": optimizer.name,
        "summary": merged.name,
    }

    with open("net_io_names.json", "w") as f:
        json.dump(names, f)
def train_net(labels):
    input_shape = (196, 196, 196)
    raw = tf.placeholder(tf.float32, shape=input_shape)
    raw_bc = tf.reshape(raw, (1, 1) + input_shape)

    last_fmap, fov, anisotropy = unet.unet(
        raw_bc,
        [12, 12 * 6, 12 * 6 * 6, 12 * 6 * 6 * 6],
        [48, 12 * 6, 12 * 6 * 6, 12 * 6 * 6 * 6],
        [[2, 2, 2], [2, 2, 2], [3, 3, 3]],
        [
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
        ],
        [
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
        ],
        voxel_size=(1, 1, 1),
        fov=(1, 1, 1),
    )

    dist_bc, fov = ops3d.conv_pass(
        last_fmap,
        kernel_size=[[1, 1, 1]],
        num_fmaps=len(labels),
        activation=None,
        fov=fov,
        voxel_size=anisotropy,
    )
    output_shape_bc = dist_bc.get_shape().as_list()
    output_shape_c = output_shape_bc[1:]  # strip the batch dimension
    output_shape = output_shape_c[1:]

    dist_c = tf.reshape(dist_bc, output_shape_c)
    network_outputs = tf.unstack(dist_c, len(labels), axis=0)
    mask = tf.placeholder(tf.float32, shape=output_shape)

    gt = []
    w = []
    for l in range(len(labels)):
        gt.append(tf.placeholder(tf.float32, shape=output_shape))
        w.append(tf.placeholder(tf.float32, shape=output_shape))
    lb = []
    lub = []
    for output_it, gt_it, w_it in zip(network_outputs, gt, w):
        lb.append(tf.losses.mean_squared_error(gt_it, output_it, w_it))
        lub.append(tf.losses.mean_squared_error(gt_it, output_it, mask))
    for label, lb_it, lub_it in zip(labels, lb, lub):
        tf.summary.scalar("lb_" + label, lb_it)
        tf.summary.scalar("lub_" + label, lub_it)
    loss_total = tf.add_n(lb)
    loss_total_unbalanced = tf.add_n(lub)
    tf.summary.scalar("loss_total", loss_total)
    tf.summary.scalar("loss_total_unbalanced", loss_total_unbalanced)
    opt = tf.train.AdamOptimizer(learning_rate=0.5e-4,
                                 beta1=0.95,
                                 beta2=0.999,
                                 epsilon=1e-8)

    optimizer = opt.minimize(loss_total)
    merged = tf.summary.merge_all()

    tf.train.export_meta_graph(filename="unet.meta")

    names = {
        "raw": raw.name,
        "dist": dist_c.name,
        "loss_total": loss_total.name,
        "loss_total_unbalanced": loss_total_unbalanced.name,
        "mask": mask.name,
        "optimizer": optimizer.name,
        "summary": merged.name,
    }
    for label, output_it, gt_it, w_it, lb_it, lub_it in zip(
            labels, network_outputs, gt, w, lb, lub):
        names[label] = output_it.name
        names["gt_" + label] = gt_it.name
        names["w_" + label] = w_it.name
        names["lb_" + label] = lb_it.name
        names["lub_" + label] = lub_it.name

    with open("net_io_names.json", "w") as f:
        json.dump(names, f)
def make_net(unet, added_steps, loss_name="loss_total", padding="valid", mode="train"):
    # input_shape = (43, 430, 430)
    names = dict()
    if padding == "valid":
        input_size = unet.min_input_shape
    else:
        input_size = np.array((0, 0, 0))
    input_size_actual = (input_size + added_steps * unet.step_valid_shape).astype(
        np.int
    )
    raw = tf.placeholder(tf.float32, shape=tuple(input_size_actual))
    names["raw"] = raw.name
    raw_bc = tf.reshape(raw, (1, 1) + tuple(input_size_actual))

    last_fmap, fov, anisotropy = unet.build(raw_bc)

    dist_bc, fov = ops3d.conv_pass(
        last_fmap,
        kernel_size=[[1, 1, 1]],
        num_fmaps=3,
        activation=None,
        padding=padding,
        fov=fov,
        voxel_size=anisotropy,
    )
    output_shape_bc = dist_bc.get_shape().as_list()
    output_shape_c = output_shape_bc[1:]
    output_shape = output_shape_c[1:]

    dist_c = tf.reshape(dist_bc, shape=output_shape_c)
    names["dist"] = dist_c.name
    cleft_dist, pre_dist, post_dist = tf.unstack(dist_c, 3, axis=0)
    names["cleft_dist"] = cleft_dist.name
    names["pre_dist"] = pre_dist.name
    names["post_dist"] = post_dist.name

    if mode.lower() == "train" or mode.lower() == "training":
        gt_cleft_dist = tf.placeholder(tf.float32, shape=output_shape)
        gt_pre_dist = tf.placeholder(tf.float32, shape=output_shape)
        gt_post_dist = tf.placeholder(tf.float32, shape=output_shape)

        names["gt_cleft_dist"] = gt_cleft_dist.name
        names["gt_pre_dist"] = gt_pre_dist.name
        names["gt_post_dist"] = gt_post_dist.name

        loss_weights_cleft = tf.placeholder(tf.float32, shape=output_shape)
        loss_weights_pre = tf.placeholder(tf.float32, shape=output_shape)
        loss_weights_post = tf.placeholder(tf.float32, shape=output_shape)

        names["loss_weights_cleft"] = loss_weights_cleft.name
        names["loss_weights_pre"] = loss_weights_pre.name
        names["loss_weights_post"] = loss_weights_post.name

        cleft_mask = tf.placeholder(tf.float32, shape=output_shape)
        pre_mask = tf.placeholder(tf.float32, shape=output_shape)
        post_mask = tf.placeholder(tf.float32, shape=output_shape)

        names["cleft_mask"] = cleft_mask.name
        names["pre_mask"] = pre_mask.name
        names["post_mask"] = post_mask.name

        loss_balanced_cleft = tf.losses.mean_squared_error(
            gt_cleft_dist, cleft_dist, loss_weights_cleft * cleft_mask
        )
        loss_balanced_pre = tf.losses.mean_squared_error(
            gt_pre_dist, pre_dist, loss_weights_pre * pre_mask
        )
        loss_balanced_post = tf.losses.mean_squared_error(
            gt_post_dist, post_dist, loss_weights_post * post_mask
        )

        names["loss_balanced_cleft"] = loss_balanced_cleft.name
        names["loss_balanced_pre"] = loss_balanced_pre.name
        names["loss_balanced_post"] = loss_balanced_post.name

        loss_unbalanced_cleft = tf.losses.mean_squared_error(
            gt_cleft_dist, cleft_dist, cleft_mask
        )
        loss_unbalanced_pre = tf.losses.mean_squared_error(
            gt_pre_dist, pre_dist, pre_mask
        )
        loss_unbalanced_post = tf.losses.mean_squared_error(
            gt_post_dist, post_dist, post_mask
        )

        names["loss_unbalanced_cleft"] = loss_unbalanced_cleft.name
        names["loss_unbalanced_pre"] = loss_unbalanced_pre.name
        names["loss_unbalanced_post"] = loss_unbalanced_post.name

        loss_total = loss_balanced_cleft + loss_balanced_pre + loss_balanced_post
        loss_total_unbalanced = (
            loss_unbalanced_cleft + loss_unbalanced_pre + loss_unbalanced_post
        )
        names["loss_total"] = loss_total.name
        names["loss_total_unbalanced"] = loss_total_unbalanced.name

        tf.summary.scalar("loss_balanced_cleft", loss_balanced_cleft)
        tf.summary.scalar("loss_balanced_pre", loss_balanced_pre)
        tf.summary.scalar("loss_balanced_post", loss_balanced_post)

        tf.summary.scalar("loss_unbalanced_cleft", loss_unbalanced_cleft)
        tf.summary.scalar("loss_unbalanced_pre", loss_unbalanced_pre)
        tf.summary.scalar("loss_unbalanced_post", loss_unbalanced_post)
        tf.summary.scalar("loss_total", loss_total)
        tf.summary.scalar("loss_total_unbalanced", loss_total_unbalanced)

        opt = tf.train.AdamOptimizer(
            learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8
        )
        if loss_name == "loss_total":
            optimizer = opt.minimize(loss_total)
        elif loss_name == "loss_total_unbalanced":
            optimizer = opt.minimize(loss_total_unbalanced)
        else:
            raise ValueError(loss_name + " not defined")
        names["optimizer"] = optimizer.name

        merged = tf.summary.merge_all()
        names["summary"] = merged.name

        with open("net_io_names.json", "w") as f:
            json.dump(names, f)
    elif (
        mode.lower() == "inference"
        or mode.lower() == "prediction"
        or mode.lower() == "pred"
    ):
        pass
    else:
        raise ValueError("unknown mode for netowrk construction: {0:}".format(mode))
    net_name = "unet_" + mode
    tf.train.export_meta_graph(filename=net_name + ".meta")
    return net_name, input_size_actual, output_shape
Example #15
0
def unet(
    fmaps_in,
    num_fmaps_down,
    num_fmaps_up,
    downsample_factors,
    kernel_size_down,
    kernel_size_up,
    activation="relu",
    layer=0,
    fov=(1, 1, 1),
    voxel_size=(1, 1, 1),
    constant_upsample=False,
):

    """Create a U-Net::
        f_in --> f_left --------------------------->> f_right--> f_out
                    |                                   ^
                    v                                   |
                 g_in --> g_left ------->> g_right --> g_out
                             |               ^
                             v               |
                                   ...
    where each ``-->`` is a convolution pass (see ``conv_pass``), each `-->>` a
    crop, and down and up arrows are max-pooling and transposed convolutions,
    respectively.
    The U-Net expects tensors to have shape ``(batch=1, channels, depth, height,
    width)``.
    This U-Net performs only "valid" convolutions, i.e., sizes of the feature
    maps decrease after each convolution.
    Args:
        fmaps_in:
            The input tensor.
        num_fmaps:
            The number of feature maps in the first layer. This is also the
            number of output feature maps.
        fmap_inc_factors:
            By how much to multiply the number of feature maps between layers.
            If layer 0 has ``k`` feature maps, layer ``l`` will have
            ``k*fmap_inc_factor**l``.
        downsample_factors:
            List of lists ``[z, y, x]`` to use to down- and up-sample the
            feature maps between layers.
        kernel_size_down:
            List of lists of tuples ``(z, y, x)`` of kernel sizes. The number of
            tuples in a list determines the number of convolutional layers in the
            corresponding level of the build on the left side.
        kernel_size_up:
            List of lists of tuples ``(z, y, x)`` of kernel sizes. The number of
            tuples in a list determines the number of convolutional layers in the
            corresponding level of the build on the right side. Within one of the
            lists going from left to right.
        activation:
            Which activation to use after a convolution. Accepts the name of any
            tensorflow activation function (e.g., ``relu`` for ``tf.nn.relu``).
        layer:
            Used internally to build the U-Net recursively.
        fov:
            Initial field of view in physical units
        voxel_size:
            Size of a voxel in the input data, in physical units

    """

    prefix = "    " * layer
    print(prefix + "Creating U-Net layer %i" % layer)
    print(prefix + "f_in: " + str(fmaps_in.shape))
    # if isinstance(fmap_inc_factors, int):
    #    fmap_inc_factors = [fmap_inc_factors]*len(downsample_factors)
    assert (
        len(num_fmaps_down) - 1
        == len(num_fmaps_up) - 1
        == len(downsample_factors)
        == len(kernel_size_down) - 1
        == len(kernel_size_up) - 1
    )
    # convolve
    with tf.name_scope("lev%i" % layer):

        f_left, fov = ops3d.conv_pass(
            fmaps_in,
            kernel_size=kernel_size_down[layer],
            num_fmaps=num_fmaps_down[layer],
            activation=activation,
            name="unet_layer_%i_left" % layer,
            fov=fov,
            voxel_size=voxel_size,
            prefix=prefix,
        )

        # last layer does not recurse
        bottom_layer = layer == len(downsample_factors)

        if bottom_layer:
            print(prefix + "bottom layer")
            print(prefix + "f_out: " + str(f_left.shape))
            return f_left, fov, voxel_size

        # downsample

        g_in, fov, voxel_size = ops3d.downsample(
            f_left,
            downsample_factors[layer],
            "unet_down_%i_to_%i" % (layer, layer + 1),
            fov=fov,
            voxel_size=voxel_size,
            prefix=prefix,
        )

        # recursive U-net
        g_out, fov, voxel_size = unet(
            g_in,
            num_fmaps_down=num_fmaps_down,
            num_fmaps_up=num_fmaps_up,
            downsample_factors=downsample_factors,
            kernel_size_down=kernel_size_down,
            kernel_size_up=kernel_size_up,
            activation=activation,
            layer=layer + 1,
            fov=fov,
            voxel_size=voxel_size,
            constant_upsample=constant_upsample,
        )

        print(prefix + "g_out: " + str(g_out.shape))

        # upsample
        g_out_upsampled, voxel_size = ops3d.upsample(
            g_out,
            downsample_factors[layer],
            num_fmaps_up[layer],
            activation=activation,
            name="unet_up_%i_to_%i" % (layer + 1, layer),
            fov=fov,
            voxel_size=voxel_size,
            prefix=prefix,
            constant_upsample=constant_upsample,
        )

        print(prefix + "g_out_upsampled: " + str(g_out_upsampled.shape))

        # copy-crop
        f_left_cropped = ops3d.crop_zyx(f_left, g_out_upsampled.get_shape().as_list())

        print(prefix + "f_left_cropped: " + str(f_left_cropped.shape))

        # concatenate along channel dimension
        f_right = tf.concat([f_left_cropped, g_out_upsampled], 1)

        print(prefix + "f_right: " + str(f_right.shape))

        # convolve
        f_out, fov = ops3d.conv_pass(
            f_right,
            kernel_size=kernel_size_up[layer],
            num_fmaps=num_fmaps_up[layer],
            name="unet_layer_%i_right" % layer,
            fov=fov,
            voxel_size=voxel_size,
            prefix=prefix,
        )

        print(prefix + "f_out: " + str(f_out.shape))

    return f_out, fov, voxel_size
Example #16
0
            [(1, 3, 3), (1, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
        ],
        [
            [(1, 3, 3), (1, 3, 3)],
            [(1, 3, 3), (1, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
            [(3, 3, 3), (3, 3, 3)],
        ],
        voxel_size=(10, 1, 1),
        fov=(10, 1, 1),
    )

    output, full_fov = ops3d.conv_pass(
        model,
        kernel_size=[(1, 1, 1)],
        num_fmaps=1,
        activation=None,
        fov=ll_fov,
        voxel_size=vx,
    )

    tf.train.export_meta_graph(filename="build.meta")

    with tf.Session() as session:
        session.run(tf.initialize_all_variables())
        tf.summary.FileWriter(".", graph=tf.get_default_graph())

    print(model.shape)