def build_net(steps=steps_inference, mode='train'): unet0 = scale_net.SerialUNet([12, 12*6, 12*6**2], [12*6, 12*6, 12*6**2], [(2, 2, 2), (3, 3, 3)], [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], input_voxel_size=(4, 4, 4)) unet1 = scale_net.SerialUNet([12, 12*6, 12*6**2], [12*6**2, 12*6**2, 12*6**2], [(3, 3, 3), (3, 3, 3)], [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], input_voxel_size=(36, 36, 36)) net = make_any_scale_net([unet0,unet1], labels, steps, mode=mode) return net
labels.append( Label("microtubules", (30, 31), data_sources=data_sources, data_dir=data_dir) ) labels.append( Label("centrosome", (31, 32, 33), data_sources=data_sources, data_dir=data_dir) ) labels.append(Label("distal_app", 32, data_sources=data_sources, data_dir=data_dir)) labels.append( Label("subdistal_app", 33, data_sources=data_sources, data_dir=data_dir) ) labels.append(Label("ribosomes", 1, data_sources=ribo_sources, data_dir=data_dir)) unet0 = scale_net.SerialUNet( [12, 12 * 6, 12 * 6 ** 2], [48, 12 * 6, 12 * 6 ** 2], [(3, 3, 3), (3, 3, 3)], [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], input_voxel_size=(4, 4, 4), ) unet1 = scale_net.SerialUNet( [12, 12 * 6, 12 * 6 ** 2], [12 * 6 ** 2, 12 * 6 ** 2, 12 * 6 ** 2], [(3, 3, 3), (3, 3, 3)], [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], input_voxel_size=(36, 36, 36), ) make_any_scale_net([unet0, unet1], labels, 4, mode="inference") tf.reset_default_graph() train_sc_net = make_any_scale_net( [unet0, unet1], labels, 5, mode="train", loss_name=loss_name
def make_net(labels, added_steps, mode='train', loss_name='loss_total'): unet0 = scale_net.SerialUNet( [12, 12 * 6, 12 * 6**2], [48, 12 * 6, 12 * 6**2], [(3, 3, 3), (3, 3, 3)], [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], input_voxel_size=(4, 4, 4)) unet1 = scale_net.SerialUNet( [12, 12 * 6, 12 * 6**2], [12 * 6**2, 12 * 6**2, 12 * 6**2], [(3, 3, 3), (3, 3, 3)], [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], input_voxel_size=(36, 36, 36)) # input_voxel_size=( # 36,36,36)) input_size = unet0.min_input_shape input_size_actual = input_size + added_steps * unet0.step_valid_shape scnet = scale_net.ScaleNet([unet0, unet1], input_size_actual, name='scnet_' + mode) inputs = [] names = dict() for k, (inp, vs) in enumerate(zip(scnet.input_shapes, scnet.voxel_sizes)): raw = tf.placeholder(tf.float32, shape=inp) raw_bc = tf.reshape(raw, (1, 1) + tuple(inp.astype(np.int))) inputs.append(raw_bc) names['raw_{0:}'.format(vs[0])] = raw.name last_fmap, fov, anisotropy = scnet.build(inputs) dist_bc, fov = ops3d.conv_pass(last_fmap, kernel_size=[(1, 1, 1)], num_fmaps=len(labels), activation=None, fov=fov, voxel_size=anisotropy) output_shape_bc = dist_bc.get_shape().as_list() output_shape_c = output_shape_bc[1:] output_shape = output_shape_c[1:] dist_c = tf.reshape(dist_bc, output_shape_c) names['dist'] = dist_c.name network_outputs = tf.unstack(dist_c, len(labels), axis=0) if mode.lower() == 'train' or mode.lower() == 'training': #mask = tf.placeholder(tf.float32, shape=output_shape) #names['mask'] = mask.name #ribo_mask = tf.placeholder(tf.float32, shape=output_shape) #names['ribo_mask'] = ribo_mask.name gt = [] w = [] cw = [] masks = [] for l in labels: masks.append(tf.placeholder(tf.float32, shape=output_shape)) gt.append(tf.placeholder(tf.float32, shape=output_shape)) w.append(tf.placeholder(tf.float32, shape=output_shape)) cw.append(l.class_weight) lb = [] lub = [] for output_it, gt_it, w_it, m_it, l in zip(network_outputs, gt, w, masks, labels): lb.append( tf.losses.mean_squared_error(gt_it, output_it, w_it * m_it)) lub.append(tf.losses.mean_squared_error(gt_it, output_it, m_it)) #if l.labelname != 'ribosomes': # lub.append(tf.losses.mean_squared_error(gt_it, output_it, mask)) #else: # lub.append(tf.losses.mean_squared_error(gt_it, output_it, ribo_mask)) names[l.labelname] = output_it.name names['gt_' + l.labelname] = gt_it.name names['w_' + l.labelname] = w_it.name names['mask_' + l.labelname] = m_it.name for l, lb_it, lub_it in zip(labels, lb, lub): tf.summary.scalar('lb_' + l.labelname, lb_it) tf.summary.scalar('lub_' + l.labelname, lub_it) names['lb_' + l.labelname] = lb_it.name names['lub_' + l.labelname] = lub_it.name loss_total = tf.add_n(lb) loss_total_unbalanced = tf.add_n(lub) loss_total_classweighted = tf.tensordot(lb, cw, axes=1) loss_total_unbalanced_classweighted = tf.tensordot(lub, cw, axes=1) tf.summary.scalar('loss_total', loss_total) names['loss_total'] = loss_total.name tf.summary.scalar('loss_total_unbalanced', loss_total_unbalanced) names['loss_total_unbalanced'] = loss_total_unbalanced.name tf.summary.scalar('loss_total_classweighted', loss_total_classweighted) names['loss_total_classweighted'] = loss_total_classweighted.name tf.summary.scalar('loss_total_unbalanced_classweighted', loss_total_unbalanced_classweighted) names[ 'loss_total_unbalanced_classweighted'] = loss_total_unbalanced_classweighted.name opt = tf.train.AdamOptimizer(learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8) if loss_name == 'loss_total': optimizer = opt.minimize(loss_total) elif loss_name == 'loss_total_unbalanced': optimizer = opt.minimize(loss_total_unbalanced) elif loss_name == 'loss_total_unbalanced_classweighted': optimizer = opt.minimize(loss_total_unbalanced_classweighted) elif loss_name == 'loss_total_classweighted': optimizer = opt.minimize(loss_total_classweighted) else: raise ValueError(loss_name + " not defined") names['optimizer'] = optimizer.name merged = tf.summary.merge_all() names['summary'] = merged.name with open('net_io_names.json', 'w') as f: json.dump(names, f) elif mode.lower() == 'inference' or mode.lower() == 'prediction': pass else: raise ValueError( "unknown mode for network construction: {0:}".format(mode)) tf.train.export_meta_graph(filename=scnet.name + '.meta') return scnet