def _test_save_load(model, X, optimizer_name, optimizer): saved_model_preds = model.predict(X[0]) saved_model_weights = K.batch_get_value(model.trainable_weights) saved_optim_weights = K.batch_get_value(model.optimizer.weights) test_name = 'test__%f{}.h5'.format(np.random.random()) modelpath = os.path.join(tempfile.gettempdir(), test_name) model.save(modelpath) del model if TF_2 and not TF_EAGER and not TF_KERAS: tf.compat.v1.experimental.output_all_intermediates(True) # bug fix model = load_model(modelpath, custom_objects={optimizer_name: optimizer}) loaded_model_preds = model.predict(X[0]) loaded_model_weights = K.batch_get_value(model.trainable_weights) loaded_optim_weights = K.batch_get_value(model.optimizer.weights) assert np.allclose(saved_model_preds, loaded_model_preds, rtol=0, atol=1e-8) for smw, lmw in zip(saved_model_weights, loaded_model_weights): assert np.allclose(smw, lmw, rtol=0, atol=1e-8) for sow, low in zip(saved_optim_weights, loaded_optim_weights): assert np.allclose(sow, low, rtol=0, atol=1e-8)
def _validate_save_load(tg, C): def _get_load_path(tg, logdir): for postfix in ('weights', 'model', 'model_noopt'): postfix += '.h5' path = [str(p) for p in Path(logdir).iterdir() if p.name.endswith(postfix)] if path: return path[0] raise Exception(f"no model save file found in {logdir}") # get behavior before saving, to ensure no changes presave-to-postload data = np.random.randn(*tg.model.input_shape) Wm_save = tg.model.get_weights() Wo_save = K.batch_get_value(tg.model.optimizer.weights) preds_save = tg.model.predict(data, batch_size=len(data)) tg.checkpoint() logdir = tg.logdir tg.destroy(confirm=True) C['traingen']['logdir'] = logdir path = _get_load_path(tg, logdir) if path.endswith('weights.h5'): model = make_classifier(**C['model']) model.load_weights(path) else: model = load_model(path) tg = init_session(C, model=model) tg.load() Wm_load = tg.model.get_weights() Wo_load = K.batch_get_value(tg.model.optimizer.weights) preds_load = tg.model.predict(data, batch_size=len(data)) for s, l in zip(Wm_save, Wm_load): assert np.allclose(s, l), "max absdiff: %s" % np.max(np.abs(s - l)) for s, l in zip(Wo_save, Wo_load): assert np.allclose(s, l), "max absdiff: %s" % np.max(np.abs(s - l)) assert np.allclose(preds_save, preds_load), ( "max absdiff: %s" % np.max(np.abs(preds_save - preds_load)))