Ejemplo n.º 1
0
def net(net_params, num_channels, inference=False):
    """Define the neural net"""
    model_name = net_params['global']['model_name'].lower()
    num_bands = int(net_params['global']['number_of_bands'])
    msg = f'Number of bands specified incompatible with this model. Requires 3 band data.'
    train_state_dict_path = get_key_def('state_dict_path',
                                        net_params['training'], None)
    pretrained = get_key_def('pretrained', net_params['training'],
                             True) if not inference else False
    dropout = get_key_def('dropout', net_params['training'], False)
    dropout_prob = get_key_def('dropout_prob', net_params['training'], 0.5)

    if model_name == 'unetsmall':
        model = unet.UNetSmall(num_channels, num_bands, dropout, dropout_prob)
    elif model_name == 'unet':
        model = unet.UNet(num_channels, num_bands, dropout, dropout_prob)
    elif model_name == 'ternausnet':
        assert num_bands == 3, msg
        model = TernausNet.ternausnet(num_channels)
    elif model_name == 'checkpointed_unet':
        model = checkpointed_unet.UNetSmall(num_channels, num_bands, dropout,
                                            dropout_prob)
    elif model_name == 'inception':
        model = inception.Inception3(num_channels, num_bands)
    elif model_name == 'fcn_resnet101':
        assert num_bands == 3, msg
        model = models.segmentation.fcn_resnet101(pretrained=False,
                                                  progress=True,
                                                  num_classes=num_channels,
                                                  aux_loss=None)
    elif model_name == 'deeplabv3_resnet101':
        assert (num_bands == 3 or num_bands == 4), msg
        if num_bands == 3:
            print('Finetuning pretrained deeplabv3 with 3 bands')
            model = models.segmentation.deeplabv3_resnet101(pretrained=True,
                                                            progress=True,
                                                            aux_loss=None)
            model.classifier = common.DeepLabHead(2048, num_channels)
        elif num_bands == 4:
            print('Finetuning pretrained deeplabv3 with 4 bands')
            model = models.segmentation.deeplabv3_resnet101(pretrained=True,
                                                            progress=True,
                                                            aux_loss=None)
            conv1 = model.backbone._modules['conv1'].weight.detach().numpy()
            depth = np.random.uniform(low=-1, high=1, size=(64, 1, 7, 7))
            conv1 = np.append(conv1, depth, axis=1)
            conv1 = torch.from_numpy(conv1).float()
            model.backbone._modules['conv1'].weight = nn.Parameter(
                conv1, requires_grad=True)
            model.classifier = common.DeepLabHead(2048, num_channels)
    else:
        raise ValueError(
            f'The model name {model_name} in the config.yaml is not defined.')

    coordconv_convert = get_key_def('coordconv_convert', net_params['global'],
                                    False)
    if coordconv_convert:
        centered = get_key_def('coordconv_centered', net_params['global'],
                               True)
        normalized = get_key_def('coordconv_normalized', net_params['global'],
                                 True)
        noise = get_key_def('coordconv_noise', net_params['global'], None)
        radius_channel = get_key_def('coordconv_radius_channel',
                                     net_params['global'], False)
        scale = get_key_def('coordconv_scale', net_params['global'], 1.0)
        # note: this operation will not attempt to preserve already-loaded model parameters!
        model = coordconv.swap_coordconv_layers(model,
                                                centered=centered,
                                                normalized=normalized,
                                                noise=noise,
                                                radius_channel=radius_channel,
                                                scale=scale)

    if inference:
        state_dict_path = net_params['inference']['state_dict_path']
        assert Path(net_params['inference']['state_dict_path']).is_file(
        ), f"Could not locate {net_params['inference']['state_dict_path']}"
        checkpoint = load_checkpoint(state_dict_path)
    elif train_state_dict_path is not None:
        assert Path(train_state_dict_path).is_file(
        ), f'Could not locate checkpoint at {train_state_dict_path}'
        checkpoint = load_checkpoint(train_state_dict_path)
    else:
        checkpoint = None

    return model, checkpoint, model_name
