コード例 #1
0
ファイル: test_passes.py プロジェクト: aseemw/coremltools
    def test_mixed_input_dtypes(self, op, x_dtype, y_dtype):
        @mb.program(input_specs=[
            mb.TensorSpec(shape=(10, 10), dtype=string_to_builtin(x_dtype)),
            mb.TensorSpec(shape=(10, 10), dtype=string_to_builtin(y_dtype))
        ])
        def prog(x, y):
            x = getattr(mb, op)(x=x, y=y)
            return x

        assert get_op_types_in_program(prog) == [op]

        _, _, block = apply_pass_and_basic_check(
            prog, "mil_backend::homogenize_input_dtypes")

        assert get_op_types_in_program(prog) == ["cast", op]

        promoted_dtype = promote_types(string_to_builtin(x_dtype),
                                       string_to_builtin(y_dtype))

        # Asserting cast configuration
        cast = block.find_ops(op_type="cast")[0]
        assert cast.dtype.val == builtin_to_string(promoted_dtype)
        assert len(cast.outputs) == 1
        assert len(cast.outputs[0].child_ops) == 1
        assert cast.outputs[0].child_ops[0].op_type == op
コード例 #2
0
    def type_inference(self):
        builtin_dtype = types.string_to_builtin(self.dtype.val)
        if builtin_dtype is None:
            raise ValueError("Unsupported dtype {}".format(self.dtype.val))
        # Replace string with symbol
        elem_shape_sym = []
        for s_var in self.elem_shape:
            # s is str or int
            s = s_var.val
            if s is None:
                msg = 'make_list elem_shape must be tuple of const. ' +\
                    'Tuple elem {} is not'
                raise ValueError(msg.format(s_var.name))

            if isinstance(s, str):
                try:
                    symbol = get_existing_symbol(s)
                except ValueError:
                    # Must be a new symbol
                    symbol = get_new_symbol(s)
                elem_shape_sym.append(symbol)
            else:
              elem_shape_sym.append(s)
        elem_type = types.tensor(builtin_dtype, elem_shape_sym)
        return types.list(
            elem_type,
            init_length=self.init_length.val,
            dynamic_length=self.dynamic_length.val,
        )
コード例 #3
0
 def type_inference(self):
     builtin_dtype = types.string_to_builtin(self.dtype.val)
     if builtin_dtype is None:
         raise ValueError("Unsupported dtype {}".format(self.dtype.val))
     elem_type = types.tensor(builtin_dtype, self.elem_shape.sym_val)
     return types.list(
         elem_type,
         init_length=self.init_length.val,
         dynamic_length=self.dynamic_length.val,
     )