def cremi_unet(name='unet', sample_to_isotropy=False):
    in_shape = (84, 268, 268)
    n_channels = 12

    # These values reproduce jans network
    initial_fmaps = 12
    fmap_increase = 5
    downsample_factors = [[1, 3, 3], [1, 3, 3], [3, 3, 3]] if sample_to_isotropy else \
        [[1, 3, 3], [1, 3, 3], [1, 3, 3]]

    raw = tf.placeholder(tf.float32, shape=in_shape)
    raw_batched = tf.reshape(raw, (
        1,
        1,
    ) + in_shape)

    unet = networks.unet(raw_batched, initial_fmaps, fmap_increase,
                         downsample_factors)

    affs_batched = networks.conv_pass(unet,
                                      kernel_size=1,
                                      num_fmaps=n_channels,
                                      num_repetitions=1,
                                      activation='sigmoid')

    output_shape_batched = affs_batched.get_shape().as_list()
    output_shape = output_shape_batched[1:]  # strip the batch dimension

    affs = tf.reshape(affs_batched, output_shape)

    gt_affs = tf.placeholder(tf.float32, shape=output_shape)

    loss_weights = tf.placeholder(tf.float32, shape=output_shape)

    loss = tf.losses.mean_squared_error(gt_affs, affs, loss_weights)
    tf.summary.scalar('loss_total', loss)

    opt = tf.train.AdamOptimizer(learning_rate=0.5e-4,
                                 beta1=0.95,
                                 beta2=0.999,
                                 epsilon=1e-8)
    optimizer = opt.minimize(loss)
    #for trainable in tf.trainable_variables():
    #    networks.tf_var_summary(trainable)
    merged = tf.summary.merge_all()

    tf.train.export_meta_graph(filename='%s.meta' % name)

    names = {
        'raw': raw.name,
        'affs': affs.name,
        'gt_affs': gt_affs.name,
        'loss_weights': loss_weights.name,
        'loss': loss.name,
        'optimizer': optimizer.name,
        'summary': merged.name
    }

    with open('net_io_names.json', 'w') as f:
        json.dump(names, f)
示例#2
0
        raw_0 = tf.placeholder(tf.float32, shape=shape_0)
        raw_0_batched = tf.reshape(raw_0, (1, 1) + shape_0)

        input_0 = tf.concat([raw_0_batched, affs_0_batched], 1)
        if ignore:
            keep_raw = tf.ones_like(raw_0_batched)
            ignore_aff = tf.zeros_like(affs_0_batched)
            ignore_mask = tf.concat([keep_raw, ignore_aff], 1)
            input_0 = networks.ignore(input_0, ignore_mask)

        unet = networks.unet(input_0, 24, 3, [[2, 2, 2], [2, 2, 2], [2, 2, 2]])

        affs_1_batched = networks.conv_pass(unet,
                                            kernel_size=1,
                                            num_fmaps=3,
                                            num_repetitions=1,
                                            activation='sigmoid')

        affs_1 = tf.reshape(affs_1_batched, (3, ) + shape_1)
        gt_affs_1 = tf.placeholder(tf.float32, shape=(3, ) + shape_1)
        loss_weights_1 = tf.placeholder(tf.float32, shape=(3, ) + shape_1)

        loss_1 = tf.losses.mean_squared_error(gt_affs_1, affs_1,
                                              loss_weights_1)

        # phase 2
        tf.summary.scalar('loss_pred0', loss_1)
        scope.reuse_variables()
        tf.stop_gradient(affs_1_batched)
        raw_1 = center_crop(raw_0, shape_1)
