Пример #1
0
Файл: helper.py Проект: zoq/onnx
def make_map(
    name,  # type: Text
    key_type,  # type: int
    keys,  # type: List[Any]
    values  # type: SequenceProto
):  # type: (...) -> MapProto
    '''
    Make a Map with specified key-value pair arguments.

    Criteria for conversion:
    - Keys and Values must have the same number of elements
    - Every key in keys must be of the same type
    - Every value in values must be of the same type
    '''
    map = MapProto()
    valid_key_int_types = [
        TensorProto.INT8, TensorProto.INT16, TensorProto.INT32,
        TensorProto.INT64, TensorProto.UINT8, TensorProto.UINT16,
        TensorProto.UINT32, TensorProto.UINT64
    ]
    map.name = name
    map.key_type = key_type
    if key_type == TensorProto.STRING:
        map.string_keys.extend(keys)
    elif key_type in valid_key_int_types:
        map.keys.extend(keys)
    map.values.CopyFrom(values)
    return map
Пример #2
0
def from_dict(dict,
              name=None):  # type: (Dict[Any, Any], Optional[Text]) -> MapProto
    """Converts a Python dictionary into a map def.

    Inputs:
        dict: Python dictionary
        name: (optional) the name of the map.
    Returns:
        map: the converted map def.
    """
    map = MapProto()
    if name:
        map.name = name
    keys = list(dict.keys())
    raw_key_type = np.array(keys[0]).dtype
    key_type = mapping.NP_TYPE_TO_TENSOR_TYPE[raw_key_type]

    valid_key_int_types = [
        TensorProto.INT8, TensorProto.INT16, TensorProto.INT32,
        TensorProto.INT64, TensorProto.UINT8, TensorProto.UINT16,
        TensorProto.UINT32, TensorProto.UINT64
    ]

    if not all(isinstance(key, raw_key_type) for key in keys):
        raise TypeError("The key type in the input dictionary is not the same "
                        "for all keys and therefore is not valid as a map.")

    values = list(dict.values())
    raw_value_type = type(values[0])
    if not all(isinstance(val, raw_value_type) for val in values):
        raise TypeError(
            "The value type in the input dictionary is not the same "
            "for all values and therefore is not valid as a map.")

    value_seq = from_list(values)

    map.key_type = key_type
    if key_type == TensorProto.STRING:
        map.string_keys.extend(keys)
    elif key_type in valid_key_int_types:
        map.keys.extend(keys)
    map.values.CopyFrom(value_seq)
    return map