Exemple #1
0
def create_network(input_shape, name):
    tf.reset_default_graph()

    # c=2, d, h, w
    raw = tf.placeholder(tf.float32, shape=(2,) + input_shape)

    # b=1, c=2, d, h, w
    raw_batched = tf.reshape(raw, (1, 2) + input_shape)

    fg_unet = unet(raw_batched, 12, 5, [[1, 3, 3], [1, 2, 2], [2, 2, 2]])

    fg_batched = conv_pass(
        fg_unet[0], kernel_sizes=(1,), num_fmaps=1, activation="sigmoid"
    )

    output_shape_batched = fg_batched[0].get_shape().as_list()

    # d, h, w, strip the batch and channel dimension
    output_shape = tuple(output_shape_batched[2:])

    fg = tf.reshape(fg_batched[0], output_shape)

    labels_fg = tf.placeholder(tf.float32, shape=output_shape)
    loss_weights = tf.placeholder(tf.float32, shape=output_shape)

    loss = tf.losses.mean_squared_error(labels_fg, fg, loss_weights)

    opt = tf.train.AdamOptimizer(
        learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8
    )
    optimizer = opt.minimize(loss)

    print("input shape: %s" % (input_shape,))
    print("output shape: %s" % (output_shape,))

    tf.train.export_meta_graph(filename=name + ".meta")

    tf.train.export_meta_graph(filename="train_net.meta")

    names = {
        "raw": raw.name,
        "fg": fg.name,
        "loss_weights": loss_weights.name,
        "loss": loss.name,
        "optimizer": optimizer.name,
        "labels_fg": labels_fg.name,
    }

    with open(name + "_names.json", "w") as f:
        json.dump(names, f)

    config = {"input_shape": input_shape, "output_shape": output_shape, "out_dims": 1}
    with open(name + "_config.json", "w") as f:
        json.dump(config, f)
Exemple #2
0
def create_network(input_shape, name, run):

    tf.reset_default_graph()

    raw = tf.placeholder(tf.float32, shape=input_shape)
    raw_batched = tf.reshape(raw, (1, 1) + input_shape)

    out, _, _ = models.unet(raw_batched, 12, 5,
                            [[1, 3, 3], [1, 3, 3], [1, 3, 3]])

    soft_mask_batched, _ = models.conv_pass(out,
                                            kernel_sizes=[1],
                                            num_fmaps=1,
                                            activation=None)

    output_shape_batched = soft_mask_batched.get_shape().as_list()
    output_shape = output_shape_batched[1:]  # strip the batch dimension

    soft_mask = tf.reshape(soft_mask_batched, output_shape[1:])

    gt_lsds = tf.placeholder(tf.float32, shape=[10] + list(output_shape[1:]))
    gt_soft_mask = gt_lsds[9, :, :, :]

    gt_maxima = tf.reshape(
        max_detection(
            tf.reshape(gt_soft_mask,
                       [1] + gt_soft_mask.get_shape().as_list() + [1]),
            [1, 1, 5, 5, 1], 0.5), gt_soft_mask.get_shape())
    pred_maxima = tf.reshape(
        max_detection(
            tf.reshape(soft_mask,
                       [1] + gt_soft_mask.get_shape().as_list() + [1]),
            [1, 1, 5, 5, 1], 0.5), gt_soft_mask.get_shape())

    # Soft weights for binary mask
    binary_mask = tf.cast(gt_soft_mask > 0, tf.float32)
    loss_weights_soft_mask = tf.ones(binary_mask.get_shape())
    loss_weights_soft_mask += tf.multiply(binary_mask,
                                          tf.reduce_sum(binary_mask))
    loss_weights_soft_mask -= binary_mask

    loss = tf.losses.mean_squared_error(soft_mask, gt_soft_mask,
                                        loss_weights_soft_mask)

    summary = tf.summary.scalar('loss', loss)

    opt = tf.train.AdamOptimizer(learning_rate=0.5e-4,
                                 beta1=0.95,
                                 beta2=0.999,
                                 epsilon=1e-8)
    optimizer = opt.minimize(loss)

    output_shape = output_shape[1:]
    print("input shape : %s" % (input_shape, ))
    print("output shape: %s" % (output_shape, ))

    output_dir = "./checkpoints/run_{}".format(run)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    tf.train.export_meta_graph(filename=output_dir + "/" + name + '.meta')

    config = {
        'raw': raw.name,
        'soft_mask': soft_mask.name,
        'gt_lsds': gt_lsds.name,
        'gt_maxima': gt_maxima.name,
        'pred_maxima': pred_maxima.name,
        'loss_weights_soft_mask': loss_weights_soft_mask.name,
        'loss': loss.name,
        'optimizer': optimizer.name,
        'input_shape': input_shape,
        'output_shape': output_shape,
        'summary': summary.name,
    }

    with open(output_dir + "/" + name + '.json', 'w') as f:
        json.dump(config, f)
