def test_error(self): with paddle.static.program_guard(paddle.static.Program()): weight_fp32 = paddle.data( name='weight_fp32', shape=[1], dtype='float32') # The input type must be Variable. self.assertRaises(TypeError, F.prelu, x=1, weight=weight_fp32) # The input dtype must be float16, float32, float64. x_int32 = paddle.data(name='x_int32', shape=[2, 3], dtype='int32') self.assertRaises(TypeError, F.prelu, x=x_int32, weight=weight_fp32) # support the input dtype is float16 x_fp16 = paddle.data(name='x_fp16', shape=[2, 3], dtype='float16') F.prelu(x=x_fp16, weight=weight_fp32)
def dygraph_check(self, weight_np): paddle.disable_static(self.place) x = paddle.to_tensor(self.x_np) weight = paddle.to_tensor(weight_np) out = F.prelu(x, weight) out_ref = ref_prelu(self.x_np, weight_np) self.assertEqual(np.allclose(out_ref, out.numpy()), True) paddle.enable_static()
def static_check(self, weight_np): with paddle.static.program_guard(paddle.static.Program()): x = paddle.data('X', self.x_np.shape, 'float32') weight = paddle.data('Alpha', weight_np.shape, 'float32') out = F.prelu(x, weight) exe = paddle.static.Executor(self.place) res = exe.run(feed={'X': self.x_np, 'Alpha': weight_np}, fetch_list=[out]) out_ref = ref_prelu(self.x_np, weight_np) self.assertEqual(np.allclose(out_ref, res[0]), True)
def forward(self, x): return F.prelu(x, self.weight)