Esempio n. 1
0
def create_network(input_shape, name):

    tf.reset_default_graph()

    # create a placeholder for the 3D raw input tensor
    raw = tf.placeholder(tf.float32, shape=input_shape)

    # create a U-Net
    raw_batched = tf.reshape(raw, (1, 1) + input_shape)
    unet_output = unet(raw_batched, 12, 5, [[1, 3, 3], [1, 3, 3], [1, 3, 3]])

    # add a convolution layer to create 3 output maps representing affinities
    # in z, y, and x
    pred_affs_batched = conv_pass(unet_output,
                                  kernel_size=1,
                                  num_fmaps=3,
                                  num_repetitions=1,
                                  activation='sigmoid')

    # get the shape of the output
    output_shape_batched = pred_affs_batched.get_shape().as_list()
    output_shape = output_shape_batched[1:]  # strip the batch dimension

    # the 4D output tensor (3, depth, height, width)
    pred_affs = tf.reshape(pred_affs_batched, output_shape)

    # create a placeholder for the corresponding ground-truth affinities
    gt_affs = tf.placeholder(tf.float32, shape=output_shape)

    # create a placeholder for per-voxel loss weights
    loss_weights = tf.placeholder(tf.float32, shape=output_shape)

    # compute the loss as the weighted mean squared error between the
    # predicted and the ground-truth affinities
    loss = tf.losses.mean_squared_error(gt_affs, pred_affs, loss_weights)

    # use the Adam optimizer to minimize the loss
    opt = tf.train.AdamOptimizer(learning_rate=0.5e-4,
                                 beta1=0.95,
                                 beta2=0.999,
                                 epsilon=1e-8)
    optimizer = opt.minimize(loss)

    # store the network in a meta-graph file
    tf.train.export_meta_graph(filename=name + '.meta')
    # write to event
    tf.summary.FileWriter('.', graph=tf.get_default_graph())
    # store network configuration for use in train and predict scripts
    config = {
        'raw': raw.name,
        'pred_affs': pred_affs.name,
        'gt_affs': gt_affs.name,
        'loss_weights': loss_weights.name,
        'loss': loss.name,
        'optimizer': optimizer.name,
        'input_shape': input_shape,
        'output_shape': output_shape[1:]
    }
    with open(name + '_config.json', 'w') as f:
        json.dump(config, f)
Esempio n. 2
0
def create_network(input_shape, name, output_folder):

    tf.reset_default_graph()

    # c=3, d, h, w
    raw = tf.placeholder(tf.float32, shape=(3, ) + input_shape)

    # b=1, c=3, d, h, w
    raw_batched = tf.reshape(raw, (
        1,
        3,
    ) + input_shape)

    out = unet(raw_batched, 12, 5, [[2, 2, 2], [2, 2, 2], [2, 2, 2]])
    output_batched = conv_pass(out,
                               kernel_size=1,
                               num_fmaps=1,
                               num_repetitions=1,
                               activation='sigmoid')
    output_shape_batched = output_batched.get_shape().as_list()

    # d, h, w
    output_shape = output_shape_batched[2:]
    output = tf.reshape(output_batched, output_shape)

    gt = tf.placeholder(tf.float32, shape=output_shape)
    loss_weights = tf.placeholder(tf.float32, shape=output_shape)

    loss = tf.losses.mean_squared_error(gt, output, loss_weights)

    opt = tf.train.AdamOptimizer(learning_rate=0.5e-4,
                                 beta1=0.95,
                                 beta2=0.999,
                                 epsilon=1e-8)
    optimizer = opt.minimize(loss)

    print("input shape: %s" % (input_shape, ))
    print("output shape: %s" % (output_shape, ))

    tf.train.export_meta_graph(filename=os.path.join(output_folder, name +
                                                     '.meta'))

    names = {
        'raw': raw.name,
        'pred': output.name,
        'gt': gt.name,
        'loss_weights': loss_weights.name,
        'loss': loss.name,
        'optimizer': optimizer.name,
    }
    with open(os.path.join(output_folder, name + '_names.json'), 'w') as f:
        json.dump(names, f)

    config = {
        'input_shape': input_shape,
        'output_shape': output_shape,
        'out_dims': 1
    }
    with open(os.path.join(output_folder, name + '_config.json'), 'w') as f:
        json.dump(config, f)
