Exemple #1
0
    def test_ordered_importer_whichmodule(self):
        """OrderedImporter's implementation of whichmodule should try each
        underlying importer's whichmodule in order.
        """
        class DummyImporter(Importer):
            def __init__(self, whichmodule_return):
                self._whichmodule_return = whichmodule_return

            def import_module(self, module_name):
                raise NotImplementedError()

            def whichmodule(self, obj, name):
                return self._whichmodule_return

        class DummyClass:
            pass

        dummy_importer_foo = DummyImporter("foo")
        dummy_importer_bar = DummyImporter("bar")
        dummy_importer_not_found = DummyImporter(
            "__main__"
        )  # __main__ is used as a proxy for "not found" by CPython

        foo_then_bar = OrderedImporter(dummy_importer_foo, dummy_importer_bar)
        self.assertEqual(foo_then_bar.whichmodule(DummyClass(), ""), "foo")

        bar_then_foo = OrderedImporter(dummy_importer_bar, dummy_importer_foo)
        self.assertEqual(bar_then_foo.whichmodule(DummyClass(), ""), "bar")

        notfound_then_foo = OrderedImporter(dummy_importer_not_found,
                                            dummy_importer_foo)
        self.assertEqual(notfound_then_foo.whichmodule(DummyClass(), ""),
                         "foo")
Exemple #2
0
def _load_storages(id, zip_reader, obj_bytes, serialized_storages,
                   serialized_dtypes):
    def persistent_load(saved_id):
        assert isinstance(saved_id, tuple)
        typename = _maybe_decode_ascii(saved_id[0])
        data = saved_id[1:]

        if typename == "storage":
            # TODO: Once we decide to break serialization FC, we can
            # stop wrapping with TypedStorage
            storage = serialized_storages[data[0]]
            dtype = serialized_dtypes[data[0]]
            return torch.storage.TypedStorage(wrap_storage=storage.untyped(),
                                              dtype=dtype)

        if typename == "reduce_deploy":
            reduce_id, func, args = data
            if reduce_id not in _loaded_reduces:
                _loaded_reduces[reduce_id] = func(_raw_packages[zip_reader],
                                                  *args)
            return _loaded_reduces[reduce_id]

        return None

    importer: Importer
    if zip_reader is not None:
        importer = OrderedImporter(_get_package(zip_reader), sys_importer)
    else:
        importer = sys_importer

    unpickler = PackageUnpickler(importer, io.BytesIO(obj_bytes))
    unpickler.persistent_load = persistent_load  # type: ignore[assignment]
    result = _deploy_objects[id] = unpickler.load()
    return result
Exemple #3
0
def _save_storages(importer, obj):
    serialized_storages = []
    serialized_dtypes = []

    def persistent_id(obj):
        # FIXME: the docs say that persistent_id should only return a string
        # but torch store returns tuples. This works only in the binary protocol
        # see
        # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
        # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
        if torch.is_storage(obj):
            serialized_storages.append(obj)
            serialized_dtypes.append(obj.dtype)
            return ('storage', len(serialized_storages) - 1)
        return None

    # Write the pickle data for `obj`
    data_buf = io.BytesIO()
    importer = importer if isinstance(importer,
                                      torch.package.PackageImporter) else None
    importers: Importer
    if importer is not None:
        importers = OrderedImporter(importer, sys_importer)
    else:
        importers = sys_importer
    pickler = create_custom_import_pickler(data_buf, importers)
    pickler.persistent_id = persistent_id
    pickler.dump(obj)
    data_value = data_buf.getvalue()
    return data_value, serialized_storages, serialized_dtypes, importer.zip_reader if importer else None
Exemple #4
0
def _load_storages(id, zip_reader, obj_bytes, serialized_storages):

    def persistent_load(saved_id):
        assert isinstance(saved_id, tuple)
        typename = _maybe_decode_ascii(saved_id[0])
        data = saved_id[1:]

        if typename == 'storage':
            return serialized_storages[data[0]]

        if typename == 'reduce_deploy':
            reduce_id, func, args = data
            if reduce_id not in _loaded_reduces:
                _loaded_reduces[reduce_id] = func(_raw_packages[zip_reader], *args)
            return _loaded_reduces[reduce_id]

        return None


    importer: Importer
    if zip_reader is not None:
        importer = OrderedImporter(_get_package(zip_reader), sys_importer)
    else:
        importer = sys_importer

    unpickler = PackageUnpickler(importer, io.BytesIO(obj_bytes))
    unpickler.persistent_load = persistent_load
    result = _deploy_objects[id] = unpickler.load()
    return result
