def accumulate(space, func, w_arr, axis, calc_dtype, w_out, identity): out_iter, out_state = w_out.create_iter() arr_shape = w_arr.get_shape() temp_shape = arr_shape[:axis] + arr_shape[axis + 1:] temp = W_NDimArray.from_shape(space, temp_shape, calc_dtype, w_instance=w_arr) temp_iter = AxisIter(temp.implementation, w_arr.get_shape(), axis) temp_state = temp_iter.reset() arr_iter, arr_state = w_arr.create_iter() arr_iter.track_index = False if identity is not None: identity = identity.convert_to(space, calc_dtype) shapelen = len(arr_shape) while not out_iter.done(out_state): accumulate_driver.jit_merge_point(shapelen=shapelen, func=func, calc_dtype=calc_dtype) w_item = arr_iter.getitem(arr_state).convert_to(space, calc_dtype) arr_state = arr_iter.next(arr_state) out_indices = out_iter.indices(out_state) if out_indices[axis] == 0: if identity is not None: w_item = func(calc_dtype, identity, w_item) else: cur_value = temp_iter.getitem(temp_state) w_item = func(calc_dtype, cur_value, w_item) out_iter.setitem(out_state, w_item) out_state = out_iter.next(out_state) temp_iter.setitem(temp_state, w_item) temp_state = temp_iter.next(temp_state) return w_out
def do_axis_reduce(space, shape, func, arr, dtype, axis, out, identity, cumulative, temp): out_iter = AxisIter(out.implementation, arr.get_shape(), axis, cumulative) out_state = out_iter.reset() if cumulative: temp_iter = AxisIter(temp.implementation, arr.get_shape(), axis, False) temp_state = temp_iter.reset() else: temp_iter = out_iter # hack temp_state = out_state arr_iter, arr_state = arr.create_iter() if identity is not None: identity = identity.convert_to(space, dtype) shapelen = len(shape) while not out_iter.done(out_state): axis_reduce__driver.jit_merge_point(shapelen=shapelen, func=func, dtype=dtype) assert not arr_iter.done(arr_state) w_val = arr_iter.getitem(arr_state).convert_to(space, dtype) if out_state.indices[axis] == 0: if identity is not None: w_val = func(dtype, identity, w_val) else: cur = temp_iter.getitem(temp_state) w_val = func(dtype, cur, w_val) out_iter.setitem(out_state, w_val) out_state = out_iter.next(out_state) if cumulative: temp_iter.setitem(temp_state, w_val) temp_state = temp_iter.next(temp_state) else: temp_state = out_state arr_state = arr_iter.next(arr_state) return out
def do_axis_reduce(space, shape, func, arr, dtype, axis, out, identity, cumulative, temp): out_iter = AxisIter(out.implementation, arr.get_shape(), axis, cumulative) out_state = out_iter.reset() if cumulative: temp_iter = AxisIter(temp.implementation, arr.get_shape(), axis, False) temp_state = temp_iter.reset() else: temp_iter = out_iter # hack temp_state = out_state arr_iter, arr_state = arr.create_iter() arr_iter.track_index = False if identity is not None: identity = identity.convert_to(space, dtype) shapelen = len(shape) while not out_iter.done(out_state): axis_reduce_driver.jit_merge_point(shapelen=shapelen, func=func, dtype=dtype) w_val = arr_iter.getitem(arr_state).convert_to(space, dtype) arr_state = arr_iter.next(arr_state) out_indices = out_iter.indices(out_state) if out_indices[axis] == 0: if identity is not None: w_val = func(dtype, identity, w_val) else: cur = temp_iter.getitem(temp_state) w_val = func(dtype, cur, w_val) out_iter.setitem(out_state, w_val) out_state = out_iter.next(out_state) if cumulative: temp_iter.setitem(temp_state, w_val) temp_state = temp_iter.next(temp_state) else: temp_state = out_state return out