def _create_models(self, backbone_retinanet, num_classes, lr=1e-5): """ Creates three models (model, training_model, prediction_model). Parameters ---------- backbone_retinanet : A function to call to create a retinanet model with a given backbone. num_classes : The number of classes to train. Returns ------- model : The base model. training_model : The training model. If multi_gpu=0, this is identical to model. prediction_model : The model wrapped with utility functions to perform object detection (applies regression values and performs NMS). """ anchor_params = None num_anchors = None model = backbone_retinanet(num_classes, num_anchors=num_anchors, modifier=None) training_model = model prediction_model = retinanet_bbox(model=model, anchor_params=anchor_params) training_model.compile( loss={ "regression": losses.smooth_l1(), "classification": losses.focal() }, optimizer=keras.optimizers.adam(lr=lr, clipnorm=0.001), ) return model, training_model, prediction_model
def _create_models(self, backbone_retinanet, num_classes, weights, freeze_backbone=False, lr=1e-5): """ Creates three models (model, training_model, prediction_model). Parameters ---------- backbone_retinanet : A function to call to create a retinanet model with a given backbone. num_classes : The number of classes to train. weights : The weights to load into the model. multi_gpu : The number of GPUs to use for training. freeze_backbone : If True, disables learning for the backbone. config : Config parameters, None indicates the default configuration. Returns ------- model : The base model. training_model : The training model. If multi_gpu=0, this is identical to model. prediction_model : The model wrapped with utility functions to perform object detection (applies regression values and performs NMS). """ modifier = freeze_model if freeze_backbone else None anchor_params = None num_anchors = None model = self._model_with_weights(backbone_retinanet( num_classes, num_anchors=num_anchors, modifier=modifier), weights=weights, skip_mismatch=True) training_model = model prediction_model = retinanet_bbox(model=model, anchor_params=anchor_params) training_model.compile(loss={ 'regression': losses.smooth_l1(), 'classification': losses.focal() }, optimizer=keras.optimizers.adam(lr=lr, clipnorm=0.001)) return model, training_model, prediction_model