def __init__(self, nx, ny, numCirlces=20, numImages=100): self.numImages = numImages # define the generator self.generator = image_gen.GrayScaleDataProvider(nx, ny, cnt=numCirlces) # number of channels and output classes self.channels = self.generator.channels self.n_class = self.generator.n_class # generate the image set self._gen_imageset()
from __future__ import division, print_function import matplotlib.pyplot as plt import matplotlib import numpy as np plt.rcParams['image.cmap'] = 'gist_earth' np.random.seed(98765) from tf_unet import image_gen from tf_unet import unet from tf_unet import util nx = 3072 ny = 3072 generator = image_gen.GrayScaleDataProvider(nx, ny, cnt=20) x_test, y_test = generator(1) print(generator.channels) ''' fig, ax = plt.subplots(1,2, sharey=True, figsize=(8,4)) ax[0].imshow(x_test[0,...,0], aspect="auto") ax[1].imshow(y_test[0,...,1], aspect="auto") ''' net = unet.Unet(channels=generator.channels, n_class=generator.n_class, layers=3, features_root=16) trainer = unet.Trainer(net, optimizer="momentum", opt_kwargs=dict(momentum=0.2))
''' Created on Jul 28, 2016 author: jakeret ''' from __future__ import print_function, division, absolute_import, unicode_literals import numpy as np from tf_unet import image_gen from tf_unet import unet from tf_unet import util if __name__ == '__main__': np.random.seed(98765) generator = image_gen.GrayScaleDataProvider(nx=572, ny=572, cnt=20, rectangles=False) net = unet.Unet(channels=generator.channels, n_class=generator.n_class, layers=3, features_root=16) trainer = unet.Trainer(net, optimizer="momentum", opt_kwargs=dict(momentum=0.2)) path = trainer.train( generator, "./unet_trained", training_iters=32, epochs=5,
import matplotlib.pyplot as plt import matplotlib import numpy as np import cv2 plt.rcParams['image.cmap'] = 'gist_earth' np.random.seed(98765) # import unet module from tf_unet import image_gen from tf_unet import unet from img_provider import simple_data_provider from tf_unet import util generator_artificial = image_gen.GrayScaleDataProvider(572, 572, cnt=20, rectangles=True) generator_reallife = simple_data_provider(x=572, y=572, nclass=3, channel='red', test=False) net_real = unet.Unet(channels=generator_reallife.channels, n_class=generator_reallife.n_class, layers=4, features_root=16) trainer_real = unet.Trainer(net_real, optimizer="momentum", opt_kwargs=dict(momentum=0.2))
import matplotlib.pyplot as plt import matplotlib import numpy as np from tf_unet import image_gen from tf_unet import unet from tf_unet import util plt.rcParams['image.cmap'] = 'gist_earth' nx = 256 ny = 256 generator = image_gen.GrayScaleDataProvider(nx, ny, cnt=20, depth_3d=8) x_test, y_test = generator(1) net = unet.Unet3D(channels=generator.channels, n_class=generator.n_class, layers=3, features_root=16) trainer = unet.Trainer(net,optimizer="momentum", opt_kwargs=dict(momentum=0.2)) # trainer = unet.Trainer(net,optimizer="adam", opt_kwargs=dict(learning_rate=0.1)) path = trainer.train(generator,"./unet_trained", training_iters=10, epochs=50, display_step=2) prediction = net.predict("./unet_trained/model.cpkt", x_test) fig, ax = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(12,5)) ax[0].imshow(x_test[0,0,...,0], aspect="auto") ax[1].imshow(y_test[0,0,...,1], aspect="auto") mask = prediction[0,0,...,1] > 0.9 ax[2].imshow(mask, aspect="auto") ax[0].set_title("Input") ax[1].set_title("Ground truth") ax[2].set_title("Prediction") fig.tight_layout() fig.savefig("../docs/toy_problem.png")