Exemple #5
0
    def test_single_ordered_importer(self):
        import module_a  # noqa: F401
        import package_a

        buffer = BytesIO()
        with PackageExporter(buffer) as pe:
            pe.save_module(package_a.__name__)

        buffer.seek(0)
        importer = PackageImporter(buffer)

        # Construct an importer-only environment.
        ordered_importer = OrderedImporter(importer)

        # The module returned by this environment should be the same one that's
        # in the importer.
        self.assertIs(
            ordered_importer.import_module("package_a"),
            importer.import_module("package_a"),
        )
        # It should not be the one available in the outer Python environment.
        self.assertIsNot(ordered_importer.import_module("package_a"),
                         package_a)

        # We didn't package this module, so it should not be available.
        with self.assertRaises(ModuleNotFoundError):
            ordered_importer.import_module("module_a")
Exemple #6
0
def _save_storages(importer, obj):
    serialized_storages = []
    serialized_dtypes = []

    importer = importer if isinstance(importer,
                                      torch.package.PackageImporter) else None
    importers: Importer
    if importer is not None:
        importers = OrderedImporter(importer, sys_importer)
    else:
        importers = sys_importer

    def persistent_id(obj):
        if torch.is_storage(obj) or isinstance(obj,
                                               torch.storage.TypedStorage):
            if isinstance(obj, torch.storage.TypedStorage):
                # TODO: Once we decide to break serialization FC, we can
                # remove this case
                storage = obj._storage
                dtype = obj.dtype
            else:
                storage = obj
                dtype = torch.uint8

            serialized_storages.append(obj)
            serialized_dtypes.append(dtype)
            return ("storage", len(serialized_storages) - 1)

        if hasattr(obj, "__reduce_deploy__"):
            if _serialized_reduces.get(id(obj)) is None:
                _serialized_reduces[id(obj)] = (
                    "reduce_deploy",
                    id(obj),
                    *obj.__reduce_deploy__(importers),
                )
            return _serialized_reduces[id(obj)]

        return None

    # Write the pickle data for `obj`
    data_buf = io.BytesIO()
    pickler = create_pickler(data_buf, importers)
    pickler.persistent_id = persistent_id
    pickler.dump(obj)
    data_value = data_buf.getvalue()
    return (
        data_value,
        serialized_storages,
        serialized_dtypes,
        importer.zip_reader if importer else None,
    )
Exemple #7
0
    def test_ordered_importer_basic(self):
        import package_a
        buffer = BytesIO()
        with PackageExporter(buffer, verbose=False) as pe:
            pe.save_module(package_a.__name__)

        buffer.seek(0)
        importer = PackageImporter(buffer)

        ordered_importer_sys_first = OrderedImporter(sys_importer, importer)
        self.assertIs(ordered_importer_sys_first.import_module('package_a'), package_a)

        ordered_importer_package_first = OrderedImporter(importer, sys_importer)
        self.assertIs(ordered_importer_package_first.import_module('package_a'), importer.import_module('package_a'))
Exemple #8
0
def _load_storages(id, zip_reader, obj_bytes, serialized_storages):

    def persistent_load(saved_id):
        assert isinstance(saved_id, tuple)
        typename = _maybe_decode_ascii(saved_id[0])
        data = saved_id[1:]

        assert typename == 'storage', \
            f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
        return serialized_storages[data[0]]


    importer: Importer
    if zip_reader is not None:
        importer = OrderedImporter(_get_package(zip_reader), sys_importer)
    else:
        importer = sys_importer

    unpickler = PackageUnpickler(importer, io.BytesIO(obj_bytes))
    unpickler.persistent_load = persistent_load
    result = _deploy_objects[id] = unpickler.load()
    return result