コード例 #1
0
ファイル: test_collect.py プロジェクト: mehrdad-shokri/ludwig
def test_collect_activations(csv_filename):
    model = None
    try:
        # This will reset the layer numbering scheme TensorFlow uses.
        # Otherwise, when we load the model, its layer names will be appended
        # with "_1".
        tf.keras.backend.reset_uids()

        model = _train(*_prepare_data(csv_filename))
        model_path = os.path.join(model.exp_dir_name, 'model')

        layers = _get_layers(model_path)
        assert len(layers) > 0

        tf.keras.backend.reset_uids()
        with tempfile.TemporaryDirectory() as output_directory:
            filenames = collect_activations(model_path,
                                            layers,
                                            data_csv=csv_filename,
                                            output_directory=output_directory)
            assert len(filenames) > len(layers)
    finally:
        if model and model.exp_dir_name:
            shutil.rmtree(model.exp_dir_name, ignore_errors=True)
コード例 #2
0
def _collect_activations(model_path, layers, csv_filename, output_directory):
    return collect_activations(model_path,
                               layers,
                               data_csv=csv_filename,
                               output_directory=output_directory)