예제 #1
0
파일: config.py 프로젝트: jydennis/thexp
    under the GNU General Public License as published by the Free 
    Software Foundation, either Version 3 of the License, or (at your option) 
    any later version, if this derivative work is distributed to a third party.

    The copyright for the program is owned by Shandong University. 
    For commercial projects that require the ability to distribute 
    the code of this program as part of a program that cannot be 
    distributed under the GNU General Public License, please contact 
            
            [email protected]
             
    to purchase a commercial license.
"""
from thexp import Params

params = Params()
params.device = 'cuda:1'
params.epoch = 5
params.batch_size = 128
params.topk = (1, 4)
params.from_args()
params.root = '/home/share/yanghaozhe/pytorchdataset'
params.dataloader = dict(shuffle=True, batch_size=32, drop_last=True)
params.optim = dict(lr=0.01, weight_decay=0.09, momentum=0.9)
params.choice('dataset', 'mnist', 'fmnist')
params.dataset = 'mnist'
params.bind('dataset', 'mnist', 'arch', 'simple')
params.bind('dataset', 'fmnist', 'arch', 'simple')
params.bind('dataset', 'cifar10', 'arch', 'cnn13')
params.ema = True
예제 #2
0

class MyParams(Params):
    def __init__(self):
        super().__init__()
        self.batch_size = 50
        self.topk = (1, 2, 3, 4)
        self.optim = dict(lr=0.009, moment=0.9)


params = MyParams()
print(params)

from thexp import Params
params = Params()
params.choice("dataset", "mnist", "cifar10", "cifar100", "svhn")
params.arange("thresh", 5, 0, 20)
print(params)

# for g in params.grid_search("thresh",range(0,20)):
#     for g in g.grid_search("dataset",['cifar10','cifar100','svhn']):
#         print(g.dataset,g.thresh)

params.bind('dataset', 'mnist', 'arch', 'simplenet')
params.bind('dataset', 'cifar10', 'arch', 'cnn13')
params.bind('arch', 'simplenet', 'arch_param', dict(feature=128))
params.bind('arch', 'cnn13', 'arch_param', dict(feature=256))
params.dataset = 'cifar10'
print(params.arch)
print(params.arch_param)
params.dataset = 'mnist'