Ejemplo n.º 1
0
    def encode(self, obj: keras.Model, description: Optional[str],
               params: Optional[Dict]) -> FrameData:
        filename = util.create_filename(self.tmp_dir)

        if self.store_whole_model:
            obj.save(filename, save_format=self.save_format)
            application_type = "tensorflow/model"
        else:
            obj.save_wights(filename, save_format=self.save_format)
            application_type = "tensorflow/weights"

        settings = {
            "class": f"{obj.__class__.__module__}.{obj.__class__.__name__}",
            "tensorflow": tf.__version__,
            "storage_format": self.archive
        }

        if self.save_format == "tf":
            archived = util.create_filename(self.tmp_dir)
            filename = shutil.make_archive(archived, self.archive, filename)

        return FrameData(
            FileContent(Path(filename)),
            MediaType("application/octet-stream", application_type),
            description, params, settings)
Ejemplo n.º 2
0
 def encode(self, obj: Path, description: Optional[str], params: Optional[Dict]) -> FrameData:
     length = obj.stat().st_size
     media_type = MediaType(CONTENT_TYPE_MAP_REVERSED.get(obj.suffix, "application/octet-stream"))
     f = obj.open("rb")
     buf = StreamContent(f, length)
     settings = {"filename": obj.name}
     settings.update(self.settings)
     return FrameData(buf, media_type, description, params, settings)
Ejemplo n.º 3
0
    def encode(self, obj: Module, description: Optional[str], params: Optional[Dict]) -> FrameData:
        buf = io.BytesIO()

        # FIXME: add model summary here
        settings = {"class": f"{obj.__class__.__module__}.{obj.__class__.__name__}",
                    "torch": torch.version.__version__}

        if self.store_whole_model:
            torch.save(obj, buf)
            application_type = "torch/model"
        else:
            torch.save(obj.state_dict(), buf)
            application_type = "torch/state"

        return FrameData(BytesContent(buf),
                         MediaType("application/binary", application_type),
                         description, params, settings)