def create_network(input_shape, name): tf.reset_default_graph() # create a placeholder for the 3D raw input tensor raw = tf.placeholder(tf.float32, shape=input_shape) # create a U-Net raw_batched = tf.reshape(raw, (1, 1) + input_shape) unet_output = unet(raw_batched, 12, 5, [[1, 3, 3], [1, 3, 3], [1, 3, 3]]) # add a convolution layer to create 3 output maps representing affinities # in z, y, and x pred_affs_batched = conv_pass(unet_output, kernel_size=1, num_fmaps=3, num_repetitions=1, activation='sigmoid') # get the shape of the output output_shape_batched = pred_affs_batched.get_shape().as_list() output_shape = output_shape_batched[1:] # strip the batch dimension # the 4D output tensor (3, depth, height, width) pred_affs = tf.reshape(pred_affs_batched, output_shape) # create a placeholder for the corresponding ground-truth affinities gt_affs = tf.placeholder(tf.float32, shape=output_shape) # create a placeholder for per-voxel loss weights loss_weights = tf.placeholder(tf.float32, shape=output_shape) # compute the loss as the weighted mean squared error between the # predicted and the ground-truth affinities loss = tf.losses.mean_squared_error(gt_affs, pred_affs, loss_weights) # use the Adam optimizer to minimize the loss opt = tf.train.AdamOptimizer(learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8) optimizer = opt.minimize(loss) # store the network in a meta-graph file tf.train.export_meta_graph(filename=name + '.meta') # write to event tf.summary.FileWriter('.', graph=tf.get_default_graph()) # store network configuration for use in train and predict scripts config = { 'raw': raw.name, 'pred_affs': pred_affs.name, 'gt_affs': gt_affs.name, 'loss_weights': loss_weights.name, 'loss': loss.name, 'optimizer': optimizer.name, 'input_shape': input_shape, 'output_shape': output_shape[1:] } with open(name + '_config.json', 'w') as f: json.dump(config, f)
def create_network(input_shape, name, output_folder): tf.reset_default_graph() # c=3, d, h, w raw = tf.placeholder(tf.float32, shape=(3, ) + input_shape) # b=1, c=3, d, h, w raw_batched = tf.reshape(raw, ( 1, 3, ) + input_shape) out = unet(raw_batched, 12, 5, [[2, 2, 2], [2, 2, 2], [2, 2, 2]]) output_batched = conv_pass(out, kernel_size=1, num_fmaps=1, num_repetitions=1, activation='sigmoid') output_shape_batched = output_batched.get_shape().as_list() # d, h, w output_shape = output_shape_batched[2:] output = tf.reshape(output_batched, output_shape) gt = tf.placeholder(tf.float32, shape=output_shape) loss_weights = tf.placeholder(tf.float32, shape=output_shape) loss = tf.losses.mean_squared_error(gt, output, loss_weights) opt = tf.train.AdamOptimizer(learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8) optimizer = opt.minimize(loss) print("input shape: %s" % (input_shape, )) print("output shape: %s" % (output_shape, )) tf.train.export_meta_graph(filename=os.path.join(output_folder, name + '.meta')) names = { 'raw': raw.name, 'pred': output.name, 'gt': gt.name, 'loss_weights': loss_weights.name, 'loss': loss.name, 'optimizer': optimizer.name, } with open(os.path.join(output_folder, name + '_names.json'), 'w') as f: json.dump(names, f) config = { 'input_shape': input_shape, 'output_shape': output_shape, 'out_dims': 1 } with open(os.path.join(output_folder, name + '_config.json'), 'w') as f: json.dump(config, f)
def create_network(input_shape, name): tf.reset_default_graph() # c=2, d, h, w raw = tf.placeholder(tf.float32, shape=(2, ) + input_shape) # b=1, c=2, d, h, w raw_batched = tf.reshape(raw, ( 1, 2, ) + input_shape) fg_unet = unet(raw_batched, 12, 5, [[1, 2, 2], [1, 2, 2], [2, 2, 2]]) fg_batched = conv_pass(fg_unet, kernel_size=1, num_fmaps=1, num_repetitions=1, activation='sigmoid') output_shape_batched = fg_batched.get_shape().as_list() # d, h, w, strip the batch and channel dimension output_shape = tuple(output_shape_batched[2:]) fg = tf.reshape(fg_batched, output_shape) labels_fg = tf.placeholder(tf.float32, shape=output_shape) loss_weights = tf.placeholder(tf.float32, shape=output_shape) loss = tf.losses.mean_squared_error(labels_fg, fg, loss_weights) opt = tf.train.AdamOptimizer(learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8) optimizer = opt.minimize(loss) print("input shape: %s" % (input_shape, )) print("output shape: %s" % (output_shape, )) tf.train.export_meta_graph(filename=name + '.meta') tf.train.export_meta_graph(filename='train_net.meta') names = { 'raw': raw.name, 'fg': fg.name, 'loss_weights': loss_weights.name, 'loss': loss.name, 'optimizer': optimizer.name, 'labels_fg': labels_fg.name } with open(name + '_names.json', 'w') as f: json.dump(names, f) config = { 'input_shape': input_shape, 'output_shape': output_shape, 'out_dims': 1 } with open(name + '_config.json', 'w') as f: json.dump(config, f)
if __name__ == "__main__": input_shape = (200, 200) raw = tf.placeholder(tf.float32, shape=input_shape) raw_batched = tf.reshape(raw, (1, 1) + input_shape) with tf.variable_scope("embedding"): embedding_unet = unet(raw_batched, 6, 5, [[2, 2], [2, 2]]) with tf.variable_scope("fg"): fg_unet = unet(raw_batched, 3, 2, [[2, 2], [2, 2]]) embedding_batched = conv_pass( embedding_unet, kernel_size=1, num_fmaps=3, num_repetitions=1, activation=None, name="embedding", ) fg_batched = conv_pass( fg_unet, kernel_size=1, num_fmaps=1, num_repetitions=1, activation="sigmoid", name="fg", ) output_shape_batched = embedding_batched.get_shape().as_list() output_shape = tuple(
if __name__ == "__main__": input_shape = (200, 200) raw = tf.placeholder(tf.float32, shape=input_shape) raw_batched = tf.reshape(raw, (1, 1) + input_shape) with tf.variable_scope('embedding'): embedding_unet = unet(raw_batched, 6, 5, [[2, 2], [2, 2]]) with tf.variable_scope('fg'): fg_unet = unet(raw_batched, 3, 2, [[2, 2], [2, 2]]) embedding_batched = conv_pass(embedding_unet, kernel_size=1, num_fmaps=3, num_repetitions=1, activation=None, name='embedding') fg_batched = conv_pass(fg_unet, kernel_size=1, num_fmaps=1, num_repetitions=1, activation='sigmoid', name='fg') output_shape_batched = embedding_batched.get_shape().as_list() output_shape = tuple( output_shape_batched[2:]) # strip the batch and channel dimension embedding = tf.reshape(embedding_batched, (3, ) + output_shape)