Ejemplo n.º 1
0
def _wrapper_count_operators(model: nn.Module, inputs: list, mode: str,
                             **kwargs) -> typing.DefaultDict[str, float]:
    # ignore some ops
    supported_ops = {k: lambda *args, **kwargs: {} for k in _IGNORED_OPS}
    supported_ops.update(kwargs.pop("supported_ops", {}))
    kwargs["supported_ops"] = supported_ops

    assert len(inputs) == 1, "Please use batch size=1"
    tensor_input = inputs[0]["image"]
    inputs = [{
        "image": tensor_input
    }]  # remove other keys, in case there are any

    old_train = model.training
    if isinstance(
            model,
        (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)):
        model = model.module
    wrapper = TracingAdapter(model, inputs)
    wrapper.eval()
    if mode == FLOPS_MODE:
        ret = flop_count(wrapper, (tensor_input, ), **kwargs)
    elif mode == ACTIVATIONS_MODE:
        ret = activation_count(wrapper, (tensor_input, ), **kwargs)
    else:
        raise NotImplementedError(
            "Count for mode {} is not supported yet.".format(mode))
    # compatible with change in fvcore
    if isinstance(ret, tuple):
        ret = ret[0]
    model.train(old_train)
    return ret
Ejemplo n.º 2
0
def export_tracing(torch_model, data_loader):
    inputs = next(iter(data_loader))
    images = torch_model.preprocess_image(inputs)
    images_tensor = images.tensor   # If you want Run with FP16, 'images.tensor.half()' instead

    if isinstance(torch_model, GeneralizedRCNN):

        def inference(model, inputs):
            # use do_postprocess=False so it returns ROI mask
            inst = model.inference(inputs, do_postprocess=False)[0]
            return [{"instances": inst}]

    else:
        inference = None  # assume that we just call the model directly

    traceable_model = TracingAdapter(torch_model, images_tensor, inference)

    ts_model = torch.jit.trace(traceable_model, (images_tensor,))
    with PathManager.open(os.path.join(args.output, "model.ts"), "wb") as f:
        torch.jit.save(ts_model, f)
    dump_torchscript_IR(ts_model, args.output)

    logger.info("Inputs schema: " + str(traceable_model.inputs_schema))
    logger.info("Outputs schema: " + str(traceable_model.outputs_schema))

    return ts_model
Ejemplo n.º 3
0
 def __init__(self, model, inputs):
     """
     Args:
         model (nn.Module):
         inputs (Any): inputs of the given model. Does not have to be tuple of tensors.
     """
     wrapper = TracingAdapter(model, inputs, allow_non_tensor=True)
     super().__init__(wrapper, wrapper.flattened_inputs)
     self.set_op_handle(**{k: None for k in _IGNORED_OPS})
Ejemplo n.º 4
0
def export_tracing(torch_model, inputs):
    assert TORCH_VERSION >= (1, 8)
    image = inputs[0]["image"]
    inputs = [{"image": image}]  # remove other unused keys

    if isinstance(torch_model, GeneralizedRCNN):

        def inference(model, inputs):
            # use do_postprocess=False so it returns ROI mask
            inst = model.inference(inputs, do_postprocess=False)[0]
            return [{"instances": inst}]

    else:
        inference = None  # assume that we just call the model directly

    traceable_model = TracingAdapter(torch_model, inputs, inference)

    if args.format == "torchscript":
        ts_model = torch.jit.trace(traceable_model, (image, ))
        with PathManager.open(os.path.join(args.output, "model.ts"),
                              "wb") as f:
            torch.jit.save(ts_model, f)
        dump_torchscript_IR(ts_model, args.output)
    elif args.format == "onnx":
        # NOTE onnx export currently failing in pytorch
        with PathManager.open(os.path.join(args.output, "model.onnx"),
                              "wb") as f:
            torch.onnx.export(traceable_model, (image, ), f, opset_version=11)
    logger.info("Inputs schema: " + str(traceable_model.inputs_schema))
    logger.info("Outputs schema: " + str(traceable_model.outputs_schema))

    if args.format != "torchscript":
        return None
    if not isinstance(torch_model, (GeneralizedRCNN, RetinaNet)):
        return None

    def eval_wrapper(inputs):
        """
        The exported model does not contain the final resize step, which is typically
        unused in deployment but needed for evaluation. We add it manually here.
        """
        input = inputs[0]
        instances = traceable_model.outputs_schema(ts_model(
            input["image"]))[0]["instances"]
        postprocessed = detector_postprocess(instances, input["height"],
                                             input["width"])
        return [{"instances": postprocessed}]

    return eval_wrapper