def test_backward(self): x = ad.variable([1, 2, 3, 4, 5, 6]) y = ad.map_fn(lambda x: x * x, x) self.numeric_gradient_check(y, {}, [x]) x = ad.variable([1, 2, 3]) y = ad.variable([-1, 1, -1]) z = ad.map_fn(lambda x: x[0] * x[1], (x, y)) self.numeric_gradient_check(z, {}, [x, y]) x = ad.variable([1, 2, 3]) y = ad.map_fn(lambda x: (x, -x), x) z = y[0] * y[1] self.numeric_gradient_check(z, {}, [x])
def call(self, inputs, **kwargs): padded = ad.pad(inputs, ((0, ), (self.pad_width[0], ), (self.pad_width[1], ), (0, ))) batch_size = ad.shape(inputs)[0] reshaped = ad.map_fn(lambda i: self.call_batch(padded, i), ad.arange(batch_size)) y = ad.dot(reshaped, self.w) if self.use_bias: y += self.b if self.activation is not None: y = self.activation(y) return y
def test_forward(self): x = ad.variable([1, 2, 3, 4, 5, 6]) y = ad.map_fn(lambda x: x * x, x) actual = y.forward() expect = np.array([1, 4, 9, 16, 25, 36]) self.assertEqual(expect.shape, y.shape) self.assertTrue(np.allclose(expect, actual), (expect, actual)) x = ad.variable([1, 2, 3]) y = ad.variable([-1, 1, -1]) z = ad.map_fn(lambda x: x[0] * x[1], (x, y)) actual = z.forward() expect = np.array([-1, 2, -3]) self.assertEqual(expect.shape, z.shape) self.assertTrue(np.allclose(expect, actual), (expect, actual)) x = ad.variable([1, 2, 3]) y = ad.map_fn(lambda x: (x, -x), x) actual = (y[0].forward(), y[1].forward()) expect = (np.array([1, 2, 3]), np.array([-1, -2, -3])) for i in range(2): self.assertEqual(expect[i].shape, y[i].shape) self.assertTrue(np.allclose(expect[i], actual[i]), (i, expect[i], actual[i]))
def call_batch(self, padded: ad.Operation, i: int): height = ad.shape(padded)[1] new_height = (height - self.dilated_kernel_size[0]) // self.strides[0] + 1 return ad.map_fn(lambda r: self.call_row(padded, i, r), ad.arange(new_height))
def call_row(self, padded: ad.Operation, i: int, r: int): width = ad.shape(padded)[2] new_width = (width - self.dilated_kernel_size[1]) // self.strides[1] + 1 return ad.map_fn(lambda c: self.call_column(padded, i, r, c), ad.arange(new_width))