def _get_checkpoint_score( self, checkpoint: _TrackedCheckpoint ) -> Tuple[bool, numbers.Number, int]: checkpoint_score_attribute = ( self._checkpoint_strategy.checkpoint_score_attribute) if checkpoint_score_attribute not in checkpoint.metrics: logger.error( f"Result dict has no key: {checkpoint_score_attribute}. " f"checkpoint_score_attr must be set to a key in the " f"result dict. Valid keys are: {list(checkpoint.metrics.keys())}" ) checkpoint_result = float("-inf") else: checkpoint_result = checkpoint.metrics[checkpoint_score_attribute] checkpoint_score_order = self._checkpoint_strategy.checkpoint_score_order if checkpoint_score_order == MAX: order_factor = 1.0 else: order_factor = -1.0 checkpoint_score = order_factor * checkpoint_result if not isinstance(checkpoint_score, numbers.Number): raise ValueError(f"Unable to persist checkpoint for " f"checkpoint_score_attribute: " f"{checkpoint_score_attribute} with value " f"{checkpoint_score}. " f"This attribute must be numerical.") return ( not is_nan(checkpoint_score), checkpoint_score if not is_nan(checkpoint_score) else 0, checkpoint.id, )
def priority(checkpoint_score_order, checkpoint_score): # Treat NaN as worst # The tuple structure is (not is_nan(), metric), which makes # the nan values to be always considered as the worst # metrics by the heap if checkpoint_score_order != MAX: checkpoint_score = -checkpoint_score return (not is_nan(checkpoint_score), checkpoint_score)