Ejemplo n.º 2
0
def net(net_params, num_channels, inference=False):
    """Define the neural net"""
    model_name = net_params['global']['model_name'].lower()
    num_bands = int(net_params['global']['number_of_bands'])
    msg = f'Number of bands specified incompatible with this model. Requires 3 band data.'
    train_state_dict_path = get_key_def('state_dict_path',
                                        net_params['training'], None)
    pretrained = get_key_def('pretrained', net_params['training'],
                             True) if not inference else False
    dropout = get_key_def('dropout', net_params['training'], False)
    dropout_prob = get_key_def('dropout_prob', net_params['training'], 0.5)

    # TODO: find a way to maybe implement it in classification one day
    if 'concatenate_depth' in net_params['global']:
        # Read the concatenation point
        conc_point = net_params['global']['concatenate_depth']

    if model_name == 'unetsmall':
        model = unet.UNetSmall(num_channels, num_bands, dropout, dropout_prob)
    elif model_name == 'unet':
        model = unet.UNet(num_channels, num_bands, dropout, dropout_prob)
    elif model_name == 'ternausnet':
        assert num_bands == 3, msg
        model = TernausNet.ternausnet(num_channels)
    elif model_name == 'checkpointed_unet':
        model = checkpointed_unet.UNetSmall(num_channels, num_bands, dropout,
                                            dropout_prob)
    elif model_name == 'inception':
        model = inception.Inception3(num_channels, num_bands)
    elif model_name == 'fcn_resnet101':
        assert num_bands == 3, msg
        model = models.segmentation.fcn_resnet101(pretrained=pretrained,
                                                  progress=True,
                                                  num_classes=num_channels,
                                                  aux_loss=None)
    elif model_name == 'deeplabv3_resnet101':
        assert (num_bands == 3 or num_bands == 4), msg
        if num_bands == 3:
            print('Finetuning pretrained deeplabv3 with 3 bands')
            model = models.segmentation.deeplabv3_resnet101(
                pretrained=pretrained, progress=True)
            classifier = list(model.classifier.children())
            model.classifier = nn.Sequential(*classifier[:-1])
            model.classifier.add_module(
                '4',
                nn.Conv2d(classifier[-1].in_channels,
                          num_channels,
                          kernel_size=(1, 1)))
        elif num_bands == 4:
            print('Finetuning pretrained deeplabv3 with 4 bands')
            print('Testing with 4 bands, concatenating at {}.'.format(
                conc_point))

            model = models.segmentation.deeplabv3_resnet101(
                pretrained=pretrained, progress=True)

            if conc_point == 'baseline':
                conv1 = model.backbone._modules['conv1'].weight.detach().numpy(
                )
                depth = np.expand_dims(
                    conv1[:, 1,
                          ...], axis=1)  # reuse green weights for infrared.
                conv1 = np.append(conv1, depth, axis=1)
                conv1 = torch.from_numpy(conv1).float()
                model.backbone._modules['conv1'].weight = nn.Parameter(
                    conv1, requires_grad=True)
                classifier = list(model.classifier.children())
                model.classifier = nn.Sequential(*classifier[:-1])
                model.classifier.add_module(
                    '4',
                    nn.Conv2d(classifier[-1].in_channels,
                              num_channels,
                              kernel_size=(1, 1)))
            else:
                classifier = list(model.classifier.children())
                model.classifier = nn.Sequential(*classifier[:-1])
                model.classifier.add_module(
                    '4',
                    nn.Conv2d(classifier[-1].in_channels,
                              num_channels,
                              kernel_size=(1, 1)))
                ###################
                #conv1 = model.backbone._modules['conv1'].weight.detach().numpy()
                #depth = np.random.uniform(low=-1, high=1, size=(64, 1, 7, 7))
                #conv1 = np.append(conv1, depth, axis=1)
                #conv1 = torch.from_numpy(conv1).float()
                #model.backbone._modules['conv1'].weight = nn.Parameter(conv1, requires_grad=True)
                ###################
                model = LayersEnsemble(model, conc_point=conc_point)

    elif model_name in lm_smp.keys():
        lsmp = lm_smp[model_name]
        # TODO: add possibility of our own weights
        lsmp['params'][
            'encoder_weights'] = "imagenet" if 'pretrained' in model_name.split(
                "_") else None
        lsmp['params']['in_channels'] = num_bands
        lsmp['params']['classes'] = num_channels
        lsmp['params']['activation'] = None

        model = lsmp['fct'](**lsmp['params'])

    else:
        raise ValueError(
            f'The model name {model_name} in the config.yaml is not defined.')

    coordconv_convert = get_key_def('coordconv_convert', net_params['global'],
                                    False)
    if coordconv_convert:
        centered = get_key_def('coordconv_centered', net_params['global'],
                               True)
        normalized = get_key_def('coordconv_normalized', net_params['global'],
                                 True)
        noise = get_key_def('coordconv_noise', net_params['global'], None)
        radius_channel = get_key_def('coordconv_radius_channel',
                                     net_params['global'], False)
        scale = get_key_def('coordconv_scale', net_params['global'], 1.0)
        # note: this operation will not attempt to preserve already-loaded model parameters!
        model = coordconv.swap_coordconv_layers(model,
                                                centered=centered,
                                                normalized=normalized,
                                                noise=noise,
                                                radius_channel=radius_channel,
                                                scale=scale)

    if inference:
        state_dict_path = net_params['inference']['state_dict_path']
        assert Path(net_params['inference']['state_dict_path']).is_file(
        ), f"Could not locate {net_params['inference']['state_dict_path']}"
        checkpoint = load_checkpoint(state_dict_path)
    elif train_state_dict_path is not None:
        assert Path(train_state_dict_path).is_file(
        ), f'Could not locate checkpoint at {train_state_dict_path}'
        checkpoint = load_checkpoint(train_state_dict_path)
    else:
        checkpoint = None

    return model, checkpoint, model_name
