def inference_net(): input_shape = (91, 862, 862) raw = tf.placeholder(tf.float32, shape=input_shape) raw_bc = tf.reshape(raw, (1, 1) + input_shape) last_fmap, fov, anisotropy = unet.unet( raw_bc, 12, 6, [[1, 3, 3], [1, 3, 3], [3, 3, 3]], [ [(1, 3, 3), (1, 3, 3)], [(1, 3, 3), (1, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], [ [(1, 3, 3), (1, 3, 3)], [(1, 3, 3), (1, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], voxel_size=(10, 1, 1), fov=(10, 1, 1), ) logits_bc, fov = ops3d.conv_pass( last_fmap, kernel_size=[[1, 1, 1]], num_fmaps=2, activation=None, fov=fov, voxel_size=anisotropy, ) output_shape_bc = logits_bc.get_shape().as_list() output_shape_c = output_shape_bc[1:] # strip the batch dimension output_shape = output_shape_c[1:] # strip the channel dimension probabilities = tf.reshape(tf.nn.softmax(logits_bc, dim=1)[0], output_shape_c) predictions = tf.argmax(probabilities, axis=0) print(probabilities.name) tf.train.export_meta_graph(filename="unet_inference.meta")
def inference_net(): input_shape = (91, 862, 862) raw = tf.placeholder(tf.float32, shape=input_shape) raw_bc = tf.reshape(raw, (1, 1) + input_shape) last_fmap, fov, anisotropy = autoencoder.autoencoder( raw_bc, 12, 6, [[1, 3, 3], [1, 3, 3], [3, 3, 3]], [ [(1, 3, 3), (1, 3, 3)], [(1, 3, 3), (1, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], [ [(1, 3, 3), (1, 3, 3)], [(1, 3, 3), (1, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], voxel_size=(10, 1, 1), fov=(10, 1, 1), ) dist_bc, fov = ops3d.conv_pass( last_fmap, kernel_size=[[1, 1, 1]], num_fmaps=1, activation=None, fov=fov, voxel_size=anisotropy, ) output_shape_bc = dist_bc.get_shape().as_list() output_shape_c = output_shape_bc[1:] output_shape = output_shape_c[1:] dist = tf.reshape(dist_bc, output_shape) tf.train.export_meta_graph(filename="autoencoder_inference.meta")
def inference_net(labels): input_shape = (340, 340, 340) raw = tf.placeholder(tf.float32, shape=input_shape) raw_bc = tf.reshape(raw, (1, 1) + input_shape) last_fmap, fov, anisotropy = unet.unet( raw_bc, [12, 12 * 6, 12 * 6 * 6, 12 * 6 * 6 * 6], [48, 12 * 6, 12 * 6 * 6, 12 * 6 * 6 * 6], [[2, 2, 2], [2, 2, 2], [3, 3, 3]], [ [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], [ [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], voxel_size=(1, 1, 1), fov=(1, 1, 1), ) dist_bc, fov = ops3d.conv_pass( last_fmap, kernel_size=[[1, 1, 1]], num_fmaps=len(labels), activation=None, fov=fov, voxel_size=anisotropy, ) output_shape_bc = dist_bc.get_shape().as_list() output_shape_c = output_shape_bc[1:] # strip the batch dimension output_shape = output_shape_c[1:] dist_c = tf.reshape(dist_bc, output_shape_c) network_outputs = tf.unstack(dist_c, len(labels), axis=0) tf.train.export_meta_graph(filename="unet_inference.meta")
def inference_net(): input_shape = (400, 400, 400) raw = tf.placeholder(tf.float32, shape=input_shape) raw_bc = tf.reshape(raw, (1, 1) + input_shape) last_fmap, fov, anisotropy = unet.unet( raw_bc, 12, 6, [[2, 2, 2], [2, 2, 2], [3, 3, 3]], [ [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], [ [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], voxel_size=(1, 1, 1), fov=(1, 1, 1), ) pred_raw_bc, fov = ops3d.conv_pass( last_fmap, kernel_size=[[1, 1, 1]], num_fmaps=1, activation=None, fov=fov, voxel_size=anisotropy, ) output_shape_bc = pred_raw_bc.get_shape().as_list() output_shape_c = output_shape_bc[1:] output_shape = output_shape_c[1:] pred_raw = tf.reshape(pred_raw_bc, output_shape) tf.train.export_meta_graph(filename="unet_inference.meta")
def train_net(): # z [1, 1, 1]: 66 -> 38 -> 10 # y, x [2, 2, 2]: 228 -> 140 -> 52 shape_0 = (220, ) * 3 shape_1 = (132, ) * 3 shape_2 = (44, ) * 3 affs_0_bc = tf.ones((1, 3) + shape_0) * 0.5 with tf.variable_scope("autocontext") as scope: # phase 1 raw_0 = tf.placeholder(tf.float32, shape=shape_0) raw_0_bc = tf.reshape(raw_0, (1, 1) + shape_0) input_0 = tf.concat([raw_0_bc, affs_0_bc], 1) out_bc, fov, anisotropy = unet.unet( input_0, 24, 3, [[2, 2, 2], [2, 2, 2], [2, 2, 2]], [ [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], [ [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], ) affs_1_bc, fov = ops3d.conv_pass( out_bc, kernel_size=[[1, 1, 1]], num_fmaps=3, activation="sigmoid", fov=fov, voxel_size=anisotropy, ) affs_1_c = tf.reshape(affs_1_bc, (3, ) + shape_1) gt_affs_1_c = tf.placeholder(tf.float32, shape=(3, ) + shape_1) loss_weights_1_c = tf.placeholder(tf.float32, shape=(3, ) + shape_1) loss_1 = tf.losses.mean_squared_error(gt_affs_1_c, affs_1_c, loss_weights_1_c) # phase 2 tf.summary.scalar("loss_pred0", loss_1) scope.reuse_variables() raw_1 = ops3d.center_crop(raw_0, shape_1) raw_1_bc = tf.reshape(raw_1, (1, 1) + shape_1) input_1 = tf.concat([raw_1_bc, affs_1_bc], 1) out_bc, fov, anisotropy = unet.unet( input_1, 24, 3, [[2, 2, 2], [2, 2, 2], [2, 2, 2]], [ [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], [ [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], fov=fov, voxel_size=anisotropy, ) affs_2_bc, fov = ops3d.conv_pass( out_bc, kernel_size=[[1, 1, 1]], num_fmaps=3, activation="sigmoid", fov=fov, voxel_size=anisotropy, ) affs_2_c = tf.reshape(affs_2_bc, (3, ) + shape_2) gt_affs_2_c = ops3d.center_crop(gt_affs_1_c, (3, ) + shape_2) loss_weights_2_c = ops3d.center_crop(loss_weights_1_c, (3, ) + shape_2) loss_2 = tf.losses.mean_squared_error(gt_affs_2_c, affs_2_c, loss_weights_2_c) tf.summary.scalar("loss_pred1", loss_2) loss = loss_1 + loss_2 tf.summary.scalar("loss_total", loss) tf.summary.scalar("loss_diff", loss_1 - loss_2) for trainable in tf.trainable_variables(): custom_ops.tf_var_summary(trainable) merged = tf.summary.merge_all() opt = tf.train.AdamOptimizer(learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8) optimizer = opt.minimize(loss) tf.train.export_meta_graph(filename="wnet.meta") names = { "raw": raw_0.name, "affs_1": affs_1_c.name, "affs_2": affs_2_c.name, "gt_affs": gt_affs_1_c.name, "loss_weights": loss_weights_1_c.name, "loss": loss.name, "optimizer": optimizer.name, "summary": merged.name, } with open("net_io_names.json", "w") as f: json.dump(names, f)
def make_net(labels, added_steps, mode="train", loss_name="loss_total"): unet0 = scale_net.SerialUNet( [12, 12 * 6, 12 * 6**2], [48, 12 * 6, 12 * 6**2], [(3, 3, 3), (3, 3, 3)], [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], input_voxel_size=(4, 4, 4), ) unet1 = scale_net.SerialUNet( [12, 12 * 6, 12 * 6**2], [12 * 6**2, 12 * 6**2, 12 * 6**2], [(3, 3, 3), (3, 3, 3)], [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], input_voxel_size=(36, 36, 36), ) # input_voxel_size=( # 36,36,36)) input_size = unet0.min_input_shape input_size_actual = input_size + added_steps * unet0.step_valid_shape scnet = scale_net.ScaleNet([unet0, unet1], input_size_actual, name="scnet_" + mode) inputs = [] names = dict() for k, (inp, vs) in enumerate(zip(scnet.input_shapes, scnet.voxel_sizes)): raw = tf.placeholder(tf.float32, shape=inp) raw_bc = tf.reshape(raw, (1, 1) + tuple(inp.astype(np.int))) inputs.append(raw_bc) names["raw_{0:}".format(vs[0])] = raw.name last_fmap, fov, anisotropy = scnet.build(inputs) dist_bc, fov = ops3d.conv_pass( last_fmap, kernel_size=[(1, 1, 1)], num_fmaps=len(labels), activation=None, fov=fov, voxel_size=anisotropy, ) output_shape_bc = dist_bc.get_shape().as_list() output_shape_c = output_shape_bc[1:] output_shape = output_shape_c[1:] dist_c = tf.reshape(dist_bc, output_shape_c) names["dist"] = dist_c.name network_outputs = tf.unstack(dist_c, len(labels), axis=0) if mode.lower() == "train" or mode.lower() == "training": # mask = tf.placeholder(tf.float32, shape=output_shape) # names['mask'] = mask.name # ribo_mask = tf.placeholder(tf.float32, shape=output_shape) # names['ribo_mask'] = ribo_mask.name gt = [] w = [] cw = [] masks = [] for l in labels: masks.append(tf.placeholder(tf.float32, shape=output_shape)) gt.append(tf.placeholder(tf.float32, shape=output_shape)) w.append(tf.placeholder(tf.float32, shape=output_shape)) cw.append(l.class_weight) lb = [] lub = [] for output_it, gt_it, w_it, m_it, l in zip(network_outputs, gt, w, masks, labels): lb.append( tf.losses.mean_squared_error(gt_it, output_it, w_it * m_it)) lub.append(tf.losses.mean_squared_error(gt_it, output_it, m_it)) # if l.labelname != 'ribosomes': # lub.append(tf.losses.mean_squared_error(gt_it, output_it, mask)) # else: # lub.append(tf.losses.mean_squared_error(gt_it, output_it, ribo_mask)) names[l.labelname] = output_it.name names["gt_" + l.labelname] = gt_it.name names["w_" + l.labelname] = w_it.name names["mask_" + l.labelname] = m_it.name for l, lb_it, lub_it in zip(labels, lb, lub): tf.summary.scalar("lb_" + l.labelname, lb_it) tf.summary.scalar("lub_" + l.labelname, lub_it) names["lb_" + l.labelname] = lb_it.name names["lub_" + l.labelname] = lub_it.name loss_total = tf.add_n(lb) loss_total_unbalanced = tf.add_n(lub) loss_total_classweighted = tf.tensordot(lb, cw, axes=1) loss_total_unbalanced_classweighted = tf.tensordot(lub, cw, axes=1) tf.summary.scalar("loss_total", loss_total) names["loss_total"] = loss_total.name tf.summary.scalar("loss_total_unbalanced", loss_total_unbalanced) names["loss_total_unbalanced"] = loss_total_unbalanced.name tf.summary.scalar("loss_total_classweighted", loss_total_classweighted) names["loss_total_classweighted"] = loss_total_classweighted.name tf.summary.scalar("loss_total_unbalanced_classweighted", loss_total_unbalanced_classweighted) names[ "loss_total_unbalanced_classweighted"] = loss_total_unbalanced_classweighted.name opt = tf.train.AdamOptimizer(learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8) if loss_name == "loss_total": optimizer = opt.minimize(loss_total) elif loss_name == "loss_total_unbalanced": optimizer = opt.minimize(loss_total_unbalanced) elif loss_name == "loss_total_unbalanced_classweighted": optimizer = opt.minimize(loss_total_unbalanced_classweighted) elif loss_name == "loss_total_classweighted": optimizer = opt.minimize(loss_total_classweighted) else: raise ValueError(loss_name + " not defined") names["optimizer"] = optimizer.name merged = tf.summary.merge_all() names["summary"] = merged.name with open("net_io_names.json", "w") as f: json.dump(names, f) elif mode.lower() == "inference" or mode.lower() == "prediction": pass else: raise ValueError( "unknown mode for network construction: {0:}".format(mode)) tf.train.export_meta_graph(filename=scnet.name + ".meta") return scnet
def train_net(): input_shape = (43, 430, 430) raw = tf.placeholder(tf.float32, shape=input_shape) raw_bc = tf.reshape(raw, (1, 1) + input_shape) last_fmap, fov, anisotropy = autoencoder.autoencoder( raw_bc, 12, 6, [[1, 3, 3], [1, 3, 3], [3, 3, 3]], [ [(1, 3, 3), (1, 3, 3)], [(1, 3, 3), (1, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], [ [(1, 3, 3), (1, 3, 3)], [(1, 3, 3), (1, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], voxel_size=(10, 1, 1), fov=(10, 1, 1), ) dist_bc, fov = ops3d.conv_pass( last_fmap, kernel_size=[[1, 1, 1]], num_fmaps=1, activation=None, fov=fov, voxel_size=anisotropy, ) output_shape_bc = dist_bc.get_shape().as_list() output_shape_c = output_shape_bc[1:] # strip the batch dimension output_shape = output_shape_c[1:] # strip the channel dimension dist = tf.reshape(dist_bc, output_shape) gt_dist = tf.placeholder(tf.float32, shape=output_shape) loss_weights = tf.placeholder(tf.float32, shape=output_shape) mask = tf.placeholder(tf.float32, shape=output_shape) loss_balanced = tf.losses.mean_squared_error(gt_dist, dist, loss_weights) tf.summary.scalar("loss_balanced_syn", loss_balanced) loss_unbalanced = tf.losses.mean_squared_error(gt_dist, dist, mask) tf.summary.scalar("loss_unbalanced_syn", loss_unbalanced) opt = tf.train.AdamOptimizer(learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8) optimizer = opt.minimize(loss_balanced) merged = tf.summary.merge_all() tf.train.export_meta_graph(filename="autoencoder.meta") names = { "raw": raw.name, "dist": dist.name, "gt_dist": gt_dist.name, "loss_balanced_syn": loss_balanced.name, "loss_unbalanced_syn": loss_unbalanced.name, "loss_weights": loss_weights.name, "mask": mask.name, "optimizer": optimizer.name, "summary": merged.name, } with open("net_io_names.json", "w") as f: json.dump(names, f)
[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], [ [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], ) affs_bc, fov = ops3d.conv_pass( last_fmap, kernel_size=[[1, 1, 1]], num_fmaps=3, activation="sigmoid", fov=fov, voxel_size=anisotropy, ) output_shape_bc = affs_bc.get_shape().as_list() output_shape_c = output_shape_bc[1:] # strip the batch dimension affs_c = tf.reshape(affs_bc, output_shape_c) gt_affs_c = tf.placeholder(tf.float32, shape=output_shape_c) loss_weights_c = tf.placeholder(tf.float32, shape=output_shape_c) loss = tf.losses.mean_squared_error(gt_affs_c, affs_c, loss_weights_c)
def train_net(): input_shape = (196, 196, 196) raw = tf.placeholder(tf.float32, shape=input_shape) raw_bc = tf.reshape(raw, (1, 1) + input_shape) last_fmap, fov, anisotropy = unet.unet( raw_bc, 12, 6, [[2, 2, 2], [2, 2, 2], [3, 3, 3]], [ [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], [ [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], voxel_size=(1, 1, 1), fov=(1, 1, 1), ) pred_raw_bc, fov = ops3d.conv_pass( last_fmap, kernel_size=[[1, 1, 1]], num_fmaps=1, activation=None, fov=fov, voxel_size=anisotropy, ) output_shape_bc = pred_raw_bc.get_shape().as_list() output_shape_c = output_shape_bc[1:] output_shape = output_shape_c[1:] pred_raw = tf.reshape(pred_raw_bc, output_shape) gt_raw_bc = ops3d.crop_zyx(raw_bc, output_shape_bc) gt_raw = tf.reshape(gt_raw_bc, output_shape) loss = tf.losses.mean_squared_error(gt_raw, pred_raw) tf.summary.scalar("loss", loss) opt = tf.train.AdamOptimizer(learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8) optimizer = opt.minimize(loss) merged = tf.summary.merge_all() tf.train.export_meta_graph(filename="unet.meta") names = { "raw": raw.name, "pred_raw": pred_raw.name, "optimizer": optimizer.name, "summary": merged.name, "loss": loss.name, } with open("net_io_names.json", "w") as f: json.dump(names, f)
def make_net(unet, labels, added_steps, loss_name="loss_total", mode="train"): names = dict() input_size = unet.min_input_shape input_size_actual = (input_size + added_steps * unet.step_valid_shape).astype( np.int ) raw = tf.placeholder(tf.float32, shape=tuple(input_size_actual)) names["raw"] = raw.name raw_bc = tf.reshape(raw, (1, 1) + tuple(input_size_actual)) last_fmap, fov, anisotropy = unet.build(raw_bc) dist_bc, fov = ops3d.conv_pass( last_fmap, kernel_size=[[1, 1, 1]], num_fmaps=len(labels), activation=None, padding=unet.padding, fov=fov, voxel_size=anisotropy, ) output_shape_bc = dist_bc.get_shape().as_list() output_shape_c = output_shape_bc[1:] # strip the batch dimension output_shape = output_shape_c[1:] dist_c = tf.reshape(dist_bc, output_shape_c) names["dist"] = dist_c.name network_outputs = tf.unstack(dist_c, len(labels), axis=0) if mode.lower() == "train" or mode.lower() == "training": mask = tf.placeholder(tf.float32, shape=output_shape) names["mask"] = mask.name # ribo_mask = tf.placeholder(tf.float32, shape=output_shape) gt = [] w = [] # cw = [] masks = [] for l in labels: masks.append(tf.placeholder(tf.float32, shape=output_shape)) gt.append(tf.placeholder(tf.float32, shape=output_shape)) w.append(tf.placeholder(tf.float32, shape=output_shape)) #cw.append(l.class_weight) lb = [] lub = [] for output_it, gt_it, w_it, m_it, label in zip( network_outputs, gt, w, masks, labels ): lb.append(tf.losses.mean_squared_error(gt_it, output_it, w_it * m_it * mask)) lub.append(tf.losses.mean_squared_error(gt_it, output_it, m_it * mask)) names[label.labelname] = output_it.name names["gt_" + label.labelname] = gt_it.name names["w_" + label.labelname] = w_it.name names["mask_" + label.labelname] = m_it.name for label, lb_it, lub_it in zip(labels, lb, lub): tf.summary.scalar("lb_" + label.labelname, lb_it) tf.summary.scalar("lub_" + label.labelname, lub_it) names["lb_" + label.labelname] = lb_it.name names["lub_" + label.labelname] = lub_it.name loss_total = tf.add_n(lb) loss_total_unbalanced = tf.add_n(lub) # loss_total_classweighted = tf.tensordot(lb, cw, axes=1) # loss_total_unbalanced_classweighted = tf.tensordot(lub, cw, axes=1) tf.summary.scalar("loss_total", loss_total) names["loss_total"] = loss_total.name tf.summary.scalar("loss_total_unbalanced", loss_total_unbalanced) names["loss_total_unbalanced"] = loss_total_unbalanced.name # tf.summary.scalar("loss_total_classweighted", loss_total_classweighted) # names["loss_total_classweighted"] = loss_total_classweighted.name # tf.summary.scalar( # "loss_total_unbalanced_classweighted", loss_total_unbalanced_classweighted # ) # names[ # "loss_total_unbalanced_classweighted" # ] = loss_total_unbalanced_classweighted.name # opt = tf.train.AdamOptimizer( learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8 ) if loss_name == "loss_total": optimizer = opt.minimize(loss_total) elif loss_name == "loss_total_unbalanced": optimizer = opt.minimize(loss_total_unbalanced) # elif loss_name == "loss_total_unbalanced_classweighted": # optimizer = opt.minimize(loss_total_unbalanced_classweighted) # elif loss_name == "loss_total_classweighted": # optimizer = opt.minimize(loss_total_classweighted) else: raise ValueError(loss_name + " not defined") names["optimizer"] = optimizer.name merged = tf.summary.merge_all() names["summary"] = merged.name with open("net_io_names.json", "w") as f: json.dump(names, f) elif ( mode.lower() == "inference" or mode.lower() == "prediction" or mode.lower() == "pred" ): pass else: raise ValueError("unknown mode for network construction {0:}".format(mode)) net_name = "unet_" + mode tf.train.export_meta_graph(filename=net_name + ".meta") return net_name, input_size_actual, output_shape
def make_net( net_name, unet, n_out, added_context, sigma=1.0, lamb=1.0, input_name="raw", output_names=None, loss_name="loss_total", mode="train", ): names = dict() input_size = unet.min_input_shape if unet.padding == "valid": assert np.all(np.array(added_context) % np.array(unet.step_valid_shape) == 0), "input shape not suitable for " \ "valid padding" else: if not np.all(np.array(added_context) > 0): logging.warning( "Small input shape does not generate any output elements free of influence from padding" ) input_size_actual = (np.array(input_size) + np.array(added_context)).astype(np.int) input = tf.placeholder(tf.float32, shape=tuple(input_size_actual)) names[input_name] = input.name input_bc = tf.reshape(input, (1, 1) + tuple(input_size_actual)) last_fmap, fov, anisotropy = unet.build(input_bc) output_bc, fov = ops3d.conv_pass( last_fmap, kernel_size=[[1, 1, 1]], num_fmaps=n_out, activation=None, padding=unet.padding, fov=fov, voxel_size=anisotropy, ) output_shape_bc = output_bc.get_shape().as_list() output_shape_c = output_shape_bc[1:] # strip the batch dimension output_shape = output_shape_c[1:] output_c = tf.reshape(output_bc, output_shape_c) names["output"] = output_c.name network_outputs = tf.unstack(output_c, n_out, axis=0) blurred_full = ops3d.gaussian_blur(input_bc, sigma) blurred_bc = ops3d.crop_zyx(blurred_full, output_shape_bc) blurred_c = tf.reshape(blurred_bc, output_shape_c) blurred = tf.reshape(blurred_c, output_shape) names["blurred"] = blurred_c.name if output_names is None: output_names = ["output_{0:}".format(n) for n in range(n_out)] assert len(output_names) == n_out if mode.lower() == "training" or mode.lower() == "forward": target = [] for tgt in range(n_out): target.append(tf.placeholder(tf.float32, shape=output_shape)) loss_l2 = [] loss_l1 = [] loss_l2_gauss = [] loss_l1_gauss = [] for output_it, tgt_it, out_name in zip(network_outputs, target, output_names): names[out_name + "_predicted"] = output_it.name names[out_name + "_target"] = tgt_it.name l2 = tf.losses.mean_squared_error(tgt_it, output_it) loss_l2.append(l2) tf.summary.scalar("l2_" + out_name, l2) names[out_name + "_l2"] = l2.name l1 = tf.losses.absolute_difference(tgt_it, output_it) loss_l1.append(l1) tf.summary.scalar("l1_" + out_name, l1) names[out_name + "_l1"] = l1.name l2_gauss = tf.losses.mean_squared_error(blurred, output_it) loss_l2_gauss.append(l2_gauss) tf.summary.scalar("l2_gauss_" + out_name, l2_gauss) names[out_name + "_l2_gauss"] = l2_gauss.name l1_gauss = tf.losses.absolute_difference(blurred, output_it) loss_l1_gauss.append(l1_gauss) tf.summary.scalar("l1_gauss_" + out_name, l1_gauss) names[out_name + "_l1_gauss"] = l1_gauss.name l2_total = tf.add_n(loss_l2) tf.summary.scalar("l2_total", l2_total) l2_gp_readout = tf.reshape(l2_total, (1, ) * 3) names["L2"] = l2_gp_readout.name l1_total = tf.add_n(loss_l1) tf.summary.scalar("l1_total", l1_total) l1_gp_readout = tf.reshape(l1_total, (1, ) * 3) names["L1"] = l1_gp_readout.name l2_gauss_total = tf.add_n(loss_l2_gauss) tf.summary.scalar("l2_gauss_total", l2_gauss_total) l2_gauss_gp_readout = tf.reshape(l2_gauss_total, (1, ) * 3) names["L2gauss"] = l2_gauss_gp_readout.name l1_gauss_total = tf.add_n(loss_l1_gauss) tf.summary.scalar("l1_gauss_total", l1_gauss_total) l1_gauss_gp_readout = tf.reshape(l1_gauss_total, (1, ) * 3) names["L1gauss"] = l1_gauss_gp_readout.name if loss_name == "L2": loss_opt = l2_total elif loss_name == "L1": loss_opt = l1_total elif loss_name == "L2+L2gauss": loss_opt = l2_total + lamb * l2_gauss_total elif loss_name == "L2+L1gauss": loss_opt = l2_total + lamb * l1_gauss_total elif loss_name == "L1+L2gauss": loss_opt = l1_total + lamb * l2_gauss_total elif loss_name == "L1+L1gauss": loss_opt = l1_total + lamb * l1_gauss_total else: raise ValueError(loss_name + "not defined") names["loss"] = loss_opt.name if mode.lower() == "training": opt = tf.train.AdamOptimizer(learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8) optimizer = opt.minimize(loss_opt) names["optimizer"] = optimizer.name merged = tf.summary.merge_all() names["summary"] = merged.name with open("{0:}_io_names.json".format(net_name), "w") as f: json.dump(names, f) elif (mode.lower() == "inference" or mode.lower() == "prediction" or mode.lower() == "pred"): pass else: raise ValueError( "unknown mode for network construction {0:}".format(mode)) tf.train.export_meta_graph(filename=net_name + "_" + mode + ".meta") return net_name, input_size_actual, output_shape
def train_net(): input_shape = (43, 430, 430) raw = tf.placeholder(tf.float32, shape=input_shape) raw_bc = tf.reshape(raw, (1, 1) + input_shape) last_fmap, fov, anisotropy = unet.unet( raw_bc, 12, 6, [[1, 3, 3], [1, 3, 3], [3, 3, 3]], [ [(1, 3, 3), (1, 3, 3)], [(1, 3, 3), (1, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], [ [(1, 3, 3), (1, 3, 3)], [(1, 3, 3), (1, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], voxel_size=(10, 1, 1), fov=(10, 1, 1), ) logits_bc, fov = ops3d.conv_pass( last_fmap, kernel_size=[[1, 1, 1]], num_fmaps=2, activation=None, fov=fov, voxel_size=anisotropy, ) output_shape_bc = logits_bc.get_shape().as_list() output_shape_c = output_shape_bc[1:] # strip the batch dimension output_shape = output_shape_c[1:] # strip the channel dimension flat_logits = tf.transpose(tf.reshape(tensor=logits_bc, shape=(2, -1))) gt_labels = tf.placeholder(tf.float32, shape=output_shape) gt_labels_flat = tf.reshape(gt_labels, (-1,)) gt_bg = tf.to_float(tf.not_equal(gt_labels_flat, 1)) flat_ohe = tf.stack(values=[gt_labels_flat, gt_bg], axis=1) loss_weights = tf.placeholder(tf.float32, shape=output_shape) loss_weights_flat = tf.reshape(loss_weights, (-1,)) mask = tf.placeholder(tf.float32, shape=output_shape) mask_flat = tf.reshape(mask, (-1,)) probabilities = tf.reshape(tf.nn.softmax(logits_bc, dim=1)[0], output_shape_c) predictions = tf.argmax(probabilities, axis=0) ce_loss_balanced = tf.losses.softmax_cross_entropy( flat_ohe, flat_logits, weights=loss_weights_flat ) ce_loss_unbalanced = tf.losses.softmax_cross_entropy( flat_ohe, flat_logits, weights=mask_flat ) tf.summary.scalar("loss_balanced_syn", ce_loss_balanced) tf.summary.scalar("loss_unbalanced_syn", ce_loss_unbalanced) opt = tf.train.AdamOptimizer( learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8 ) optimizer = opt.minimize(ce_loss_balanced) merged = tf.summary.merge_all() tf.train.export_meta_graph(filename="unet.meta") names = { "raw": raw.name, "probabilities": probabilities.name, "predictions": predictions.name, "gt_labels": gt_labels.name, "loss_balanced_syn": ce_loss_balanced.name, "loss_unbalanced_syn": ce_loss_unbalanced.name, "loss_weights": loss_weights.name, "mask": mask.name, "optimizer": optimizer.name, "summary": merged.name, } with open("net_io_names.json", "w") as f: json.dump(names, f)
def train_net(labels): input_shape = (196, 196, 196) raw = tf.placeholder(tf.float32, shape=input_shape) raw_bc = tf.reshape(raw, (1, 1) + input_shape) last_fmap, fov, anisotropy = unet.unet( raw_bc, [12, 12 * 6, 12 * 6 * 6, 12 * 6 * 6 * 6], [48, 12 * 6, 12 * 6 * 6, 12 * 6 * 6 * 6], [[2, 2, 2], [2, 2, 2], [3, 3, 3]], [ [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], [ [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], voxel_size=(1, 1, 1), fov=(1, 1, 1), ) dist_bc, fov = ops3d.conv_pass( last_fmap, kernel_size=[[1, 1, 1]], num_fmaps=len(labels), activation=None, fov=fov, voxel_size=anisotropy, ) output_shape_bc = dist_bc.get_shape().as_list() output_shape_c = output_shape_bc[1:] # strip the batch dimension output_shape = output_shape_c[1:] dist_c = tf.reshape(dist_bc, output_shape_c) network_outputs = tf.unstack(dist_c, len(labels), axis=0) mask = tf.placeholder(tf.float32, shape=output_shape) gt = [] w = [] for l in range(len(labels)): gt.append(tf.placeholder(tf.float32, shape=output_shape)) w.append(tf.placeholder(tf.float32, shape=output_shape)) lb = [] lub = [] for output_it, gt_it, w_it in zip(network_outputs, gt, w): lb.append(tf.losses.mean_squared_error(gt_it, output_it, w_it)) lub.append(tf.losses.mean_squared_error(gt_it, output_it, mask)) for label, lb_it, lub_it in zip(labels, lb, lub): tf.summary.scalar("lb_" + label, lb_it) tf.summary.scalar("lub_" + label, lub_it) loss_total = tf.add_n(lb) loss_total_unbalanced = tf.add_n(lub) tf.summary.scalar("loss_total", loss_total) tf.summary.scalar("loss_total_unbalanced", loss_total_unbalanced) opt = tf.train.AdamOptimizer(learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8) optimizer = opt.minimize(loss_total) merged = tf.summary.merge_all() tf.train.export_meta_graph(filename="unet.meta") names = { "raw": raw.name, "dist": dist_c.name, "loss_total": loss_total.name, "loss_total_unbalanced": loss_total_unbalanced.name, "mask": mask.name, "optimizer": optimizer.name, "summary": merged.name, } for label, output_it, gt_it, w_it, lb_it, lub_it in zip( labels, network_outputs, gt, w, lb, lub): names[label] = output_it.name names["gt_" + label] = gt_it.name names["w_" + label] = w_it.name names["lb_" + label] = lb_it.name names["lub_" + label] = lub_it.name with open("net_io_names.json", "w") as f: json.dump(names, f)
def make_net(unet, added_steps, loss_name="loss_total", padding="valid", mode="train"): # input_shape = (43, 430, 430) names = dict() if padding == "valid": input_size = unet.min_input_shape else: input_size = np.array((0, 0, 0)) input_size_actual = (input_size + added_steps * unet.step_valid_shape).astype( np.int ) raw = tf.placeholder(tf.float32, shape=tuple(input_size_actual)) names["raw"] = raw.name raw_bc = tf.reshape(raw, (1, 1) + tuple(input_size_actual)) last_fmap, fov, anisotropy = unet.build(raw_bc) dist_bc, fov = ops3d.conv_pass( last_fmap, kernel_size=[[1, 1, 1]], num_fmaps=3, activation=None, padding=padding, fov=fov, voxel_size=anisotropy, ) output_shape_bc = dist_bc.get_shape().as_list() output_shape_c = output_shape_bc[1:] output_shape = output_shape_c[1:] dist_c = tf.reshape(dist_bc, shape=output_shape_c) names["dist"] = dist_c.name cleft_dist, pre_dist, post_dist = tf.unstack(dist_c, 3, axis=0) names["cleft_dist"] = cleft_dist.name names["pre_dist"] = pre_dist.name names["post_dist"] = post_dist.name if mode.lower() == "train" or mode.lower() == "training": gt_cleft_dist = tf.placeholder(tf.float32, shape=output_shape) gt_pre_dist = tf.placeholder(tf.float32, shape=output_shape) gt_post_dist = tf.placeholder(tf.float32, shape=output_shape) names["gt_cleft_dist"] = gt_cleft_dist.name names["gt_pre_dist"] = gt_pre_dist.name names["gt_post_dist"] = gt_post_dist.name loss_weights_cleft = tf.placeholder(tf.float32, shape=output_shape) loss_weights_pre = tf.placeholder(tf.float32, shape=output_shape) loss_weights_post = tf.placeholder(tf.float32, shape=output_shape) names["loss_weights_cleft"] = loss_weights_cleft.name names["loss_weights_pre"] = loss_weights_pre.name names["loss_weights_post"] = loss_weights_post.name cleft_mask = tf.placeholder(tf.float32, shape=output_shape) pre_mask = tf.placeholder(tf.float32, shape=output_shape) post_mask = tf.placeholder(tf.float32, shape=output_shape) names["cleft_mask"] = cleft_mask.name names["pre_mask"] = pre_mask.name names["post_mask"] = post_mask.name loss_balanced_cleft = tf.losses.mean_squared_error( gt_cleft_dist, cleft_dist, loss_weights_cleft * cleft_mask ) loss_balanced_pre = tf.losses.mean_squared_error( gt_pre_dist, pre_dist, loss_weights_pre * pre_mask ) loss_balanced_post = tf.losses.mean_squared_error( gt_post_dist, post_dist, loss_weights_post * post_mask ) names["loss_balanced_cleft"] = loss_balanced_cleft.name names["loss_balanced_pre"] = loss_balanced_pre.name names["loss_balanced_post"] = loss_balanced_post.name loss_unbalanced_cleft = tf.losses.mean_squared_error( gt_cleft_dist, cleft_dist, cleft_mask ) loss_unbalanced_pre = tf.losses.mean_squared_error( gt_pre_dist, pre_dist, pre_mask ) loss_unbalanced_post = tf.losses.mean_squared_error( gt_post_dist, post_dist, post_mask ) names["loss_unbalanced_cleft"] = loss_unbalanced_cleft.name names["loss_unbalanced_pre"] = loss_unbalanced_pre.name names["loss_unbalanced_post"] = loss_unbalanced_post.name loss_total = loss_balanced_cleft + loss_balanced_pre + loss_balanced_post loss_total_unbalanced = ( loss_unbalanced_cleft + loss_unbalanced_pre + loss_unbalanced_post ) names["loss_total"] = loss_total.name names["loss_total_unbalanced"] = loss_total_unbalanced.name tf.summary.scalar("loss_balanced_cleft", loss_balanced_cleft) tf.summary.scalar("loss_balanced_pre", loss_balanced_pre) tf.summary.scalar("loss_balanced_post", loss_balanced_post) tf.summary.scalar("loss_unbalanced_cleft", loss_unbalanced_cleft) tf.summary.scalar("loss_unbalanced_pre", loss_unbalanced_pre) tf.summary.scalar("loss_unbalanced_post", loss_unbalanced_post) tf.summary.scalar("loss_total", loss_total) tf.summary.scalar("loss_total_unbalanced", loss_total_unbalanced) opt = tf.train.AdamOptimizer( learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8 ) if loss_name == "loss_total": optimizer = opt.minimize(loss_total) elif loss_name == "loss_total_unbalanced": optimizer = opt.minimize(loss_total_unbalanced) else: raise ValueError(loss_name + " not defined") names["optimizer"] = optimizer.name merged = tf.summary.merge_all() names["summary"] = merged.name with open("net_io_names.json", "w") as f: json.dump(names, f) elif ( mode.lower() == "inference" or mode.lower() == "prediction" or mode.lower() == "pred" ): pass else: raise ValueError("unknown mode for netowrk construction: {0:}".format(mode)) net_name = "unet_" + mode tf.train.export_meta_graph(filename=net_name + ".meta") return net_name, input_size_actual, output_shape
def unet( fmaps_in, num_fmaps_down, num_fmaps_up, downsample_factors, kernel_size_down, kernel_size_up, activation="relu", layer=0, fov=(1, 1, 1), voxel_size=(1, 1, 1), constant_upsample=False, ): """Create a U-Net:: f_in --> f_left --------------------------->> f_right--> f_out | ^ v | g_in --> g_left ------->> g_right --> g_out | ^ v | ... where each ``-->`` is a convolution pass (see ``conv_pass``), each `-->>` a crop, and down and up arrows are max-pooling and transposed convolutions, respectively. The U-Net expects tensors to have shape ``(batch=1, channels, depth, height, width)``. This U-Net performs only "valid" convolutions, i.e., sizes of the feature maps decrease after each convolution. Args: fmaps_in: The input tensor. num_fmaps: The number of feature maps in the first layer. This is also the number of output feature maps. fmap_inc_factors: By how much to multiply the number of feature maps between layers. If layer 0 has ``k`` feature maps, layer ``l`` will have ``k*fmap_inc_factor**l``. downsample_factors: List of lists ``[z, y, x]`` to use to down- and up-sample the feature maps between layers. kernel_size_down: List of lists of tuples ``(z, y, x)`` of kernel sizes. The number of tuples in a list determines the number of convolutional layers in the corresponding level of the build on the left side. kernel_size_up: List of lists of tuples ``(z, y, x)`` of kernel sizes. The number of tuples in a list determines the number of convolutional layers in the corresponding level of the build on the right side. Within one of the lists going from left to right. activation: Which activation to use after a convolution. Accepts the name of any tensorflow activation function (e.g., ``relu`` for ``tf.nn.relu``). layer: Used internally to build the U-Net recursively. fov: Initial field of view in physical units voxel_size: Size of a voxel in the input data, in physical units """ prefix = " " * layer print(prefix + "Creating U-Net layer %i" % layer) print(prefix + "f_in: " + str(fmaps_in.shape)) # if isinstance(fmap_inc_factors, int): # fmap_inc_factors = [fmap_inc_factors]*len(downsample_factors) assert ( len(num_fmaps_down) - 1 == len(num_fmaps_up) - 1 == len(downsample_factors) == len(kernel_size_down) - 1 == len(kernel_size_up) - 1 ) # convolve with tf.name_scope("lev%i" % layer): f_left, fov = ops3d.conv_pass( fmaps_in, kernel_size=kernel_size_down[layer], num_fmaps=num_fmaps_down[layer], activation=activation, name="unet_layer_%i_left" % layer, fov=fov, voxel_size=voxel_size, prefix=prefix, ) # last layer does not recurse bottom_layer = layer == len(downsample_factors) if bottom_layer: print(prefix + "bottom layer") print(prefix + "f_out: " + str(f_left.shape)) return f_left, fov, voxel_size # downsample g_in, fov, voxel_size = ops3d.downsample( f_left, downsample_factors[layer], "unet_down_%i_to_%i" % (layer, layer + 1), fov=fov, voxel_size=voxel_size, prefix=prefix, ) # recursive U-net g_out, fov, voxel_size = unet( g_in, num_fmaps_down=num_fmaps_down, num_fmaps_up=num_fmaps_up, downsample_factors=downsample_factors, kernel_size_down=kernel_size_down, kernel_size_up=kernel_size_up, activation=activation, layer=layer + 1, fov=fov, voxel_size=voxel_size, constant_upsample=constant_upsample, ) print(prefix + "g_out: " + str(g_out.shape)) # upsample g_out_upsampled, voxel_size = ops3d.upsample( g_out, downsample_factors[layer], num_fmaps_up[layer], activation=activation, name="unet_up_%i_to_%i" % (layer + 1, layer), fov=fov, voxel_size=voxel_size, prefix=prefix, constant_upsample=constant_upsample, ) print(prefix + "g_out_upsampled: " + str(g_out_upsampled.shape)) # copy-crop f_left_cropped = ops3d.crop_zyx(f_left, g_out_upsampled.get_shape().as_list()) print(prefix + "f_left_cropped: " + str(f_left_cropped.shape)) # concatenate along channel dimension f_right = tf.concat([f_left_cropped, g_out_upsampled], 1) print(prefix + "f_right: " + str(f_right.shape)) # convolve f_out, fov = ops3d.conv_pass( f_right, kernel_size=kernel_size_up[layer], num_fmaps=num_fmaps_up[layer], name="unet_layer_%i_right" % layer, fov=fov, voxel_size=voxel_size, prefix=prefix, ) print(prefix + "f_out: " + str(f_out.shape)) return f_out, fov, voxel_size
[(1, 3, 3), (1, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], [ [(1, 3, 3), (1, 3, 3)], [(1, 3, 3), (1, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], ], voxel_size=(10, 1, 1), fov=(10, 1, 1), ) output, full_fov = ops3d.conv_pass( model, kernel_size=[(1, 1, 1)], num_fmaps=1, activation=None, fov=ll_fov, voxel_size=vx, ) tf.train.export_meta_graph(filename="build.meta") with tf.Session() as session: session.run(tf.initialize_all_variables()) tf.summary.FileWriter(".", graph=tf.get_default_graph()) print(model.shape)