def build_convline(text_proto):
    convline_config = convline_pb2.ConvLine()
    text_format.Merge(text_proto, convline_config)
    argscope_fn = hyperparams_builder.build
    dynamic_argscope_fn = dynamic_hyperparams_builder.build
    return convline_builder.build(argscope_fn, dynamic_argscope_fn,
                                  convline_config, True)
コード例 #2
0
def build(argscope_fn, negative_attention_config, is_training):
    convline = None
    if negative_attention_config.HasField('convline'):
        convline = convline_builder.build(argscope_fn, None,
                                          negative_attention_config.convline,
                                          is_training)

    return NegativeAttention(
        convline=convline,
        concat_type=negative_attention_config.ConcatType.Name(
            negative_attention_config.concat_type),
        similarity_type=negative_attention_config.SimilarityType.Name(
            negative_attention_config.similarity_type),
        use_gt_labels=negative_attention_config.use_gt_labels,
        add_loss=negative_attention_config.add_loss)
def build_cross_similarity(argscope_fn, cross_similarity_config, attention_tree, k, is_training):
  cross_similarity_oneof = cross_similarity_config.WhichOneof('cross_similarity_oneof')
  if cross_similarity_oneof == 'cosine_cross_similarity':
    cosine_cross_similarity = cross_similarity_config.cosine_cross_similarity
    return cross_similarity.CosineCrossSimilarity()
  elif cross_similarity_oneof == 'linear_cross_similarity':
    linear_cross_similarity = cross_similarity_config.linear_cross_similarity
    fc_hyperparameters = argscope_fn(
      linear_cross_similarity.fc_hyperparameters, is_training)
    return cross_similarity.LinearCrossSimilarity(fc_hyperparameters)
  elif cross_similarity_oneof == 'deep_cross_similarity':
    deep_cross_similarity = cross_similarity_config.deep_cross_similarity
    fc_hyperparameters = argscope_fn(
      deep_cross_similarity.fc_hyperparameters, is_training)
    convline = None
    if deep_cross_similarity.HasField('convline'):
      convline = convline_builder.build(argscope_fn,
                                        None,
                                        deep_cross_similarity.convline,
                                        is_training)
    negative_attention = None
    if deep_cross_similarity.HasField('negative_attention'):
      negative_attention = negative_attention_builder.build(argscope_fn,
                                                            deep_cross_similarity.negative_attention,
                                                            is_training)
    return cross_similarity.DeepCrossSimilarity(
        deep_cross_similarity.stop_gradient,
        fc_hyperparameters, convline,
        negative_attention,
        sum_output=deep_cross_similarity.sum_output)
  elif cross_similarity_oneof == 'average_cross_similarity':
    average_cross_similarity = cross_similarity_config.average_cross_similarity
    return cross_similarity.AverageCrossSimilarity()
  elif cross_similarity_oneof == 'euclidean_cross_similarity':
    return cross_similarity.EuclideanCrossSimilarity()
  elif cross_similarity_oneof == 'pairwise_cross_similarity':
    pairwise_cross_similarity = cross_similarity_config.pairwise_cross_similarity
    base_cross_similarity = build_cross_similarity(argscope_fn,
                                                   pairwise_cross_similarity.cross_similarity,
                                                   attention_tree,
                                                   k,
                                                   is_training)
    return cross_similarity.PairwiseCrossSimilarity(pairwise_cross_similarity.stop_gradient,
                                                    base_cross_similarity, k,
                                                    attention_tree=attention_tree)
  elif cross_similarity_oneof == 'k1_cross_similarity':
    k1_cross_similarity = cross_similarity_config.k1_cross_similarity
    base_cross_similarity = build_cross_similarity(argscope_fn,
                                                   k1_cross_similarity.cross_similarity,
                                                   attention_tree,
                                                   k,
                                                   is_training)
    return cross_similarity.K1CrossSimilarity(base_cross_similarity, k,
                    k1_cross_similarity.share_weights_with_pairwise_cs,
                                              k1_cross_similarity.mode,
                                              k1_cross_similarity.topk)
  elif cross_similarity_oneof == 'double_cross_similarity':
    double_cross_similarity = cross_similarity_config.double_cross_similarity
    main_cs = build_cross_similarity(argscope_fn,
                                     double_cross_similarity.main,
                                     attention_tree, k, is_training)
    if double_cross_similarity.HasField('transfered'):
      transfered_config = double_cross_similarity.transfered
    else:
      transfered_config = double_cross_similarity.main

    transfered_cs = build_cross_similarity(argscope_fn,
                                           transfered_config,
                                           attention_tree,
                                           k,
                                           is_training)

    ## PairwiseCrossSimilarity overrides the scope
    ## We need these to ensure main and transfered cs does not share variables
    if isinstance(main_cs, cross_similarity.PairwiseCrossSimilarity):
      main_cs._k2_scope_key = 'main_pairwise_cross_similarity'
    if isinstance(transfered_cs, cross_similarity.PairwiseCrossSimilarity):
      transfered_cs._k2_scope_key = 'transfered_pairwise_cross_similarity'

    return cross_similarity.DoubleCrossSimilarity(main_cs, transfered_cs,
                                                  double_cross_similarity.main_weight,
                                                  double_cross_similarity.fea_split_ind)

  raise ValueError('Unknown cross_similarity: {}'.format(cross_similarity_oneof))
  def _fn(k, tree, parall_iterations):
    pre_convline, post_convline, negative_convline = None, None, None
    if unit_config.HasField('pre_convline'):
      pre_convline = convline_builder.build(argscope_fn, None,
                                            unit_config.pre_convline,
                                            is_training)

    if unit_config.HasField('post_convline'):
      post_convline = convline_builder.build(argscope_fn, None,
                                            unit_config.post_convline,
                                            is_training)
    if unit_config.HasField('negative_convline'):
      negative_convline = convline_builder.build(argscope_fn, None,
                                            unit_config.negative_convline,
                                            is_training)

    res_fc_hyperparams = None
    if unit_config.HasField('res_fc_hyperparams'):
      res_fc_hyperparams = argscope_fn(unit_config.res_fc_hyperparams,
                                       is_training)

    post_convline = _post_convline_builder(argscope_fn,
                                           post_convline,
                                           unit_config.use_tanh_sigmoid_in_post_convline,
                                           unit_config.post_convline_res,
                                           res_fc_hyperparams,
                                           unit_config.split_fea_in_res,
                                           is_training)

    cross_similarity = cross_similarity_builder.build(argscope_fn,
                                                      unit_config.cross_similarity,
                                                      tree,
                                                      k, is_training)
    loss = attention_loss_builder.build(unit_config.loss, num_classes, k)

    max_ncobj_proposals = unit_config.ncobj_proposals
    if is_training:
      ncobj_proposals = unit_config.training_subsampler.ncobj_proposals
      if unit_config.training_subsampler.HasField('topk'):
        max_ncobj_proposals = unit_config.training_subsampler.topk
    else:
      ncobj_proposals = max_ncobj_proposals

    res_fc_hyperparams = None
    if unit_config.HasField('res_fc_hyperparams'):
      res_fc_hyperparams = argscope_fn(unit_config.res_fc_hyperparams,
                                       is_training)

    positive_balance_fraction = None
    if unit_config.training_subsampler.HasField('positive_balance_fraction'):
      positive_balance_fraction = unit_config.training_subsampler.positive_balance_fraction

    return attention_tree.AttentionUnit(ncobj_proposals, max_ncobj_proposals,
                                        positive_balance_fraction,
                                        k, pre_convline, post_convline,
                                        cross_similarity,
                                        loss,
                                        is_training,
                                        unit_config.orig_fea_in_post_convline,
                                        unit_config.training_subsampler.sample_hard_examples,
                                        unit_config.training_subsampler.stratified,
                                        unit_config.loss.positive_balance_fraction,
                                        unit_config.loss.minibatch_size,
                                        parall_iterations,
                                        unit_config.loss.weight,
                                        unit_config.negative_example_weight,
                                        unit_config.compute_scores_after_matching,
                                        unit_config.overwrite_fea_by_scores,
                                        negative_convline,
                                        is_calibration,
                                        calibration_type,
                                        unit_config.unary_energy_scale,
                                        unit_config.transfered_objectness_weight)
