Beispiel #1
0
def convert_linear_layer(node, params):
    """Convert linear layer from onnx node and params."""
    # Default Gemm attributes
    dc = dict(
        transpose_weight=True,
        transpose_activation=False,
        weight_multiplier=1,
        bias_multiplier=1,
    )
    dc.update(extract_attributes(node))
    for attr in node.attribute:
        if attr.name in ["transA"] and extract_attr_values(attr) != 0:
            raise NotImplementedError(
                "Not implemented for attr.name={} and value!=0.".format(
                    attr.name))

    kwargs = {}
    weight, bias = extract_params(params)
    kwargs["bias"] = bias is not None
    kwargs["in_features"] = weight.dims[1]
    kwargs["out_features"] = weight.dims[0]

    # initialize layer and load weights
    layer = nn.Linear(**kwargs)
    load_params(layer, weight, bias)

    # apply onnx gemm attributes
    if dc.get("transpose_weight"):
        layer.weight.data = layer.weight.data.t()

    layer.weight.data *= dc.get("weight_multiplier")
    if layer.bias is not None:
        layer.bias.data *= dc.get("bias_multiplier")

    return layer
def convert_instance_norm_layer(node, params):
    kwargs = extract_attributes(node)
    # Skip input dimension check, not possible before forward pass
    layer = InstanceNormWrapper
    torch_params = [torch.from_numpy(numpy_helper.to_array(param)) for param in params]

    # Initialize layer and load weights
    layer = layer(torch_params, **kwargs)
    return layer
Beispiel #3
0
def convert_batch_norm_layer(node, params):
    kwargs = extract_attributes(node)
    # Skip input dimension check, not possible before forward pass
    layer = BatchNormWrapper
    torch_params = [_deserialize_to_torch(param) for param in params]

    # Initialize layer and load weights
    layer = layer(torch_params, **kwargs)
    return layer
Beispiel #4
0
def convert_batch_norm_layer(node, params):
    kwargs = extract_attributes(node)
    layer = nn.BatchNorm2d()

    kwargs["num_features"] = params[0].dims[0]
    # initialize layer and load weights
    layer = layer(**kwargs)
    key = ["weight", "bias", "running_mean", "running_var"]
    for key, value in zip(key, params):
        getattr(layer,
                key).data = torch.from_numpy(numpy_helper.to_array(value))

    return layer
Beispiel #5
0
def convert_batch_norm_layer(node, params):
    kwargs = extract_attributes(node)
    layer = BatchNormUnsafe  # Input dimension check missing, not possible before forward pass

    kwargs["num_features"] = params[0].dims[0]
    # initialize layer and load weights
    layer = layer(**kwargs)
    key = ["weight", "bias", "running_mean", "running_var"]
    for key, value in zip(key, params):
        getattr(layer,
                key).data = torch.from_numpy(numpy_helper.to_array(value))

    return layer
def convert_layer(node, layer_type, params=None):
    """Use to convert Conv, MaxPool, AvgPool layers."""
    assert layer_type in [
        "Conv",
        "ConvTranspose",
        "MaxPool",
        "AvgPool",
    ], "Incorrect layer type: {}".format(layer_type)
    kwargs = extract_attributes(node)
    kernel_size_length = len(kwargs["kernel_size"])
    try:
        layer = getattr(nn, "{}{}d".format(layer_type, kernel_size_length))
    except AttributeError:
        raise ValueError(
            "Unexpected length of kernel_size dimension: {}".format(kernel_size_length)
        )

    pad_layer = None
    if params:
        weight, bias = extract_params(params)
        kwargs["bias"] = bias is not None
        kwargs["in_channels"] = weight.dims[1] * kwargs.get("groups", 1)
        kwargs["out_channels"] = weight.dims[0]

        if layer_type == "ConvTranspose":
            kwargs["in_channels"], kwargs["out_channels"] = (
                kwargs["out_channels"],
                kwargs["in_channels"],
            )

        # if padding is a layer, remove from kwargs and prepend later
        if "padding" in kwargs and isinstance(kwargs["padding"], nn.Module):
            pad_layer = kwargs.pop("padding")

        # initialize layer and load weights
        layer = layer(**kwargs)
        load_params(layer, weight, bias)
    else:
        # initialize operations without parameters (MaxPool, AvgPool, etc.)
        if layer_type == "MaxPool":
            kwargs["return_indices"] = True

        # if padding is a layer, remove from kwargs and prepend later
        if "padding" in kwargs and isinstance(kwargs["padding"], nn.Module):
            pad_layer = kwargs.pop("padding")
        layer = layer(**kwargs)

    if pad_layer is not None:
        layer = nn.Sequential(pad_layer, layer)

    return layer