Exemple #3
0
def mknet(parameter, name):
    learning_rate = parameter['learning_rate']
    input_shape = tuple(parameter['input_size'])
    fmap_inc_factor = parameter['fmap_inc_factor']
    downsample_factors = parameter['downsample_factors']
    fmap_num = parameter['fmap_num']
    unet_model = parameter['unet_model']
    num_heads = 2 if unet_model == 'dh_unet' else 1
    m_loss_scale = parameter['m_loss_scale']
    d_loss_scale = parameter['d_loss_scale']
    voxel_size = tuple(parameter['voxel_size'])  # only needed for computing
    # field of view. No impact on the actual architecture.

    assert unet_model == 'vanilla' or unet_model == 'dh_unet', \
        'unknown unetmodel {}'.format(unet_model)

    tf.reset_default_graph()

    # d, h, w
    raw = tf.placeholder(tf.float32, shape=input_shape)

    # b=1, c=1, d, h, w
    raw_batched = tf.reshape(raw, (1, 1) + input_shape)

    # b=1, c=fmap_num, d, h, w
    outputs, fov, voxel_size = models.unet(raw_batched,
                                           fmap_num,
                                           fmap_inc_factor,
                                           downsample_factors,
                                           num_heads=num_heads,
                                           voxel_size=voxel_size)
    if num_heads == 1:
        outputs = (outputs, outputs)
    print('unet has fov in nm: ', fov)

    # b=1, c=3, d, h, w
    partner_vectors_batched, fov = models.conv_pass(
        outputs[0],
        kernel_sizes=[1],
        num_fmaps=3,
        activation=None,  # Regression
        name='partner_vector')

    # b=1, c=1, d, h, w
    syn_indicator_batched, fov = models.conv_pass(outputs[1],
                                                  kernel_sizes=[1],
                                                  num_fmaps=1,
                                                  activation=None,
                                                  name='syn_indicator')
    print('fov in nm: ', fov)

    # d, h, w
    output_shape = tuple(syn_indicator_batched.get_shape().as_list()
                         [2:])  # strip batch and channel dimension.
    syn_indicator_shape = output_shape

    # c=3, d, h, w
    partner_vectors_shape = (3, ) + syn_indicator_shape

    # c=3, d, h, w
    pred_partner_vectors = tf.reshape(partner_vectors_batched,
                                      partner_vectors_shape)
    gt_partner_vectors = tf.placeholder(tf.float32,
                                        shape=partner_vectors_shape)
    vectors_mask = tf.placeholder(tf.float32,
                                  shape=syn_indicator_shape)  # d,h,w
    gt_mask = tf.placeholder(tf.bool, shape=syn_indicator_shape)  # d,h,w
    vectors_mask = tf.cast(vectors_mask, tf.bool)

    # d, h, w
    pred_syn_indicator = tf.reshape(
        syn_indicator_batched, syn_indicator_shape)  # squeeze batch dimension
    gt_syn_indicator = tf.placeholder(tf.float32, shape=syn_indicator_shape)
    indicator_weight = tf.placeholder(tf.float32, shape=syn_indicator_shape)

    partner_vectors_loss_mask = tf.losses.mean_squared_error(
        gt_partner_vectors,
        pred_partner_vectors,
        tf.reshape(vectors_mask, (1, ) + syn_indicator_shape),
        reduction=tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS)

    syn_indicator_loss_weighted = tf.losses.sigmoid_cross_entropy(
        gt_syn_indicator,
        pred_syn_indicator,
        indicator_weight,
        reduction=tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS)
    pred_syn_indicator_out = tf.sigmoid(pred_syn_indicator)  # For output.

    iteration = tf.Variable(1.0, name='training_iteration', trainable=False)
    loss = m_loss_scale * syn_indicator_loss_weighted + d_loss_scale * partner_vectors_loss_mask

    # Monitor in tensorboard.

    tf.summary.scalar('loss', loss)
    tf.summary.scalar('loss_vectors', partner_vectors_loss_mask)
    tf.summary.scalar('loss_indicator', syn_indicator_loss_weighted)
    summary = tf.summary.merge_all()

    # l=1, d, h, w
    print("input shape : %s" % (input_shape, ))
    print("output shape: %s" % (output_shape, ))

    # Train Ops.
    opt = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                 beta1=0.95,
                                 beta2=0.999,
                                 epsilon=1e-8)

    gvs_ = opt.compute_gradients(loss)
    optimizer = opt.apply_gradients(gvs_, global_step=iteration)

    tf.train.export_meta_graph(filename=name + '.meta')

    names = {
        'raw': raw.name,
        'gt_partner_vectors': gt_partner_vectors.name,
        'pred_partner_vectors': pred_partner_vectors.name,
        'gt_syn_indicator': gt_syn_indicator.name,
        'pred_syn_indicator': pred_syn_indicator.name,
        'pred_syn_indicator_out': pred_syn_indicator_out.name,
        'indicator_weight': indicator_weight.name,
        'vectors_mask': vectors_mask.name,
        'gt_mask': gt_mask.name,
        'loss': loss.name,
        'optimizer': optimizer.name,
        'summary': summary.name,
        'input_shape': input_shape,
        'output_shape': output_shape
    }

    names['outputs'] = {
        'pred_syn_indicator_out': {
            "out_dims": 1,
            "out_dtype": "uint8"
        },
        'pred_partner_vectors': {
            "out_dims": 3,
            "out_dtype": "float32"
        }
    }
    if m_loss_scale == 0:
        names['outputs'].pop('pred_syn_indicator_out')
    if d_loss_scale == 0:
        names['outputs'].pop('pred_partner_vectors')

    with open(name + '_config.json', 'w') as f:
        json.dump(names, f)

    total_parameters = 0
    for variable in tf.trainable_variables():
        # shape is an array of tf.Dimension
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        total_parameters += variable_parameters
    print("Number of parameters:", total_parameters)
    print("Estimated size of parameters in GB:",
          float(total_parameters) * 8 / (1024 * 1024 * 1024))
