def train_net(): input_shape = (84, 268, 268) raw = tf.placeholder(tf.float32, shape=input_shape) raw_batched = tf.reshape(raw, (1, 1,) + input_shape) last_fmap, fov, anisotropy = unet.unet(raw_batched, 12, 6, [[1, 3, 3], [1, 3, 3], [1, 3, 3]], [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], voxel_size=(10, 1, 1), fov=(10, 1, 1)) logits_batched, fov = unet.conv_pass( last_fmap, kernel_size=[[1, 1, 1]], num_fmaps=2, activation=None, fov=fov, voxel_size=anisotropy ) output_shape_batched = logits_batched.get_shape().as_list() output_shape = output_shape_batched[1:] # strip the batch dimension flat_logits = tf.transpose(tf.reshape(tensor=logits_batched, shape=(2,-1))) gt_labels = tf.placeholder(tf.float32, shape=output_shape[1:]) gt_labels_flat = tf.reshape(gt_labels, (-1,)) gt_bg = tf.to_float(tf.not_equal(gt_labels_flat, 1)) flat_ohe = tf.stack(values=[gt_labels_flat, gt_bg], axis=1) loss_weights = tf.placeholder(tf.float32, shape=output_shape[1:]) loss_weights_flat = tf.reshape(loss_weights, (-1,)) print(logits_batched.get_shape().as_list()) probabilities = tf.reshape(tf.nn.softmax(logits_batched, dim=1)[0], output_shape) predictions = tf.argmax(probabilities, axis=0) ce_loss_balanced = tf.losses.softmax_cross_entropy(flat_ohe, flat_logits, weights=loss_weights_flat) ce_loss_unbalanced = tf.losses.softmax_cross_entropy(flat_ohe, flat_logits) tf.summary.scalar('loss_balanced_syn', ce_loss_balanced) tf.summary.scalar('loss_unbalanced_syn', ce_loss_unbalanced) opt = tf.train.AdamOptimizer( learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8) optimizer = opt.minimize(ce_loss_balanced) merged = tf.summary.merge_all() tf.train.export_meta_graph(filename='unet.meta') names = { 'raw': raw.name, 'probabilities': probabilities.name, 'predictions': predictions.name, 'gt_labels': gt_labels.name, 'loss_balanced_syn': ce_loss_balanced.name, 'loss_unbalanced_syn': ce_loss_unbalanced.name, 'loss_weights': loss_weights.name, 'optimizer': optimizer.name, 'summary': merged.name} with open('net_io_names.json', 'w') as f: json.dump(names, f)
def train_net(): input_shape = (43, 430, 430) raw = tf.placeholder(tf.float32, shape=input_shape) raw_batched = tf.reshape(raw, ( 1, 1, ) + input_shape) last_fmap, fov, anisotropy = unet.unet( raw_batched, 12, 6, [[1, 3, 3], [1, 3, 3], [3, 3, 3]], [[(1, 3, 3), (1, 3, 3)], [(1, 3, 3), (1, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], [[(1, 3, 3), (1, 3, 3)], [(1, 3, 3), (1, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], voxel_size=(10, 1, 1), fov=(10, 1, 1)) dist_batched, fov = unet.conv_pass(last_fmap, kernel_size=[[1, 1, 1]], num_fmaps=2, activation=None, fov=fov, voxel_size=anisotropy) syn_dist, bdy_dist = tf.unstack(dist_batched, 2, axis=1) output_shape = syn_dist.get_shape().as_list() gt_syn_dist = tf.placeholder(tf.float32, shape=output_shape) gt_bdy_dist = tf.placeholder(tf.float32, shape=output_shape) loss_weights = tf.placeholder(tf.float32, shape=output_shape[1:]) loss_weights_batched = tf.reshape(loss_weights, shape=output_shape) loss_balanced_syn = tf.losses.mean_squared_error(gt_syn_dist, syn_dist, loss_weights_batched) loss_bdy = tf.losses.mean_squared_error(gt_bdy_dist, bdy_dist) loss_total = loss_balanced_syn + loss_bdy tf.summary.scalar('loss_balanced_syn', loss_balanced_syn) tf.summary.scalar('loss_bdy', loss_bdy) tf.summary.scalar('loss_total', loss_total) loss_unbalanced = tf.losses.mean_squared_error(gt_syn_dist, syn_dist) tf.summary.scalar('loss_unbalanced_syn', loss_unbalanced) opt = tf.train.AdamOptimizer(learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8) optimizer = opt.minimize(loss_total) merged = tf.summary.merge_all() tf.train.export_meta_graph(filename='unet.meta') names = { 'raw': raw.name, 'syn_dist': syn_dist.name, 'bdy_dist': bdy_dist.name, 'gt_syn_dist': gt_syn_dist.name, 'gt_bdy_dist': gt_bdy_dist.name, 'loss_balanced_syn': loss_balanced_syn.name, 'loss_unbalanced_syn': loss_unbalanced.name, 'loss_bdy': loss_bdy.name, 'loss_total': loss_total.name, 'loss_weights': loss_weights.name, 'optimizer': optimizer.name, 'summary': merged.name } with open('net_io_names.json', 'w') as f: json.dump(names, f)
def train_net(): input_shape = (84, 268, 268) raw = tf.placeholder(tf.float32, shape=input_shape) raw_batched = tf.reshape(raw, ( 1, 1, ) + input_shape) last_fmap, fov, anisotropy = unet.unet( raw_batched, 12, 6, [[1, 3, 3], [1, 3, 3], [1, 3, 3]], [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], voxel_size=(10, 1, 1), fov=(10, 1, 1)) dist_batched, fov = unet.conv_pass(last_fmap, kernel_size=[[1, 1, 1]], num_fmaps=1, activation=None, fov=fov, voxel_size=anisotropy) output_shape_batched = dist_batched.get_shape().as_list() output_shape = output_shape_batched[1:] # strip the batch dimension dist = tf.reshape(dist_batched, output_shape) gt_dist = tf.placeholder(tf.float32, shape=output_shape) loss_weights = tf.placeholder(tf.float32, shape=output_shape[1:]) loss_weights_batched = tf.reshape(loss_weights, shape=output_shape) loss_balanced = tf.losses.mean_squared_error(gt_dist, dist, loss_weights_batched) tf.summary.scalar('loss_balanced_syn', loss_balanced) loss_unbalanced = tf.losses.mean_squared_error(gt_dist, dist) tf.summary.scalar('loss_unbalanced_syn', loss_unbalanced) opt = tf.train.AdamOptimizer(learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8) optimizer = opt.minimize(loss_unbalanced) merged = tf.summary.merge_all() tf.train.export_meta_graph(filename='unet.meta') names = { 'raw': raw.name, 'dist': dist.name, 'gt_dist': gt_dist.name, 'loss_balanced_syn': loss_balanced.name, 'loss_unbalanced_syn': loss_unbalanced.name, 'loss_weights': loss_weights.name, 'optimizer': optimizer.name, 'summary': merged.name } with open('net_io_names.json', 'w') as f: json.dump(names, f)