예제 #1
0
 def test_configs_conversion(self):
     blocks = assemblenet.flat_lists_to_blocks(
         assemblenet.asn50_structure, assemblenet.asn_structure_weights)
     re_structure, re_weights = assemblenet.blocks_to_flat_lists(blocks)
     self.assertAllEqual(re_structure,
                         assemblenet.asn50_structure,
                         msg='asn50_structure')
     self.assertAllEqual(re_weights,
                         assemblenet.asn_structure_weights,
                         msg='asn_structure_weights')
예제 #2
0
def build_assemblenet_model(
    input_specs: tf.keras.layers.InputSpec,
    model_config: cfg.AssembleNetModel,
    num_classes: int,
    l2_regularizer: tf.keras.regularizers.Regularizer = None):
  """Builds assemblenet model."""
  input_specs_dict = {'image': input_specs}
  backbone = build_assemblenet_v1(input_specs, model_config, l2_regularizer)
  backbone_cfg = model_config.backbone.get()
  model_structure, _ = cfg.blocks_to_flat_lists(backbone_cfg.blocks)
  model = AssembleNetModel(
      backbone,
      num_classes=num_classes,
      num_frames=backbone_cfg.num_frames,
      model_structure=model_structure,
      input_specs=input_specs_dict,
      max_pool_preditions=model_config.max_pool_preditions)
  return model
예제 #3
0
def build_assemblenet_plus(
    input_specs: tf.keras.layers.InputSpec,
    model_config: cfg.Backbone3D,
    l2_regularizer: tf.keras.regularizers.Regularizer = None
) -> tf.keras.Model:
    """Builds assemblenet++ backbone."""
    del l2_regularizer

    backbone_type = model_config.backbone.type
    backbone_cfg = model_config.backbone.get()
    norm_activation_config = model_config.norm_activation
    assert backbone_type == 'assemblenet++'

    assemblenet_depth = int(backbone_cfg.model_id)
    if assemblenet_depth not in ASSEMBLENET_SPECS:
        raise ValueError('Not a valid assemblenet_depth:', assemblenet_depth)
    model_structure, model_edge_weights = cfg.blocks_to_flat_lists(
        backbone_cfg.blocks)
    params = ASSEMBLENET_SPECS[assemblenet_depth]
    block_fn = functools.partial(
        params['block'],
        use_sync_bn=norm_activation_config.use_sync_bn,
        bn_decay=norm_activation_config.norm_momentum,
        bn_epsilon=norm_activation_config.norm_epsilon)
    backbone = AssembleNetPlus(
        block_fn=block_fn,
        num_blocks=params['num_blocks'],
        num_frames=backbone_cfg.num_frames,
        model_structure=model_structure,
        input_specs=input_specs,
        model_edge_weights=model_edge_weights,
        use_sync_bn=norm_activation_config.use_sync_bn,
        bn_decay=norm_activation_config.norm_momentum,
        bn_epsilon=norm_activation_config.norm_epsilon,
        use_object_input=backbone_cfg.
        use_object_input,  #todo: get from backbone config
        attention_mode=backbone_cfg.attention_mode
    )  #todo: get from backbone config
    logging.info('Number of parameters in AssembleNet++ backbone: %f M.',
                 backbone.count_params() / 10.**6)
    return backbone
예제 #4
0
def build_assemblenet_plus_model(
        input_specs: tf.keras.layers.InputSpec,
        model_config: cfg.AssembleNetPlusModel,
        num_classes: int,
        l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None):
    """Builds assemblenet++ model."""
    input_specs_dict = {'image': input_specs}
    backbone = build_assemblenet_plus(input_specs, model_config.backbone,
                                      model_config.norm_activation,
                                      l2_regularizer)
    backbone_cfg = model_config.backbone.get()
    model_structure, _ = cfg.blocks_to_flat_lists(backbone_cfg.blocks)
    model = AssembleNetPlusModel(
        backbone,
        num_classes=num_classes,
        num_frames=backbone_cfg.num_frames,
        model_structure=model_structure,
        input_specs=input_specs_dict,
        max_pool_predictions=model_config.max_pool_predictions,
        use_object_input=backbone_cfg.use_object_input)
    return model
예제 #5
0
def build_assemblenet_v1(
    input_specs: tf.keras.layers.InputSpec,
    backbone_config: hyperparams.Config,
    norm_activation_config: hyperparams.Config,
    l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None
) -> tf.keras.Model:
  """Builds assemblenet backbone."""
  del l2_regularizer

  backbone_type = backbone_config.type
  backbone_cfg = backbone_config.get()
  assert backbone_type == 'assemblenet'

  assemblenet_depth = int(backbone_cfg.model_id)
  if assemblenet_depth not in ASSEMBLENET_SPECS:
    raise ValueError('Not a valid assemblenet_depth:', assemblenet_depth)
  model_structure, model_edge_weights = cfg.blocks_to_flat_lists(
      backbone_cfg.blocks)
  params = ASSEMBLENET_SPECS[assemblenet_depth]
  block_fn = functools.partial(
      params['block'],
      use_sync_bn=norm_activation_config.use_sync_bn,
      bn_decay=norm_activation_config.norm_momentum,
      bn_epsilon=norm_activation_config.norm_epsilon)
  backbone = AssembleNet(
      block_fn=block_fn,
      num_blocks=params['num_blocks'],
      num_frames=backbone_cfg.num_frames,
      model_structure=model_structure,
      input_specs=input_specs,
      model_edge_weights=model_edge_weights,
      combine_method=backbone_cfg.combine_method,
      use_sync_bn=norm_activation_config.use_sync_bn,
      bn_decay=norm_activation_config.norm_momentum,
      bn_epsilon=norm_activation_config.norm_epsilon)
  logging.info('Number of parameters in AssembleNet backbone: %f M.',
               backbone.count_params() / 10.**6)
  return backbone