def train(): """ Example: >>> train() """ import random np.random.seed(1031726816 % 4294967295) torch.manual_seed(137852547 % 4294967295) random.seed(2497950049 % 4294967295) xpu = xpu_device.XPU.from_argv() print('Chosen xpu = {!r}'.format(xpu)) cifar_num = 10 if ub.argflag('--lab'): datasets = cifar_training_datasets(output_colorspace='LAB', norm_mode='independent', cifar_num=cifar_num) elif ub.argflag('--rgb'): datasets = cifar_training_datasets(output_colorspace='RGB', norm_mode='independent', cifar_num=cifar_num) elif ub.argflag('--rgb-dep'): datasets = cifar_training_datasets(output_colorspace='RGB', norm_mode='dependant', cifar_num=cifar_num) else: raise AssertionError('specify --rgb / --lab') import netharn.models.densenet # batch_size = (128 // 3) * 3 batch_size = 64 # initializer_ = (initializers.KaimingNormal, { # 'nonlinearity': 'relu', # }) lr = 0.1 initializer_ = (initializers.LSUV, {}) hyper = hyperparams.HyperParams( workdir=ub.ensuredir('train_cifar_work'), model=( netharn.models.densenet.DenseNet, { 'cifar': True, 'block_config': (32, 32, 32), # 100 layer depth 'num_classes': datasets['train'].n_classes, 'drop_rate': float(ub.argval('--drop_rate', default=.2)), 'groups': 1, }), optimizer=( torch.optim.SGD, { # 'weight_decay': .0005, 'weight_decay': float(ub.argval('--weight_decay', default=.0005)), 'momentum': 0.9, 'nesterov': True, 'lr': 0.1, }), scheduler=(nh.schedulers.ListedLR, { 'points': { 0: lr, 150: lr * 0.1, 250: lr * 0.01, }, 'interpolate': False }), monitor=(nh.Monitor, { 'minimize': ['loss'], 'maximize': ['mAP'], 'patience': 314, 'max_epoch': 314, }), initializer=initializer_, criterion=(torch.nn.CrossEntropyLoss, {}), # Specify anything else that is special about your hyperparams here # Especially if you make a custom_batch_runner augment=str(datasets['train'].augmenter), other=ub.dict_union( { # TODO: type of augmentation as a parameter dependency # 'augmenter': str(datasets['train'].augmenter), # 'augment': datasets['train'].augment, 'batch_size': batch_size, 'colorspace': datasets['train'].output_colorspace, 'n_classes': datasets['train'].n_classes, # 'center_inputs': datasets['train'].center_inputs, }, datasets['train'].center_inputs.__dict__), ) # if ub.argflag('--rgb-indie'): # hyper.other['norm'] = 'dependant' hyper.input_ids['train'] = datasets['train'].input_id xpu = xpu_device.XPU.cast('auto') print('xpu = {}'.format(xpu)) data_kw = {'batch_size': batch_size} if xpu.is_gpu(): data_kw.update({'num_workers': 8, 'pin_memory': True}) tags = ['train', 'vali', 'test'] loaders = ub.odict() for tag in tags: dset = datasets[tag] shuffle = tag == 'train' data_kw_ = data_kw.copy() if tag != 'train': data_kw_['batch_size'] = max(batch_size // 4, 1) loader = torch.utils.data.DataLoader(dset, shuffle=shuffle, **data_kw_) loaders[tag] = loader harn = fit_harness.FitHarness( hyper=hyper, datasets=datasets, xpu=xpu, loaders=loaders, ) # harn.monitor = early_stop.EarlyStop(patience=40) harn.monitor = monitor.Monitor(min_keys=['loss'], max_keys=['global_acc', 'class_acc'], patience=40) harn.initialize() harn.run()
def __init__(harn, hyper=None, train_dpath=None): if isinstance(hyper, dict): hyper = hyperparams.HyperParams(**hyper) harn.hyper = hyper harn.main_prog = None harn.datasets = None harn.loaders = None harn.model = None harn.optimizer = None harn.scheduler = None harn.monitor = None harn.criterion = None harn.dynamics = {'batch_step': 1} harn.paths = None harn.train_dpath = train_dpath harn.nice_dpath = None harn._initialized = False harn.flog = None harn.tlog = None harn.dry = None # Track current epoch number harn.epoch = 0 # Track current iteration within an epoch harn.bxs = { 'train': 0, 'vali': 0, 'test': 0, } harn.intervals = { 'display_train': 1, 'display_vali': 1, 'display_test': 1, 'log_iter_train': None, 'log_iter_test': None, 'log_iter_vali': None, 'vali': 1, 'test': 1, # how often to take a snapshot 'snapshot': 1, # how often to remove old snapshots 'cleanup': 10, } harn.config = { 'show_prog': True, 'use_tqdm': None, 'prog_backend': 'tqdm', # A loss that would be considered large 'large_loss': 100, # number of recent / best snapshots to keep 'num_keep': 10, 'keep_freq': 10, } harn.current_tag = None