def _loop_main(main):
            w = conv_nd_kernel.Writer()

            # Loop openings.
            out_xs = conv_nd_kernel.vars('out_x', self.ndim)
            offsets = conv_nd_kernel.vars('offset', self.ndim)
            outs1 = self.outs[1:] + [1]
            offsets1 = ['out_0 * c0'] + offsets[:-1]
            for out_x, out_x0, out_x1, offset, offset1, out1 in moves.zip(
                    out_xs, out_x0s, out_x1s, offsets, offsets1, outs1):
                w.write(
                    'for (int {} = {}; {} < {}; ++{}) {{'.format(
                        out_x, out_x0, out_x, out_x1, out_x), 'inc')
                w.write('int {} = {} * ({} + {});'.format(
                    offset, out1, out_x, offset1))

            # Write main-part.
            offset = offsets[-1]
            for l in main(offset, xs, out_xs).split('\n'):
                w.write(l)

            # Loop closings.
            for _ in out_xs:
                w.write('}', 'dec')

            return [w.get()]
 def main(self, offset, xs):
     # 2D: T v = in[offset_1];
     #     if (maxval < v) {
     #       maxval   = v;
     #       argmax_0 = x_0;
     #       argmax_1 = x_1;
     #     }
     w = conv_nd_kernel.Writer()
     w.write('T v = in[{}];'.format(offset))
     w.write('if (maxval < v) {', 'inc')
     w.write('maxval = v;')
     for argmax, x in six.moves.zip(self.argmaxs, xs):
         w.write('{} = {};'.format(argmax, x))
     w.write('}', 'dec')
     return w.get()
Example #3
0
 def main(self, offset, xs, out_xs):
     # 2D: int kx = (x_1 - out_x_1 * s_1 + k_1 *
     #              (x_0 - out_x_0 * s_0 + k_0 * 0));
     #     if (indexes[offset_1] == kx) {
     #       val = val + gy[offset_1];
     #     }
     def aux(x, out_x, s):
         return '{} - {} * {}'.format(x, out_x, s)
     w = conv_nd_kernel.Writer()
     w.write('int kx = {};'.format(
         conv_nd_kernel.muladdexp(self.ks, conv_nd_kernel.map_(
             aux, xs, out_xs, self.ss), '0')))
     w.write('if (indexes[{}] == kx) {{'.format(offset), 'inc')
     w.write('val = val + gy[{}];'.format(offset))
     w.write('}', 'dec')
     return w.get()