def test_find_latest_checkpoint(self):
     # Populate a folder of images and try checkpoint
     checkpoints_path = os.path.join(self.tmp_dir, "test1")
     # Create files
     self.assertEqual(None, train.find_latest_checkpoint(checkpoints_path))
     # When fail_safe is turned off, throw an exception when no checkpoint is found.
     six.assertRaisesRegex(self, ValueError, "Checkpoint path",
                           train.find_latest_checkpoint, checkpoints_path,
                           False)
     for suffix in ["0", "2", "4", "12", "_config.json", "ABC"]:
         open(checkpoints_path + '.' + suffix, 'a').close()
     self.assertEqual(checkpoints_path + ".12",
                      train.find_latest_checkpoint(checkpoints_path))
Пример #2
0
def test():

    checkpoints_path = "checkpoints/resnet_unet_3"
    assert (os.path.isfile(checkpoints_path +
                           "_config.json")), "Checkpoint not found."

    latest_weights = find_latest_checkpoint(checkpoints_path)
    print(latest_weights)
Пример #3
0
def model_from_checkpoint_path(model, checkpoints_path):

    assert (os.path.isfile(checkpoints_path +
                           "_config.json")), "Checkpoint not found."
    model_config = json.loads(
        open(checkpoints_path + "_config.json", "r").read())
    latest_weights = find_latest_checkpoint(checkpoints_path)
    assert (not latest_weights is None), "Checkpoint not found."
    print("loaded weights ", latest_weights)
    model.load_weights(latest_weights)
    return model
Пример #4
0
def unet_model_from_checkpoint_path(checkpoints_path):

    assert (os.path.isfile(checkpoints_path+"_config.json")
            ), "Checkpoint not found."
    model_config = json.loads(
        open(checkpoints_path+"_config.json", "r").read())
    latest_weights = find_latest_checkpoint(checkpoints_path)
    assert (latest_weights is not None), "Checkpoint not found."
    model = _unet_depth_segm(
        model_config['n_classes'], input_height=model_config['input_height'],
        input_width=model_config['input_width'], encoder=get_mobilenet_encoder)
    model.load_weights(latest_weights)
    return model
Пример #5
0
def model_from_checkpoint_path(checkpoints_path):

    assert (os.path.isfile(checkpoints_path +
                           "_config.json")), "Checkpoint not found."
    model_config = json.loads(
        open(checkpoints_path + "_config.json", "r").read())
    latest_weights = find_latest_checkpoint(checkpoints_path)
    assert (latest_weights is not None), "Checkpoint not found."
    model = pspnet(model_config['n_classes'],
                   input_height=model_config['input_height'],
                   input_width=model_config['input_width'])
    model.summary()
    model.load_weights(latest_weights)
    return model
Пример #6
0
def model_from_checkpoint_path(checkpoints_path):

    from .models.all_models import model_from_name
    assert (os.path.isfile(checkpoints_path+"_config.json")
            ), "Checkpoint not found."
    model_config = json.loads(
        open(checkpoints_path+"_config.json", "r").read())
    latest_weights = find_latest_checkpoint(checkpoints_path)
    assert (latest_weights is not None), "Checkpoint not found."
    model = model_from_name[model_config['model_class']](
        model_config['n_classes'], input_height=model_config['input_height'],
        input_width=model_config['input_width'])
    print("loaded weights ", latest_weights)
    model.load_weights(latest_weights)
    return model
Пример #7
0
    cv2.imwrite(img, binary)

from keras_segmentation.models.pspnet import pspnet
from keras_segmentation.pretrained import model_from_checkpoint_path
from keras_segmentation.train import find_latest_checkpoint

#train model
model = pspnet(n_classes=2)
model.train(train_images="dataset/image",
            train_annotations="dataset/mask",
            checkpoints_path="/usr/code/tmp/checkpoints",
            epochs=5)

#load model
model_config = {
    'model_class': 'pspnet',
    'n_classes': 2,
    "input_height": 384,
    "input_width": 576
}
latest_weight = find_latest_checkpoint("/usr/code/tmp/checkpoints")
model = model_from_checkpoint_path(model_config, latest_weight)

#order to produce prediction images
'''
 python -m keras_segmentation predict \
 --checkpoints_path="/usr/code/tmp_2/checkpoints" \
 --input_path="/usr/code/dataset_test/image/" \
 --output_path="/usr/code/dataset_test/predict/"
 '''
Пример #8
0
    val_images="dataset/block_val",
    val_annotations="dataset/block_val_anno",
    val_steps_per_epoch=32,
    val_batch_size=2,
    checkpoints_path="new_model",
    load_weights=None,  #"new_model.80", # load_weights='new_model.10' 继续训练
    batch_size=2,  # 默认2。batch_size太大的话,会导致 GPU 内存不够
    steps_per_epoch=128,  # 默认512。每一代用多少个batch, None代表自动分割,即数据集样本数/batch样本数
    epochs=140,
)
#OOM when allocating tensor with shape[2,4096,90,90]

duration = datetime.datetime.now() - start_time
print('-------- Training time (s): {}'.format(duration.seconds))

new_model_checkpoint = find_latest_checkpoint("new_model")
new_model.load_weights(new_model_checkpoint)
out = new_model.predict_segmentation(inp="dataset/block_val/003_h3_w6.jpg",
                                     out_fname="out.png")

out = np.array(out).astype(np.float32)  # 转换为浮点型,plt.imshow才认为图像在0到1内,即1是白的。
plt.imshow(out)

##out_color = cv2.imread("out2.png")
#out_color = get_colored_segmentation_image(out,n_classes=2)
#out_color = np.array(out_color).astype(np.int32)
#plt.imshow(out_color)

#import sys
#sys.exit()