def map_location(storage: torch.Storage, location) -> torch.Storage: """ Args: storage (torch.Storage) : the initial deserialization of the storage of the data read by `torch.load`, residing on the CPU. location (str): tag identifiying the location the data being read by `torch.load` was originally saved from. Returns: torch.Storage : the storage """ xpu = xpu_device.XPU(gpu_num) if xpu.is_gpu(): return storage.cuda(xpu.num) else: return storage
def __init__(harn, model, datasets, batch_size=4, criterion_cls='cross_entropy', hyper=None, xpu=None, train_dpath=None, dry=False): harn.dry = dry if harn.dry: train_dpath = ub.ensure_app_cache_dir('clab/dry') ub.delete(train_dpath) train_dpath = ub.ensure_app_cache_dir('clab/dry') harn.dry = dry harn.train_dpath = train_dpath harn.snapshot_dpath = ub.ensuredir( (harn.train_dpath, 'torch_snapshots')) if harn.dry: harn.xpu = xpu_device.XPU(None) else: harn.xpu = xpu_device.XPU(xpu) data_kw = {'batch_size': batch_size} if harn.xpu.is_gpu(): data_kw.update({'num_workers': 6, 'pin_memory': True}) harn.loaders = {} harn.datasets = datasets for tag, dset in datasets.items(): assert tag in {'train', 'vali', 'test'} shuffle = tag == 'train' data_kw_ = data_kw.copy() if tag != 'train': data_kw_['batch_size'] = max(batch_size // 4, 1) loader = torch.utils.data.DataLoader(dset, shuffle=shuffle, **data_kw_) harn.loaders[tag] = loader harn.model = model harn.hyper = hyper harn.lr_scheduler = hyper.scheduler_cls(**hyper.scheduler_params) harn.criterion_cls = hyper.criterion_cls harn.optimizer_cls = hyper.optimizer_cls harn.criterion_params = hyper.criterion_params harn.optimizer_params = hyper.optimizer_params harn._metric_hooks = [] harn._run_metrics = None harn._epoch_callbacks = [] harn._iter_callbacks = [] harn.intervals = { 'display_train': 1, 'display_vali': 1, 'display_test': 1, 'vali': 1, 'test': 1, 'snapshot': 1, } harn.config = { 'max_iter': 1000, } harn.epoch = 0
def __init__(harn, model, datasets, batch_size=4, criterion_cls='cross_entropy', hyper=None, xpu=None, train_dpath=None, dry=False): harn.dry = dry if harn.dry: train_dpath = ub.ensure_app_cache_dir('clab/dry') ub.delete(train_dpath) train_dpath = ub.ensure_app_cache_dir('clab/dry') harn.dry = dry harn.train_dpath = train_dpath harn.snapshot_dpath = ub.ensuredir( (harn.train_dpath, 'torch_snapshots')) if harn.dry: harn.xpu = xpu_device.XPU(None) else: harn.xpu = xpu_device.XPU(xpu) # Allow for command line override batch_size = int(ub.argval('--batch_size', default=batch_size)) data_kw = {'batch_size': batch_size} if harn.xpu.is_gpu(): num_workers = int(ub.argval('--num_workers', default=6)) num_workers = 0 if ub.argflag('--serial') else num_workers pin_memory = False if ub.argflag('--nopin') else True data_kw.update({ 'num_workers': num_workers, 'pin_memory': pin_memory }) # data_kw.update({'num_workers': 0, 'pin_memory': False}) harn.loaders = ub.odict() harn.datasets = datasets assert set(harn.datasets.keys()).issubset({'train', 'vali', 'test'}) for tag in ['train', 'vali', 'test']: dset = harn.datasets.get(tag, None) if dset: shuffle = tag == 'train' data_kw_ = data_kw.copy() if tag != 'train': tag_batch_size = max(batch_size // 4, 1) tag_batch_size = int( ub.argval('--{}-batch_size'.format(tag), default=tag_batch_size)) data_kw_['batch_size'] = tag_batch_size loader = torch.utils.data.DataLoader(dset, shuffle=shuffle, **data_kw_) harn.loaders[tag] = loader harn.model = model harn.hyper = hyper harn.lr_scheduler = hyper.scheduler_cls(**hyper.scheduler_params) harn.criterion_cls = hyper.criterion_cls harn.optimizer_cls = hyper.optimizer_cls harn.criterion_params = hyper.criterion_params harn.optimizer_params = hyper.optimizer_params harn._metric_hooks = [] harn._run_metrics = None harn._epoch_callbacks = [] harn._iter_callbacks = [] harn.early_stop = EarlyStop() harn.intervals = { 'display_train': 1, 'display_vali': 1, 'display_test': 1, 'vali': 1, 'test': 1, 'snapshot': 1, } harn.config = { 'max_iter': 500, } harn.epoch = 0 harn.tlogger = None harn.prog = None
def demo(): """ python -m clab.live.fit_harn2 demo Example: >>> from clab.live.fit_harn2 import * >>> demo() """ from clab.torch import hyperparams from clab.torch import layers transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), ]) n_classes = 2 datakw = dict( image_size=(1, 16, 16), num_classes=n_classes, transform=transform, ) SSEG = 1 if SSEG: datakw['image_size'] = (1, 200, 200) def _gt_as_sseg_mask(gt): # transform groundtruth into a random semantic segmentation mask return (torch.rand(datakw['image_size'][1:]) * n_classes).long() datakw['target_transform'] = _gt_as_sseg_mask factor = 1 else: factor = 200 datasets = { 'train': torchvision.datasets.FakeData(size=10 * factor, random_offset=0, **datakw), 'vali': torchvision.datasets.FakeData(size=5 * factor, random_offset=110000, **datakw), 'test': torchvision.datasets.FakeData(size=2 * factor, random_offset=1110000, **datakw), } class DummyModel(torch.nn.Module): def __init__(self, n_classes): super(DummyModel, self).__init__() self.seq = torch.nn.Sequential(*[ layers.Flatten(), torch.nn.Linear(int(np.prod(datakw['image_size'])), 10), torch.nn.ReLU(inplace=True), torch.nn.Linear(10, n_classes), ]) self.n_classes = n_classes def forward(self, inputs): if isinstance(inputs, (tuple, list)): assert len(inputs) == 1 inputs = inputs[0] return self.seq(inputs) if SSEG: from clab.torch.models import unet from clab.torch import criterions model = unet.UNet(in_channels=1, n_classes=n_classes, feature_scale=64) hyper = hyperparams.HyperParams( criterion_cls=criterions.CrossEntropyLoss2D, ) else: model = DummyModel(n_classes) hyper = hyperparams.HyperParams( criterion_cls=torch.nn.CrossEntropyLoss, ) xpu = xpu_device.XPU() # hack def compute_loss(harn, outputs, labels): # Compute the loss if isinstance(outputs, list): outputs = outputs[0] if isinstance(labels, list): target = labels[0] if SSEG: loss = harn.criterion(outputs, target) else: target = labels[0].long() pred = torch.nn.functional.log_softmax(outputs, dim=1) loss = torch.nn.functional.nll_loss(pred, target) return loss train_dpath = ub.ensure_app_cache_dir('clab/demo/fit_harness') ub.delete(train_dpath) dry = 0 harn = FitHarness( model=model, datasets=datasets, xpu=xpu, hyper=hyper, train_dpath=train_dpath, dry=dry, ) harn.compute_loss = compute_loss harn.run() if not dry: # reload and continue training the model # stopping criterion should trigger immediately harn = FitHarness( model=model, datasets=datasets, xpu=xpu, hyper=hyper, train_dpath=train_dpath, dry=dry, ) harn.compute_loss = compute_loss harn.run()
def urban_fit(): """ CommandLine: python -m clab.live.urban_train urban_fit --profile python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=segnet python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=unet --noaux python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=unet python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --dry python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=unet --colorspace=RGB --combine python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=unet --dry python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=unet2 --colorspace=RGB --combine python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=unet2 --colorspace=RGB --use_aux_diff python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=dense_unet --colorspace=RGB --use_aux_diff # Train a variant of the dense net with more parameters python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=dense_unet --colorspace=RGB --use_aux_diff --combine \ --pretrained '/home/local/KHQ/jon.crall/data/work/urban_mapper4/arch/dense_unet/train/input_25800-phpjjsqu/solver_25800-phpjjsqu_dense_unet_mmavmuou_zeosddyf_a=1,c=RGB,n_ch=6,n_cl=4/torch_snapshots/_epoch_00000030.pt' --gpu=1 # Fine tune the model using all the available data python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=unet2 --colorspace=RGB --use_aux_diff --combine \ --pretrained '/home/local/KHQ/jon.crall/data/work/urban_mapper2/arch/unet2/train/input_25800-hemanvft/solver_25800-hemanvft_unet2_mmavmuou_stuyuerd_a=1,c=RGB,n_ch=6,n_cl=4/torch_snapshots/_epoch_00000041.pt' --gpu=3 --finetune # Keep a bit of the data for validation but use more python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=unet2 --colorspace=RGB --use_aux_diff --halfcombo \ --pretrained '/home/local/KHQ/jon.crall/data/work/urban_mapper2/arch/unet2/train/input_25800-hemanvft/solver_25800-hemanvft_unet2_mmavmuou_stuyuerd_a=1,c=RGB,n_ch=6,n_cl=4/torch_snapshots/_epoch_00000041.pt' --gpu=3 Example: >>> from clab.torch.fit_harness import * >>> harn = urban_fit() """ arch = ub.argval('--arch', default='unet') colorspace = ub.argval('--colorspace', default='RGB').upper() datasets = load_task_dataset('urban_mapper_3d', colorspace=colorspace, arch=arch) datasets['train'].augment = True # Make sure we use consistent normalization # TODO: give normalization a part of the hashid # TODO: save normalization type with the model # datasets['train'].center_inputs = datasets['train']._make_normalizer() # if ub.argflag('--combine'): # # custom centering from the initialization point I'm going to use # datasets['train'].center_inputs = datasets['train']._custom_urban_mapper_normalizer( # 0.3750553785198646, 1.026544662398811, 2.5136079110849674) # else: # datasets['train'].center_inputs = datasets['train']._make_normalizer(mode=2) datasets['train'].center_inputs = datasets['train']._make_normalizer( mode=3) # datasets['train'].center_inputs = _custom_urban_mapper_normalizer(0, 1, 2.5) datasets['test'].center_inputs = datasets['train'].center_inputs datasets['vali'].center_inputs = datasets['train'].center_inputs # Ensure normalization is the same for each dataset datasets['train'].augment = True # turn off aux layers if ub.argflag('--noaux'): for v in datasets.values(): v.aux_keys = [] batch_size = 14 if arch == 'segnet': batch_size = 6 elif arch == 'dense_unet': batch_size = 6 # dense_unet batch memsizes # idle = 11 MiB # 0 = 438 MiB # 3 ~= 5000 MiB # 5 = 8280 MiB # 6 = 9758 MiB # each image adds (1478 - 1568.4) MiB n_classes = datasets['train'].n_classes n_channels = datasets['train'].n_channels class_weights = datasets['train'].class_weights() ignore_label = datasets['train'].ignore_label print('n_classes = {!r}'.format(n_classes)) print('n_channels = {!r}'.format(n_channels)) print('batch_size = {!r}'.format(batch_size)) hyper = hyperparams.HyperParams( criterion=( criterions.CrossEntropyLoss2D, { 'ignore_label': ignore_label, # TODO: weight should be a FloatTensor 'weight': class_weights, }), optimizer=( torch.optim.SGD, { # 'weight_decay': .0006, 'weight_decay': .0005, 'momentum': 0.99 if arch == 'dense_unet' else .9, 'nesterov': True, }), scheduler=( 'Exponential', { 'gamma': 0.99, # 'base_lr': 0.0015, 'base_lr': 0.001 if not ub.argflag('--halfcombo') else 0.0005, 'stepsize': 2, }), other={ 'n_classes': n_classes, 'n_channels': n_channels, 'augment': datasets['train'].augment, 'colorspace': datasets['train'].colorspace, }) starting_points = { 'unet_rgb_4k': ub.truepath( '~/remote/aretha/data/work/urban_mapper/arch/unet/train/input_4214-yxalqwdk/solver_4214-yxalqwdk_unet_vgg_nttxoagf_a=1,n_ch=5,n_cl=3/torch_snapshots/_epoch_00000236.pt' ), # 'unet_rgb_8k': ub.truepath('~/remote/aretha/data/work/urban_mapper/arch/unet/train/input_8438-haplmmpq/solver_8438-haplmmpq_unet_None_kvterjeu_a=1,c=RGB,n_ch=5,n_cl=3/torch_snapshots/_epoch_00000402.pt'), # "ImageCenterScale", {"im_mean": [[[0.3750553785198646]]], "im_scale": [[[1.026544662398811]]]} # "DTMCenterScale", "std": 2.5136079110849674, "nan_value": -32767.0 } # 'unet_rgb_8k': ub.truepath( # '~/data/work/urban_mapper2/arch/unet/train/input_4214-guwsobde/' # 'solver_4214-guwsobde_unet_mmavmuou_eqnoygqy_a=1,c=RGB,n_ch=5,n_cl=4/torch_snapshots/_epoch_00000189.pt' # ) 'unet_rgb_8k': ub.truepath( '~/remote/aretha/data/work/urban_mapper2/arch/unet2/train/input_4214-guwsobde/' 'solver_4214-guwsobde_unet2_mmavmuou_tqynysqo_a=1,c=RGB,n_ch=5,n_cl=4/torch_snapshots/_epoch_00000100.pt' ) } pretrained = ub.argval('--pretrained', default=None) if pretrained is None: if arch == 'segnet': pretrained = 'vgg' else: pretrained = None if ub.argflag('--combine'): pretrained = starting_points['unet_rgb_8k'] if arch == 'unet2': pretrained = '/home/local/KHQ/jon.crall/data/work/urban_mapper2/arch/unet2/train/input_25800-hemanvft/solver_25800-hemanvft_unet2_mmavmuou_stuyuerd_a=1,c=RGB,n_ch=6,n_cl=4/torch_snapshots/_epoch_00000042.pt' elif arch == 'dense_unet2': pretrained = '/home/local/KHQ/jon.crall/data/work/urban_mapper2/arch/unet2/train/input_25800-hemanvft/solver_25800-hemanvft_unet2_mmavmuou_stuyuerd_a=1,c=RGB,n_ch=6,n_cl=4/torch_snapshots/_epoch_00000042.pt' else: pretrained = starting_points['unet_rgb_4k'] train_dpath = directory_structure(datasets['train'].task.workdir, arch, datasets, pretrained=pretrained, train_hyper_id=hyper.hyper_id(), suffix='_' + hyper.other_id()) print('arch = {!r}'.format(arch)) dry = ub.argflag('--dry') if dry: model = None elif arch == 'segnet': model = models.SegNet(in_channels=n_channels, n_classes=n_classes) model.init_he_normal() if pretrained == 'vgg': model.init_vgg16_params() elif arch == 'linknet': model = models.LinkNet(in_channels=n_channels, n_classes=n_classes) elif arch == 'unet': model = models.UNet(in_channels=n_channels, n_classes=n_classes, nonlinearity='leaky_relu') snapshot = xpu_device.XPU(None).load(pretrained) model_state_dict = snapshot['model_state_dict'] model.load_partial_state(model_state_dict) # model.shock_outward() elif arch == 'unet2': from clab.live import unet2 model = unet2.UNet2(n_alt_classes=3, in_channels=n_channels, n_classes=n_classes, nonlinearity='leaky_relu') snapshot = xpu_device.XPU(None).load(pretrained) model_state_dict = snapshot['model_state_dict'] model.load_partial_state(model_state_dict) elif arch == 'dense_unet': from clab.live import unet3 model = unet3.DenseUNet(n_alt_classes=3, in_channels=n_channels, n_classes=n_classes) model.init_he_normal() snapshot = xpu_device.XPU(None).load(pretrained) model_state_dict = snapshot['model_state_dict'] model.load_partial_state(model_state_dict) elif arch == 'dense_unet2': from clab.live import unet3 model = unet3.DenseUNet2(n_alt_classes=3, in_channels=n_channels, n_classes=n_classes) # model.init_he_normal() snapshot = xpu_device.XPU(None).load(pretrained) model_state_dict = snapshot['model_state_dict'] model.load_partial_state(model_state_dict, shock_partial=False) elif arch == 'dummy': model = models.SSegDummy(in_channels=n_channels, n_classes=n_classes) else: raise ValueError('unknown arch') if ub.argflag('--finetune'): # Hack in a reduced learning rate hyper = hyperparams.HyperParams( criterion=( criterions.CrossEntropyLoss2D, { 'ignore_label': ignore_label, # TODO: weight should be a FloatTensor 'weight': class_weights, }), optimizer=( torch.optim.SGD, { # 'weight_decay': .0006, 'weight_decay': .0005, 'momentum': 0.99 if arch == 'dense_unet' else .9, 'nesterov': True, }), scheduler=('Constant', { 'base_lr': 0.0001, }), other={ 'n_classes': n_classes, 'n_channels': n_channels, 'augment': datasets['train'].augment, 'colorspace': datasets['train'].colorspace, }) xpu = xpu_device.XPU.from_argv() if datasets['train'].use_aux_diff: # arch in ['unet2', 'dense_unet']: from clab.live import fit_harn2 harn = fit_harn2.FitHarness( model=model, hyper=hyper, datasets=datasets, xpu=xpu, train_dpath=train_dpath, dry=dry, batch_size=batch_size, ) harn.criterion2 = criterions.CrossEntropyLoss2D( weight=torch.FloatTensor([.1, 1, 0]), ignore_label=2) def compute_loss(harn, outputs, labels): output1, output2 = outputs label1, label2 = labels # Compute the loss loss1 = harn.criterion(output1, label1) loss2 = harn.criterion2(output2, label2) loss = (.45 * loss1 + .55 * loss2) return loss harn.compute_loss = compute_loss # z = harn.loaders['train'] # b = next(iter(z)) # print('b = {!r}'.format(b)) # import sys # sys.exit(0) def custom_metrics(harn, output, label): ignore_label = datasets['train'].ignore_label labels = datasets['train'].task.labels metrics_dict = metrics._sseg_metrics(output[1], label[1], labels=labels, ignore_label=ignore_label) return metrics_dict else: harn = fit_harness.FitHarness( model=model, hyper=hyper, datasets=datasets, xpu=xpu, train_dpath=train_dpath, dry=dry, batch_size=batch_size, ) def custom_metrics(harn, output, label): ignore_label = datasets['train'].ignore_label labels = datasets['train'].task.labels metrics_dict = metrics._sseg_metrics(output, label, labels=labels, ignore_label=ignore_label) return metrics_dict harn.add_metric_hook(custom_metrics) # HACK # im = datasets['train'][0][0] # w, h = im.shape[-2:] # single_output_shape = (n_classes, w, h) # harn.single_output_shape = single_output_shape # print('harn.single_output_shape = {!r}'.format(harn.single_output_shape)) harn.run() return harn