示例#1
0

if args.params in ['beta', 'betas', 'b']:
	params_used = beta_params # mi_params                                                                                                                                         
elif args.params in ['mi', 'mis', 'm', 'constraint', 'constraints']:
	params_used = mi_params
elif args.params in ['few_betas', 'small_beta']:
        params_used = beta_params

if args.dmax is not None:
        params_used["layers.5.layer_kwargs.d_max"] = list(args.dmax)
        vary_together = False

#params = {"activation.encoder": ["softplus", "sigmoid"]}                                                                                                                     
if args.dataset == 'fmnist':
	d = dataset.fMNIST()
elif args.dataset == 'binary_mnist':
	d = dataset.MNIST(binary= True)
elif args.dataset == 'mnist':
	d = dataset.MNIST()
elif args.dataset == 'omniglot':
	d = dataset.Omniglot()
elif args.dataset == 'dsprites':
	d = dataset.DSprites()

if args.per_label is not None:
	d.shrink_supervised(int(args.per_label))


# name is important!  = filesave location                                                                                                                                     
if args.time is not None:
示例#2
0
文件: run.py 项目: brekelma/echo
parser.add_argument('--beta', type=float)
parser.add_argument('--validate', type=bool, default=1)
parser.add_argument('--verbose', type=bool, default=0)
parser.add_argument('--fit_gen', type=bool, default=1)
parser.add_argument('--fit_tf', type=bool, default=0)
parser.add_argument('--per_label')
parser.add_argument('--dataset', type=str, default='binary_mnist')
args, _ = parser.parse_known_args()

if ".json" in args.config:
    config = args.config
else:
    config = json.loads(args.config.replace("'", '"'))

if args.dataset == 'fmnist':
    d = dataset.fMNIST()
elif args.dataset == 'binary_fmnist':
    d = dataset.fMNIST(binary=True)
elif args.dataset == 'binary_mnist':
    d = dataset.MNIST(binary=True)
elif args.dataset == 'mnist':
    d = dataset.MNIST()
elif args.dataset in ['omniglot', 'omni']:
    d = dataset.Omniglot()
elif args.dataset == 'dsprites':
    d = dataset.DSprites()
elif args.dataset == "cifar10" or args.dataset == 'cifar':
    d = dataset.Cifar10()

if args.per_label is not None:
    d.shrink_supervised(args.per_label)