Example #1
0
#%%
x = load_image_data()

#%%
x = x.unbatch()

x = x.batch(batch_size=10)
shuffled_data = x.shuffle(buffer_size=SHUFFLE_BUFFER_SIZE)
test = shuffled_data.take(TEST_SIZE).repeat()
train = shuffled_data.skip(TEST_SIZE).repeat()
#%%

model = DnCNN(depth=17)
model.compile(optimizer=keras.optimizers.Adam(),
              loss=dcnn_loss,
              metrics=[psnr])

now = datetime.now()
tensorboard_callback = keras.callbacks.TensorBoard(
    log_dir='logs\log_from_{}'.format(now.strftime("%Y-%m-%d_at_%H-%M-%S")),
    histogram_freq=1)

model.fit(x=train,
          steps_per_epoch=1000,
          validation_data=test,
          epochs=5,
          validation_steps=50,
          callbacks=[tensorboard_callback])
model.summary()
#                                 nomal=True,
#                                  fill_mode='constant')
# generator = train_datagen.flow_from_directory(file_path=train_file_path,
#             data_dir=data_dir, data_suffix=data_suffix,
#             label_dir=label_dir, label_suffix=label_suffix,
#             target_size=target_shape, color_mode='grayscale',
#             batch_size=batch_size, shuffle=True,
#             loss_shape=None)

scheduler = LearningRateScheduler(lr_scheduler)
callbacks = [scheduler]

# ################### checkpoint saver#######################
checkpoint = ModelCheckpoint(filepath=os.path.join(save_path,
                                                   'checkpoint_weights.h5'),
                             save_weights_only=True)  # .{epoch:d}
callbacks.append(checkpoint)

# model = srcnn(input_shape=input_shape, kernel_size=[3, 3])
model = DnCNN(input_shape=input_shape)
# model.load_weights('unet_optics_l2.h5')
model.compile(loss=mean_squared_error, optimizer='adadelta')
model.summary()
history = model.fit(input_data,
                    input_label,
                    batch_size=batch_size,
                    nb_epoch=epochs,
                    callbacks=callbacks,
                    verbose=1)

model.save_weights('DnCNN_l2_mnist_combinenoise200_noise.h5')