예제 #1
0
    def load_snapshot(pharn, load_path):
        print('Loading snapshot onto {}'.format(pharn.xpu))
        snapshot = torch.load(load_path, map_location=pharn.xpu.map_location())

        if 'model_kw' not in snapshot:
            # FIXME: we should be able to get information from the snapshot
            print('warning snapshot not saved with modelkw')
            n_classes = pharn.dataset.n_classes
            n_channels = pharn.dataset.n_channels

        # Infer which model this belongs to
        # FIXME: The model must be constructed with the EXACT same kwargs This
        # will be easier when onnx supports model serialization.
        if snapshot['model_class_name'] == 'UNet':
            pharn.model = models.UNet(in_channels=n_channels,
                                      n_classes=n_classes,
                                      nonlinearity='leaky_relu')
        elif snapshot['model_class_name'] == 'UNet2':
            pharn.model = unet2.UNet2(
                in_channels=n_channels, n_classes=n_classes, n_alt_classes=3,
                nonlinearity='leaky_relu'
            )
        elif snapshot['model_class_name'] == 'DenseUNet':
            pharn.model = unet3.DenseUNet(
                in_channels=n_channels, n_classes=n_classes, n_alt_classes=3,
            )
        else:
            raise NotImplementedError(snapshot['model_class_name'])

        pharn.model = pharn.xpu.to_xpu(pharn.model)
        pharn.model.load_state_dict(snapshot['model_state_dict'])
예제 #2
0
def fit_networks(datasets, xpu):
    print('datasets = {}'.format(datasets))
    n_classes = datasets['train'].n_classes
    n_channels = datasets['train'].n_channels
    class_weights = datasets['train'].class_weights()
    ignore_label = datasets['train'].ignore_label

    print('n_classes = {!r}'.format(n_classes))
    print('n_channels = {!r}'.format(n_channels))

    arches = [
        'unet2',
        'dense_unet',
    ]

    arch_to_train_dpath = {}
    arch_to_best_epochs = {}

    for arch in arches:

        hyper = hyperparams.HyperParams(
            criterion=(criterions.CrossEntropyLoss2D, {
                'ignore_label': ignore_label,
                # TODO: weight should be a FloatTensor
                'weight': class_weights,
            }),
            optimizer=(torch.optim.SGD, {
                # 'weight_decay': .0006,
                'weight_decay': .0005,
                'momentum': .9,
                'nesterov': True,
            }),
            scheduler=('Exponential', {
                'gamma': 0.99,
                'base_lr': 0.001,
                'stepsize': 2,
            }),
            other={
                'n_classes': n_classes,
                'n_channels': n_channels,
                'augment': datasets['train'].augment,
                'colorspace': datasets['train'].colorspace,
            }
        )

        train_dpath = ub.ensuredir((datasets['train'].task.workdir, 'train', arch))

        train_info =  {
            'arch': arch,
            'train_id': datasets['train'].input_id,
            'train_hyper_id': hyper.hyper_id(),
            'colorspace': datasets['train'].colorspace,
            # Hack in centering information
            'hack_centers': [
                (t.__class__.__name__, t.__getstate__())
                for t in datasets['train'].center_inputs.transforms
            ]
        }
        util.write_json(join(train_dpath, 'train_info.json'), train_info)

        arch_to_train_dpath[arch] = train_dpath

        if arch == 'unet2':
            batch_size = 14
            model = unet2.UNet2(n_alt_classes=3, in_channels=n_channels,
                                n_classes=n_classes, nonlinearity='leaky_relu')
        elif arch == 'dense_unet':
            batch_size = 6
            model = unet3.DenseUNet(n_alt_classes=3, in_channels=n_channels,
                                    n_classes=n_classes)

        dry = 0
        harn = fit_harn2.FitHarness(
            model=model, hyper=hyper, datasets=datasets, xpu=xpu,
            train_dpath=train_dpath, dry=dry,
            batch_size=batch_size,
        )
        harn.criterion2 = criterions.CrossEntropyLoss2D(
            weight=torch.FloatTensor([0.1, 1.0, 0.0]),
            ignore_label=2
        )
        if DEBUG:
            harn.config['max_iter'] = 30
        else:

            # Note on aretha we can do 140 epochs in 7 days, so
            # be careful with how long we take to train.
            # With a reduction of 16, we can take a few more epochs
            # Unet2 take ~10 minutes to get through one

            # with num_workers=0, we have 374.00s/it = 6.23 m/it
            # this comes down to 231 epochs per day
            # harn.config['max_iter'] = 432  # 3 days max
            harn.config['max_iter'] = 200  # ~1 day max (if multiprocessing works)
        harn.early_stop.patience = 10

        def compute_loss(harn, outputs, labels):

            output1, output2 = outputs
            label1, label2 = labels

            # Compute the loss
            loss1 = harn.criterion(output1, label1)
            loss2 = harn.criterion2(output2, label2)
            loss = (.45 * loss1 + .55 * loss2)
            return loss

        harn.compute_loss = compute_loss

        def custom_metrics(harn, output, label):
            ignore_label = datasets['train'].ignore_label
            labels = datasets['train'].task.labels

            metrics_dict = metrics._sseg_metrics(output[1], label[1],
                                                 labels=labels,
                                                 ignore_label=ignore_label)
            return metrics_dict

        harn.add_metric_hook(custom_metrics)

        harn.run()
        arch_to_best_epochs[arch] = harn.early_stop.best_epochs()

    # Select model and hyperparams
    print('arch_to_train_dpath = {}'.format(ub.repr2(arch_to_train_dpath, nl=1)))
    print('arch_to_best_epochs = {}'.format(ub.repr2(arch_to_best_epochs, nl=1)))
    return arch_to_train_dpath, arch_to_best_epochs
