def test_mixed_input_dtypes(self, op, x_dtype, y_dtype): @mb.program(input_specs=[ mb.TensorSpec(shape=(10, 10), dtype=string_to_builtin(x_dtype)), mb.TensorSpec(shape=(10, 10), dtype=string_to_builtin(y_dtype)) ]) def prog(x, y): x = getattr(mb, op)(x=x, y=y) return x assert get_op_types_in_program(prog) == [op] _, _, block = apply_pass_and_basic_check( prog, "mil_backend::homogenize_input_dtypes") assert get_op_types_in_program(prog) == ["cast", op] promoted_dtype = promote_types(string_to_builtin(x_dtype), string_to_builtin(y_dtype)) # Asserting cast configuration cast = block.find_ops(op_type="cast")[0] assert cast.dtype.val == builtin_to_string(promoted_dtype) assert len(cast.outputs) == 1 assert len(cast.outputs[0].child_ops) == 1 assert cast.outputs[0].child_ops[0].op_type == op
def type_inference(self): builtin_dtype = types.string_to_builtin(self.dtype.val) if builtin_dtype is None: raise ValueError("Unsupported dtype {}".format(self.dtype.val)) # Replace string with symbol elem_shape_sym = [] for s_var in self.elem_shape: # s is str or int s = s_var.val if s is None: msg = 'make_list elem_shape must be tuple of const. ' +\ 'Tuple elem {} is not' raise ValueError(msg.format(s_var.name)) if isinstance(s, str): try: symbol = get_existing_symbol(s) except ValueError: # Must be a new symbol symbol = get_new_symbol(s) elem_shape_sym.append(symbol) else: elem_shape_sym.append(s) elem_type = types.tensor(builtin_dtype, elem_shape_sym) return types.list( elem_type, init_length=self.init_length.val, dynamic_length=self.dynamic_length.val, )
def type_inference(self): builtin_dtype = types.string_to_builtin(self.dtype.val) if builtin_dtype is None: raise ValueError("Unsupported dtype {}".format(self.dtype.val)) elem_type = types.tensor(builtin_dtype, self.elem_shape.sym_val) return types.list( elem_type, init_length=self.init_length.val, dynamic_length=self.dynamic_length.val, )