Exemplo n.º 1
0
def test_invalid_multiconv_options():
    mc = poptorch.MultiConv()

    with pytest.raises(AssertionError, match="Invalid partials types"):
        mc.partialsTypes("half")

    with pytest.raises(AssertionError, match="Invalid plan type"):
        mc.planType("parallel")
Exemplo n.º 2
0
def test_multiconv_options_broadcast():
    multiconv = (
        poptorch.MultiConv().availableMemoryProportions(0.8).partialsTypes(
            poptorch.MultiConvPartialsType.Float).planType(
                poptorch.MultiConvPlanType.Parallel).perConvReservedTiles(
                    100).cycleBackOff(0.3))

    multiconv_harness(multiconv)
Exemplo n.º 3
0
def test_multiconv_options_per_conv():
    partials_types = [
        poptorch.MultiConvPartialsType.Float,
        poptorch.MultiConvPartialsType.Float
    ]
    multiconv = (poptorch.MultiConv().availableMemoryProportions(
        (0.8, 0.7)).partialsTypes(partials_types).planType(
            poptorch.MultiConvPlanType.Parallel).perConvReservedTiles(
                120).cycleBackOff(0.4))

    multiconv_harness(multiconv)
Exemplo n.º 4
0
        def forward(self, x):
            with poptorch.MultiConv():
                y = self.layer1A(x)
                z = self.layer1B(x)
                x = y + z

            x = self.layer2(x)
            x = x.view(-1, 320)
            x = self.layer3_act(self.layer3(x))
            x = self.layer4(x)
            x = self.softmax(x)
            return x
Exemplo n.º 5
0
 def forward(self, x, y):
     with poptorch.MultiConv():
         return (self.convA(x), self.convB(y))
Exemplo n.º 6
0
 def forward(self, x):
     with poptorch.MultiConv():
         return torch.pow(x, 2)
Exemplo n.º 7
0
 def forward(self, x):
     with poptorch.MultiConv():
         with poptorch.MultiConv():
             return self.conv(x)
Exemplo n.º 8
0
 def forwardWithMultiConv(*args, **kwargs):
     with poptorch.MultiConv():
         return forward_impl(*args, **kwargs)