Ejemplo n.º 1
0
class TorchAdd(nn.Module):
    """Wrapper around torch.add so that all ops can be found at build"""
    def __init__(self):
        super().__init__()
        self.add_func = FloatFunctional()

    def forward(self, x, y):
        return self.add_func.add(x, y)
Ejemplo n.º 2
0
class TorchCat2T(nn.Module):
    """Concatenate two tensors in channel dimension"""
    def __init__(self):
        super().__init__()
        self.cat_func = FloatFunctional()

    def forward(self, x, y):
        return self.cat_func.cat([x, y], dim=1)
Ejemplo n.º 3
0
class TorchCat(nn.Module):
    """Wrapper around torch.cat so that all ops can be found at build"""
    def __init__(self):
        super().__init__()
        self.cat_func = FloatFunctional()

    def forward(self, tensors: List[torch.Tensor], dim: int):
        return self.cat_func.cat(tensors, dim)
Ejemplo n.º 4
0
class TorchMultiply(nn.Module):
    """Wrapper around torch.mul so that all ops can be found at build"""
    def __init__(self):
        super().__init__()
        self.mul_func = FloatFunctional()

    def forward(self, x, y):
        return self.mul_func.mul(x, y)
Ejemplo n.º 5
0
class TorchMulScalar(nn.Module):
    """Wrapper around torch.mul so that all ops can be found at build
        y must be a scalar, needed for quantization
    """
    def __init__(self):
        super().__init__()
        self.mul_func = FloatFunctional()

    def forward(self, x, y):
        return self.mul_func.mul_scalar(x, y)
Ejemplo n.º 6
0
class TorchAddScalar(nn.Module):
    """Wrapper around torch.add so that all ops can be found at build
    y must be a scalar, needed for quantization
    """
    def __init__(self, scalar):
        super().__init__()
        self.add_func = FloatFunctional()
        self.scalar = scalar

    def forward(self, x):
        return self.add_func.add_scalar(x, self.scalar)
Ejemplo n.º 7
0
 def __init__(self, scalar):
     super().__init__()
     self.add_func = FloatFunctional()
     self.scalar = scalar
Ejemplo n.º 8
0
 def __init__(self):
     super().__init__()
     self.add_func = FloatFunctional()