def get_problem(dataset, K, p, lmbd, rho, batch_size, save_dir): # Setup the training constant and a test set if dataset == 'artificial': from Lcod.simple_problem_generator import SimpleProblemGenerator from Lcod.simple_problem_generator import create_dictionary D = create_dictionary(K, p, seed=290890) pb = SimpleProblemGenerator(D, lmbd, rho=rho, batch_size=batch_size, corr=corr, seed=422742) elif dataset == 'mnist': from Lcod.mnist_problem_generator import MnistProblemGenerator from Lcod.mnist_problem_generator import create_dictionary_dl D = create_dictionary_dl(lmbd, K, N=10000, dir_mnist=save_dir) pb = MnistProblemGenerator(D, lmbd, batch_size=batch_size, dir_mnist=save_dir, seed=42242) elif dataset == 'images': from Lcod.image_problem_generator import ImageProblemGenerator from Lcod.image_problem_generator import create_dictionary_haar p = int(np.sqrt(p)) D = create_dictionary_haar(p, wavelet='haar') pb = ImageProblemGenerator(D, lmbd, batch_size=batch_size, seed=1234) else: raise NameError("dataset {} not reconized by the script" "".format(dataset)) return pb, D
def get_problem(config): # retrieve the parameter of the problem dataset = config['data'] batch_size, lmbd = config['batch_size'], config['lmbd'] seed = config.get('seed') # Setup the training constant and a test set if dataset == 'artificial': from Lcod.simple_problem_generator import SimpleProblemGenerator from Lcod.simple_problem_generator import create_dictionary # retrieve specific parameters for the problem K, p, rho = config['K'], config['p'], config['rho'] seed_D, corr = config.get('seed_D'), config.get('corr', 0) D = create_dictionary(K, p, seed=seed_D) pb = SimpleProblemGenerator(D, lmbd, rho=rho, batch_size=batch_size, corr=corr, seed=seed) elif dataset == 'adverse': from Lcod.simple_problem_generator import SimpleProblemGenerator from data_handlers.dictionaries import create_adversarial_dictionary # retrieve specific parameters for the problem K, p, rho = config['K'], config['p'], config['rho'] seed_D, corr = config.get('seed_D'), config.get('corr', 0) D = create_adversarial_dictionary(K, p, seed=seed_D) pb = SimpleProblemGenerator(D, lmbd, rho=rho, batch_size=batch_size, corr=corr, seed=seed) elif dataset == 'mnist': from Lcod.mnist_problem_generator import MnistProblemGenerator from Lcod.mnist_problem_generator import create_dictionary_dl K, save_dir = config['K'], config['save_dir'] D = create_dictionary_dl(lmbd, K, N=10000, dir_mnist=save_dir) pb = MnistProblemGenerator(D, lmbd, batch_size=batch_size, dir_mnist=save_dir, seed=seed) elif dataset == 'images': from Lcod.image_problem_generator import ImageProblemGenerator from Lcod.image_problem_generator import create_dictionary_haar p = config['p'] D = create_dictionary_haar(p) pb = ImageProblemGenerator(D, lmbd, batch_size=batch_size, seed=seed) else: raise NameError("dataset {} not reconized by the script" "".format(dataset)) return pb, D
save_dir = _assert_exist('save_exp', NAME_EXP) _assert_exist(save_dir, 'ckpt') save_curve = os.path.join(save_dir, "curve_cost.npy") # Setup the training constant and a test set if dataset == 'artificial': from Lcod.simple_problem_generator import SimpleProblemGenerator from Lcod.simple_problem_generator import create_dictionary p = 64 # Dimension of the data D = create_dictionary(K, p, seed=290890) pb = SimpleProblemGenerator(D, lmbd, rho=rho, batch_size=batch_size, corr=corr, seed=422742) elif dataset == 'mnist': from Lcod.mnist_problem_generator import MnistProblemGenerator from Lcod.mnist_problem_generator import create_dictionary_dl D = create_dictionary_dl(lmbd, K, N=10000, dir_mnist=save_dir) pb = MnistProblemGenerator(D, lmbd, batch_size=batch_size, dir_mnist=save_dir, seed=42242) elif dataset == 'images': from Lcod.image_problem_generator import ImageProblemGenerator from Lcod.image_problem_generator import create_dictionary_haar p = 8 reg_scale = 1e-4 D = create_dictionary_haar(p) pb = ImageProblemGenerator(D, lmbd, batch_size=batch_size, data_dir='data/VOC', seed=1234) elif dataset == 'cifar': from Lcod.cifar_generator import CifarProblemGenerator from Lcod.cifar_generator import create_dictionary_dl p = 8 reg_scale = 1e-4