def deepmac_proto_to_params(deepmac_config): """Convert proto to named tuple.""" loss = losses_pb2.Loss() # Add dummy localization loss to avoid the loss_builder throwing error. loss.localization_loss.weighted_l2.CopyFrom( losses_pb2.WeightedL2LocalizationLoss()) loss.classification_loss.CopyFrom(deepmac_config.classification_loss) classification_loss, _, _, _, _, _, _ = (losses_builder.build(loss)) return DeepMACParams( dim=deepmac_config.dim, classification_loss=classification_loss, task_loss_weight=deepmac_config.task_loss_weight, pixel_embedding_dim=deepmac_config.pixel_embedding_dim, allowed_masked_classes_ids=deepmac_config.allowed_masked_classes_ids, mask_size=deepmac_config.mask_size, mask_num_subsamples=deepmac_config.mask_num_subsamples, use_xy=deepmac_config.use_xy, network_type=deepmac_config.network_type, use_instance_embedding=deepmac_config.use_instance_embedding, num_init_channels=deepmac_config.num_init_channels, predict_full_resolution_masks=deepmac_config. predict_full_resolution_masks, postprocess_crop_size=deepmac_config.postprocess_crop_size)
def mask_proto_to_params(mask_config): """Converts CenterNet.MaskEstimation proto to parameter namedtuple.""" loss = losses_pb2.Loss() # Add dummy localization loss to avoid the loss_builder throwing error. loss.localization_loss.weighted_l2.CopyFrom( losses_pb2.WeightedL2LocalizationLoss()) loss.classification_loss.CopyFrom(mask_config.classification_loss) classification_loss, _, _, _, _, _, _ = (losses_builder.build(loss)) return center_net_meta_arch.MaskParams( classification_loss=classification_loss, task_loss_weight=mask_config.task_loss_weight, mask_height=mask_config.mask_height, mask_width=mask_config.mask_width, score_threshold=mask_config.score_threshold, heatmap_bias_init=mask_config.heatmap_bias_init)
def tracking_proto_to_params(tracking_config): """Converts CenterNet.TrackEstimation proto to parameter namedtuple.""" loss = losses_pb2.Loss() # Add dummy localization loss to avoid the loss_builder throwing error. # TODO(yuhuic): update the loss builder to take the localization loss # directly. loss.localization_loss.weighted_l2.CopyFrom( losses_pb2.WeightedL2LocalizationLoss()) loss.classification_loss.CopyFrom(tracking_config.classification_loss) classification_loss, _, _, _, _, _, _ = losses_builder.build(loss) return center_net_meta_arch.TrackParams( num_track_ids=tracking_config.num_track_ids, reid_embed_size=tracking_config.reid_embed_size, classification_loss=classification_loss, num_fc_layers=tracking_config.num_fc_layers, task_loss_weight=tracking_config.task_loss_weight)
def object_center_proto_to_params(oc_config): """Converts CenterNet.ObjectCenter proto to parameter namedtuple.""" loss = losses_pb2.Loss() # Add dummy localization loss to avoid the loss_builder throwing error. # TODO(yuhuic): update the loss builder to take the localization loss # directly. loss.localization_loss.weighted_l2.CopyFrom( losses_pb2.WeightedL2LocalizationLoss()) loss.classification_loss.CopyFrom(oc_config.classification_loss) classification_loss, _, _, _, _, _, _ = (losses_builder.build(loss)) return center_net_meta_arch.ObjectCenterParams( classification_loss=classification_loss, object_center_loss_weight=oc_config.object_center_loss_weight, heatmap_bias_init=oc_config.heatmap_bias_init, min_box_overlap_iou=oc_config.min_box_overlap_iou, max_box_predictions=oc_config.max_box_predictions)