def mk_net(**kwargs):

    tf.reset_default_graph()

    input_shape = kwargs['input_shape']
    if not isinstance(input_shape, tuple):
        input_shape = tuple(input_shape)

    # create a placeholder for the 3D raw input tensor
    raw = tf.placeholder(tf.float32, shape=input_shape, name="raw")

    # create a U-Net
    raw_batched = tf.reshape(raw, (1, kwargs['num_channels']) + input_shape)
    model, _, _ = unet(raw_batched,
                       num_fmaps=kwargs['num_fmaps'],
                       fmap_inc_factors=kwargs['fmap_inc_factors'],
                       fmap_dec_factors=kwargs['fmap_dec_factors'],
                       downsample_factors=kwargs['downsample_factors'],
                       activation=kwargs['activation'],
                       padding=kwargs['padding'],
                       kernel_size=kwargs['kernel_size'],
                       num_repetitions=kwargs['num_repetitions'],
                       upsampling=kwargs['upsampling'],
                       crop_factor=kwargs.get('crop_factor', True))
    print(model)

    num_patch_fmaps = np.prod(kwargs['patchshape'])
    model, _ = conv_pass(model,
                         kernel_sizes=[1],
                         num_fmaps=num_patch_fmaps,
                         padding=kwargs['padding'],
                         activation=None,
                         name="output")
    print(model)

    logits = tf.squeeze(model, axis=0)
    output_shape = logits.get_shape().as_list()[1:]
    logitspatch = logits

    pred_affs = tf.sigmoid(logitspatch)

    raw_cropped = crop(raw, output_shape)

    # placeholder for gt
    gt_affs_shape = pred_affs.get_shape().as_list()
    gt_affs = tf.placeholder(tf.float32, shape=gt_affs_shape, name="gt_affs")
    anchor = tf.placeholder(tf.float32,
                            shape=[1] + output_shape,
                            name="anchor")

    # loss
    # loss_weights_affs = tf.placeholder(
    #     tf.float32,
    #     shape=pred_affs.get_shape(),
    #     name="loss_weights_affs")

    loss = tf.losses.sigmoid_cross_entropy(gt_affs, logitspatch)
    # loss_weights_affs)

    loss_sums = []
    loss_sums.append(tf.summary.scalar('loss_sum', loss))
    summaries = tf.summary.merge(loss_sums, name="summaries")

    # optimizer
    learning_rate = tf.placeholder_with_default(kwargs['lr'],
                                                shape=(),
                                                name="learning-rate")
    opt = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                 beta1=0.95,
                                 beta2=0.999,
                                 epsilon=1e-8)
    optimizer = opt.minimize(loss)

    tf.train.export_meta_graph(
        filename=os.path.join(kwargs['output_folder'], kwargs['name'] +
                              '.meta'))

    fn = os.path.join(kwargs['output_folder'], kwargs['name'])
    names = {
        'raw': raw.name,
        'raw_cropped': raw_cropped.name,
        'gt_affs': gt_affs.name,
        'pred_affs': pred_affs.name,
        # 'loss_weights_affs': loss_weights_affs.name,
        'anchor': anchor.name,
        'loss': loss.name,
        'optimizer': optimizer.name,
        'summaries': summaries.name
    }
    with open(fn + '_names.json', 'w') as f:
        json.dump(names, f)

    config = {
        'input_shape': input_shape,
        'gt_affs_shape': gt_affs_shape,
        'output_shape': output_shape,
    }

    with open(fn + '_config.json', 'w') as f:
        json.dump(config, f)
