def type_inference(self): list_elem_type = self.ls.elem_type value_type = self.value.sym_type dynamic_length = self.ls.dynamic_length init_length = self.ls.init_length if list_elem_type is None: # fill in the elem type using value's type info. return types.list(value_type, init_length=init_length, dynamic_length=dynamic_length) if list_elem_type == types.unknown: msg = "Input ls elem type unknown. Override with {}" logging.warning(msg.format(value_type)) return types.list(value_type, init_length=init_length, dynamic_length=dynamic_length) if not types.is_subtype(value_type, list_elem_type): msg = "Elem type mismatch: ls elem type {} vs " + "value type {}" raise ValueError(msg.format(list_elem_type, value_type)) return self.ls.sym_type
def type_inference(self): num_indices = self.indices.shape[0] num_values = self.value.shape[0] if num_values != num_indices: raise ValueError("Cannot scatter {} values to {} indices".format( num_values, num_indices)) list_elem_type = self.ls.elem_type value_type = self.value.sym_type dynamic_length = self.ls.dynamic_length init_length = self.ls.init_length elem_type = types.tensor(value_type.get_primitive(), value_type.get_shape()[1:]) if list_elem_type == types.unknown: # fill in the elem type using value's type info. return types.list(elem_type, dynamic_length=dynamic_length, init_length=init_length) if not types.is_subtype(elem_type, list_elem_type): msg = "Elem type mismatch: ls elem type {} vs " + "value type {}" raise ValueError(msg.format(list_elem_type, elem_type)) return self.ls.sym_type
def _check_is_compatible_type(type1, type2): if not types.is_subtype(type1, type2): is_comp, _ = types.is_tensor_and_is_compatible(type1, type2) return is_comp return True