コード例 #1
0
ファイル: test_enforce.py プロジェクト: jakegrigsby/daze
def test_ae_with_vae_forward_pass():
    with pytest.raises(DazeModelTypeError):
        model = dz.AutoEncoder(
            ConvolutionalEncoder(3),
            CifarDecoder(),
            forward_pass_func=dz.forward_pass.probabilistic_encode_decode(),
            loss_funcs=[dz.loss.latent_l1()])
コード例 #2
0
ファイル: test_model.py プロジェクト: jakegrigsby/daze
def test_get_batch_encodings_np():
    x, _ = dz.data.cifar10.load(70, "f32")
    x /= 255
    model = dz.AutoEncoder(ConvolutionalEncoder(latent_dim=2), CifarDecoder())
    encodings = model.get_batch_encodings(x)
    assert isinstance(encodings, tf.Tensor)
    assert encodings.numpy().shape[0] == 70
    assert encodings.numpy().shape[1] == 2
コード例 #3
0
ファイル: test_enforce.py プロジェクト: jakegrigsby/daze
def test_ae_with_gan_loss_func():
    with pytest.raises(DazeModelTypeError):
        model = dz.AutoEncoder(ConvolutionalEncoder(3),
                               CifarDecoder(),
                               loss_funcs=[dz.loss.feature_matching()])
コード例 #4
0
ファイル: test_enforce.py プロジェクト: jakegrigsby/daze
def test_ae_with_vae_loss_func():
    with pytest.raises(DazeModelTypeError):
        model = dz.AutoEncoder(ConvolutionalEncoder(3),
                               CifarDecoder(),
                               loss_funcs=[dz.loss.kl()])
コード例 #5
0
ファイル: test_enforce.py プロジェクト: jakegrigsby/daze
def test_ae_with_gan_forward_pass():
    with pytest.raises(DazeModelTypeError):
        model = dz.AutoEncoder(
            ConvolutionalEncoder(3),
            CifarDecoder(),
            forward_pass_func=dz.forward_pass.generative_adversarial())
コード例 #6
0
ファイル: test_model.py プロジェクト: jakegrigsby/daze
def test_get_batch_encodings_unknown():
    with pytest.raises(ValueError):
        model = dz.AutoEncoder(ConvolutionalEncoder(latent_dim=2),
                               CifarDecoder())
        encodings = model.get_batch_encodings([1.0, 2.0, 3.0])
コード例 #7
0
ファイル: test_recipes.py プロジェクト: jakegrigsby/daze
def test_default():
    model = dz.AutoEncoder(ConvolutionalEncoder(3), CifarDecoder())
    cbs = make_callbacks(model)
    train(model, cbs)