Beispiel #7
0
def convert_instance_norm_layer(node, params):
    kwargs = extract_attributes(node)
    # Skips input dimension check, not possible before forward pass
    layer = nn.InstanceNorm2d()

    kwargs["num_features"] = params[0].dims[0]
    # initialize layer and load weights
    layer = layer(**kwargs)
    key = ["weight", "bias"]
    for key, value in zip(key, params):
        getattr(layer,
                key).data = torch.from_numpy(numpy_helper.to_array(value))

    return layer
Beispiel #8
0
def convert_operations(onnx_model, batch_dim=0):
    """
    Convert onnx model operations. Yields onnx's operator_id, opeartor_name and
    converted pytorch operator.

    Parameters
    ----------
    onnx_model: onnx.ModelProto
        Loaded onnx model.
    batch_dim: int
        Usually 0 for computer vision models and 1 for NLP models.

    Returns
    -------
    iterator: (op_id, op_name, op)
    """
    weights = {tensor.name: tensor for tensor in onnx_model.graph.initializer}

    for i, node in enumerate(onnx_model.graph.node):
        # extract only useful inputs
        params = [
            weights[par_name] for par_name in node.input if par_name in weights
        ]

        if node.op_type == "Conv":
            op = convert_layer(node, "Conv", params)
        elif node.op_type == "Relu":
            op = nn.ReLU(inplace=True)
        elif node.op_type == "LeakyRelu":
            op = nn.LeakyReLU(**extract_attributes(node), inplace=True)
        elif node.op_type == "Sigmoid":
            op = nn.Sigmoid()
        elif node.op_type == "MaxPool":
            op = convert_layer(node, "MaxPool")
        elif node.op_type == "AveragePool":
            op = convert_layer(node, "AvgPool")
        elif node.op_type == "Flatten":
            op = Flatten(**extract_attributes(node))
        elif node.op_type == "Gemm":
            op = convert_linear_layer(node, params)
            op.feature_dim = batch_dim + 1  # Necessary for transformers
        elif node.op_type == "BatchNormalization":
            op = convert_batch_norm_layer(node, params=params)
        elif node.op_type == "InstanceNormalization":
            op = convert_instance_norm_layer(node, params=params)
        elif node.op_type == "Concat":
            op = Concat(**extract_attributes(node))
        elif node.op_type == "Constant":
            # 常量OP如何解决的问题
            op = value_wrapper(
                torch.from_numpy(extract_attributes(node)["constant"]))
        elif node.op_type == "Reshape":
            shape = list(
                filter(lambda x: x.name == node.input[1],
                       onnx_model.graph.initializer))
            shape = numpy_helper.to_array(shape[0]) if shape else None
            op = Reshape(tuple(shape))
        elif node.op_type == "Shape":
            op = Shape()
        elif node.op_type == "Gather":
            op = Gather(**extract_attributes(node))
        elif node.op_type == "Squeeze":
            op = Squeeze(**extract_attributes(node))
        elif node.op_type == "Unsqueeze":
            op = partial(torch.unsqueeze, **extract_attributes(node))
        elif node.op_type == "ConstantOfShape":
            op = ConstantOfShape(**extract_attributes(node))
        elif node.op_type == "Slice":
            op = Slice(**extract_attributes(node))
        elif node.op_type == "Cast":
            op = Cast(**extract_attributes(node))
        elif node.op_type == "Where":
            op = Where()
        elif node.op_type == "Equal":
            op = torch.eq
        elif node.op_type == "Mul":
            op = Mul(**extract_attributes(node))
        elif node.op_type == "Div":
            op = torch.true_divide
        elif node.op_type == "MatMul":
            if params:
                weight = torch.from_numpy(numpy_helper.to_array(params[0]))
                op = nn.Linear(weight.shape[0], weight.shape[1], bias=False)
                op.weight.data = weight.t()

                # check if next node Add to add bias
                next_node = onnx_model.graph.node[i + 1]
                next_params = [
                    weights[par_name] for par_name in next_node.input
                    if par_name in weights
                ]
                if next_params and next_node.op_type == "Add":
                    bias = torch.from_numpy(
                        numpy_helper.to_array(next_params[0]))
                    op.bias = nn.Parameter(bias)
                    node.output.pop()
                    node.output.extend(next_node.output)
                    onnx_model.graph.node.pop(i + 1)  # remove next node
            else:
                op = Matmul()
        elif node.op_type == "Sub":
            op = torch.sub
        elif node.op_type == "Pow":
            op = torch.pow
        elif node.op_type == "Sqrt":
            op = torch.sqrt
        elif node.op_type == "Softmax":
            op = nn.Softmax(dim=1)
        elif node.op_type == "Transpose":
            op = partial(torch.Tensor.permute, **extract_attributes(node))
        elif node.op_type == "Split":
            kwargs = extract_attributes(node)
            # if the split_size_or_sections is not in node attributes,
            # the number_of_splits becomes the number of node outputs
            if "split_size_or_sections" not in kwargs:
                kwargs["number_of_splits"] = len(node.output)
            op = Split(**kwargs)
        elif node.op_type == "ReduceMean":
            kwargs = dict(keepdim=True)
            kwargs.update(extract_attributes(node))
            op = partial(torch.mean, **kwargs)
        elif node.op_type == "Add":
            op = Add()
        elif node.op_type == "GlobalAveragePool":
            op = GlobalAveragePool()
        elif node.op_type == "ConvTranspose":
            op = convert_layer(node, "ConvTranspose", params)
        elif node.op_type == "Identity":
            op = nn.Identity()
        elif node.op_type == "Resize":
            op = Resize(**extract_attributes(node))
        elif node.op_type == "Upsample":
            op = Upsample(**extract_attributes(node))
        elif node.op_type == "OneHot":
            op = OneHot(**extract_attributes(node))
        elif node.op_type == "Pad":
            op = Pad(**extract_attributes(node))
        elif node.op_type == "Clip":
            op = Clamp(**extract_attributes(node))
        elif node.op_type == "Tanh":
            op = torch.tanh
        elif node.op_type == "Erf":
            op = torch.erf
        elif node.op_type == "Log":
            op = torch.log
        elif node.op_type == "Exp":
            op = torch.exp
        elif node.op_type == "LRN":
            op = nn.LocalResponseNorm(**extract_attributes(node))
        elif node.op_type == "Dropout":
            op = nn.Dropout(p=1.0)
        else:
            op = getattr(torch, node.op_type.lower(), None)
            if op is None:
                raise NotImplementedError(
                    "Conversion not implemented for op_type={}.".format(
                        node.op_type))
            else:
                print("Automatic inference of operator: {}".format(
                    node.op_type.lower()))

        op_name = "{}_{}".format(node.op_type, node.output[0])
        op_id = node.output[0]
        yield op_id, op_name, op
