예제 #1
0
    def __init__(self, nx, ny, numCirlces=20, numImages=100):

        self.numImages = numImages

        # define the generator
        self.generator = image_gen.GrayScaleDataProvider(nx,
                                                         ny,
                                                         cnt=numCirlces)

        # number of channels and output classes
        self.channels = self.generator.channels
        self.n_class = self.generator.n_class

        # generate the image set
        self._gen_imageset()
예제 #2
0
from __future__ import division, print_function

import matplotlib.pyplot as plt
import matplotlib
import numpy as np
plt.rcParams['image.cmap'] = 'gist_earth'
np.random.seed(98765)

from tf_unet import image_gen
from tf_unet import unet
from tf_unet import util

nx = 3072
ny = 3072

generator = image_gen.GrayScaleDataProvider(nx, ny, cnt=20)
x_test, y_test = generator(1)
print(generator.channels)
'''
fig, ax = plt.subplots(1,2, sharey=True, figsize=(8,4))
ax[0].imshow(x_test[0,...,0], aspect="auto")
ax[1].imshow(y_test[0,...,1], aspect="auto")
'''

net = unet.Unet(channels=generator.channels,
                n_class=generator.n_class,
                layers=3,
                features_root=16)
trainer = unet.Trainer(net,
                       optimizer="momentum",
                       opt_kwargs=dict(momentum=0.2))
예제 #3
0
'''
Created on Jul 28, 2016

author: jakeret
'''
from __future__ import print_function, division, absolute_import, unicode_literals
import numpy as np
from tf_unet import image_gen
from tf_unet import unet
from tf_unet import util

if __name__ == '__main__':
    np.random.seed(98765)

    generator = image_gen.GrayScaleDataProvider(nx=572,
                                                ny=572,
                                                cnt=20,
                                                rectangles=False)

    net = unet.Unet(channels=generator.channels,
                    n_class=generator.n_class,
                    layers=3,
                    features_root=16)

    trainer = unet.Trainer(net,
                           optimizer="momentum",
                           opt_kwargs=dict(momentum=0.2))
    path = trainer.train(
        generator,
        "./unet_trained",
        training_iters=32,
        epochs=5,
예제 #4
0
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import cv2
plt.rcParams['image.cmap'] = 'gist_earth'
np.random.seed(98765)

# import unet module
from tf_unet import image_gen
from tf_unet import unet
from img_provider import simple_data_provider
from tf_unet import util

generator_artificial = image_gen.GrayScaleDataProvider(572,
                                                       572,
                                                       cnt=20,
                                                       rectangles=True)
generator_reallife = simple_data_provider(x=572,
                                          y=572,
                                          nclass=3,
                                          channel='red',
                                          test=False)

net_real = unet.Unet(channels=generator_reallife.channels,
                     n_class=generator_reallife.n_class,
                     layers=4,
                     features_root=16)

trainer_real = unet.Trainer(net_real,
                            optimizer="momentum",
                            opt_kwargs=dict(momentum=0.2))
예제 #5
0
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
from tf_unet import image_gen
from tf_unet import unet
from tf_unet import util

plt.rcParams['image.cmap'] = 'gist_earth'

nx = 256
ny = 256
generator = image_gen.GrayScaleDataProvider(nx, ny, cnt=20, depth_3d=8)
x_test, y_test = generator(1)

net = unet.Unet3D(channels=generator.channels, n_class=generator.n_class, layers=3, features_root=16)
trainer = unet.Trainer(net,optimizer="momentum", opt_kwargs=dict(momentum=0.2))
# trainer = unet.Trainer(net,optimizer="adam", opt_kwargs=dict(learning_rate=0.1))

path = trainer.train(generator,"./unet_trained", training_iters=10, epochs=50, display_step=2)
prediction = net.predict("./unet_trained/model.cpkt", x_test)

fig, ax = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(12,5))
ax[0].imshow(x_test[0,0,...,0], aspect="auto")
ax[1].imshow(y_test[0,0,...,1], aspect="auto")
mask = prediction[0,0,...,1] > 0.9
ax[2].imshow(mask, aspect="auto")
ax[0].set_title("Input")
ax[1].set_title("Ground truth")
ax[2].set_title("Prediction")
fig.tight_layout()
fig.savefig("../docs/toy_problem.png")