from deliravision.models.gans import BoundarySeekingGAN
from deliravision.losses import BoundarySeekingLoss
import os
from training.gans._basic import train, predict

if __name__ == '__main__':

    img_path = os.path.abspath("~/data/")
    outpath = os.path.abspath("~/GanExperiments")
    num_epochs = 1000
    key_mapping = {"x": "data"}

    model, weight_path = train(BoundarySeekingGAN, {"latent_dim": 100,
                                                    "img_shape": (1, 28, 28)},
                               os.path.join(outpath, "train"), img_path,
                               num_epochs=num_epochs,
                               key_mapping=key_mapping,
                               additional_losses={"boundary_seeking":
                                                  BoundarySeekingLoss()})

    predict(model, weight_path, os.path.join(outpath, "preds"), num_epochs)
Ejemplo n.º 2
0
    outpath = os.path.abspath("~/GanExperiments")
    num_epochs = 1000
    key_mapping = {"real_imgs": "data", "real_labels": "label"}
    latent_dim = 100
    batchsize = 64
    n_classes = 10

    model, weight_path = train(
        AuxiliaryClassifierGANPyTorch, {
            "latent_dim": latent_dim,
            "img_size": 28,
            "n_channels": 1,
            "n_classes": n_classes
        },
        os.path.join(outpath, "train"),
        img_path,
        num_epochs=num_epochs,
        additional_losses={"auxiliary": torch.nn.CrossEntropyLoss()},
        key_mapping=key_mapping,
        batchsize=batchsize)

    predict(model,
            weight_path,
            os.path.join(outpath, "preds"),
            num_epochs,
            gen_fns=[torch.randn, torch.randint],
            gen_args=[(batchsize, latent_dim), (0, n_classes, (batchsize, ))],
            gen_kwargs=[{}, {
                "dtype": torch.long
            }])
Ejemplo n.º 3
0
        "code_dim": code_dim
    },
                               os.path.join(outpath, "train"),
                               img_path,
                               num_epochs=num_epochs,
                               key_mapping=key_mapping,
                               additional_losses={
                                   "categorical":
                                   torch.nn.CrossEntropyLoss(),
                                   "continuous":
                                   torch.nn.MSELoss(),
                                   "adversarial":
                                   AdversarialLoss(torch.nn.MSELoss())
                               },
                               create_optim_fn=create_optims,
                               batchsize=batchsize)

    predict(model,
            weight_path,
            os.path.join(outpath, "preds"),
            num_epochs,
            gen_fns=[torch.randn, onehot_ints, uniform],
            gen_args=[(batchsize, latent_dim), (n_classes, (batchsize, 1)),
                      ((batchsize, code_dim), )],
            gen_kwargs=[{}, {
                "dtype": torch.long
            }, {
                "min_val": -1,
                "max_val": 1
            }])
Ejemplo n.º 4
0
from deliravision.models.gans import AdversarialAutoEncoderPyTorch
import os
from training.gans._basic import train, predict
import torch

if __name__ == '__main__':

    img_path = os.path.abspath("~/data/")
    outpath = os.path.abspath("~/GanExperiments")
    num_epochs = 1500
    key_mapping = {"x": "data"}

    model, weight_path = train(
        AdversarialAutoEncoderPyTorch, {
            "latent_dim": 100,
            "img_shape": (1, 28, 28)
        },
        os.path.join(outpath, "train"),
        img_path,
        num_epochs=num_epochs,
        additional_losses={"pixelwise": torch.nn.L1Loss()},
        key_mapping=key_mapping)

    predict(model,
            weight_path,
            os.path.join(outpath, "preds"),
            num_epochs,
            generative_network="generator.decoder")