예제 #3
0
def urban_fit():
    """

    CommandLine:
        python -m clab.live.urban_train urban_fit --profile

        python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=segnet

        python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=unet --noaux
        python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=unet

        python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --dry

        python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=unet --colorspace=RGB --combine


        python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=unet --dry

        python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=unet2 --colorspace=RGB --combine
        python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=unet2 --colorspace=RGB --use_aux_diff

        python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=dense_unet --colorspace=RGB --use_aux_diff


        # Train a variant of the dense net with more parameters
        python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=dense_unet --colorspace=RGB --use_aux_diff --combine \
                --pretrained '/home/local/KHQ/jon.crall/data/work/urban_mapper4/arch/dense_unet/train/input_25800-phpjjsqu/solver_25800-phpjjsqu_dense_unet_mmavmuou_zeosddyf_a=1,c=RGB,n_ch=6,n_cl=4/torch_snapshots/_epoch_00000030.pt' --gpu=1

        # Fine tune the model using all the available data
        python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=unet2 --colorspace=RGB --use_aux_diff --combine \
                --pretrained '/home/local/KHQ/jon.crall/data/work/urban_mapper2/arch/unet2/train/input_25800-hemanvft/solver_25800-hemanvft_unet2_mmavmuou_stuyuerd_a=1,c=RGB,n_ch=6,n_cl=4/torch_snapshots/_epoch_00000041.pt' --gpu=3 --finetune


        # Keep a bit of the data for validation but use more
        python -m clab.live.urban_train urban_fit --task=urban_mapper_3d --arch=unet2 --colorspace=RGB --use_aux_diff --halfcombo \
                --pretrained '/home/local/KHQ/jon.crall/data/work/urban_mapper2/arch/unet2/train/input_25800-hemanvft/solver_25800-hemanvft_unet2_mmavmuou_stuyuerd_a=1,c=RGB,n_ch=6,n_cl=4/torch_snapshots/_epoch_00000041.pt' --gpu=3

    Example:
        >>> from clab.torch.fit_harness import *
        >>> harn = urban_fit()
    """
    arch = ub.argval('--arch', default='unet')
    colorspace = ub.argval('--colorspace', default='RGB').upper()

    datasets = load_task_dataset('urban_mapper_3d',
                                 colorspace=colorspace,
                                 arch=arch)
    datasets['train'].augment = True

    # Make sure we use consistent normalization
    # TODO: give normalization a part of the hashid
    # TODO: save normalization type with the model
    # datasets['train'].center_inputs = datasets['train']._make_normalizer()

    # if ub.argflag('--combine'):
    #     # custom centering from the initialization point I'm going to use
    #     datasets['train'].center_inputs = datasets['train']._custom_urban_mapper_normalizer(
    #         0.3750553785198646, 1.026544662398811, 2.5136079110849674)
    # else:
    # datasets['train'].center_inputs = datasets['train']._make_normalizer(mode=2)
    datasets['train'].center_inputs = datasets['train']._make_normalizer(
        mode=3)
    # datasets['train'].center_inputs = _custom_urban_mapper_normalizer(0, 1, 2.5)

    datasets['test'].center_inputs = datasets['train'].center_inputs
    datasets['vali'].center_inputs = datasets['train'].center_inputs

    # Ensure normalization is the same for each dataset
    datasets['train'].augment = True

    # turn off aux layers
    if ub.argflag('--noaux'):
        for v in datasets.values():
            v.aux_keys = []

    batch_size = 14
    if arch == 'segnet':
        batch_size = 6
    elif arch == 'dense_unet':
        batch_size = 6
        # dense_unet batch memsizes
        # idle =   11 MiB
        # 0    =  438 MiB
        # 3   ~= 5000 MiB
        # 5    = 8280 MiB
        # 6    = 9758 MiB
        # each image adds (1478 - 1568.4) MiB

    n_classes = datasets['train'].n_classes
    n_channels = datasets['train'].n_channels
    class_weights = datasets['train'].class_weights()
    ignore_label = datasets['train'].ignore_label

    print('n_classes = {!r}'.format(n_classes))
    print('n_channels = {!r}'.format(n_channels))
    print('batch_size = {!r}'.format(batch_size))

    hyper = hyperparams.HyperParams(
        criterion=(
            criterions.CrossEntropyLoss2D,
            {
                'ignore_label': ignore_label,
                # TODO: weight should be a FloatTensor
                'weight': class_weights,
            }),
        optimizer=(
            torch.optim.SGD,
            {
                # 'weight_decay': .0006,
                'weight_decay': .0005,
                'momentum': 0.99 if arch == 'dense_unet' else .9,
                'nesterov': True,
            }),
        scheduler=(
            'Exponential',
            {
                'gamma': 0.99,
                # 'base_lr': 0.0015,
                'base_lr': 0.001 if not ub.argflag('--halfcombo') else 0.0005,
                'stepsize': 2,
            }),
        other={
            'n_classes': n_classes,
            'n_channels': n_channels,
            'augment': datasets['train'].augment,
            'colorspace': datasets['train'].colorspace,
        })

    starting_points = {
        'unet_rgb_4k':
        ub.truepath(
            '~/remote/aretha/data/work/urban_mapper/arch/unet/train/input_4214-yxalqwdk/solver_4214-yxalqwdk_unet_vgg_nttxoagf_a=1,n_ch=5,n_cl=3/torch_snapshots/_epoch_00000236.pt'
        ),

        # 'unet_rgb_8k': ub.truepath('~/remote/aretha/data/work/urban_mapper/arch/unet/train/input_8438-haplmmpq/solver_8438-haplmmpq_unet_None_kvterjeu_a=1,c=RGB,n_ch=5,n_cl=3/torch_snapshots/_epoch_00000402.pt'),
        # "ImageCenterScale", {"im_mean": [[[0.3750553785198646]]], "im_scale": [[[1.026544662398811]]]}
        # "DTMCenterScale", "std": 2.5136079110849674, "nan_value": -32767.0 }

        # 'unet_rgb_8k': ub.truepath(
        #     '~/data/work/urban_mapper2/arch/unet/train/input_4214-guwsobde/'
        #     'solver_4214-guwsobde_unet_mmavmuou_eqnoygqy_a=1,c=RGB,n_ch=5,n_cl=4/torch_snapshots/_epoch_00000189.pt'
        # )
        'unet_rgb_8k':
        ub.truepath(
            '~/remote/aretha/data/work/urban_mapper2/arch/unet2/train/input_4214-guwsobde/'
            'solver_4214-guwsobde_unet2_mmavmuou_tqynysqo_a=1,c=RGB,n_ch=5,n_cl=4/torch_snapshots/_epoch_00000100.pt'
        )
    }

    pretrained = ub.argval('--pretrained', default=None)
    if pretrained is None:
        if arch == 'segnet':
            pretrained = 'vgg'
        else:
            pretrained = None
            if ub.argflag('--combine'):
                pretrained = starting_points['unet_rgb_8k']

                if arch == 'unet2':
                    pretrained = '/home/local/KHQ/jon.crall/data/work/urban_mapper2/arch/unet2/train/input_25800-hemanvft/solver_25800-hemanvft_unet2_mmavmuou_stuyuerd_a=1,c=RGB,n_ch=6,n_cl=4/torch_snapshots/_epoch_00000042.pt'
                elif arch == 'dense_unet2':
                    pretrained = '/home/local/KHQ/jon.crall/data/work/urban_mapper2/arch/unet2/train/input_25800-hemanvft/solver_25800-hemanvft_unet2_mmavmuou_stuyuerd_a=1,c=RGB,n_ch=6,n_cl=4/torch_snapshots/_epoch_00000042.pt'
            else:
                pretrained = starting_points['unet_rgb_4k']

    train_dpath = directory_structure(datasets['train'].task.workdir,
                                      arch,
                                      datasets,
                                      pretrained=pretrained,
                                      train_hyper_id=hyper.hyper_id(),
                                      suffix='_' + hyper.other_id())

    print('arch = {!r}'.format(arch))
    dry = ub.argflag('--dry')
    if dry:
        model = None
    elif arch == 'segnet':
        model = models.SegNet(in_channels=n_channels, n_classes=n_classes)
        model.init_he_normal()
        if pretrained == 'vgg':
            model.init_vgg16_params()
    elif arch == 'linknet':
        model = models.LinkNet(in_channels=n_channels, n_classes=n_classes)
    elif arch == 'unet':
        model = models.UNet(in_channels=n_channels,
                            n_classes=n_classes,
                            nonlinearity='leaky_relu')
        snapshot = xpu_device.XPU(None).load(pretrained)
        model_state_dict = snapshot['model_state_dict']
        model.load_partial_state(model_state_dict)
        # model.shock_outward()
    elif arch == 'unet2':
        from clab.live import unet2
        model = unet2.UNet2(n_alt_classes=3,
                            in_channels=n_channels,
                            n_classes=n_classes,
                            nonlinearity='leaky_relu')
        snapshot = xpu_device.XPU(None).load(pretrained)
        model_state_dict = snapshot['model_state_dict']
        model.load_partial_state(model_state_dict)

    elif arch == 'dense_unet':
        from clab.live import unet3
        model = unet3.DenseUNet(n_alt_classes=3,
                                in_channels=n_channels,
                                n_classes=n_classes)
        model.init_he_normal()
        snapshot = xpu_device.XPU(None).load(pretrained)
        model_state_dict = snapshot['model_state_dict']
        model.load_partial_state(model_state_dict)
    elif arch == 'dense_unet2':
        from clab.live import unet3
        model = unet3.DenseUNet2(n_alt_classes=3,
                                 in_channels=n_channels,
                                 n_classes=n_classes)
        # model.init_he_normal()
        snapshot = xpu_device.XPU(None).load(pretrained)
        model_state_dict = snapshot['model_state_dict']
        model.load_partial_state(model_state_dict, shock_partial=False)
    elif arch == 'dummy':
        model = models.SSegDummy(in_channels=n_channels, n_classes=n_classes)
    else:
        raise ValueError('unknown arch')

    if ub.argflag('--finetune'):
        # Hack in a reduced learning rate
        hyper = hyperparams.HyperParams(
            criterion=(
                criterions.CrossEntropyLoss2D,
                {
                    'ignore_label': ignore_label,
                    # TODO: weight should be a FloatTensor
                    'weight': class_weights,
                }),
            optimizer=(
                torch.optim.SGD,
                {
                    # 'weight_decay': .0006,
                    'weight_decay': .0005,
                    'momentum': 0.99 if arch == 'dense_unet' else .9,
                    'nesterov': True,
                }),
            scheduler=('Constant', {
                'base_lr': 0.0001,
            }),
            other={
                'n_classes': n_classes,
                'n_channels': n_channels,
                'augment': datasets['train'].augment,
                'colorspace': datasets['train'].colorspace,
            })

    xpu = xpu_device.XPU.from_argv()

    if datasets['train'].use_aux_diff:
        # arch in ['unet2', 'dense_unet']:

        from clab.live import fit_harn2
        harn = fit_harn2.FitHarness(
            model=model,
            hyper=hyper,
            datasets=datasets,
            xpu=xpu,
            train_dpath=train_dpath,
            dry=dry,
            batch_size=batch_size,
        )
        harn.criterion2 = criterions.CrossEntropyLoss2D(
            weight=torch.FloatTensor([.1, 1, 0]), ignore_label=2)

        def compute_loss(harn, outputs, labels):

            output1, output2 = outputs
            label1, label2 = labels

            # Compute the loss
            loss1 = harn.criterion(output1, label1)
            loss2 = harn.criterion2(output2, label2)
            loss = (.45 * loss1 + .55 * loss2)
            return loss

        harn.compute_loss = compute_loss

        # z = harn.loaders['train']
        # b = next(iter(z))
        # print('b = {!r}'.format(b))
        # import sys
        # sys.exit(0)

        def custom_metrics(harn, output, label):
            ignore_label = datasets['train'].ignore_label
            labels = datasets['train'].task.labels

            metrics_dict = metrics._sseg_metrics(output[1],
                                                 label[1],
                                                 labels=labels,
                                                 ignore_label=ignore_label)
            return metrics_dict
    else:
        harn = fit_harness.FitHarness(
            model=model,
            hyper=hyper,
            datasets=datasets,
            xpu=xpu,
            train_dpath=train_dpath,
            dry=dry,
            batch_size=batch_size,
        )

        def custom_metrics(harn, output, label):
            ignore_label = datasets['train'].ignore_label
            labels = datasets['train'].task.labels

            metrics_dict = metrics._sseg_metrics(output,
                                                 label,
                                                 labels=labels,
                                                 ignore_label=ignore_label)
            return metrics_dict

    harn.add_metric_hook(custom_metrics)

    # HACK
    # im = datasets['train'][0][0]
    # w, h = im.shape[-2:]
    # single_output_shape = (n_classes, w, h)
    # harn.single_output_shape = single_output_shape
    # print('harn.single_output_shape = {!r}'.format(harn.single_output_shape))

    harn.run()
    return harn