def deepmac_proto_to_params(deepmac_config):
    """Convert proto to named tuple."""

    loss = losses_pb2.Loss()
    # Add dummy localization loss to avoid the loss_builder throwing error.
    loss.localization_loss.weighted_l2.CopyFrom(
        losses_pb2.WeightedL2LocalizationLoss())
    loss.classification_loss.CopyFrom(deepmac_config.classification_loss)
    classification_loss, _, _, _, _, _, _ = (losses_builder.build(loss))

    return DeepMACParams(
        dim=deepmac_config.dim,
        classification_loss=classification_loss,
        task_loss_weight=deepmac_config.task_loss_weight,
        pixel_embedding_dim=deepmac_config.pixel_embedding_dim,
        allowed_masked_classes_ids=deepmac_config.allowed_masked_classes_ids,
        mask_size=deepmac_config.mask_size,
        mask_num_subsamples=deepmac_config.mask_num_subsamples,
        use_xy=deepmac_config.use_xy,
        network_type=deepmac_config.network_type,
        use_instance_embedding=deepmac_config.use_instance_embedding,
        num_init_channels=deepmac_config.num_init_channels,
        predict_full_resolution_masks=deepmac_config.
        predict_full_resolution_masks,
        postprocess_crop_size=deepmac_config.postprocess_crop_size)
Beispiel #2
0
def mask_proto_to_params(mask_config):
    """Converts CenterNet.MaskEstimation proto to parameter namedtuple."""
    loss = losses_pb2.Loss()
    # Add dummy localization loss to avoid the loss_builder throwing error.
    loss.localization_loss.weighted_l2.CopyFrom(
        losses_pb2.WeightedL2LocalizationLoss())
    loss.classification_loss.CopyFrom(mask_config.classification_loss)
    classification_loss, _, _, _, _, _, _ = (losses_builder.build(loss))
    return center_net_meta_arch.MaskParams(
        classification_loss=classification_loss,
        task_loss_weight=mask_config.task_loss_weight,
        mask_height=mask_config.mask_height,
        mask_width=mask_config.mask_width,
        score_threshold=mask_config.score_threshold,
        heatmap_bias_init=mask_config.heatmap_bias_init)
Beispiel #3
0
def tracking_proto_to_params(tracking_config):
    """Converts CenterNet.TrackEstimation proto to parameter namedtuple."""
    loss = losses_pb2.Loss()
    # Add dummy localization loss to avoid the loss_builder throwing error.
    # TODO(yuhuic): update the loss builder to take the localization loss
    # directly.
    loss.localization_loss.weighted_l2.CopyFrom(
        losses_pb2.WeightedL2LocalizationLoss())
    loss.classification_loss.CopyFrom(tracking_config.classification_loss)
    classification_loss, _, _, _, _, _, _ = losses_builder.build(loss)
    return center_net_meta_arch.TrackParams(
        num_track_ids=tracking_config.num_track_ids,
        reid_embed_size=tracking_config.reid_embed_size,
        classification_loss=classification_loss,
        num_fc_layers=tracking_config.num_fc_layers,
        task_loss_weight=tracking_config.task_loss_weight)
Beispiel #4
0
def object_center_proto_to_params(oc_config):
    """Converts CenterNet.ObjectCenter proto to parameter namedtuple."""
    loss = losses_pb2.Loss()
    # Add dummy localization loss to avoid the loss_builder throwing error.
    # TODO(yuhuic): update the loss builder to take the localization loss
    # directly.
    loss.localization_loss.weighted_l2.CopyFrom(
        losses_pb2.WeightedL2LocalizationLoss())
    loss.classification_loss.CopyFrom(oc_config.classification_loss)
    classification_loss, _, _, _, _, _, _ = (losses_builder.build(loss))
    return center_net_meta_arch.ObjectCenterParams(
        classification_loss=classification_loss,
        object_center_loss_weight=oc_config.object_center_loss_weight,
        heatmap_bias_init=oc_config.heatmap_bias_init,
        min_box_overlap_iou=oc_config.min_box_overlap_iou,
        max_box_predictions=oc_config.max_box_predictions)