示例#1
0
 def test_schema_check_mode_functionality_with_multiple_outputs(self):
     x = torch.arange(9.)
     m_expected, e_expected = torch.frexp(x)
     m_actual = torch.arange(9.)
     e_actual = torch.zeros([9], dtype=torch.int32)
     with enable_torch_dispatch_mode(SchemaCheckMode()):
         torch.frexp(x, out=(m_actual, e_actual))
     self.assertEqual(m_expected, m_actual)
     self.assertEqual(e_expected, e_actual)
示例#2
0
 def test_schema_check_mode_mutated_aliasing_multiple_outputs(self):
     x = torch.arange(9.)
     m_actual = torch.arange(9.)
     e_actual = torch.zeros([9], dtype=torch.int32)
     schema_check = SchemaCheckMode()
     with enable_torch_dispatch_mode(schema_check):
         torch.frexp(x, out=(m_actual, e_actual))
     self.assertEqual([('aten::frexp', 'mantissa'),
                       ('aten::frexp', 'exponent')], schema_check.mutated)
     self.assertEqual([('aten::frexp', 'mantissa', 'output_0'),
                       ('aten::frexp', 'exponent', 'output_1')],
                      schema_check.aliasing)
示例#3
0
    def forward(ctx,
                input,
                weight,
                bias=None,
                temporal="i",
                width=8,
                widtht=4,
                degree=2,
                delta=0,
                cycle_pos=16,
                cycle_neg=-16):
        ctx.save_for_backward(input, weight, bias)

        dtype = input.type()

        if temporal in ["i", "input"]:
            input_fp32 = input.detach().clone().type(torch.float)
            mantissa, exponent = torch.frexp(input_fp32)
            frac = torch.zeros_like(input_fp32)
            mantissa_new = torch.zeros_like(input_fp32)
        elif temporal in ["w", "weight"]:
            weight_fp32 = weight.detach().clone().type(torch.float)
            mantissa, exponent = torch.frexp(weight_fp32)
            frac = torch.zeros_like(weight_fp32)
            mantissa_new = torch.zeros_like(weight_fp32)

        mantissa = mantissa << width
        for i in range(degree):
            mantissa = mantissa >> widtht
            torch.frac(mantissa, out=frac)
            torch.trunc(mantissa, out=mantissa)
            torch.clamp(frac << widtht, cycle_neg + 1, cycle_pos - 1, out=frac)
            torch.add(frac >> widtht, mantissa_new >> widtht, out=mantissa_new)

        mantissa_new = mantissa_new << delta

        if temporal in ["i", "input"]:
            input_new = torch.ldexp(mantissa_new, exponent).type(dtype)
            weight_new = weight
        elif temporal in ["w", "weight"]:
            input_new = input
            weight_new = torch.ldexp(mantissa_new, exponent).type(dtype)

        output = torch.matmul(input_new, weight_new.t())

        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output
示例#4
0
def decompose(t):
    if TORCH_VERSION_MAJOR < 1 or (TORCH_VERSION_MAJOR >= 1
                                   and TORCH_VERSION_MINOR < 9):
        raise Exception(
            'Torch version >= 1.9.0 needed for 24_bit_allreduce.decompose')
    mantissa, exponent = torch.frexp(t.float())
    return mantissa.half(), exponent.to(torch.int8)