Ejemplo n.º 3
0
def net(net_params, num_channels, inference=False):
    """Define the neural net"""
    model_name = net_params['global']['model_name'].lower()
    num_bands = int(net_params['global']['number_of_bands'])
    msg = f'Number of bands specified incompatible with this model. Requires 3 band data.'
    train_state_dict_path = get_key_def('state_dict_path', net_params['training'], None)
    pretrained = get_key_def('pretrained', net_params['training'], True) if not inference else False
    dropout = get_key_def('dropout', net_params['training'], False)
    dropout_prob = get_key_def('dropout_prob', net_params['training'], 0.5)

    if model_name == 'unetsmall':
        model = unet.UNetSmall(num_channels, num_bands, dropout, dropout_prob)
    elif model_name == 'unet':
        model = unet.UNet(num_channels, num_bands, dropout, dropout_prob)
    elif model_name == 'ternausnet':
        assert num_bands == 3, msg
        model = TernausNet.ternausnet(num_channels)
    elif model_name == 'checkpointed_unet':
        model = checkpointed_unet.UNetSmall(num_channels, num_bands, dropout, dropout_prob)
    elif model_name == 'inception':
        model = inception.Inception3(num_channels, num_bands)
    elif model_name == 'fcn_resnet101':
        assert num_bands == 3, msg
        model = models.segmentation.fcn_resnet101(pretrained=False, progress=True, num_classes=num_channels,
                                                  aux_loss=None)
    elif model_name == 'deeplabv3_resnet101':
        try:
            model = models.segmentation.deeplabv3_resnet101(pretrained=False, progress=True, in_channels=num_bands,
                                                            num_classes=num_channels, aux_loss=None)
        except:
            assert num_bands==3, 'Edit torchvision scripts segmentation.py and resnet.py to build deeplabv3_resnet ' \
                                 'with more or less than 3 bands'
            model = models.segmentation.deeplabv3_resnet101(pretrained=False, progress=True,
                                                            num_classes=num_channels, aux_loss=None)
    else:
        raise ValueError(f'The model name {model_name} in the config.yaml is not defined.')

    coordconv_convert = get_key_def('coordconv_convert', net_params['global'], False)
    if coordconv_convert:
        centered = get_key_def('coordconv_centered', net_params['global'], True)
        normalized = get_key_def('coordconv_normalized', net_params['global'], True)
        noise = get_key_def('coordconv_noise', net_params['global'], None)
        radius_channel = get_key_def('coordconv_radius_channel', net_params['global'], False)
        scale = get_key_def('coordconv_scale', net_params['global'], 1.0)
        # note: this operation will not attempt to preserve already-loaded model parameters!
        model = coordconv.swap_coordconv_layers(model, centered=centered, normalized=normalized, noise=noise,
                                                radius_channel=radius_channel, scale=scale)

    if inference:
        state_dict_path = net_params['inference']['state_dict_path']
        assert Path(net_params['inference']['state_dict_path']).is_file(), f"Could not locate {net_params['inference']['state_dict_path']}"
        checkpoint = load_checkpoint(state_dict_path)
    elif train_state_dict_path is not None:
        assert Path(train_state_dict_path).is_file(), f'Could not locate checkpoint at {train_state_dict_path}'
        checkpoint = load_checkpoint(train_state_dict_path)
    elif pretrained and (model_name == ('deeplabv3_resnet101' or 'fcn_resnet101')):
        print(f'Retrieving coco checkpoint for {model_name}...\n')
        if model_name == 'deeplabv3_resnet101':  # default to pretrained on coco (21 classes)
            coco_model = models.segmentation.deeplabv3_resnet101(pretrained=True, progress=True, num_classes=21, aux_loss=None)
        else:
            coco_model = models.segmentation.fcn_resnet101(pretrained=True, progress=True, num_classes=21, aux_loss=None)
        checkpoint = coco_model.state_dict()
        # Place entire state_dict inside 'model' key for compatibility with the rest of GDL workflow
        temp_checkpoint = {}
        temp_checkpoint['model'] = {k: v for k, v in checkpoint.items()}
        del coco_model, checkpoint
        checkpoint = temp_checkpoint
    elif pretrained:
        warnings.warn(f'No pretrained checkpoint found for {model_name}.')
        checkpoint = None
    else:
        checkpoint = None

    return model, checkpoint, model_name
