예제 #1
파일: pio.py 프로젝트: feitianyiren/hpat
    def _infer_h5_typ(self, rhs):
        # infer the type if it is of the from f['A']['B'][:] or f['A'][b,:]
        # with constant filename
        # TODO: static_getitem has index_var for sure?
        # make sure it's slice, TODO: support non-slice like integer
        require(rhs.op in ('getitem', 'static_getitem'))
        # XXX can't know the type of index here especially if it is bool arr
        # make sure it is not string (we're not in the middle a select chain)
        index_var = rhs.index if rhs.op == 'getitem' else rhs.index_var
        index_val = guard(find_const, self.func_ir, index_var)
        require(not isinstance(index_val, str))
        # index_def = get_definition(self.func_ir, index_var)
        # require(isinstance(index_def, ir.Expr) and index_def.op == 'call')
        # require(find_callname(self.func_ir, index_def) == ('slice', 'builtins'))
        # collect object names until the call
        val_def = rhs
        obj_name_list = []
        while True:
            val_def = get_definition(self.func_ir, val_def.value)
            require(isinstance(val_def, ir.Expr))
            if val_def.op == 'call':
                return self._get_h5_type_file(val_def, obj_name_list)

            # object_name should be constant str
            require(val_def.op in ('getitem', 'static_getitem'))
            val_index_var = val_def.index if val_def.op == 'getitem' else val_def.index_var
            obj_name = find_str_const(self.func_ir, val_index_var)
예제 #2
파일: pio.py 프로젝트: feitianyiren/hpat
    def _get_h5_type_file(self, val_def, obj_name_list):
        require(len(obj_name_list) > 0)
        require(find_callname(self.func_ir, val_def) == ('File', 'h5py'))
        require(len(val_def.args) > 0)
        f_name = find_str_const(self.func_ir, val_def.args[0])

        import h5py
        f = h5py.File(f_name, 'r')
        obj = f
        for obj_name in obj_name_list:
            obj = obj[obj_name]
        require(isinstance(obj, h5py.Dataset))
        ndims = len(obj.shape)
        numba_dtype = numba.numpy_support.from_dtype(obj.dtype)
        return types.Array(numba_dtype, ndims, 'C')