示例#1
0
def types_to_proto(valuetype):
    if types.is_tensor(valuetype):
        primitive = types_to_proto_primitive(valuetype.get_primitive())
        return create_valuetype_tensor(valuetype.get_shape(), primitive)
    elif types.is_tuple(valuetype):
        v_type = pm.ValueType()
        t_type = v_type.tupleType
        for t in valuetype.T:
            new_v_type = t_type.types.add()
            new_v_type.CopyFrom(types_to_proto(t))
        return v_type
    elif types.is_list(valuetype):
        elem = valuetype.T[0]
        length = valuetype.T[1]
        if types.is_tensor(elem):
            dtype = types_to_proto_primitive(elem.get_primitive())
            elem_shape = elem.get_shape()
        elif types.is_scalar(elem):
            dtype = types_to_proto_primitive(valuetype)
            elem_shape = ()
        elif types.is_str(elem):
            dtype = types_to_proto_primitive(elem)
            elem_shape = ()
        else:
            raise NotImplementedError(
                "Only list of either tensors or scalars supported. "
                "Got element of type {}".format(elem.__type_info__()))
        return create_valuetype_list(length=length,
                                     elem_shape=elem_shape,
                                     dtype=dtype)
    elif types.is_dict(valuetype):
        return create_valuetype_dict(valuetype.T[0], valuetype.T[1])
    else:
        return create_valuetype_scalar(types_to_proto_primitive(valuetype))
示例#2
0
    def type_inference(self):
        # check the type of "classes"
        if not types.is_list(self.classes.sym_type):
            msg = "'classes' in the op 'classify' must be of type list. Instead it is {}."
            raise ValueError(msg.format(self.classes.sym_type.__type_info__()))

        # check the type of "probabilities"
        if self.probabilities.dtype != types.fp32:
            msg = "classify op: input probabilities must be of type fp32. Instead it is of type {}"
            raise TypeError(
                msg.format(self.probabilities.sym_type.get_primitive().
                           __type_info__()))

        classes_elem_type = self.classes.elem_type
        if classes_elem_type not in {types.str, types.int64}:
            msg = "Type of elements in 'classes' in the op 'classify' must be either str or int64. Instead it is {}."
            raise ValueError(msg.format(classes_elem_type.__type_info__()))

        # check that the size of "classes" is compatible with the size of "probabilities"
        if not any_symbolic(self.probabilities.shape):
            size = np.prod(self.probabilities.shape)
            if len(self.classes.val) != size:
                msg = "In op 'classify', number of classes must match the size of the tensor corresponding to 'probabilities'."
                raise ValueError(msg)

        return classes_elem_type, types.dict(classes_elem_type, types.double)
示例#3
0
 def _does_block_contain_symbolic_shape(block):
     for op in block.operations:
         for b in op.blocks:
             if _does_block_contain_symbolic_shape(b):
                 return True
         for out in op.outputs:
             if types.is_tensor(out.sym_type):
                 shape = out.sym_type.get_shape()
                 if any_symbolic(shape):
                     return True
             elif types.is_scalar(out.sym_type) or types.is_str(
                     out.sym_type):
                 if is_symbolic(out.val):
                     return True
             elif types.is_list(out.sym_type):
                 if types.is_tensor(out.elem_type):
                     if any_symbolic(out.elem_type.get_shape()):
                         return True
                 else:
                     raise NotImplementedError(
                         "\'{}\' type in a list not handled".format(
                             out.elem_type))
             else:
                 raise NotImplementedError(
                     "\'{}\' type is not handled".format(out.sym_type))
     return False
示例#4
0
 def type_str(self):
     is_tensor = types.is_tensor(self.sym_type)
     is_list = types.is_list(self.sym_type)
     if is_tensor:
         type_string = "(Tensor)"
     elif is_list:
         type_string = "(List)"
     else:
         type_string = "(Scalar)"
     return type_string