def convert_operations(onnx_graph, opset_version, batch_dim=0, enable_pruning=True):
    """
    Convert onnx model operations. Yields onnx's operator_id, operator_name and
    converted pytorch operator.

    Parameters
    ----------
    onnx_graph: onnx.GraphProto
        Loaded onnx model's GraphProto.
    opset_version: int
        ONNX model's opset version.
    batch_dim: int
        Usually 0 for computer vision models and 1 for NLP models.
    enable_pruning: bool
        Track kept/pruned indices between different calls to forward pass.

    Returns
    -------
    iterator: (op_id, op_name, op)
    """
    weights = {tensor.name: tensor for tensor in onnx_graph.initializer}

    for i, node in enumerate(onnx_graph.node):
        # extract only useful inputs
        params = [weights[par_name] for par_name in node.input if par_name in weights]

        if node.op_type == "Add":
            op = Add(feature_dim=batch_dim + 1)  # 0 for CV models and 1 for NLP
        elif node.op_type == "And":
            op = OperatorWrapper(torch.logical_and)
        elif node.op_type == "AveragePool":
            op = convert_layer(node, "AvgPool")
        elif node.op_type == "BatchNormalization":
            op = convert_batch_norm_layer(node, params=params)
        elif node.op_type == "Cast":
            op = Cast(**extract_attributes(node))
        elif node.op_type == "Ceil":
            op = OperatorWrapper(torch.ceil)
        elif node.op_type == "Clip":
            op = Clip(**extract_attributes(node))
        elif node.op_type == "Concat":
            op = partial(torch.cat, **extract_attributes(node))
        elif node.op_type == "Constant":
            op = Constant(**extract_attributes(node))
        elif node.op_type == "ConstantOfShape":
            op = ConstantOfShape(**extract_attributes(node))
        elif node.op_type == "Conv":
            op = convert_layer(node, "Conv", params)
        elif node.op_type == "ConvTranspose":
            op = convert_layer(node, "ConvTranspose", params)
        elif node.op_type == "Div":
            op = Div()
        elif node.op_type == "Elu":
            op = nn.ELU(**extract_attributes(node), inplace=True)
        elif node.op_type == "Equal":
            op = OperatorWrapper(torch.eq)
        elif node.op_type == "Erf":
            op = OperatorWrapper(torch.erf)
        elif node.op_type == "Exp":
            op = OperatorWrapper(torch.exp)
        elif node.op_type == "Expand":
            op = Expand()
        elif node.op_type == "Flatten":
            op = Flatten(**extract_attributes(node))
            op.feature_dim = batch_dim + 1  # Necessary for transformers
        elif node.op_type == "Floor":
            op = OperatorWrapper(torch.floor)
        elif node.op_type == "Gather":
            op = Gather(**extract_attributes(node))
        elif node.op_type == "GatherND":
            op = GatherND(**extract_attributes(node))
        elif node.op_type == "Gemm":
            op = convert_linear_layer(node, params)
        elif node.op_type == "GlobalAveragePool":
            op = GlobalAveragePool()
        elif node.op_type == "Greater":
            op = OperatorWrapper(torch.greater)
        elif node.op_type == "Identity":
            op = nn.Identity()
        elif node.op_type == "InstanceNormalization":
            op = convert_instance_norm_layer(node, params=params)
        elif node.op_type == "LeakyRelu":
            op = nn.LeakyReLU(**extract_attributes(node), inplace=True)
        elif node.op_type == "Less":
            op = OperatorWrapper(torch.less)
        elif node.op_type == "Log":
            op = OperatorWrapper(torch.log)
        elif node.op_type == "Loop":
            op = Loop(
                opset_version=opset_version,
                batch_dim=batch_dim,
                **extract_attributes(node),
            )
        elif node.op_type == "LSTM":
            op = convert_lstm_layer(node, weights)
        elif node.op_type == "MatMul":
            if params:
                weight = torch.from_numpy(numpy_helper.to_array(params[0]))
                op = nn.Linear(weight.shape[0], weight.shape[1], bias=False)
                op.weight.data = weight.t()

                # check if next node Add to add bias
                next_node = onnx_graph.node[i + 1]
                next_params = [
                    weights[par_name]
                    for par_name in next_node.input
                    if par_name in weights
                ]
                if next_params and next_node.op_type == "Add":
                    bias = torch.from_numpy(numpy_helper.to_array(next_params[0]))
                    op.bias = nn.Parameter(bias)
                    node.output.pop()
                    node.output.extend(next_node.output)
                    onnx_graph.node.pop(i + 1)  # remove next node
            else:
                op = MatMul()
        elif node.op_type == "Max":
            op = OperatorWrapper(torch.max)
        elif node.op_type == "MaxPool":
            op = convert_layer(node, "MaxPool")
        elif node.op_type == "Min":
            op = OperatorWrapper(torch.min)
        elif node.op_type == "Mul":
            op = OperatorWrapper(torch.mul)
        elif node.op_type == "NonMaxSuppression":
            op = NonMaxSuppression(**extract_attributes(node))
        elif node.op_type == "Not":
            op = OperatorWrapper(torch.logical_not)
        elif node.op_type == "OneHot":
            op = OneHot(**extract_attributes(node))
        elif node.op_type == "Or":
            op = OperatorWrapper(torch.logical_or)
        elif node.op_type == "Pad":
            op = Pad(**extract_attributes(node))
        elif node.op_type == "Pow":
            op = OperatorWrapper(torch.pow)
        elif node.op_type == "PRelu":
            op = PRelu()
        elif node.op_type == "Range":
            op = Range()
        elif node.op_type == "Reciprocal":
            op = OperatorWrapper(torch.reciprocal)
        elif node.op_type == "ReduceMax":
            kwargs = dict(keepdim=True)
            kwargs.update(extract_attributes(node))
            op = partial(torch.max, **kwargs)
        elif node.op_type == "ReduceMean":
            kwargs = dict(keepdim=True)
            kwargs.update(extract_attributes(node))
            op = partial(torch.mean, **kwargs)
        elif node.op_type == "ReduceMin":
            kwargs = dict(keepdim=True)
            kwargs.update(extract_attributes(node))
            op = partial(torch.min, **kwargs)
        elif node.op_type == "ReduceProd":
            kwargs = dict(keepdim=True)
            kwargs.update(extract_attributes(node))
            op = partial(torch.prod, **kwargs)
        elif node.op_type == "ReduceSum":
            op = ReduceSum(opset_version=opset_version, **extract_attributes(node))
        elif node.op_type == "Relu":
            op = nn.ReLU(inplace=True)
        elif node.op_type == "Reshape":
            shape = list(
                filter(lambda x: x.name == node.input[1], onnx_graph.initializer)
            )
            shape = np.copy(numpy_helper.to_array(shape[0])) if shape else None
            op = Reshape(enable_pruning, shape)
        elif node.op_type == "Resize":
            op = Resize(**extract_attributes(node))
        elif node.op_type == "Scatter":
            op = Scatter(**extract_attributes(node))
        elif node.op_type == "ScatterElements":
            op = ScatterElements(**extract_attributes(node))
        elif node.op_type == "ScatterND":
            op = ScatterND()
        elif node.op_type == "Shape":
            op = Shape()
        elif node.op_type == "Sigmoid":
            op = nn.Sigmoid()
        elif node.op_type == "Slice":
            op = Slice(**extract_attributes(node))
        elif node.op_type == "Softmax":
            kwargs = dict(dim=-1)
            kwargs.update(extract_attributes(node))
            op = nn.Softmax(**kwargs)
        elif node.op_type == "Softplus":
            op = nn.Softplus(beta=1)
        elif node.op_type == "Softsign":
            op = nn.Softsign()
        elif node.op_type == "Split":
            kwargs = extract_attributes(node)
            # if the split_size_or_sections is not in node attributes,
            # the number_of_splits becomes the number of node outputs
            if "split_size_or_sections" not in kwargs:
                kwargs["number_of_splits"] = len(node.output)
            op = Split(enable_pruning, **kwargs)
        elif node.op_type == "Sqrt":
            op = OperatorWrapper(torch.sqrt)
        elif node.op_type == "Squeeze":
            op = Squeeze(opset_version=opset_version, **extract_attributes(node))
        elif node.op_type == "Sub":
            op = OperatorWrapper(torch.sub)
        elif node.op_type == "Tanh":
            op = OperatorWrapper(torch.tanh)
        elif node.op_type == "ThresholdedRelu":
            op = ThresholdedRelu(**extract_attributes(node))
        elif node.op_type == "Tile":
            op = Tile()
        elif node.op_type == "TopK":
            op = TopK()
        elif node.op_type == "Transpose":
            op = Transpose(**extract_attributes(node))
        elif node.op_type == "Unsqueeze":
            op = Unsqueeze(opset_version=opset_version, **extract_attributes(node))
        elif node.op_type == "Upsample":
            op = Upsample(**extract_attributes(node))
        elif node.op_type == "Where":
            op = Where()
        else:
            op = getattr(torch, node.op_type.lower(), None)
            if op is None:
                raise NotImplementedError(
                    "Conversion not implemented for op_type={}.".format(node.op_type)
                )
            else:
                print(
                    "Automatic inference of operator: {}".format(node.op_type.lower())
                )

        op_name = "{}_{}".format(node.op_type, node.output[0])
        op_id = node.output[0]
        yield op_id, op_name, op