Exemple #5
0
def mk_net(**kwargs):

    tf.reset_default_graph()

    input_shape = kwargs['input_shape']
    if not isinstance(input_shape, tuple):
        input_shape = tuple(input_shape)

    # create a placeholder for the 3D raw input tensor
    raw = tf.placeholder(tf.float32, shape=input_shape, name="raw")

    # create a U-Net
    raw_batched = tf.reshape(raw, (1, 1) + input_shape)
    # unet_output = unet(raw_batched, 14, 4, [[1,3,3],[1,3,3],[1,3,3]])
    model, _, _ = unet(raw_batched,
                       num_fmaps=kwargs['num_fmaps'],
                       fmap_inc_factors=kwargs['fmap_inc_factors'],
                       fmap_dec_factors=kwargs['fmap_dec_factors'],
                       downsample_factors=kwargs['downsample_factors'],
                       activation=kwargs['activation'],
                       padding=kwargs['padding'],
                       kernel_size=kwargs['kernel_size'],
                       num_repetitions=kwargs['num_repetitions'],
                       upsampling=kwargs['upsampling'])
    print(model)

    # add a convolution layer to create 3 output maps representing affinities
    # in z, y, and x
    model, _ = conv_pass(
        model,
        kernel_sizes=[1],
        num_fmaps=4,
        # num_repetitions=1,
        padding=kwargs['padding'],
        activation=None,
        name="output")
    print(model)

    # the 4D output tensor (channels, depth, height, width)
    pred = tf.squeeze(model, axis=0)
    output_shape = pred.get_shape().as_list()[1:]
    pred_affs, pred_fgbg = tf.split(pred, [3, 1], 0)

    raw_cropped = crop(raw, output_shape)
    raw_cropped = tf.expand_dims(raw_cropped, 0)

    # create a placeholder for the corresponding ground-truth affinities
    gt_affs = tf.placeholder(tf.float32,
                             shape=pred_affs.get_shape(),
                             name="gt_affs")
    gt_labels = tf.placeholder(tf.float32,
                               shape=pred_fgbg.get_shape(),
                               name="gt_labels")
    gt_fgbg = tf.placeholder(tf.float32,
                             shape=pred_fgbg.get_shape(),
                             name="gt_fgbg")
    anchor = tf.placeholder(tf.float32,
                            shape=pred_fgbg.get_shape(),
                            name="anchor")
    # gt_fgbg = tf.clip_by_value(gt_labels, 0, 1)

    # create a placeholder for per-voxel loss weights
    # loss_weights_affs = tf.placeholder(
    #     tf.float32,
    #     shape=pred_affs.get_shape(),
    #     name="loss_weights_affs")
    loss_weights_fgbg = tf.placeholder(tf.float32,
                                       shape=pred_fgbg.get_shape(),
                                       name="loss_weights_fgbg")

    # compute the loss as the weighted mean squared error between the
    # predicted and the ground-truth affinities
    loss_fgbg, pred_fgbg, loss_fgbg_print = \
        util.get_loss_weighted(gt_fgbg, pred_fgbg, loss_weights_fgbg,
                               kwargs['loss'], "fgbg", True)

    neighborhood = [[-1, 0, 0], [0, -1, 0], [0, 0, -1]]
    gt_seg = tf.squeeze(gt_labels, axis=0, name="gt_seg")
    pred_affs = tf.identity(pred_affs, name="pred_affs")
    loss_malis = malis.malis_loss_op(pred_affs,
                                     gt_affs,
                                     gt_seg,
                                     neighborhood,
                                     name="malis_loss")

    loss = (kwargs['loss_malis_coeff'] * loss_malis +
            kwargs['loss_fgbg_coeff'] * loss_fgbg)

    loss_malis_sum = tf.summary.scalar('loss_malis_sum',
                                       kwargs['loss_malis_coeff'] * loss_malis)
    loss_fgbg_sum = tf.summary.scalar('loss_fgbg_sum',
                                      kwargs['loss_fgbg_coeff'] * loss_fgbg)
    loss_sum = tf.summary.scalar('loss_sum', loss)
    summaries = tf.summary.merge([loss_malis_sum, loss_fgbg_sum, loss_sum],
                                 name="summaries")

    learning_rate = tf.placeholder_with_default(kwargs['lr'],
                                                shape=(),
                                                name="learning-rate")
    # use the Adam optimizer to minimize the loss
    opt = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                 beta1=0.95,
                                 beta2=0.999,
                                 epsilon=1e-8)
    optimizer = opt.minimize(loss)

    # store the network in a meta-graph file
    tf.train.export_meta_graph(
        filename=os.path.join(kwargs['output_folder'], kwargs['name'] +
                              '.meta'))

    # store network configuration for use in train and predict scripts
    fn = os.path.join(kwargs['output_folder'], kwargs['name'])
    names = {
        'raw': raw.name,
        'raw_cropped': raw_cropped.name,
        'pred_affs': pred_affs.name,
        'gt_affs': gt_affs.name,
        'gt_labels': gt_labels.name,
        # 'loss_weights_affs': loss_weights_affs.name,
        'pred_fgbg': pred_fgbg.name,
        'gt_fgbg': gt_fgbg.name,
        'anchor': anchor.name,
        'loss_weights_fgbg': loss_weights_fgbg.name,
        'loss': loss.name,
        'optimizer': optimizer.name,
        'summaries': summaries.name
    }

    with open(fn + '_names.json', 'w') as f:
        json.dump(names, f)

    config = {
        'input_shape': input_shape,
        'output_shape': output_shape,
    }
    with open(fn + '_config.json', 'w') as f:
        json.dump(config, f)
