예제 #1
0
 def test_error(self):
     with paddle.static.program_guard(paddle.static.Program()):
         weight_fp32 = paddle.data(
             name='weight_fp32', shape=[1], dtype='float32')
         # The input type must be Variable.
         self.assertRaises(TypeError, F.prelu, x=1, weight=weight_fp32)
         # The input dtype must be float16, float32, float64.
         x_int32 = paddle.data(name='x_int32', shape=[2, 3], dtype='int32')
         self.assertRaises(TypeError, F.prelu, x=x_int32, weight=weight_fp32)
         # support the input dtype is float16
         x_fp16 = paddle.data(name='x_fp16', shape=[2, 3], dtype='float16')
         F.prelu(x=x_fp16, weight=weight_fp32)
예제 #2
0
 def dygraph_check(self, weight_np):
     paddle.disable_static(self.place)
     x = paddle.to_tensor(self.x_np)
     weight = paddle.to_tensor(weight_np)
     out = F.prelu(x, weight)
     out_ref = ref_prelu(self.x_np, weight_np)
     self.assertEqual(np.allclose(out_ref, out.numpy()), True)
     paddle.enable_static()
예제 #3
0
 def static_check(self, weight_np):
     with paddle.static.program_guard(paddle.static.Program()):
         x = paddle.data('X', self.x_np.shape, 'float32')
         weight = paddle.data('Alpha', weight_np.shape, 'float32')
         out = F.prelu(x, weight)
         exe = paddle.static.Executor(self.place)
         res = exe.run(feed={'X': self.x_np,
                             'Alpha': weight_np},
                       fetch_list=[out])
     out_ref = ref_prelu(self.x_np, weight_np)
     self.assertEqual(np.allclose(out_ref, res[0]), True)
예제 #4
0
파일: iresnet.py 프로젝트: bilylee/DCQ
 def forward(self, x):
     return F.prelu(x, self.weight)