def test_amp_lists_5(self): # 5. w=None, b={'elementwise_add'} self.fp32_list.add('elementwise_add') self.bf16_list.remove('elementwise_add') self.amp_lists_ = amp.AutoMixedPrecisionListsBF16( custom_fp32_list={'elementwise_add'})
def test_is_in_fp32_varnames(self): block = fluid.default_main_program().global_block() var1 = block.create_var(name="X", shape=[3], dtype='float32') var2 = block.create_var(name="Y", shape=[3], dtype='float32') var3 = block.create_var(name="Z", shape=[3], dtype='float32') op1 = block.append_op(type="abs", inputs={"X": [var1]}, outputs={"Out": [var2]}) op2 = block.append_op(type="abs", inputs={"X": [var2]}, outputs={"Out": [var3]}) amp_lists_1 = amp.AutoMixedPrecisionListsBF16( custom_fp32_varnames={'X'}) assert amp.bf16.amp_utils._is_in_fp32_varnames(op1, amp_lists_1) amp_lists_2 = amp.AutoMixedPrecisionListsBF16( custom_fp32_varnames={'Y'}) assert amp.bf16.amp_utils._is_in_fp32_varnames(op2, amp_lists_2) assert amp.bf16.amp_utils._is_in_fp32_varnames(op1, amp_lists_2)
def test_amp_list_8(self): self.bf16_list.add('reshape2') self.gray_list.remove('reshape2') self.amp_lists_ = amp.AutoMixedPrecisionListsBF16( custom_bf16_list={'reshape2'})
def test_amp_lists_6(self): # 6. w=None, b={'lstm'} self.fp32_list.add('lstm') self.amp_lists_ = amp.AutoMixedPrecisionListsBF16( custom_fp32_list={'lstm'})
def test_amp_lists_3(self): # 3. w={'lstm'}, b=None self.bf16_list.add('lstm') self.amp_lists_ = amp.AutoMixedPrecisionListsBF16({'lstm'})
def test_amp_lists_2(self): # 2. w={'tanh'}, b=None self.fp32_list.remove('tanh') self.bf16_list.add('tanh') self.amp_lists_ = amp.AutoMixedPrecisionListsBF16({'tanh'})
def test_amp_lists_1(self): # 1. w={'exp}, b=None self.bf16_list.add('exp') self.fp32_list.remove('exp') self.amp_lists_ = amp.AutoMixedPrecisionListsBF16({'exp'})
def test_amp_lists(self): self.amp_lists_ = amp.AutoMixedPrecisionListsBF16()