예제 #1
0
파일: nnio.py 프로젝트: lclibardi/clab
    def map_location(storage: torch.Storage, location) -> torch.Storage:
        """
        Args:
            storage (torch.Storage) : the initial deserialization of the
                storage of the data read by `torch.load`, residing on the CPU.
            location (str): tag identifiying the location the data being read
                by `torch.load` was originally saved from.

        Returns:
            torch.Storage : the storage
        """
        xpu = xpu_device.XPU(gpu_num)
        if xpu.is_gpu():
            return storage.cuda(xpu.num)
        else:
            return storage
예제 #2
0
    def __init__(harn,
                 model,
                 datasets,
                 batch_size=4,
                 criterion_cls='cross_entropy',
                 hyper=None,
                 xpu=None,
                 train_dpath=None,
                 dry=False):

        harn.dry = dry
        if harn.dry:
            train_dpath = ub.ensure_app_cache_dir('clab/dry')
            ub.delete(train_dpath)
            train_dpath = ub.ensure_app_cache_dir('clab/dry')

        harn.dry = dry
        harn.train_dpath = train_dpath
        harn.snapshot_dpath = ub.ensuredir(
            (harn.train_dpath, 'torch_snapshots'))

        if harn.dry:
            harn.xpu = xpu_device.XPU(None)
        else:
            harn.xpu = xpu_device.XPU(xpu)

        data_kw = {'batch_size': batch_size}
        if harn.xpu.is_gpu():
            data_kw.update({'num_workers': 6, 'pin_memory': True})

        harn.loaders = {}
        harn.datasets = datasets
        for tag, dset in datasets.items():
            assert tag in {'train', 'vali', 'test'}
            shuffle = tag == 'train'
            data_kw_ = data_kw.copy()
            if tag != 'train':
                data_kw_['batch_size'] = max(batch_size // 4, 1)
            loader = torch.utils.data.DataLoader(dset,
                                                 shuffle=shuffle,
                                                 **data_kw_)
            harn.loaders[tag] = loader

        harn.model = model

        harn.hyper = hyper

        harn.lr_scheduler = hyper.scheduler_cls(**hyper.scheduler_params)
        harn.criterion_cls = hyper.criterion_cls
        harn.optimizer_cls = hyper.optimizer_cls

        harn.criterion_params = hyper.criterion_params
        harn.optimizer_params = hyper.optimizer_params

        harn._metric_hooks = []
        harn._run_metrics = None

        harn._epoch_callbacks = []
        harn._iter_callbacks = []

        harn.intervals = {
            'display_train': 1,
            'display_vali': 1,
            'display_test': 1,
            'vali': 1,
            'test': 1,
            'snapshot': 1,
        }
        harn.config = {
            'max_iter': 1000,
        }
        harn.epoch = 0
예제 #3
0
    def __init__(harn,
                 model,
                 datasets,
                 batch_size=4,
                 criterion_cls='cross_entropy',
                 hyper=None,
                 xpu=None,
                 train_dpath=None,
                 dry=False):

        harn.dry = dry
        if harn.dry:
            train_dpath = ub.ensure_app_cache_dir('clab/dry')
            ub.delete(train_dpath)
            train_dpath = ub.ensure_app_cache_dir('clab/dry')

        harn.dry = dry
        harn.train_dpath = train_dpath
        harn.snapshot_dpath = ub.ensuredir(
            (harn.train_dpath, 'torch_snapshots'))

        if harn.dry:
            harn.xpu = xpu_device.XPU(None)
        else:
            harn.xpu = xpu_device.XPU(xpu)

        # Allow for command line override
        batch_size = int(ub.argval('--batch_size', default=batch_size))

        data_kw = {'batch_size': batch_size}
        if harn.xpu.is_gpu():
            num_workers = int(ub.argval('--num_workers', default=6))
            num_workers = 0 if ub.argflag('--serial') else num_workers
            pin_memory = False if ub.argflag('--nopin') else True
            data_kw.update({
                'num_workers': num_workers,
                'pin_memory': pin_memory
            })
            # data_kw.update({'num_workers': 0, 'pin_memory': False})

        harn.loaders = ub.odict()
        harn.datasets = datasets
        assert set(harn.datasets.keys()).issubset({'train', 'vali', 'test'})
        for tag in ['train', 'vali', 'test']:
            dset = harn.datasets.get(tag, None)
            if dset:
                shuffle = tag == 'train'
                data_kw_ = data_kw.copy()
                if tag != 'train':
                    tag_batch_size = max(batch_size // 4, 1)
                    tag_batch_size = int(
                        ub.argval('--{}-batch_size'.format(tag),
                                  default=tag_batch_size))
                    data_kw_['batch_size'] = tag_batch_size
                loader = torch.utils.data.DataLoader(dset,
                                                     shuffle=shuffle,
                                                     **data_kw_)
                harn.loaders[tag] = loader

        harn.model = model

        harn.hyper = hyper

        harn.lr_scheduler = hyper.scheduler_cls(**hyper.scheduler_params)
        harn.criterion_cls = hyper.criterion_cls
        harn.optimizer_cls = hyper.optimizer_cls

        harn.criterion_params = hyper.criterion_params
        harn.optimizer_params = hyper.optimizer_params

        harn._metric_hooks = []
        harn._run_metrics = None

        harn._epoch_callbacks = []
        harn._iter_callbacks = []

        harn.early_stop = EarlyStop()

        harn.intervals = {
            'display_train': 1,
            'display_vali': 1,
            'display_test': 1,
            'vali': 1,
            'test': 1,
            'snapshot': 1,
        }
        harn.config = {
            'max_iter': 500,
        }
        harn.epoch = 0

        harn.tlogger = None
        harn.prog = None
예제 #4
0
def demo():
    """
    python -m clab.live.fit_harn2 demo

    Example:
        >>> from clab.live.fit_harn2 import *
        >>> demo()
    """
    from clab.torch import hyperparams
    from clab.torch import layers
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
    ])
    n_classes = 2

    datakw = dict(
        image_size=(1, 16, 16),
        num_classes=n_classes,
        transform=transform,
    )

    SSEG = 1
    if SSEG:
        datakw['image_size'] = (1, 200, 200)

        def _gt_as_sseg_mask(gt):
            # transform groundtruth into a random semantic segmentation mask
            return (torch.rand(datakw['image_size'][1:]) * n_classes).long()

        datakw['target_transform'] = _gt_as_sseg_mask
        factor = 1
    else:
        factor = 200

    datasets = {
        'train':
        torchvision.datasets.FakeData(size=10 * factor,
                                      random_offset=0,
                                      **datakw),
        'vali':
        torchvision.datasets.FakeData(size=5 * factor,
                                      random_offset=110000,
                                      **datakw),
        'test':
        torchvision.datasets.FakeData(size=2 * factor,
                                      random_offset=1110000,
                                      **datakw),
    }

    class DummyModel(torch.nn.Module):
        def __init__(self, n_classes):
            super(DummyModel, self).__init__()
            self.seq = torch.nn.Sequential(*[
                layers.Flatten(),
                torch.nn.Linear(int(np.prod(datakw['image_size'])), 10),
                torch.nn.ReLU(inplace=True),
                torch.nn.Linear(10, n_classes),
            ])
            self.n_classes = n_classes

        def forward(self, inputs):
            if isinstance(inputs, (tuple, list)):
                assert len(inputs) == 1
                inputs = inputs[0]
            return self.seq(inputs)

    if SSEG:
        from clab.torch.models import unet
        from clab.torch import criterions
        model = unet.UNet(in_channels=1, n_classes=n_classes, feature_scale=64)
        hyper = hyperparams.HyperParams(
            criterion_cls=criterions.CrossEntropyLoss2D, )
    else:
        model = DummyModel(n_classes)
        hyper = hyperparams.HyperParams(
            criterion_cls=torch.nn.CrossEntropyLoss, )

    xpu = xpu_device.XPU()

    # hack
    def compute_loss(harn, outputs, labels):
        # Compute the loss
        if isinstance(outputs, list):
            outputs = outputs[0]
        if isinstance(labels, list):
            target = labels[0]
        if SSEG:
            loss = harn.criterion(outputs, target)
        else:
            target = labels[0].long()
            pred = torch.nn.functional.log_softmax(outputs, dim=1)
            loss = torch.nn.functional.nll_loss(pred, target)
        return loss

    train_dpath = ub.ensure_app_cache_dir('clab/demo/fit_harness')
    ub.delete(train_dpath)

    dry = 0
    harn = FitHarness(
        model=model,
        datasets=datasets,
        xpu=xpu,
        hyper=hyper,
        train_dpath=train_dpath,
        dry=dry,
    )
    harn.compute_loss = compute_loss
    harn.run()

    if not dry:
        # reload and continue training the model
        # stopping criterion should trigger immediately
        harn = FitHarness(
            model=model,
            datasets=datasets,
            xpu=xpu,
            hyper=hyper,
            train_dpath=train_dpath,
            dry=dry,
        )
        harn.compute_loss = compute_loss
        harn.run()
