Esempio n. 1
0
    def bufferize(
            worker: AbstractWorker,
            script_module: torch.jit.TopLevelTracedModule) -> TracedModulePB:
        """
            This method serializes a torch.jit.TopLevelTracedModule using TracedModulePB.

            Args:
                script_module (torch.jit.TopLevelTracedModule): input TopLevelTracedModule to be serialized.

            Returns:
                protobuf_script (TracedModulePB): serialized TopLevelTracedModule.
        """
        protobuf_script = ScriptModulePB()
        protobuf_script.obj = script_module.save_to_buffer()
        return protobuf_script
Esempio n. 2
0
def _bufferize_traced_module(
        worker: AbstractWorker,
        script_module: torch.jit.TopLevelTracedModule) -> TracedModulePB:
    protobuf_script = ScriptModulePB()
    protobuf_script.obj = script_module.save_to_buffer()
    return protobuf_script
Esempio n. 3
0
def get_node_to_io_tensor_names_map(model: torch.nn.Module,
                                    trace: torch.jit.TopLevelTracedModule,
                                    inputs: List[torch.Tensor]) -> \
        (Dict[str, Union[OpToIOTensors, List[OpToIOTensors]]], set):
    """
    Given an Torch model, gets the inputs and output tensor names for each node in the model.
    :param model: The user provided model instance
    :param trace: the mode in torch script format
    :param inputs: sample tensor inputs
    :return: Dictionary of torch script node name and corresponding input and output tensor names and
        a set with all valid param names in model
    """
    # pylint: disable=too-many-locals

    # Generates a look up dictionary for getting modules from their names.
    model_name = type(model).__name__
    module_to_name = {}
    for name, module in model.named_modules(prefix=model_name):
        module_to_name[module] = name
    if isinstance(inputs, torch.Tensor):
        graph = trace.graph_for(inputs)
    else:
        graph = trace.graph_for(*inputs)
    ir_nodes_list = _parse_graph(graph, model)
    _coalesce_add_and_mm_nodes(ir_nodes_list)

    node_to_io_tensor_name_map = {}
    valid_param_set = set()
    prefix_len = len(model_name) + 1

    modules = []

    def forward_hook(curr_module: torch.nn.Module, *_):
        """
        Custom forward hook function to add every module to module list.
        :param curr_module: Current module being traversed during forward pass.
        """
        if not isinstance(curr_module, PassThroughOp):
            modules.append(curr_module)

    run_hook_for_layers_with_given_input(model, inputs, forward_hook)
    index = 0
    module_types = [types for types in ConnectedGraph.op_type_map.values()]
    for node in ir_nodes_list:
        if node.module is None:
            if node.node_type in module_types:
                node.module = modules[index]
                assert ConnectedGraph.op_type_map[type(
                    node.module)] == node.node_type
            else:
                continue
        module_name = module_to_name[node.module][prefix_len:]
        index = index + 1

        node_to_io_tensor_name_map[module_name] = \
            OpToIOTensors(
                [inp.debugName() for inp in node.inputs],
                [output.debugName() for output in node.outputs])

        for param_name, _ in node.module.named_parameters():
            valid_param_set.add(module_name + '.' + param_name)

    assert index == len(modules)

    return node_to_io_tensor_name_map, valid_param_set