コード例 #1
0
ファイル: test_recipes.py プロジェクト: jakegrigsby/daze
def test_contractive():
    if dz.tracing.TRACE_GRAPHS:
        with pytest.raises(ValueError):
            model = dz.recipes.ContractiveAutoEncoder(
                ConvolutionalEncoder(), CifarDecoder(), gamma=0.1
            )
    else:
        model = dz.recipes.ContractiveAutoEncoder(
            ConvolutionalEncoder(), CifarDecoder(), gamma=0.1
        )
        cbs = make_callbacks(model)
        train(model, cbs)
コード例 #2
0
ファイル: test_enforce.py プロジェクト: jakegrigsby/daze
def test_ae_with_vae_forward_pass():
    with pytest.raises(DazeModelTypeError):
        model = dz.AutoEncoder(
            ConvolutionalEncoder(3),
            CifarDecoder(),
            forward_pass_func=dz.forward_pass.probabilistic_encode_decode(),
            loss_funcs=[dz.loss.latent_l1()])
コード例 #3
0
ファイル: test_enforce.py プロジェクト: jakegrigsby/daze
def test_gan_with_vae_forward_pass():
    with pytest.raises(DazeModelTypeError):
        model = dz.GAN(
            CifarDecoder(),
            ConvolutionalEncoder(),
            100,
            forward_pass_func=dz.forward_pass.probabilistic_encode_decode())
コード例 #4
0
ファイル: test_model.py プロジェクト: jakegrigsby/daze
def test_get_batch_encodings_np():
    x, _ = dz.data.cifar10.load(70, "f32")
    x /= 255
    model = dz.AutoEncoder(ConvolutionalEncoder(latent_dim=2), CifarDecoder())
    encodings = model.get_batch_encodings(x)
    assert isinstance(encodings, tf.Tensor)
    assert encodings.numpy().shape[0] == 70
    assert encodings.numpy().shape[1] == 2
コード例 #5
0
ファイル: test_enforce.py プロジェクト: jakegrigsby/daze
def test_gan_with_gen_loss_in_disc_loss():
    with pytest.raises(DazeModelTypeError):
        model = dz.GAN(CifarDecoder(),
                       ConvolutionalEncoder(),
                       100,
                       discriminator_loss=[dz.loss.vanilla_generator_loss()])
コード例 #6
0
ファイル: test_enforce.py プロジェクト: jakegrigsby/daze
def test_gan_with_disc_loss_in_gen_loss():
    with pytest.raises(DazeModelTypeError):
        model = dz.GAN(CifarDecoder(),
                       ConvolutionalEncoder(),
                       100,
                       generator_loss=[dz.loss.one_sided_label_smoothing()])
コード例 #7
0
ファイル: test_enforce.py プロジェクト: jakegrigsby/daze
def test_gan_with_ae_loss_in_disc_loss():
    with pytest.raises(DazeModelTypeError):
        model = dz.GAN(CifarDecoder(),
                       ConvolutionalEncoder(),
                       100,
                       discriminator_loss=[dz.loss.reconstruction()])
コード例 #8
0
ファイル: test_enforce.py プロジェクト: jakegrigsby/daze
def test_gan_with_ae_loss_in_gen_loss():
    with pytest.raises(DazeModelTypeError):
        model = dz.GAN(CifarDecoder(),
                       ConvolutionalEncoder(),
                       100,
                       generator_loss=[dz.loss.contractive(.1)])
コード例 #9
0
ファイル: test_enforce.py プロジェクト: jakegrigsby/daze
def test_ae_with_vae_loss_func():
    with pytest.raises(DazeModelTypeError):
        model = dz.AutoEncoder(ConvolutionalEncoder(3),
                               CifarDecoder(),
                               loss_funcs=[dz.loss.kl()])
コード例 #10
0
ファイル: test_recipes.py プロジェクト: jakegrigsby/daze
def test_gan_one_sided_labels():
    model = dz.GAN(CifarDecoder(), ConvolutionalEncoder(), noise_dim=100, discriminator_loss=[dz.loss.one_sided_label_smoothing()])
    train(model, None)
コード例 #11
0
ファイル: test_recipes.py プロジェクト: jakegrigsby/daze
def test_denoising():
    model = dz.recipes.DenoisingAutoEncoder(ConvolutionalEncoder(), CifarDecoder(), gamma=0.1)
    cbs = make_callbacks(model)
    train(model, cbs)
コード例 #12
0
ファイル: test_recipes.py プロジェクト: jakegrigsby/daze
def test_klsparse():
    model = dz.recipes.KlSparseAutoEncoder(
        ConvolutionalEncoder(), CifarDecoder(), rho=0.01, beta=0.1
    )
    cbs = make_callbacks(model)
    train(model, cbs)
コード例 #13
0
ファイル: test_recipes.py プロジェクト: jakegrigsby/daze
def test_vae():
    model = dz.recipes.VariationalAutoEncoder(ConvolutionalEncoder(), CifarDecoder())
    cbs = make_callbacks(model)
    train(model, cbs)
コード例 #14
0
ファイル: test_recipes.py プロジェクト: jakegrigsby/daze
def test_gan_instance_noise():
    model = dz.GAN(CifarDecoder(), ConvolutionalEncoder(), noise_dim=100, forward_pass_func=dz.forward_pass.generative_adversarial_instance_noise(.2, 0., 1000))
    train(model, None)
コード例 #15
0
ファイル: test_recipes.py プロジェクト: jakegrigsby/daze
def test_gan_feature_matching():
    model = dz.GAN(CifarDecoder(), ConvolutionalEncoder(), noise_dim=100, generator_loss=[dz.loss.feature_matching()])
    train(model, None)
コード例 #16
0
ファイル: test_enforce.py プロジェクト: jakegrigsby/daze
def test_ae_with_gan_forward_pass():
    with pytest.raises(DazeModelTypeError):
        model = dz.AutoEncoder(
            ConvolutionalEncoder(3),
            CifarDecoder(),
            forward_pass_func=dz.forward_pass.generative_adversarial())
コード例 #17
0
ファイル: test_recipes.py プロジェクト: jakegrigsby/daze
def test_default():
    model = dz.AutoEncoder(ConvolutionalEncoder(3), CifarDecoder())
    cbs = make_callbacks(model)
    train(model, cbs)
コード例 #18
0
ファイル: test_recipes.py プロジェクト: jakegrigsby/daze
def test_l1sparse():
    model = dz.recipes.L1SparseAutoEncoder(ConvolutionalEncoder(), CifarDecoder(), gamma=0.1)
    cbs = make_callbacks(model)
    train(model, cbs)
コード例 #19
0
ファイル: test_recipes.py プロジェクト: jakegrigsby/daze
def test_gan():
    model = dz.GAN(CifarDecoder(), ConvolutionalEncoder(), 100)
    cbs = [tensorboard_generative_sample(dz.math.random_normal([5, 100]))]
    train(model, cbs)
コード例 #20
0
ファイル: test_enforce.py プロジェクト: jakegrigsby/daze
def test_ae_with_gan_loss_func():
    with pytest.raises(DazeModelTypeError):
        model = dz.AutoEncoder(ConvolutionalEncoder(3),
                               CifarDecoder(),
                               loss_funcs=[dz.loss.feature_matching()])
コード例 #21
0
ファイル: test_model.py プロジェクト: jakegrigsby/daze
def test_get_batch_encodings_unknown():
    with pytest.raises(ValueError):
        model = dz.AutoEncoder(ConvolutionalEncoder(latent_dim=2),
                               CifarDecoder())
        encodings = model.get_batch_encodings([1.0, 2.0, 3.0])