class ImageClassificationModel(hyperparams.Config):
    num_classes: int = 0
    input_size: List[int] = dataclasses.field(default_factory=list)
    backbone: backbones.Backbone = backbones.Backbone(
        type='darknet', darknet=backbones.Darknet())
    dropout_rate: float = 0.0
    norm_activation: common.NormActivation = common.NormActivation()
    # Adds a Batch Normalization layer pre-GlobalAveragePooling in classification.
    add_head_batch_norm: bool = False
class ImageClassificationModel(hyperparams.Config):
  """Image classification model config."""
  num_classes: int = 0
  input_size: List[int] = dataclasses.field(default_factory=lambda: [224, 224])
  backbone: backbones.Backbone = backbones.Backbone(
      type='darknet', darknet=backbones.Darknet())
  dropout_rate: float = 0.0
  norm_activation: common.NormActivation = common.NormActivation()
  # Adds a Batch Normalization layer pre-GlobalAveragePooling in classification.
  add_head_batch_norm: bool = False
  kernel_initializer: str = 'VarianceScaling'
Example #3
0
class Yolo(hyperparams.Config):
  input_size: Optional[List[int]] = dataclasses.field(
      default_factory=lambda: [512, 512, 3])
  backbone: backbones.Backbone = backbones.Backbone(
      type='darknet', darknet=backbones.Darknet(model_id='cspdarknet53'))
  decoder: decoders.Decoder = decoders.Decoder(
      type='yolo_decoder',
      yolo_decoder=decoders.YoloDecoder(version='v4', type='regular'))
  head: YoloHead = YoloHead()
  detection_generator: YoloDetectionGenerator = YoloDetectionGenerator()
  loss: YoloLoss = YoloLoss()
  norm_activation: common.NormActivation = common.NormActivation(
      activation='mish',
      use_sync_bn=True,
      norm_momentum=0.99,
      norm_epsilon=0.001)
  num_classes: int = 80
  anchor_boxes: AnchorBoxes = AnchorBoxes()
  darknet_based_model: bool = False