예제 #1
0
def get_learn(data):
    # create model
    model = nb_resnet_unet.get_unet_res18(1, True)
    model.load_state_dict(torch.load('./models/unet_res18_allres_init.pth'))

    # create learner
    learn = Learner(data, model)

    # split model
    learn.layer_groups = split_model(learn.model)

    # set multi-gpu
    if data.device.type == 'cuda':
        learn.model = torch.nn.DataParallel(learn.model,
                                            device_ids=[0, 1, 2, 3])

    # set loss func
#     learn.loss_func = partial(nb_loss_metrics.combo_loss, balance_ratio=1)
#     learn.loss_func = nb_loss_metrics.dice_loss
    learn.loss_func = partial(nb_loss_metrics.balance_bce, balance_ratio=1)

    # 添加metrics
    learn.metrics += [nb_loss_metrics.dice_loss]
    learn.metrics += [partial(nb_loss_metrics.balance_bce, balance_ratio=1)]
    learn.metrics += [nb_loss_metrics.mask_iou]

    return learn
예제 #2
0
def get_learn_detectsym_17clas(data, gaf, clas_weights=weights):
    '''
    用的符号检测的17个类别的数据集
    '''
    # create model
    model = resnet_ssd.get_resnet18_1ssd(num_classes=17)
    model.load_state_dict(torch.load('./models/pretrained_res18_1ssd.pth'))

    # create learner
    learn = Learner(data, model)

    # split model
    learn.layer_groups = split_model(learn.model)

    # set multi-gpu
    if data.device.type == 'cuda':
        learn.model = torch.nn.DataParallel(
            learn.model, device_ids=device_ids)  #device_ids=[0,1,2,3,4,5])

    # set loss func
    learn.loss_func = partial(anchors_loss_metrics.yolo_L,
                              gaf=gaf,
                              conf_th=1,
                              clas_weights=clas_weights,
                              lambda_nconf=10)

    # 添加metrics
    learn.metrics += [
        partial(anchors_loss_metrics.clas_L,
                gaf=gaf,
                clas_weights=clas_weights)
    ]
    learn.metrics += [
        partial(anchors_loss_metrics.cent_L,
                gaf=gaf,
                clas_weights=clas_weights)
    ]
    learn.metrics += [
        partial(anchors_loss_metrics.pConf_L,
                gaf=gaf,
                clas_weights=clas_weights)
    ]
    learn.metrics += [
        partial(anchors_loss_metrics.nConf_L, gaf=gaf, conf_th=1)
    ]
    learn.metrics += [partial(anchors_loss_metrics.clas_acc, gaf=gaf)]
    learn.metrics += [partial(anchors_loss_metrics.cent_d, gaf=gaf)]

    return learn
예제 #3
0
파일: model.py 프로젝트: zeta1999/keraTorch
class Sequential:
    def __init__(self, model=None):
        self.layers = []
        self.last_dim = None
        self.model = model
        self.device = torch.device('cpu')
        if torch.cuda.is_available():
            self.device = torch.device('cuda')

    def add(self, layer):
        layer = layer.get_layer(self.last_dim)
        self.last_dim = layer['output_dim']
        self.layers.extend(layer['layers'])

    def compile(self, loss, optimizer=None):
        if len(self.layers) > 0:
            self.model = nn.Sequential(*self.layers)
        self.loss = loss

    def fit(self, x, y, bs, epochs, lr=1e-3, one_cycle=True, get_lr=True):
        db = create_db(x, y, bs=bs)
        self.learn = Learner(db, self.model, loss_func=self.loss)
        if one_cycle:
            self.learn.fit_one_cycle(epochs, lr)
        else:
            self.learn.fit(epochs, lr)

    def lr_find(self, x, y, bs):
        db = create_db(x, y, bs=bs)
        learn = Learner(db, self.model, loss_func=self.loss)
        learn.lr_find()
        clear_output()
        learn.recorder.plot(suggestion=True)

    def predict(self, x):
        self.learn.model.eval()
        with torch.no_grad():
            y_preds = self.learn.model(torch.Tensor(x).to(device))
        return y_preds.cpu().numpy()
예제 #4
0
        num_workers=NUM_WORKERS,
        normalization=STATISTICS,
    )

    # init model
    swa_model = MODEL(num_classes=N_CLASSES, dropout_p=DROPOUT)
    model = MODEL(num_classes=N_CLASSES, dropout_p=DROPOUT)

    # nullify all swa model parameters
    swa_params = swa_model.parameters()
    for swa_param in swa_params:
        swa_param.data = torch.zeros_like(swa_param.data)

    # average model
    n_swa = len(os.listdir(MODELS_FOLDER))
    print(f"Averaging {n_swa} models")
    for file in os.listdir(MODELS_FOLDER):
        model.load_state_dict(torch.load(f'{MODELS_FOLDER}/{file}')['model'])
        model_params = model.parameters()
        for model_param, swa_param in zip(model_params, swa_params):
            swa_param.data += model_param.data / n_swa

    # fix batch norm
    print("Fixing batch norm")
    swa_model.to(DEVICE)
    learn = Learner(data, model, model_dir=MODELS_FOLDER, loss_func=CRITERION, opt_func=OPTIMIZER, wd=WD)
    learn.model = convert_model(learn.model)
    learn.model = nn.DataParallel(learn.model).to(DEVICE)
    fix_batchnorm(learn.model, learn.data.train_dl)
    learn.save('swa_model')
