Exemplo n.º 1
0
model_weights_name = 'unet_weight_model.hdf5'
# TODO: move to config .json files
img_height = 512
img_width = 512
img_size = (img_height, img_width)
model = UNet(input_size=(img_width, img_height, 1),
             n_filters=64,
             pretrained_weights=model_weights_name)
train_gen = train_generator(batch_size=2,
                            train_path=train_path,
                            image_folder='liver',
                            mask_folder='mask',
                            target_size=img_size)
model.build()
model_name2 = '/home/new-ece/szc0173/liver/unet_model2.hdf5'
model_checkpoint2 = model.checkpoint(model_name2)
model.fit_generator(train_gen,
                    steps_per_epoch=280,
                    epochs=10,
                    callbacks=[model_checkpoint2])
#history1=model.fit(
#    train_gen,
#    steps_per_epoch =280 ,
#    epochs = 10,
#    callbacks = [model_checkpoint2]
#    )

# saving model weights
model.save_model('/home/new-ece/szc0173/liver/unet_weight_model2.hdf5')
model.save('/home/new-ece/szc0173/liver/unet_model2.hdf5')
                                target_size=img_size)

    # check if pretrained weights are defined
    if is_file(file_name=model_weights_name):
        pretrained_weights = model_weights_name
    else:
        pretrained_weights = None

    # build model
    unett = UNet(input_size=(img_width, img_height, 1),
                 n_filters=64,
                 pretrained_weights=pretrained_weights)
    unett.build()

    # creating a callback, hence best weights configurations will be saved
    model_checkpoint = unett.checkpoint(model_name)

    # model training
    # steps per epoch should be equal to number of samples in database divided by batch size
    # in this case, it is 560 / 2 = 280
    unett.fit_generator(train_gen,
                        steps_per_epoch=281,
                        epochs=10,
                        callbacks=[model_checkpoint])
    #history=unett.fit(
    #    train_gen,
    #    steps_per_epoch=280,
    #    epochs=5,
    #    callbacks=[model_checkpoint]
    #    )