Esempio n. 3
0
def create_network(input_shape, name):
    tf.reset_default_graph()

    # c=2, d, h, w
    raw = tf.placeholder(tf.float32, shape=(2, ) + input_shape)

    # b=1, c=2, d, h, w
    raw_batched = tf.reshape(raw, (
        1,
        2,
    ) + input_shape)

    fg_unet = unet(raw_batched, 12, 5, [[1, 2, 2], [1, 2, 2], [2, 2, 2]])

    fg_batched = conv_pass(fg_unet,
                           kernel_size=1,
                           num_fmaps=1,
                           num_repetitions=1,
                           activation='sigmoid')

    output_shape_batched = fg_batched.get_shape().as_list()

    # d, h, w, strip the batch and channel dimension
    output_shape = tuple(output_shape_batched[2:])

    fg = tf.reshape(fg_batched, output_shape)

    labels_fg = tf.placeholder(tf.float32, shape=output_shape)
    loss_weights = tf.placeholder(tf.float32, shape=output_shape)

    loss = tf.losses.mean_squared_error(labels_fg, fg, loss_weights)

    opt = tf.train.AdamOptimizer(learning_rate=0.5e-4,
                                 beta1=0.95,
                                 beta2=0.999,
                                 epsilon=1e-8)
    optimizer = opt.minimize(loss)

    print("input shape: %s" % (input_shape, ))
    print("output shape: %s" % (output_shape, ))

    tf.train.export_meta_graph(filename=name + '.meta')

    tf.train.export_meta_graph(filename='train_net.meta')

    names = {
        'raw': raw.name,
        'fg': fg.name,
        'loss_weights': loss_weights.name,
        'loss': loss.name,
        'optimizer': optimizer.name,
        'labels_fg': labels_fg.name
    }

    with open(name + '_names.json', 'w') as f:
        json.dump(names, f)

    config = {
        'input_shape': input_shape,
        'output_shape': output_shape,
        'out_dims': 1
    }
    with open(name + '_config.json', 'w') as f:
        json.dump(config, f)
Esempio n. 4
0
if __name__ == "__main__":

    input_shape = (200, 200)

    raw = tf.placeholder(tf.float32, shape=input_shape)
    raw_batched = tf.reshape(raw, (1, 1) + input_shape)

    with tf.variable_scope("embedding"):
        embedding_unet = unet(raw_batched, 6, 5, [[2, 2], [2, 2]])
    with tf.variable_scope("fg"):
        fg_unet = unet(raw_batched, 3, 2, [[2, 2], [2, 2]])

    embedding_batched = conv_pass(
        embedding_unet,
        kernel_size=1,
        num_fmaps=3,
        num_repetitions=1,
        activation=None,
        name="embedding",
    )

    fg_batched = conv_pass(
        fg_unet,
        kernel_size=1,
        num_fmaps=1,
        num_repetitions=1,
        activation="sigmoid",
        name="fg",
    )

    output_shape_batched = embedding_batched.get_shape().as_list()
    output_shape = tuple(
Esempio n. 5
0
if __name__ == "__main__":

    input_shape = (200, 200)

    raw = tf.placeholder(tf.float32, shape=input_shape)
    raw_batched = tf.reshape(raw, (1, 1) + input_shape)

    with tf.variable_scope('embedding'):
        embedding_unet = unet(raw_batched, 6, 5, [[2, 2], [2, 2]])
    with tf.variable_scope('fg'):
        fg_unet = unet(raw_batched, 3, 2, [[2, 2], [2, 2]])

    embedding_batched = conv_pass(embedding_unet,
                                  kernel_size=1,
                                  num_fmaps=3,
                                  num_repetitions=1,
                                  activation=None,
                                  name='embedding')

    fg_batched = conv_pass(fg_unet,
                           kernel_size=1,
                           num_fmaps=1,
                           num_repetitions=1,
                           activation='sigmoid',
                           name='fg')

    output_shape_batched = embedding_batched.get_shape().as_list()
    output_shape = tuple(
        output_shape_batched[2:])  # strip the batch and channel dimension

    embedding = tf.reshape(embedding_batched, (3, ) + output_shape)