예제 #5
0
class MaskRCNN(ArcGISModel):
    """
    Creates a ``MaskRCNN`` Instance segmentation object

    =====================   ===========================================
    **Argument**            **Description**
    ---------------------   -------------------------------------------
    data                    Required fastai Databunch. Returned data object from
                            ``prepare_data`` function.
    ---------------------   -------------------------------------------
    backbone                Optional function. Backbone CNN model to be used for
                            creating the base of the `MaskRCNN`, which
                            is `resnet50` by default. 
                            Compatible backbones: 'resnet50'
    ---------------------   -------------------------------------------
    pretrained_path         Optional string. Path where pre-trained model is
                            saved.
    =====================   ===========================================

    :returns: ``MaskRCNN`` Object
    """
    def __init__(self, data, backbone=None, pretrained_path=None):

        super().__init__(data, backbone)

        self._backbone = models.resnet50

        #if not self._check_backbone_support(self._backbone):
        #    raise Exception (f"Enter only compatible backbones from {', '.join(self.supported_backbones)}")

        self._code = instance_detector_prf

        model = models.detection.maskrcnn_resnet50_fpn(pretrained=True,
                                                       min_size=data.chip_size)
        in_features = model.roi_heads.box_predictor.cls_score.in_features
        model.roi_heads.box_predictor = FastRCNNPredictor(in_features, data.c)
        in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
        hidden_layer = 256
        model.roi_heads.mask_predictor = MaskRCNNPredictor(
            in_features_mask, hidden_layer, data.c)

        self.learn = Learner(data, model, loss_func=mask_rcnn_loss)
        self.learn.callbacks.append(train_callback(self.learn))
        self.learn.model = self.learn.model.to(self._device)

        # fixes for zero division error when slice is passed
        self.learn.layer_groups = split_model_idx(self.learn.model, [28])
        self.learn.create_opt(lr=3e-3)

        if pretrained_path is not None:
            self.load(pretrained_path)

    def __str__(self):
        return self.__repr__()

    def __repr__(self):
        return '<%s>' % (type(self).__name__)

    @property
    def supported_backbones(self):
        return [models.detection.maskrcnn_resnet50_fpn.__name__]

    @classmethod
    def from_model(cls, emd_path, data=None):
        """
        Creates a ``MaskRCNN`` Instance segmentation object from an Esri Model Definition (EMD) file.

        =====================   ===========================================
        **Argument**            **Description**
        ---------------------   -------------------------------------------
        emd_path                Required string. Path to Esri Model Definition
                                file.
        ---------------------   -------------------------------------------
        data                    Required fastai Databunch or None. Returned data
                                object from ``prepare_data`` function or None for
                                inferencing.

        =====================   ===========================================

        :returns: `MaskRCNN` Object
        """

        emd_path = Path(emd_path)
        with open(emd_path) as f:
            emd = json.load(f)

        model_file = Path(emd['ModelFile'])

        if not model_file.is_absolute():
            model_file = emd_path.parent / model_file

        model_params = emd['ModelParameters']

        try:
            class_mapping = {i['Value']: i['Name'] for i in emd['Classes']}
            color_mapping = {i['Value']: i['Color'] for i in emd['Classes']}
        except KeyError:
            class_mapping = {
                i['ClassValue']: i['ClassName']
                for i in emd['Classes']
            }
            color_mapping = {
                i['ClassValue']: i['Color']
                for i in emd['Classes']
            }

        if data is None:
            empty_data = _EmptyData(path=tempfile.TemporaryDirectory().name,
                                    loss_func=None,
                                    c=len(class_mapping) + 1,
                                    chip_size=emd['ImageHeight'])
            empty_data.class_mapping = class_mapping
            empty_data.color_mapping = color_mapping
            return cls(empty_data,
                       **model_params,
                       pretrained_path=str(model_file))
        else:
            return cls(data, **model_params, pretrained_path=str(model_file))

    def _create_emd(self, path):
        import random
        super()._create_emd(path)
        self._emd_template["Framework"] = "arcgis.learn.models._inferencing"
        self._emd_template["ModelConfiguration"] = "_maskrcnn_inferencing"
        self._emd_template["InferenceFunction"] = "ArcGISInstanceDetector.py"

        self._emd_template["ExtractBands"] = [0, 1, 2]
        self._emd_template['Classes'] = []
        class_data = {}
        for i, class_name in enumerate(
                self._data.classes[1:]):  # 0th index is background
            inverse_class_mapping = {
                v: k
                for k, v in self._data.class_mapping.items()
            }
            class_data["Value"] = inverse_class_mapping[class_name]
            class_data["Name"] = class_name
            color = [random.choice(range(256)) for i in range(3)] if is_no_color(self._data.color_mapping) else \
            self._data.color_mapping[inverse_class_mapping[class_name]]
            class_data["Color"] = color
            self._emd_template['Classes'].append(class_data.copy())

        json.dump(self._emd_template,
                  open(path.with_suffix('.emd'), 'w'),
                  indent=4)
        return path.stem

    @property
    def _model_metrics(self):
        return {}

    def _predict_results(self, xb):

        self.learn.model.eval()
        predictions = self.learn.model(xb.cuda())
        predictionsf = []
        for i in range(len(predictions)):
            predictionsf.append({})
            predictionsf[i]['masks'] = predictions[i]['masks'].detach().cpu(
            ).numpy()
            predictionsf[i]['boxes'] = predictions[i]['boxes'].detach().cpu(
            ).numpy()
            predictionsf[i]['labels'] = predictions[i]['labels'].detach().cpu(
            ).numpy()
            predictionsf[i]['scores'] = predictions[i]['scores'].detach().cpu(
            ).numpy()
            del predictions[i]['masks']
            del predictions[i]['boxes']
            del predictions[i]['labels']
            del predictions[i]['scores']
        del xb
        torch.cuda.empty_cache()

        return predictionsf

    def _predict_postprocess(self,
                             predictions,
                             threshold=0.5,
                             box_threshold=0.5):

        pred_mask = []
        pred_box = []

        for i in range(len(predictions)):
            out = predictions[i]['masks'].squeeze()
            pred_box.append([])

            if out.shape[0] != 0:  # handle for prediction with n masks
                if len(
                        out.shape
                ) == 2:  # for out dimension hxw (in case of only one predicted mask)
                    out = out[None]
                ymask = np.where(out[0] > threshold, 1, 0)
                #if torch.max(out[0]) > threshold:
                if predictions[i]['scores'][0] > box_threshold:
                    pred_box[i].append(predictions[i]['boxes'][0])
                for j in range(1, out.shape[0]):
                    ym1 = np.where(out[j] > threshold, j + 1, 0)
                    ymask += ym1
                    #if torch.max(out[j]) > threshold:
                    if predictions[i]['scores'][j] > box_threshold:
                        pred_box[i].append(predictions[i]['boxes'][j])
            else:
                ymask = np.zeros(
                    (self._data.chip_size,
                     self._data.chip_size))  # handle for not predicted masks
            pred_mask.append(ymask)
        return pred_mask, pred_box

    def show_results(self,
                     mode='mask',
                     mask_threshold=0.5,
                     box_threshold=0.7,
                     nrows=None,
                     imsize=5,
                     index=0,
                     alpha=0.5,
                     cmap='tab20'):
        """
        Displays the results of a trained model on a part of the validation set.

        =====================   ===========================================
        **Argument**            **Description**
        ---------------------   -------------------------------------------
        mode                    Required arguments within ['bbox', 'mask', 'bbox_mask'].
                                    * ``bbox`` - For visualizing only boundig boxes.
                                    * ``mask`` - For visualizing only mask
                                    * ``bbox_mask`` - For visualizing both mask and bounding boxes.
        ---------------------   -------------------------------------------
        mask_threshold          Optional float. The probabilty above which
                                a pixel will be considered mask.
        ---------------------   -------------------------------------------
        box_threshold           Optional float. The pobabilty above which
                                a detection will be considered valid.
        ---------------------   -------------------------------------------
        nrows                   Optional int. Number of rows of results
                                to be displayed.
        =====================   ===========================================
        """

        if mode not in ['bbox', 'mask', 'bbox_mask']:
            raise Exception("mode can be only ['bbox', 'mask', 'bbox_mask']")

        # Get Number of items
        if nrows is None:
            nrows = self._data.batch_size
        ncols = 2

        # Get Batch
        xb, yb = self._data.one_batch('DatasetType.Valid')

        predictions = self._predict_results(xb)

        pred_mask, pred_box = self._predict_postprocess(
            predictions, mask_threshold, box_threshold)

        fig, ax = plt.subplots(nrows=nrows,
                               ncols=ncols,
                               figsize=(ncols * imsize, nrows * imsize))
        fig.suptitle('Ground Truth / Predictions', fontsize=20)

        for i in range(nrows):
            ax[i][0].imshow(xb[i].numpy().transpose(1, 2, 0))
            ax[i][0].axis('off')
            if mode in ['mask', 'bbox_mask']:
                yb_mask = yb[i][0].numpy()
                for j in range(1, yb[i].shape[0]):
                    max_unique = np.max(np.unique(yb_mask))
                    yb_j = np.where(yb[i][j] > 0, yb[i][j] + max_unique,
                                    yb[i][j])
                    yb_mask += yb_j
                ax[i][0].imshow(yb_mask, cmap=cmap, alpha=alpha)
            ax[i][0].axis('off')
            ax[i][1].imshow(xb[i].numpy().transpose(1, 2, 0))
            ax[i][1].axis('off')
            if mode in ['mask', 'bbox_mask']:
                ax[i][1].imshow(pred_mask[i], cmap=cmap, alpha=alpha)
            if mode in ['bbox', 'bbox']:
                if pred_box[i] != []:
                    for num_boxes in pred_box[i]:
                        rect = patches.Rectangle((num_boxes[0], num_boxes[1]),
                                                 num_boxes[2] - num_boxes[0],
                                                 num_boxes[3] - num_boxes[1],
                                                 linewidth=1,
                                                 edgecolor='r',
                                                 facecolor='none')
                        ax[i][1].add_patch(rect)
            ax[i][1].axis('off')
        plt.subplots_adjust(top=0.95)
        torch.cuda.empty_cache()
