コード例 #1
0
ファイル: test_Network.py プロジェクト: rwth-i6/returnn
def test_enc_dec1_init():
  config = Config()
  config.load_file(StringIO(config_enc_dec1_json))

  network_json = LayerNetwork.json_from_config(config)
  assert_true(network_json)
  network = LayerNetwork.from_json_and_config(network_json, config)
  assert_true(network)
コード例 #2
0
def test_enc_dec1_init():
    config = Config()
    config.load_file(StringIO(config_enc_dec1_json))

    network_json = LayerNetwork.json_from_config(config)
    assert_true(network_json)
    network = LayerNetwork.from_json_and_config(network_json, config)
    assert_true(network)
コード例 #3
0
ファイル: test_Network.py プロジェクト: rwth-i6/returnn
def test_enc_dec1_hdf():
  filename = tempfile.mktemp(prefix="crnn-model-test")
  model = h5py.File(filename, "w")

  config = Config()
  config.load_file(StringIO(config_enc_dec1_json))
  network_json = LayerNetwork.json_from_config(config)
  assert_true(network_json)
  network = LayerNetwork.from_json_and_config(network_json, config)
  assert_true(network)

  network.save_hdf(model, epoch=42)
  model.close()

  loaded_model = h5py.File(filename, "r")
  loaded_net = LayerNetwork.from_hdf_model_topology(loaded_model)
  assert_true(loaded_net)
  assert_equal(sorted(network.hidden.keys()), sorted(loaded_net.hidden.keys()))
  assert_equal(sorted(network.y.keys()), sorted(loaded_net.y.keys()))
  assert_equal(sorted(network.j.keys()), sorted(loaded_net.j.keys()))

  os.remove(filename)
コード例 #4
0
def test_enc_dec1_hdf():
    filename = tempfile.mktemp(prefix="crnn-model-test")
    model = h5py.File(filename, "w")

    config = Config()
    config.load_file(StringIO(config_enc_dec1_json))
    network_json = LayerNetwork.json_from_config(config)
    assert_true(network_json)
    network = LayerNetwork.from_json_and_config(network_json, config)
    assert_true(network)

    network.save_hdf(model, epoch=42)
    model.close()

    loaded_model = h5py.File(filename, "r")
    loaded_net = LayerNetwork.from_hdf_model_topology(loaded_model)
    assert_true(loaded_net)
    assert_equal(sorted(network.hidden.keys()),
                 sorted(loaded_net.hidden.keys()))
    assert_equal(sorted(network.y.keys()), sorted(loaded_net.y.keys()))
    assert_equal(sorted(network.j.keys()), sorted(loaded_net.j.keys()))

    os.remove(filename)