def inference_net(): input_shape = (91, 862, 862) raw = tf.placeholder(tf.float32, shape=input_shape) raw_bc = tf.reshape(raw, (1, 1) + input_shape) last_fmap, fov, anisotropy = unet.unet( raw_bc, 12, 6, [[1, 3, 3], [1, 3, 3], [3, 3, 3]], [ [(1, 3, 3), (1, 3, 3)], [(1, 3, 3), (1, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], [ [(1, 3, 3), (1, 3, 3)], [(1, 3, 3), (1, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], voxel_size=(10, 1, 1), fov=(10, 1, 1), ) logits_bc, fov = ops3d.conv_pass( last_fmap, kernel_size=[[1, 1, 1]], num_fmaps=2, activation=None, fov=fov, voxel_size=anisotropy, ) output_shape_bc = logits_bc.get_shape().as_list() output_shape_c = output_shape_bc[1:] # strip the batch dimension output_shape = output_shape_c[1:] # strip the channel dimension probabilities = tf.reshape(tf.nn.softmax(logits_bc, dim=1)[0], output_shape_c) predictions = tf.argmax(probabilities, axis=0) print(probabilities.name) tf.train.export_meta_graph(filename="unet_inference.meta")
def inference_net(): input_shape = (400, 400, 400) raw = tf.placeholder(tf.float32, shape=input_shape) raw_bc = tf.reshape(raw, (1, 1) + input_shape) last_fmap, fov, anisotropy = unet.unet( raw_bc, 12, 6, [[2, 2, 2], [2, 2, 2], [3, 3, 3]], [ [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], [ [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], voxel_size=(1, 1, 1), fov=(1, 1, 1), ) pred_raw_bc, fov = ops3d.conv_pass( last_fmap, kernel_size=[[1, 1, 1]], num_fmaps=1, activation=None, fov=fov, voxel_size=anisotropy, ) output_shape_bc = pred_raw_bc.get_shape().as_list() output_shape_c = output_shape_bc[1:] output_shape = output_shape_c[1:] pred_raw = tf.reshape(pred_raw_bc, output_shape) tf.train.export_meta_graph(filename="unet_inference.meta")
def train_net(): # z [1, 1, 1]: 66 -> 38 -> 10 # y, x [2, 2, 2]: 228 -> 140 -> 52 shape_0 = (220, ) * 3 shape_1 = (132, ) * 3 shape_2 = (44, ) * 3 affs_0_bc = tf.ones((1, 3) + shape_0) * 0.5 with tf.variable_scope("autocontext") as scope: # phase 1 raw_0 = tf.placeholder(tf.float32, shape=shape_0) raw_0_bc = tf.reshape(raw_0, (1, 1) + shape_0) input_0 = tf.concat([raw_0_bc, affs_0_bc], 1) out_bc, fov, anisotropy = unet.unet( input_0, 24, 3, [[2, 2, 2], [2, 2, 2], [2, 2, 2]], [ [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], [ [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], ) affs_1_bc, fov = ops3d.conv_pass( out_bc, kernel_size=[[1, 1, 1]], num_fmaps=3, activation="sigmoid", fov=fov, voxel_size=anisotropy, ) affs_1_c = tf.reshape(affs_1_bc, (3, ) + shape_1) gt_affs_1_c = tf.placeholder(tf.float32, shape=(3, ) + shape_1) loss_weights_1_c = tf.placeholder(tf.float32, shape=(3, ) + shape_1) loss_1 = tf.losses.mean_squared_error(gt_affs_1_c, affs_1_c, loss_weights_1_c) # phase 2 tf.summary.scalar("loss_pred0", loss_1) scope.reuse_variables() raw_1 = ops3d.center_crop(raw_0, shape_1) raw_1_bc = tf.reshape(raw_1, (1, 1) + shape_1) input_1 = tf.concat([raw_1_bc, affs_1_bc], 1) out_bc, fov, anisotropy = unet.unet( input_1, 24, 3, [[2, 2, 2], [2, 2, 2], [2, 2, 2]], [ [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], [ [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], fov=fov, voxel_size=anisotropy, ) affs_2_bc, fov = ops3d.conv_pass( out_bc, kernel_size=[[1, 1, 1]], num_fmaps=3, activation="sigmoid", fov=fov, voxel_size=anisotropy, ) affs_2_c = tf.reshape(affs_2_bc, (3, ) + shape_2) gt_affs_2_c = ops3d.center_crop(gt_affs_1_c, (3, ) + shape_2) loss_weights_2_c = ops3d.center_crop(loss_weights_1_c, (3, ) + shape_2) loss_2 = tf.losses.mean_squared_error(gt_affs_2_c, affs_2_c, loss_weights_2_c) tf.summary.scalar("loss_pred1", loss_2) loss = loss_1 + loss_2 tf.summary.scalar("loss_total", loss) tf.summary.scalar("loss_diff", loss_1 - loss_2) for trainable in tf.trainable_variables(): custom_ops.tf_var_summary(trainable) merged = tf.summary.merge_all() opt = tf.train.AdamOptimizer(learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8) optimizer = opt.minimize(loss) tf.train.export_meta_graph(filename="wnet.meta") names = { "raw": raw_0.name, "affs_1": affs_1_c.name, "affs_2": affs_2_c.name, "gt_affs": gt_affs_1_c.name, "loss_weights": loss_weights_1_c.name, "loss": loss.name, "optimizer": optimizer.name, "summary": merged.name, } with open("net_io_names.json", "w") as f: json.dump(names, f)
def train_net(): input_shape = (196, 196, 196) raw = tf.placeholder(tf.float32, shape=input_shape) raw_bc = tf.reshape(raw, (1, 1) + input_shape) last_fmap, fov, anisotropy = unet.unet( raw_bc, 12, 6, [[2, 2, 2], [2, 2, 2], [3, 3, 3]], [ [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], [ [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], voxel_size=(1, 1, 1), fov=(1, 1, 1), ) pred_raw_bc, fov = ops3d.conv_pass( last_fmap, kernel_size=[[1, 1, 1]], num_fmaps=1, activation=None, fov=fov, voxel_size=anisotropy, ) output_shape_bc = pred_raw_bc.get_shape().as_list() output_shape_c = output_shape_bc[1:] output_shape = output_shape_c[1:] pred_raw = tf.reshape(pred_raw_bc, output_shape) gt_raw_bc = ops3d.crop_zyx(raw_bc, output_shape_bc) gt_raw = tf.reshape(gt_raw_bc, output_shape) loss = tf.losses.mean_squared_error(gt_raw, pred_raw) tf.summary.scalar("loss", loss) opt = tf.train.AdamOptimizer(learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8) optimizer = opt.minimize(loss) merged = tf.summary.merge_all() tf.train.export_meta_graph(filename="unet.meta") names = { "raw": raw.name, "pred_raw": pred_raw.name, "optimizer": optimizer.name, "summary": merged.name, "loss": loss.name, } with open("net_io_names.json", "w") as f: json.dump(names, f)
def train_net(): input_shape = (43, 430, 430) raw = tf.placeholder(tf.float32, shape=input_shape) raw_bc = tf.reshape(raw, (1, 1) + input_shape) last_fmap, fov, anisotropy = unet.unet( raw_bc, 12, 6, [[1, 3, 3], [1, 3, 3], [3, 3, 3]], [ [(1, 3, 3), (1, 3, 3)], [(1, 3, 3), (1, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], [ [(1, 3, 3), (1, 3, 3)], [(1, 3, 3), (1, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], voxel_size=(10, 1, 1), fov=(10, 1, 1), ) logits_bc, fov = ops3d.conv_pass( last_fmap, kernel_size=[[1, 1, 1]], num_fmaps=2, activation=None, fov=fov, voxel_size=anisotropy, ) output_shape_bc = logits_bc.get_shape().as_list() output_shape_c = output_shape_bc[1:] # strip the batch dimension output_shape = output_shape_c[1:] # strip the channel dimension flat_logits = tf.transpose(tf.reshape(tensor=logits_bc, shape=(2, -1))) gt_labels = tf.placeholder(tf.float32, shape=output_shape) gt_labels_flat = tf.reshape(gt_labels, (-1,)) gt_bg = tf.to_float(tf.not_equal(gt_labels_flat, 1)) flat_ohe = tf.stack(values=[gt_labels_flat, gt_bg], axis=1) loss_weights = tf.placeholder(tf.float32, shape=output_shape) loss_weights_flat = tf.reshape(loss_weights, (-1,)) mask = tf.placeholder(tf.float32, shape=output_shape) mask_flat = tf.reshape(mask, (-1,)) probabilities = tf.reshape(tf.nn.softmax(logits_bc, dim=1)[0], output_shape_c) predictions = tf.argmax(probabilities, axis=0) ce_loss_balanced = tf.losses.softmax_cross_entropy( flat_ohe, flat_logits, weights=loss_weights_flat ) ce_loss_unbalanced = tf.losses.softmax_cross_entropy( flat_ohe, flat_logits, weights=mask_flat ) tf.summary.scalar("loss_balanced_syn", ce_loss_balanced) tf.summary.scalar("loss_unbalanced_syn", ce_loss_unbalanced) opt = tf.train.AdamOptimizer( learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8 ) optimizer = opt.minimize(ce_loss_balanced) merged = tf.summary.merge_all() tf.train.export_meta_graph(filename="unet.meta") names = { "raw": raw.name, "probabilities": probabilities.name, "predictions": predictions.name, "gt_labels": gt_labels.name, "loss_balanced_syn": ce_loss_balanced.name, "loss_unbalanced_syn": ce_loss_unbalanced.name, "loss_weights": loss_weights.name, "mask": mask.name, "optimizer": optimizer.name, "summary": merged.name, } with open("net_io_names.json", "w") as f: json.dump(names, f)