class MaskRCNN(ArcGISModel):
    """
    Creates a ``MaskRCNN`` Instance segmentation object

    =====================   ===========================================
    **Argument**            **Description**
    ---------------------   -------------------------------------------
    data                    Required fastai Databunch. Returned data object from
                            ``prepare_data`` function.
    ---------------------   -------------------------------------------
    backbone                Optional function. Backbone CNN model to be used for
                            creating the base of the `MaskRCNN`, which
                            is `resnet50` by default. 
                            Compatible backbones: 'resnet50'
    ---------------------   -------------------------------------------
    pretrained_path         Optional string. Path where pre-trained model is
                            saved.
    =====================   ===========================================

    :returns: ``MaskRCNN`` Object
    """
    def __init__(self, data, backbone=None, pretrained_path=None):

        super().__init__(data, backbone)

        if self._is_multispectral:
            self._backbone_ms = self._backbone
            self._backbone = self._orig_backbone
            scaled_mean_values = data._scaled_mean_values[
                data._extract_bands].tolist()
            scaled_std_values = data._scaled_std_values[
                data._extract_bands].tolist()

        if backbone is None:
            self._backbone = models.resnet50
        elif type(backbone) is str:
            self._backbone = getattr(models, backbone)
        else:
            self._backbone = backbone

        if not self._check_backbone_support(self._backbone):
            raise Exception(
                f"Enter only compatible backbones from {', '.join(self.supported_backbones)}"
            )

        self._code = instance_detector_prf

        if self._backbone.__name__ is 'resnet50':
            model = models.detection.maskrcnn_resnet50_fpn(
                pretrained=True,
                min_size=1.5 * data.chip_size,
                max_size=2 * data.chip_size)
            if self._is_multispectral:
                model.backbone = _change_tail(model.backbone, data)
                model.transform.image_mean = scaled_mean_values
                model.transform.image_std = scaled_std_values
        elif self._backbone.__name__ in ['resnet18', 'resnet34']:
            if self._is_multispectral:
                backbone_small = create_body(self._backbone_ms,
                                             cut=_get_backbone_meta(
                                                 backbone_fn.__name__)['cut'])
                backbone_small.out_channels = 512
                model = models.detection.MaskRCNN(
                    backbone_small,
                    91,
                    min_size=1.5 * data.chip_size,
                    max_size=2 * data.chip_size,
                    image_mean=scaled_mean_values,
                    image_std=scaled_std_values)
            else:
                backbone_small = create_body(self._backbone)
                backbone_small.out_channels = 512
                model = models.detection.MaskRCNN(backbone_small,
                                                  91,
                                                  min_size=1.5 *
                                                  data.chip_size,
                                                  max_size=2 * data.chip_size)
        else:
            backbone_fpn = resnet_fpn_backbone(self._backbone.__name__, True)
            if self._is_multispectral:
                backbone_fpn = _change_tail(backbone_fpn, data)
                model = models.detection.MaskRCNN(
                    backbone_fpn,
                    91,
                    min_size=1.5 * data.chip_size,
                    max_size=2 * data.chip_size,
                    image_mean=scaled_mean_values,
                    image_std=scaled_std_values)
            else:
                model = models.detection.MaskRCNN(backbone_fpn,
                                                  91,
                                                  min_size=1.5 *
                                                  data.chip_size,
                                                  max_size=2 * data.chip_size)
        in_features = model.roi_heads.box_predictor.cls_score.in_features
        model.roi_heads.box_predictor = FastRCNNPredictor(in_features, data.c)
        in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
        hidden_layer = 256
        model.roi_heads.mask_predictor = MaskRCNNPredictor(
            in_features_mask, hidden_layer, data.c)

        self.learn = Learner(data, model, loss_func=mask_rcnn_loss)
        self.learn.callbacks.append(train_callback(self.learn))
        self.learn.model = self.learn.model.to(self._device)
        self.learn.c_device = self._device

        # fixes for zero division error when slice is passed
        idx = 27
        if self._backbone.__name__ in ['resnet18', 'resnet34']:
            idx = self._freeze()
        self.learn.layer_groups = split_model_idx(self.learn.model, [idx])
        self.learn.create_opt(lr=3e-3)

        # make first conv weights learnable
        self._arcgis_init_callback()

        if pretrained_path is not None:
            self.load(pretrained_path)

        if self._is_multispectral:
            self._orig_backbone = self._backbone
            self._backbone = self._backbone_ms

    def unfreeze(self):
        for _, param in self.learn.model.named_parameters():
            param.requires_grad = True

    def _freeze(self):
        "Freezes the pretrained backbone."
        for idx, i in enumerate(flatten_model(self.learn.model.backbone)):
            if isinstance(i, (torch.nn.BatchNorm2d)):
                continue
            for p in i.parameters():
                p.requires_grad = False
        return idx

    def __str__(self):
        return self.__repr__()

    def __repr__(self):
        return '<%s>' % (type(self).__name__)

    @property
    def supported_backbones(self):
        """
        Supported torchvision backbones for this model.
        """
        return [*self._resnet_family]

    @classmethod
    def from_model(cls, emd_path, data=None):
        """
        Creates a ``MaskRCNN`` Instance segmentation object from an Esri Model Definition (EMD) file.

        =====================   ===========================================
        **Argument**            **Description**
        ---------------------   -------------------------------------------
        emd_path                Required string. Path to Esri Model Definition
                                file.
        ---------------------   -------------------------------------------
        data                    Required fastai Databunch or None. Returned data
                                object from ``prepare_data`` function or None for
                                inferencing.

        =====================   ===========================================

        :returns: `MaskRCNN` Object
        """

        emd_path = Path(emd_path)
        with open(emd_path) as f:
            emd = json.load(f)

        model_file = Path(emd['ModelFile'])

        if not model_file.is_absolute():
            model_file = emd_path.parent / model_file

        model_params = emd['ModelParameters']

        try:
            class_mapping = {i['Value']: i['Name'] for i in emd['Classes']}
            color_mapping = {i['Value']: i['Color'] for i in emd['Classes']}
        except KeyError:
            class_mapping = {
                i['ClassValue']: i['ClassName']
                for i in emd['Classes']
            }
            color_mapping = {
                i['ClassValue']: i['Color']
                for i in emd['Classes']
            }

        if data is None:
            data = _EmptyData(path=emd_path.parent.parent,
                              loss_func=None,
                              c=len(class_mapping) + 1,
                              chip_size=emd['ImageHeight'])
            data.class_mapping = class_mapping
            data.color_mapping = color_mapping
            data.emd_path = emd_path
            data.emd = emd
            data = get_multispectral_data_params_from_emd(data, emd)

        return cls(data, **model_params, pretrained_path=str(model_file))

    def _get_emd_params(self):
        import random

        _emd_template = {}
        _emd_template["Framework"] = "arcgis.learn.models._inferencing"
        _emd_template["ModelConfiguration"] = "_maskrcnn_inferencing"
        _emd_template["InferenceFunction"] = "ArcGISInstanceDetector.py"

        _emd_template["ExtractBands"] = [0, 1, 2]
        _emd_template['Classes'] = []
        class_data = {}
        for i, class_name in enumerate(
                self._data.classes[1:]):  # 0th index is background
            inverse_class_mapping = {
                v: k
                for k, v in self._data.class_mapping.items()
            }
            class_data["Value"] = inverse_class_mapping[class_name]
            class_data["Name"] = class_name
            color = [random.choice(range(256)) for i in range(3)] if is_no_color(self._data.color_mapping) else \
            self._data.color_mapping[inverse_class_mapping[class_name]]
            class_data["Color"] = color
            _emd_template['Classes'].append(class_data.copy())

        return _emd_template

    @property
    def _model_metrics(self):
        return {
            'average_precision_score':
            self.average_precision_score(show_progress=False)
        }

    def _predict_results(self, xb):

        self.learn.model.eval()
        xb_l = xb.to(self._device)
        predictions = self.learn.model(list(xb_l))
        xb_l = xb_l.detach().cpu()
        del xb_l
        predictionsf = []
        for i in range(len(predictions)):
            predictionsf.append({})
            predictionsf[i]['masks'] = predictions[i]['masks'].detach().cpu(
            ).numpy()
            predictionsf[i]['boxes'] = predictions[i]['boxes'].detach().cpu(
            ).numpy()
            predictionsf[i]['labels'] = predictions[i]['labels'].detach().cpu(
            ).numpy()
            predictionsf[i]['scores'] = predictions[i]['scores'].detach().cpu(
            ).numpy()
            del predictions[i]['masks']
            del predictions[i]['boxes']
            del predictions[i]['labels']
            del predictions[i]['scores']
        if self._device == torch.device('cuda'):
            torch.cuda.empty_cache()
        return predictionsf

    def _predict_postprocess(self,
                             predictions,
                             threshold=0.5,
                             box_threshold=0.5):

        pred_mask = []
        pred_box = []

        for i in range(len(predictions)):
            out = predictions[i]['masks'].squeeze()
            pred_box.append([])

            if out.shape[0] != 0:  # handle for prediction with n masks
                if len(
                        out.shape
                ) == 2:  # for out dimension hxw (in case of only one predicted mask)
                    out = out[None]
                ymask = np.where(out[0] > threshold, 1, 0)
                if predictions[i]['scores'][0] > box_threshold:
                    pred_box[i].append(predictions[i]['boxes'][0])
                for j in range(1, out.shape[0]):
                    ym1 = np.where(out[j] > threshold, j + 1, 0)
                    ymask += ym1
                    if predictions[i]['scores'][j] > box_threshold:
                        pred_box[i].append(predictions[i]['boxes'][j])
            else:
                ymask = np.zeros(
                    (self._data.chip_size,
                     self._data.chip_size))  # handle for not predicted masks
            pred_mask.append(ymask)
        return pred_mask, pred_box

    def show_results(self,
                     rows=4,
                     mode='mask',
                     mask_threshold=0.5,
                     box_threshold=0.7,
                     imsize=5,
                     index=0,
                     alpha=0.5,
                     cmap='tab20',
                     **kwargs):
        """
        Displays the results of a trained model on a part of the validation set.

        =====================   ===========================================
        **Argument**            **Description**
        ---------------------   -------------------------------------------
        mode                    Required arguments within ['bbox', 'mask', 'bbox_mask'].
                                    * ``bbox`` - For visualizing only boundig boxes.
                                    * ``mask`` - For visualizing only mask
                                    * ``bbox_mask`` - For visualizing both mask and bounding boxes.
        ---------------------   -------------------------------------------
        mask_threshold          Optional float. The probabilty above which
                                a pixel will be considered mask.
        ---------------------   -------------------------------------------
        box_threshold           Optional float. The pobabilty above which
                                a detection will be considered valid.
        ---------------------   -------------------------------------------
        nrows                   Optional int. Number of rows of results
                                to be displayed.
        =====================   ===========================================
        """
        self._check_requisites()
        if mode not in ['bbox', 'mask', 'bbox_mask']:
            raise Exception("mode can be only ['bbox', 'mask', 'bbox_mask']")

        # Get Number of items
        nrows = rows
        ncols = 2

        type_data_loader = kwargs.get(
            'data_loader',
            'validation')  # options : traininig, validation, testing
        if type_data_loader == 'training':
            data_loader = self._data.train_dl
        elif type_data_loader == 'validation':
            data_loader = self._data.valid_dl
        elif type_data_loader == 'testing':
            data_loader = self._data.test_dl
        else:
            e = Exception(f'could not find {type_data_loader} in data.')
            raise (e)

        statistics_type = kwargs.get(
            'statistics_type', 'dataset')  # Accepted Values `dataset`, `DRA`

        cmap_fn = getattr(matplotlib.cm, cmap)

        title_font_size = 16
        if kwargs.get('top', None) is not None:
            top = kwargs.get('top')
        else:
            top = 1 - (math.sqrt(title_font_size) /
                       math.sqrt(100 * nrows * imsize))

        x_batch, y_batch = [], []
        i = 0
        dl_iterater = iter(data_loader)
        while i < nrows:
            x, y = next(dl_iterater)
            x_batch.append(x)
            y_batch.append(y)
            i += self._data.batch_size
        x_batch = torch.cat(x_batch)
        y_batch = torch.cat(y_batch)

        # Get Predictions
        prediction_store = []
        for i in range(0, x_batch.shape[0], self._data.batch_size):
            prediction_store.extend(
                self._predict_results(x_batch[i:i + self._data.batch_size]))
        pred_mask, pred_box = self._predict_postprocess(
            prediction_store, mask_threshold, box_threshold)

        if self._is_multispectral:
            rgb_bands = kwargs.get('rgb_bands',
                                   self._data._symbology_rgb_bands)

            e = Exception(
                '`rgb_bands` should be a valid band_order, list or tuple of length 3 or 1.'
            )
            symbology_bands = []
            if not (len(rgb_bands) == 3 or len(rgb_bands) == 1):
                raise (e)
            for b in rgb_bands:
                if type(b) == str:
                    b_index = self._bands.index(b)
                elif type(b) == int:
                    self._bands[
                        b]  # To check if the band index specified by the user really exists.
                    b_index = b
                else:
                    raise (e)
                b_index = self._data._extract_bands.index(b_index)
                symbology_bands.append(b_index)

            # Denormalize X
            if self._data._do_normalize:
                x_batch = (self._data._scaled_std_values[
                    self._data._extract_bands].view(1, -1, 1, 1).to(x_batch) *
                           x_batch) + self._data._scaled_mean_values[
                               self._data._extract_bands].view(1, -1, 1,
                                                               1).to(x_batch)

            # Extract RGB Bands
            symbology_x_batch = x_batch[:, symbology_bands]
            if statistics_type == 'DRA':
                shp = symbology_x_batch.shape
                min_vals = symbology_x_batch.view(shp[0], shp[1],
                                                  -1).min(dim=2)[0]
                max_vals = symbology_x_batch.view(shp[0], shp[1],
                                                  -1).max(dim=2)[0]
                symbology_x_batch = symbology_x_batch / (
                    max_vals.view(shp[0], shp[1], 1, 1) -
                    min_vals.view(shp[0], shp[1], 1, 1) + .001)

            # Channel first to channel last for plotting
            symbology_x_batch = symbology_x_batch.permute(0, 2, 3, 1)
            # Clamp float values to range 0 - 1
            if symbology_x_batch.mean() < 1:
                symbology_x_batch = symbology_x_batch.clamp(0, 1)
        else:
            symbology_x_batch = x_batch.permute(0, 2, 3, 1)

        # Squeeze channels if single channel (1, 224, 224) -> (224, 224)
        if symbology_x_batch.shape[-1] == 1:
            symbology_x_batch = symbology_x_batch.squeeze()

        fig, ax = plt.subplots(nrows=nrows,
                               ncols=ncols,
                               figsize=(ncols * imsize, nrows * imsize))
        fig.suptitle('Ground Truth / Predictions', fontsize=title_font_size)
        for i in range(nrows):
            if nrows == 1:
                ax_i = ax
            else:
                ax_i = ax[i]

            # Ground Truth
            ax_i[0].imshow(symbology_x_batch[i])
            ax_i[0].axis('off')
            if mode in ['mask', 'bbox_mask']:
                n_instance = y_batch[i].unique().shape[0]
                y_merged = y_batch[i].max(dim=0)[0].cpu().numpy()
                y_rgba = cmap_fn._resample(n_instance)(y_merged)
                y_rgba[y_merged == 0] = 0
                y_rgba[:, :, -1] = alpha
                ax_i[0].imshow(y_rgba)
            ax_i[0].axis('off')

            # Predictions
            ax_i[1].imshow(symbology_x_batch[i])
            ax_i[1].axis('off')
            if mode in ['mask', 'bbox_mask']:
                n_instance = np.unique(pred_mask[i]).shape[0]
                p_rgba = cmap_fn._resample(n_instance)(pred_mask[i])
                p_rgba[pred_mask[i] == 0] = 0
                p_rgba[:, :, -1] = alpha
                ax_i[1].imshow(p_rgba)
            if mode in ['bbox_mask', 'bbox']:
                if pred_box[i] != []:
                    for num_boxes in pred_box[i]:
                        rect = patches.Rectangle((num_boxes[0], num_boxes[1]),
                                                 num_boxes[2] - num_boxes[0],
                                                 num_boxes[3] - num_boxes[1],
                                                 linewidth=1,
                                                 edgecolor='r',
                                                 facecolor='none')
                        ax_i[1].add_patch(rect)
            ax_i[1].axis('off')
        plt.subplots_adjust(top=top)
        if self._device == torch.device('cuda'):
            torch.cuda.empty_cache()

    def average_precision_score(self,
                                detect_thresh=0.5,
                                iou_thresh=0.5,
                                mean=False,
                                show_progress=True):
        """
        Computes average precision on the validation set for each class.

        =====================   ===========================================
        **Argument**            **Description**
        ---------------------   -------------------------------------------
        detect_thresh           Optional float. The probabilty above which
                                a detection will be considered for computing
                                average precision.
        ---------------------   -------------------------------------------                        
        iou_thresh              Optional float. The intersection over union
                                threshold with the ground truth mask, above
                                which a predicted mask will be
                                considered a true positive.
        ---------------------   -------------------------------------------
        mean                    Optional bool. If False returns class-wise
                                average precision otherwise returns mean
                                average precision.
        =====================   ===========================================
        :returns: `dict` if mean is False otherwise `float`
        """
        self._check_requisites()
        if mean:
            aps = compute_class_AP(self, self._data.valid_dl, 1, show_progress,
                                   detect_thresh, iou_thresh, mean)
            return aps
        else:
            aps = compute_class_AP(self, self._data.valid_dl, self._data.c - 1,
                                   show_progress, detect_thresh, iou_thresh)
            return dict(zip(self._data.classes[1:], aps))
