def add_malis_loss(graph): pred_affs = graph.get_tensor_by_name(config['pred_affs']) gt_affs = graph.get_tensor_by_name(config['gt_affs']) gt_seg = tf.placeholder(tf.int32, shape=(44, 44, 44), name='gt_seg') gt_affs_mask = tf.placeholder(tf.int32, shape=(3, 44, 44, 44), name='gt_affs_mask') mlo = malis.malis_loss_op(pred_affs, gt_affs, gt_seg, neighborhood, gt_affs_mask) # loss = mlo + beta config['kl_loss'] summary = tf.summary.scalar('malis_loss', mlo) # tf.summary.scalar('kl_loss', kl_loss) opt = tf.train.AdamOptimizer(learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8, name='malis_optimizer') # summary = tf.summary.merge_all() # print(summary) # opt = tf.train.AdamOptimizer() optimizer = opt.minimize(mlo) return (mlo, optimizer)
def malis(self, target, logits): #bzxyc aff = tf.slice(target, [0, 0, 0, 0, 0], [-1, -1, -1, -1, 3]) aff = tf.reshape(aff, (16, 128, 128, 3)) aff = tf.transpose(aff, [3, 0, 1, 2]) gt = tf.slice(target, [0, 0, 0, 0, 3], [-1, -1, -1, -1, -1]) gt = tf.reshape(gt, [16, 128, 128]) out = tf.reshape(logits[0], [16, 128, 128, 3]) out = tf.transpose(out, [3, 0, 1, 2]) mal = malis.malis_loss_op(out, aff, gt, self.nhood) return mal
def add_malis_loss(graph): pred_affs = graph.get_tensor_by_name(config['pred_affs']) gt_affs = graph.get_tensor_by_name(config['gt_affs_out']) gt_seg = tf.placeholder(tf.int32, shape=config['output_shape'], name='gt_seg') gt_affs_mask = tf.placeholder(tf.int32, shape=[3] + config['output_shape'], name='gt_affs_mask') # pred_affs_loss_weights = graph.get_tensor_by_name(config['pred_affs_loss_weights']) prior = graph.get_tensor_by_name(config['prior']) posterior = graph.get_tensor_by_name(config['posterior']) p = z(prior) q = z(posterior) mlo = malis.malis_loss_op(pred_affs, gt_affs, gt_seg, neighborhood, gt_affs_mask) # mse = tf.losses.mean_squared_error( # gt_affs, # pred_affs, # pred_affs_loss_weights) kl = tf.distributions.kl_divergence(p, q) kl = tf.reshape(kl, [], name="kl_loss") loss = mlo + beta * kl tf.summary.scalar('malis_loss', mlo) tf.summary.scalar('kl_loss', kl) opt = tf.train.AdamOptimizer( learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8, name='mse_optimizer') summary = tf.summary.merge_all() # print(summary) # opt = tf.train.AdamOptimizer() optimizer = opt.minimize(loss) return (loss, optimizer)
def add_malis_loss(graph): affs = graph.get_tensor_by_name(config['affs']) gt_affs = graph.get_tensor_by_name(config['gt_affs']) gt_seg = tf.placeholder(tf.int64, shape=(48, 56, 56), name='gt_seg') gt_affs_mask = tf.placeholder(tf.int64, shape=(3, 48, 56, 56), name='gt_affs_mask') loss = malis.malis_loss_op(affs, gt_affs, gt_seg, neighborhood, gt_affs_mask) malis_summary = tf.summary.scalar('setup13_malis_loss', loss) opt = tf.train.AdamOptimizer(learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8, name='malis_optimizer') optimizer = opt.minimize(loss) return (loss, optimizer)
def malis_cross(self, target, logits): # bzxyc aff = tf.slice(target, [0, 0, 0, 0, 0], [-1, -1, -1, -1, 3]) aff = tf.reshape(aff, (16, 128, 128, 3)) aff = tf.transpose(aff, [3, 0, 1, 2]) gt = tf.slice(target, [0, 0, 0, 0, 3], [-1, -1, -1, -1, -1]) gt = tf.reshape(gt, [16, 128, 128]) out = tf.reshape(logits[0], [16, 128, 128, 3]) out = tf.transpose(out, [3, 0, 1, 2]) mal = malis.malis_loss_op(out, aff, gt, self.nhood) cross = self.cross.weighted_cross( tf.slice(target, [0, 0, 0, 0, 0], [-1, -1, -1, -1, 3]), logits) out = tf.add(mal, cross) return out
def mk_net(**kwargs): tf.reset_default_graph() input_shape = kwargs['input_shape'] if not isinstance(input_shape, tuple): input_shape = tuple(input_shape) # create a placeholder for the 3D raw input tensor raw = tf.placeholder(tf.float32, shape=input_shape, name="raw") # create a U-Net raw_batched = tf.reshape(raw, (1, 1) + input_shape) # unet_output = unet(raw_batched, 14, 4, [[1,3,3],[1,3,3],[1,3,3]]) model, _, _ = unet(raw_batched, num_fmaps=kwargs['num_fmaps'], fmap_inc_factors=kwargs['fmap_inc_factors'], fmap_dec_factors=kwargs['fmap_dec_factors'], downsample_factors=kwargs['downsample_factors'], activation=kwargs['activation'], padding=kwargs['padding'], kernel_size=kwargs['kernel_size'], num_repetitions=kwargs['num_repetitions'], upsampling=kwargs['upsampling']) print(model) # add a convolution layer to create 3 output maps representing affinities # in z, y, and x model, _ = conv_pass( model, kernel_sizes=[1], num_fmaps=4, # num_repetitions=1, padding=kwargs['padding'], activation=None, name="output") print(model) # the 4D output tensor (channels, depth, height, width) pred = tf.squeeze(model, axis=0) output_shape = pred.get_shape().as_list()[1:] pred_affs, pred_fgbg = tf.split(pred, [3, 1], 0) raw_cropped = crop(raw, output_shape) raw_cropped = tf.expand_dims(raw_cropped, 0) # create a placeholder for the corresponding ground-truth affinities gt_affs = tf.placeholder(tf.float32, shape=pred_affs.get_shape(), name="gt_affs") gt_labels = tf.placeholder(tf.float32, shape=pred_fgbg.get_shape(), name="gt_labels") gt_fgbg = tf.placeholder(tf.float32, shape=pred_fgbg.get_shape(), name="gt_fgbg") anchor = tf.placeholder(tf.float32, shape=pred_fgbg.get_shape(), name="anchor") # gt_fgbg = tf.clip_by_value(gt_labels, 0, 1) # create a placeholder for per-voxel loss weights # loss_weights_affs = tf.placeholder( # tf.float32, # shape=pred_affs.get_shape(), # name="loss_weights_affs") loss_weights_fgbg = tf.placeholder(tf.float32, shape=pred_fgbg.get_shape(), name="loss_weights_fgbg") # compute the loss as the weighted mean squared error between the # predicted and the ground-truth affinities loss_fgbg, pred_fgbg, loss_fgbg_print = \ util.get_loss_weighted(gt_fgbg, pred_fgbg, loss_weights_fgbg, kwargs['loss'], "fgbg", True) neighborhood = [[-1, 0, 0], [0, -1, 0], [0, 0, -1]] gt_seg = tf.squeeze(gt_labels, axis=0, name="gt_seg") pred_affs = tf.identity(pred_affs, name="pred_affs") loss_malis = malis.malis_loss_op(pred_affs, gt_affs, gt_seg, neighborhood, name="malis_loss") loss = (kwargs['loss_malis_coeff'] * loss_malis + kwargs['loss_fgbg_coeff'] * loss_fgbg) loss_malis_sum = tf.summary.scalar('loss_malis_sum', kwargs['loss_malis_coeff'] * loss_malis) loss_fgbg_sum = tf.summary.scalar('loss_fgbg_sum', kwargs['loss_fgbg_coeff'] * loss_fgbg) loss_sum = tf.summary.scalar('loss_sum', loss) summaries = tf.summary.merge([loss_malis_sum, loss_fgbg_sum, loss_sum], name="summaries") learning_rate = tf.placeholder_with_default(kwargs['lr'], shape=(), name="learning-rate") # use the Adam optimizer to minimize the loss opt = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=0.95, beta2=0.999, epsilon=1e-8) optimizer = opt.minimize(loss) # store the network in a meta-graph file tf.train.export_meta_graph( filename=os.path.join(kwargs['output_folder'], kwargs['name'] + '.meta')) # store network configuration for use in train and predict scripts fn = os.path.join(kwargs['output_folder'], kwargs['name']) names = { 'raw': raw.name, 'raw_cropped': raw_cropped.name, 'pred_affs': pred_affs.name, 'gt_affs': gt_affs.name, 'gt_labels': gt_labels.name, # 'loss_weights_affs': loss_weights_affs.name, 'pred_fgbg': pred_fgbg.name, 'gt_fgbg': gt_fgbg.name, 'anchor': anchor.name, 'loss_weights_fgbg': loss_weights_fgbg.name, 'loss': loss.name, 'optimizer': optimizer.name, 'summaries': summaries.name } with open(fn + '_names.json', 'w') as f: json.dump(names, f) config = { 'input_shape': input_shape, 'output_shape': output_shape, } with open(fn + '_config.json', 'w') as f: json.dump(config, f)