Exemplo n.º 1
0
def extract_weigths():
    from model import create_model_infant_seg
    from main import prepare_dirs, setup_tensorflow

    prepare_dirs(delete_train_dir=False)
    sess, summary_writer = setup_tensorflow()

    (tf_t1_input, tf_t2_input, tf_label, aux1_pred, aux2_pred, main_pred,
     aux1_loss, aux2_loss, main_loss, final_loss, gene_vars,
     main_possibility) = create_model_infant_seg(train_phase=False)

    saver = tf.train.Saver()
    model_path = tf.train.latest_checkpoint(FLAGS.last_trained_checkpoint)

    print('saver restore from:%s' % model_path)
    saver.restore(sess, model_path)

    print '** after resotre..'
    hdf5_data = h5py.File(FLAGS.model_saved_hdf5, 'w')
    for op in tf.trainable_variables():

        para_name = op.name.split('/')[-1]
        if para_name.startswith('weights') or para_name.startswith('biases'):
            # layer_name = op.name.split('/')[0]
            # para_name = para_name[:-2]
            print '>> create data value:', op.name[:-2]
            value = sess.run(op)
            dset = hdf5_data.create_dataset(op.name[:-2],
                                            value.shape,
                                            dtype=np.float32)
            dset[...] = value

    hdf5_data.close()
Exemplo n.º 2
0
def train():
    prepare_dirs(delete_train_dir=True)
    sess, summary_writer = setup_tensorflow()

    (tf_t1_input, tf_t2_input, tf_label, aux1_pred, aux2_pred, main_pred,
     aux1_loss, aux2_loss, main_loss, final_loss, gene_vars,
     main_possibility) = create_model_infant_seg(train_phase=True)

    train_minimize, learning_rate, global_step = create_optimizers(final_loss)

    train_data = TrainData(locals())
    train_model(train_data)
Exemplo n.º 3
0
def test():
    prepare_dirs(delete_train_dir=False)
    sess, summary_writer = setup_tensorflow()
    # here for test, batch_size of tf_input is 1

    (tf_t1_input, tf_t2_input, tf_label, aux1_pred, aux2_pred, main_pred,
     aux1_loss, aux2_loss, main_loss, final_loss, gene_vars,
     main_possibility) = create_model_infant_seg(train_phase=False)

    saver = tf.train.Saver()
    model_path = tf.train.latest_checkpoint(FLAGS.last_trained_checkpoint)
    print('saver restore from:%s' % model_path)
    saver.restore(sess, model_path)

    test_data = TestData(locals())
    predict_multi_modality_test_images_in_nifti(test_data)