def build(argscope_fn, attention_tree_config,
          k_shot, num_classes, num_negative_bags, is_training, is_calibration):
  """Builds attention_tree based on the configuration.
  Args:
    argscope_fn: A function that takes the following inputs:
        * hyperparams_pb2.Hyperparams proto
        * a boolean indicating if the model is in training mode.
      and returns a tf slim argscope for Conv and FC hyperparameters.
    attention_tree_config:
    k_shot:
    is_training: Whether the models is in training mode.
  Returns:
    attention_tree: attention.attention_tree.AttentionTree object.
  Raises:
    ValueError: On unknown parameter learner.
  """
  if not isinstance(attention_tree_config, attention_tree_pb2.AttentionTree):
    raise ValueError('attention_tree_config not of type '
                     'attention_tree_pb2.AttentionTree.')

  fea_split_ind = None
  if attention_tree_config.HasField('fea_split_ind'):
    fea_split_ind = attention_tree_config.fea_split_ind

  preprocess_convline = None
  if attention_tree_config.HasField('preprocess_convline'):
    preprocess_convline = convline_builder.build(argscope_fn,
                                          None,
                                          attention_tree_config.preprocess_convline,
                                          is_training)
  rescore_convline = None
  rescore_fc_hyperparams = None
  if attention_tree_config.rescore_instances:
    if attention_tree_config.HasField('rescore_convline'):
      rescore_convline = convline_builder.build(argscope_fn,
                                              None,
                                              attention_tree_config.rescore_convline,
                                              is_training)
    if attention_tree_config.HasField('rescore_fc_hyperparams'):
      rescore_fc_hyperparams = argscope_fn(
            attention_tree_config.rescore_fc_hyperparams,
            is_training)

  negative_preprocess_convline = None
  if attention_tree_config.HasField('negative_preprocess_convline'):
    negative_preprocess_convline = convline_builder.build(argscope_fn,
        None,
        attention_tree_config.negative_preprocess_convline,
        is_training)

  negative_postprocess_convline = None
  if attention_tree_config.HasField('negative_postprocess_convline'):
    negative_postprocess_convline = convline_builder.build(argscope_fn,
        None,
        attention_tree_config.negative_postprocess_convline,
        is_training)
  calibration_type = attention_tree_config.CalibrationType.Name(attention_tree_config.calibration_type)
  units = [build_attention_unit(argscope_fn,
                                unit_config,
                                num_classes,
                                k_shot,
                                is_training,
                                calibration_type,
                                is_calibration)
                                for unit_config in attention_tree_config.unit]

  subsampler_ncobj, subsampler_pos_frac = None, None
  subsampler_agnostic = False
  if attention_tree_config.HasField('training_subsampler'):
      subsampler = attention_tree_config.training_subsampler
      if subsampler.HasField('positive_balance_fraction'):
        subsampler_pos_frac = subsampler.positive_balance_fraction
      if subsampler.HasField('ncobj_proposals'):
        subsampler_ncobj = subsampler.ncobj_proposals
      if subsampler.HasField('agnostic'):
        subsampler_agnostic = subsampler.agnostic

  return attention_tree.AttentionTree(units, k_shot,
                                      num_negative_bags, is_training,
                                      attention_tree_config.stop_features_gradient,
                                      preprocess_convline,
                                      num_classes,
                                      attention_tree_config.rescore_instances,
                                      attention_tree_config.rescore_min_match_frac,
                                      rescore_convline,
                                      rescore_fc_hyperparams,
                                      negative_preprocess_convline,
                                      negative_postprocess_convline,
                                      subsampler_ncobj,
                                      subsampler_pos_frac,
                                      subsampler_agnostic,
                                      fea_split_ind)
