def after(self, out_xs): # 2D: out = maxval; # int argmax_k_0 = argmax_0 + p_0 - out_x_0 * s_0; # int argmax_k_1 = argmax_1 + p_1 - out_x_1 * s_1; # indexes = (argmax_k_1 + k_1 * argmax_k_0); def aux(argmax_k, argmax, p, out_x, s): return 'int {} = {} + {} - {} * {};'.format( argmax_k, argmax, p, out_x, s) argmax_ks = conv_nd_kernel.vars('argmax_k', self.ndim) argmax_k_decls = conv_nd_kernel.map_( aux, argmax_ks, self.argmaxs, self.ps, out_xs, self.ss) indexes_set = 'indexes = {};'.format( conv_nd_kernel.muladdexp(self.ks[1:], argmax_ks[1:], argmax_ks[0])) return '\n'.join(['out = maxval;'] + argmax_k_decls + [indexes_set])
def main(self, offset, xs, out_xs): # 2D: int kx = (x_1 - out_x_1 * s_1 + k_1 * # (x_0 - out_x_0 * s_0 + k_0 * 0)); # if (indexes[offset_1] == kx) { # val = val + gy[offset_1]; # } def aux(x, out_x, s): return '{} - {} * {}'.format(x, out_x, s) w = conv_nd_kernel.Writer() w.write('int kx = {};'.format( conv_nd_kernel.muladdexp(self.ks, conv_nd_kernel.map_( aux, xs, out_xs, self.ss), '0'))) w.write('if (indexes[{}] == kx) {{'.format(offset), 'inc') w.write('val = val + gy[{}];'.format(offset)) w.write('}', 'dec') return w.get()