Exemple #6
0
def mk_net(**kwargs):

    tf.reset_default_graph()

    input_shape = kwargs['input_shape']
    if not isinstance(input_shape, tuple):
        input_shape = tuple(input_shape)

    # create a placeholder for the 3D raw input tensor
    raw = tf.placeholder(tf.float32, shape=input_shape, name="raw")

    # create a U-Net
    raw_batched = tf.reshape(raw, (1, 1) + input_shape)
    # unet_output = unet(raw_batched, 14, 4, [[1,3,3],[1,3,3],[1,3,3]])
    model, _, _ = unet(raw_batched,
                       num_fmaps=kwargs['num_fmaps'],
                       fmap_inc_factors=kwargs['fmap_inc_factors'],
                       fmap_dec_factors=kwargs['fmap_dec_factors'],
                       downsample_factors=kwargs['downsample_factors'],
                       activation=kwargs['activation'],
                       padding=kwargs['padding'],
                       kernel_size=kwargs['kernel_size'],
                       num_repetitions=kwargs['num_repetitions'],
                       upsampling=kwargs['upsampling'],
                       crop_factor=kwargs.get('crop_factor', True))
    print(model)

    model, _ = conv_pass(
        model,
        kernel_sizes=[1],
        num_fmaps=3,
        padding=kwargs['padding'],
        activation=None,
        name="output")
    print(model)

    # the 4D output tensor (channels, depth, height, width)
    pred_threeclass = tf.squeeze(model, axis=0)
    output_shape = pred_threeclass.get_shape().as_list()[1:]

    pred_class_max = tf.argmax(pred_threeclass, axis=0, output_type=tf.int32)
    pred_class_max = tf.expand_dims(pred_class_max, 0)
    pred_fgbg = tf.nn.softmax(pred_threeclass, dim=0)[0]
    pred_fgbg = tf.expand_dims(pred_fgbg, 0)

    raw_cropped = crop(raw, output_shape)
    raw_cropped = tf.expand_dims(raw_cropped, 0)

    # create a placeholder for the corresponding ground-truth
    gt_threeclass = tf.placeholder(tf.int32, shape=[1]+output_shape,
                               name="gt_threeclass")
    gt_threeclassTmp = tf.squeeze(gt_threeclass, 0)
    anchor = tf.placeholder(tf.float32, shape=gt_threeclass.get_shape(),
                             name="anchor")

    # create a placeholder for per-voxel loss weights
    loss_weights_threeclass = tf.placeholder(
        tf.float32,
        shape=gt_threeclass.get_shape(),
        name="loss_weights_threeclass")

    loss_threeclass, _, loss_threeclass_print = \
        util.get_loss_weighted(gt_threeclassTmp,
                               tf.transpose(pred_threeclass, [1, 2, 3, 0]),
                               tf.transpose(loss_weights_threeclass, [1, 2, 3, 0]),
                               kwargs['loss'], "threeclass", False)

    if kwargs['debug']:
        _, _, loss_threeclass_print2 = \
        util.get_loss(gt_threeclassTmp, tf.transpose(pred_threeclass, [1, 2, 3, 0]),
                      kwargs['loss'], "threeclass", False)
        print_ops = loss_threeclass_print + loss_threeclass_print2
    else:
        print_ops = None
    with tf.control_dependencies(print_ops):
        loss = (1.0 * loss_threeclass)

    loss_sum = tf.summary.scalar('loss_sum', loss)
    summaries = tf.summary.merge([loss_sum], name="summaries")

    learning_rate = tf.placeholder_with_default(kwargs['lr'], shape=(),
                                                name="learning-rate")
    # use the Adam optimizer to minimize the loss
    opt = tf.train.AdamOptimizer(
        learning_rate=learning_rate,
        beta1=0.95,
        beta2=0.999,
        epsilon=1e-8)
    optimizer = opt.minimize(loss)

    # store the network in a meta-graph file
    tf.train.export_meta_graph(filename=os.path.join(kwargs['output_folder'],
                                                     kwargs['name'] +'.meta'))

    # store network configuration for use in train and predict scripts
    fn = os.path.join(kwargs['output_folder'], kwargs['name'])
    names = {
        'raw': raw.name,
        'raw_cropped': raw_cropped.name,
        'gt_threeclass': gt_threeclass.name,
        'pred_threeclass': pred_threeclass.name,
        'pred_class_max': pred_class_max.name,
        'pred_fgbg': pred_fgbg.name,
        'anchor': anchor.name,
        'loss_weights_threeclass': loss_weights_threeclass.name,
        'loss': loss.name,
        'optimizer': optimizer.name,
        'summaries': summaries.name
    }

    with open(fn + '_names.json', 'w') as f:
        json.dump(names, f)

    config = {
        'input_shape': input_shape,
        'output_shape': output_shape,
    }
    with open(fn + '_config.json', 'w') as f:
        json.dump(config, f)
