def init_op(self): self.fmt = '{in1} = {in2}{tranpose};' + self.fmt_base self.fmt_with_idx = '{in1}({idx}) = {in2};' self.at_idx = self.el.at_idx self.transpose = '' self.slices = self.el.slices self.input1 = sanitize_name(self.el.inputs[0].name) self.input2 = sanitize_name(self.el.inputs[1].name) shape0 = self.el.inputs[0].shape shape1 = self.el.inputs[1].shape if len(shape0) and len(shape1) and shape0[0] != shape1[0]: self.transpose = '.t()' if self.el.slices: self.slice_tuples = [] slices = self.el.slices input_shape = self.el.inputs[0].shape nslices = len(slices) slice_ndim = len(input_shape) # in ardamillo no 1d vectors -> convert to 2d row vector if slice_ndim == 1: input_shape = (input_shape[0], 1) slice_ndim += 1 # fill up empty slices with ':'' (-> slice[None]) while nslices < slice_ndim: slices.append(slice(None)) nslices += 1 # convert to armadillo slices for idx in range(nslices): if type(slices[idx]) == int: self.slice_tuples.append((slices[idx], slices[idx])) else: sl = slices[idx] stp = sl.stop - 1 if sl.stop else input_shape[idx] - 1 self.slice_tuples.append((sl.start or 0, stp)) out_shape = self.el.output.kwargs['slice_shape'] slice_shape = out_shape if len(out_shape) != 1 else (out_shape[0], 1) arma_shape = (self.slice_tuples[0][1] - self.slice_tuples[0][0] + 1, self.slice_tuples[1][1] - self.slice_tuples[1][0] + 1) if slice_shape == arma_shape or (slice_shape == () and arma_shape == (1, 1)): self.slice_fmt = 'set_items({lhs}, {rhs}, {{{start}}}, {{{end}}});' elif slice_shape == (arma_shape[1], arma_shape[0]): self.slice_fmt = 'set_items({lhs}, {rhs}, {{{start}}}, {{{end}}}, true);' else: raise NotImplementedError("You should not end up here.")
def __init__(self, el): self.el = el self.name = sanitize_name(el.name) self.fmt_caller = fmt_caller.format(name=self.name, caller_class=el.caller_info[0] if el.caller_info else '', caller_fun=el.caller_info[2] if el.caller_info else '', caller_line=el.caller_info[3] if el.caller_info else '') \ if config.debug else '' self.init_op()
def __init__(self, el): self.el = el self.name = sanitize_name(el.name) self.fmt_base = (' ' + fmt_caller + '\n' + (fmt_print_double if el.output.shape == () else fmt_print_mat)).format( name=self.name, caller_class=el.caller_info[0] if el.caller_info else '', caller_fun=el.caller_info[2] if el.caller_info else '', caller_line=el.caller_info[3] if el.caller_info else '') \ if config.debug else '' self.init_op()
def wrapper(*args): if not jet.jet_mode: return func(*args) func_id = id(func) func_cached = _func_cached_dict[func_id]['func'] if func_cached is not None: return func_cached(*args) shapes = _func_cached_dict[func_id]['shapes'] if inspect.ismethod(func): arg_names = func.__code__.co_varnames[1:func.__code__. co_argcount] else: arg_names = func.__code__.co_varnames[:func.__code__. co_argcount] if len(arg_names) != len(args): assert (len(arg_names) == 0) arg_names = [get_unique_name('ph') for each in args] if len(shapes) != len(arg_names) and shapes: raise ValueError( 'Shapes length does not match the arguments length.') if not shapes: shapes = [ arg.shape if hasattr(arg, 'shape') else () for arg in args ] _func_cached_dict[func_id]['shapes'] = shapes ph = [ placeholder(name=arg[1], shape=shapes[arg[0]]) for arg in enumerate(arg_names) ] fun_name = func.__code__.co_name if fun_name == '<lambda>': fun_name = get_unique_name('lambda') jb = JetBuilder(args=ph, out=func(*ph), file_name=get_unique_name( sanitize_name('{}_{}_{func_name}'.format( *get_caller_info('jit.py')[1:-1], func_name=fun_name))), fun_name=get_unique_name(fun_name)) jet_class = getattr(jb.build(), jb.class_name) jet_func = getattr(jet_class(), jb.fun_name) _func_cached_dict[func_id]['func'] = jet_func return jet_func(*args)
def __repr__(self): input1 = sanitize_name(self.el.inputs[0].name) return self.fmt.format(dtype=self.get_dtype(), name=self.name, in1=input1, operator=self.op_map[self.el.op])