def load_model_from_run(backbone_name, load_model_from, load_weights_from=None, skip_mismatch=True): b = backbone(backbone_name) json_path = get_json_filename(load_model_from) if not os.path.exists(json_path): h5_path = get_weights_filename(load_model_from) if not os.path.exists(h5_path): raise ValueError("run with name %s doesn't exist" % load_model_from) model = load_model(h5_path, custom_objects=b.custom_objects, compile=False) else: with open(json_path, 'r') as json_file: json_string = json_file.read() model = model_from_json(json_string, custom_objects=b.custom_objects) if load_weights_from: h5_path = get_weights_filename(load_weights_from) if os.path.exists(h5_path): model.load_weights(h5_path, by_name=True, skip_mismatch=skip_mismatch) return model
def save_model_to_run(model, run_name): json_path = get_json_filename(run_name) h5_path = get_weights_filename(run_name) with open(json_path, 'w') as json_file: json_file.write(model.to_json()) model.save_weights(h5_path)
def load_model_weights_from(model, weights, skip_mismatch): if weights is None: return if os.path.exists(weights): model.load_weights(weights, by_name=True, skip_mismatch=skip_mismatch) return weights_path = get_weights_filename(weights) if os.path.exists(weights_path): model.load_weights(weights_path, by_name=True, skip_mismatch=skip_mismatch) return raise ValueError('Unknown weights to load from!', weights, weights_path)
def config_seg_callbacks(run_name=None): callbacks = [ ValidationPrediction(show_confusion_matrix=False), ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, verbose=1, mode='auto', min_lr=1e-7), ] if run_name: callbacks.extend([ ModelCheckpoint(get_weights_filename(run_name), monitor='val_loss', save_best_only=True, save_weights_only=True, verbose=True), CSVLogger(filename=get_csv_filename(run_name)) ]) return callbacks