def temporal_offset_proto_to_params(temporal_offset_config): """Converts CenterNet.TemporalOffsetEstimation proto to param-tuple.""" loss = losses_pb2.Loss() # Add dummy classification loss to avoid the loss_builder throwing error. # TODO(yuhuic): update the loss builder to take the classification loss # directly. loss.classification_loss.weighted_sigmoid.CopyFrom( losses_pb2.WeightedSigmoidClassificationLoss()) loss.localization_loss.CopyFrom(temporal_offset_config.localization_loss) _, localization_loss, _, _, _, _, _ = losses_builder.build(loss) return center_net_meta_arch.TemporalOffsetParams( localization_loss=localization_loss, task_loss_weight=temporal_offset_config.task_loss_weight)
def object_detection_proto_to_params(od_config): """Converts CenterNet.ObjectDetection proto to parameter namedtuple.""" loss = losses_pb2.Loss() # Add dummy classification loss to avoid the loss_builder throwing error. # TODO(yuhuic): update the loss builder to take the classification loss # directly. loss.classification_loss.weighted_sigmoid.CopyFrom( losses_pb2.WeightedSigmoidClassificationLoss()) loss.localization_loss.CopyFrom(od_config.localization_loss) _, localization_loss, _, _, _, _, _ = (losses_builder.build(loss)) return center_net_meta_arch.ObjectDetectionParams( localization_loss=localization_loss, scale_loss_weight=od_config.scale_loss_weight, offset_loss_weight=od_config.offset_loss_weight, task_loss_weight=od_config.task_loss_weight)