def test_find_latest_checkpoint(self): # Populate a folder of images and try checkpoint checkpoints_path = os.path.join(self.tmp_dir, "test1") # Create files self.assertEqual(None, train.find_latest_checkpoint(checkpoints_path)) # When fail_safe is turned off, throw an exception when no checkpoint is found. six.assertRaisesRegex(self, ValueError, "Checkpoint path", train.find_latest_checkpoint, checkpoints_path, False) for suffix in ["0", "2", "4", "12", "_config.json", "ABC"]: open(checkpoints_path + '.' + suffix, 'a').close() self.assertEqual(checkpoints_path + ".12", train.find_latest_checkpoint(checkpoints_path))
def test(): checkpoints_path = "checkpoints/resnet_unet_3" assert (os.path.isfile(checkpoints_path + "_config.json")), "Checkpoint not found." latest_weights = find_latest_checkpoint(checkpoints_path) print(latest_weights)
def model_from_checkpoint_path(model, checkpoints_path): assert (os.path.isfile(checkpoints_path + "_config.json")), "Checkpoint not found." model_config = json.loads( open(checkpoints_path + "_config.json", "r").read()) latest_weights = find_latest_checkpoint(checkpoints_path) assert (not latest_weights is None), "Checkpoint not found." print("loaded weights ", latest_weights) model.load_weights(latest_weights) return model
def unet_model_from_checkpoint_path(checkpoints_path): assert (os.path.isfile(checkpoints_path+"_config.json") ), "Checkpoint not found." model_config = json.loads( open(checkpoints_path+"_config.json", "r").read()) latest_weights = find_latest_checkpoint(checkpoints_path) assert (latest_weights is not None), "Checkpoint not found." model = _unet_depth_segm( model_config['n_classes'], input_height=model_config['input_height'], input_width=model_config['input_width'], encoder=get_mobilenet_encoder) model.load_weights(latest_weights) return model
def model_from_checkpoint_path(checkpoints_path): assert (os.path.isfile(checkpoints_path + "_config.json")), "Checkpoint not found." model_config = json.loads( open(checkpoints_path + "_config.json", "r").read()) latest_weights = find_latest_checkpoint(checkpoints_path) assert (latest_weights is not None), "Checkpoint not found." model = pspnet(model_config['n_classes'], input_height=model_config['input_height'], input_width=model_config['input_width']) model.summary() model.load_weights(latest_weights) return model
def model_from_checkpoint_path(checkpoints_path): from .models.all_models import model_from_name assert (os.path.isfile(checkpoints_path+"_config.json") ), "Checkpoint not found." model_config = json.loads( open(checkpoints_path+"_config.json", "r").read()) latest_weights = find_latest_checkpoint(checkpoints_path) assert (latest_weights is not None), "Checkpoint not found." model = model_from_name[model_config['model_class']]( model_config['n_classes'], input_height=model_config['input_height'], input_width=model_config['input_width']) print("loaded weights ", latest_weights) model.load_weights(latest_weights) return model
cv2.imwrite(img, binary) from keras_segmentation.models.pspnet import pspnet from keras_segmentation.pretrained import model_from_checkpoint_path from keras_segmentation.train import find_latest_checkpoint #train model model = pspnet(n_classes=2) model.train(train_images="dataset/image", train_annotations="dataset/mask", checkpoints_path="/usr/code/tmp/checkpoints", epochs=5) #load model model_config = { 'model_class': 'pspnet', 'n_classes': 2, "input_height": 384, "input_width": 576 } latest_weight = find_latest_checkpoint("/usr/code/tmp/checkpoints") model = model_from_checkpoint_path(model_config, latest_weight) #order to produce prediction images ''' python -m keras_segmentation predict \ --checkpoints_path="/usr/code/tmp_2/checkpoints" \ --input_path="/usr/code/dataset_test/image/" \ --output_path="/usr/code/dataset_test/predict/" '''
val_images="dataset/block_val", val_annotations="dataset/block_val_anno", val_steps_per_epoch=32, val_batch_size=2, checkpoints_path="new_model", load_weights=None, #"new_model.80", # load_weights='new_model.10' 继续训练 batch_size=2, # 默认2。batch_size太大的话,会导致 GPU 内存不够 steps_per_epoch=128, # 默认512。每一代用多少个batch, None代表自动分割,即数据集样本数/batch样本数 epochs=140, ) #OOM when allocating tensor with shape[2,4096,90,90] duration = datetime.datetime.now() - start_time print('-------- Training time (s): {}'.format(duration.seconds)) new_model_checkpoint = find_latest_checkpoint("new_model") new_model.load_weights(new_model_checkpoint) out = new_model.predict_segmentation(inp="dataset/block_val/003_h3_w6.jpg", out_fname="out.png") out = np.array(out).astype(np.float32) # 转换为浮点型,plt.imshow才认为图像在0到1内,即1是白的。 plt.imshow(out) ##out_color = cv2.imread("out2.png") #out_color = get_colored_segmentation_image(out,n_classes=2) #out_color = np.array(out_color).astype(np.int32) #plt.imshow(out_color) #import sys #sys.exit()