Esempio n. 1
0
 def test_save_to_file_and_load_from_file_roundtrip_complex_nontorch_root(
         self, complex_multi_layered_nontorch_root, pickle_only,
         compress_save_file):
     TORCH_TAG_PREFIX = "torch"
     make_component(torch.nn.Module,
                    TORCH_TAG_PREFIX,
                    only_module='torch.nn')
     old_obj = complex_multi_layered_nontorch_root(from_config=True)
     state = old_obj.get_state()
     with tempfile.TemporaryDirectory() as root_path:
         path = os.path.join(root_path, 'savefile.flambe')
         save_state_to_file(state, path, compress_save_file, pickle_only)
         list_files(path)
         if pickle_only:
             path += '.pkl'
         if compress_save_file:
             path += '.tar.gz'
         state_loaded = load_state_from_file(path)
         check_mapping_equivalence(state, state_loaded)
         check_mapping_equivalence(state._metadata, state_loaded._metadata)
     new_obj = complex_multi_layered_nontorch_root(from_config=True)
     int_state = new_obj.get_state()
     new_obj.load_state(state_loaded, strict=False)
     old_state = old_obj.get_state()
     new_state = new_obj.get_state()
     check_mapping_equivalence(new_state, old_state)
     check_mapping_equivalence(old_state._metadata, new_state._metadata)
     check_mapping_equivalence(int_state._metadata, state_loaded._metadata)
Esempio n. 2
0
    def test_save_to_file_and_load_from_file_with_extensions(
            self, mock_add_extensions, mock_import_module,
            mock_installed_module, compress_save_file, pickle_only, schema):
        """Test that extensions are saved to the output config.yaml
        and they are also added when loading back the object."""

        mock_installed_module.return_value = True

        schema_obj = schema()

        # Add extensions manually because if we use add_extensions_metadata
        # then no extensions will be added as the schema doesn't container_folder
        # any prefix.
        schema_obj._extensions = TestSerializationExtensions.EXTENSIONS

        obj = schema_obj()
        state = obj.get_state()

        with tempfile.TemporaryDirectory() as root_path:
            path = os.path.join(root_path, 'savefile.flambe')
            save_state_to_file(state, path, compress_save_file, pickle_only)

            list_files(path)
            if pickle_only:
                path += '.pkl'
            if compress_save_file:
                path += '.tar.gz'
            state_loaded = load_state_from_file(path)

            check_mapping_equivalence(state, state_loaded)
            check_mapping_equivalence(state._metadata, state_loaded._metadata)

            _ = Basic.load_from_path(path)
            mock_add_extensions.assert_called_once_with(
                TestSerializationExtensions.EXTENSIONS)
Esempio n. 3
0
 def test_save_to_file_and_load_from_file_roundtrip_complex(
         self, complex_multi_layered):
     TORCH_TAG_PREFIX = "torch"
     make_component(torch.nn.Module,
                    TORCH_TAG_PREFIX,
                    only_module='torch.nn')
     old_obj = complex_multi_layered(from_config=True)
     # Test that the current state is actually saved, for a
     # Component-only child of torch objects
     old_obj.child.child.child.x = 24
     state = old_obj.get_state()
     with tempfile.TemporaryDirectory() as path:
         save_state_to_file(state, path)
         list_files(path)
         state_loaded = load_state_from_file(path)
         check_mapping_equivalence(state, state_loaded)
         # assert False
     new_obj = complex_multi_layered(from_config=True)
     new_obj.load_state(state_loaded, strict=False)
     old_state = old_obj.get_state()
     new_state = new_obj.get_state()
     check_mapping_equivalence(new_state, old_state)
     check_mapping_equivalence(old_state._metadata,
                               new_state._metadata,
                               exclude_config=False)
Esempio n. 4
0
 def test_module_save_and_load_roundtrip_pytorch_only_bridge(self):
     a = BasicStateful.compile(x=3)
     b = 100
     c = BasicStatefulTwo.compile(y=0)
     item = ComposableTorchStatefulTorchOnlyChild.compile(a=a, b=b, c=c)
     extra = torch.nn.Linear(2, 2)
     old_obj = Org.compile(item=item, extra=None)
     # x for a2 should be different from instance a
     a2 = BasicStateful.compile(x=4)
     b2 = 101
     # y for c2 should be different from instance c
     c2 = BasicStatefulTwo.compile(y=1)
     item2 = ComposableTorchStatefulTorchOnlyChild.compile(a=a2, b=b2, c=c2)
     extra2 = torch.nn.Linear(2, 2)
     new_obj = Org.compile(item=item2, extra=None)
     with tempfile.TemporaryDirectory() as root_path:
         path = os.path.join(root_path, 'asavefile2.flambe')
         old_state = old_obj.get_state()
         save_state_to_file(old_state, path)
         new_state = load_state_from_file(path)
         new_obj.load_state(new_state)
         # save(old_obj, path)
         # new_obj = load(path)
     old_state_get = old_obj.get_state()
     new_state_get = new_obj.get_state()
     check_mapping_equivalence(new_state, old_state)
     check_mapping_equivalence(old_state._metadata,
                               new_state._metadata,
                               exclude_config=False)
     check_mapping_equivalence(new_state_get, old_state_get)
     check_mapping_equivalence(old_state_get._metadata,
                               new_state_get._metadata,
                               exclude_config=True)
 def test_save_to_file_and_load_from_file_roundtrip_pytorch(self, alternating_nn_module_with_state):
     old_obj = alternating_nn_module_with_state(from_config=False)
     state = old_obj.get_state()
     with tempfile.TemporaryDirectory() as path:
         save_state_to_file(state, path)
         state = load_state_from_file(path)
     new_obj = alternating_nn_module_with_state(from_config=False)
     new_obj.load_state(state, strict=False)
     old_state = old_obj.get_state()
     new_state = new_obj.get_state()
     check_mapping_equivalence(new_state, old_state)
     check_mapping_equivalence(old_state._metadata, new_state._metadata, exclude_config=False)
 def test_save_to_file_and_load_from_file_roundtrip(self, basic_object):
     old_obj = basic_object(from_config=True)
     state = old_obj.get_state()
     with tempfile.TemporaryDirectory() as path:
         save_state_to_file(state, path)
         state = load_state_from_file(path)
     new_obj = basic_object(from_config=False)
     new_obj.load_state(state, strict=False)
     old_state = old_obj.get_state()
     new_state = new_obj.get_state()
     check_mapping_equivalence(new_state, old_state)
     check_mapping_equivalence(old_state._metadata, new_state._metadata, exclude_config=True)
    def test_module_save_and_load_single_instance_appears_twice(self, make_classes_2):
        txt = """
!C
one: !A
  akw2: &theb !B
    bkw2: test
    bkw1: 1
  akw1: 8
two: !A
  akw1: 8
  # Comment Here
  akw2: *theb
"""
        c = yaml.load(txt)()
        c.one.akw2.bkw1 = 6
        assert c.one.akw2 is c.two.akw2
        assert c.one.akw2.bkw1 == c.two.akw2.bkw1
        with tempfile.TemporaryDirectory() as path:
            save(c, path)
            state = load_state_from_file(path)
            loaded_c = load(path)
        assert loaded_c.one.akw2 is loaded_c.two.akw2
        assert loaded_c.one.akw2.bkw1 == loaded_c.two.akw2.bkw1