def _test_save_load(model, X, optimizer_name, optimizer):
    saved_model_preds = model.predict(X[0])
    saved_model_weights = K.batch_get_value(model.trainable_weights)
    saved_optim_weights = K.batch_get_value(model.optimizer.weights)

    test_name = 'test__%f{}.h5'.format(np.random.random())
    modelpath = os.path.join(tempfile.gettempdir(), test_name)
    model.save(modelpath)
    del model
    if TF_2 and not TF_EAGER and not TF_KERAS:
        tf.compat.v1.experimental.output_all_intermediates(True)  # bug fix

    model = load_model(modelpath, custom_objects={optimizer_name: optimizer})
    loaded_model_preds = model.predict(X[0])
    loaded_model_weights = K.batch_get_value(model.trainable_weights)
    loaded_optim_weights = K.batch_get_value(model.optimizer.weights)

    assert np.allclose(saved_model_preds,
                       loaded_model_preds,
                       rtol=0,
                       atol=1e-8)
    for smw, lmw in zip(saved_model_weights, loaded_model_weights):
        assert np.allclose(smw, lmw, rtol=0, atol=1e-8)
    for sow, low in zip(saved_optim_weights, loaded_optim_weights):
        assert np.allclose(sow, low, rtol=0, atol=1e-8)
Beispiel #2
0
def _validate_save_load(tg, C):
    def _get_load_path(tg, logdir):
        for postfix in ('weights', 'model', 'model_noopt'):
            postfix += '.h5'
            path = [str(p) for p in Path(logdir).iterdir()
                    if p.name.endswith(postfix)]
            if path:
                return path[0]
        raise Exception(f"no model save file found in {logdir}")

    # get behavior before saving, to ensure no changes presave-to-postload
    data = np.random.randn(*tg.model.input_shape)

    Wm_save = tg.model.get_weights()
    Wo_save = K.batch_get_value(tg.model.optimizer.weights)
    preds_save = tg.model.predict(data, batch_size=len(data))

    tg.checkpoint()
    logdir = tg.logdir
    tg.destroy(confirm=True)

    C['traingen']['logdir'] = logdir
    path = _get_load_path(tg, logdir)
    if path.endswith('weights.h5'):
        model = make_classifier(**C['model'])
        model.load_weights(path)
    else:
        model = load_model(path)
    tg = init_session(C, model=model)
    tg.load()

    Wm_load = tg.model.get_weights()
    Wo_load = K.batch_get_value(tg.model.optimizer.weights)
    preds_load = tg.model.predict(data, batch_size=len(data))

    for s, l in zip(Wm_save, Wm_load):
        assert np.allclose(s, l), "max absdiff: %s" % np.max(np.abs(s - l))
    for s, l in zip(Wo_save, Wo_load):
        assert np.allclose(s, l), "max absdiff: %s" % np.max(np.abs(s - l))
    assert np.allclose(preds_save, preds_load), (
          "max absdiff: %s" % np.max(np.abs(preds_save - preds_load)))