Ejemplo n.º 4
0
def net(net_params, num_channels, inference=False):
    """Define the neural net"""
    model_name = net_params['global']['model_name'].lower()
    num_bands = int(net_params['global']['number_of_bands'])
    msg = f'Number of bands specified incompatible with this model. Requires 3 band data.'
    train_state_dict_path = get_key_def('state_dict_path', net_params['training'], None)
    pretrained = get_key_def('pretrained', net_params['training'], True) if not inference else False
    dropout = get_key_def('dropout', net_params['training'], False)
    dropout_prob = get_key_def('dropout_prob', net_params['training'], 0.5)
    dontcare_val = get_key_def("ignore_index", net_params["training"], -1)
    num_devices = net_params['global']['num_gpus']

    if dontcare_val == 0:
        warnings.warn("The 'dontcare' value (or 'ignore_index') used in the loss function cannot be zero;"
                      " all valid class indices should be consecutive, and start at 0. The 'dontcare' value"
                      " will be remapped to -1 while loading the dataset, and inside the config from now on.")
        net_params["training"]["ignore_index"] = -1

    # TODO: find a way to maybe implement it in classification one day
    if 'concatenate_depth' in net_params['global']:
        # Read the concatenation point
        conc_point = net_params['global']['concatenate_depth']

    if model_name == 'unetsmall':
        model = unet.UNetSmall(num_channels, num_bands, dropout, dropout_prob)
    elif model_name == 'unet':
        model = unet.UNet(num_channels, num_bands, dropout, dropout_prob)
    elif model_name == 'ternausnet':
        assert num_bands == 3, msg
        model = TernausNet.ternausnet(num_channels)
    elif model_name == 'checkpointed_unet':
        model = checkpointed_unet.UNetSmall(num_channels, num_bands, dropout, dropout_prob)
    elif model_name == 'inception':
        model = inception.Inception3(num_channels, num_bands)
    elif model_name == 'fcn_resnet101':
        assert num_bands == 3, msg
        model = models.segmentation.fcn_resnet101(pretrained=False, progress=True, num_classes=num_channels,
                                                  aux_loss=None)
    elif model_name == 'deeplabv3_resnet101':
        assert (num_bands == 3 or num_bands == 4), msg
        if num_bands == 3:
            print('Finetuning pretrained deeplabv3 with 3 bands')
            model = models.segmentation.deeplabv3_resnet101(pretrained=pretrained, progress=True)
            classifier = list(model.classifier.children())
            model.classifier = nn.Sequential(*classifier[:-1])
            model.classifier.add_module('4', nn.Conv2d(classifier[-1].in_channels, num_channels, kernel_size=(1, 1)))
        elif num_bands == 4:
            print('Finetuning pretrained deeplabv3 with 4 bands')
            print('Testing with 4 bands, concatenating at {}.'.format(conc_point))

            model = models.segmentation.deeplabv3_resnet101(pretrained=pretrained, progress=True)

            if conc_point=='baseline':
                conv1 = model.backbone._modules['conv1'].weight.detach().numpy()
                depth = np.expand_dims(conv1[:, 1, ...], axis=1)  # reuse green weights for infrared.
                conv1 = np.append(conv1, depth, axis=1)
                conv1 = torch.from_numpy(conv1).float()
                model.backbone._modules['conv1'].weight = nn.Parameter(conv1, requires_grad=True)
                classifier = list(model.classifier.children())
                model.classifier = nn.Sequential(*classifier[:-1])
                model.classifier.add_module(
                    '4', nn.Conv2d(classifier[-1].in_channels, num_channels, kernel_size=(1, 1))
                )
            else:
                classifier = list(model.classifier.children())
                model.classifier = nn.Sequential(*classifier[:-1])
                model.classifier.add_module(
                        '4', nn.Conv2d(classifier[-1].in_channels, num_channels, kernel_size=(1, 1))
                )
                ###################
                #conv1 = model.backbone._modules['conv1'].weight.detach().numpy()
                #depth = np.random.uniform(low=-1, high=1, size=(64, 1, 7, 7))
                #conv1 = np.append(conv1, depth, axis=1)
                #conv1 = torch.from_numpy(conv1).float()
                #model.backbone._modules['conv1'].weight = nn.Parameter(conv1, requires_grad=True)
                ###################
                model = LayersEnsemble(model, conc_point=conc_point)

    elif model_name in lm_smp.keys():
        lsmp = lm_smp[model_name]
        # TODO: add possibility of our own weights
        lsmp['params']['encoder_weights'] = "imagenet" if 'pretrained' in model_name.split("_") else None
        lsmp['params']['in_channels'] = num_bands
        lsmp['params']['classes'] = num_channels
        lsmp['params']['activation'] = None

        model = lsmp['fct'](**lsmp['params'])


    else:
        raise ValueError(f'The model name {model_name} in the config.yaml is not defined.')

    coordconv_convert = get_key_def('coordconv_convert', net_params['global'], False)
    if coordconv_convert:
        centered = get_key_def('coordconv_centered', net_params['global'], True)
        normalized = get_key_def('coordconv_normalized', net_params['global'], True)
        noise = get_key_def('coordconv_noise', net_params['global'], None)
        radius_channel = get_key_def('coordconv_radius_channel', net_params['global'], False)
        scale = get_key_def('coordconv_scale', net_params['global'], 1.0)
        # note: this operation will not attempt to preserve already-loaded model parameters!
        model = coordconv.swap_coordconv_layers(model, centered=centered, normalized=normalized, noise=noise,
                                                radius_channel=radius_channel, scale=scale)

    if inference:
        state_dict_path = net_params['inference']['state_dict_path']
        assert Path(net_params['inference']['state_dict_path']).is_file(), f"Could not locate {net_params['inference']['state_dict_path']}"
        checkpoint = load_checkpoint(state_dict_path)

        return model, checkpoint, model_name

    else:

        if train_state_dict_path is not None:
            assert Path(train_state_dict_path).is_file(), f'Could not locate checkpoint at {train_state_dict_path}'
            checkpoint = load_checkpoint(train_state_dict_path)
        else:
            checkpoint = None
        assert num_devices is not None and num_devices >= 0, "missing mandatory num gpus parameter"
        # list of GPU devices that are available and unused. If no GPUs, returns empty list
        lst_device_ids = get_device_ids(num_devices) if torch.cuda.is_available() else []
        num_devices = len(lst_device_ids) if lst_device_ids else 0
        device = torch.device(f'cuda:{lst_device_ids[0]}' if torch.cuda.is_available() and lst_device_ids else 'cpu')
        print(f"Number of cuda devices requested: {net_params['global']['num_gpus']}. Cuda devices available: {lst_device_ids}\n")
        if num_devices == 1:
            print(f"Using Cuda device {lst_device_ids[0]}\n")
        elif num_devices > 1:
            print(f"Using data parallel on devices: {str(lst_device_ids)[1:-1]}. Main device: {lst_device_ids[0]}\n") # TODO: why are we showing indices [1:-1] for lst_device_ids?
            try:  # For HPC when device 0 not available. Error: Invalid device id (in torch/cuda/__init__.py).
                model = nn.DataParallel(model,
                                        device_ids=lst_device_ids)  # DataParallel adds prefix 'module.' to state_dict keys
            except AssertionError:
                warnings.warn(f"Unable to use devices {lst_device_ids}. Trying devices {list(range(len(lst_device_ids)))}")
                device = torch.device('cuda:0')
                lst_device_ids = range(len(lst_device_ids))
                model = nn.DataParallel(model,
                                        device_ids=lst_device_ids)  # DataParallel adds prefix 'module.' to state_dict keys
        else:
            warnings.warn(f"No Cuda device available. This process will only run on CPU\n")
        tqdm.write(f'Setting model, criterion, optimizer and learning rate scheduler...\n')
        try:  # For HPC when device 0 not available. Error: Cuda invalid device ordinal.
            model.to(device)
        except RuntimeError:
            warnings.warn(f"Unable to use device. Trying device 0...\n")
            device = torch.device(f'cuda:0' if torch.cuda.is_available() and lst_device_ids else 'cpu')
            model.to(device)

        model, criterion, optimizer, lr_scheduler = set_hyperparameters(net_params, num_channels, model, checkpoint, dontcare_val)
        criterion = criterion.to(device)

        return model, model_name, criterion, optimizer, lr_scheduler
