def test_forward(self): y = ad.while_loop( cond=lambda inputs: ad.less(inputs[0], ad.constant(10)), body=lambda inputs: [inputs[0] + 1], loop_vars=[ad.variable(0.0)], ) actual = y.forward() expect = np.arange(1, 11) self.assertEqual((None, ), y.shape) self.assertTrue(np.allclose(expect, actual), (expect, actual)) y = ad.while_loop( cond=lambda inputs: ad.less(inputs[0], ad.constant(5)), body=lambda inputs: [inputs[0] + 1, (inputs[0] + 1) * inputs[1]], loop_vars=[ad.variable(0.0), ad.variable(1.0)], output_index=1, ) actual = y.forward() expect = np.array([1, 2, 6, 24, 120]) self.assertEqual((None, ), y.shape) self.assertTrue(np.allclose(expect, actual), (expect, actual)) y = ad.while_loop( cond=lambda inputs: ad.less(inputs[0], ad.constant(64)), body=lambda inputs: [inputs[0] * 2, ad.dot(inputs[1], ad.variable([[1, 1], [1, 0]]))], loop_vars=[ad.variable(1), ad.variable([[1, 0], [0, 1]])], output_index=1, ) actual = y.forward() expect = np.array([1, 2, 3, 5, 8, 13]) self.assertEqual((None, 2, 2), y.shape) self.assertTrue(np.allclose(expect, actual[:, 0, 0]), (expect, actual))
def _gen_random_and_result(x_shape, y_shape, call_type=True): x_val = np.random.random(x_shape) y_val = np.random.random(y_shape) x = ad.variable(x_val, name='X%s' % str(x_shape)) y = ad.variable(y_val, name='Y%s' % str(y_shape)) if call_type: z = x < y else: z = ad.less(x, y) expect = (x_val < y_val).astype(dtype=np.float64) return z, [x, y], expect
def test_backward(self): x = ad.variable([[1, 1], [1, 0]]) y = ad.while_loop( cond=lambda inputs: ad.less(inputs[0], ad.constant(64)), body=lambda inputs: [inputs[0] * 2, ad.dot(inputs[1], x)], loop_vars=[ad.variable(1), ad.variable([[1, 0], [0, 1]])], output_index=1, ) self.numeric_gradient_check(y, {}, [x])
def call(self, inputs, **kwargs): initial_val = ad.dot( ad.zeros_like(inputs)[:, 0, :], ad.zeros_like(self.wx[:, :self.units])) outputs = ad.while_loop( lambda body_inputs: ad.less(body_inputs[0], ad.shape(inputs)[1]), lambda x: self.step(inputs, x), [ad.variable(0.0), initial_val, initial_val], output_index=-1, ) if self.return_sequences: return outputs.transpose(axes=[1, 0, 2]) return outputs[-1]