Esempio n. 1
0
    def type_inference(self):
        list_elem_type = self.ls.elem_type
        value_type = self.value.sym_type
        dynamic_length = self.ls.dynamic_length
        init_length = self.ls.init_length

        if list_elem_type is None:
            # fill in the elem type using value's type info.
            return types.list(value_type,
                              init_length=init_length,
                              dynamic_length=dynamic_length)
        if list_elem_type == types.unknown:
            msg = "Input ls elem type unknown. Override with {}"
            logging.warning(msg.format(value_type))
            return types.list(value_type,
                              init_length=init_length,
                              dynamic_length=dynamic_length)
        if not types.is_subtype(value_type, list_elem_type):
            msg = "Elem type mismatch: ls elem type {} vs " + "value type {}"
            raise ValueError(msg.format(list_elem_type, value_type))
        return self.ls.sym_type
Esempio n. 2
0
    def type_inference(self):
        num_indices = self.indices.shape[0]
        num_values = self.value.shape[0]
        if num_values != num_indices:
            raise ValueError("Cannot scatter {} values to {} indices".format(
                num_values, num_indices))
        list_elem_type = self.ls.elem_type
        value_type = self.value.sym_type
        dynamic_length = self.ls.dynamic_length
        init_length = self.ls.init_length

        elem_type = types.tensor(value_type.get_primitive(),
                                 value_type.get_shape()[1:])
        if list_elem_type == types.unknown:
            # fill in the elem type using value's type info.
            return types.list(elem_type,
                              dynamic_length=dynamic_length,
                              init_length=init_length)
        if not types.is_subtype(elem_type, list_elem_type):
            msg = "Elem type mismatch: ls elem type {} vs " + "value type {}"
            raise ValueError(msg.format(list_elem_type, elem_type))
        return self.ls.sym_type
Esempio n. 3
0
 def _check_is_compatible_type(type1, type2):
     if not types.is_subtype(type1, type2):
         is_comp, _ = types.is_tensor_and_is_compatible(type1, type2)
         return is_comp
     return True