예제 #1
0
    def test_amp_lists_5(self):
        # 5. w=None, b={'elementwise_add'}
        self.fp32_list.add('elementwise_add')
        self.bf16_list.remove('elementwise_add')

        self.amp_lists_ = amp.AutoMixedPrecisionListsBF16(
            custom_fp32_list={'elementwise_add'})
예제 #2
0
    def test_is_in_fp32_varnames(self):
        block = fluid.default_main_program().global_block()

        var1 = block.create_var(name="X", shape=[3], dtype='float32')
        var2 = block.create_var(name="Y", shape=[3], dtype='float32')
        var3 = block.create_var(name="Z", shape=[3], dtype='float32')
        op1 = block.append_op(type="abs",
                              inputs={"X": [var1]},
                              outputs={"Out": [var2]})
        op2 = block.append_op(type="abs",
                              inputs={"X": [var2]},
                              outputs={"Out": [var3]})
        amp_lists_1 = amp.AutoMixedPrecisionListsBF16(
            custom_fp32_varnames={'X'})
        assert amp.bf16.amp_utils._is_in_fp32_varnames(op1, amp_lists_1)
        amp_lists_2 = amp.AutoMixedPrecisionListsBF16(
            custom_fp32_varnames={'Y'})
        assert amp.bf16.amp_utils._is_in_fp32_varnames(op2, amp_lists_2)
        assert amp.bf16.amp_utils._is_in_fp32_varnames(op1, amp_lists_2)
예제 #3
0
    def test_amp_list_8(self):
        self.bf16_list.add('reshape2')
        self.gray_list.remove('reshape2')

        self.amp_lists_ = amp.AutoMixedPrecisionListsBF16(
            custom_bf16_list={'reshape2'})
예제 #4
0
    def test_amp_lists_6(self):
        # 6. w=None, b={'lstm'}
        self.fp32_list.add('lstm')

        self.amp_lists_ = amp.AutoMixedPrecisionListsBF16(
            custom_fp32_list={'lstm'})
예제 #5
0
    def test_amp_lists_3(self):
        # 3. w={'lstm'}, b=None
        self.bf16_list.add('lstm')

        self.amp_lists_ = amp.AutoMixedPrecisionListsBF16({'lstm'})
예제 #6
0
    def test_amp_lists_2(self):
        # 2. w={'tanh'}, b=None
        self.fp32_list.remove('tanh')
        self.bf16_list.add('tanh')

        self.amp_lists_ = amp.AutoMixedPrecisionListsBF16({'tanh'})
예제 #7
0
    def test_amp_lists_1(self):
        # 1. w={'exp}, b=None
        self.bf16_list.add('exp')
        self.fp32_list.remove('exp')

        self.amp_lists_ = amp.AutoMixedPrecisionListsBF16({'exp'})
예제 #8
0
 def test_amp_lists(self):
     self.amp_lists_ = amp.AutoMixedPrecisionListsBF16()