def load_snapshot(pharn, load_path): print('Loading snapshot onto {}'.format(pharn.xpu)) snapshot = torch.load(load_path, map_location=pharn.xpu.map_location()) if 'model_kw' not in snapshot: # FIXME: we should be able to get information from the snapshot print('warning snapshot not saved with modelkw') n_classes = pharn.dataset.n_classes n_channels = pharn.dataset.n_channels # Infer which model this belongs to # FIXME: The model must be constructed with the EXACT same kwargs This # will be easier when onnx supports model serialization. if snapshot['model_class_name'] == 'UNet': pharn.model = models.UNet(in_channels=n_channels, n_classes=n_classes, nonlinearity='leaky_relu') elif snapshot['model_class_name'] == 'UNet2': pharn.model = unet2.UNet2( in_channels=n_channels, n_classes=n_classes, n_alt_classes=3, nonlinearity='leaky_relu' ) elif snapshot['model_class_name'] == 'DenseUNet': pharn.model = unet3.DenseUNet( in_channels=n_channels, n_classes=n_classes, n_alt_classes=3, ) else: raise NotImplementedError(snapshot['model_class_name']) pharn.model = pharn.xpu.to_xpu(pharn.model) pharn.model.load_state_dict(snapshot['model_state_dict'])
def fit_networks(datasets, xpu): print('datasets = {}'.format(datasets)) n_classes = datasets['train'].n_classes n_channels = datasets['train'].n_channels class_weights = datasets['train'].class_weights() ignore_label = datasets['train'].ignore_label print('n_classes = {!r}'.format(n_classes)) print('n_channels = {!r}'.format(n_channels)) arches = [ 'unet2', 'dense_unet', ] arch_to_train_dpath = {} arch_to_best_epochs = {} for arch in arches: hyper = hyperparams.HyperParams( criterion=(criterions.CrossEntropyLoss2D, { 'ignore_label': ignore_label, # TODO: weight should be a FloatTensor 'weight': class_weights, }), optimizer=(torch.optim.SGD, { # 'weight_decay': .0006, 'weight_decay': .0005, 'momentum': .9, 'nesterov': True, }), scheduler=('Exponential', { 'gamma': 0.99, 'base_lr': 0.001, 'stepsize': 2, }), other={ 'n_classes': n_classes, 'n_channels': n_channels, 'augment': datasets['train'].augment, 'colorspace': datasets['train'].colorspace, } ) train_dpath = ub.ensuredir((datasets['train'].task.workdir, 'train', arch)) train_info = { 'arch': arch, 'train_id': datasets['train'].input_id, 'train_hyper_id': hyper.hyper_id(), 'colorspace': datasets['train'].colorspace, # Hack in centering information 'hack_centers': [ (t.__class__.__name__, t.__getstate__()) for t in datasets['train'].center_inputs.transforms ] } util.write_json(join(train_dpath, 'train_info.json'), train_info) arch_to_train_dpath[arch] = train_dpath if arch == 'unet2': batch_size = 14 model = unet2.UNet2(n_alt_classes=3, in_channels=n_channels, n_classes=n_classes, nonlinearity='leaky_relu') elif arch == 'dense_unet': batch_size = 6 model = unet3.DenseUNet(n_alt_classes=3, in_channels=n_channels, n_classes=n_classes) dry = 0 harn = fit_harn2.FitHarness( model=model, hyper=hyper, datasets=datasets, xpu=xpu, train_dpath=train_dpath, dry=dry, batch_size=batch_size, ) harn.criterion2 = criterions.CrossEntropyLoss2D( weight=torch.FloatTensor([0.1, 1.0, 0.0]), ignore_label=2 ) if DEBUG: harn.config['max_iter'] = 30 else: # Note on aretha we can do 140 epochs in 7 days, so # be careful with how long we take to train. # With a reduction of 16, we can take a few more epochs # Unet2 take ~10 minutes to get through one # with num_workers=0, we have 374.00s/it = 6.23 m/it # this comes down to 231 epochs per day # harn.config['max_iter'] = 432 # 3 days max harn.config['max_iter'] = 200 # ~1 day max (if multiprocessing works) harn.early_stop.patience = 10 def compute_loss(harn, outputs, labels): output1, output2 = outputs label1, label2 = labels # Compute the loss loss1 = harn.criterion(output1, label1) loss2 = harn.criterion2(output2, label2) loss = (.45 * loss1 + .55 * loss2) return loss harn.compute_loss = compute_loss def custom_metrics(harn, output, label): ignore_label = datasets['train'].ignore_label labels = datasets['train'].task.labels metrics_dict = metrics._sseg_metrics(output[1], label[1], labels=labels, ignore_label=ignore_label) return metrics_dict harn.add_metric_hook(custom_metrics) harn.run() arch_to_best_epochs[arch] = harn.early_stop.best_epochs() # Select model and hyperparams print('arch_to_train_dpath = {}'.format(ub.repr2(arch_to_train_dpath, nl=1))) print('arch_to_best_epochs = {}'.format(ub.repr2(arch_to_best_epochs, nl=1))) return arch_to_train_dpath, arch_to_best_epochs
def urban_fit(): """ CommandLine: python -m clab.live.urban_train urban_fit --profile python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=segnet python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=unet --noaux python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=unet python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --dry python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=unet --colorspace=RGB --combine python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=unet --dry python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=unet2 --colorspace=RGB --combine python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=unet2 --colorspace=RGB --use_aux_diff python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=dense_unet --colorspace=RGB --use_aux_diff # Train a variant of the dense net with more parameters python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=dense_unet --colorspace=RGB --use_aux_diff --combine \ --pretrained '/home/local/KHQ/jon.crall/data/work/urban_mapper4/arch/dense_unet/train/input_25800-phpjjsqu/solver_25800-phpjjsqu_dense_unet_mmavmuou_zeosddyf_a=1,c=RGB,n_ch=6,n_cl=4/torch_snapshots/_epoch_00000030.pt' --gpu=1 # Fine tune the model using all the available data python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=unet2 --colorspace=RGB --use_aux_diff --combine \ --pretrained '/home/local/KHQ/jon.crall/data/work/urban_mapper2/arch/unet2/train/input_25800-hemanvft/solver_25800-hemanvft_unet2_mmavmuou_stuyuerd_a=1,c=RGB,n_ch=6,n_cl=4/torch_snapshots/_epoch_00000041.pt' --gpu=3 --finetune # Keep a bit of the data for validation but use more python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=unet2 --colorspace=RGB --use_aux_diff --halfcombo \ --pretrained '/home/local/KHQ/jon.crall/data/work/urban_mapper2/arch/unet2/train/input_25800-hemanvft/solver_25800-hemanvft_unet2_mmavmuou_stuyuerd_a=1,c=RGB,n_ch=6,n_cl=4/torch_snapshots/_epoch_00000041.pt' --gpu=3 Example: >>> from clab.torch.fit_harness import * >>> harn = urban_fit() """ arch = ub.argval('--arch', default='unet') colorspace = ub.argval('--colorspace', default='RGB').upper() datasets = load_task_dataset('urban_mapper_3d', colorspace=colorspace, arch=arch) datasets['train'].augment = True # Make sure we use consistent normalization # TODO: give normalization a part of the hashid # TODO: save normalization type with the model # datasets['train'].center_inputs = datasets['train']._make_normalizer() # if ub.argflag('--combine'): # # custom centering from the initialization point I'm going to use # datasets['train'].center_inputs = datasets['train']._custom_urban_mapper_normalizer( # 0.3750553785198646, 1.026544662398811, 2.5136079110849674) # else: # datasets['train'].center_inputs = datasets['train']._make_normalizer(mode=2) datasets['train'].center_inputs = datasets['train']._make_normalizer( mode=3) # datasets['train'].center_inputs = _custom_urban_mapper_normalizer(0, 1, 2.5) datasets['test'].center_inputs = datasets['train'].center_inputs datasets['vali'].center_inputs = datasets['train'].center_inputs # Ensure normalization is the same for each dataset datasets['train'].augment = True # turn off aux layers if ub.argflag('--noaux'): for v in datasets.values(): v.aux_keys = [] batch_size = 14 if arch == 'segnet': batch_size = 6 elif arch == 'dense_unet': batch_size = 6 # dense_unet batch memsizes # idle = 11 MiB # 0 = 438 MiB # 3 ~= 5000 MiB # 5 = 8280 MiB # 6 = 9758 MiB # each image adds (1478 - 1568.4) MiB n_classes = datasets['train'].n_classes n_channels = datasets['train'].n_channels class_weights = datasets['train'].class_weights() ignore_label = datasets['train'].ignore_label print('n_classes = {!r}'.format(n_classes)) print('n_channels = {!r}'.format(n_channels)) print('batch_size = {!r}'.format(batch_size)) hyper = hyperparams.HyperParams( criterion=( criterions.CrossEntropyLoss2D, { 'ignore_label': ignore_label, # TODO: weight should be a FloatTensor 'weight': class_weights, }), optimizer=( torch.optim.SGD, { # 'weight_decay': .0006, 'weight_decay': .0005, 'momentum': 0.99 if arch == 'dense_unet' else .9, 'nesterov': True, }), scheduler=( 'Exponential', { 'gamma': 0.99, # 'base_lr': 0.0015, 'base_lr': 0.001 if not ub.argflag('--halfcombo') else 0.0005, 'stepsize': 2, }), other={ 'n_classes': n_classes, 'n_channels': n_channels, 'augment': datasets['train'].augment, 'colorspace': datasets['train'].colorspace, }) starting_points = { 'unet_rgb_4k': ub.truepath( '~/remote/aretha/data/work/urban_mapper/arch/unet/train/input_4214-yxalqwdk/solver_4214-yxalqwdk_unet_vgg_nttxoagf_a=1,n_ch=5,n_cl=3/torch_snapshots/_epoch_00000236.pt' ), # 'unet_rgb_8k': ub.truepath('~/remote/aretha/data/work/urban_mapper/arch/unet/train/input_8438-haplmmpq/solver_8438-haplmmpq_unet_None_kvterjeu_a=1,c=RGB,n_ch=5,n_cl=3/torch_snapshots/_epoch_00000402.pt'), # "ImageCenterScale", {"im_mean": [[[0.3750553785198646]]], "im_scale": [[[1.026544662398811]]]} # "DTMCenterScale", "std": 2.5136079110849674, "nan_value": -32767.0 } # 'unet_rgb_8k': ub.truepath( # '~/data/work/urban_mapper2/arch/unet/train/input_4214-guwsobde/' # 'solver_4214-guwsobde_unet_mmavmuou_eqnoygqy_a=1,c=RGB,n_ch=5,n_cl=4/torch_snapshots/_epoch_00000189.pt' # ) 'unet_rgb_8k': ub.truepath( '~/remote/aretha/data/work/urban_mapper2/arch/unet2/train/input_4214-guwsobde/' 'solver_4214-guwsobde_unet2_mmavmuou_tqynysqo_a=1,c=RGB,n_ch=5,n_cl=4/torch_snapshots/_epoch_00000100.pt' ) } pretrained = ub.argval('--pretrained', default=None) if pretrained is None: if arch == 'segnet': pretrained = 'vgg' else: pretrained = None if ub.argflag('--combine'): pretrained = starting_points['unet_rgb_8k'] if arch == 'unet2': pretrained = '/home/local/KHQ/jon.crall/data/work/urban_mapper2/arch/unet2/train/input_25800-hemanvft/solver_25800-hemanvft_unet2_mmavmuou_stuyuerd_a=1,c=RGB,n_ch=6,n_cl=4/torch_snapshots/_epoch_00000042.pt' elif arch == 'dense_unet2': pretrained = '/home/local/KHQ/jon.crall/data/work/urban_mapper2/arch/unet2/train/input_25800-hemanvft/solver_25800-hemanvft_unet2_mmavmuou_stuyuerd_a=1,c=RGB,n_ch=6,n_cl=4/torch_snapshots/_epoch_00000042.pt' else: pretrained = starting_points['unet_rgb_4k'] train_dpath = directory_structure(datasets['train'].task.workdir, arch, datasets, pretrained=pretrained, train_hyper_id=hyper.hyper_id(), suffix='_' + hyper.other_id()) print('arch = {!r}'.format(arch)) dry = ub.argflag('--dry') if dry: model = None elif arch == 'segnet': model = models.SegNet(in_channels=n_channels, n_classes=n_classes) model.init_he_normal() if pretrained == 'vgg': model.init_vgg16_params() elif arch == 'linknet': model = models.LinkNet(in_channels=n_channels, n_classes=n_classes) elif arch == 'unet': model = models.UNet(in_channels=n_channels, n_classes=n_classes, nonlinearity='leaky_relu') snapshot = xpu_device.XPU(None).load(pretrained) model_state_dict = snapshot['model_state_dict'] model.load_partial_state(model_state_dict) # model.shock_outward() elif arch == 'unet2': from clab.live import unet2 model = unet2.UNet2(n_alt_classes=3, in_channels=n_channels, n_classes=n_classes, nonlinearity='leaky_relu') snapshot = xpu_device.XPU(None).load(pretrained) model_state_dict = snapshot['model_state_dict'] model.load_partial_state(model_state_dict) elif arch == 'dense_unet': from clab.live import unet3 model = unet3.DenseUNet(n_alt_classes=3, in_channels=n_channels, n_classes=n_classes) model.init_he_normal() snapshot = xpu_device.XPU(None).load(pretrained) model_state_dict = snapshot['model_state_dict'] model.load_partial_state(model_state_dict) elif arch == 'dense_unet2': from clab.live import unet3 model = unet3.DenseUNet2(n_alt_classes=3, in_channels=n_channels, n_classes=n_classes) # model.init_he_normal() snapshot = xpu_device.XPU(None).load(pretrained) model_state_dict = snapshot['model_state_dict'] model.load_partial_state(model_state_dict, shock_partial=False) elif arch == 'dummy': model = models.SSegDummy(in_channels=n_channels, n_classes=n_classes) else: raise ValueError('unknown arch') if ub.argflag('--finetune'): # Hack in a reduced learning rate hyper = hyperparams.HyperParams( criterion=( criterions.CrossEntropyLoss2D, { 'ignore_label': ignore_label, # TODO: weight should be a FloatTensor 'weight': class_weights, }), optimizer=( torch.optim.SGD, { # 'weight_decay': .0006, 'weight_decay': .0005, 'momentum': 0.99 if arch == 'dense_unet' else .9, 'nesterov': True, }), scheduler=('Constant', { 'base_lr': 0.0001, }), other={ 'n_classes': n_classes, 'n_channels': n_channels, 'augment': datasets['train'].augment, 'colorspace': datasets['train'].colorspace, }) xpu = xpu_device.XPU.from_argv() if datasets['train'].use_aux_diff: # arch in ['unet2', 'dense_unet']: from clab.live import fit_harn2 harn = fit_harn2.FitHarness( model=model, hyper=hyper, datasets=datasets, xpu=xpu, train_dpath=train_dpath, dry=dry, batch_size=batch_size, ) harn.criterion2 = criterions.CrossEntropyLoss2D( weight=torch.FloatTensor([.1, 1, 0]), ignore_label=2) def compute_loss(harn, outputs, labels): output1, output2 = outputs label1, label2 = labels # Compute the loss loss1 = harn.criterion(output1, label1) loss2 = harn.criterion2(output2, label2) loss = (.45 * loss1 + .55 * loss2) return loss harn.compute_loss = compute_loss # z = harn.loaders['train'] # b = next(iter(z)) # print('b = {!r}'.format(b)) # import sys # sys.exit(0) def custom_metrics(harn, output, label): ignore_label = datasets['train'].ignore_label labels = datasets['train'].task.labels metrics_dict = metrics._sseg_metrics(output[1], label[1], labels=labels, ignore_label=ignore_label) return metrics_dict else: harn = fit_harness.FitHarness( model=model, hyper=hyper, datasets=datasets, xpu=xpu, train_dpath=train_dpath, dry=dry, batch_size=batch_size, ) def custom_metrics(harn, output, label): ignore_label = datasets['train'].ignore_label labels = datasets['train'].task.labels metrics_dict = metrics._sseg_metrics(output, label, labels=labels, ignore_label=ignore_label) return metrics_dict harn.add_metric_hook(custom_metrics) # HACK # im = datasets['train'][0][0] # w, h = im.shape[-2:] # single_output_shape = (n_classes, w, h) # harn.single_output_shape = single_output_shape # print('harn.single_output_shape = {!r}'.format(harn.single_output_shape)) harn.run() return harn