Ejemplo n.º 5
0
def net(net_params, inference=False):
    """Define the neural net"""
    model_name = net_params['global']['model_name'].lower()
    num_classes = net_params['global']['num_classes']
    if num_classes == 1:
        warnings.warn(
            "config specified that number of classes is 1, but model will be instantiated"
            " with a minimum of two regardless (will assume that 'background' exists)"
        )
        num_classes = 2
    msg = f'Number of bands specified incompatible with this model. Requires 3 band data.'
    state_dict_path = ''
    if model_name == 'unetsmall':
        model = unet.UNetSmall(num_classes,
                               net_params['global']['number_of_bands'],
                               net_params['training']['dropout'],
                               net_params['training']['dropout_prob'])
    elif model_name == 'unet':
        model = unet.UNet(num_classes, net_params['global']['number_of_bands'],
                          net_params['training']['dropout'],
                          net_params['training']['dropout_prob'])
    elif model_name == 'ternausnet':
        assert net_params['global']['number_of_bands'] == 3, msg
        model = TernausNet.ternausnet(num_classes)
    elif model_name == 'checkpointed_unet':
        model = checkpointed_unet.UNetSmall(
            num_classes, net_params['global']['number_of_bands'],
            net_params['training']['dropout'],
            net_params['training']['dropout_prob'])
    elif model_name == 'inception':
        model = inception.Inception3(num_classes,
                                     net_params['global']['number_of_bands'])
    elif model_name == 'fcn_resnet101':
        assert net_params['global']['number_of_bands'] == 3, msg
        coco_model = models.segmentation.fcn_resnet101(pretrained=True,
                                                       progress=True,
                                                       num_classes=21,
                                                       aux_loss=None)
        model = models.segmentation.fcn_resnet101(pretrained=False,
                                                  progress=True,
                                                  num_classes=num_classes,
                                                  aux_loss=None)
        chopped_dict = chop_layer(coco_model.state_dict(),
                                  layer_names=['classifier.4'])
        del coco_model
        # load the new state dict
        # When strict=False, allows to load only the variables that are identical between the two models irrespective of
        # whether one is subset/superset of the other.
        model.load_state_dict(chopped_dict, strict=False)
    elif model_name == 'deeplabv3_resnet101':
        assert net_params['global']['number_of_bands'] == 3, msg
        # pretrained on coco (21 classes)
        coco_model = models.segmentation.deeplabv3_resnet101(pretrained=True,
                                                             progress=True,
                                                             num_classes=21,
                                                             aux_loss=None)
        model = models.segmentation.deeplabv3_resnet101(
            pretrained=False,
            progress=True,
            num_classes=num_classes,
            aux_loss=None)
        chopped_dict = chop_layer(coco_model.state_dict(),
                                  layer_names=['classifier.4'])
        del coco_model
        model.load_state_dict(chopped_dict, strict=False)
    else:
        raise ValueError(
            f'The model name {model_name} in the config.yaml is not defined.')

    coordconv_convert = get_key_def('coordconv_convert', net_params['global'],
                                    False)
    if coordconv_convert:
        centered = get_key_def('coordconv_centered', net_params['global'],
                               True)
        normalized = get_key_def('coordconv_normalized', net_params['global'],
                                 True)
        noise = get_key_def('coordconv_noise', net_params['global'], None)
        radius_channel = get_key_def('coordconv_radius_channel',
                                     net_params['global'], False)
        scale = get_key_def('coordconv_scale', net_params['global'], 1.0)
        # note: this operation will not attempt to preserve already-loaded model parameters!
        model = coordconv.swap_coordconv_layers(model,
                                                centered=centered,
                                                normalized=normalized,
                                                noise=noise,
                                                radius_channel=radius_channel,
                                                scale=scale)

    if net_params['training']['state_dict_path']:
        state_dict_path = net_params['training']['state_dict_path']
        checkpoint = load_checkpoint(state_dict_path)
    elif inference:
        state_dict_path = net_params['inference']['state_dict_path']
        checkpoint = load_checkpoint(state_dict_path)
    else:
        checkpoint = None

    return model, checkpoint, model_name
