def make_net(labels, added_steps, mode='train', loss_name='loss_total'): unet0 = scale_net.SerialUNet( [12, 12 * 6, 12 * 6**2], [48, 12 * 6, 12 * 6**2], [(3, 3, 3), (3, 3, 3)], [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], input_voxel_size=(4, 4, 4)) unet1 = scale_net.SerialUNet( [12, 12 * 6, 12 * 6**2], [12 * 6**2, 12 * 6**2, 12 * 6**2], [(3, 3, 3), (3, 3, 3)], [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], input_voxel_size=(36, 36, 36)) # input_voxel_size=( # 36,36,36)) input_size = unet0.min_input_shape input_size_actual = input_size + added_steps * unet0.step_valid_shape scnet = scale_net.ScaleNet([unet0, unet1], input_size_actual, name='scnet_' + mode) inputs = [] names = dict() for k, (inp, vs) in enumerate(zip(scnet.input_shapes, scnet.voxel_sizes)): raw = tf.placeholder(tf.float32, shape=inp) raw_bc = tf.reshape(raw, (1, 1) + tuple(inp.astype(np.int))) inputs.append(raw_bc) names['raw_{0:}'.format(vs[0])] = raw.name last_fmap, fov, anisotropy = scnet.build(inputs) dist_bc, fov = ops3d.conv_pass(last_fmap, kernel_size=[(1, 1, 1)], num_fmaps=len(labels), activation=None, fov=fov, voxel_size=anisotropy) output_shape_bc = dist_bc.get_shape().as_list() output_shape_c = output_shape_bc[1:] output_shape = output_shape_c[1:] dist_c = tf.reshape(dist_bc, output_shape_c) names['dist'] = dist_c.name network_outputs = tf.unstack(dist_c, len(labels), axis=0) if mode.lower() == 'train' or mode.lower() == 'training': #mask = tf.placeholder(tf.float32, shape=output_shape) #names['mask'] = mask.name #ribo_mask = tf.placeholder(tf.float32, shape=output_shape) #names['ribo_mask'] = ribo_mask.name gt = [] w = [] cw = [] masks = [] for l in labels: masks.append(tf.placeholder(tf.float32, shape=output_shape)) gt.append(tf.placeholder(tf.float32, shape=output_shape)) w.append(tf.placeholder(tf.float32, shape=output_shape)) cw.append(l.class_weight) lb = [] lub = [] for output_it, gt_it, w_it, m_it, l in zip(network_outputs, gt, w, masks, labels): lb.append( tf.losses.mean_squared_error(gt_it, output_it, w_it * m_it)) lub.append(tf.losses.mean_squared_error(gt_it, output_it, m_it)) #if l.labelname != 'ribosomes': # lub.append(tf.losses.mean_squared_error(gt_it, output_it, mask)) #else: # lub.append(tf.losses.mean_squared_error(gt_it, output_it, ribo_mask)) names[l.labelname] = output_it.name names['gt_' + l.labelname] = gt_it.name names['w_' + l.labelname] = w_it.name names['mask_' + l.labelname] = m_it.name for l, lb_it, lub_it in zip(labels, lb, lub): tf.summary.scalar('lb_' + l.labelname, lb_it) tf.summary.scalar('lub_' + l.labelname, lub_it) names['lb_' + l.labelname] = lb_it.name names['lub_' + l.labelname] = lub_it.name loss_total = tf.add_n(lb) loss_total_unbalanced = tf.add_n(lub) loss_total_classweighted = tf.tensordot(lb, cw, axes=1) loss_total_unbalanced_classweighted = tf.tensordot(lub, cw, axes=1) tf.summary.scalar('loss_total', loss_total) names['loss_total'] = loss_total.name tf.summary.scalar('loss_total_unbalanced', loss_total_unbalanced) names['loss_total_unbalanced'] = loss_total_unbalanced.name tf.summary.scalar('loss_total_classweighted', loss_total_classweighted) names['loss_total_classweighted'] = loss_total_classweighted.name tf.summary.scalar('loss_total_unbalanced_classweighted', loss_total_unbalanced_classweighted) names[ 'loss_total_unbalanced_classweighted'] = loss_total_unbalanced_classweighted.name opt = tf.train.AdamOptimizer(learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8) if loss_name == 'loss_total': optimizer = opt.minimize(loss_total) elif loss_name == 'loss_total_unbalanced': optimizer = opt.minimize(loss_total_unbalanced) elif loss_name == 'loss_total_unbalanced_classweighted': optimizer = opt.minimize(loss_total_unbalanced_classweighted) elif loss_name == 'loss_total_classweighted': optimizer = opt.minimize(loss_total_classweighted) else: raise ValueError(loss_name + " not defined") names['optimizer'] = optimizer.name merged = tf.summary.merge_all() names['summary'] = merged.name with open('net_io_names.json', 'w') as f: json.dump(names, f) elif mode.lower() == 'inference' or mode.lower() == 'prediction': pass else: raise ValueError( "unknown mode for network construction: {0:}".format(mode)) tf.train.export_meta_graph(filename=scnet.name + '.meta') return scnet
def make_any_scale_net( serial_unet_list, labels, added_steps, mode="train", loss_name="loss_total" ): # input_voxel_size=( # 36,36,36)) input_size = serial_unet_list[0].min_input_shape input_size_actual = input_size + added_steps * serial_unet_list[0].step_valid_shape scnet = scale_net.ScaleNet( serial_unet_list, input_size_actual, name="scnet_" + mode ) inputs = [] names = dict() for k, (inp, vs) in enumerate(zip(scnet.input_shapes, scnet.voxel_sizes)): raw = tf.placeholder(tf.float32, shape=inp) raw_bc = tf.reshape(raw, (1, 1) + tuple(inp.astype(np.int))) inputs.append(raw_bc) names["raw_{0:}".format(vs[0])] = raw.name last_fmap, fov, anisotropy = scnet.build(inputs) dist_bc, fov = ops3d.conv_pass( last_fmap, kernel_size=[(1, 1, 1)], num_fmaps=len(labels), activation=None, fov=fov, voxel_size=anisotropy, ) output_shape_bc = dist_bc.get_shape().as_list() output_shape_c = output_shape_bc[1:] output_shape = output_shape_c[1:] dist_c = tf.reshape(dist_bc, output_shape_c) names["dist"] = dist_c.name network_outputs = tf.unstack(dist_c, len(labels), axis=0) if mode.lower() == "train" or mode.lower() == "training": # mask = tf.placeholder(tf.float32, shape=output_shape) # names['mask'] = mask.name # ribo_mask = tf.placeholder(tf.float32, shape=output_shape) # names['ribo_mask'] = ribo_mask.name gt = [] w = [] cw = [] masks = [] for l in labels: masks.append(tf.placeholder(tf.float32, shape=output_shape)) gt.append(tf.placeholder(tf.float32, shape=output_shape)) w.append(tf.placeholder(tf.float32, shape=output_shape)) cw.append(l.class_weight) lb = [] lub = [] for output_it, gt_it, w_it, m_it, l in zip( network_outputs, gt, w, masks, labels ): lb.append(tf.losses.mean_squared_error(gt_it, output_it, w_it * m_it)) lub.append(tf.losses.mean_squared_error(gt_it, output_it, m_it)) # if l.labelname != 'ribosomes': # lub.append(tf.losses.mean_squared_error(gt_it, output_it, mask)) # else: # lub.append(tf.losses.mean_squared_error(gt_it, output_it, ribo_mask)) names[l.labelname] = output_it.name names["gt_" + l.labelname] = gt_it.name names["w_" + l.labelname] = w_it.name names["mask_" + l.labelname] = m_it.name for l, lb_it, lub_it in zip(labels, lb, lub): tf.summary.scalar("lb_" + l.labelname, lb_it) tf.summary.scalar("lub_" + l.labelname, lub_it) names["lb_" + l.labelname] = lb_it.name names["lub_" + l.labelname] = lub_it.name loss_total = tf.add_n(lb) loss_total_unbalanced = tf.add_n(lub) loss_total_classweighted = tf.tensordot(lb, cw, axes=1) loss_total_unbalanced_classweighted = tf.tensordot(lub, cw, axes=1) tf.summary.scalar("loss_total", loss_total) names["loss_total"] = loss_total.name tf.summary.scalar("loss_total_unbalanced", loss_total_unbalanced) names["loss_total_unbalanced"] = loss_total_unbalanced.name tf.summary.scalar("loss_total_classweighted", loss_total_classweighted) names["loss_total_classweighted"] = loss_total_classweighted.name tf.summary.scalar( "loss_total_unbalanced_classweighted", loss_total_unbalanced_classweighted ) names[ "loss_total_unbalanced_classweighted" ] = loss_total_unbalanced_classweighted.name opt = tf.train.AdamOptimizer( learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8 ) if loss_name == "loss_total": optimizer = opt.minimize(loss_total) elif loss_name == "loss_total_unbalanced": optimizer = opt.minimize(loss_total_unbalanced) elif loss_name == "loss_total_unbalanced_classweighted": optimizer = opt.minimize(loss_total_unbalanced_classweighted) elif loss_name == "loss_total_classweighted": optimizer = opt.minimize(loss_total_classweighted) else: raise ValueError(loss_name + " not defined") names["optimizer"] = optimizer.name merged = tf.summary.merge_all() names["summary"] = merged.name with open("net_io_names.json", "w") as f: json.dump(names, f) elif mode.lower() == "inference" or mode.lower() == "prediction": pass else: raise ValueError("unknown mode for network construction: {0:}".format(mode)) tf.train.export_meta_graph(filename=scnet.name + ".meta") return scnet