def convert_lstm_layer(node, weights):
    """Convert LSTM layer from onnx node and params."""
    params_tuple = extract_and_load_params_lstm(node, weights)
    (X, W, R, B, sequence_lens, initial_h, initial_c, P) = params_tuple
    if initial_h is not None:
        raise NotImplementedError("LSTM initial_h not yet implemented.")
    if initial_c is not None:
        raise NotImplementedError("LSTM initial_c not yet implemented.")
    if P is not None:
        raise NotImplementedError("LSTM P not yet implemented.")

    dc = dict(
        activation_alpha=None,
        activation_beta=None,
        activations=None,
        clip=None,
        direction="forward",
        hidden_size=None,
        input_forget=0,
        layout=0,
    )
    dc.update(extract_attributes(node))
    if dc["activation_alpha"] is not None:
        raise NotImplementedError(
            "LSTM activation_alpha {}.".format(dc["activation_alpha"])
        )
    if dc["activation_beta"] is not None:
        raise NotImplementedError(
            "LSTM activation_beta {}.".format(dc["activation_beta"])
        )
    if dc["activations"] is not None:
        # TODO allow if torch-compatible activations are set explicitly
        raise NotImplementedError("LSTM activations {}.".format(dc["activations"]))
    if dc["clip"] is not None:
        raise NotImplementedError("LSTM clip {}".format(dc["clip"]))
    if dc["direction"] not in ("forward", "bidirectional"):
        raise ValueError("LSTM direction {}.".format(dc["direction"]))
    if dc["hidden_size"] is None:
        raise ValueError("LSTM hidden_size is None.")
    if dc["input_forget"] != 0:
        raise NotImplementedError("LSTM input_forget {}.".format(dc["input_forget"]))
    if dc["layout"] != 0:
        raise NotImplementedError(
            "LSTM not implemented for layout={}".format(dc["layout"])
        )

    kwargs = {
        "input_size": W.shape[2],
        "hidden_size": dc["hidden_size"],
        "num_layers": 1,
        "bias": True,
        "batch_first": False,
        "dropout": 0,
        "bidirectional": dc["direction"] == "bidirectional",
    }
    lstm_layer = nn.LSTM(**kwargs)

    input_size = kwargs["input_size"]
    hidden_size = kwargs["hidden_size"]
    num_directions = kwargs["bidirectional"] + 1
    num_layers = 1
    if kwargs["bidirectional"]:
        # Set input-hidden weights
        W_iofc = W.transpose(0, 1).view(4 * hidden_size, num_directions, input_size)
        for dir_dim, dir_str in [(0, ""), (1, "_reverse")]:
            W_ifco = torch.cat(
                tensors=(
                    W_iofc[0:hidden_size, dir_dim, :],
                    W_iofc[2 * hidden_size : 4 * hidden_size, dir_dim, :],
                    W_iofc[hidden_size : 2 * hidden_size, dir_dim, :],
                ),
                dim=0,
            )
            getattr(lstm_layer, "weight_ih_l0{}".format(dir_str)).data = W_ifco

        # Set hidden-hidden weights
        R_iofc = R.transpose(0, 1).view(4 * hidden_size, num_directions, hidden_size)
        for dir_dim, dir_str in [(0, ""), (1, "_reverse")]:
            R_ifco = torch.cat(
                tensors=(
                    R_iofc[0:hidden_size, dir_dim, :],
                    R_iofc[2 * hidden_size : 4 * hidden_size, dir_dim, :],
                    R_iofc[hidden_size : 2 * hidden_size, dir_dim, :],
                ),
                dim=0,
            )
            getattr(lstm_layer, "weight_hh_l0{}".format(dir_str)).data = R_ifco

        # Set input-hidden biases
        for dir_dim, dir_str in [(0, ""), (1, "_reverse")]:
            Wb_iofc = B[dir_dim, 0 : 4 * hidden_size]
            Wb_ifco = torch.cat(
                tensors=(
                    Wb_iofc[0:hidden_size],
                    Wb_iofc[2 * hidden_size : 4 * hidden_size],
                    Wb_iofc[hidden_size : 2 * hidden_size],
                ),
                dim=0,
            )
            getattr(lstm_layer, "bias_ih_l0{}".format(dir_str)).data = Wb_ifco

        # Set hidden-hidden biases
        for dir_dim, dir_str in [(0, ""), (1, "_reverse")]:
            Rb_iofc = B[dir_dim, 4 * hidden_size :]
            Rb_ifco = torch.cat(
                tensors=(
                    Rb_iofc[0:hidden_size],
                    Rb_iofc[2 * hidden_size : 4 * hidden_size],
                    Rb_iofc[hidden_size : 2 * hidden_size],
                ),
                dim=0,
            )
            getattr(lstm_layer, "bias_hh_l0{}".format(dir_str)).data = Rb_ifco
    else:
        # Set input-hidden weights
        W_iofc = W.transpose(0, 1).view(4 * hidden_size, input_size)
        W_ifco = torch.cat(
            tensors=(
                W_iofc[0:hidden_size, :],
                W_iofc[2 * hidden_size : 4 * hidden_size, :],
                W_iofc[hidden_size : 2 * hidden_size, :],
            ),
            dim=0,
        )
        getattr(lstm_layer, "weight_ih_l0").data = W_ifco

        # Set hidden-hidden weights
        R_iofc = R.transpose(0, 1).view(4 * hidden_size, hidden_size)
        R_ifco = torch.cat(
            tensors=(
                R_iofc[0:hidden_size, :],
                R_iofc[2 * hidden_size : 4 * hidden_size, :],
                R_iofc[hidden_size : 2 * hidden_size, :],
            ),
            dim=0,
        )
        getattr(lstm_layer, "weight_hh_l0").data = R_ifco

        # Set input-hidden biases
        Wb_iofc = B[0, 0 : 4 * hidden_size]
        Wb_ifco = torch.cat(
            tensors=(
                Wb_iofc[0:hidden_size],
                Wb_iofc[2 * hidden_size : 4 * hidden_size],
                Wb_iofc[hidden_size : 2 * hidden_size],
            ),
            dim=0,
        )
        getattr(lstm_layer, "bias_ih_l0").data = Wb_ifco

        # Set hidden-hidden biases
        Rb_iofc = B[0, 4 * hidden_size :]
        Rb_ifco = torch.cat(
            tensors=(
                Rb_iofc[0:hidden_size],
                Rb_iofc[2 * hidden_size : 4 * hidden_size],
                Rb_iofc[hidden_size : 2 * hidden_size],
            ),
            dim=0,
        )
        getattr(lstm_layer, "bias_hh_l0").data = Rb_ifco

    layer = LSTMWrapper(lstm_layer)
    return layer