def test_invalid_multiconv_options(): mc = poptorch.MultiConv() with pytest.raises(AssertionError, match="Invalid partials types"): mc.partialsTypes("half") with pytest.raises(AssertionError, match="Invalid plan type"): mc.planType("parallel")
def test_multiconv_options_broadcast(): multiconv = ( poptorch.MultiConv().availableMemoryProportions(0.8).partialsTypes( poptorch.MultiConvPartialsType.Float).planType( poptorch.MultiConvPlanType.Parallel).perConvReservedTiles( 100).cycleBackOff(0.3)) multiconv_harness(multiconv)
def test_multiconv_options_per_conv(): partials_types = [ poptorch.MultiConvPartialsType.Float, poptorch.MultiConvPartialsType.Float ] multiconv = (poptorch.MultiConv().availableMemoryProportions( (0.8, 0.7)).partialsTypes(partials_types).planType( poptorch.MultiConvPlanType.Parallel).perConvReservedTiles( 120).cycleBackOff(0.4)) multiconv_harness(multiconv)
def forward(self, x): with poptorch.MultiConv(): y = self.layer1A(x) z = self.layer1B(x) x = y + z x = self.layer2(x) x = x.view(-1, 320) x = self.layer3_act(self.layer3(x)) x = self.layer4(x) x = self.softmax(x) return x
def forward(self, x, y): with poptorch.MultiConv(): return (self.convA(x), self.convB(y))
def forward(self, x): with poptorch.MultiConv(): return torch.pow(x, 2)
def forward(self, x): with poptorch.MultiConv(): with poptorch.MultiConv(): return self.conv(x)
def forwardWithMultiConv(*args, **kwargs): with poptorch.MultiConv(): return forward_impl(*args, **kwargs)