Exemple #1
0
 def type_inference(self):
     init_length = self.init_length.val
     if self.elem_shape is None or self.elem_shape.sym_val is None:
         return types.list(
             types.unknown,
             init_length=init_length,
             dynamic_length=self.dynamic_length.val,
         )
     builtin_dtype = types.string_to_builtin(self.dtype.val)
     if builtin_dtype is None:
         raise ValueError("Unsupported dtype {}".format(self.dtype.val))
     elem_type = types.tensor(builtin_dtype, self.elem_shape.sym_val)
     return types.list(elem_type,
                       init_length=init_length,
                       dynamic_length=self.dynamic_length.val)
Exemple #2
0
    def __init__(self,
                 name,
                 elem_type=None,
                 init_length=None,
                 dynamic_length=True,
                 sym_val=None,
                 **kwargs):
        """
        elem_type (builtin.tensor)

        init_length (int): initial length

        dynamic_length (bool): True to allow list to grow. False uses
        init_length as the fixed size (init_length is runtime length).

        sym_val: value of the list, if available
        """
        super(ListVar,
              self).__init__(name=name,
                             sym_type=types.list(elem_type, init_length,
                                                 dynamic_length),
                             sym_val=sym_val,
                             **kwargs)
        self._elem_type = elem_type
        self.init_length = init_length
        self.dynamic_length = dynamic_length
Exemple #3
0
    def type_inference(self):
        builtin_dtype = types.string_to_builtin(self.dtype.val)
        if builtin_dtype is None:
            raise ValueError("Unsupported dtype {}".format(self.dtype.val))
        # Replace string with symbol
        elem_shape_sym = []
        for s_var in self.elem_shape:
            # s is str or int
            s = s_var.val
            if s is None:
                msg = 'make_list elem_shape must be tuple of const. ' +\
                    'Tuple elem {} is not'
                raise ValueError(msg.format(s_var.name))

            if isinstance(s, str):
                try:
                    symbol = get_existing_symbol(s)
                except ValueError:
                    # Must be a new symbol
                    symbol = get_new_symbol(s)
                elem_shape_sym.append(symbol)
            else:
              elem_shape_sym.append(s)
        elem_type = types.tensor(builtin_dtype, elem_shape_sym)
        return types.list(
            elem_type,
            init_length=self.init_length.val,
            dynamic_length=self.dynamic_length.val,
        )
    def type_inference(self):
        list_elem_type = self.ls.elem_type
        value_type = self.value.sym_type
        dynamic_length = self.ls.dynamic_length
        init_length = self.ls.init_length

        if list_elem_type is None:
            # fill in the elem type using value's type info.
            return types.list(value_type,
                              init_length=init_length,
                              dynamic_length=dynamic_length)
        if list_elem_type == types.unknown:
            msg = "Input ls elem type unknown. Override with {}"
            logging.warning(msg.format(value_type))
            return types.list(value_type,
                              init_length=init_length,
                              dynamic_length=dynamic_length)
        if not types.is_subtype(value_type, list_elem_type):
            msg = "Elem type mismatch: ls elem type {} vs " + "value type {}"
            raise ValueError(msg.format(list_elem_type, value_type))
        return self.ls.sym_type
Exemple #5
0
def proto_to_types(valuetype):
    """
    A helper function that maps the proto value type to PyMIL types.
    """
    if valuetype.WhichOneof("type") == "tensorType":
        tensortype = valuetype.tensorType
        dtype = types.proto_to_builtin_types[tensortype.dataType]

        if tensortype.rank < 0:
            raise ValueError("Negative or Dynamic ranks not supported")
        if tensortype.rank != len(tensortype.dimensions):
            raise ValueError("Rank doesn't match the number of dimensions")
        if tensortype.attributes != {}:
            raise ValueError("Attributes on tensorType not supported")

        shape = []
        for i in range(tensortype.rank):
            shape.append(get_proto_dim(tensortype.dimensions[i]))

        # For the zero rank tensor, we always convert it back to scalar in PyMIL first
        if tensortype.rank == 0:
            return dtype

        return types.tensor(dtype, shape)

    elif valuetype.WhichOneof("type") == "listType":
        listtype = valuetype.listType
        elem_type = proto_to_types(listtype.type)

        if listtype.length.unknown:
            init_length = None
        else:
            init_length = listtype.length.constant.size

        # In the MIL proto, there is no such thing of "dynamic_length", hence we set it to True when
        # converting back to PyMIL
        return types.list(elem_type, init_length, dynamic_length=True)

    elif valuetype.WhichOneof("type") == "dictionaryType":
        dicttype = valuetype.dictionaryType
        keytype = proto_to_types(dicttype.keyType)
        valuetype = proto_to_types(dicttype.valueType)

        return types.dict(keytype, valuetype)
    else:
        raise NotImplementedError("Types {} not yet implemented".format(
            valuetype.WhichOneof("type")))
    def type_inference(self):
        num_indices = self.indices.shape[0]
        num_values = self.value.shape[0]
        if num_values != num_indices:
            raise ValueError("Cannot scatter {} values to {} indices".format(
                num_values, num_indices))
        list_elem_type = self.ls.elem_type
        value_type = self.value.sym_type
        dynamic_length = self.ls.dynamic_length
        init_length = self.ls.init_length

        elem_type = types.tensor(value_type.get_primitive(),
                                 value_type.get_shape()[1:])
        if list_elem_type == types.unknown:
            # fill in the elem type using value's type info.
            return types.list(elem_type,
                              dynamic_length=dynamic_length,
                              init_length=init_length)
        if not types.is_subtype(elem_type, list_elem_type):
            msg = "Elem type mismatch: ls elem type {} vs " + "value type {}"
            raise ValueError(msg.format(list_elem_type, elem_type))
        return self.ls.sym_type