Exemple #1
0
def export(args: ExportArgs, model: Module, val_loader: DataLoader,
           save_dir: str) -> None:
    """
    Utility method to export the model and data

    :param args : An ExportArgs object containing config for export task.
    :param model: loaded model architecture to export
    :param val_loader: A DataLoader for validation data
    :param save_dir: Directory to store checkpoints at during exporting process
    """
    exporter = ModuleExporter(model, save_dir)

    # export PyTorch state dict
    LOGGER.info(f"exporting pytorch in {save_dir}")

    exporter.export_pytorch(use_zipfile_serialization_if_available=(
        args.use_zipfile_serialization_if_available))
    onnx_exported = False

    for batch, data in tqdm(
            enumerate(val_loader),
            desc="Exporting samples",
            total=args.num_samples if args.num_samples > 1 else 1,
    ):
        if not onnx_exported:
            # export onnx file using first sample for graph freezing
            LOGGER.info(f"exporting onnx in {save_dir}")
            exporter.export_onnx(data[0],
                                 opset=args.onnx_opset,
                                 convert_qat=True)
            onnx_exported = True

        if args.num_samples > 0:
            exporter.export_samples(sample_batches=[data[0]],
                                    sample_labels=[data[1]],
                                    exp_counter=batch)
Exemple #2
0
def test_export_batches(batch_size):
    sample_batch = torch.randn(batch_size, 8)
    exporter = ModuleExporter(MLPNet(), tempfile.gettempdir())
    exporter.export_samples([sample_batch])