Exemple #7
0
def create_network(input_shape,
                   name,
                   setup,
                   voxel_size=[40, 4, 4],
                   nms_window=[1, 1, 10, 10, 1],
                   nms_threshold=0.5):

    tf.reset_default_graph()

    # Useless and causes incompatibility - reinsert for legacy code
    # with tf.variable_scope('setup_{}'.format(setup)):
    raw = tf.placeholder(tf.float32, shape=input_shape)
    raw_batched = tf.reshape(raw, (1, 1) + input_shape)

    out, _, _ = models.unet(raw_batched, 12, 5,
                            [[1, 3, 3], [1, 3, 3], [1, 3, 3]])

    lsds_batched, _ = models.conv_pass(out,
                                       kernel_sizes=[1],
                                       num_fmaps=10,
                                       activation=None)

    output_shape_batched = lsds_batched.get_shape().as_list()
    output_shape = output_shape_batched[1:]  # strip the batch dimension

    lsds = tf.reshape(lsds_batched, output_shape)
    soft_mask = lsds[9, :, :, :]
    soft_mask = tf.clip_by_value(soft_mask, 0, 1.0)
    derivatives = lsds[:9, :, :, :]

    gt_lsds = tf.placeholder(tf.float32, shape=output_shape)
    gt_soft_mask = gt_lsds[9, :, :, :]
    gt_derivatives = gt_lsds[:9, :, :, :]

    print(gt_soft_mask.get_shape().as_list())
    print(soft_mask.get_shape().as_list())
    print(list(output_shape))
    print(list(output_shape_batched))

    gt_maxima, gt_reduced_maxima = max_detection(
        tf.reshape(gt_soft_mask,
                   [1] + gt_soft_mask.get_shape().as_list() + [1]), nms_window,
        nms_threshold)
    pred_maxima, pred_reduced_maxima = max_detection(
        tf.reshape(soft_mask, [1] + gt_soft_mask.get_shape().as_list() + [1]),
        nms_window, nms_threshold)

    # Soft weights for binary mask
    binary_mask = tf.cast(gt_soft_mask > 0, tf.float32)
    loss_weights_soft_mask = tf.ones(binary_mask.get_shape())
    loss_weights_soft_mask += tf.multiply(binary_mask,
                                          tf.reduce_sum(binary_mask))
    loss_weights_soft_mask -= binary_mask

    loss_weights_lsds = tf.stack([loss_weights_soft_mask] * 10)

    loss = tf.losses.mean_squared_error(lsds, gt_lsds, loss_weights_lsds)

    summary = tf.summary.scalar('loss', loss)

    opt = tf.train.AdamOptimizer(learning_rate=0.5e-4,
                                 beta1=0.95,
                                 beta2=0.999,
                                 epsilon=1e-8)
    optimizer = opt.minimize(loss)

    output_shape = output_shape[1:]
    print("input shape : %s" % (input_shape, ))
    print("output shape: %s" % (output_shape, ))

    tf.train.export_meta_graph(filename=name + '.meta')

    config = {
        'raw': raw.name,
        'derivatives': derivatives.name,
        'soft_mask': soft_mask.name,
        'gt_lsds': gt_lsds.name,
        'gt_maxima': gt_maxima.name,
        'gt_reduced_maxima': gt_reduced_maxima.name,
        'pred_maxima': pred_maxima.name,
        'pred_reduced_maxima': pred_reduced_maxima.name,
        'loss_weights_lsds': loss_weights_soft_mask.name,
        'loss': loss.name,
        'optimizer': optimizer.name,
        'input_shape': input_shape,
        'output_shape': output_shape,
        'summary': summary.name,
    }

    config['outputs'] = {
        'soft_mask': {
            "out_dims": 1,
            "out_dtype": "uint8"
        },
        'derivatives': {
            "out_dims": 9,
            "out_dtype": "uint8"
        },
        'reduced_maxima': {
            "out_dims": 1,
            "out_dtype": "uint8"
        }
    }

    config['voxel_size'] = voxel_size

    with open(name + '.json', 'w') as f:
        json.dump(config, f)
