예제 #1
0
def get_problem(dataset, K, p, lmbd, rho, batch_size, save_dir):
    # Setup the training constant and a test set
    if dataset == 'artificial':
        from Lcod.simple_problem_generator import SimpleProblemGenerator
        from Lcod.simple_problem_generator import create_dictionary
        D = create_dictionary(K, p, seed=290890)
        pb = SimpleProblemGenerator(D, lmbd, rho=rho, batch_size=batch_size,
                                    corr=corr, seed=422742)
    elif dataset == 'mnist':
        from Lcod.mnist_problem_generator import MnistProblemGenerator
        from Lcod.mnist_problem_generator import create_dictionary_dl
        D = create_dictionary_dl(lmbd, K, N=10000, dir_mnist=save_dir)
        pb = MnistProblemGenerator(D, lmbd, batch_size=batch_size,
                                   dir_mnist=save_dir, seed=42242)
    elif dataset == 'images':
        from Lcod.image_problem_generator import ImageProblemGenerator
        from Lcod.image_problem_generator import create_dictionary_haar
        p = int(np.sqrt(p))
        D = create_dictionary_haar(p, wavelet='haar')
        pb = ImageProblemGenerator(D, lmbd, batch_size=batch_size,
                                   seed=1234)
    else:
        raise NameError("dataset {} not reconized by the script"
                        "".format(dataset))
    return pb, D
def get_problem(config):

    # retrieve the parameter of the problem
    dataset = config['data']
    batch_size, lmbd = config['batch_size'], config['lmbd']
    seed = config.get('seed')

    # Setup the training constant and a test set
    if dataset == 'artificial':
        from Lcod.simple_problem_generator import SimpleProblemGenerator
        from Lcod.simple_problem_generator import create_dictionary

        # retrieve specific parameters for the problem
        K, p, rho = config['K'], config['p'], config['rho']
        seed_D, corr = config.get('seed_D'), config.get('corr', 0)
        D = create_dictionary(K, p, seed=seed_D)
        pb = SimpleProblemGenerator(D,
                                    lmbd,
                                    rho=rho,
                                    batch_size=batch_size,
                                    corr=corr,
                                    seed=seed)
    elif dataset == 'adverse':
        from Lcod.simple_problem_generator import SimpleProblemGenerator
        from data_handlers.dictionaries import create_adversarial_dictionary

        # retrieve specific parameters for the problem
        K, p, rho = config['K'], config['p'], config['rho']
        seed_D, corr = config.get('seed_D'), config.get('corr', 0)
        D = create_adversarial_dictionary(K, p, seed=seed_D)
        pb = SimpleProblemGenerator(D,
                                    lmbd,
                                    rho=rho,
                                    batch_size=batch_size,
                                    corr=corr,
                                    seed=seed)
    elif dataset == 'mnist':
        from Lcod.mnist_problem_generator import MnistProblemGenerator
        from Lcod.mnist_problem_generator import create_dictionary_dl
        K, save_dir = config['K'], config['save_dir']
        D = create_dictionary_dl(lmbd, K, N=10000, dir_mnist=save_dir)
        pb = MnistProblemGenerator(D,
                                   lmbd,
                                   batch_size=batch_size,
                                   dir_mnist=save_dir,
                                   seed=seed)
    elif dataset == 'images':
        from Lcod.image_problem_generator import ImageProblemGenerator
        from Lcod.image_problem_generator import create_dictionary_haar
        p = config['p']
        D = create_dictionary_haar(p)
        pb = ImageProblemGenerator(D, lmbd, batch_size=batch_size, seed=seed)
    else:
        raise NameError("dataset {} not reconized by the script"
                        "".format(dataset))
    return pb, D
예제 #3
0
def get_problem(config):

    # retrieve the parameter of the problem
    dataset = config['data']
    batch_size, lmbd = config['batch_size'], config['lmbd']
    seed = config.get('seed')

    # Setup the training constant and a test set
    if dataset == 'artificial':
        from Lcod.simple_problem_generator import SimpleProblemGenerator
        from Lcod.simple_problem_generator import create_dictionary

        # retrieve specific parameters for the problem
        K, p, rho = config['K'], config['p'], config['rho']
        seed_D, corr = config.get('seed_D'), config.get('corr', 0)
        D = create_dictionary(K, p, seed=seed_D)
        pb = SimpleProblemGenerator(D, lmbd, rho=rho, batch_size=batch_size,
                                    corr=corr, seed=seed)
    elif dataset == 'adverse':
        from Lcod.simple_problem_generator import SimpleProblemGenerator
        from data_handlers.dictionaries import create_adversarial_dictionary

        # retrieve specific parameters for the problem
        K, p, rho = config['K'], config['p'], config['rho']
        seed_D, corr = config.get('seed_D'), config.get('corr', 0)
        D = create_adversarial_dictionary(K, p, seed=seed_D)
        pb = SimpleProblemGenerator(D, lmbd, rho=rho, batch_size=batch_size,
                                    corr=corr, seed=seed)
    elif dataset == 'mnist':
        from Lcod.mnist_problem_generator import MnistProblemGenerator
        from Lcod.mnist_problem_generator import create_dictionary_dl
        K, save_dir = config['K'], config['save_dir']
        D = create_dictionary_dl(lmbd, K, N=10000, dir_mnist=save_dir)
        pb = MnistProblemGenerator(D, lmbd, batch_size=batch_size,
                                   dir_mnist=save_dir, seed=seed)
    elif dataset == 'images':
        from Lcod.image_problem_generator import ImageProblemGenerator
        from Lcod.image_problem_generator import create_dictionary_haar
        p = config['p']
        D = create_dictionary_haar(p)
        pb = ImageProblemGenerator(D, lmbd, batch_size=batch_size,
                                   seed=seed)
    else:
        raise NameError("dataset {} not reconized by the script"
                        "".format(dataset))
    return pb, D
예제 #4
0
     p = 64                 # Dimension of the data
     D = create_dictionary(K, p, seed=290890)
     pb = SimpleProblemGenerator(D, lmbd, rho=rho, batch_size=batch_size,
                                 corr=corr, seed=422742)
 elif dataset == 'mnist':
     from Lcod.mnist_problem_generator import MnistProblemGenerator
     from Lcod.mnist_problem_generator import create_dictionary_dl
     D = create_dictionary_dl(lmbd, K, N=10000, dir_mnist=save_dir)
     pb = MnistProblemGenerator(D, lmbd, batch_size=batch_size,
                                dir_mnist=save_dir, seed=42242)
 elif dataset == 'images':
     from Lcod.image_problem_generator import ImageProblemGenerator
     from Lcod.image_problem_generator import create_dictionary_haar
     p = 8
     reg_scale = 1e-4
     D = create_dictionary_haar(p)
     pb = ImageProblemGenerator(D, lmbd, batch_size=batch_size,
                                data_dir='data/VOC', seed=1234)
 elif dataset == 'cifar':
     from Lcod.cifar_generator import CifarProblemGenerator
     from Lcod.cifar_generator import create_dictionary_dl
     p = 8
     reg_scale = 1e-4
     D = create_dictionary_dl(lmbd,K)
     # D = np.random.random([1024,3072])
     pb = CifarProblemGenerator(D, lmbd, batch_size=batch_size,
                                 seed=1234)
 elif dataset == 'fruit':
     from Lcod.fruit_generator import FruitProblemGenerator
     from Lcod.fruit_generator import create_dictionary_dl
     p = 8