Esempio n. 1
0
def extract_attributes(node):
    """Extract onnx attributes. Map onnx feature naming to pytorch."""
    kwargs = {}
    for attr in node.attribute:
        if attr.name == "dilations":
            kwargs["dilation"] = extract_attr_values(attr)
        elif attr.name == "group":
            kwargs["groups"] = extract_attr_values(attr)
        elif attr.name == "kernel_shape":
            kwargs["kernel_size"] = extract_attr_values(attr)
        elif attr.name == "pads":
            params = extract_attr_values(attr)
            if node.op_type == "Pad":
                kwargs["padding"] = extract_padding_params(params)
            else:
                # Works for Conv, MaxPooling and other layers from convert_layer func
                kwargs["padding"] = extract_padding_params_for_conv_layer(
                    params)
        elif attr.name == "strides":
            kwargs["stride"] = extract_attr_values(attr)
        elif attr.name == "axis" and node.op_type == "Flatten":
            kwargs["start_dim"] = extract_attr_values(attr)
        elif attr.name == "axis" or attr.name == "axes":
            v = extract_attr_values(attr)
            if isinstance(v, (tuple, list)) and len(v) == 1:
                kwargs["dim"] = v[0]
            else:
                kwargs["dim"] = v
        elif attr.name == "keepdims":
            kwargs["keepdim"] = bool(extract_attr_values(attr))
        elif attr.name == "epsilon":
            kwargs["eps"] = extract_attr_values(attr)
        elif attr.name == "momentum":
            kwargs["momentum"] = extract_attr_values(attr)
        elif attr.name == "ceil_mode":
            kwargs["ceil_mode"] = bool(extract_attr_values(attr))
        elif attr.name == "value":
            kwargs["constant"] = extract_attr_values(attr)
        elif attr.name == "perm":
            kwargs["dims"] = extract_attr_values(attr)
        elif attr.name == "split":
            kwargs["split_size_or_sections"] = extract_attr_values(attr)
        elif attr.name == "spatial":
            kwargs["spatial"] = extract_attr_values(
                attr)  # Batch norm parameter
        elif attr.name == "to":
            kwargs["dtype"] = TENSOR_PROTO_MAPPING[extract_attr_values(
                attr)].lower()
        elif attr.name == "mode":
            kwargs["mode"] = extract_attr_values(attr)
        elif attr.name == "transB":
            kwargs["transpose_weight"] = not extract_attr_values(attr)
        elif attr.name == "transA":
            kwargs["transpose_activation"] = bool(extract_attr_values(attr))
        elif attr.name == "alpha" and node.op_type == "LeakyRelu":
            kwargs["negative_slope"] = extract_attr_values(attr)
        elif attr.name == "alpha":
            kwargs["weight_multiplier"] = extract_attr_values(attr)
        elif attr.name == "beta":
            kwargs["bias_multiplier"] = extract_attr_values(attr)
        elif attr.name == "starts":
            kwargs["starts"] = extract_attr_values(attr)
        elif attr.name == "ends":
            kwargs["ends"] = extract_attr_values(attr)
        elif attr.name == "coordinate_transformation_mode":
            arg = extract_attr_values(attr)
            if arg == "align_corners":
                kwargs["align_corners"] = True
            else:
                warnings.warn(
                    "Pytorch's interpolate uses no coordinate_transformation_mode={}. "
                    "Result might differ.".format(arg))
        elif node.op_type == "Resize":
            # These parameters are not used, warn in Resize operator
            kwargs[attr.name] = extract_attr_values(attr)
        elif attr.name == "auto_pad":
            value = extract_attr_values(attr)
            if value == "NOTSET":
                pass
            else:
                raise NotImplementedError(
                    "auto_pad={} functionality not implemented.".format(value))
        else:
            raise NotImplementedError(
                "Extraction of attribute {} not implemented.".format(
                    attr.name))
    return kwargs
Esempio n. 2
0
def test_extract_padding_params(weight, onnx_pads, torch_pads):
    out_pads = extract_padding_params(onnx_pads)
    assert out_pads == torch_pads