示例#3
0
                       (1, 3, 3)], [(3, 3, 3),
                                    (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]],
        [[(1, 3, 3),
          (1, 3, 3)], [(1, 3, 3),
                       (1, 3, 3)], [(3, 3, 3),
                                    (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]],
        anisotropy=[10, 1, 1],
        fov=[10, 1, 1])
    # raw = tf.placeholder(tf.float32, shape=(132,)*3)
    # raw_batched = tf.reshape(raw, (1, 1,) + (132,)*3)
    #
    # unet = networks.unet(raw_batched, 24, 3, [[2, 2, 2], [2, 2, 2], [2, 2, 2]])

    dist_batched, fov = networks.conv_pass(unet,
                                           kernel_size=[[1, 1, 1]],
                                           num_fmaps=1,
                                           activation=None,
                                           fov=fov,
                                           anisotropy=anisotropy)

    output_shape_batched = dist_batched.get_shape().as_list()

    output_shape = output_shape_batched[1:]  # strip the batch dimension

    dist = tf.reshape(dist_batched, output_shape)

    gt_dist = tf.placeholder(tf.float32, shape=output_shape)

    loss_weights = tf.placeholder(tf.float32, shape=output_shape[1:])
    loss_weights_batched = tf.reshape(loss_weights, shape=output_shape)

    loss_eucl = tf.losses.mean_squared_error(gt_dist, dist,
示例#4
0
    raw = tf.placeholder(tf.float32, shape=(196, ) * 3)
    raw_batched = tf.reshape(raw, (
        1,
        1,
    ) + (196, ) * 3)

    unet = networks.unet(raw_batched, 12, 6, [[2, 2, 2], [2, 2, 2], [3, 3, 3]])

    # raw = tf.placeholder(tf.float32, shape=(132,)*3)
    # raw_batched = tf.reshape(raw, (1, 1,) + (132,)*3)
    #
    # unet = networks.unet(raw_batched, 24, 3, [[2, 2, 2], [2, 2, 2], [2, 2, 2]])

    dist_batched = networks.conv_pass(unet,
                                      kernel_size=1,
                                      num_fmaps=1,
                                      num_repetitions=1,
                                      activation='tanh')

    output_shape_batched = dist_batched.get_shape().as_list()

    output_shape = output_shape_batched[1:]  # strip the batch dimension

    dist = tf.reshape(dist_batched, output_shape)

    gt_dist = tf.placeholder(tf.float32, shape=output_shape)

    #loss_weights = tf.placeholder(tf.float32, shape=output_shape)

    loss_eucl = tf.losses.mean_squared_error(gt_dist, dist)
    tf.summary.scalar('loss_total', loss_eucl)
                       (1, 3, 3)], [(3, 3, 3),
                                    (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]],
        [[(1, 3, 3),
          (1, 3, 3)], [(1, 3, 3),
                       (1, 3, 3)], [(3, 3, 3),
                                    (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]],
        voxel_size=[10, 1, 1],
        fov=[10, 1, 1])
    # raw = tf.placeholder(tf.float32, shape=(132,)*3)
    # raw_batched = tf.reshape(raw, (1, 1,) + (132,)*3)
    #
    # unet = networks.unet(raw_batched, 24, 3, [[2, 2, 2], [2, 2, 2], [2, 2, 2]])

    dist_batched, fov1 = networks.conv_pass(unet,
                                            kernel_size=[[1, 1, 1]],
                                            num_fmaps=1,
                                            activation=None,
                                            fov=fov,
                                            voxel_size=anisotropy)
    aff_batched, fov = networks.conv_pass(unet,
                                          kernel_size=[[1, 1, 1]],
                                          num_fmaps=3,
                                          activation=None,
                                          name='aff_conv',
                                          fov=fov,
                                          voxel_size=anisotropy)
    print("distbatched", dist_batched.get_shape().as_list())
    dist_output_shape_batched = dist_batched.get_shape().as_list()
    dist_output_shape = dist_output_shape_batched[1:]
    syn_dist = tf.reshape(dist_batched, dist_output_shape)
    gt_syn_dist = tf.placeholder(tf.float32, shape=dist_output_shape)
    loss_weights = tf.placeholder(tf.float32, shape=dist_output_shape[1:])