コード例 #6
0
def _build_rcnn_attention_model(rcnna_config, is_training, is_calibration):
    """Builds a R-CNN attention model based on the model config.

  Args:
    rcnna_config: A rcnn_attention.proto object containing the config for the
    desired RCNNAttention model.
    is_training: True if this model is being built for training purposes.

  Returns:
    RCNNAttentionMetaArch based on the config.
  Raises:
    ValueError: If rcnna_config.type is not recognized (i.e. not registered in
      model_class_map).
  """
    num_classes = rcnna_config.num_classes
    k_shot = rcnna_config.k_shot
    image_resizer_fn = image_resizer_builder.build(rcnna_config.image_resizer)

    feature_extractor = _build_faster_rcnn_feature_extractor(
        rcnna_config.feature_extractor, is_training)

    first_stage_only = rcnna_config.first_stage_only
    first_stage_anchor_generator = anchor_generator_builder.build(
        rcnna_config.first_stage_anchor_generator)

    first_stage_atrous_rate = rcnna_config.first_stage_atrous_rate
    first_stage_box_predictor_arg_scope = hyperparams_builder.build(
        rcnna_config.first_stage_box_predictor_conv_hyperparams, is_training)
    first_stage_box_predictor_kernel_size = (
        rcnna_config.first_stage_box_predictor_kernel_size)
    first_stage_box_predictor_depth = rcnna_config.first_stage_box_predictor_depth
    first_stage_minibatch_size = rcnna_config.first_stage_minibatch_size
    first_stage_positive_balance_fraction = (
        rcnna_config.first_stage_positive_balance_fraction)
    first_stage_nms_score_threshold = rcnna_config.first_stage_nms_score_threshold
    first_stage_nms_iou_threshold = rcnna_config.first_stage_nms_iou_threshold
    first_stage_max_proposals = rcnna_config.first_stage_max_proposals
    first_stage_loc_loss_weight = (
        rcnna_config.first_stage_localization_loss_weight)
    first_stage_obj_loss_weight = rcnna_config.first_stage_objectness_loss_weight

    initial_crop_size = rcnna_config.initial_crop_size
    maxpool_kernel_size = rcnna_config.maxpool_kernel_size
    maxpool_stride = rcnna_config.maxpool_stride

    second_stage_box_predictor = build_box_predictor(
        hyperparams_builder.build,
        rcnna_config.second_stage_box_predictor,
        is_training=is_training,
        num_classes=num_classes)
    second_stage_batch_size = rcnna_config.second_stage_batch_size
    second_stage_balance_fraction = rcnna_config.second_stage_balance_fraction
    (second_stage_non_max_suppression_fn,
     second_stage_score_conversion_fn) = post_processing_builder.build(
         rcnna_config.second_stage_post_processing)
    second_stage_localization_loss_weight = (
        rcnna_config.second_stage_localization_loss_weight)
    second_stage_classification_loss_weight = (
        rcnna_config.second_stage_classification_loss_weight)

    hard_example_miner = None
    if rcnna_config.HasField('hard_example_miner'):
        hard_example_miner = losses_builder.build_hard_example_miner(
            rcnna_config.hard_example_miner,
            second_stage_classification_loss_weight,
            second_stage_localization_loss_weight)

    attention_tree = None
    if rcnna_config.HasField('attention_tree'):
        attention_tree = attention_tree_builder.build(
            hyperparams_builder.build, rcnna_config.attention_tree,
            rcnna_config.k_shot, num_classes, rcnna_config.num_negative_bags,
            is_training, is_calibration)

    second_stage_convline = None
    if rcnna_config.HasField('second_stage_convline'):
        second_stage_convline = convline_builder.build(
            hyperparams_builder.build, None,
            rcnna_config.second_stage_convline, is_training)

    common_kwargs = {
        'is_training': is_training,
        'image_resizer_fn': image_resizer_fn,
        'feature_extractor': feature_extractor,
        'first_stage_only': first_stage_only,
        'first_stage_anchor_generator': first_stage_anchor_generator,
        'first_stage_atrous_rate': first_stage_atrous_rate,
        'first_stage_box_predictor_arg_scope':
        first_stage_box_predictor_arg_scope,
        'first_stage_box_predictor_kernel_size':
        first_stage_box_predictor_kernel_size,
        'first_stage_box_predictor_depth': first_stage_box_predictor_depth,
        'first_stage_minibatch_size': first_stage_minibatch_size,
        'first_stage_positive_balance_fraction':
        first_stage_positive_balance_fraction,
        'first_stage_nms_score_threshold': first_stage_nms_score_threshold,
        'first_stage_nms_iou_threshold': first_stage_nms_iou_threshold,
        'first_stage_max_proposals': first_stage_max_proposals,
        'first_stage_localization_loss_weight': first_stage_loc_loss_weight,
        'first_stage_objectness_loss_weight': first_stage_obj_loss_weight,
        'second_stage_batch_size': second_stage_batch_size,
        'second_stage_balance_fraction': second_stage_balance_fraction,
        'second_stage_non_max_suppression_fn':
        second_stage_non_max_suppression_fn,
        'second_stage_score_conversion_fn': second_stage_score_conversion_fn,
        'second_stage_localization_loss_weight':
        second_stage_localization_loss_weight,
        'second_stage_classification_loss_weight':
        second_stage_classification_loss_weight,
        'hard_example_miner': hard_example_miner,
        'initial_crop_size': initial_crop_size,
        'maxpool_kernel_size': maxpool_kernel_size,
        'maxpool_stride': maxpool_stride,
        'second_stage_mask_rcnn_box_predictor': second_stage_box_predictor,
        'num_classes': num_classes
    }

    if isinstance(second_stage_box_predictor, box_predictor.RfcnBoxPredictor):
        raise ValueError('RFCNBoxPredictor is not supported.')
    elif rcnna_config.build_faster_rcnn_arch:
        model = faster_rcnn_meta_arch.FasterRCNNMetaArch(**common_kwargs)
        model._k_shot = k_shot
        model._tree_debug_tensors = lambda: {}
        return model
    else:
        return rcnn_attention_meta_arch.RCNNAttentionMetaArch(
            k_shot=k_shot,
            attention_tree=attention_tree,
            second_stage_convline=second_stage_convline,
            attention_tree_only=rcnna_config.attention_tree_only,
            add_gt_boxes_to_rpn=rcnna_config.add_gt_boxes_to_rpn,
            **common_kwargs)