예제 #1
0
    def run(self, basename, output_dir, project=None, name=None, model_version=None, **kwargs):
        logger.warning("Pytorch exporting is experimental and is not guaranteed to work for plugin models.")
        client_output, server_output = get_output_paths(
            output_dir,
            project, name,
            model_version,
            kwargs.get('remote', True),
        )
        logger.info("Saving vectorizers and vocabs to %s", client_output)
        logger.info("Saving serialized model to %s", server_output)
        model, vectorizers, model_name = self.load_model(basename)
        order = monkey_patch_embeddings(model)
        data, lengths = create_fake_data(VECTORIZER_SHAPE_MAP, vectorizers, order)
        meta = create_metadata(
            order, ['output'],
            self.sig_name,
            model_name, model.lengths_key,
            exporter_type=self.preproc_type()
        )

        exportable = self.wrapper(model)
        logger.info("Tracing Model.")
        traced = torch.jit.trace(exportable, (data, lengths))
        traced.save(os.path.join(server_output, 'model.pt'))

        logger.info("Saving metadata.")
        save_to_bundle(client_output, basename, assets=meta)
        logger.info('Successfully exported model to %s', output_dir)
예제 #2
0
def create_bundle(builder, output_path, basename, assets=None):
    """Creates the output files for an exported model.

    :builder the tensorflow saved_model builder.
    :output_path the path to export a model. this includes the model_version name.
    :assets a dictionary of assets to save alongside the model.
    """
    builder.save()

    model_name = os.path.basename(basename)
    directory = os.path.realpath(os.path.dirname(basename))
    save_to_bundle(output_path, directory, assets)
예제 #3
0
def create_bundle(builder, output_path, basename, assets=None):
    """Creates the output files for an exported model.

    :builder the tensorflow saved_model builder.
    :output_path the path to export a model. this includes the model_version name.
    :assets a dictionary of assets to save alongside the model.
    """
    builder.save()

    model_name = os.path.basename(basename)
    directory = os.path.realpath(os.path.dirname(basename))
    save_to_bundle(output_path, directory, assets)
예제 #4
0
    def run(self,
            basename,
            output_dir,
            project=None,
            name=None,
            model_version=None,
            **kwargs):
        logger.warning(
            "Pytorch exporting is experimental and is not guaranteed to work for plugin models."
        )
        client_output, server_output = get_output_paths(
            output_dir,
            project,
            name,
            model_version,
            kwargs.get('remote', True),
        )
        logger.info("Saving vectorizers and vocabs to %s", client_output)
        logger.info("Saving serialized model to %s", server_output)
        model, vectorizers, model_name = self.load_model(basename)
        order = monkey_patch_embeddings(model)
        data, lengths = create_fake_data(VECTORIZER_SHAPE_MAP, vectorizers,
                                         order)
        meta = create_metadata(order, ['output'],
                               self.sig_name,
                               model_name,
                               model.lengths_key,
                               exporter_type=self.preproc_type())

        exportable = self.wrapper(model)
        logger.info("Tracing Model.")
        traced = torch.jit.trace(exportable, (data, lengths))
        traced.save(os.path.join(server_output, 'model.pt'))

        logger.info("Saving metadata.")
        save_to_bundle(client_output, basename, assets=meta)
        logger.info('Successfully exported model to %s', output_dir)
예제 #5
0
    def _run(self,
             basename,
             output_dir,
             project=None,
             name=None,
             model_version=None,
             use_version=False,
             zip_results=True,
             **kwargs):
        logger.warning(
            "Pytorch exporting is experimental and is not guaranteed to work for plugin models."
        )
        client_output, server_output = get_output_paths(
            output_dir,
            project,
            name,
            model_version,
            kwargs.get('remote', False),
            use_version=use_version)
        logger.info("Saving vectorizers and vocabs to %s", client_output)
        logger.info("Saving serialized model to %s", server_output)

        model, vectorizers, vocabs, model_name = self.load_model(basename)
        model = self.apply_model_patches(model)

        data = self.create_example_input(vocabs, vectorizers)
        example_output = self.create_example_output(model)

        inputs = self.create_model_inputs(model)
        outputs = self.create_model_outputs(model)

        dynamics = self.create_dynamic_axes(model, vectorizers, inputs,
                                            outputs)

        meta = create_metadata(inputs, outputs, self.sig_name, model_name,
                               model.lengths_key)

        if not self.tracing:
            model = torch.jit.script(model)

        logger.info("Exporting Model.")
        logger.info("Model inputs: %s", inputs)
        logger.info("Model outputs: %s", outputs)

        torch.onnx.export(model,
                          data,
                          verbose=True,
                          dynamic_axes=dynamics,
                          f=f'{server_output}/{model_name}.onnx',
                          input_names=inputs,
                          output_names=outputs,
                          opset_version=self.onnx_opset,
                          example_outputs=example_output)

        logger.info("Saving metadata.")
        save_to_bundle(client_output,
                       basename,
                       assets=meta,
                       zip_results=zip_results)
        logger.info('Successfully exported model to %s', output_dir)
        return client_output, server_output
예제 #6
0
    def _run(self,
             basename,
             output_dir,
             project=None,
             name=None,
             model_version=None,
             use_version=False,
             zip_results=True,
             remote=False,
             **kwargs):
        client_output, server_output = get_output_paths(
            output_dir,
            project,
            name,
            model_version,
            remote,
            use_version=use_version)
        logger.info("Saving vectorizers and vocabs to %s", client_output)
        logger.info("Saving serialized model to %s", server_output)

        model, vectorizers, vocabs, model_name = self.load_model(basename)
        # Triton server wants to see a specific name

        model = self.apply_model_patches(model)

        data = self.create_example_input(vocabs, vectorizers)
        example_output = self.create_example_output(model)

        inputs = self.create_model_inputs(model)
        outputs = self.create_model_outputs(model)

        dynamics = self.create_dynamic_axes(model, vectorizers, inputs,
                                            outputs)

        meta = create_metadata(inputs, outputs, self.sig_name, model_name,
                               model.lengths_key)

        if not self.tracing:
            model = torch.jit.script(model)

        logger.info("Exporting Model.")
        logger.info("Model inputs: %s", inputs)
        logger.info("Model outputs: %s", outputs)

        onnx_model_name = REMOTE_MODEL_NAME if remote else model_name

        torch.onnx.export(
            model,
            data,
            verbose=True,
            dynamic_axes=dynamics,
            f=f'{server_output}/{onnx_model_name}.onnx',
            input_names=inputs,
            output_names=outputs,
            opset_version=self.onnx_opset,
            #propagate=True,
            example_outputs=example_output)

        logger.info("Saving metadata.")
        save_to_bundle(client_output,
                       basename,
                       assets=meta,
                       zip_results=zip_results)
        logger.info('Successfully exported model to %s', output_dir)
        return client_output, server_output