Ejemplo n.º 6
0
def net(model_name: str,
        num_bands: int,
        num_channels: int,
        dontcare_val: int,
        num_devices: int,
        train_state_dict_path: str = None,
        pretrained: bool = True,
        dropout_prob: float = False,
        loss_fn: str = None,
        optimizer: str = None,
        class_weights: Sequence = None,
        net_params=None,
        conc_point: str = None,
        coordconv_params=None,
        inference_state_dict: str = None):
    """Define the neural net"""
    msg = f'Number of bands specified incompatible with this model. Requires 3 band data.'
    pretrained = False if train_state_dict_path or inference_state_dict else pretrained
    dropout = True if dropout_prob else False
    model = None

    if model_name == 'unetsmall':
        model = unet.UNetSmall(num_channels, num_bands, dropout, dropout_prob)
    elif model_name == 'unet':
        model = unet.UNet(num_channels, num_bands, dropout, dropout_prob)
    elif model_name == 'ternausnet':
        if not num_bands == 3:
            raise NotImplementedError(msg)
        model = TernausNet.ternausnet(num_channels)
    elif model_name == 'checkpointed_unet':
        model = checkpointed_unet.UNetSmall(num_channels, num_bands, dropout,
                                            dropout_prob)
    elif model_name == 'inception':
        model = inception.Inception3(num_channels, num_bands)
    elif model_name == 'fcn_resnet101':
        if not num_bands == 3:
            raise NotImplementedError(msg)
        model = models.segmentation.fcn_resnet101(pretrained=False,
                                                  progress=True,
                                                  num_classes=num_channels,
                                                  aux_loss=None)
    elif model_name == 'deeplabv3_resnet101':
        if not (num_bands == 3 or num_bands == 4):
            raise NotImplementedError(msg)
        if num_bands == 3:
            model = models.segmentation.deeplabv3_resnet101(
                pretrained=pretrained, progress=True)
            classifier = list(model.classifier.children())
            model.classifier = nn.Sequential(*classifier[:-1])
            model.classifier.add_module(
                '4',
                nn.Conv2d(classifier[-1].in_channels,
                          num_channels,
                          kernel_size=(1, 1)))
        elif num_bands == 4:

            model = models.segmentation.deeplabv3_resnet101(
                pretrained=pretrained, progress=True)

            if conc_point == 'baseline':
                logging.info(
                    'Testing with 4 bands, concatenating at {}.'.format(
                        conc_point))
                conv1 = model.backbone._modules['conv1'].weight.detach().numpy(
                )
                depth = np.expand_dims(
                    conv1[:, 1,
                          ...], axis=1)  # reuse green weights for infrared.
                conv1 = np.append(conv1, depth, axis=1)
                conv1 = torch.from_numpy(conv1).float()
                model.backbone._modules['conv1'].weight = nn.Parameter(
                    conv1, requires_grad=True)
                classifier = list(model.classifier.children())
                model.classifier = nn.Sequential(*classifier[:-1])
                model.classifier.add_module(
                    '4',
                    nn.Conv2d(classifier[-1].in_channels,
                              num_channels,
                              kernel_size=(1, 1)))
            else:
                classifier = list(model.classifier.children())
                model.classifier = nn.Sequential(*classifier[:-1])
                model.classifier.add_module(
                    '4',
                    nn.Conv2d(classifier[-1].in_channels,
                              num_channels,
                              kernel_size=(1, 1)))
                ###################
                # conv1 = model.backbone._modules['conv1'].weight.detach().numpy()
                # depth = np.random.uniform(low=-1, high=1, size=(64, 1, 7, 7))
                # conv1 = np.append(conv1, depth, axis=1)
                # conv1 = torch.from_numpy(conv1).float()
                # model.backbone._modules['conv1'].weight = nn.Parameter(conv1, requires_grad=True)
                ###################
                conc_point = 'conv1' if not conc_point else conc_point
                model = LayersEnsemble(model, conc_point=conc_point)

        logging.info(
            f'Finetuning pretrained deeplabv3 with {num_bands} input channels (imagery bands). '
            f'Concatenation point: "{conc_point}"')

    elif model_name in lm_smp.keys():
        lsmp = lm_smp[model_name]
        # TODO: add possibility of our own weights
        lsmp['params'][
            'encoder_weights'] = "imagenet" if 'pretrained' in model_name.split(
                "_") else None
        lsmp['params']['in_channels'] = num_bands
        lsmp['params']['classes'] = num_channels
        lsmp['params']['activation'] = None

        model = lsmp['fct'](**lsmp['params'])

    else:
        raise ValueError(
            f'The model name {model_name} in the config.yaml is not defined.')

    coordconv_convert = get_key_def('coordconv_convert', coordconv_params,
                                    False)
    if coordconv_convert:
        centered = get_key_def('coordconv_centered', coordconv_params, True)
        normalized = get_key_def('coordconv_normalized', coordconv_params,
                                 True)
        noise = get_key_def('coordconv_noise', coordconv_params, None)
        radius_channel = get_key_def('coordconv_radius_channel',
                                     coordconv_params, False)
        scale = get_key_def('coordconv_scale', coordconv_params, 1.0)
        # note: this operation will not attempt to preserve already-loaded model parameters!
        model = coordconv.swap_coordconv_layers(model,
                                                centered=centered,
                                                normalized=normalized,
                                                noise=noise,
                                                radius_channel=radius_channel,
                                                scale=scale)

    if inference_state_dict:
        state_dict_path = inference_state_dict
        checkpoint = load_checkpoint(state_dict_path)

        return model, checkpoint, model_name

    else:

        if train_state_dict_path is not None:
            checkpoint = load_checkpoint(train_state_dict_path)
        else:
            checkpoint = None
        # list of GPU devices that are available and unused. If no GPUs, returns empty list
        gpu_devices_dict = get_device_ids(num_devices)
        num_devices = len(gpu_devices_dict.keys())
        logging.info(
            f"Number of cuda devices requested: {num_devices}. "
            f"Cuda devices available: {list(gpu_devices_dict.keys())}\n")
        if num_devices == 1:
            logging.info(
                f"Using Cuda device 'cuda:{list(gpu_devices_dict.keys())[0]}'")
        elif num_devices > 1:
            logging.info(
                f"Using data parallel on devices: {list(gpu_devices_dict.keys())[1:]}. "
                f"Main device: 'cuda:{list(gpu_devices_dict.keys())[0]}'")
            try:  # For HPC when device 0 not available. Error: Invalid device id (in torch/cuda/__init__.py).
                # DataParallel adds prefix 'module.' to state_dict keys
                model = nn.DataParallel(model,
                                        device_ids=list(
                                            gpu_devices_dict.keys()))
            except AssertionError:
                logging.warning(
                    f"Unable to use devices with ids {gpu_devices_dict.keys()}"
                    f"Trying devices with ids {list(range(len(gpu_devices_dict.keys())))}"
                )
                model = nn.DataParallel(
                    model,
                    device_ids=list(range(len(gpu_devices_dict.keys()))))
        else:
            logging.warning(
                f"No Cuda device available. This process will only run on CPU\n"
            )
        logging.info(
            f'Setting model, criterion, optimizer and learning rate scheduler...\n'
        )
        device = torch.device(
            f'cuda:{list(range(len(gpu_devices_dict.keys())))[0]}'
            if gpu_devices_dict else 'cpu')
        try:  # For HPC when device 0 not available. Error: Cuda invalid device ordinal.
            model.to(device)
        except AssertionError:
            logging.exception(f"Unable to use device. Trying device 0...\n")
            device = torch.device(f'cuda' if gpu_devices_dict else 'cpu')
            model.to(device)

        model, criterion, optimizer, lr_scheduler = set_hyperparameters(
            params=net_params,
            num_classes=num_channels,
            model=model,
            checkpoint=checkpoint,
            dontcare_val=dontcare_val,
            loss_fn=loss_fn,
            optimizer=optimizer,
            class_weights=class_weights,
            inference=inference_state_dict)
        criterion = criterion.to(device)

        return model, model_name, criterion, optimizer, lr_scheduler, device, gpu_devices_dict