def build_convline(text_proto): convline_config = convline_pb2.ConvLine() text_format.Merge(text_proto, convline_config) argscope_fn = hyperparams_builder.build dynamic_argscope_fn = dynamic_hyperparams_builder.build return convline_builder.build(argscope_fn, dynamic_argscope_fn, convline_config, True)
def build(argscope_fn, negative_attention_config, is_training): convline = None if negative_attention_config.HasField('convline'): convline = convline_builder.build(argscope_fn, None, negative_attention_config.convline, is_training) return NegativeAttention( convline=convline, concat_type=negative_attention_config.ConcatType.Name( negative_attention_config.concat_type), similarity_type=negative_attention_config.SimilarityType.Name( negative_attention_config.similarity_type), use_gt_labels=negative_attention_config.use_gt_labels, add_loss=negative_attention_config.add_loss)
def build_cross_similarity(argscope_fn, cross_similarity_config, attention_tree, k, is_training): cross_similarity_oneof = cross_similarity_config.WhichOneof('cross_similarity_oneof') if cross_similarity_oneof == 'cosine_cross_similarity': cosine_cross_similarity = cross_similarity_config.cosine_cross_similarity return cross_similarity.CosineCrossSimilarity() elif cross_similarity_oneof == 'linear_cross_similarity': linear_cross_similarity = cross_similarity_config.linear_cross_similarity fc_hyperparameters = argscope_fn( linear_cross_similarity.fc_hyperparameters, is_training) return cross_similarity.LinearCrossSimilarity(fc_hyperparameters) elif cross_similarity_oneof == 'deep_cross_similarity': deep_cross_similarity = cross_similarity_config.deep_cross_similarity fc_hyperparameters = argscope_fn( deep_cross_similarity.fc_hyperparameters, is_training) convline = None if deep_cross_similarity.HasField('convline'): convline = convline_builder.build(argscope_fn, None, deep_cross_similarity.convline, is_training) negative_attention = None if deep_cross_similarity.HasField('negative_attention'): negative_attention = negative_attention_builder.build(argscope_fn, deep_cross_similarity.negative_attention, is_training) return cross_similarity.DeepCrossSimilarity( deep_cross_similarity.stop_gradient, fc_hyperparameters, convline, negative_attention, sum_output=deep_cross_similarity.sum_output) elif cross_similarity_oneof == 'average_cross_similarity': average_cross_similarity = cross_similarity_config.average_cross_similarity return cross_similarity.AverageCrossSimilarity() elif cross_similarity_oneof == 'euclidean_cross_similarity': return cross_similarity.EuclideanCrossSimilarity() elif cross_similarity_oneof == 'pairwise_cross_similarity': pairwise_cross_similarity = cross_similarity_config.pairwise_cross_similarity base_cross_similarity = build_cross_similarity(argscope_fn, pairwise_cross_similarity.cross_similarity, attention_tree, k, is_training) return cross_similarity.PairwiseCrossSimilarity(pairwise_cross_similarity.stop_gradient, base_cross_similarity, k, attention_tree=attention_tree) elif cross_similarity_oneof == 'k1_cross_similarity': k1_cross_similarity = cross_similarity_config.k1_cross_similarity base_cross_similarity = build_cross_similarity(argscope_fn, k1_cross_similarity.cross_similarity, attention_tree, k, is_training) return cross_similarity.K1CrossSimilarity(base_cross_similarity, k, k1_cross_similarity.share_weights_with_pairwise_cs, k1_cross_similarity.mode, k1_cross_similarity.topk) elif cross_similarity_oneof == 'double_cross_similarity': double_cross_similarity = cross_similarity_config.double_cross_similarity main_cs = build_cross_similarity(argscope_fn, double_cross_similarity.main, attention_tree, k, is_training) if double_cross_similarity.HasField('transfered'): transfered_config = double_cross_similarity.transfered else: transfered_config = double_cross_similarity.main transfered_cs = build_cross_similarity(argscope_fn, transfered_config, attention_tree, k, is_training) ## PairwiseCrossSimilarity overrides the scope ## We need these to ensure main and transfered cs does not share variables if isinstance(main_cs, cross_similarity.PairwiseCrossSimilarity): main_cs._k2_scope_key = 'main_pairwise_cross_similarity' if isinstance(transfered_cs, cross_similarity.PairwiseCrossSimilarity): transfered_cs._k2_scope_key = 'transfered_pairwise_cross_similarity' return cross_similarity.DoubleCrossSimilarity(main_cs, transfered_cs, double_cross_similarity.main_weight, double_cross_similarity.fea_split_ind) raise ValueError('Unknown cross_similarity: {}'.format(cross_similarity_oneof))
def _fn(k, tree, parall_iterations): pre_convline, post_convline, negative_convline = None, None, None if unit_config.HasField('pre_convline'): pre_convline = convline_builder.build(argscope_fn, None, unit_config.pre_convline, is_training) if unit_config.HasField('post_convline'): post_convline = convline_builder.build(argscope_fn, None, unit_config.post_convline, is_training) if unit_config.HasField('negative_convline'): negative_convline = convline_builder.build(argscope_fn, None, unit_config.negative_convline, is_training) res_fc_hyperparams = None if unit_config.HasField('res_fc_hyperparams'): res_fc_hyperparams = argscope_fn(unit_config.res_fc_hyperparams, is_training) post_convline = _post_convline_builder(argscope_fn, post_convline, unit_config.use_tanh_sigmoid_in_post_convline, unit_config.post_convline_res, res_fc_hyperparams, unit_config.split_fea_in_res, is_training) cross_similarity = cross_similarity_builder.build(argscope_fn, unit_config.cross_similarity, tree, k, is_training) loss = attention_loss_builder.build(unit_config.loss, num_classes, k) max_ncobj_proposals = unit_config.ncobj_proposals if is_training: ncobj_proposals = unit_config.training_subsampler.ncobj_proposals if unit_config.training_subsampler.HasField('topk'): max_ncobj_proposals = unit_config.training_subsampler.topk else: ncobj_proposals = max_ncobj_proposals res_fc_hyperparams = None if unit_config.HasField('res_fc_hyperparams'): res_fc_hyperparams = argscope_fn(unit_config.res_fc_hyperparams, is_training) positive_balance_fraction = None if unit_config.training_subsampler.HasField('positive_balance_fraction'): positive_balance_fraction = unit_config.training_subsampler.positive_balance_fraction return attention_tree.AttentionUnit(ncobj_proposals, max_ncobj_proposals, positive_balance_fraction, k, pre_convline, post_convline, cross_similarity, loss, is_training, unit_config.orig_fea_in_post_convline, unit_config.training_subsampler.sample_hard_examples, unit_config.training_subsampler.stratified, unit_config.loss.positive_balance_fraction, unit_config.loss.minibatch_size, parall_iterations, unit_config.loss.weight, unit_config.negative_example_weight, unit_config.compute_scores_after_matching, unit_config.overwrite_fea_by_scores, negative_convline, is_calibration, calibration_type, unit_config.unary_energy_scale, unit_config.transfered_objectness_weight)
def build(argscope_fn, attention_tree_config, k_shot, num_classes, num_negative_bags, is_training, is_calibration): """Builds attention_tree based on the configuration. Args: argscope_fn: A function that takes the following inputs: * hyperparams_pb2.Hyperparams proto * a boolean indicating if the model is in training mode. and returns a tf slim argscope for Conv and FC hyperparameters. attention_tree_config: k_shot: is_training: Whether the models is in training mode. Returns: attention_tree: attention.attention_tree.AttentionTree object. Raises: ValueError: On unknown parameter learner. """ if not isinstance(attention_tree_config, attention_tree_pb2.AttentionTree): raise ValueError('attention_tree_config not of type ' 'attention_tree_pb2.AttentionTree.') fea_split_ind = None if attention_tree_config.HasField('fea_split_ind'): fea_split_ind = attention_tree_config.fea_split_ind preprocess_convline = None if attention_tree_config.HasField('preprocess_convline'): preprocess_convline = convline_builder.build(argscope_fn, None, attention_tree_config.preprocess_convline, is_training) rescore_convline = None rescore_fc_hyperparams = None if attention_tree_config.rescore_instances: if attention_tree_config.HasField('rescore_convline'): rescore_convline = convline_builder.build(argscope_fn, None, attention_tree_config.rescore_convline, is_training) if attention_tree_config.HasField('rescore_fc_hyperparams'): rescore_fc_hyperparams = argscope_fn( attention_tree_config.rescore_fc_hyperparams, is_training) negative_preprocess_convline = None if attention_tree_config.HasField('negative_preprocess_convline'): negative_preprocess_convline = convline_builder.build(argscope_fn, None, attention_tree_config.negative_preprocess_convline, is_training) negative_postprocess_convline = None if attention_tree_config.HasField('negative_postprocess_convline'): negative_postprocess_convline = convline_builder.build(argscope_fn, None, attention_tree_config.negative_postprocess_convline, is_training) calibration_type = attention_tree_config.CalibrationType.Name(attention_tree_config.calibration_type) units = [build_attention_unit(argscope_fn, unit_config, num_classes, k_shot, is_training, calibration_type, is_calibration) for unit_config in attention_tree_config.unit] subsampler_ncobj, subsampler_pos_frac = None, None subsampler_agnostic = False if attention_tree_config.HasField('training_subsampler'): subsampler = attention_tree_config.training_subsampler if subsampler.HasField('positive_balance_fraction'): subsampler_pos_frac = subsampler.positive_balance_fraction if subsampler.HasField('ncobj_proposals'): subsampler_ncobj = subsampler.ncobj_proposals if subsampler.HasField('agnostic'): subsampler_agnostic = subsampler.agnostic return attention_tree.AttentionTree(units, k_shot, num_negative_bags, is_training, attention_tree_config.stop_features_gradient, preprocess_convline, num_classes, attention_tree_config.rescore_instances, attention_tree_config.rescore_min_match_frac, rescore_convline, rescore_fc_hyperparams, negative_preprocess_convline, negative_postprocess_convline, subsampler_ncobj, subsampler_pos_frac, subsampler_agnostic, fea_split_ind)
def _build_rcnn_attention_model(rcnna_config, is_training, is_calibration): """Builds a R-CNN attention model based on the model config. Args: rcnna_config: A rcnn_attention.proto object containing the config for the desired RCNNAttention model. is_training: True if this model is being built for training purposes. Returns: RCNNAttentionMetaArch based on the config. Raises: ValueError: If rcnna_config.type is not recognized (i.e. not registered in model_class_map). """ num_classes = rcnna_config.num_classes k_shot = rcnna_config.k_shot image_resizer_fn = image_resizer_builder.build(rcnna_config.image_resizer) feature_extractor = _build_faster_rcnn_feature_extractor( rcnna_config.feature_extractor, is_training) first_stage_only = rcnna_config.first_stage_only first_stage_anchor_generator = anchor_generator_builder.build( rcnna_config.first_stage_anchor_generator) first_stage_atrous_rate = rcnna_config.first_stage_atrous_rate first_stage_box_predictor_arg_scope = hyperparams_builder.build( rcnna_config.first_stage_box_predictor_conv_hyperparams, is_training) first_stage_box_predictor_kernel_size = ( rcnna_config.first_stage_box_predictor_kernel_size) first_stage_box_predictor_depth = rcnna_config.first_stage_box_predictor_depth first_stage_minibatch_size = rcnna_config.first_stage_minibatch_size first_stage_positive_balance_fraction = ( rcnna_config.first_stage_positive_balance_fraction) first_stage_nms_score_threshold = rcnna_config.first_stage_nms_score_threshold first_stage_nms_iou_threshold = rcnna_config.first_stage_nms_iou_threshold first_stage_max_proposals = rcnna_config.first_stage_max_proposals first_stage_loc_loss_weight = ( rcnna_config.first_stage_localization_loss_weight) first_stage_obj_loss_weight = rcnna_config.first_stage_objectness_loss_weight initial_crop_size = rcnna_config.initial_crop_size maxpool_kernel_size = rcnna_config.maxpool_kernel_size maxpool_stride = rcnna_config.maxpool_stride second_stage_box_predictor = build_box_predictor( hyperparams_builder.build, rcnna_config.second_stage_box_predictor, is_training=is_training, num_classes=num_classes) second_stage_batch_size = rcnna_config.second_stage_batch_size second_stage_balance_fraction = rcnna_config.second_stage_balance_fraction (second_stage_non_max_suppression_fn, second_stage_score_conversion_fn) = post_processing_builder.build( rcnna_config.second_stage_post_processing) second_stage_localization_loss_weight = ( rcnna_config.second_stage_localization_loss_weight) second_stage_classification_loss_weight = ( rcnna_config.second_stage_classification_loss_weight) hard_example_miner = None if rcnna_config.HasField('hard_example_miner'): hard_example_miner = losses_builder.build_hard_example_miner( rcnna_config.hard_example_miner, second_stage_classification_loss_weight, second_stage_localization_loss_weight) attention_tree = None if rcnna_config.HasField('attention_tree'): attention_tree = attention_tree_builder.build( hyperparams_builder.build, rcnna_config.attention_tree, rcnna_config.k_shot, num_classes, rcnna_config.num_negative_bags, is_training, is_calibration) second_stage_convline = None if rcnna_config.HasField('second_stage_convline'): second_stage_convline = convline_builder.build( hyperparams_builder.build, None, rcnna_config.second_stage_convline, is_training) common_kwargs = { 'is_training': is_training, 'image_resizer_fn': image_resizer_fn, 'feature_extractor': feature_extractor, 'first_stage_only': first_stage_only, 'first_stage_anchor_generator': first_stage_anchor_generator, 'first_stage_atrous_rate': first_stage_atrous_rate, 'first_stage_box_predictor_arg_scope': first_stage_box_predictor_arg_scope, 'first_stage_box_predictor_kernel_size': first_stage_box_predictor_kernel_size, 'first_stage_box_predictor_depth': first_stage_box_predictor_depth, 'first_stage_minibatch_size': first_stage_minibatch_size, 'first_stage_positive_balance_fraction': first_stage_positive_balance_fraction, 'first_stage_nms_score_threshold': first_stage_nms_score_threshold, 'first_stage_nms_iou_threshold': first_stage_nms_iou_threshold, 'first_stage_max_proposals': first_stage_max_proposals, 'first_stage_localization_loss_weight': first_stage_loc_loss_weight, 'first_stage_objectness_loss_weight': first_stage_obj_loss_weight, 'second_stage_batch_size': second_stage_batch_size, 'second_stage_balance_fraction': second_stage_balance_fraction, 'second_stage_non_max_suppression_fn': second_stage_non_max_suppression_fn, 'second_stage_score_conversion_fn': second_stage_score_conversion_fn, 'second_stage_localization_loss_weight': second_stage_localization_loss_weight, 'second_stage_classification_loss_weight': second_stage_classification_loss_weight, 'hard_example_miner': hard_example_miner, 'initial_crop_size': initial_crop_size, 'maxpool_kernel_size': maxpool_kernel_size, 'maxpool_stride': maxpool_stride, 'second_stage_mask_rcnn_box_predictor': second_stage_box_predictor, 'num_classes': num_classes } if isinstance(second_stage_box_predictor, box_predictor.RfcnBoxPredictor): raise ValueError('RFCNBoxPredictor is not supported.') elif rcnna_config.build_faster_rcnn_arch: model = faster_rcnn_meta_arch.FasterRCNNMetaArch(**common_kwargs) model._k_shot = k_shot model._tree_debug_tensors = lambda: {} return model else: return rcnn_attention_meta_arch.RCNNAttentionMetaArch( k_shot=k_shot, attention_tree=attention_tree, second_stage_convline=second_stage_convline, attention_tree_only=rcnna_config.attention_tree_only, add_gt_boxes_to_rpn=rcnna_config.add_gt_boxes_to_rpn, **common_kwargs)