def mknet_binary(config: Dict[str, Any], output_path: Path = Path()):

    input_shape = (2, ) + tuple(config["INPUT_SHAPE"])
    num_fmaps_foreground = config["NUM_FMAPS_FOREGROUND"]
    fmap_inc_factors_foreground = config["FMAP_INC_FACTORS_FOREGROUND"]
    downsample_factors = config["DOWNSAMPLE_FACTORS"]
    kernel_size_up = config["KERNEL_SIZE_UP"]

    output_shape = np.array(config["OUTPUT_SHAPE"])

    raw = tf.placeholder(tf.float32, shape=input_shape, name="raw_input")
    loss_weights = tf.placeholder(tf.float32,
                                  shape=output_shape,
                                  name="loss_weights")
    gt_labels = tf.placeholder(tf.int64, shape=output_shape, name="gt_labels")

    raw_batched = tf.reshape(raw, (1, ) + input_shape)

    with tf.variable_scope("fg"):
        fg_unet = unet(
            raw_batched,
            num_fmaps=num_fmaps_foreground,
            fmap_inc_factors=fmap_inc_factors_foreground,
            downsample_factors=downsample_factors,
            kernel_size_up=kernel_size_up,
            constant_upsample=True,
        )

    fg_batched = conv_pass(fg_unet[0],
                           kernel_sizes=[1],
                           num_fmaps=1,
                           activation=None)[0]

    output_shape_batched = fg_batched.get_shape().as_list()
    output_shape = tuple(
        output_shape_batched[2:])  # strip the batch and channel dimension

    assert all(
        np.isclose(np.array(output_shape), np.array(
            config["OUTPUT_SHAPE"]))), "output shapes don't match"

    fg_logits = tf.reshape(fg_batched[0], output_shape, name="fg_logits")
    fg = tf.sigmoid(fg_logits, name="fg")
    gt_fg = tf.not_equal(gt_labels, 0, name="gt_fg")

    fg_loss = tf.losses.sigmoid_cross_entropy(gt_fg,
                                              fg_logits,
                                              weights=loss_weights,
                                              scope="fg_loss")

    opt = tf.train.AdamOptimizer(learning_rate=0.5e-5,
                                 beta1=0.95,
                                 beta2=0.999,
                                 epsilon=1e-8)

    optimizer = opt.minimize(fg_loss)

    tf.summary.scalar("fg_loss", fg_loss)

    summaries = tf.summary.merge_all()

    tf.train.export_meta_graph(filename=output_path /
                               "train_net_foreground.meta")
    names = {
        "raw": raw.name,
        "gt_labels": gt_labels.name,
        "gt_fg": gt_fg.name,
        "fg_pred": fg.name,
        "fg_logits": fg_logits.name,
        "loss_weights": loss_weights.name,
        "fg_loss": fg_loss.name,
        "optimizer": optimizer.name,
        "summaries": summaries.name,
    }

    with (output_path / "tensor_names.json").open("w") as f:
        json.dump(names, f)
