Пример #1
0
from implementations.support_scripts.common import h5_small_vgg_generator
import numpy as np
from skimage import io, color
import matplotlib.pyplot as plt

g = h5_small_vgg_generator(16, "../h5_data", None)
b = next(g)
vgg = b[0][1][4, :, :, :]
s = b[0][0][4, :, :, :]
c = b[1][4, :, :, :]
im = np.concatenate((s, c), axis=2)
rgb = color.lab2rgb(im)
imgplot = plt.imshow(rgb)
plt.show()
plt.imshow(vgg[:, :, 0])
plt.show()
Пример #2
0
opt = optimizers.Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
model.compile(optimizer=opt, loss=custom_mse)

model.summary()

start_from = 75
save_every_n_epoch = 5
n_epochs = 10000
model.load_weights("../weights/implementation7d-relu-70.h5")

# start image downloader
# ip = ImagePacker("../small_dataset", "../h5_data",  "imp7d-relu-", num_images=1024, num_files=None)
# ip.start()
ip = None

g = h5_small_vgg_generator(b_size, "../h5_data", ip)
gval = h5_small_vgg_generator(b_size, "../h5_validate", None)

for i in range(start_from // save_every_n_epoch,
               n_epochs // save_every_n_epoch):
    print("START", i * save_every_n_epoch, "/", n_epochs)
    history = model.fit_generator(g,
                                  steps_per_epoch=60000 / b_size,
                                  epochs=save_every_n_epoch,
                                  validation_data=gval,
                                  validation_steps=(128 // b_size))
    model.save_weights("../weights/implementation7d-relu-" +
                       str(i * save_every_n_epoch) + ".h5")

    # save sample images
    whole_image_check(model, 20,
model = Model(inputs=[main_input, vgg16.input], output=last)
opt = optimizers.Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
model.compile(optimizer=opt, loss=custom_mse, metrics=[root_mean_squared_error, mean_squared_error])

model.summary()

start_from = 0
save_every_n_epoch = 1
n_epochs = 10000
# model.load_weights("../weights/implementation9-bn-24.h5")

# start image downloader
ip = None

g = h5_small_vgg_generator(b_size, "../data/h5_small_train", ip)
gval = h5_small_vgg_generator(b_size, "../data/h5_small_validation", None)


for i in range(start_from // save_every_n_epoch, n_epochs // save_every_n_epoch):
    print("START", i * save_every_n_epoch, "/", n_epochs)
    history = model.fit_generator(g, steps_per_epoch=100000//b_size, epochs=save_every_n_epoch,
                                  validation_data=gval, validation_steps=(10000//b_size))
    model.save_weights("../weights/implementation9-bn-" + str(i * save_every_n_epoch) + ".h5")

    # save sample images
    whole_image_check_overlapping(model, 80, "imp9-bn-" + str(i * save_every_n_epoch) + "-")

    # save history
    output = open('../history/imp9-bn-{:0=4d}.pkl'.format(i * save_every_n_epoch), 'wb')
    pickle.dump(history.history, output)