Ejemplo n.º 1
def _extract_observed_pareto_2d(
    Y: np.ndarray,
    reference_point: Optional[Tuple[float, float]],
    minimize: Union[bool, Tuple[bool, bool]] = True,
) -> np.ndarray:
    if Y.shape[1] != 2:
        raise NotImplementedError("Currently only the 2-dim case is handled.")
    # If `minimize` is a bool, apply to both dimensions
    if isinstance(minimize, bool):
        minimize = (minimize, minimize)
    Y_copy = deepcopy(torch.from_numpy(Y).to())
    if reference_point:
        ref_point = torch.tensor(reference_point, dtype=Y_copy.dtype)
        for i in range(2):
            # Filter based on reference point
            Y_copy = (
                Y_copy[Y_copy[:, i] < ref_point[i]]
                if minimize[i]
                else Y_copy[Y_copy[:, i] > ref_point[i]]
    for i in range(2):
        # Flip sign in each dimension based on minimize
        Y_copy[:, i] *= (-1) ** minimize[i]
    Y_pareto = Y_copy[is_non_dominated(Y_copy)]
    Y_pareto = Y_pareto[torch.argsort(input=Y_pareto[:, 0], descending=True)]
    for i in range(2):
        # Flip sign back
        Y_pareto[:, i] *= (-1) ** minimize[i]

    assert Y_pareto.shape[1] == 2  # Y_pareto should have two outcomes.
    return Y_pareto.detach().cpu().numpy()
Ejemplo n.º 2
def _extract_observed_pareto_2d(Y: torch.Tensor,
                                reference_point: Tuple[float, float],
                                minimize: bool = True) -> torch.Tensor:
    if Y.shape[1] != 2:
        raise NotImplementedError("Currently only the 2-dim case is handled.")
    ref_point = torch.tensor(reference_point, dtype=Y.dtype)
    Y_pareto = Y[is_non_dominated(-1 * Y if minimize else Y)]
    Y_pareto = (Y_pareto[torch.all(Y_pareto < ref_point, dim=1)] if minimize
                else Y_pareto[torch.all(Y_pareto > ref_point, dim=1)])
    Y_pareto = Y_pareto[torch.argsort(Y_pareto[:, 0])]
    if Y_pareto.shape[0] == 0:
        better = "below" if minimize else "above"
        raise ValueError(
            f"No Pareto-optimal points in `Y` were {better} the reference point."

    assert Y_pareto.shape[1] == 2  # Y_pareto should have two outcomes.
    return Y_pareto