Exemple #1
0
def train_net():
    input_shape = (84, 268, 268)
    raw = tf.placeholder(tf.float32, shape=input_shape)
    raw_batched = tf.reshape(raw, (1, 1,) + input_shape)

    last_fmap, fov, anisotropy = unet.unet(raw_batched, 12, 6, [[1, 3, 3], [1, 3, 3], [1, 3, 3]],
                                           [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)],
                                            [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]],
                                           [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)],
                                            [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]],
                                           voxel_size=(10, 1, 1), fov=(10, 1, 1))

    logits_batched, fov = unet.conv_pass(
        last_fmap,
        kernel_size=[[1, 1, 1]],
        num_fmaps=2,
        activation=None,
        fov=fov,
        voxel_size=anisotropy
    )

    output_shape_batched = logits_batched.get_shape().as_list()

    output_shape = output_shape_batched[1:]  # strip the batch dimension
    flat_logits = tf.transpose(tf.reshape(tensor=logits_batched, shape=(2,-1)))


    gt_labels = tf.placeholder(tf.float32, shape=output_shape[1:])
    gt_labels_flat = tf.reshape(gt_labels, (-1,))

    gt_bg = tf.to_float(tf.not_equal(gt_labels_flat, 1))
    flat_ohe = tf.stack(values=[gt_labels_flat, gt_bg], axis=1)
    
    loss_weights = tf.placeholder(tf.float32, shape=output_shape[1:])
    loss_weights_flat = tf.reshape(loss_weights, (-1,))
    print(logits_batched.get_shape().as_list())
    probabilities = tf.reshape(tf.nn.softmax(logits_batched, dim=1)[0], output_shape)
    predictions = tf.argmax(probabilities, axis=0)

    ce_loss_balanced = tf.losses.softmax_cross_entropy(flat_ohe, flat_logits, weights=loss_weights_flat)
    ce_loss_unbalanced = tf.losses.softmax_cross_entropy(flat_ohe, flat_logits)
    tf.summary.scalar('loss_balanced_syn', ce_loss_balanced)
    tf.summary.scalar('loss_unbalanced_syn', ce_loss_unbalanced)

    opt = tf.train.AdamOptimizer(
        learning_rate=0.5e-4,
        beta1=0.95,
        beta2=0.999,
        epsilon=1e-8)

    optimizer = opt.minimize(ce_loss_balanced)
    merged = tf.summary.merge_all()

    tf.train.export_meta_graph(filename='unet.meta')

    names = {
        'raw': raw.name,
        'probabilities': probabilities.name,
        'predictions': predictions.name,
        'gt_labels': gt_labels.name,
        'loss_balanced_syn': ce_loss_balanced.name,
        'loss_unbalanced_syn': ce_loss_unbalanced.name,
        'loss_weights': loss_weights.name,
        'optimizer': optimizer.name,
        'summary': merged.name}

    with open('net_io_names.json', 'w') as f:
        json.dump(names, f)
