コード例 #1
0
 def test_configs_conversion(self):
     blocks = assemblenet.flat_lists_to_blocks(
         assemblenet.asn50_structure, assemblenet.asn_structure_weights)
     re_structure, re_weights = assemblenet.blocks_to_flat_lists(blocks)
     self.assertAllEqual(re_structure,
                         assemblenet.asn50_structure,
                         msg='asn50_structure')
     self.assertAllEqual(re_weights,
                         assemblenet.asn_structure_weights,
                         msg='asn_structure_weights')
コード例 #2
0
ファイル: assemblenet.py プロジェクト: vishalbelsare/models
def build_assemblenet_model(
    input_specs: tf.keras.layers.InputSpec,
    model_config: cfg.AssembleNetModel,
    num_classes: int,
    l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None):
  """Builds assemblenet model."""
  input_specs_dict = {'image': input_specs}
  backbone = build_assemblenet_v1(input_specs, model_config.backbone,
                                  model_config.norm_activation, l2_regularizer)
  backbone_cfg = model_config.backbone.get()
  model_structure, _ = cfg.blocks_to_flat_lists(backbone_cfg.blocks)
  model = AssembleNetModel(
      backbone,
      num_classes=num_classes,
      num_frames=backbone_cfg.num_frames,
      model_structure=model_structure,
      input_specs=input_specs_dict,
      max_pool_predictions=model_config.max_pool_predictions)
  return model
コード例 #3
0
ファイル: assemblenet.py プロジェクト: vishalbelsare/models
def build_assemblenet_v1(
    input_specs: tf.keras.layers.InputSpec,
    backbone_config: hyperparams.Config,
    norm_activation_config: hyperparams.Config,
    l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None
) -> tf.keras.Model:
  """Builds assemblenet backbone."""
  del l2_regularizer

  backbone_type = backbone_config.type
  backbone_cfg = backbone_config.get()
  assert 'assemblenet' in backbone_type

  assemblenet_depth = int(backbone_cfg.model_id)
  if assemblenet_depth not in ASSEMBLENET_SPECS:
    raise ValueError('Not a valid assemblenet_depth:', assemblenet_depth)
  model_structure, model_edge_weights = cfg.blocks_to_flat_lists(
      backbone_cfg.blocks)
  params = ASSEMBLENET_SPECS[assemblenet_depth]
  block_fn = functools.partial(
      params['block'],
      use_sync_bn=norm_activation_config.use_sync_bn,
      bn_decay=norm_activation_config.norm_momentum,
      bn_epsilon=norm_activation_config.norm_epsilon)
  backbone = AssembleNet(
      block_fn=block_fn,
      num_blocks=params['num_blocks'],
      num_frames=backbone_cfg.num_frames,
      model_structure=model_structure,
      input_specs=input_specs,
      model_edge_weights=model_edge_weights,
      combine_method=backbone_cfg.combine_method,
      use_sync_bn=norm_activation_config.use_sync_bn,
      bn_decay=norm_activation_config.norm_momentum,
      bn_epsilon=norm_activation_config.norm_epsilon)
  logging.info('Number of parameters in AssembleNet backbone: %f M.',
               backbone.count_params() / 10.**6)
  return backbone