def test_module_save_and_load_roundtrip(self, basic_object, pickle_only, compress_save_file):
     old_obj = basic_object(from_config=True)
     with tempfile.TemporaryDirectory() as root_path:
         path = os.path.join(root_path, 'savefile.flambe')
         save(old_obj, path, compress_save_file, pickle_only)
         if pickle_only:
             path += '.pkl'
         if compress_save_file:
             path += '.tar.gz'
         new_obj = load(path)
     old_state = old_obj.get_state()
     new_state = new_obj.get_state()
     check_mapping_equivalence(new_state, old_state)
     check_mapping_equivalence(old_state._metadata, new_state._metadata, exclude_config=False)
    def test_module_save_and_load_single_instance_appears_twice(self, make_classes_2):
        txt = """
!C
one: !A
  akw2: &theb !B
    bkw2: test
    bkw1: 1
  akw1: 8
two: !A
  akw1: 8
  # Comment Here
  akw2: *theb
"""
        c = yaml.load(txt)()
        c.one.akw2.bkw1 = 6
        assert c.one.akw2 is c.two.akw2
        assert c.one.akw2.bkw1 == c.two.akw2.bkw1
        with tempfile.TemporaryDirectory() as path:
            save(c, path)
            state = load_state_from_file(path)
            loaded_c = load(path)
        assert loaded_c.one.akw2 is loaded_c.two.akw2
        assert loaded_c.one.akw2.bkw1 == loaded_c.two.akw2.bkw1
Example #3
0
def test_exporter_builder():
    with tmpdir() as d, tmpdir() as d2, tmpfile(
            mode="w", suffix=".yaml") as f, tmpfile(mode="w",
                                                    suffix=".yaml") as f2:
        # First run an experiment
        exp = """
!Experiment

name: exporter
save_path: {}

pipeline:
  dataset: !SSTDataset
    transform:
      text: !TextField
      label: !LabelField
  model: !TextClassifier
    embedder: !Embedder
      embedding: !torch.Embedding
        num_embeddings: !@ dataset.text.vocab_size
        embedding_dim: 30
      encoder: !PooledRNNEncoder
        input_size: 30
        rnn_type: lstm
        n_layers: 1
        hidden_size: 16
    output_layer: !SoftmaxLayer
      input_size: !@ model[embedder].encoder.rnn.hidden_size
      output_size: !@ dataset.label.vocab_size

  exporter: !Exporter
    model: !@ model
    text: !@ dataset.text
"""

        exp = exp.format(d)
        f.write(exp)
        f.flush()
        ret = subprocess.run(['flambe', f.name, '-i'])
        assert ret.returncode == 0

        # Then run a builder

        builder = """
flambe_inference: tests/data/dummy_extensions/inference/
---

!Builder

destination: {0}

component: !flambe_inference.DummyInferenceEngine
  model: !TextClassifier.load_from_path
    path: {1}
"""
        base = os.path.join(d, "output__exporter", "exporter")
        path_aux = [
            x for x in os.listdir(base) if os.path.isdir(os.path.join(base, x))
        ][0]  # Should be only 1 folder bc of no variants
        model_path = os.path.join(base, path_aux, "checkpoint",
                                  "checkpoint.flambe", "model")

        builder = builder.format(d2, model_path)
        f2.write(builder)
        f2.flush()

        ret = subprocess.run(['flambe', f2.name, '-i'])
        assert ret.returncode == 0

        # The extensions needs to be imported using extensions.py module
        extensions.import_modules(["flambe_inference"])

        # Import the module after import_modules (which registered tags already)
        from flambe_inference import DummyInferenceEngine

        eng1 = flambe.load(d2)

        assert type(eng1) is DummyInferenceEngine
        assert type(eng1.model) is TextClassifier

        extension_path = os.path.join(
            os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
            "tests/data/dummy_extensions/inference")
        assert eng1._extensions == {"flambe_inference": extension_path}

        eng2 = DummyInferenceEngine.load_from_path(d2)

        assert type(eng2) is DummyInferenceEngine
        assert type(eng2.model) is TextClassifier

        assert eng2._extensions == {"flambe_inference": extension_path}

        assert module_equals(eng1.model, eng2.model)