Exemple #2
0
def train_net():
    input_shape = (43, 430, 430)
    raw = tf.placeholder(tf.float32, shape=input_shape)
    raw_batched = tf.reshape(raw, (
        1,
        1,
    ) + input_shape)

    last_fmap, fov, anisotropy = unet.unet(
        raw_batched,
        12,
        6, [[1, 3, 3], [1, 3, 3], [3, 3, 3]],
        [[(1, 3, 3),
          (1, 3, 3)], [(1, 3, 3),
                       (1, 3, 3)], [(3, 3, 3),
                                    (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]],
        [[(1, 3, 3),
          (1, 3, 3)], [(1, 3, 3),
                       (1, 3, 3)], [(3, 3, 3),
                                    (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]],
        voxel_size=(10, 1, 1),
        fov=(10, 1, 1))

    dist_batched, fov = unet.conv_pass(last_fmap,
                                       kernel_size=[[1, 1, 1]],
                                       num_fmaps=2,
                                       activation=None,
                                       fov=fov,
                                       voxel_size=anisotropy)

    syn_dist, bdy_dist = tf.unstack(dist_batched, 2, axis=1)

    output_shape = syn_dist.get_shape().as_list()

    gt_syn_dist = tf.placeholder(tf.float32, shape=output_shape)
    gt_bdy_dist = tf.placeholder(tf.float32, shape=output_shape)

    loss_weights = tf.placeholder(tf.float32, shape=output_shape[1:])
    loss_weights_batched = tf.reshape(loss_weights, shape=output_shape)

    loss_balanced_syn = tf.losses.mean_squared_error(gt_syn_dist, syn_dist,
                                                     loss_weights_batched)
    loss_bdy = tf.losses.mean_squared_error(gt_bdy_dist, bdy_dist)
    loss_total = loss_balanced_syn + loss_bdy
    tf.summary.scalar('loss_balanced_syn', loss_balanced_syn)
    tf.summary.scalar('loss_bdy', loss_bdy)
    tf.summary.scalar('loss_total', loss_total)

    loss_unbalanced = tf.losses.mean_squared_error(gt_syn_dist, syn_dist)
    tf.summary.scalar('loss_unbalanced_syn', loss_unbalanced)

    opt = tf.train.AdamOptimizer(learning_rate=0.5e-4,
                                 beta1=0.95,
                                 beta2=0.999,
                                 epsilon=1e-8)

    optimizer = opt.minimize(loss_total)
    merged = tf.summary.merge_all()

    tf.train.export_meta_graph(filename='unet.meta')

    names = {
        'raw': raw.name,
        'syn_dist': syn_dist.name,
        'bdy_dist': bdy_dist.name,
        'gt_syn_dist': gt_syn_dist.name,
        'gt_bdy_dist': gt_bdy_dist.name,
        'loss_balanced_syn': loss_balanced_syn.name,
        'loss_unbalanced_syn': loss_unbalanced.name,
        'loss_bdy': loss_bdy.name,
        'loss_total': loss_total.name,
        'loss_weights': loss_weights.name,
        'optimizer': optimizer.name,
        'summary': merged.name
    }

    with open('net_io_names.json', 'w') as f:
        json.dump(names, f)
def train_net():
    input_shape = (84, 268, 268)
    raw = tf.placeholder(tf.float32, shape=input_shape)
    raw_batched = tf.reshape(raw, (
        1,
        1,
    ) + input_shape)

    last_fmap, fov, anisotropy = unet.unet(
        raw_batched,
        12,
        6, [[1, 3, 3], [1, 3, 3], [1, 3, 3]],
        [[(3, 3, 3),
          (3, 3, 3)], [(3, 3, 3),
                       (3, 3, 3)], [(3, 3, 3),
                                    (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]],
        [[(3, 3, 3),
          (3, 3, 3)], [(3, 3, 3),
                       (3, 3, 3)], [(3, 3, 3),
                                    (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]],
        voxel_size=(10, 1, 1),
        fov=(10, 1, 1))

    dist_batched, fov = unet.conv_pass(last_fmap,
                                       kernel_size=[[1, 1, 1]],
                                       num_fmaps=1,
                                       activation=None,
                                       fov=fov,
                                       voxel_size=anisotropy)

    output_shape_batched = dist_batched.get_shape().as_list()

    output_shape = output_shape_batched[1:]  # strip the batch dimension

    dist = tf.reshape(dist_batched, output_shape)

    gt_dist = tf.placeholder(tf.float32, shape=output_shape)

    loss_weights = tf.placeholder(tf.float32, shape=output_shape[1:])
    loss_weights_batched = tf.reshape(loss_weights, shape=output_shape)

    loss_balanced = tf.losses.mean_squared_error(gt_dist, dist,
                                                 loss_weights_batched)
    tf.summary.scalar('loss_balanced_syn', loss_balanced)

    loss_unbalanced = tf.losses.mean_squared_error(gt_dist, dist)
    tf.summary.scalar('loss_unbalanced_syn', loss_unbalanced)

    opt = tf.train.AdamOptimizer(learning_rate=0.5e-4,
                                 beta1=0.95,
                                 beta2=0.999,
                                 epsilon=1e-8)

    optimizer = opt.minimize(loss_unbalanced)
    merged = tf.summary.merge_all()

    tf.train.export_meta_graph(filename='unet.meta')

    names = {
        'raw': raw.name,
        'dist': dist.name,
        'gt_dist': gt_dist.name,
        'loss_balanced_syn': loss_balanced.name,
        'loss_unbalanced_syn': loss_unbalanced.name,
        'loss_weights': loss_weights.name,
        'optimizer': optimizer.name,
        'summary': merged.name
    }

    with open('net_io_names.json', 'w') as f:
        json.dump(names, f)