Ejemplo n.º 1
0
    def _get_type_val(self, value):

        if isinstance(value, (float, np.float64)):
            value = np.float32(value)
        elif isinstance(value, bool):
            value = np.bool(value)
        elif isinstance(value, (int, np.int64)):
            value = np.int32(value)
        elif isinstance(value, (tuple, list, np.ndarray)):
            value = np.array(value)
            if value.dtype == np.int64:
                # We use int32 by default.
                value = value.astype(np.int32)

            if value.dtype == np.float64:
                # We use float32 by default.
                value = value.astype(np.float32)

        elif isinstance(value, mil_list):
            # if val that was passed in is of type mil_list, which is just a wrapper on top of python list
            # then construct the list type
            list_value = value.ls
            if len(list_value) == 0:
                raise ValueError("'mil_list' points to an empty list")
            builtin_elem_type, _ = self._get_type_val(list_value[0])
            from coremltools.converters.mil.mil.types.type_list import list as types_list
            builtin_type = types_list(builtin_elem_type, init_length=len(list_value), dynamic_length=False)
            return builtin_type, value


        if not isinstance(value, (np.generic, np.ndarray, str, bool, mil_list)):
            raise ValueError("Unknown value for constant: {}".format(value))

        _, builtin_type = numpy_val_to_builtin_val(value)
        return builtin_type, value
Ejemplo n.º 2
0
    def _get_type_val(self, value):

        if isinstance(value, (float, np.float64)):
            value = np.float32(value)
        elif isinstance(value, bool):
            value = np.bool(value)
        elif isinstance(value, (int, np.int64)):
            value = np.int32(value)
        elif isinstance(value, (tuple, list, np.ndarray)):
            value = np.array(value)

            # For the int type, we use int32 by default
            if value.dtype in [
                    np.uint8, np.int8, np.uint16, np.int16, np.uint32,
                    np.uint64, np.int64
            ]:
                if value.dtype in [np.uint64, np.int64]:
                    msg = "Downcast const op {} data int64 as int32".format(
                        self.name)
                    logging.debug(msg)
                value = value.astype(np.int32)

            # For the float type, we use float32 by default
            elif value.dtype == np.float64:
                msg = "Downcast const op {} data fp64 as fp32".format(
                    self.name)
                logging.debug(msg)
                value = value.astype(np.float32)

        elif isinstance(value, mil_list):
            # if val that was passed in is of type mil_list, which is just a wrapper on top of python list
            # then construct the list type
            list_value = value.ls
            if len(list_value) == 0:
                raise ValueError("'mil_list' points to an empty list")
            builtin_elem_type, _ = self._get_type_val(list_value[0])
            # mil_list is a special case that we want to preserve the int64 element type
            if isinstance(list_value[0], np.int64):
                builtin_elem_type = types.int64
            from coremltools.converters.mil.mil.types.type_list import list as types_list
            builtin_type = types_list(builtin_elem_type,
                                      init_length=len(list_value),
                                      dynamic_length=False)
            return builtin_type, value

        if not isinstance(value,
                          (np.generic, np.ndarray, str, bool, mil_list)):
            raise ValueError("Unknown value for constant: {}".format(value))

        _, builtin_type = numpy_val_to_builtin_val(value)
        return builtin_type, value
Ejemplo n.º 3
0
    def _get_type_val(self, value):

        if isinstance(value, (float, np.float64)):
            value = np.float32(value)
        elif isinstance(value, bool):
            value = np.bool(value)
        elif isinstance(value, (six.integer_types, np.int64)):
            value = np.int32(value)
        elif isinstance(value, (tuple, list, np.ndarray)):
            value = np.array(value)
            if value.dtype == np.int64:
                # We use int32 by default.
                value = value.astype(np.int32)

            if value.dtype == np.float64:
                # We use float32 by default.
                value = value.astype(np.float32)

        if not isinstance(value, (np.generic, np.ndarray, six.string_types, bool)):
            raise ValueError("Unknown value for constant: {}".format(value))

        _, builtin_type = numpy_val_to_builtin_val(value)
        return builtin_type, value
Ejemplo n.º 4
0
def _constant_propagation(fn, new_graph, constant_nodes, constant_node_num_outputs):
    try:
        if len(constant_nodes) > 0:
            with tf.Graph().as_default() as graph:
                tf.import_graph_def(new_graph, name="")

                # We're only making one call to `sess.run()` in order to compute constant values.
                # In this context, the default optimization settings make everything dramatically
                # slower and more memory-intensive.
                if tf.__version__ < _StrictVersion("1.13.1"):
                    session_config = tf.ConfigProto()
                    session_config.graph_options.optimizer_options.opt_level = (
                        tf.OptimizerOptions.L0
                    )
                    sess = tf.Session(graph=graph, config=session_config)
                else:
                    session_config = tf.compat.v1.ConfigProto()
                    session_config.graph_options.optimizer_options.opt_level = (
                        tf.compat.v1.OptimizerOptions.L0
                    )
                    session_config.graph_options.rewrite_options.disable_meta_optimizer = (
                        True
                    )
                    sess = tf.compat.v1.Session(graph=graph, config=session_config)

                query_list = list()
                control_flow_ops = list()
                for c in constant_nodes:
                    for j in range(constant_node_num_outputs[c]):
                        query = c + ":" + str(j)
                        lower_query = query.lower()
                        if "switch" in lower_query or "cond" in lower_query:
                            control_flow_ops.append(query)
                        else:
                            query_list.append(query)
                result_list = sess.run(query_list)
                result = {
                    query_list[i]: result_list[i] for i in range(len(query_list))
                }
                # propagate switch one by one
                for op in control_flow_ops:
                    try:
                        res = sess.run([op])
                        result.update({op: res[0]})
                    except:
                        logging.warning(
                            '[Constant Propagation] Skip "dead" tensor: {}'.format(
                                op
                            )
                        )
                        result.update({op: None})

                sess.close()

            for k, v in fn.graph.items():
                if k in constant_node_num_outputs:
                    if constant_node_num_outputs[k] == 1:
                        result_entry = k + ":0"
                        try:
                            v.value, v.datatype = numpy_val_to_builtin_val(
                                result[result_entry]
                            )
                        except:
                            logging.error(result_entry)
                            logging.error(result[result_entry])
                    else:
                        values = [
                            result[k + ":" + str(i)]
                            for i in range(constant_node_num_outputs[k])
                        ]
                        try:
                            npval = [numpy_val_to_builtin_val(i) for i in values]
                            v.datatype = types.tuple(tuple([val[1] for val in npval]))
                            v.value = v.datatype()
                            for idx, val in enumerate(npval):
                                v.value.val[idx] = val[0]
                        except:
                            logging.error(values)
            for k, v in fn.graph.items():
                if v.op == "get_tuple":
                    inp = fn.graph[v.inputs[0]]
                    idx = v.attr["index"]
                    if inp.value is not None:
                        v.value = inp.value.val[idx]
                        v.datatype = inp.datatype.T[idx]

    except Exception as e:
        logging.exception("Constant Propagation pass failed: {}".format(e))