def test_pending(self): a = jt.float([1,2,3]) b = jt.float([1,2,3]) c = a.float().float().float() * b.float().float().float() del a c.data assert (c.data==[1,4,9]).all(), c.data d, = jt.grad(c, [b]) d.data assert (d.data==[1,2,3]).all(), d.data
def test_int_grad(self): x = jt.array(2.0) z = x * x * x * x * x dx, = jt.grad(z, [x]) self.assertEqual(dx.data, 5 * 2**4) y1 = jt.int(x) y2 = jt.float(x) z = x * x * y1 * y1 * y2 expect_error(lambda: jt.grad(z, [y1])) dx, = jt.grad(z, [x]) self.assertEqual(dx.data, 48)
def resize_and_crop(x, bbox, interpolation="nearest"): N, k = bbox.shape H, W = x.shape assert k==4 shape = [N,H,W] # fx x cx # +------------> # fy | a dx | b # | dy # y | - o - # | # cy | c | d # v img = x bb = [ bbox.reindex(shape, ["i0", str(i)]) for i in range(4) ] hid = jt.index(shape, 1) wid = jt.index(shape, 2) one = jt.float(1).broadcast(shape) x = bb[0]*jt.float(H-1)+hid*(bb[2]-bb[0]) y = bb[1]*jt.float(W-1)+wid*(bb[3]-bb[1]) if interpolation=="nearest": return img.reindex_var([x.round(), y.round()]) if interpolation=="bilinear": fx, fy = x.floor(), y.floor() cx, cy = fx+one, fy+one dx, dy = x-fx, y-fy a = img.reindex_var([fx, fy]) b = img.reindex_var([cx, fy]) c = img.reindex_var([fx, cy]) d = img.reindex_var([cx, cy]) dnx, dny = one-dx, one-dy ab = dx*b + dnx*a cd = dx*d + dnx*c o = ab*dny + cd*dy return o raise(f"Not support {interpolation}")