default='ssvae_aspire', type=str, help="model file prefix to store") ssvae_parser.add_argument('--continue-from', default=None, type=str, help="model file path to make continued from") args = parser.parse_args() if args.model is None: parser.print_help() sys.exit(1) # some assertions to make sure that batching math assumptions are met assert parse_torch_version() >= (0, 2, 1), "you need pytorch 0.2.1 or later" set_logfile(Path(args.log_dir, "train.log")) logger.info(f"Training started with command: {' '.join(sys.argv)}") args_str = [f"{k}={v}" for (k, v) in vars(args).items()] logger.info(f"args: {' '.join(args_str)}") if args.use_cuda: logger.info("using cuda") torch.set_default_tensor_type("torch.cuda.FloatTensor") if args.seed is not None: torch.manual_seed(args.seed) np.random.seed(args.seed)
from torch.autograd import Variable import pyro.distributions as dist from utils.mnist_cached import MNISTCached, setup_data_loaders from pyro.infer import SVI from pyro.optim import Adam from pyro.nn import ClippedSoftmax, ClippedSigmoid from pyro.shim import parse_torch_version from utils.custom_mlp import MLP, Exp from utils.vae_plots import plot_conditional_samples_ssvae, mnist_test_tsne_ssvae from util import set_seed, print_and_log, mkdir_p import torch.nn as nn version_warning = ''' 11/02/2017: This example does not work with PyTorch 0.2, please install PyTorch 0.3. ''' torch_version = parse_torch_version() if (torch_version < (0, 2, 1) and not torch_version[-1].startswith("+")): print(version_warning) sys.exit(0) class SSVAE(nn.Module): """ This class encapsulates the parameters (neural networks) and models & guides needed to train a semi-supervised variational auto-encoder on the MNIST image dataset :param output_size: size of the tensor representing the class label (10 for MNIST since we represent the class labels as a one-hot vector with 10 components) :param input_size: size of the tensor representing the image (28*28 = 784 for our MNIST dataset since we flatten the images and scale the pixels to be in [0,1]) :param z_dim: size of the tensor representing the latent random variable z