Exemplo n.º 1
0
def create_assets(basename,
                  sig_input,
                  sig_output,
                  sig_name,
                  lengths_key=None,
                  beam=None,
                  return_labels=False,
                  preproc='client'):
    """Save required variables for running an exported model from baseline's services.

    :basename the base model name. e.g. /path/to/tagger-26075
    :sig_input the input dictionary
    :sig_output the output namedTuple
    :lengths_key the lengths_key from the model.
        used to translate batch output. Exported models will return a flat list,
        and it needs to be reshaped into per-example lists. We use this key to tell
        us which feature holds the sequence lengths.

    """
    inputs = [k for k in sig_input]
    outputs = sig_output._fields
    model_name = basename.split("/")[-1]
    directory = basename.split("/")[:-1]
    metadata = create_metadata(inputs,
                               outputs,
                               sig_name,
                               model_name,
                               lengths_key,
                               beam=beam,
                               return_labels=return_labels,
                               preproc=preproc)
    return metadata
Exemplo n.º 2
0
    def run(self, basename, output_dir, project=None, name=None, model_version=None, **kwargs):
        logger.warning("Pytorch exporting is experimental and is not guaranteed to work for plugin models.")
        client_output, server_output = get_output_paths(
            output_dir,
            project, name,
            model_version,
            kwargs.get('remote', True),
        )
        logger.info("Saving vectorizers and vocabs to %s", client_output)
        logger.info("Saving serialized model to %s", server_output)
        model, vectorizers, model_name = self.load_model(basename)
        order = monkey_patch_embeddings(model)
        data, lengths = create_fake_data(VECTORIZER_SHAPE_MAP, vectorizers, order)
        meta = create_metadata(
            order, ['output'],
            self.sig_name,
            model_name, model.lengths_key,
            exporter_type=self.preproc_type()
        )

        exportable = self.wrapper(model)
        logger.info("Tracing Model.")
        traced = torch.jit.trace(exportable, (data, lengths))
        traced.save(os.path.join(server_output, 'model.pt'))

        logger.info("Saving metadata.")
        save_to_bundle(client_output, basename, assets=meta)
        logger.info('Successfully exported model to %s', output_dir)
Exemplo n.º 3
0
def create_assets(basename, sig_input, sig_output, sig_name, lengths_key=None, beam=None, return_labels=False, preproc='client'):
    """Save required variables for running an exported model from baseline's services.

    :basename the base model name. e.g. /path/to/tagger-26075
    :sig_input the input dictionary
    :sig_output the output namedTuple
    :lengths_key the lengths_key from the model.
        used to translate batch output. Exported models will return a flat list,
        and it needs to be reshaped into per-example lists. We use this key to tell
        us which feature holds the sequence lengths.

    """
    inputs = [k for k in sig_input]
    outputs =  sig_output._fields
    model_name = basename.split("/")[-1]
    directory = basename.split("/")[:-1]
    metadata = create_metadata(inputs, outputs, sig_name, model_name, lengths_key, beam=beam,
                               return_labels=return_labels, preproc=preproc)
    return metadata
Exemplo n.º 4
0
    def run(self,
            basename,
            output_dir,
            project=None,
            name=None,
            model_version=None,
            **kwargs):
        logger.warning(
            "Pytorch exporting is experimental and is not guaranteed to work for plugin models."
        )
        client_output, server_output = get_output_paths(
            output_dir,
            project,
            name,
            model_version,
            kwargs.get('remote', True),
        )
        logger.info("Saving vectorizers and vocabs to %s", client_output)
        logger.info("Saving serialized model to %s", server_output)
        model, vectorizers, model_name = self.load_model(basename)
        order = monkey_patch_embeddings(model)
        data, lengths = create_fake_data(VECTORIZER_SHAPE_MAP, vectorizers,
                                         order)
        meta = create_metadata(order, ['output'],
                               self.sig_name,
                               model_name,
                               model.lengths_key,
                               exporter_type=self.preproc_type())

        exportable = self.wrapper(model)
        logger.info("Tracing Model.")
        traced = torch.jit.trace(exportable, (data, lengths))
        traced.save(os.path.join(server_output, 'model.pt'))

        logger.info("Saving metadata.")
        save_to_bundle(client_output, basename, assets=meta)
        logger.info('Successfully exported model to %s', output_dir)