def mknet(config: Dict[str, Any], output_path: Path = Path()):

    network_name = config["NETWORK_NAME"]
    input_shape = (2, ) + tuple(config["INPUT_SHAPE"])
    output_shape = tuple(config["OUTPUT_SHAPE"])
    embedding_dims = config["EMBEDDING_DIMS"]
    num_fmaps_embedding = config["NUM_FMAPS_EMBEDDING"]
    fmap_inc_factors_embedding = config["FMAP_INC_FACTORS_EMBEDDING"]
    downsample_factors = config["DOWNSAMPLE_FACTORS"]
    kernel_size_up = config["KERNEL_SIZE_UP"]

    raw = tf.placeholder(tf.float32, shape=input_shape, name="raw_input")
    raw_batched = tf.reshape(raw, (1, ) + input_shape)

    fg_pred = tf.placeholder(tf.float32, shape=output_shape, name="fg_pred")
    gt_labels = tf.placeholder(tf.int64, shape=output_shape, name="gt_labels")
    loss_weights = tf.placeholder(tf.float32,
                                  shape=output_shape,
                                  name="loss_weights")

    with tf.variable_scope("embedding"):
        embedding_unet = unet(
            raw_batched,
            num_fmaps=num_fmaps_embedding,
            fmap_inc_factors=fmap_inc_factors_embedding,
            downsample_factors=downsample_factors,
            kernel_size_up=kernel_size_up,
            constant_upsample=True,
        )

    embedding_batched = conv_pass(
        embedding_unet[0],
        kernel_sizes=[1],
        num_fmaps=embedding_dims,
        activation=None,
        name="embedding",
    )

    embedding_norms = tf.norm(embedding_batched[0], axis=1, keep_dims=True)
    embedding_scaled = embedding_batched[0] / embedding_norms

    output_shape_batched = embedding_scaled.get_shape().as_list()
    output_shape = tuple(
        output_shape_batched[2:])  # strip the batch and channel dimension

    assert all(
        np.isclose(np.array(output_shape), np.array(
            config["OUTPUT_SHAPE"]))), "output shapes don't match"

    embedding = tf.reshape(embedding_scaled, (embedding_dims, ) + output_shape)

    tf.train.export_meta_graph(filename=output_path / f"{network_name}.meta")
    names = {
        "raw": raw.name,
        "embedding": embedding.name,
        "fg_pred": fg_pred.name,
        "gt_labels": gt_labels.name,
        "loss_weights": loss_weights.name,
    }
    with (output_path / "tensor_names.json").open("w") as f:
        json.dump(names, f)