Пример #1
0
def build_net(steps=steps_inference, mode='train'):
    unet0 = scale_net.SerialUNet([12, 12*6, 12*6**2], [12*6, 12*6, 12*6**2],
                                 [(2, 2, 2), (3, 3, 3)],
                                 [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]],
                                 [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], input_voxel_size=(4, 4, 4))
    unet1 = scale_net.SerialUNet([12, 12*6, 12*6**2], [12*6**2, 12*6**2, 12*6**2],
                                 [(3, 3, 3), (3, 3, 3)],
                                 [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]],
                                 [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], input_voxel_size=(36, 36, 36))
    net = make_any_scale_net([unet0,unet1], labels, steps, mode=mode)
    return net
Пример #2
0
    labels.append(
        Label("microtubules", (30, 31), data_sources=data_sources, data_dir=data_dir)
    )
    labels.append(
        Label("centrosome", (31, 32, 33), data_sources=data_sources, data_dir=data_dir)
    )
    labels.append(Label("distal_app", 32, data_sources=data_sources, data_dir=data_dir))
    labels.append(
        Label("subdistal_app", 33, data_sources=data_sources, data_dir=data_dir)
    )
    labels.append(Label("ribosomes", 1, data_sources=ribo_sources, data_dir=data_dir))

    unet0 = scale_net.SerialUNet(
        [12, 12 * 6, 12 * 6 ** 2],
        [48, 12 * 6, 12 * 6 ** 2],
        [(3, 3, 3), (3, 3, 3)],
        [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]],
        [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]],
        input_voxel_size=(4, 4, 4),
    )
    unet1 = scale_net.SerialUNet(
        [12, 12 * 6, 12 * 6 ** 2],
        [12 * 6 ** 2, 12 * 6 ** 2, 12 * 6 ** 2],
        [(3, 3, 3), (3, 3, 3)],
        [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]],
        [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]],
        input_voxel_size=(36, 36, 36),
    )
    make_any_scale_net([unet0, unet1], labels, 4, mode="inference")
    tf.reset_default_graph()
    train_sc_net = make_any_scale_net(
        [unet0, unet1], labels, 5, mode="train", loss_name=loss_name
def make_net(labels, added_steps, mode='train', loss_name='loss_total'):
    unet0 = scale_net.SerialUNet(
        [12, 12 * 6, 12 * 6**2], [48, 12 * 6, 12 * 6**2], [(3, 3, 3),
                                                           (3, 3, 3)],
        [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3),
                                                          (3, 3, 3)]],
        [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]],
        input_voxel_size=(4, 4, 4))
    unet1 = scale_net.SerialUNet(
        [12, 12 * 6, 12 * 6**2], [12 * 6**2, 12 * 6**2, 12 * 6**2],
        [(3, 3, 3),
         (3, 3, 3)], [[(3, 3, 3),
                       (3, 3, 3)], [(3, 3, 3),
                                    (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]],
        [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]],
        input_voxel_size=(36, 36, 36))
    # input_voxel_size=(
    # 36,36,36))
    input_size = unet0.min_input_shape
    input_size_actual = input_size + added_steps * unet0.step_valid_shape
    scnet = scale_net.ScaleNet([unet0, unet1],
                               input_size_actual,
                               name='scnet_' + mode)
    inputs = []
    names = dict()
    for k, (inp, vs) in enumerate(zip(scnet.input_shapes, scnet.voxel_sizes)):
        raw = tf.placeholder(tf.float32, shape=inp)
        raw_bc = tf.reshape(raw, (1, 1) + tuple(inp.astype(np.int)))
        inputs.append(raw_bc)
        names['raw_{0:}'.format(vs[0])] = raw.name

    last_fmap, fov, anisotropy = scnet.build(inputs)

    dist_bc, fov = ops3d.conv_pass(last_fmap,
                                   kernel_size=[(1, 1, 1)],
                                   num_fmaps=len(labels),
                                   activation=None,
                                   fov=fov,
                                   voxel_size=anisotropy)
    output_shape_bc = dist_bc.get_shape().as_list()
    output_shape_c = output_shape_bc[1:]
    output_shape = output_shape_c[1:]

    dist_c = tf.reshape(dist_bc, output_shape_c)
    names['dist'] = dist_c.name
    network_outputs = tf.unstack(dist_c, len(labels), axis=0)
    if mode.lower() == 'train' or mode.lower() == 'training':
        #mask = tf.placeholder(tf.float32, shape=output_shape)
        #names['mask'] = mask.name
        #ribo_mask = tf.placeholder(tf.float32, shape=output_shape)
        #names['ribo_mask'] = ribo_mask.name
        gt = []
        w = []
        cw = []
        masks = []
        for l in labels:
            masks.append(tf.placeholder(tf.float32, shape=output_shape))
            gt.append(tf.placeholder(tf.float32, shape=output_shape))
            w.append(tf.placeholder(tf.float32, shape=output_shape))
            cw.append(l.class_weight)
        lb = []
        lub = []
        for output_it, gt_it, w_it, m_it, l in zip(network_outputs, gt, w,
                                                   masks, labels):
            lb.append(
                tf.losses.mean_squared_error(gt_it, output_it, w_it * m_it))
            lub.append(tf.losses.mean_squared_error(gt_it, output_it, m_it))
            #if l.labelname != 'ribosomes':
            #    lub.append(tf.losses.mean_squared_error(gt_it, output_it, mask))
            #else:
            #    lub.append(tf.losses.mean_squared_error(gt_it, output_it, ribo_mask))
            names[l.labelname] = output_it.name
            names['gt_' + l.labelname] = gt_it.name
            names['w_' + l.labelname] = w_it.name
            names['mask_' + l.labelname] = m_it.name
        for l, lb_it, lub_it in zip(labels, lb, lub):
            tf.summary.scalar('lb_' + l.labelname, lb_it)
            tf.summary.scalar('lub_' + l.labelname, lub_it)
            names['lb_' + l.labelname] = lb_it.name
            names['lub_' + l.labelname] = lub_it.name

        loss_total = tf.add_n(lb)
        loss_total_unbalanced = tf.add_n(lub)
        loss_total_classweighted = tf.tensordot(lb, cw, axes=1)
        loss_total_unbalanced_classweighted = tf.tensordot(lub, cw, axes=1)

        tf.summary.scalar('loss_total', loss_total)
        names['loss_total'] = loss_total.name
        tf.summary.scalar('loss_total_unbalanced', loss_total_unbalanced)
        names['loss_total_unbalanced'] = loss_total_unbalanced.name
        tf.summary.scalar('loss_total_classweighted', loss_total_classweighted)
        names['loss_total_classweighted'] = loss_total_classweighted.name
        tf.summary.scalar('loss_total_unbalanced_classweighted',
                          loss_total_unbalanced_classweighted)
        names[
            'loss_total_unbalanced_classweighted'] = loss_total_unbalanced_classweighted.name

        opt = tf.train.AdamOptimizer(learning_rate=0.5e-4,
                                     beta1=0.95,
                                     beta2=0.999,
                                     epsilon=1e-8)
        if loss_name == 'loss_total':
            optimizer = opt.minimize(loss_total)
        elif loss_name == 'loss_total_unbalanced':
            optimizer = opt.minimize(loss_total_unbalanced)
        elif loss_name == 'loss_total_unbalanced_classweighted':
            optimizer = opt.minimize(loss_total_unbalanced_classweighted)
        elif loss_name == 'loss_total_classweighted':
            optimizer = opt.minimize(loss_total_classweighted)
        else:
            raise ValueError(loss_name + " not defined")
        names['optimizer'] = optimizer.name
        merged = tf.summary.merge_all()
        names['summary'] = merged.name
        with open('net_io_names.json', 'w') as f:
            json.dump(names, f)
    elif mode.lower() == 'inference' or mode.lower() == 'prediction':
        pass
    else:
        raise ValueError(
            "unknown mode for network construction: {0:}".format(mode))
    tf.train.export_meta_graph(filename=scnet.name + '.meta')
    return scnet