示例#5
0
def create_immediate_value(var):
    if types.is_tensor(var.sym_type):
        return create_tensor_value(var.val)
    elif types.is_list(var.sym_type):
        if var.elem_type == types.str:
            return create_list_scalarvalue(var.val, np.str)
        elif var.elem_type == types.int64:
            return create_list_scalarvalue(var.val, np.int64)
        else:
            raise NotImplementedError(
                "List element type, {}, not supported yet.".format(
                    var.sym_type.__type_info__()))
    else:
        return create_scalar_value(var.val)
示例#6
0
文件: load.py 项目: apple/coremltools
def _create_var_from_spec(spec):
    """
    This helper function is used for creating PyMIL Var/ListVar from the proto spec.
    Mainly used for the contruction of the control flow ops.
    """
    assert isinstance(spec, pm.NamedValueType)
    sym_type = proto_to_types(spec.type)
    name = spec.name
    if types.is_list(sym_type):
        var = ListVar(
            name, 
            elem_type=sym_type.T[0], 
            init_length=sym_type.T[1], 
            dynamic_length=sym_type.T[2])
    else:
        var = Var(name, sym_type, None, op=None, op_output_idx=None)
    return var
示例#7
0
 def _is_compatible(self, v):
     return (types.is_list(v.sym_type) or types.is_scalar(v.dtype)
             or types.is_tensor(v.dtype))
示例#8
0
 def _is_compatible(self, v):
     return types.is_list(v.sym_type)
示例#9
0
    def type_value_inference(self, overwrite_output=False):
        """
        Perform type inference and auto_val computation based on new input Vars
        in kwargs. If self._output_vars is None then we generate _output_vars;
        otherwise no new Var is created, but type inference result is verified
        against existing _output_vars, if overwrite_output is False.

        If overwrite_output is True, then the type inference result overwrites the
        existing _output_vars
        """
        output_types = self.type_inference()
        if not isinstance(output_types, tuple):
            output_types = (output_types, )
        output_vals = self._auto_val(output_types)
        try:
            output_names = self.output_names()
            if not isinstance(output_names, tuple):
                output_names = (output_names, )
        except NotImplementedError as e:
            if len(output_types) > 1:
                output_names = tuple(
                    str(i) for i, _ in enumerate(output_types))
            else:
                output_names = ("", )  # output name same as op name.

        # Combine (output_names, output_types, output_vals) to create output
        # Vars.
        if self._output_vars is None:
            self._output_vars = []
            for i, (n, sym_type, sym_val) in enumerate(
                    zip(output_names, output_types, output_vals)):
                name = self.name + "_" + n if n != "" else self.name
                if types.is_list(sym_type):
                    new_var = ListVar(
                        name,
                        elem_type=sym_type.T[0],
                        init_length=sym_type.T[1],
                        dynamic_length=sym_type.T[2],
                        sym_val=sym_val if
                        (sym_val is not None
                         and isinstance(sym_val.val, list)) else None,
                        op=self,
                        op_output_idx=i,
                    )
                else:
                    new_var = Var(name,
                                  sym_type,
                                  sym_val,
                                  op=self,
                                  op_output_idx=i)
                self._output_vars.append(new_var)
        else:
            # Check new inference result against existing self._output_vars.
            for i, (n, sym_type, sym_val) in enumerate(
                    zip(output_names, output_types, output_vals)):
                out_var = self._output_vars[i]
                # Check type inference
                if overwrite_output:
                    out_var._sym_type = sym_type
                elif not is_compatible_type(sym_type, out_var.sym_type):
                    msg = "Output Var {} in op {} type changes with new input Vars"
                    raise ValueError(msg.format(out_var.name, self.name))

                # Check value inference
                if overwrite_output:
                    out_var._sym_val = sym_val

                if sym_val is not None and out_var.sym_val is not None:
                    if np.any(sym_val.val != out_var.sym_val):
                        if overwrite_output:
                            out_var._sym_val = sym_val
                        else:
                            msg = 'value_inference differs for var {} in op {}'
                            if not _is_compatible_symbolic_array(
                                    sym_val.val, out_var.sym_val):
                                raise ValueError(
                                    msg.format(out_var.name, self.name))