Exemplo n.º 5
0
    def _run(self,
             basename,
             output_dir,
             project=None,
             name=None,
             model_version=None,
             use_version=False,
             zip_results=True,
             **kwargs):
        logger.warning(
            "Pytorch exporting is experimental and is not guaranteed to work for plugin models."
        )
        client_output, server_output = get_output_paths(
            output_dir,
            project,
            name,
            model_version,
            kwargs.get('remote', False),
            use_version=use_version)
        logger.info("Saving vectorizers and vocabs to %s", client_output)
        logger.info("Saving serialized model to %s", server_output)

        model, vectorizers, vocabs, model_name = self.load_model(basename)
        model = self.apply_model_patches(model)

        data = self.create_example_input(vocabs, vectorizers)
        example_output = self.create_example_output(model)

        inputs = self.create_model_inputs(model)
        outputs = self.create_model_outputs(model)

        dynamics = self.create_dynamic_axes(model, vectorizers, inputs,
                                            outputs)

        meta = create_metadata(inputs, outputs, self.sig_name, model_name,
                               model.lengths_key)

        if not self.tracing:
            model = torch.jit.script(model)

        logger.info("Exporting Model.")
        logger.info("Model inputs: %s", inputs)
        logger.info("Model outputs: %s", outputs)

        torch.onnx.export(model,
                          data,
                          verbose=True,
                          dynamic_axes=dynamics,
                          f=f'{server_output}/{model_name}.onnx',
                          input_names=inputs,
                          output_names=outputs,
                          opset_version=self.onnx_opset,
                          example_outputs=example_output)

        logger.info("Saving metadata.")
        save_to_bundle(client_output,
                       basename,
                       assets=meta,
                       zip_results=zip_results)
        logger.info('Successfully exported model to %s', output_dir)
        return client_output, server_output
Exemplo n.º 6
0
    def _run(self,
             basename,
             output_dir,
             project=None,
             name=None,
             model_version=None,
             use_version=False,
             zip_results=True,
             remote=False,
             **kwargs):
        client_output, server_output = get_output_paths(
            output_dir,
            project,
            name,
            model_version,
            remote,
            use_version=use_version)
        logger.info("Saving vectorizers and vocabs to %s", client_output)
        logger.info("Saving serialized model to %s", server_output)

        model, vectorizers, vocabs, model_name = self.load_model(basename)
        # Triton server wants to see a specific name

        model = self.apply_model_patches(model)

        data = self.create_example_input(vocabs, vectorizers)
        example_output = self.create_example_output(model)

        inputs = self.create_model_inputs(model)
        outputs = self.create_model_outputs(model)

        dynamics = self.create_dynamic_axes(model, vectorizers, inputs,
                                            outputs)

        meta = create_metadata(inputs, outputs, self.sig_name, model_name,
                               model.lengths_key)

        if not self.tracing:
            model = torch.jit.script(model)

        logger.info("Exporting Model.")
        logger.info("Model inputs: %s", inputs)
        logger.info("Model outputs: %s", outputs)

        onnx_model_name = REMOTE_MODEL_NAME if remote else model_name

        torch.onnx.export(
            model,
            data,
            verbose=True,
            dynamic_axes=dynamics,
            f=f'{server_output}/{onnx_model_name}.onnx',
            input_names=inputs,
            output_names=outputs,
            opset_version=self.onnx_opset,
            #propagate=True,
            example_outputs=example_output)

        logger.info("Saving metadata.")
        save_to_bundle(client_output,
                       basename,
                       assets=meta,
                       zip_results=zip_results)
        logger.info('Successfully exported model to %s', output_dir)
        return client_output, server_output