def create_network(input_shape, name): tf.reset_default_graph() # c=2, d, h, w raw = tf.placeholder(tf.float32, shape=(2,) + input_shape) # b=1, c=2, d, h, w raw_batched = tf.reshape(raw, (1, 2) + input_shape) fg_unet = unet(raw_batched, 12, 5, [[1, 3, 3], [1, 2, 2], [2, 2, 2]]) fg_batched = conv_pass( fg_unet[0], kernel_sizes=(1,), num_fmaps=1, activation="sigmoid" ) output_shape_batched = fg_batched[0].get_shape().as_list() # d, h, w, strip the batch and channel dimension output_shape = tuple(output_shape_batched[2:]) fg = tf.reshape(fg_batched[0], output_shape) labels_fg = tf.placeholder(tf.float32, shape=output_shape) loss_weights = tf.placeholder(tf.float32, shape=output_shape) loss = tf.losses.mean_squared_error(labels_fg, fg, loss_weights) opt = tf.train.AdamOptimizer( learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8 ) optimizer = opt.minimize(loss) print("input shape: %s" % (input_shape,)) print("output shape: %s" % (output_shape,)) tf.train.export_meta_graph(filename=name + ".meta") tf.train.export_meta_graph(filename="train_net.meta") names = { "raw": raw.name, "fg": fg.name, "loss_weights": loss_weights.name, "loss": loss.name, "optimizer": optimizer.name, "labels_fg": labels_fg.name, } with open(name + "_names.json", "w") as f: json.dump(names, f) config = {"input_shape": input_shape, "output_shape": output_shape, "out_dims": 1} with open(name + "_config.json", "w") as f: json.dump(config, f)
def create_network(input_shape, name, run): tf.reset_default_graph() raw = tf.placeholder(tf.float32, shape=input_shape) raw_batched = tf.reshape(raw, (1, 1) + input_shape) out, _, _ = models.unet(raw_batched, 12, 5, [[1, 3, 3], [1, 3, 3], [1, 3, 3]]) soft_mask_batched, _ = models.conv_pass(out, kernel_sizes=[1], num_fmaps=1, activation=None) output_shape_batched = soft_mask_batched.get_shape().as_list() output_shape = output_shape_batched[1:] # strip the batch dimension soft_mask = tf.reshape(soft_mask_batched, output_shape[1:]) gt_lsds = tf.placeholder(tf.float32, shape=[10] + list(output_shape[1:])) gt_soft_mask = gt_lsds[9, :, :, :] gt_maxima = tf.reshape( max_detection( tf.reshape(gt_soft_mask, [1] + gt_soft_mask.get_shape().as_list() + [1]), [1, 1, 5, 5, 1], 0.5), gt_soft_mask.get_shape()) pred_maxima = tf.reshape( max_detection( tf.reshape(soft_mask, [1] + gt_soft_mask.get_shape().as_list() + [1]), [1, 1, 5, 5, 1], 0.5), gt_soft_mask.get_shape()) # Soft weights for binary mask binary_mask = tf.cast(gt_soft_mask > 0, tf.float32) loss_weights_soft_mask = tf.ones(binary_mask.get_shape()) loss_weights_soft_mask += tf.multiply(binary_mask, tf.reduce_sum(binary_mask)) loss_weights_soft_mask -= binary_mask loss = tf.losses.mean_squared_error(soft_mask, gt_soft_mask, loss_weights_soft_mask) summary = tf.summary.scalar('loss', loss) opt = tf.train.AdamOptimizer(learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8) optimizer = opt.minimize(loss) output_shape = output_shape[1:] print("input shape : %s" % (input_shape, )) print("output shape: %s" % (output_shape, )) output_dir = "./checkpoints/run_{}".format(run) if not os.path.exists(output_dir): os.makedirs(output_dir) tf.train.export_meta_graph(filename=output_dir + "/" + name + '.meta') config = { 'raw': raw.name, 'soft_mask': soft_mask.name, 'gt_lsds': gt_lsds.name, 'gt_maxima': gt_maxima.name, 'pred_maxima': pred_maxima.name, 'loss_weights_soft_mask': loss_weights_soft_mask.name, 'loss': loss.name, 'optimizer': optimizer.name, 'input_shape': input_shape, 'output_shape': output_shape, 'summary': summary.name, } with open(output_dir + "/" + name + '.json', 'w') as f: json.dump(config, f)
def mknet(parameter, name): learning_rate = parameter['learning_rate'] input_shape = tuple(parameter['input_size']) fmap_inc_factor = parameter['fmap_inc_factor'] downsample_factors = parameter['downsample_factors'] fmap_num = parameter['fmap_num'] unet_model = parameter['unet_model'] num_heads = 2 if unet_model == 'dh_unet' else 1 m_loss_scale = parameter['m_loss_scale'] d_loss_scale = parameter['d_loss_scale'] voxel_size = tuple(parameter['voxel_size']) # only needed for computing # field of view. No impact on the actual architecture. assert unet_model == 'vanilla' or unet_model == 'dh_unet', \ 'unknown unetmodel {}'.format(unet_model) tf.reset_default_graph() # d, h, w raw = tf.placeholder(tf.float32, shape=input_shape) # b=1, c=1, d, h, w raw_batched = tf.reshape(raw, (1, 1) + input_shape) # b=1, c=fmap_num, d, h, w outputs, fov, voxel_size = models.unet(raw_batched, fmap_num, fmap_inc_factor, downsample_factors, num_heads=num_heads, voxel_size=voxel_size) if num_heads == 1: outputs = (outputs, outputs) print('unet has fov in nm: ', fov) # b=1, c=3, d, h, w partner_vectors_batched, fov = models.conv_pass( outputs[0], kernel_sizes=[1], num_fmaps=3, activation=None, # Regression name='partner_vector') # b=1, c=1, d, h, w syn_indicator_batched, fov = models.conv_pass(outputs[1], kernel_sizes=[1], num_fmaps=1, activation=None, name='syn_indicator') print('fov in nm: ', fov) # d, h, w output_shape = tuple(syn_indicator_batched.get_shape().as_list() [2:]) # strip batch and channel dimension. syn_indicator_shape = output_shape # c=3, d, h, w partner_vectors_shape = (3, ) + syn_indicator_shape # c=3, d, h, w pred_partner_vectors = tf.reshape(partner_vectors_batched, partner_vectors_shape) gt_partner_vectors = tf.placeholder(tf.float32, shape=partner_vectors_shape) vectors_mask = tf.placeholder(tf.float32, shape=syn_indicator_shape) # d,h,w gt_mask = tf.placeholder(tf.bool, shape=syn_indicator_shape) # d,h,w vectors_mask = tf.cast(vectors_mask, tf.bool) # d, h, w pred_syn_indicator = tf.reshape( syn_indicator_batched, syn_indicator_shape) # squeeze batch dimension gt_syn_indicator = tf.placeholder(tf.float32, shape=syn_indicator_shape) indicator_weight = tf.placeholder(tf.float32, shape=syn_indicator_shape) partner_vectors_loss_mask = tf.losses.mean_squared_error( gt_partner_vectors, pred_partner_vectors, tf.reshape(vectors_mask, (1, ) + syn_indicator_shape), reduction=tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS) syn_indicator_loss_weighted = tf.losses.sigmoid_cross_entropy( gt_syn_indicator, pred_syn_indicator, indicator_weight, reduction=tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS) pred_syn_indicator_out = tf.sigmoid(pred_syn_indicator) # For output. iteration = tf.Variable(1.0, name='training_iteration', trainable=False) loss = m_loss_scale * syn_indicator_loss_weighted + d_loss_scale * partner_vectors_loss_mask # Monitor in tensorboard. tf.summary.scalar('loss', loss) tf.summary.scalar('loss_vectors', partner_vectors_loss_mask) tf.summary.scalar('loss_indicator', syn_indicator_loss_weighted) summary = tf.summary.merge_all() # l=1, d, h, w print("input shape : %s" % (input_shape, )) print("output shape: %s" % (output_shape, )) # Train Ops. opt = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=0.95, beta2=0.999, epsilon=1e-8) gvs_ = opt.compute_gradients(loss) optimizer = opt.apply_gradients(gvs_, global_step=iteration) tf.train.export_meta_graph(filename=name + '.meta') names = { 'raw': raw.name, 'gt_partner_vectors': gt_partner_vectors.name, 'pred_partner_vectors': pred_partner_vectors.name, 'gt_syn_indicator': gt_syn_indicator.name, 'pred_syn_indicator': pred_syn_indicator.name, 'pred_syn_indicator_out': pred_syn_indicator_out.name, 'indicator_weight': indicator_weight.name, 'vectors_mask': vectors_mask.name, 'gt_mask': gt_mask.name, 'loss': loss.name, 'optimizer': optimizer.name, 'summary': summary.name, 'input_shape': input_shape, 'output_shape': output_shape } names['outputs'] = { 'pred_syn_indicator_out': { "out_dims": 1, "out_dtype": "uint8" }, 'pred_partner_vectors': { "out_dims": 3, "out_dtype": "float32" } } if m_loss_scale == 0: names['outputs'].pop('pred_syn_indicator_out') if d_loss_scale == 0: names['outputs'].pop('pred_partner_vectors') with open(name + '_config.json', 'w') as f: json.dump(names, f) total_parameters = 0 for variable in tf.trainable_variables(): # shape is an array of tf.Dimension shape = variable.get_shape() variable_parameters = 1 for dim in shape: variable_parameters *= dim.value total_parameters += variable_parameters print("Number of parameters:", total_parameters) print("Estimated size of parameters in GB:", float(total_parameters) * 8 / (1024 * 1024 * 1024))
def mk_net(**kwargs): tf.reset_default_graph() input_shape = kwargs['input_shape'] if not isinstance(input_shape, tuple): input_shape = tuple(input_shape) # create a placeholder for the 3D raw input tensor raw = tf.placeholder(tf.float32, shape=input_shape, name="raw") # create a U-Net raw_batched = tf.reshape(raw, (1, kwargs['num_channels']) + input_shape) model, _, _ = unet(raw_batched, num_fmaps=kwargs['num_fmaps'], fmap_inc_factors=kwargs['fmap_inc_factors'], fmap_dec_factors=kwargs['fmap_dec_factors'], downsample_factors=kwargs['downsample_factors'], activation=kwargs['activation'], padding=kwargs['padding'], kernel_size=kwargs['kernel_size'], num_repetitions=kwargs['num_repetitions'], upsampling=kwargs['upsampling'], crop_factor=kwargs.get('crop_factor', True)) print(model) num_patch_fmaps = np.prod(kwargs['patchshape']) model, _ = conv_pass(model, kernel_sizes=[1], num_fmaps=num_patch_fmaps, padding=kwargs['padding'], activation=None, name="output") print(model) logits = tf.squeeze(model, axis=0) output_shape = logits.get_shape().as_list()[1:] logitspatch = logits pred_affs = tf.sigmoid(logitspatch) raw_cropped = crop(raw, output_shape) # placeholder for gt gt_affs_shape = pred_affs.get_shape().as_list() gt_affs = tf.placeholder(tf.float32, shape=gt_affs_shape, name="gt_affs") anchor = tf.placeholder(tf.float32, shape=[1] + output_shape, name="anchor") # loss # loss_weights_affs = tf.placeholder( # tf.float32, # shape=pred_affs.get_shape(), # name="loss_weights_affs") loss = tf.losses.sigmoid_cross_entropy(gt_affs, logitspatch) # loss_weights_affs) loss_sums = [] loss_sums.append(tf.summary.scalar('loss_sum', loss)) summaries = tf.summary.merge(loss_sums, name="summaries") # optimizer learning_rate = tf.placeholder_with_default(kwargs['lr'], shape=(), name="learning-rate") opt = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=0.95, beta2=0.999, epsilon=1e-8) optimizer = opt.minimize(loss) tf.train.export_meta_graph( filename=os.path.join(kwargs['output_folder'], kwargs['name'] + '.meta')) fn = os.path.join(kwargs['output_folder'], kwargs['name']) names = { 'raw': raw.name, 'raw_cropped': raw_cropped.name, 'gt_affs': gt_affs.name, 'pred_affs': pred_affs.name, # 'loss_weights_affs': loss_weights_affs.name, 'anchor': anchor.name, 'loss': loss.name, 'optimizer': optimizer.name, 'summaries': summaries.name } with open(fn + '_names.json', 'w') as f: json.dump(names, f) config = { 'input_shape': input_shape, 'gt_affs_shape': gt_affs_shape, 'output_shape': output_shape, } with open(fn + '_config.json', 'w') as f: json.dump(config, f)
def mk_net(**kwargs): tf.reset_default_graph() input_shape = kwargs['input_shape'] if not isinstance(input_shape, tuple): input_shape = tuple(input_shape) # create a placeholder for the 3D raw input tensor raw = tf.placeholder(tf.float32, shape=input_shape, name="raw") # create a U-Net raw_batched = tf.reshape(raw, (1, 1) + input_shape) # unet_output = unet(raw_batched, 14, 4, [[1,3,3],[1,3,3],[1,3,3]]) model, _, _ = unet(raw_batched, num_fmaps=kwargs['num_fmaps'], fmap_inc_factors=kwargs['fmap_inc_factors'], fmap_dec_factors=kwargs['fmap_dec_factors'], downsample_factors=kwargs['downsample_factors'], activation=kwargs['activation'], padding=kwargs['padding'], kernel_size=kwargs['kernel_size'], num_repetitions=kwargs['num_repetitions'], upsampling=kwargs['upsampling']) print(model) # add a convolution layer to create 3 output maps representing affinities # in z, y, and x model, _ = conv_pass( model, kernel_sizes=[1], num_fmaps=4, # num_repetitions=1, padding=kwargs['padding'], activation=None, name="output") print(model) # the 4D output tensor (channels, depth, height, width) pred = tf.squeeze(model, axis=0) output_shape = pred.get_shape().as_list()[1:] pred_affs, pred_fgbg = tf.split(pred, [3, 1], 0) raw_cropped = crop(raw, output_shape) raw_cropped = tf.expand_dims(raw_cropped, 0) # create a placeholder for the corresponding ground-truth affinities gt_affs = tf.placeholder(tf.float32, shape=pred_affs.get_shape(), name="gt_affs") gt_labels = tf.placeholder(tf.float32, shape=pred_fgbg.get_shape(), name="gt_labels") gt_fgbg = tf.placeholder(tf.float32, shape=pred_fgbg.get_shape(), name="gt_fgbg") anchor = tf.placeholder(tf.float32, shape=pred_fgbg.get_shape(), name="anchor") # gt_fgbg = tf.clip_by_value(gt_labels, 0, 1) # create a placeholder for per-voxel loss weights # loss_weights_affs = tf.placeholder( # tf.float32, # shape=pred_affs.get_shape(), # name="loss_weights_affs") loss_weights_fgbg = tf.placeholder(tf.float32, shape=pred_fgbg.get_shape(), name="loss_weights_fgbg") # compute the loss as the weighted mean squared error between the # predicted and the ground-truth affinities loss_fgbg, pred_fgbg, loss_fgbg_print = \ util.get_loss_weighted(gt_fgbg, pred_fgbg, loss_weights_fgbg, kwargs['loss'], "fgbg", True) neighborhood = [[-1, 0, 0], [0, -1, 0], [0, 0, -1]] gt_seg = tf.squeeze(gt_labels, axis=0, name="gt_seg") pred_affs = tf.identity(pred_affs, name="pred_affs") loss_malis = malis.malis_loss_op(pred_affs, gt_affs, gt_seg, neighborhood, name="malis_loss") loss = (kwargs['loss_malis_coeff'] * loss_malis + kwargs['loss_fgbg_coeff'] * loss_fgbg) loss_malis_sum = tf.summary.scalar('loss_malis_sum', kwargs['loss_malis_coeff'] * loss_malis) loss_fgbg_sum = tf.summary.scalar('loss_fgbg_sum', kwargs['loss_fgbg_coeff'] * loss_fgbg) loss_sum = tf.summary.scalar('loss_sum', loss) summaries = tf.summary.merge([loss_malis_sum, loss_fgbg_sum, loss_sum], name="summaries") learning_rate = tf.placeholder_with_default(kwargs['lr'], shape=(), name="learning-rate") # use the Adam optimizer to minimize the loss opt = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=0.95, beta2=0.999, epsilon=1e-8) optimizer = opt.minimize(loss) # store the network in a meta-graph file tf.train.export_meta_graph( filename=os.path.join(kwargs['output_folder'], kwargs['name'] + '.meta')) # store network configuration for use in train and predict scripts fn = os.path.join(kwargs['output_folder'], kwargs['name']) names = { 'raw': raw.name, 'raw_cropped': raw_cropped.name, 'pred_affs': pred_affs.name, 'gt_affs': gt_affs.name, 'gt_labels': gt_labels.name, # 'loss_weights_affs': loss_weights_affs.name, 'pred_fgbg': pred_fgbg.name, 'gt_fgbg': gt_fgbg.name, 'anchor': anchor.name, 'loss_weights_fgbg': loss_weights_fgbg.name, 'loss': loss.name, 'optimizer': optimizer.name, 'summaries': summaries.name } with open(fn + '_names.json', 'w') as f: json.dump(names, f) config = { 'input_shape': input_shape, 'output_shape': output_shape, } with open(fn + '_config.json', 'w') as f: json.dump(config, f)
def mk_net(**kwargs): tf.reset_default_graph() input_shape = kwargs['input_shape'] if not isinstance(input_shape, tuple): input_shape = tuple(input_shape) # create a placeholder for the 3D raw input tensor raw = tf.placeholder(tf.float32, shape=input_shape, name="raw") # create a U-Net raw_batched = tf.reshape(raw, (1, 1) + input_shape) # unet_output = unet(raw_batched, 14, 4, [[1,3,3],[1,3,3],[1,3,3]]) model, _, _ = unet(raw_batched, num_fmaps=kwargs['num_fmaps'], fmap_inc_factors=kwargs['fmap_inc_factors'], fmap_dec_factors=kwargs['fmap_dec_factors'], downsample_factors=kwargs['downsample_factors'], activation=kwargs['activation'], padding=kwargs['padding'], kernel_size=kwargs['kernel_size'], num_repetitions=kwargs['num_repetitions'], upsampling=kwargs['upsampling'], crop_factor=kwargs.get('crop_factor', True)) print(model) model, _ = conv_pass( model, kernel_sizes=[1], num_fmaps=3, padding=kwargs['padding'], activation=None, name="output") print(model) # the 4D output tensor (channels, depth, height, width) pred_threeclass = tf.squeeze(model, axis=0) output_shape = pred_threeclass.get_shape().as_list()[1:] pred_class_max = tf.argmax(pred_threeclass, axis=0, output_type=tf.int32) pred_class_max = tf.expand_dims(pred_class_max, 0) pred_fgbg = tf.nn.softmax(pred_threeclass, dim=0)[0] pred_fgbg = tf.expand_dims(pred_fgbg, 0) raw_cropped = crop(raw, output_shape) raw_cropped = tf.expand_dims(raw_cropped, 0) # create a placeholder for the corresponding ground-truth gt_threeclass = tf.placeholder(tf.int32, shape=[1]+output_shape, name="gt_threeclass") gt_threeclassTmp = tf.squeeze(gt_threeclass, 0) anchor = tf.placeholder(tf.float32, shape=gt_threeclass.get_shape(), name="anchor") # create a placeholder for per-voxel loss weights loss_weights_threeclass = tf.placeholder( tf.float32, shape=gt_threeclass.get_shape(), name="loss_weights_threeclass") loss_threeclass, _, loss_threeclass_print = \ util.get_loss_weighted(gt_threeclassTmp, tf.transpose(pred_threeclass, [1, 2, 3, 0]), tf.transpose(loss_weights_threeclass, [1, 2, 3, 0]), kwargs['loss'], "threeclass", False) if kwargs['debug']: _, _, loss_threeclass_print2 = \ util.get_loss(gt_threeclassTmp, tf.transpose(pred_threeclass, [1, 2, 3, 0]), kwargs['loss'], "threeclass", False) print_ops = loss_threeclass_print + loss_threeclass_print2 else: print_ops = None with tf.control_dependencies(print_ops): loss = (1.0 * loss_threeclass) loss_sum = tf.summary.scalar('loss_sum', loss) summaries = tf.summary.merge([loss_sum], name="summaries") learning_rate = tf.placeholder_with_default(kwargs['lr'], shape=(), name="learning-rate") # use the Adam optimizer to minimize the loss opt = tf.train.AdamOptimizer( learning_rate=learning_rate, beta1=0.95, beta2=0.999, epsilon=1e-8) optimizer = opt.minimize(loss) # store the network in a meta-graph file tf.train.export_meta_graph(filename=os.path.join(kwargs['output_folder'], kwargs['name'] +'.meta')) # store network configuration for use in train and predict scripts fn = os.path.join(kwargs['output_folder'], kwargs['name']) names = { 'raw': raw.name, 'raw_cropped': raw_cropped.name, 'gt_threeclass': gt_threeclass.name, 'pred_threeclass': pred_threeclass.name, 'pred_class_max': pred_class_max.name, 'pred_fgbg': pred_fgbg.name, 'anchor': anchor.name, 'loss_weights_threeclass': loss_weights_threeclass.name, 'loss': loss.name, 'optimizer': optimizer.name, 'summaries': summaries.name } with open(fn + '_names.json', 'w') as f: json.dump(names, f) config = { 'input_shape': input_shape, 'output_shape': output_shape, } with open(fn + '_config.json', 'w') as f: json.dump(config, f)
def create_network(input_shape, name, setup, voxel_size=[40, 4, 4], nms_window=[1, 1, 10, 10, 1], nms_threshold=0.5): tf.reset_default_graph() # Useless and causes incompatibility - reinsert for legacy code # with tf.variable_scope('setup_{}'.format(setup)): raw = tf.placeholder(tf.float32, shape=input_shape) raw_batched = tf.reshape(raw, (1, 1) + input_shape) out, _, _ = models.unet(raw_batched, 12, 5, [[1, 3, 3], [1, 3, 3], [1, 3, 3]]) lsds_batched, _ = models.conv_pass(out, kernel_sizes=[1], num_fmaps=10, activation=None) output_shape_batched = lsds_batched.get_shape().as_list() output_shape = output_shape_batched[1:] # strip the batch dimension lsds = tf.reshape(lsds_batched, output_shape) soft_mask = lsds[9, :, :, :] soft_mask = tf.clip_by_value(soft_mask, 0, 1.0) derivatives = lsds[:9, :, :, :] gt_lsds = tf.placeholder(tf.float32, shape=output_shape) gt_soft_mask = gt_lsds[9, :, :, :] gt_derivatives = gt_lsds[:9, :, :, :] print(gt_soft_mask.get_shape().as_list()) print(soft_mask.get_shape().as_list()) print(list(output_shape)) print(list(output_shape_batched)) gt_maxima, gt_reduced_maxima = max_detection( tf.reshape(gt_soft_mask, [1] + gt_soft_mask.get_shape().as_list() + [1]), nms_window, nms_threshold) pred_maxima, pred_reduced_maxima = max_detection( tf.reshape(soft_mask, [1] + gt_soft_mask.get_shape().as_list() + [1]), nms_window, nms_threshold) # Soft weights for binary mask binary_mask = tf.cast(gt_soft_mask > 0, tf.float32) loss_weights_soft_mask = tf.ones(binary_mask.get_shape()) loss_weights_soft_mask += tf.multiply(binary_mask, tf.reduce_sum(binary_mask)) loss_weights_soft_mask -= binary_mask loss_weights_lsds = tf.stack([loss_weights_soft_mask] * 10) loss = tf.losses.mean_squared_error(lsds, gt_lsds, loss_weights_lsds) summary = tf.summary.scalar('loss', loss) opt = tf.train.AdamOptimizer(learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8) optimizer = opt.minimize(loss) output_shape = output_shape[1:] print("input shape : %s" % (input_shape, )) print("output shape: %s" % (output_shape, )) tf.train.export_meta_graph(filename=name + '.meta') config = { 'raw': raw.name, 'derivatives': derivatives.name, 'soft_mask': soft_mask.name, 'gt_lsds': gt_lsds.name, 'gt_maxima': gt_maxima.name, 'gt_reduced_maxima': gt_reduced_maxima.name, 'pred_maxima': pred_maxima.name, 'pred_reduced_maxima': pred_reduced_maxima.name, 'loss_weights_lsds': loss_weights_soft_mask.name, 'loss': loss.name, 'optimizer': optimizer.name, 'input_shape': input_shape, 'output_shape': output_shape, 'summary': summary.name, } config['outputs'] = { 'soft_mask': { "out_dims": 1, "out_dtype": "uint8" }, 'derivatives': { "out_dims": 9, "out_dtype": "uint8" }, 'reduced_maxima': { "out_dims": 1, "out_dtype": "uint8" } } config['voxel_size'] = voxel_size with open(name + '.json', 'w') as f: json.dump(config, f)
def mknet_binary(config: Dict[str, Any], output_path: Path = Path()): input_shape = (2, ) + tuple(config["INPUT_SHAPE"]) num_fmaps_foreground = config["NUM_FMAPS_FOREGROUND"] fmap_inc_factors_foreground = config["FMAP_INC_FACTORS_FOREGROUND"] downsample_factors = config["DOWNSAMPLE_FACTORS"] kernel_size_up = config["KERNEL_SIZE_UP"] output_shape = np.array(config["OUTPUT_SHAPE"]) raw = tf.placeholder(tf.float32, shape=input_shape, name="raw_input") loss_weights = tf.placeholder(tf.float32, shape=output_shape, name="loss_weights") gt_labels = tf.placeholder(tf.int64, shape=output_shape, name="gt_labels") raw_batched = tf.reshape(raw, (1, ) + input_shape) with tf.variable_scope("fg"): fg_unet = unet( raw_batched, num_fmaps=num_fmaps_foreground, fmap_inc_factors=fmap_inc_factors_foreground, downsample_factors=downsample_factors, kernel_size_up=kernel_size_up, constant_upsample=True, ) fg_batched = conv_pass(fg_unet[0], kernel_sizes=[1], num_fmaps=1, activation=None)[0] output_shape_batched = fg_batched.get_shape().as_list() output_shape = tuple( output_shape_batched[2:]) # strip the batch and channel dimension assert all( np.isclose(np.array(output_shape), np.array( config["OUTPUT_SHAPE"]))), "output shapes don't match" fg_logits = tf.reshape(fg_batched[0], output_shape, name="fg_logits") fg = tf.sigmoid(fg_logits, name="fg") gt_fg = tf.not_equal(gt_labels, 0, name="gt_fg") fg_loss = tf.losses.sigmoid_cross_entropy(gt_fg, fg_logits, weights=loss_weights, scope="fg_loss") opt = tf.train.AdamOptimizer(learning_rate=0.5e-5, beta1=0.95, beta2=0.999, epsilon=1e-8) optimizer = opt.minimize(fg_loss) tf.summary.scalar("fg_loss", fg_loss) summaries = tf.summary.merge_all() tf.train.export_meta_graph(filename=output_path / "train_net_foreground.meta") names = { "raw": raw.name, "gt_labels": gt_labels.name, "gt_fg": gt_fg.name, "fg_pred": fg.name, "fg_logits": fg_logits.name, "loss_weights": loss_weights.name, "fg_loss": fg_loss.name, "optimizer": optimizer.name, "summaries": summaries.name, } with (output_path / "tensor_names.json").open("w") as f: json.dump(names, f)
def mknet(config: Dict[str, Any], output_path: Path = Path()): network_name = config["NETWORK_NAME"] input_shape = (2, ) + tuple(config["INPUT_SHAPE"]) output_shape = tuple(config["OUTPUT_SHAPE"]) embedding_dims = config["EMBEDDING_DIMS"] num_fmaps_embedding = config["NUM_FMAPS_EMBEDDING"] fmap_inc_factors_embedding = config["FMAP_INC_FACTORS_EMBEDDING"] downsample_factors = config["DOWNSAMPLE_FACTORS"] kernel_size_up = config["KERNEL_SIZE_UP"] raw = tf.placeholder(tf.float32, shape=input_shape, name="raw_input") raw_batched = tf.reshape(raw, (1, ) + input_shape) fg_pred = tf.placeholder(tf.float32, shape=output_shape, name="fg_pred") gt_labels = tf.placeholder(tf.int64, shape=output_shape, name="gt_labels") loss_weights = tf.placeholder(tf.float32, shape=output_shape, name="loss_weights") with tf.variable_scope("embedding"): embedding_unet = unet( raw_batched, num_fmaps=num_fmaps_embedding, fmap_inc_factors=fmap_inc_factors_embedding, downsample_factors=downsample_factors, kernel_size_up=kernel_size_up, constant_upsample=True, ) embedding_batched = conv_pass( embedding_unet[0], kernel_sizes=[1], num_fmaps=embedding_dims, activation=None, name="embedding", ) embedding_norms = tf.norm(embedding_batched[0], axis=1, keep_dims=True) embedding_scaled = embedding_batched[0] / embedding_norms output_shape_batched = embedding_scaled.get_shape().as_list() output_shape = tuple( output_shape_batched[2:]) # strip the batch and channel dimension assert all( np.isclose(np.array(output_shape), np.array( config["OUTPUT_SHAPE"]))), "output shapes don't match" embedding = tf.reshape(embedding_scaled, (embedding_dims, ) + output_shape) tf.train.export_meta_graph(filename=output_path / f"{network_name}.meta") names = { "raw": raw.name, "embedding": embedding.name, "fg_pred": fg_pred.name, "gt_labels": gt_labels.name, "loss_weights": loss_weights.name, } with (output_path / "tensor_names.json").open("w") as f: json.dump(names, f)