def test_configs_conversion(self): blocks = assemblenet.flat_lists_to_blocks( assemblenet.asn50_structure, assemblenet.asn_structure_weights) re_structure, re_weights = assemblenet.blocks_to_flat_lists(blocks) self.assertAllEqual(re_structure, assemblenet.asn50_structure, msg='asn50_structure') self.assertAllEqual(re_weights, assemblenet.asn_structure_weights, msg='asn_structure_weights')
def build_assemblenet_model( input_specs: tf.keras.layers.InputSpec, model_config: cfg.AssembleNetModel, num_classes: int, l2_regularizer: tf.keras.regularizers.Regularizer = None): """Builds assemblenet model.""" input_specs_dict = {'image': input_specs} backbone = build_assemblenet_v1(input_specs, model_config, l2_regularizer) backbone_cfg = model_config.backbone.get() model_structure, _ = cfg.blocks_to_flat_lists(backbone_cfg.blocks) model = AssembleNetModel( backbone, num_classes=num_classes, num_frames=backbone_cfg.num_frames, model_structure=model_structure, input_specs=input_specs_dict, max_pool_preditions=model_config.max_pool_preditions) return model
def build_assemblenet_plus( input_specs: tf.keras.layers.InputSpec, model_config: cfg.Backbone3D, l2_regularizer: tf.keras.regularizers.Regularizer = None ) -> tf.keras.Model: """Builds assemblenet++ backbone.""" del l2_regularizer backbone_type = model_config.backbone.type backbone_cfg = model_config.backbone.get() norm_activation_config = model_config.norm_activation assert backbone_type == 'assemblenet++' assemblenet_depth = int(backbone_cfg.model_id) if assemblenet_depth not in ASSEMBLENET_SPECS: raise ValueError('Not a valid assemblenet_depth:', assemblenet_depth) model_structure, model_edge_weights = cfg.blocks_to_flat_lists( backbone_cfg.blocks) params = ASSEMBLENET_SPECS[assemblenet_depth] block_fn = functools.partial( params['block'], use_sync_bn=norm_activation_config.use_sync_bn, bn_decay=norm_activation_config.norm_momentum, bn_epsilon=norm_activation_config.norm_epsilon) backbone = AssembleNetPlus( block_fn=block_fn, num_blocks=params['num_blocks'], num_frames=backbone_cfg.num_frames, model_structure=model_structure, input_specs=input_specs, model_edge_weights=model_edge_weights, use_sync_bn=norm_activation_config.use_sync_bn, bn_decay=norm_activation_config.norm_momentum, bn_epsilon=norm_activation_config.norm_epsilon, use_object_input=backbone_cfg. use_object_input, #todo: get from backbone config attention_mode=backbone_cfg.attention_mode ) #todo: get from backbone config logging.info('Number of parameters in AssembleNet++ backbone: %f M.', backbone.count_params() / 10.**6) return backbone
def build_assemblenet_plus_model( input_specs: tf.keras.layers.InputSpec, model_config: cfg.AssembleNetPlusModel, num_classes: int, l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None): """Builds assemblenet++ model.""" input_specs_dict = {'image': input_specs} backbone = build_assemblenet_plus(input_specs, model_config.backbone, model_config.norm_activation, l2_regularizer) backbone_cfg = model_config.backbone.get() model_structure, _ = cfg.blocks_to_flat_lists(backbone_cfg.blocks) model = AssembleNetPlusModel( backbone, num_classes=num_classes, num_frames=backbone_cfg.num_frames, model_structure=model_structure, input_specs=input_specs_dict, max_pool_predictions=model_config.max_pool_predictions, use_object_input=backbone_cfg.use_object_input) return model
def build_assemblenet_v1( input_specs: tf.keras.layers.InputSpec, backbone_config: hyperparams.Config, norm_activation_config: hyperparams.Config, l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None ) -> tf.keras.Model: """Builds assemblenet backbone.""" del l2_regularizer backbone_type = backbone_config.type backbone_cfg = backbone_config.get() assert backbone_type == 'assemblenet' assemblenet_depth = int(backbone_cfg.model_id) if assemblenet_depth not in ASSEMBLENET_SPECS: raise ValueError('Not a valid assemblenet_depth:', assemblenet_depth) model_structure, model_edge_weights = cfg.blocks_to_flat_lists( backbone_cfg.blocks) params = ASSEMBLENET_SPECS[assemblenet_depth] block_fn = functools.partial( params['block'], use_sync_bn=norm_activation_config.use_sync_bn, bn_decay=norm_activation_config.norm_momentum, bn_epsilon=norm_activation_config.norm_epsilon) backbone = AssembleNet( block_fn=block_fn, num_blocks=params['num_blocks'], num_frames=backbone_cfg.num_frames, model_structure=model_structure, input_specs=input_specs, model_edge_weights=model_edge_weights, combine_method=backbone_cfg.combine_method, use_sync_bn=norm_activation_config.use_sync_bn, bn_decay=norm_activation_config.norm_momentum, bn_epsilon=norm_activation_config.norm_epsilon) logging.info('Number of parameters in AssembleNet backbone: %f M.', backbone.count_params() / 10.**6) return backbone