示例#1
0
                              default='ssvae_aspire',
                              type=str,
                              help="model file prefix to store")
    ssvae_parser.add_argument('--continue-from',
                              default=None,
                              type=str,
                              help="model file path to make continued from")

    args = parser.parse_args()

    if args.model is None:
        parser.print_help()
        sys.exit(1)

    # some assertions to make sure that batching math assumptions are met
    assert parse_torch_version() >= (0, 2,
                                     1), "you need pytorch 0.2.1 or later"

    set_logfile(Path(args.log_dir, "train.log"))

    logger.info(f"Training started with command: {' '.join(sys.argv)}")
    args_str = [f"{k}={v}" for (k, v) in vars(args).items()]
    logger.info(f"args: {' '.join(args_str)}")

    if args.use_cuda:
        logger.info("using cuda")
        torch.set_default_tensor_type("torch.cuda.FloatTensor")

    if args.seed is not None:
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)
示例#2
0
from torch.autograd import Variable
import pyro.distributions as dist
from utils.mnist_cached import MNISTCached, setup_data_loaders
from pyro.infer import SVI
from pyro.optim import Adam
from pyro.nn import ClippedSoftmax, ClippedSigmoid
from pyro.shim import parse_torch_version
from utils.custom_mlp import MLP, Exp
from utils.vae_plots import plot_conditional_samples_ssvae, mnist_test_tsne_ssvae
from util import set_seed, print_and_log, mkdir_p
import torch.nn as nn

version_warning = '''
11/02/2017: This example does not work with PyTorch 0.2, please install PyTorch 0.3.
'''
torch_version = parse_torch_version()
if (torch_version < (0, 2, 1) and not torch_version[-1].startswith("+")):
    print(version_warning)
    sys.exit(0)


class SSVAE(nn.Module):
    """
    This class encapsulates the parameters (neural networks) and models & guides needed to train a
    semi-supervised variational auto-encoder on the MNIST image dataset

    :param output_size: size of the tensor representing the class label (10 for MNIST since
                        we represent the class labels as a one-hot vector with 10 components)
    :param input_size: size of the tensor representing the image (28*28 = 784 for our MNIST dataset
                       since we flatten the images and scale the pixels to be in [0,1])
    :param z_dim: size of the tensor representing the latent random variable z