# +
model = Dnet_1ch()

learn = Learner(data,
                model,
                loss_func=Loss_combine(),
                opt_func=Over9000,
                metrics=[
                    Metric_grapheme(),
                    Metric_vowel(),
                    Metric_consonant(),
                    Metric_tot()
                ])

learn.model = nn.DataParallel(learn.model)

logger = CSVLogger(learn, f'log{fold}')
learn.clip_grad = 1.0
learn.split([model.head1])
learn.unfreeze()
# -

learn.summary()

# +
# learn.fit_one_cycle(32, max_lr=slice(0.2e-2,1e-2), wd=[1e-3,0.1e-1], pct_start=0.0,
#     div_factor=100, callbacks = [logger, SaveModelCallback(learn,monitor='metric_tot',
#     mode='max',name=f'model_{fold}'),MixUpCallback(learn)])

# changed config
예제 #8
0
class ModelExtension(ArcGISModel):
    """
    Creates a ``ModelExtension`` object, object detection model to train a model from your own source.

    =====================   ============================================================
    **Argument**            **Description**
    ---------------------   ------------------------------------------------------------
    data                    Required fastai Databunch. Returned data object from
                            ``prepare_data`` function.
    ---------------------   ------------------------------------------------------------
    model_conf              A class definition contains the following methods:

                                * ``get_model(self, data, backbone=None)``: for model definition,
                                
                                * ``on_batch_begin(self, learn, model_input_batch, model_target_batch)``: for 
                                  feeding input to the model during training, 

                                * ``transform_input(self, xb)``: for feeding input to the model during
                                  inferencing/validation,

                                * ``transform_input_multispectral(self, xb)``: for feeding input to the
                                  model during inferencing/validation in case of multispectral data,

                                * ``loss(self, model_output, *model_target)``: to return loss value of the model, and 

                                * ``post_process(self, pred, nms_overlap, thres, chip_size, device)``: to post-process
                                  the output of the model.
    ---------------------   ------------------------------------------------------------
    backbone                Optional function. If custom model requires any backbone.
    ---------------------   ------------------------------------------------------------
    pretrained_path         Optional string. Path where pre-trained model is
                            saved.
    =====================   ============================================================

    :return: ``ModelExtension`` Object
    """
    def __init__(self, data, model_conf, backbone=None, pretrained_path=None):

        super().__init__(data, backbone)
        self.model_conf = model_conf()
        self.model_conf_class = model_conf
        self._backend = 'pytorch'
        model = self.model_conf.get_model(data, backbone)
        if self._is_multispectral:
            model = _change_tail(model, data)
        if not _isnotebook() and os.name == 'posix':
            _set_ddp_multigpu(self)
            if self._multigpu_training:
                self.learn = Learner(
                    data, model,
                    loss_func=self.model_conf.loss).to_distributed(
                        self._rank_distributed)
            else:
                self.learn = Learner(data,
                                     model,
                                     loss_func=self.model_conf.loss)
        else:
            self.learn = Learner(data, model, loss_func=self.model_conf.loss)
        self.learn.callbacks.append(
            self.train_callback(self.learn, self.model_conf.on_batch_begin))
        self._code = code
        self._arcgis_init_callback()  # make first conv weights learnable
        if pretrained_path is not None:
            self.load(pretrained_path)

    if HAS_FASTAI:

        class train_callback(LearnerCallback):
            def __init__(self, learn, on_batch_begin_fn):
                super().__init__(learn)
                self.on_batch_begin_fn = on_batch_begin_fn

            def on_batch_begin(self, last_input, last_target, train, **kwargs):

                last_input, last_target = self.on_batch_begin_fn(
                    self.learn, last_input, last_target)

                return {'last_input': last_input, 'last_target': last_target}

    def _analyze_pred(self,
                      pred,
                      thresh=0.5,
                      nms_overlap=0.1,
                      ret_scores=True,
                      device=None):
        return self.model_conf.post_process(pred, nms_overlap, thresh,
                                            self.learn.data.chip_size, device)

    def _get_emd_params(self):
        import random
        _emd_template = {}
        _emd_template["Framework"] = "arcgis.learn.models._inferencing"
        _emd_template["InferenceFunction"] = "ArcGISObjectDetector.py"
        _emd_template["ModelConfiguration"] = "_model_extension_inferencing"
        _emd_template["ModelType"] = "ObjectDetection"
        _emd_template["ExtractBands"] = [0, 1, 2]
        _emd_template['Classes'] = []
        _emd_template['ModelConfigurationFile'] = "ModelConfiguration.py"
        _emd_template['ModelFileConfigurationClass'] = type(
            self.model_conf).__name__

        class_data = {}
        for i, class_name in enumerate(
                self._data.classes[1:]):  # 0th index is background
            inverse_class_mapping = {
                v: k
                for k, v in self._data.class_mapping.items()
            }
            class_data["Value"] = inverse_class_mapping[class_name]
            class_data["Name"] = class_name
            color = [random.choice(range(256)) for i in range(3)]
            class_data["Color"] = color
            _emd_template['Classes'].append(class_data.copy())

        return _emd_template

    @property
    def _is_model_extension(self):
        return True

    @classmethod
    def from_model(cls, emd_path, data=None):
        """
        Creates a ``ModelExtension`` object from an Esri Model Definition (EMD) file.

        =====================   ===========================================
        **Argument**            **Description**
        ---------------------   -------------------------------------------
        emd_path                Required string. Path to Esri Model Definition
                                file.
        ---------------------   -------------------------------------------
        data                    Required fastai Databunch or None. Returned data
                                object from ``prepare_data`` function or None for
                                inferencing.

        =====================   ===========================================

        :returns: `ModelExtension` Object
        """

        emd_path = Path(emd_path)

        with open(emd_path) as f:
            emd = json.load(f)

        model_file = Path(emd['ModelFile'])

        if not model_file.is_absolute():
            model_file = emd_path.parent / model_file

        modelconf = Path(emd['ModelConfigurationFile'])

        if not modelconf.is_absolute():
            modelconf = emd_path.parent / modelconf

        modelconfclass = emd['ModelFileConfigurationClass']

        sys.path.append(os.path.dirname(modelconf))
        model_configuration = getattr(
            importlib.import_module('{}'.format(modelconf.name[0:-3])),
            modelconfclass)

        backbone = emd['ModelParameters']['backbone']

        try:
            class_mapping = {i['Value']: i['Name'] for i in emd['Classes']}
            color_mapping = {i['Value']: i['Color'] for i in emd['Classes']}
        except KeyError:
            class_mapping = {
                i['ClassValue']: i['ClassName']
                for i in emd['Classes']
            }
            color_mapping = {
                i['ClassValue']: i['Color']
                for i in emd['Classes']
            }

        if data is None:
            data = _EmptyData(path=emd_path.parent.parent,
                              loss_func=None,
                              c=len(class_mapping) + 1,
                              chip_size=emd['ImageHeight'])
            data.class_mapping = class_mapping
            data.color_mapping = color_mapping
            data.emd_path = emd_path
            data.emd = emd
            data.classes = ['background']
            for k, v in class_mapping.items():
                data.classes.append(v)
            data = get_multispectral_data_params_from_emd(data, emd)
        return cls(data,
                   model_configuration,
                   backbone,
                   pretrained_path=str(model_file))

    @property
    def _model_metrics(self):
        return {
            'average_precision_score':
            self.average_precision_score(show_progress=False)
        }

    def _get_y(self, bbox, clas):
        try:
            bbox = bbox.view(-1, 4)
        except Exception:
            bbox = torch.zeros(size=[0, 4])
        bb_keep = ((bbox[:, 2] - bbox[:, 0]) > 0).nonzero()[:, 0]
        return bbox[bb_keep], clas[bb_keep]

    def _intersect(self, box_a, box_b):
        max_xy = torch.min(box_a[:, None, 2:], box_b[None, :, 2:])
        min_xy = torch.max(box_a[:, None, :2], box_b[None, :, :2])
        inter = torch.clamp((max_xy - min_xy), min=0)
        return inter[:, :, 0] * inter[:, :, 1]

    def _box_sz(self, b):
        return (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1])

    def _jaccard(self, box_a, box_b):
        inter = self._intersect(box_a, box_b)
        union = self._box_sz(box_a).unsqueeze(1) + self._box_sz(
            box_b).unsqueeze(0) - inter
        return inter / union

    def show_results(self, rows=5, thresh=0.5, nms_overlap=0.1):
        """
        Displays the results of a trained model on a part of the validation set.
        """
        self._check_requisites()
        if rows > len(self._data.valid_ds):
            rows = len(self._data.valid_ds)
        self._show_results_modified(rows=rows,
                                    thresh=thresh,
                                    nms_overlap=nms_overlap,
                                    model=self)

    def _show_results_multispectral(self,
                                    rows=5,
                                    thresh=0.3,
                                    nms_overlap=0.1,
                                    alpha=1,
                                    **kwargs):
        ax = show_results_multispectral(self,
                                        nrows=rows,
                                        thresh=thresh,
                                        nms_overlap=nms_overlap,
                                        alpha=alpha,
                                        **kwargs)

    def _show_results_modified(self, rows=5, **kwargs):

        if rows > len(self._data.valid_ds):
            rows = len(self._data.valid_ds)

        ds_type = DatasetType.Valid
        n_items = rows**2 if self.learn.data.train_ds.x._square_show_res else rows
        if self.learn.dl(ds_type).batch_size < n_items:
            n_items = self.learn.dl(ds_type).batch_size
        ds = self.learn.dl(ds_type).dataset
        xb, yb = self.learn.data.one_batch(ds_type, detach=False, denorm=False)
        self.learn.model.eval()
        preds = self.learn.model(self.model_conf.transform_input(xb))
        x, y = to_cpu(xb), to_cpu(yb)
        norm = getattr(self.learn.data, 'norm', False)
        if norm:
            x = self.learn.data.denorm(x)
            if norm.keywords.get('do_y', False):
                y = self.learn.data.denorm(y, do_x=True)
                preds = self.learn.data.denorm(preds, do_x=True)
        analyze_kwargs, kwargs = split_kwargs_by_func(kwargs,
                                                      ds.y.analyze_pred)
        preds = ds.y.analyze_pred(preds, **analyze_kwargs)
        xs = [ds.x.reconstruct(grab_idx(x, i)) for i in range(n_items)]
        if has_arg(ds.y.reconstruct, 'x'):
            ys = [
                ds.y.reconstruct(grab_idx(y, i), x=x) for i, x in enumerate(xs)
            ]
            zs = [ds.y.reconstruct(z, x=x) for z, x in zip(preds, xs)]
        else:
            ys = [ds.y.reconstruct(grab_idx(y, i)) for i in range(n_items)]
            zs = [ds.y.reconstruct(z) for z in preds]
        ds.x.show_xyzs(xs, ys, zs, **kwargs)

    def average_precision_score(self,
                                detect_thresh=0.2,
                                iou_thresh=0.1,
                                mean=False,
                                show_progress=True):
        """
        Computes average precision on the validation set for each class.

        =====================   ===========================================
        **Argument**            **Description**
        ---------------------   -------------------------------------------
        detect_thresh           Optional float. The probabilty above which
                                a detection will be considered for computing
                                average precision.
        ---------------------   -------------------------------------------
        iou_thresh              Optional float. The intersection over union
                                threshold with the ground truth labels, above
                                which a predicted bounding box will be
                                considered a true positive.
        ---------------------   -------------------------------------------
        mean                    Optional bool. If False returns class-wise
                                average precision otherwise returns mean
                                average precision.
        =====================   ===========================================

        :returns: `dict` if mean is False otherwise `float`
        """
        self._check_requisites()

        aps = compute_class_AP(self,
                               self._data.valid_dl,
                               self._data.c - 1,
                               show_progress,
                               detect_thresh=detect_thresh,
                               iou_thresh=iou_thresh)
        if mean:
            import statistics
            return statistics.mean(aps)
        else:
            return dict(zip(self._data.classes[1:], aps))