예제 #5
0
def urban_fit():
    """

    CommandLine:
        python -m clab.live.urban_train urban_fit --profile

        python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=segnet

        python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=unet --noaux
        python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=unet

        python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --dry

        python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=unet --colorspace=RGB --combine


        python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=unet --dry

        python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=unet2 --colorspace=RGB --combine
        python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=unet2 --colorspace=RGB --use_aux_diff

        python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=dense_unet --colorspace=RGB --use_aux_diff


        # Train a variant of the dense net with more parameters
        python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=dense_unet --colorspace=RGB --use_aux_diff --combine \
                --pretrained '/home/local/KHQ/jon.crall/data/work/urban_mapper4/arch/dense_unet/train/input_25800-phpjjsqu/solver_25800-phpjjsqu_dense_unet_mmavmuou_zeosddyf_a=1,c=RGB,n_ch=6,n_cl=4/torch_snapshots/_epoch_00000030.pt' --gpu=1

        # Fine tune the model using all the available data
        python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=unet2 --colorspace=RGB --use_aux_diff --combine \
                --pretrained '/home/local/KHQ/jon.crall/data/work/urban_mapper2/arch/unet2/train/input_25800-hemanvft/solver_25800-hemanvft_unet2_mmavmuou_stuyuerd_a=1,c=RGB,n_ch=6,n_cl=4/torch_snapshots/_epoch_00000041.pt' --gpu=3 --finetune


        # Keep a bit of the data for validation but use more
        python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=unet2 --colorspace=RGB --use_aux_diff --halfcombo \
                --pretrained '/home/local/KHQ/jon.crall/data/work/urban_mapper2/arch/unet2/train/input_25800-hemanvft/solver_25800-hemanvft_unet2_mmavmuou_stuyuerd_a=1,c=RGB,n_ch=6,n_cl=4/torch_snapshots/_epoch_00000041.pt' --gpu=3

    Example:
        >>> from clab.torch.fit_harness import *
        >>> harn = urban_fit()
    """
    arch = ub.argval('--arch', default='unet')
    colorspace = ub.argval('--colorspace', default='RGB').upper()

    datasets = load_task_dataset('urban_mapper_3d',
                                 colorspace=colorspace,
                                 arch=arch)
    datasets['train'].augment = True

    # Make sure we use consistent normalization
    # TODO: give normalization a part of the hashid
    # TODO: save normalization type with the model
    # datasets['train'].center_inputs = datasets['train']._make_normalizer()

    # if ub.argflag('--combine'):
    #     # custom centering from the initialization point I'm going to use
    #     datasets['train'].center_inputs = datasets['train']._custom_urban_mapper_normalizer(
    #         0.3750553785198646, 1.026544662398811, 2.5136079110849674)
    # else:
    # datasets['train'].center_inputs = datasets['train']._make_normalizer(mode=2)
    datasets['train'].center_inputs = datasets['train']._make_normalizer(
        mode=3)
    # datasets['train'].center_inputs = _custom_urban_mapper_normalizer(0, 1, 2.5)

    datasets['test'].center_inputs = datasets['train'].center_inputs
    datasets['vali'].center_inputs = datasets['train'].center_inputs

    # Ensure normalization is the same for each dataset
    datasets['train'].augment = True

    # turn off aux layers
    if ub.argflag('--noaux'):
        for v in datasets.values():
            v.aux_keys = []

    batch_size = 14
    if arch == 'segnet':
        batch_size = 6
    elif arch == 'dense_unet':
        batch_size = 6
        # dense_unet batch memsizes
        # idle =   11 MiB
        # 0    =  438 MiB
        # 3   ~= 5000 MiB
        # 5    = 8280 MiB
        # 6    = 9758 MiB
        # each image adds (1478 - 1568.4) MiB

    n_classes = datasets['train'].n_classes
    n_channels = datasets['train'].n_channels
    class_weights = datasets['train'].class_weights()
    ignore_label = datasets['train'].ignore_label

    print('n_classes = {!r}'.format(n_classes))
    print('n_channels = {!r}'.format(n_channels))
    print('batch_size = {!r}'.format(batch_size))

    hyper = hyperparams.HyperParams(
        criterion=(
            criterions.CrossEntropyLoss2D,
            {
                'ignore_label': ignore_label,
                # TODO: weight should be a FloatTensor
                'weight': class_weights,
            }),
        optimizer=(
            torch.optim.SGD,
            {
                # 'weight_decay': .0006,
                'weight_decay': .0005,
                'momentum': 0.99 if arch == 'dense_unet' else .9,
                'nesterov': True,
            }),
        scheduler=(
            'Exponential',
            {
                'gamma': 0.99,
                # 'base_lr': 0.0015,
                'base_lr': 0.001 if not ub.argflag('--halfcombo') else 0.0005,
                'stepsize': 2,
            }),
        other={
            'n_classes': n_classes,
            'n_channels': n_channels,
            'augment': datasets['train'].augment,
            'colorspace': datasets['train'].colorspace,
        })

    starting_points = {
        'unet_rgb_4k':
        ub.truepath(
            '~/remote/aretha/data/work/urban_mapper/arch/unet/train/input_4214-yxalqwdk/solver_4214-yxalqwdk_unet_vgg_nttxoagf_a=1,n_ch=5,n_cl=3/torch_snapshots/_epoch_00000236.pt'
        ),

        # 'unet_rgb_8k': ub.truepath('~/remote/aretha/data/work/urban_mapper/arch/unet/train/input_8438-haplmmpq/solver_8438-haplmmpq_unet_None_kvterjeu_a=1,c=RGB,n_ch=5,n_cl=3/torch_snapshots/_epoch_00000402.pt'),
        # "ImageCenterScale", {"im_mean": [[[0.3750553785198646]]], "im_scale": [[[1.026544662398811]]]}
        # "DTMCenterScale", "std": 2.5136079110849674, "nan_value": -32767.0 }

        # 'unet_rgb_8k': ub.truepath(
        #     '~/data/work/urban_mapper2/arch/unet/train/input_4214-guwsobde/'
        #     'solver_4214-guwsobde_unet_mmavmuou_eqnoygqy_a=1,c=RGB,n_ch=5,n_cl=4/torch_snapshots/_epoch_00000189.pt'
        # )
        'unet_rgb_8k':
        ub.truepath(
            '~/remote/aretha/data/work/urban_mapper2/arch/unet2/train/input_4214-guwsobde/'
            'solver_4214-guwsobde_unet2_mmavmuou_tqynysqo_a=1,c=RGB,n_ch=5,n_cl=4/torch_snapshots/_epoch_00000100.pt'
        )
    }

    pretrained = ub.argval('--pretrained', default=None)
    if pretrained is None:
        if arch == 'segnet':
            pretrained = 'vgg'
        else:
            pretrained = None
            if ub.argflag('--combine'):
                pretrained = starting_points['unet_rgb_8k']

                if arch == 'unet2':
                    pretrained = '/home/local/KHQ/jon.crall/data/work/urban_mapper2/arch/unet2/train/input_25800-hemanvft/solver_25800-hemanvft_unet2_mmavmuou_stuyuerd_a=1,c=RGB,n_ch=6,n_cl=4/torch_snapshots/_epoch_00000042.pt'
                elif arch == 'dense_unet2':
                    pretrained = '/home/local/KHQ/jon.crall/data/work/urban_mapper2/arch/unet2/train/input_25800-hemanvft/solver_25800-hemanvft_unet2_mmavmuou_stuyuerd_a=1,c=RGB,n_ch=6,n_cl=4/torch_snapshots/_epoch_00000042.pt'
            else:
                pretrained = starting_points['unet_rgb_4k']

    train_dpath = directory_structure(datasets['train'].task.workdir,
                                      arch,
                                      datasets,
                                      pretrained=pretrained,
                                      train_hyper_id=hyper.hyper_id(),
                                      suffix='_' + hyper.other_id())

    print('arch = {!r}'.format(arch))
    dry = ub.argflag('--dry')
    if dry:
        model = None
    elif arch == 'segnet':
        model = models.SegNet(in_channels=n_channels, n_classes=n_classes)
        model.init_he_normal()
        if pretrained == 'vgg':
            model.init_vgg16_params()
    elif arch == 'linknet':
        model = models.LinkNet(in_channels=n_channels, n_classes=n_classes)
    elif arch == 'unet':
        model = models.UNet(in_channels=n_channels,
                            n_classes=n_classes,
                            nonlinearity='leaky_relu')
        snapshot = xpu_device.XPU(None).load(pretrained)
        model_state_dict = snapshot['model_state_dict']
        model.load_partial_state(model_state_dict)
        # model.shock_outward()
    elif arch == 'unet2':
        from clab.live import unet2
        model = unet2.UNet2(n_alt_classes=3,
                            in_channels=n_channels,
                            n_classes=n_classes,
                            nonlinearity='leaky_relu')
        snapshot = xpu_device.XPU(None).load(pretrained)
        model_state_dict = snapshot['model_state_dict']
        model.load_partial_state(model_state_dict)

    elif arch == 'dense_unet':
        from clab.live import unet3
        model = unet3.DenseUNet(n_alt_classes=3,
                                in_channels=n_channels,
                                n_classes=n_classes)
        model.init_he_normal()
        snapshot = xpu_device.XPU(None).load(pretrained)
        model_state_dict = snapshot['model_state_dict']
        model.load_partial_state(model_state_dict)
    elif arch == 'dense_unet2':
        from clab.live import unet3
        model = unet3.DenseUNet2(n_alt_classes=3,
                                 in_channels=n_channels,
                                 n_classes=n_classes)
        # model.init_he_normal()
        snapshot = xpu_device.XPU(None).load(pretrained)
        model_state_dict = snapshot['model_state_dict']
        model.load_partial_state(model_state_dict, shock_partial=False)
    elif arch == 'dummy':
        model = models.SSegDummy(in_channels=n_channels, n_classes=n_classes)
    else:
        raise ValueError('unknown arch')

    if ub.argflag('--finetune'):
        # Hack in a reduced learning rate
        hyper = hyperparams.HyperParams(
            criterion=(
                criterions.CrossEntropyLoss2D,
                {
                    'ignore_label': ignore_label,
                    # TODO: weight should be a FloatTensor
                    'weight': class_weights,
                }),
            optimizer=(
                torch.optim.SGD,
                {
                    # 'weight_decay': .0006,
                    'weight_decay': .0005,
                    'momentum': 0.99 if arch == 'dense_unet' else .9,
                    'nesterov': True,
                }),
            scheduler=('Constant', {
                'base_lr': 0.0001,
            }),
            other={
                'n_classes': n_classes,
                'n_channels': n_channels,
                'augment': datasets['train'].augment,
                'colorspace': datasets['train'].colorspace,
            })

    xpu = xpu_device.XPU.from_argv()

    if datasets['train'].use_aux_diff:
        # arch in ['unet2', 'dense_unet']:

        from clab.live import fit_harn2
        harn = fit_harn2.FitHarness(
            model=model,
            hyper=hyper,
            datasets=datasets,
            xpu=xpu,
            train_dpath=train_dpath,
            dry=dry,
            batch_size=batch_size,
        )
        harn.criterion2 = criterions.CrossEntropyLoss2D(
            weight=torch.FloatTensor([.1, 1, 0]), ignore_label=2)

        def compute_loss(harn, outputs, labels):

            output1, output2 = outputs
            label1, label2 = labels

            # Compute the loss
            loss1 = harn.criterion(output1, label1)
            loss2 = harn.criterion2(output2, label2)
            loss = (.45 * loss1 + .55 * loss2)
            return loss

        harn.compute_loss = compute_loss

        # z = harn.loaders['train']
        # b = next(iter(z))
        # print('b = {!r}'.format(b))
        # import sys
        # sys.exit(0)

        def custom_metrics(harn, output, label):
            ignore_label = datasets['train'].ignore_label
            labels = datasets['train'].task.labels

            metrics_dict = metrics._sseg_metrics(output[1],
                                                 label[1],
                                                 labels=labels,
                                                 ignore_label=ignore_label)
            return metrics_dict
    else:
        harn = fit_harness.FitHarness(
            model=model,
            hyper=hyper,
            datasets=datasets,
            xpu=xpu,
            train_dpath=train_dpath,
            dry=dry,
            batch_size=batch_size,
        )

        def custom_metrics(harn, output, label):
            ignore_label = datasets['train'].ignore_label
            labels = datasets['train'].task.labels

            metrics_dict = metrics._sseg_metrics(output,
                                                 label,
                                                 labels=labels,
                                                 ignore_label=ignore_label)
            return metrics_dict

    harn.add_metric_hook(custom_metrics)

    # HACK
    # im = datasets['train'][0][0]
    # w, h = im.shape[-2:]
    # single_output_shape = (n_classes, w, h)
    # harn.single_output_shape = single_output_shape
    # print('harn.single_output_shape = {!r}'.format(harn.single_output_shape))

    harn.run()
    return harn