Beispiel #1
0
class TorchAdd(nn.Module):
    """Wrapper around torch.add so that all ops can be found at build"""
    def __init__(self):
        super().__init__()
        self.add_func = FloatFunctional()

    def forward(self, x, y):
        return self.add_func.add(x, y)
Beispiel #2
0
class TorchCat2T(nn.Module):
    """Concatenate two tensors in channel dimension"""
    def __init__(self):
        super().__init__()
        self.cat_func = FloatFunctional()

    def forward(self, x, y):
        return self.cat_func.cat([x, y], dim=1)
Beispiel #3
0
class TorchCat(nn.Module):
    """Wrapper around torch.cat so that all ops can be found at build"""
    def __init__(self):
        super().__init__()
        self.cat_func = FloatFunctional()

    def forward(self, tensors: List[torch.Tensor], dim: int):
        return self.cat_func.cat(tensors, dim)
Beispiel #4
0
class TorchMultiply(nn.Module):
    """Wrapper around torch.mul so that all ops can be found at build"""
    def __init__(self):
        super().__init__()
        self.mul_func = FloatFunctional()

    def forward(self, x, y):
        return self.mul_func.mul(x, y)
class TorchMulScalar(nn.Module):
    """Wrapper around torch.mul so that all ops can be found at build
        y must be a scalar, needed for quantization
    """
    def __init__(self):
        super().__init__()
        self.mul_func = FloatFunctional()

    def forward(self, x, y):
        return self.mul_func.mul_scalar(x, y)
Beispiel #6
0
class TorchAddScalar(nn.Module):
    """Wrapper around torch.add so that all ops can be found at build
    y must be a scalar, needed for quantization
    """
    def __init__(self, scalar):
        super().__init__()
        self.add_func = FloatFunctional()
        self.scalar = scalar

    def forward(self, x):
        return self.add_func.add_scalar(x, self.scalar)
Beispiel #7
0
 def __init__(self, scalar):
     super().__init__()
     self.add_func = FloatFunctional()
     self.scalar = scalar
Beispiel #8
0
 def __init__(self):
     super().__init__()
     self.add_func = FloatFunctional()