def test_state_dict_url(self, framework, should_be_available): multi_layer_encoder = enc.alexnet_multi_layer_encoder(pretrained=False) if should_be_available: assert isinstance(multi_layer_encoder.state_dict_url(framework), str) else: with pytest.raises(RuntimeError): multi_layer_encoder.state_dict_url(framework)
def test_alexnet_multi_layer_encoder_smoke(subtests): multi_layer_encoder = enc.alexnet_multi_layer_encoder(pretrained=False) assert isinstance(multi_layer_encoder, enc.alexnet.AlexNetMultiLayerEncoder) with subtests.test("repr"): assert isinstance(multi_layer_encoder, enc.alexnet.AlexNetMultiLayerEncoder)
def test_alexnet_multi_layer_encoder_weights(): weights = "weights" with pytest.warns(UserWarning): multi_layer_encoder = enc.alexnet_multi_layer_encoder( pretrained=False, weights=weights, internal_preprocessing=False) assert multi_layer_encoder.framework == weights
def test_alexnet_multi_layer_encoder_state_dict_url(subtests, frameworks): def should_be_available(framework): return framework == "torch" multi_layer_encoder = enc.alexnet_multi_layer_encoder(pretrained=False) for framework in frameworks: with subtests.test(framework=framework): if should_be_available(framework): assert isinstance( multi_layer_encoder.state_dict_url(framework), str) else: with pytest.raises(RuntimeError): multi_layer_encoder.state_dict_url(framework)
def test_AlexNetMultiLayerEncoder(self): asset = self.load_asset(path.join("enc", "alexnet")) multi_layer_encoder = enc.alexnet_multi_layer_encoder( weights="torch", preprocessing=False, allow_inplace=False) layers = tuple(multi_layer_encoder.children_names()) with torch.no_grad(): encs = multi_layer_encoder(asset.input.image, layers) actual = dict( zip( layers, [ pystiche.TensorKey(x, precision=asset.params.precision) for x in encs ], )) desired = asset.output.enc_keys self.assertDictEqual(actual, desired)
def test_alexnet_multi_layer_encoder(enc_asset_loader): asset = enc_asset_loader("alexnet") multi_layer_encoder = enc.alexnet_multi_layer_encoder(pretrained=True, weights="torch", preprocessing=False, allow_inplace=False) layers = tuple(multi_layer_encoder.children_names()) with torch.no_grad(): encs = multi_layer_encoder(asset.input.image, layers) actual = dict( zip( layers, [ pystiche.TensorKey(x, precision=asset.params.precision) for x in encs ], )) desired = asset.output.enc_keys assert actual == desired
def test_main(self, enc_asset_loader): asset = enc_asset_loader("alexnet") multi_layer_encoder = enc.alexnet_multi_layer_encoder( pretrained=True, weights="torch", preprocessing=False, allow_inplace=False ) layers = tuple(multi_layer_encoder.children_names()) with torch.no_grad(): encs = multi_layer_encoder(asset.input.image, layers) actual = dict( zip( layers, [pystiche.TensorKey(x, precision=asset.params.precision) for x in encs], ) ) desired = asset.output.enc_keys assert actual == desired @pytest.mark.parametrize( ("framework", "should_be_available"), [ pytest.param(framework, should_be_available, id=framework) for framework, should_be_available in [ ("torch", True), ("caffe", False), ] ], ) def test_state_dict_url(self, framework, should_be_available): multi_layer_encoder = enc.alexnet_multi_layer_encoder(pretrained=False) if should_be_available: assert isinstance(multi_layer_encoder.state_dict_url(framework), str) else: with pytest.raises(RuntimeError): multi_layer_encoder.state_dict_url(framework)
def test_load_state_dict_smoke(self): model = models.alexnet(pretrained=False) state_dict = model.state_dict() multi_layer_encoder = enc.alexnet_multi_layer_encoder() multi_layer_encoder.load_state_dict(state_dict)
def test_repr_smoke(self): multi_layer_encoder = enc.alexnet_multi_layer_encoder(pretrained=False) assert isinstance(repr(multi_layer_encoder), str)
def test_main_smoke(self): multi_layer_encoder = enc.alexnet_multi_layer_encoder(pretrained=False) assert isinstance(multi_layer_encoder, enc.alexnet.AlexNetMultiLayerEncoder)
def test_alexnet_multi_layer_encoder_smoke(self): multi_layer_encoder = enc.alexnet_multi_layer_encoder() self.assertIsInstance(multi_layer_encoder, enc.alexnet.MultiLayerAlexNetEncoder)