示例#1
0
文件: io.py 项目: zectre/delta
def save_model(model, filename):
    """
    Save a model. Includes DELTA configuration.
    """
    model.save(filename, save_format='h5')
    with h5py.File(filename, 'r+') as f:
        f.attrs['delta'] = config.export()
示例#2
0
def test_save(tmp_path):
    tmp_path = tmp_path / 'temp.h5'
    inputs = keras.layers.Input((None, None, 1))
    out = keras.layers.Conv2D(filters=9, kernel_size=2)(inputs)
    m = keras.Model(inputs, out)
    io.save_model(m, tmp_path)
    with h5py.File(tmp_path, 'r') as f:
        assert f.attrs['delta'] == config.export()
示例#3
0
文件: train.py 项目: zectre/delta
def _mlflow_train_setup(model, dataset, training_spec):
    mlflow.set_tracking_uri(config.mlflow.uri())
    mlflow.set_experiment(config.mlflow.experiment())
    mlflow.start_run()
    _log_mlflow_params(model, dataset, training_spec)

    temp_dir = tempfile.mkdtemp()
    fname = os.path.join(temp_dir, 'config.yaml')
    with open(fname, 'w') as f:
        f.write(config.export())
    mlflow.log_artifact(fname)
    os.remove(fname)

    return _MLFlowCallback(temp_dir)
示例#4
0
def save_model(model, filename):
    """
    Save a model. Includes DELTA configuration.

    Parameters
    ----------
    model: tensorflow.keras.models.Model
        The model to save.
    filename: str
        Output filename.
    """
    model.save(filename, save_format='h5')
    with h5py.File(filename, 'r+') as f:
        f.attrs['delta'] = config.export()
示例#5
0
文件: io.py 项目: nasa/delta
def save_model(model, filename):
    """
    Save a model. Includes DELTA configuration.

    Parameters
    ----------
    model: tensorflow.keras.models.Model
        The model to save.
    filename: str
        Output filename.
    """
    if str(filename).endswith('.h5'):
        model.save(filename, save_format='h5')
        with h5py.File(filename, 'r+') as f:
            f.attrs['delta'] = config.export()
    else: # SavedModel format
        model.save(filename)
        # Record the config values into a subfolder of the savedmodel output folder
        config_copy_folder = os.path.join(filename, 'assets.extra')
        config_copy_path   = os.path.join(config_copy_folder, 'delta_config.yaml')
        if not os.path.exists(config_copy_folder):
            os.mkdir(config_copy_folder)
        with open(config_copy_path, 'w') as f:
            f.write(config.export())
示例#6
0
def test_dump():
    config_reset()

    assert config.to_dict() == yaml.load(config.export())