コード例 #1
0
def run_ner(
        lang: str = 'eng',
        log_dir: str = 'logs',
        task: str = NER,
        batch_size: int = 1,
        epochs: int = 1,
        dataset: str = 'data/conll-2003/',
        loss: str = 'cross',
        max_seq_len: int = 128,
        do_lower_case: bool = False,
        warmup_proportion: float = 0.1,
        rand_seed: int = None,
        ds_size: int = None,
        data_bunch_path: str = 'data/conll-2003/db',
        tuned_learner: str = None,
        do_train: str = False,
        do_eval: str = False,
        save: bool = False,
        nameX: str = 'ner',
        mask: tuple = ('s', 's'),
):
    name = "_".join(
        map(str, [
            nameX, task, lang, mask[0], mask[1], loss, batch_size, max_seq_len,
            do_train, do_eval
        ]))
    log_dir = Path(log_dir)
    log_dir.mkdir(parents=True, exist_ok=True)
    init_logger(log_dir, name)

    if rand_seed:
        random.seed(rand_seed)
        np.random.seed(rand_seed)
        torch.manual_seed(rand_seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(rand_seed)

    trainset = dataset + lang + '/train.txt'
    devset = dataset + lang + '/dev.txt'
    testset = dataset + lang + '/test.txt'

    bert_model = 'bert-base-cased' if lang == 'eng' else 'bert-base-multilingual-cased'
    print(f'Lang: {lang}\nModel: {bert_model}\nRun: {name}')
    model = BertForTokenClassification.from_pretrained(bert_model,
                                                       num_labels=len(VOCAB),
                                                       cache_dir='bertm')
    if tuned_learner:
        print('Loading pretrained learner: ', tuned_learner)
        model.bert.load_state_dict(torch.load(tuned_learner))

    model = torch.nn.DataParallel(model)
    model_lr_group = bert_layer_list(model)
    layers = len(model_lr_group)
    kwargs = {'max_seq_len': max_seq_len, 'ds_size': ds_size, 'mask': mask}

    train_dl = DataLoader(dataset=NerDataset(trainset,
                                             bert_model,
                                             train=True,
                                             **kwargs),
                          batch_size=batch_size,
                          shuffle=True,
                          collate_fn=partial(pad, train=True))

    dev_dl = DataLoader(dataset=NerDataset(devset, bert_model, **kwargs),
                        batch_size=batch_size,
                        shuffle=False,
                        collate_fn=pad)

    test_dl = DataLoader(dataset=NerDataset(testset, bert_model, **kwargs),
                         batch_size=batch_size,
                         shuffle=False,
                         collate_fn=pad)

    data = DataBunch(train_dl=train_dl,
                     valid_dl=dev_dl,
                     test_dl=test_dl,
                     collate_fn=pad,
                     path=Path(data_bunch_path))

    train_opt_steps = int(len(train_dl.dataset) / batch_size) * epochs
    optim = BertAdam(model.parameters(),
                     lr=0.01,
                     warmup=warmup_proportion,
                     t_total=train_opt_steps)

    loss_fun = ner_loss_func if loss == 'cross' else partial(ner_loss_func,
                                                             zero=True)
    metrics = [Conll_F1()]

    learn = Learner(
        data,
        model,
        BertAdam,
        loss_func=loss_fun,
        metrics=metrics,
        true_wd=False,
        layer_groups=model_lr_group,
        path='learn' + nameX,
    )

    learn.opt = OptimWrapper(optim)

    lrm = 1.6

    # select set of starting lrs
    lrs_eng = [0.01, 5e-4, 3e-4, 3e-4, 1e-5]
    lrs_deu = [0.01, 5e-4, 5e-4, 3e-4, 2e-5]

    startlr = lrs_eng if lang == 'eng' else lrs_deu
    results = [['epoch', 'lr', 'f1', 'val_loss', 'train_loss', 'train_losses']]
    if do_train:
        learn.freeze()
        learn.fit_one_cycle(1, startlr[0], moms=(0.8, 0.7))
        learn.freeze_to(-3)
        lrs = learn.lr_range(slice(startlr[1] / (1.6**15), startlr[1]))
        learn.fit_one_cycle(1, lrs, moms=(0.8, 0.7))
        learn.freeze_to(-6)
        lrs = learn.lr_range(slice(startlr[2] / (1.6**15), startlr[2]))
        learn.fit_one_cycle(1, lrs, moms=(0.8, 0.7))
        learn.freeze_to(-12)
        lrs = learn.lr_range(slice(startlr[3] / (1.6**15), startlr[3]))
        learn.fit_one_cycle(1, lrs, moms=(0.8, 0.7))
        learn.unfreeze()
        lrs = learn.lr_range(slice(startlr[4] / (1.6**15), startlr[4]))
        learn.fit_one_cycle(1, lrs, moms=(0.8, 0.7))

    if do_eval:
        res = learn.validate(test_dl, metrics=metrics)
        met_res = [f'{m.__name__}: {r}' for m, r in zip(metrics, res[1:])]
        print(f'Validation on TEST SET:\nloss {res[0]}, {met_res}')
        results.append(['val', '-', res[1], res[0], '-', '-'])

    with open(log_dir / (name + '.csv'), 'a') as resultFile:
        wr = csv.writer(resultFile)
        wr.writerows(results)
コード例 #2
0
def run_ner(
        lang: str = 'eng',
        log_dir: str = 'logs',
        task: str = NER,
        batch_size: int = 1,
        lr: float = 5e-5,
        epochs: int = 1,
        dataset: str = 'data/conll-2003/',
        loss: str = 'cross',
        max_seq_len: int = 128,
        do_lower_case: bool = False,
        warmup_proportion: float = 0.1,
        grad_acc_steps: int = 1,
        rand_seed: int = None,
        fp16: bool = False,
        loss_scale: float = None,
        ds_size: int = None,
        data_bunch_path: str = 'data/conll-2003/db',
        bertAdam: bool = False,
        freez: bool = False,
        one_cycle: bool = False,
        discr: bool = False,
        lrm: int = 2.6,
        div: int = None,
        tuned_learner: str = None,
        do_train: str = False,
        do_eval: str = False,
        save: bool = False,
        name: str = 'ner',
        mask: tuple = ('s', 's'),
):
    name = "_".join(
        map(str, [
            name, task, lang, mask[0], mask[1], loss, batch_size, lr,
            max_seq_len, do_train, do_eval
        ]))

    log_dir = Path(log_dir)
    log_dir.mkdir(parents=True, exist_ok=True)
    init_logger(log_dir, name)

    if rand_seed:
        random.seed(rand_seed)
        np.random.seed(rand_seed)
        torch.manual_seed(rand_seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(rand_seed)

    trainset = dataset + lang + '/train.txt'
    devset = dataset + lang + '/dev.txt'
    testset = dataset + lang + '/test.txt'

    bert_model = 'bert-base-cased' if lang == 'eng' else 'bert-base-multilingual-cased'
    print(f'Lang: {lang}\nModel: {bert_model}\nRun: {name}')
    model = BertForTokenClassification.from_pretrained(bert_model,
                                                       num_labels=len(VOCAB),
                                                       cache_dir='bertm')

    model = torch.nn.DataParallel(model)
    model_lr_group = bert_layer_list(model)
    layers = len(model_lr_group)
    kwargs = {'max_seq_len': max_seq_len, 'ds_size': ds_size, 'mask': mask}

    train_dl = DataLoader(dataset=NerDataset(trainset,
                                             bert_model,
                                             train=True,
                                             **kwargs),
                          batch_size=batch_size,
                          shuffle=True,
                          collate_fn=partial(pad, train=True))

    dev_dl = DataLoader(dataset=NerDataset(devset, bert_model, **kwargs),
                        batch_size=batch_size,
                        shuffle=False,
                        collate_fn=pad)

    test_dl = DataLoader(dataset=NerDataset(testset, bert_model, **kwargs),
                         batch_size=batch_size,
                         shuffle=False,
                         collate_fn=pad)

    data = DataBunch(train_dl=train_dl,
                     valid_dl=dev_dl,
                     test_dl=test_dl,
                     collate_fn=pad,
                     path=Path(data_bunch_path))

    loss_fun = ner_loss_func if loss == 'cross' else partial(ner_loss_func,
                                                             zero=True)
    metrics = [Conll_F1()]

    learn = Learner(
        data,
        model,
        BertAdam,
        loss_func=loss_fun,
        metrics=metrics,
        true_wd=False,
        layer_groups=None if not freez else model_lr_group,
        path='learn',
    )

    # initialise bert adam optimiser
    train_opt_steps = int(len(train_dl.dataset) / batch_size) * epochs
    optim = BertAdam(model.parameters(),
                     lr=lr,
                     warmup=warmup_proportion,
                     t_total=train_opt_steps)

    if bertAdam: learn.opt = OptimWrapper(optim)
    else: print("No Bert Adam")

    # load fine-tuned learner
    if tuned_learner:
        print('Loading pretrained learner: ', tuned_learner)
        learn.load(tuned_learner)

    # Uncomment to graph learning rate plot
    # learn.lr_find()
    # learn.recorder.plot(skip_end=15)

    # set lr (discriminative learning rates)
    if div: layers = div
    lrs = lr if not discr else learn.lr_range(slice(lr / lrm**(layers), lr))

    results = [['epoch', 'lr', 'f1', 'val_loss', 'train_loss', 'train_losses']]

    if do_train:
        for epoch in range(epochs):
            if freez:
                lay = (layers // (epochs - 1)) * epoch * -1
                if lay == 0:
                    print('Freeze')
                    learn.freeze()
                elif lay == layers:
                    print('unfreeze')
                    learn.unfreeze()
                else:
                    print('freeze2')
                    learn.freeze_to(lay)
                print('Freezing layers ', lay, ' off ', layers)

            # Fit Learner - eg train model
            if one_cycle: learn.fit_one_cycle(1, lrs, moms=(0.8, 0.7))
            else: learn.fit(1, lrs)

            results.append([
                epoch,
                lrs,
                learn.recorder.metrics[0][0],
                learn.recorder.val_losses[0],
                np.array(learn.recorder.losses).mean(),
                learn.recorder.losses,
            ])

            if save:
                m_path = learn.save(f"{lang}_{epoch}_model", return_path=True)
                print(f'Saved model to {m_path}')
    if save: learn.export(f'{lang}.pkl')

    if do_eval:
        res = learn.validate(test_dl, metrics=metrics)
        met_res = [f'{m.__name__}: {r}' for m, r in zip(metrics, res[1:])]
        print(f'Validation on TEST SET:\nloss {res[0]}, {met_res}')
        results.append(['val', '-', res[1], res[0], '-', '-'])

    with open(log_dir / (name + '.csv'), 'a') as resultFile:
        wr = csv.writer(resultFile)
        wr.writerows(results)
コード例 #3
0
class DeepLab(ArcGISModel):
    """
    Creates a ``DeepLab`` Semantic segmentation object

    =====================   ===========================================
    **Argument**            **Description**
    ---------------------   -------------------------------------------
    data                    Required fastai Databunch. Returned data object from
                            ``prepare_data`` function.
    ---------------------   -------------------------------------------
    backbone                Optional function. Backbone CNN model to be used for
                            creating the base of the `DeepLab`, which
                            is `resnet101` by default since it is pretrained in
                            torchvision. It supports the ResNet,
                            DenseNet, and VGG families.
    ---------------------   -------------------------------------------
    pretrained_path         Optional string. Path where pre-trained model is
                            saved.
    =====================   ===========================================

    **kwargs**

    =====================   ===========================================
    **Argument**            **Description**
    ---------------------   -------------------------------------------
    class_balancing         Optional boolean. If True, it will balance the
                            cross-entropy loss inverse to the frequency
                            of pixels per class. Default: False. 
    ---------------------   -------------------------------------------
    mixup                   Optional boolean. If True, it will use mixup
                            augmentation and mixup loss. Default: False
    ---------------------   -------------------------------------------
    focal_loss              Optional boolean. If True, it will use focal loss.
                            Default: False
    ---------------------   -------------------------------------------
    ignore_classes          Optional list. It will contain the list of class
                            values on which model will not incur loss.
                            Default: []                                                       
    =====================   ===========================================     

    :returns: ``DeepLab`` Object
    """
    def __init__(self,
                 data,
                 backbone=None,
                 pretrained_path=None,
                 *args,
                 **kwargs):
        # Set default backbone to be 'resnet101'
        if backbone is None:
            backbone = models.resnet101

        super().__init__(data, backbone)

        self._ignore_classes = kwargs.get('ignore_classes', [])
        if self._ignore_classes != [] and len(data.classes) <= 3:
            raise Exception(
                f"`ignore_classes` parameter can only be used when the dataset has more than 2 classes."
            )

        data_classes = list(self._data.class_mapping.keys())
        if 0 not in list(data.class_mapping.values()):
            self._ignore_mapped_class = [
                data_classes.index(k) + 1 for k in self._ignore_classes
                if k != 0
            ]
        else:
            self._ignore_mapped_class = [
                data_classes.index(k) + 1 for k in self._ignore_classes
            ]
        if self._ignore_classes != []:
            if 0 not in self._ignore_mapped_class:
                self._ignore_mapped_class.insert(0, 0)
            global accuracy
            accuracy = partial(accuracy,
                               ignore_mapped_class=self._ignore_mapped_class)

        self.mixup = kwargs.get('mixup', False)
        self.class_balancing = kwargs.get('class_balancing', False)
        self.focal_loss = kwargs.get('focal_loss', False)

        _backbone = self._backbone
        if hasattr(self, '_orig_backbone'):
            _backbone = self._orig_backbone

        if not self._check_backbone_support(_backbone):
            raise Exception(
                f"Enter only compatible backbones from {', '.join(self.supported_backbones)}"
            )

        self._code = image_classifier_prf
        if self._backbone.__name__ is 'resnet101':
            model = _create_deeplab(data.c)
            if self._is_multispectral:
                model = _change_tail(model, data)
        else:
            model = Deeplab(data.c, self._backbone, data.chip_size)

        if not _isnotebook() and os.name == 'posix':
            _set_ddp_multigpu(self)
            if self._multigpu_training:
                self.learn = Learner(data, model,
                                     metrics=accuracy).to_distributed(
                                         self._rank_distributed)
            else:
                self.learn = Learner(data, model, metrics=accuracy)
        else:
            self.learn = Learner(data, model, metrics=accuracy)

        self.learn.loss_func = self._deeplab_loss

        ## setting class_weight if present in data
        if self.class_balancing and self._data.class_weight is not None:
            class_weight = torch.tensor(
                [self._data.class_weight.mean()] +
                self._data.class_weight.tolist()).float().to(self._device)
        else:
            class_weight = None

        ## Raising warning in apropriate case
        if self.class_balancing:
            if self._data.class_weight is None:
                logger.warning(
                    "Could not find 'NumPixelsPerClass' in 'esri_accumulated_stats.json'. Ignoring `class_balancing` parameter."
                )
            elif getattr(data, 'overflow_encountered', False):
                logger.warning(
                    "Overflow Encountered. Ignoring `class_balancing` parameter."
                )
                class_weight = [1] * len(data.classes)

        ## Setting class weights for ignored classes
        if self._ignore_classes != []:
            if not self.class_balancing:
                class_weight = torch.tensor([1] * data.c).float().to(
                    self._device)
            class_weight[self._ignore_mapped_class] = 0.
        else:
            class_weight = None

        self._final_class_weight = class_weight

        if self.focal_loss:
            self.learn.loss_func = FocalLoss(self.learn.loss_func)
        if self.mixup:
            self.learn.callbacks.append(MixUpCallback(self.learn))

        self.learn.model = self.learn.model.to(self._device)
        self._freeze()
        self._arcgis_init_callback()  # make first conv weights learnable
        if pretrained_path is not None:
            self.load(pretrained_path)

    @property
    def supported_backbones(self):
        return DeepLab._supported_backbones()

    @staticmethod
    def _supported_backbones():
        return [*_resnet_family, *_densenet_family, *_vgg_family]

    @classmethod
    def from_model(cls, emd_path, data=None):
        """
        Creates a ``DeepLab`` semantic segmentation object from an Esri Model Definition (EMD) file.

        =====================   ===========================================
        **Argument**            **Description**
        ---------------------   -------------------------------------------
        emd_path                Required string. Path to Esri Model Definition
                                file.
        ---------------------   -------------------------------------------
        data                    Required fastai Databunch or None. Returned data
                                object from ``prepare_data`` function or None for
                                inferencing.

        =====================   ===========================================

        :returns: `DeepLab` Object
        """

        emd_path = Path(emd_path)
        with open(emd_path) as f:
            emd = json.load(f)

        model_file = Path(emd['ModelFile'])

        if not model_file.is_absolute():
            model_file = emd_path.parent / model_file

        model_params = emd['ModelParameters']

        try:
            class_mapping = {i['Value']: i['Name'] for i in emd['Classes']}
            color_mapping = {i['Value']: i['Color'] for i in emd['Classes']}
        except KeyError:
            class_mapping = {
                i['ClassValue']: i['ClassName']
                for i in emd['Classes']
            }
            color_mapping = {
                i['ClassValue']: i['Color']
                for i in emd['Classes']
            }

        if data is None:
            empty_data = _EmptyData(path=emd_path.parent.parent,
                                    loss_func=None,
                                    c=len(class_mapping) + 1,
                                    chip_size=emd['ImageHeight'])
            empty_data.class_mapping = class_mapping
            empty_data.color_mapping = color_mapping
            empty_data = get_multispectral_data_params_from_emd(
                empty_data, emd)
            empty_data.emd_path = emd_path
            empty_data.emd = emd
            return cls(empty_data,
                       **model_params,
                       pretrained_path=str(model_file))
        else:
            return cls(data, **model_params, pretrained_path=str(model_file))

    def _get_emd_params(self):
        import random
        _emd_template = {}
        _emd_template["Framework"] = "arcgis.learn.models._inferencing"
        _emd_template["ModelConfiguration"] = "_deeplab_infrencing"
        _emd_template["InferenceFunction"] = "ArcGISImageClassifier.py"
        _emd_template["ExtractBands"] = [0, 1, 2]
        _emd_template["ignore_mapped_class"] = self._ignore_mapped_class
        _emd_template['Classes'] = []
        class_data = {}
        for i, class_name in enumerate(
                self._data.classes[1:]):  # 0th index is background
            inverse_class_mapping = {
                v: k
                for k, v in self._data.class_mapping.items()
            }
            class_data["Value"] = inverse_class_mapping[class_name]
            class_data["Name"] = class_name
            color = [random.choice(range(256)) for i in range(3)] if is_no_color(self._data.color_mapping) else \
            self._data.color_mapping[inverse_class_mapping[class_name]]
            class_data["Color"] = color
            _emd_template['Classes'].append(class_data.copy())

        return _emd_template

    def accuracy(self):
        return self.learn.validate()[-1].tolist()

    @property
    def _model_metrics(self):
        return {'accuracy': '{0:1.4e}'.format(self._get_model_metrics())}

    def _get_model_metrics(self, **kwargs):
        checkpoint = kwargs.get('checkpoint', True)
        if not hasattr(self.learn, 'recorder'):
            return 0.0

        model_accuracy = self.learn.recorder.metrics[-1][0]
        if checkpoint:
            model_accuracy = np.max(self.learn.recorder.metrics)
        return float(model_accuracy)

    def _deeplab_loss(self, outputs, targets, **kwargs):
        targets = targets.squeeze(1).detach()

        criterion = nn.CrossEntropyLoss(weight=self._final_class_weight).to(
            self._device)
        if self.learn.model.training:
            out = outputs[0]
            aux = outputs[1]
        else:  # validation
            out = outputs
        main_loss = criterion(out, targets)

        if self.learn.model.training:
            aux_loss = criterion(aux, targets)
            total_loss = main_loss + 0.4 * aux_loss
            return total_loss
        else:
            return main_loss

    def _freeze(self):
        "Freezes the pretrained backbone."
        for idx, i in enumerate(flatten_model(self.learn.model)):
            if isinstance(i, (nn.BatchNorm2d)):
                continue
            if hasattr(i, 'dilation'):
                dilation = i.dilation
                dilation = dilation[0] if isinstance(dilation,
                                                     tuple) else dilation
                if dilation > 1:
                    break
            for p in i.parameters():
                p.requires_grad = False

        self.learn.layer_groups = split_model_idx(
            self.learn.model, [idx]
        )  ## Could also call self.learn.freeze after this line because layer groups are now present.
        self.learn.create_opt(lr=3e-3)

    def unfreeze(self):
        for _, param in self.learn.model.named_parameters():
            param.requires_grad = True

    def show_results(self, rows=5, **kwargs):
        """
        Displays the results of a trained model on a part of the validation set.
        """
        self._check_requisites()
        if rows > len(self._data.valid_ds):
            rows = len(self._data.valid_ds)
        self.learn.show_results(rows=rows,
                                ignore_mapped_class=self._ignore_mapped_class,
                                **kwargs)

    def _show_results_multispectral(self,
                                    rows=5,
                                    alpha=0.7,
                                    **kwargs):  # parameters adjusted in kwargs
        ax = show_results_multispectral(self,
                                        nrows=rows,
                                        alpha=alpha,
                                        **kwargs)

    def mIOU(self, mean=False, show_progress=True):
        """
        Computes mean IOU on the validation set for each class.

        =====================   ===========================================
        **Argument**            **Description**
        ---------------------   -------------------------------------------
        mean                    Optional bool. If False returns class-wise
                                mean IOU, otherwise returns mean iou of all
                                classes combined.
        ---------------------   -------------------------------------------
        show_progress           Optional bool. Displays the prgress bar if
                                True.                                         
        =====================   ===========================================
        
        :returns: `dict` if mean is False otherwise `float`
        """
        num_classes = torch.arange(self._data.c)
        miou = compute_miou(self, self._data.valid_dl, mean, num_classes,
                            show_progress, self._ignore_mapped_class)
        if mean:
            return np.mean(miou)
        if self._ignore_mapped_class == []:
            return dict(zip(['0'] + self._data.classes[1:], miou))
        else:
            class_values = [0] + list(self._data.class_mapping.keys())
            return {
                class_values[i]: miou[i]
                for i in range(len(miou)) if i not in self._ignore_mapped_class
            }

    def per_class_metrics(self):
        """
        Computer per class precision, recall and f1-score on validation set.
        """
        ## Calling imported function `per_class_metrics`
        return per_class_metrics(self,
                                 ignore_mapped_class=self._ignore_mapped_class)
コード例 #4
0
ファイル: _deeplab.py プロジェクト: Samakwa/VRP-TCC-For-RSS
class DeepLab(ArcGISModel):
    """
    Creates a ``DeepLab`` Semantic segmentation object

    =====================   ===========================================
    **Argument**            **Description**
    ---------------------   -------------------------------------------
    data                    Required fastai Databunch. Returned data object from
                            ``prepare_data`` function.
    ---------------------   -------------------------------------------
    backbone                Optional function. Backbone CNN model to be used for
                            creating the base of the `DeepLab`, which
                            is `resnet101` by default since it is pretrained in
                            torchvision. It supports the ResNet,
                            DenseNet, and VGG families.
    ---------------------   -------------------------------------------
    pretrained_path         Optional string. Path where pre-trained model is
                            saved.
    =====================   ===========================================

    :returns: ``DeepLab`` Object
    """
    def __init__(self, data, backbone=None, pretrained_path=None):
        # Set default backbone to be 'resnet101'
        if backbone is None:
            backbone = models.resnet101

        super().__init__(data, backbone)

        _backbone = self._backbone
        if hasattr(self, '_orig_backbone'):
            _backbone = self._orig_backbone

        if not self._check_backbone_support(_backbone):
            raise Exception(
                f"Enter only compatible backbones from {', '.join(self.supported_backbones)}"
            )

        self._code = image_classifier_prf
        if self._backbone.__name__ is 'resnet101':
            model = _create_deeplab(data.c)
            if self._is_multispectral:
                model = _change_tail(model, data)
        else:
            model = Deeplab(data.c, self._backbone, data.chip_size)

        self.learn = Learner(data, model, metrics=self._accuracy)
        self.learn.loss_func = self._deeplab_loss
        self.learn.model = self.learn.model.to(self._device)
        self._freeze()
        self._arcgis_init_callback()  # make first conv weights learnable
        if pretrained_path is not None:
            self.load(pretrained_path)

    @property
    def supported_backbones(self):
        return DeepLab._supported_backbones()

    @staticmethod
    def _supported_backbones():
        return [*_resnet_family, *_densenet_family, *_vgg_family]

    @classmethod
    def from_model(cls, emd_path, data=None):
        """
        Creates a ``DeepLab`` semantic segmentation object from an Esri Model Definition (EMD) file.

        =====================   ===========================================
        **Argument**            **Description**
        ---------------------   -------------------------------------------
        emd_path                Required string. Path to Esri Model Definition
                                file.
        ---------------------   -------------------------------------------
        data                    Required fastai Databunch or None. Returned data
                                object from ``prepare_data`` function or None for
                                inferencing.

        =====================   ===========================================

        :returns: `DeepLab` Object
        """

        emd_path = Path(emd_path)
        with open(emd_path) as f:
            emd = json.load(f)

        model_file = Path(emd['ModelFile'])

        if not model_file.is_absolute():
            model_file = emd_path.parent / model_file

        model_params = emd['ModelParameters']

        try:
            class_mapping = {i['Value']: i['Name'] for i in emd['Classes']}
            color_mapping = {i['Value']: i['Color'] for i in emd['Classes']}
        except KeyError:
            class_mapping = {
                i['ClassValue']: i['ClassName']
                for i in emd['Classes']
            }
            color_mapping = {
                i['ClassValue']: i['Color']
                for i in emd['Classes']
            }

        if data is None:
            empty_data = _EmptyData(path=emd_path.parent.parent,
                                    loss_func=None,
                                    c=len(class_mapping) + 1,
                                    chip_size=emd['ImageHeight'])
            empty_data.class_mapping = class_mapping
            empty_data.color_mapping = color_mapping
            empty_data = get_multispectral_data_params_from_emd(
                empty_data, emd)
            empty_data.emd_path = emd_path
            empty_data.emd = emd
            return cls(empty_data,
                       **model_params,
                       pretrained_path=str(model_file))
        else:
            return cls(data, **model_params, pretrained_path=str(model_file))

    def _get_emd_params(self):
        import random
        _emd_template = {}
        _emd_template["Framework"] = "arcgis.learn.models._inferencing"
        _emd_template["ModelConfiguration"] = "_deeplab_infrencing"
        _emd_template["InferenceFunction"] = "ArcGISImageClassifier.py"

        _emd_template["ExtractBands"] = [0, 1, 2]
        _emd_template['Classes'] = []
        class_data = {}
        for i, class_name in enumerate(
                self._data.classes[1:]):  # 0th index is background
            inverse_class_mapping = {
                v: k
                for k, v in self._data.class_mapping.items()
            }
            class_data["Value"] = inverse_class_mapping[class_name]
            class_data["Name"] = class_name
            color = [random.choice(range(256)) for i in range(3)] if is_no_color(self._data.color_mapping) else \
            self._data.color_mapping[inverse_class_mapping[class_name]]
            class_data["Color"] = color
            _emd_template['Classes'].append(class_data.copy())

        return _emd_template

    def accuracy(self):
        return self.learn.validate()[-1].tolist()

    @property
    def _model_metrics(self):
        return {'accuracy': '{0:1.4e}'.format(self._get_model_metrics())}

    def _get_model_metrics(self, **kwargs):
        checkpoint = kwargs.get('checkpoint', True)
        if not hasattr(self.learn, 'recorder'):
            return 0.0

        model_accuracy = self.learn.recorder.metrics[-1][0]
        if checkpoint:
            model_accuracy = np.max(self.learn.recorder.metrics)
        return float(model_accuracy)

    def _deeplab_loss(self, outputs, targets):
        targets = targets.squeeze(1).detach()
        criterion = nn.CrossEntropyLoss().to(self._device)
        if self.learn.model.training:
            out = outputs[0]
            aux = outputs[1]
        else:  # validation
            out = outputs
        main_loss = criterion(out, targets)

        if self.learn.model.training:
            aux_loss = criterion(aux, targets)
            total_loss = main_loss + 0.4 * aux_loss
            return total_loss
        else:
            return main_loss

    def _freeze(self):
        "Freezes the pretrained backbone."
        for idx, i in enumerate(flatten_model(self.learn.model)):
            if isinstance(i, (nn.BatchNorm2d)):
                continue
            if hasattr(i, 'dilation'):
                dilation = i.dilation
                dilation = dilation[0] if isinstance(dilation,
                                                     tuple) else dilation
                if dilation > 1:
                    break
            for p in i.parameters():
                p.requires_grad = False

        self.learn.layer_groups = split_model_idx(
            self.learn.model, [idx]
        )  ## Could also call self.learn.freeze after this line because layer groups are now present.
        self.learn.create_opt(lr=3e-3)

    def unfreeze(self):
        for _, param in self.learn.model.named_parameters():
            param.requires_grad = True

    def show_results(self, rows=5, **kwargs):
        """
        Displays the results of a trained model on a part of the validation set.
        """
        self._check_requisites()
        if rows > len(self._data.valid_ds):
            rows = len(self._data.valid_ds)
        self.learn.show_results(rows=rows, **kwargs)

    def _show_results_multispectral(self,
                                    rows=5,
                                    alpha=0.7,
                                    **kwargs):  # parameters adjusted in kwargs
        ax = show_results_multispectral(self,
                                        nrows=rows,
                                        alpha=alpha,
                                        **kwargs)

    def _accuracy(self, input, target, void_code=0, class_mapping=None):
        if self.learn.model.training:  # while training
            input = input[0]

        target = target.squeeze(1)
        return (input.argmax(dim=1) == target).float().mean()

    def mIOU(self, mean=False, show_progress=True):
        """
        Computes mean IOU on the validation set for each class.

        =====================   ===========================================
        **Argument**            **Description**
        ---------------------   -------------------------------------------
        mean                    Optional bool. If False returns class-wise
                                mean IOU, otherwise returns mean iou of all
                                classes combined.
        ---------------------   -------------------------------------------
        show_progress           Optional bool. Displays the prgress bar if
                                True.                                         
        =====================   ===========================================
        
        :returns: `dict` if mean is False otherwise `float`
        """
        num_classes = torch.arange(self._data.c)
        miou = compute_miou(self, self._data.valid_dl, mean, num_classes,
                            show_progress)
        if mean:
            return np.mean(miou)
        return dict(zip(['0'] + self._data.classes[1:], miou))
コード例 #5
0
ファイル: train.py プロジェクト: lewfish/mlx
def train(config_path, opts):
    tmp_dir_obj = tempfile.TemporaryDirectory()
    tmp_dir = tmp_dir_obj.name

    cfg = load_config(config_path, opts)
    print(cfg)

    # Setup data
    databunch, full_databunch = build_databunch(cfg, tmp_dir)
    output_dir = setup_output_dir(cfg, tmp_dir)
    print(full_databunch)

    plotter = build_plotter(cfg)
    if not cfg.lr_find_mode and not cfg.predict_mode:
        plotter.plot_data(databunch, output_dir)

    # Setup model
    num_labels = databunch.c
    model = build_model(cfg, num_labels)
    metrics = [CocoMetric(num_labels)]
    learn = Learner(databunch, model, path=output_dir, metrics=metrics)
    fastai.basic_train.loss_batch = loss_batch
    best_model_path = join(output_dir, 'best_model.pth')
    last_model_path = join(output_dir, 'last_model.pth')

    # Train model
    callbacks = [
        MyCSVLogger(learn, filename='log'),
        SubLossMetric(learn, model.subloss_names)
    ]

    if cfg.output_uri.startswith('s3://'):
        callbacks.append(
            SyncCallback(output_dir, cfg.output_uri, cfg.solver.sync_interval))

    if cfg.model.init_weights:
        device = next(model.parameters()).device
        model.load_state_dict(
            torch.load(cfg.model.init_weights, map_location=device))

    if not cfg.predict_mode:
        if cfg.overfit_mode:
            learn.fit_one_cycle(cfg.solver.num_epochs, cfg.solver.lr, callbacks=callbacks)
            torch.save(learn.model.state_dict(), best_model_path)
            learn.model.eval()
            print('Validating on training set...')
            learn.validate(full_databunch.train_dl, metrics=metrics)
        else:
            tb_logger = TensorboardLogger(learn, 'run')
            tb_logger.set_extra_args(
                model.subloss_names, cfg.overfit_mode)

            extra_callbacks = [
                MySaveModelCallback(
                    learn, best_model_path, monitor='coco_metric', every='improvement'),
                MySaveModelCallback(learn, last_model_path, every='epoch'),
                TrackEpochCallback(learn),
            ]
            callbacks.extend(extra_callbacks)
            if cfg.lr_find_mode:
                learn.lr_find()
                learn.recorder.plot(suggestion=True, return_fig=True)
                lr = learn.recorder.min_grad_lr
                print('lr_find() found lr: {}'.format(lr))
                exit()

            learn.fit_one_cycle(cfg.solver.num_epochs, cfg.solver.lr, callbacks=callbacks)
            print('Validating on full validation set...')
            learn.validate(full_databunch.valid_dl, metrics=metrics)
    else:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        model.load_state_dict(
            torch.load(join(output_dir, 'best_model.pth'), map_location=device))
        model.eval()
        plot_dataset = databunch.train_ds

    print('Plotting predictions...')
    plot_dataset = databunch.train_ds if cfg.overfit_mode else databunch.valid_ds
    plotter.make_debug_plots(plot_dataset, model, databunch.classes, output_dir)
    if cfg.output_uri.startswith('s3://'):
        sync_to_dir(output_dir, cfg.output_uri)